├── dice.py ├── generator.py ├── train.py ├── preprocess.py ├── model.py └── augmentation.py /dice.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | 3 | smooth = 1. 4 | 5 | def dice_coef(y_true, y_pred): 6 | y_true_f = K.flatten(y_true) 7 | y_pred_f = K.flatten(y_pred) 8 | intersection = K.sum(y_true_f * y_pred_f) 9 | return (2. * intersection + smooth) / (K.sum(y_true_f * y_true_f) + K.sum(y_pred_f * y_pred_f) + smooth) 10 | 11 | 12 | def dice_coef_loss(y_true, y_pred): 13 | return 1. - dice_coef(y_true, y_pred) 14 | 15 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def generator(images, masks): 5 | while True: 6 | x_batch, y_batch = [], [] 7 | 8 | for i in range(8): 9 | 10 | s = np.random.randint(images.shape[0]) 11 | 12 | img = images[s] 13 | msk = masks[s] 14 | 15 | x_batch.append(img) 16 | y_batch.append(msk) 17 | 18 | yield np.array(x_batch), np.array(y_batch) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from model import unet 2 | from generator import * 3 | from preprocess import * 4 | from keras.callbacks import EarlyStopping 5 | import keras.backend as K 6 | import tensorflow as tf 7 | import os 8 | 9 | 10 | def train(image_path, mask_path): 11 | print('load data>>>>') 12 | image_train, image_valid, mask_train, mask_valid = preprocess_data_train( 13 | image_path, mask_path, size=64, replica=3, split=True) 14 | 15 | print('data loading complete!') 16 | 17 | print('model loaded>>>>') 18 | print('fitting model>>>>') 19 | config = tf.ConfigProto() 20 | config.gpu_options.allow_growth = True 21 | with tf.Session(graph=tf.get_default_graph(), config=config) as sess: 22 | K.set_session(sess) 23 | sess.run(tf.global_variables_initializer()) 24 | os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1" 25 | stop = EarlyStopping(patience=4) 26 | 27 | # checkpoint = ModelCheckpoint(filepath='/checkpoint-{epoch:02d}-{val_loss:.4f}.hdf5', 28 | # monitor='val_loss', verbose=1, save_best_only=True) 29 | 30 | 31 | model = unet(lr=1e-4) 32 | model.summary() 33 | model.fit_generator(generator=generator(image_train, mask_train), 34 | steps_per_epoch=len(image_train), 35 | epochs=10, 36 | validation_data=[image_valid, mask_valid], 37 | #validation_steps=64, 38 | verbose=1, 39 | callbacks=[stop]) 40 | model.save_weights('./weight.h5') 41 | 42 | if __name__ == '__main__': 43 | train(image_path='../image.npy', mask_path='../mask.npy') -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from augmentation import data_augmentation 3 | 4 | 5 | def load_data(path): 6 | 7 | '''read data in shapes of (height ,width, n_samples)''' 8 | 9 | return np.load(path) 10 | 11 | def normalize(image): 12 | 13 | '''HU=(-1000,0), then normalize the data range(0,1)''' 14 | 15 | return (np.clip(image, -1000, 0) + 1000) / 1000 16 | 17 | 18 | def crop_data(images, size=64): 19 | 20 | '''reshape images into (64, 64, n_samples*64)''' 21 | 22 | return images.reshape((size, images.shape[1], -1), order='F').reshape((size, size, -1)) 23 | 24 | 25 | def reshape_data(images): 26 | 27 | '''transform data into shapes of (n_samples*64, 64, 64, 1)''' 28 | 29 | return np.transpose(images, (2, 1, 0))[..., None] 30 | 31 | 32 | def get_batches(images, size): 33 | 34 | '''return batches in shapes of (batches, 64, 64, 64, 1)''' 35 | 36 | return np.array(np.split(images, int(images.shape[0]/size), axis=0)) 37 | 38 | def get_split(data_train): 39 | 40 | x = int(data_train.shape[0] * 0.9) 41 | 42 | images_train = data_train[:x] 43 | images_valid = data_train[x:] 44 | 45 | return images_train, images_valid 46 | 47 | # def preprocess_data(path, size=64, replica=None, split=True): 48 | # 49 | # images = load_data(path) 50 | # images = crop_data(images, size) 51 | # 52 | # if replica != None: 53 | # images_re = np.copy(images) 54 | # for i in range(1,replica): 55 | # images = np.concatenate((images, images_re), axis=-1) 56 | # else: 57 | # pass 58 | # 59 | # images = reshape_data(images) 60 | # images = get_batches(images, size) 61 | # 62 | # if split: 63 | # images_train, images_valid = get_split(images) 64 | # 65 | # return images_train, images_valid 66 | # else: 67 | # 68 | # return images 69 | 70 | def preprocess_data_train(image_path, mask_path, size=64, replica=None, split=True): 71 | 72 | image = load_data(image_path) 73 | mask = load_data(mask_path) 74 | 75 | image = normalize(image) 76 | 77 | if replica != None: 78 | 79 | img_re = np.copy(image) 80 | msk_re = np.copy(mask) 81 | 82 | for i in range(1,replica): 83 | img_re, msk_re = data_augmentation(img_re, msk_re, size) 84 | image = np.concatenate((image, img_re), axis=-1) 85 | mask = np.concatenate((mask, msk_re), axis=-1) 86 | else: 87 | pass 88 | 89 | image = crop_data(image,size) 90 | image = reshape_data(image) 91 | image = get_batches(image, size) 92 | 93 | mask = crop_data(mask, size) 94 | mask = reshape_data(mask) 95 | mask = get_batches(mask, size) 96 | 97 | if split: 98 | image_train, image_valid = get_split(image) 99 | mask_train, mask_valid = get_split(mask) 100 | 101 | return image_train, image_valid, mask_train, mask_valid 102 | else: 103 | 104 | return image, mask 105 | 106 | 107 | def recover(images, size): 108 | 109 | images = images.reshape(-1, size, size, 1) 110 | images = images.reshape(-1, size, size) 111 | images = np.transpose(images, (2,1,0)) 112 | images = images.reshape((size, 512, -1)).reshape((512, 512, -1), order='F') 113 | 114 | return images 115 | 116 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input, BatchNormalization, MaxPool3D, Conv3D, UpSampling3D, Concatenate, Activation 2 | from keras.models import Model 3 | from keras.optimizers import Adam 4 | from dice import * 5 | 6 | def unet(lr): 7 | 8 | inputs = Input((64, 64, 64, 1)) 9 | 10 | '''downsample''' 11 | conv1 = Conv3D(8, 3, padding='same', kernel_initializer='he_normal')(inputs) 12 | batc1 = BatchNormalization(axis=-1)(conv1) 13 | acti1 = Activation('relu')(batc1) 14 | conv2 = Conv3D(16, 3, padding='same', kernel_initializer='he_normal')(acti1) 15 | batc2 = BatchNormalization(axis=-1)(conv2) 16 | acti2 = Activation('relu')(batc2) 17 | maxp1 = MaxPool3D(2)(acti2) 18 | 19 | 20 | conv3 = Conv3D(16, 3, padding='same', kernel_initializer='he_normal')(maxp1) 21 | batc3 = BatchNormalization(axis=-1)(conv3) 22 | acti3 = Activation('relu')(batc3) 23 | conv4 = Conv3D(32, 3, padding='same', kernel_initializer='he_normal')(acti3) 24 | batc4 = BatchNormalization(axis=-1)(conv4) 25 | acti4 = Activation('relu')(batc4) 26 | maxp2 = MaxPool3D(2)(acti4) 27 | 28 | 29 | conv5 = Conv3D(32, 3, padding='same', kernel_initializer='he_normal')(maxp2) 30 | batc5 = BatchNormalization(axis=-1)(conv5) 31 | acti5 = Activation('relu')(batc5) 32 | conv6 = Conv3D(64, 3, padding='same', kernel_initializer='he_normal')(acti5) 33 | batc6 = BatchNormalization(axis=-1)(conv6) 34 | acti6 = Activation('relu')(batc6) 35 | maxp3 = MaxPool3D(2)(acti6) 36 | 37 | conv7 = Conv3D(64, 3, padding='same', kernel_initializer='he_normal')(maxp3) 38 | batc7 = BatchNormalization(axis=-1)(conv7) 39 | acti7 = Activation('relu')(batc7) 40 | conv8 = Conv3D(128, 3, padding='same', kernel_initializer='he_normal')(acti7) 41 | batc8 = BatchNormalization(axis=-1)(conv8) 42 | acti8 = Activation('relu')(batc8) 43 | 44 | 45 | '''upsample''' 46 | upsa1 = UpSampling3D(2)(acti8) 47 | # print('upsam1 shape: ', upsam1.shape) 48 | merg1 = Concatenate(axis=-1)([conv6, upsa1]) 49 | conv9 = Conv3D(64, 3, padding='same', kernel_initializer='he_normal')(merg1) 50 | batc9 = BatchNormalization(axis=-1)(conv9) 51 | acti9 = Activation('relu')(batc9) 52 | conv10 = Conv3D(64, 3, padding='same', kernel_initializer='he_normal')(acti9) 53 | batc10 = BatchNormalization(axis=-1)(conv10) 54 | acti10 = Activation('relu')(batc10) 55 | 56 | upsa2 = UpSampling3D(2)(acti10) 57 | merg2 = Concatenate(axis=-1)([conv4, upsa2]) 58 | conv11 = Conv3D(32, 3, padding='same', kernel_initializer='he_normal')(merg2) 59 | batc11 = BatchNormalization(axis=-1)(conv11) 60 | acti11 = Activation('relu')(batc11) 61 | conv12 = Conv3D(32, 3, padding='same', kernel_initializer='he_normal')(acti11) 62 | batc12 = BatchNormalization(axis=-1)(conv12) 63 | acti12 = Activation('relu')(batc12) 64 | 65 | upsa3 = UpSampling3D(2)(acti12) 66 | merg3 = Concatenate(axis=-1)([conv2, upsa3]) 67 | conv13 = Conv3D(16, 3, padding='same', kernel_initializer='he_normal')(merg3) 68 | batc13 = BatchNormalization(axis=-1)(conv13) 69 | acti13 = Activation('relu')(batc13) 70 | conv14 = Conv3D(16, 3, padding='same', kernel_initializer='he_normal')(acti13) 71 | convol = Conv3D(1, 1, activation='sigmoid')(conv14) 72 | 73 | 74 | model = Model(inputs=inputs, outputs=convol) 75 | model.compile(optimizer=Adam(lr=lr), loss=dice_coef_loss, metrics=[dice_coef]) 76 | 77 | return model -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import signal, ndimage 3 | 4 | 5 | 6 | def shift(image, max_amt=0.2, seed=42): 7 | 8 | new_img = np.copy(image) 9 | shape = new_img.shape 10 | max_x = int(shape[0] * max_amt) 11 | max_y = int(shape[1] * max_amt) 12 | np.random.seed(seed) 13 | x = np.random.randint(low=-max_x, high=max_x) 14 | np.random.seed(seed) 15 | y = np.random.randint(low=-max_y, high=max_y) 16 | 17 | return ndimage.interpolation.shift(new_img,shift=[x,y]) 18 | 19 | def flipx(image, seed=42): 20 | 21 | new_img = np.copy(image) 22 | 23 | return new_img[::-1, :] 24 | 25 | def flipy(image, seed=42): 26 | 27 | new_img = np.copy(image) 28 | 29 | return new_img[:, ::-1] 30 | 31 | def rotate(image, seed=42): 32 | np.random.seed(seed) 33 | randnum = np.random.randint(1,360) 34 | new_image = np.copy(image) 35 | return ndimage.rotate(new_image, angle=randnum, reshape=False) 36 | 37 | # def elastic_distortion(image, kernel_dim=5, sigma=6, alpha=47, seed=42): 38 | # 39 | # # Returns gaussian kernel in two dimensions 40 | # # d is the square kernel edge size, it must be an odd number. 41 | # # i.e. kernel is of the size (d,d) 42 | # def gaussian_kernel(d, sigma): 43 | # if d % 2 == 0: 44 | # raise ValueError("Kernel edge size must be an odd number") 45 | # 46 | # cols_identifier = np.int32( 47 | # np.ones((d, d)) * np.array(np.arange(d))) 48 | # rows_identifier = np.int32( 49 | # np.ones((d, d)) * np.array(np.arange(d)).reshape(d, 1)) 50 | # 51 | # kernel = np.exp(-1. * ((rows_identifier - d/2)**2 + (cols_identifier - d/2)**2) / (2. * sigma**2)) 52 | # kernel *= 1. / (2. * np.pi * sigma**2) # normalize 53 | # return kernel 54 | # 55 | # np.random.seed(seed) 56 | # field_x = np.random.uniform(low=-1, high=1, size=(image.shape[1], image.shape[1])) * alpha 57 | # np.random.seed(seed) 58 | # field_y = np.random.uniform(low=-1, high=1, size=(image.shape[1], image.shape[1])) * alpha 59 | # 60 | # kernel = gaussian_kernel(kernel_dim, sigma) 61 | # 62 | # # Distortion fields convolved with the gaussian kernel 63 | # # This smoothes the field out. 64 | # field_x = signal.convolve2d(field_x, kernel, mode="same") 65 | # field_y = signal.convolve2d(field_y, kernel, mode="same") 66 | # 67 | # d = image.shape[1] 68 | # cols_identifier = np.int32(np.ones((d, d))*np.array(np.arange(d))) 69 | # rows_identifier = np.int32( 70 | # np.ones((d, d))*np.array(np.arange(d)).reshape(d, 1)) 71 | # 72 | # down_row = np.int32(np.floor(field_x)) + rows_identifier 73 | # top_row = np.int32(np.ceil(field_x)) + rows_identifier 74 | # down_col = np.int32(np.floor(field_y)) + cols_identifier 75 | # top_col = np.int32(np.ceil(field_y)) + cols_identifier 76 | # 77 | # 78 | # padded_image = np.pad( 79 | # image, pad_width=d, mode="constant", constant_values=0) 80 | # 81 | # x1 = down_row.flatten() 82 | # y1 = down_col.flatten() 83 | # x2 = top_row.flatten() 84 | # y2 = top_col.flatten() 85 | # 86 | # Q11 = padded_image[d+x1, d+y1] 87 | # Q12 = padded_image[d+x1, d+y2] 88 | # Q21 = padded_image[d+x2, d+y1] 89 | # Q22 = padded_image[d+x2, d+y2] 90 | # x = (rows_identifier + field_x).flatten() 91 | # y = (cols_identifier + field_y).flatten() 92 | # 93 | # # Bilinear interpolation algorithm is as described here: 94 | # # https://en.wikipedia.org/wiki/Bilinear_interpolation#Algorithm 95 | # distorted_image = (1. / ((x2 - x1) * (y2 - y1)))*( 96 | # Q11 * (x2 - x) * (y2 - y) + 97 | # Q21 * (x - x1) * (y2 - y) + 98 | # Q12 * (x2 - x) * (y - y1) + 99 | # Q22 * (x - x1) * (y - y1)) 100 | # 101 | # distorted_image = distorted_image.reshape((d, d)) 102 | # return distorted_image 103 | 104 | def zoom(image, seed=42): 105 | 106 | shape = image.shape 107 | half = int(image.shape[1]/2) 108 | 109 | np.random.seed(seed) 110 | z = np.random.uniform(0.5,1.5) 111 | 112 | image = ndimage.interpolation.zoom(image, z) 113 | center = (np.array(image.shape) / 2).astype(int) 114 | if z > 1: 115 | return image[(center[0]-half):(center[0]+half),(center[1]-half):(center[1]+half)] 116 | elif z < 1: 117 | image = np.pad(image, ((0,shape[0]-image.shape[0]),(0,shape[1]-image.shape[1])), 118 | mode='constant',constant_values=0) 119 | return image 120 | else: 121 | return image 122 | 123 | def smooth(image, seed=42): 124 | 125 | '''gaussian filter''' 126 | 127 | sigma = np.random.uniform(0.6, 1.3) 128 | 129 | return ndimage.gaussian_filter(image, sigma=sigma) 130 | 131 | 132 | # def brighter(image, seed=42): 133 | # # 134 | # # return image+500 135 | # # 136 | # # def darker(image, seed=42): 137 | # # 138 | # # return image-500 139 | 140 | def data_augmentation(image, mask, size=64): 141 | 142 | 143 | for i in range(0, image.shape[2], size): 144 | 145 | ops = { 146 | 0: flipx, 147 | 1: shift, 148 | 2: flipy, 149 | 3: rotate, 150 | 4: zoom, 151 | 5: smooth 152 | } 153 | 154 | which_op = np.random.randint(0, 6) 155 | 156 | 157 | for sample in range(min(64, image.shape[2]-i)): 158 | image[:, :, i+sample] = ops[which_op](image[:, :, i+sample], seed=i) 159 | if which_op != 5: 160 | mask[:, :, i+sample] = ops[which_op](mask[:, :, i+sample], seed=i) 161 | else: 162 | pass 163 | 164 | return image, mask --------------------------------------------------------------------------------