├── .gitignore ├── utils ├── __init__.py ├── getTrainData.py ├── eval.py └── utils.py ├── model ├── __init__.py ├── metrics.py ├── callBack.py └── network.py ├── README.md ├── configTrain.py ├── pre-process ├── preprocess.py ├── Untitled.ipynb └── proprecess.ipynb └── train.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.gz 2 | *.zip 3 | .ipynb_checkpoints/ 4 | __pycache__/ 5 | model_save/ 6 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from . import getTrainData 3 | from . import eval -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import network 2 | from . import callBack 3 | from . import metrics 4 | 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3DUnet-keras 2 | 3 | 4 | 3D Unet biomedical segmentation model powered by tensorpack with fast io speed. 5 | 6 | Borrow a lot of codes from https://github.com/tkuanlun350/3DUnet-Tensorflow-Brats18. 7 | I streamlined the code and changed it to the keras version. 8 | 9 | ## Dependencies 10 | - Python 3; 11 | - TensorFlow 1.12.0; 12 | - Keras 2.2.4; 13 | ``` 14 | -------------------------------------------------------------------------------- /utils/getTrainData.py: -------------------------------------------------------------------------------- 1 | from keras.utils import to_categorical 2 | import numpy as np 3 | from .utils import crop_brain_region, get_roi 4 | 5 | 6 | def image_gen3d(ls, batch_size=(48,224,224)): 7 | 8 | out_img = [] 9 | out_mask = [] 10 | 11 | while True: 12 | np.random.shuffle(ls) 13 | for data in ls: 14 | img = data 15 | mask = data[:-10] + 'mask.nii.gz' 16 | 17 | img, mask, weight, original_shape, bbox = crop_brain_region([img], mask) 18 | img, mask = get_roi(img, mask) 19 | 20 | out_img += [img] 21 | out_mask += [mask] 22 | if len(out_img) >= batch_size: 23 | yield np.stack(out_img, 0), to_categorical(np.stack(out_mask, 0)) 24 | out_img, out_mask=[], [] 25 | -------------------------------------------------------------------------------- /configTrain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | PATCH_SIZE = (48,224,224) 5 | center = "valid" #?? 6 | BATCH_SIZE = 2 7 | DATA_SAMPLING = 'all_positive' 8 | 9 | # model network 10 | DEPTH = 5 11 | RESIDUAL = True 12 | DEEP_SUPERVISION = True 13 | FILTER_GROW = True 14 | INSTANCE_NORM = True 15 | NUM_CLASS = 6 16 | BASE_FILTER = 16 17 | 18 | #callbacks list 19 | monitor = 'val_loss'#'val_loss' 20 | mode = 'min' 21 | early_p = 20 #the patience time to stop training while the loss is not down 22 | reduce_lr_p = 10 #the patience time to reduce lr 23 | 24 | min_lr = 1e-8 25 | lr = 1e-4 26 | 27 | 28 | 29 | #Keep relevant training record files 30 | time = '2020' 31 | name = 'one' 32 | 33 | path = os.path.join('model_save/', time, name) 34 | if not os.path.isdir(path): 35 | os.makedirs(path) 36 | 37 | log_csv_name = 'log' + '.csv' 38 | log_csv_path = os.path.join(path, log_csv_name) 39 | 40 | best_model_save_name = '{}_weights.best.hdf5'.format('pretrain_model') 41 | best_model_save_path = os.path.join(path, best_model_save_name) 42 | 43 | model_save_name = 'model' + '.h5' 44 | model_save_path = os.path.join(path, model_save_name) 45 | 46 | conf_save_path = os.path.join(path, 'conf.csv') 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /model/metrics.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import partial 3 | 4 | from keras import backend as K 5 | 6 | 7 | def dice_coefficient(y_true, y_pred, smooth=1.): 8 | y_true_f = K.flatten(y_true) 9 | y_pred_f = K.flatten(y_pred) 10 | intersection = K.sum(y_true_f * y_pred_f) 11 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 12 | 13 | 14 | def dice_coefficient_loss(y_true, y_pred): 15 | return -dice_coefficient(y_true, y_pred) 16 | 17 | 18 | def weighted_dice_coefficient(y_true, y_pred, axis=(-3, -2, -1), smooth=0.00001): 19 | """ 20 | Weighted dice coefficient. Default axis assumes a "channels first" data structure 21 | :param smooth: 22 | :param y_true: 23 | :param y_pred: 24 | :param axis: 25 | :return: 26 | """ 27 | return K.mean(2. * (K.sum(y_true * y_pred, 28 | axis=axis) + smooth/2)/(K.sum(y_true, 29 | axis=axis) + K.sum(y_pred, 30 | axis=axis) + smooth)) 31 | 32 | 33 | def weighted_dice_coefficient_loss(y_true, y_pred): 34 | return -weighted_dice_coefficient(y_true, y_pred) 35 | 36 | 37 | def label_wise_dice_coefficient(y_true, y_pred, label_index): 38 | return dice_coefficient(y_true[:, label_index], y_pred[:, label_index]) 39 | 40 | 41 | def get_label_dice_coefficient_function(label_index): 42 | f = partial(label_wise_dice_coefficient, label_index=label_index) 43 | f.__setattr__('__name__', 'label_{0}_dice_coef'.format(label_index)) 44 | return f 45 | 46 | 47 | dice_coef = dice_coefficient 48 | dice_coef_loss = dice_coefficient_loss -------------------------------------------------------------------------------- /model/callBack.py: -------------------------------------------------------------------------------- 1 | from keras.callbacks import CSVLogger 2 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau 3 | 4 | def call(weight_path, #The best model save path while training , mutil gou model 5 | monitor, #loss 6 | mode, #min 7 | reduce_lr_p, # 8 | early_p, # 9 | log_csv_path, # Save log 10 | save_best_only = True, 11 | save_weights_only = True, 12 | factor=0.5, 13 | min_delta=0.0001, 14 | cooldown=2, 15 | min_lr=1e-6, 16 | verbose=1 17 | ): 18 | 19 | csv_logger = CSVLogger(log_csv_path, append=True, separator=',') 20 | 21 | checkpoint = ModelCheckpoint(weight_path, 22 | mode = mode, 23 | monitor = monitor, 24 | verbose = verbose, 25 | save_best_only = save_best_only, 26 | save_weights_only = save_weights_only) 27 | 28 | reduceLROnPlat = ReduceLROnPlateau(monitor = monitor, 29 | factor = factor, 30 | patience = reduce_lr_p, 31 | verbose = 1, 32 | mode = mode, 33 | epsilon = min_delta, ## 34 | cooldown = cooldown, ## 35 | min_lr = min_lr) 36 | 37 | early = EarlyStopping(monitor = monitor, 38 | mode = mode, 39 | patience = early_p) # probably needs to be more patient 40 | 41 | callbacks_list = [checkpoint, csv_logger, reduceLROnPlat, early] 42 | return callbacks_list -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras.models import Model 3 | from keras.layers import Activation, Input, concatenate, BatchNormalization 4 | from keras.layers import Conv3D, UpSampling3D, Conv3DTranspose 5 | from keras.layers import add 6 | from keras.layers import LeakyReLU, Reshape, Lambda 7 | from keras.initializers import RandomNormal 8 | import keras 9 | import numpy as np 10 | import configTrain 11 | 12 | def myConv(x_in, nf, strides=1, kernel_size = 3): 13 | """ 14 | specific convolution module including convolution followed by leakyrelu 15 | """ 16 | x_out = Conv3D(nf, kernel_size=3, padding='same', kernel_initializer='he_normal', strides=strides)(x_in) 17 | x_out = BatchNormalization()(x_out) 18 | x_out = LeakyReLU(0.2)(x_out) 19 | return x_out 20 | 21 | 22 | def Unet3dBlock(l, n_feat): 23 | if configTrain.RESIDUAL: 24 | l_in = l 25 | for i in range(2): 26 | l = myConv(l, n_feat) 27 | return add([l_in, l]) if configTrain.RESIDUAL else l 28 | 29 | 30 | def UnetUpsample(l, num_filters): 31 | l = UpSampling3D()(l) 32 | l = myConv(l, num_filters) 33 | return l 34 | 35 | 36 | BASE_FILTER = configTrain.BASE_FILTER 37 | 38 | def unet3d(vol_size): 39 | inputs = Input(shape=vol_size) 40 | depth = configTrain.DEPTH 41 | filters = [] 42 | down_list = [] 43 | deep_supervision = None 44 | layer = myConv(inputs, BASE_FILTER) 45 | 46 | for d in range(depth): 47 | if configTrain.FILTER_GROW: 48 | num_filters = BASE_FILTER * (2**d) 49 | else: 50 | num_filters = BASE_FILTER 51 | filters.append(num_filters) 52 | layer = Unet3dBlock(layer, n_feat = num_filters) 53 | down_list.append(layer) 54 | if d != depth - 1: 55 | layer = myConv(layer, num_filters*2, strides=2) 56 | 57 | for d in range(depth-2, -1, -1): 58 | layer = UnetUpsample(layer, filters[d]) 59 | layer = concatenate([layer, down_list[d]]) 60 | layer = myConv(layer, filters[d]) 61 | layer = myConv(layer, filters[d], kernel_size = 1) 62 | 63 | if configTrain.DEEP_SUPERVISION: 64 | if 0< d < 3: 65 | pred = myConv(layer, configTrain.NUM_CLASS) 66 | if deep_supervision is None: 67 | deep_supervision = pred 68 | else: 69 | deep_supervision = add([pred, deep_supervision]) 70 | deep_supervision = UpSampling3D()(deep_supervision) 71 | 72 | layer = myConv(layer, configTrain.NUM_CLASS, kernel_size = 1) 73 | 74 | if configTrain.DEEP_SUPERVISION: 75 | layer = add([layer, deep_supervision]) 76 | layer = Conv3D(configTrain.NUM_CLASS, kernel_size = 1)(layer) 77 | x = Activation('softmax', name='softmax')(layer) 78 | # pri 79 | model = Model(inputs=[inputs], outputs=[x]) 80 | return model -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: eval.py 3 | import tqdm 4 | import os 5 | import numpy as np 6 | from tensorpack.utils.utils import get_tqdm_kwargs 7 | from .utils import * 8 | 9 | def segment_one_image(data, model_func): 10 | """ 11 | perform inference and unpad the volume to original shape 12 | """ 13 | img, _, weight, original_shape, bbox = crop_brain_region([data], None, False) 14 | temp_size = original_shape 15 | temp_bbox = bbox 16 | prob = batch_segmentation(img, model_func, data_shape = configTrain.PATCH_SIZE) 17 | pred = np.argmax(prob, axis=-1) 18 | out_label = np.asarray(pred, np.int16) 19 | final_label = np.zeros(temp_size, np.int16) 20 | final_label = set_ND_volume_roi_with_bounding_box_range(final_label, temp_bbox[0], temp_bbox[1], out_label) 21 | return final_label 22 | 23 | def batch_segmentation(temp_imgs, model, data_shape=[19, 180, 160]): 24 | batch_size = configTrain.BATCH_SIZE 25 | data_channel = 1 26 | class_num = configTrain.NUM_CLASS 27 | image_shape = temp_imgs[0].shape 28 | label_shape = [data_shape[0], data_shape[1], data_shape[2]] 29 | D, H, W = image_shape 30 | input_center = [int(D/2), int(H/2), int(W/2)] 31 | temp_prob1 = np.zeros([D, H, W, class_num]) 32 | 33 | sub_image_batches = [] 34 | for center_slice in range(int(label_shape[0]/2), D + int(label_shape[0]/2), label_shape[0]): 35 | center_slice = min(center_slice, D - int(label_shape[0]/2)) 36 | sub_image_batch = [] 37 | for chn in range(data_channel): 38 | temp_input_center = [center_slice, input_center[1], input_center[2]] 39 | sub_image = extract_roi_from_volume( 40 | temp_imgs[chn], temp_input_center, data_shape, fill="zero") 41 | sub_image_batch.append(sub_image) 42 | sub_image_batch = np.asanyarray(sub_image_batch, np.float32) 43 | sub_image_batches.append(sub_image_batch) 44 | 45 | total_batch = len(sub_image_batches) 46 | max_mini_batch = int((total_batch+batch_size-1)/batch_size) 47 | sub_label_idx1 = 0 48 | for mini_batch_idx in range(max_mini_batch): 49 | data_mini_batch = sub_image_batches[mini_batch_idx*batch_size: 50 | min((mini_batch_idx+1)*batch_size, total_batch)] 51 | if(mini_batch_idx == max_mini_batch - 1): 52 | for idx in range(batch_size - (total_batch - mini_batch_idx*batch_size)): 53 | data_mini_batch.append(np.zeros([data_channel] + list(data_shape))) 54 | data_mini_batch = np.asarray(data_mini_batch, np.float32) 55 | data_mini_batch = np.transpose(data_mini_batch, [0, 2, 3, 4, 1]) 56 | prob_mini_batch1 = model.predict_on_batch(data_mini_batch) 57 | 58 | for batch_idx in range(prob_mini_batch1.shape[0]): 59 | center_slice = sub_label_idx1*label_shape[0] + int(label_shape[0]/2) 60 | center_slice = min(center_slice, D - int(label_shape[0]/2)) 61 | temp_input_center = [center_slice, input_center[1], input_center[2], int(class_num/2)] 62 | sub_prob = np.reshape(prob_mini_batch1[batch_idx], label_shape + [class_num]) 63 | temp_prob1 = set_roi_to_volume(temp_prob1, temp_input_center, sub_prob) 64 | sub_label_idx1 = sub_label_idx1 + 1 65 | 66 | return temp_prob1 67 | 68 | -------------------------------------------------------------------------------- /pre-process/preprocess.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import warnings 4 | import shutil 5 | import argparse 6 | import SimpleITK as sitk 7 | import numpy as np 8 | from tqdm import tqdm 9 | from nipype.interfaces.ants import N4BiasFieldCorrection 10 | 11 | def N4BiasFieldCorrect(filename, output_filename): 12 | normalized = N4BiasFieldCorrection() 13 | normalized.inputs.input_image = filename 14 | normalized.inputs.output_image = output_filename 15 | normalized.run() 16 | return None 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--data', help='training data path', default="/data/dataset/BRATS2018/training/") 21 | parser.add_argument('--out', help="output path", default="./N4_Normalized") 22 | parser.add_argument('--mode', help="output path", default="training") 23 | args = parser.parse_args() 24 | if args.mode == 'test': 25 | BRATS_data = glob.glob(args.data + "/*") 26 | patient_ids = [x.split("/")[-1] for x in BRATS_data] 27 | print("Processing Testing data ...") 28 | for idx, file_name in tqdm(enumerate(BRATS_data), total=len(BRATS_data)): 29 | mod = glob.glob(file_name+"/*.nii*") 30 | output_dir = "{}/test/{}/".format(args.out, patient_ids[idx]) 31 | if not os.path.exists(output_dir): 32 | os.makedirs(output_dir) 33 | for mod_file in mod: 34 | if 'flair' not in mod_file and 'seg' not in mod_file: 35 | output_path = "{}/{}".format(output_dir, mod_file.split("/")[-1]) 36 | N4BiasFieldCorrect(mod_file, output_path) 37 | else: 38 | output_path = "{}/{}".format(output_dir, mod_file.split("/")[-1]) 39 | shutil.copy(mod_file, output_path) 40 | else: 41 | HGG_data = glob.glob(args.data + "HGG/*") 42 | LGG_data = glob.glob(args.data + "LGG/*") 43 | hgg_patient_ids = [x.split("/")[-1] for x in HGG_data] 44 | lgg_patient_ids = [x.split("/")[-1] for x in LGG_data] 45 | print("Processing HGG ...") 46 | for idx, file_name in tqdm(enumerate(HGG_data), total=len(HGG_data)): 47 | mod = glob.glob(file_name+"/*.nii*") 48 | output_dir = "{}/HGG/{}/".format(args.out, hgg_patient_ids[idx]) 49 | if not os.path.exists(output_dir): 50 | os.makedirs(output_dir) 51 | for mod_file in mod: 52 | if 'flair' not in mod_file and 'seg' not in mod_file: 53 | output_path = "{}/{}".format(output_dir, mod_file.split("/")[-1]) 54 | N4BiasFieldCorrect(mod_file, output_path) 55 | else: 56 | output_path = "{}/{}".format(output_dir, mod_file.split("/")[-1]) 57 | shutil.copy(mod_file, output_path) 58 | print("Processing LGG ...") 59 | for idx, file_name in tqdm(enumerate(LGG_data), total=len(LGG_data)): 60 | mod = glob.glob(file_name+"/*.nii*") 61 | output_dir = "{}/LGG/{}/".format(args.out, lgg_patient_ids[idx]) 62 | if not os.path.exists(output_dir): 63 | os.makedirs(output_dir) 64 | for mod_file in mod: 65 | if 'flair' not in mod_file and 'seg' not in mod_file: 66 | output_path = "{}/{}".format(output_dir, mod_file.split("/")[-1]) 67 | N4BiasFieldCorrect(mod_file, output_path) 68 | else: 69 | output_path = "{}/{}".format(output_dir, mod_file.split("/")[-1]) 70 | shutil.copy(mod_file, output_path) 71 | 72 | 73 | 74 | if __name__ == "__main__": 75 | main() -------------------------------------------------------------------------------- /pre-process/Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 12, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import SimpleITK as sitk" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 13, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "file = '/public/lixin/lung/lung-lobe-seg/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.237215747217294006286437405216_img.nii.gz'" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 14, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "data = sitk.GetArrayFromImage(sitk.ReadImage(file))" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 15, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "data": { 37 | "text/plain": [ 38 | "(230, 256, 256)" 39 | ] 40 | }, 41 | "execution_count": 15, 42 | "metadata": {}, 43 | "output_type": "execute_result" 44 | } 45 | ], 46 | "source": [ 47 | "data.shape" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 17, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "data": { 57 | "text/plain": [ 58 | "(125, 256, 256)" 59 | ] 60 | }, 61 | "execution_count": 17, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | } 65 | ], 66 | "source": [ 67 | "mask = '/public/lixin/lung/lung-lobe-seg/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.237215747217294006286437405216_mask.nii.gz'\n", 68 | "mask_data = sitk.GetArrayFromImage(sitk.ReadImage(mask))\n", 69 | "mask_data.shape" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 5, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "import numpy as np " 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 6, 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "data": { 88 | "text/plain": [ 89 | "array([ 0, 4, 5, 6, 7, 8, 512, 516, 517, 518, 519, 520],\n", 90 | " dtype=uint16)" 91 | ] 92 | }, 93 | "execution_count": 6, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "np.unique(data)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "# 4,5,6,7,8" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 7, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "img_path = '/public/lixin/lung/lung-lobe-seg/matchedfile/9/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.416701701108520592702405866796.mhd'" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 8, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "import ants" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 9, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "img = ants.image_read(img_path)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 10, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "data": { 143 | "text/plain": [ 144 | "(512, 512, 277)" 145 | ] 146 | }, 147 | "execution_count": 10, 148 | "metadata": {}, 149 | "output_type": "execute_result" 150 | } 151 | ], 152 | "source": [ 153 | "img.shape" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 11, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "new_img = ants.resample_image(img, (256,256,int(img.shape[2]/2)), True, 0)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 25, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "data": { 172 | "text/plain": [ 173 | "(256, 256, 80)" 174 | ] 175 | }, 176 | "execution_count": 25, 177 | "metadata": {}, 178 | "output_type": "execute_result" 179 | } 180 | ], 181 | "source": [ 182 | "new_img.shape" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 27, 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "data": { 192 | "text/plain": [ 193 | "" 201 | ] 202 | }, 203 | "execution_count": 27, 204 | "metadata": {}, 205 | "output_type": "execute_result" 206 | } 207 | ], 208 | "source": [ 209 | "new_img.min" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 18, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stderr", 219 | "output_type": "stream", 220 | "text": [ 221 | "Using TensorFlow backend.\n" 222 | ] 223 | }, 224 | { 225 | "name": "stdout", 226 | "output_type": "stream", 227 | "text": [ 228 | "Segmentation Models: using `keras` framework.\n" 229 | ] 230 | } 231 | ], 232 | "source": [ 233 | "import segmentation_models as sm" 234 | ] 235 | } 236 | ], 237 | "metadata": { 238 | "kernelspec": { 239 | "display_name": "Python 3", 240 | "language": "python", 241 | "name": "python3" 242 | }, 243 | "language_info": { 244 | "codemirror_mode": { 245 | "name": "ipython", 246 | "version": 3 247 | }, 248 | "file_extension": ".py", 249 | "mimetype": "text/x-python", 250 | "name": "python", 251 | "nbconvert_exporter": "python", 252 | "pygments_lexer": "ipython3", 253 | "version": "3.6.8" 254 | } 255 | }, 256 | "nbformat": 4, 257 | "nbformat_minor": 2 258 | } 259 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import os 4 | import SimpleITK as sitk 5 | import pickle 6 | from scipy import ndimage 7 | import configTrain 8 | import copy 9 | 10 | 11 | def crop_brain_region(imgs, mask, with_gt=True): 12 | 13 | WINDOW_LEVEL = (1000,-500) 14 | volume_list = [] 15 | for idx, mod in enumerate(imgs): 16 | volume = sitk.ReadImage(mod) 17 | img_window = sitk.Cast(sitk.IntensityWindowing(volume, 18 | windowMinimum=WINDOW_LEVEL[1] - WINDOW_LEVEL[0] / 2.0, 19 | windowMaximum=WINDOW_LEVEL[1] + WINDOW_LEVEL[0] / 2.0), 20 | sitk.sitkUInt8) 21 | img_window_arr = sitk.GetArrayFromImage(img_window) 22 | if idx == 0: 23 | original_shape = img_window_arr.shape #return 24 | margin = 1 25 | bbmin, bbmax = get_none_zero_region(img_window_arr, margin) 26 | weight = np.asarray(img_window_arr > 0, np.float32) #return 27 | img_window_arr_crop = crop_ND_volume_with_bounding_box(img_window_arr, bbmin, bbmax) 28 | img_window_arr_crop_norm = itensity_normalize_one_volume(img_window_arr_crop) 29 | volume_list.append(img_window_arr_crop_norm) 30 | if with_gt: 31 | label = load_nifty_volume_as_array(mask) 32 | 33 | label[label == 1] = 0 34 | label[label == 4] = 1 35 | label[label == 5] = 2 36 | label[label == 6] = 3 37 | label[label == 7] = 4 38 | label[label == 8] = 5 39 | label[label > 10] = 0 40 | 41 | label = crop_ND_volume_with_bounding_box(label, bbmin, bbmax) 42 | return volume_list, label, weight, original_shape, [bbmin, bbmax] 43 | else: 44 | return volume_list, None, weight, original_shape, [bbmin, bbmax] 45 | 46 | def get_roi(img, mask): 47 | volume_shape = img[0].shape 48 | center_point = get_random_roi_sampling_center(volume_shape, configTrain.PATCH_SIZE, configTrain.center) 49 | sub_label = extract_roi_from_volume(mask, 50 | center_point, 51 | configTrain.PATCH_SIZE, 52 | fill = 'zero') 53 | 54 | sub_img = extract_roi_from_volume(img[0], 55 | center_point, 56 | configTrain.PATCH_SIZE, 57 | fill = 'zero') 58 | 59 | axis = [1,2,3,0] 60 | sub_data = np.transpose([sub_img], axis) 61 | sub_label = np.transpose(sub_label[np.newaxis, ...], axis) 62 | return sub_data, sub_label 63 | 64 | 65 | def get_random_roi_sampling_center(input_shape, output_shape, sample_mode='full'): 66 | """ 67 | get a random coordinate representing the center of a roi for sampling 68 | inputs: 69 | input_shape: the shape of sampled volume 70 | output_shape: the desired roi shape 71 | sample_mode: 'valid': the entire roi should be inside the input volume 72 | 'full': only the roi centre should be inside the input volume 73 | bounding_box: the bounding box which the roi center should be limited to 74 | outputs: 75 | center: the output center coordinate of a roi 76 | """ 77 | center = [] 78 | for i in range(len(input_shape)): 79 | if sample_mode == 'full': 80 | x0 = 0 81 | x1 = input_shape[i] 82 | elif sample_mode == 'valid': 83 | x0 = int(output_shape[i]/2) 84 | x1 = input_shape[i] - x0 85 | 86 | if x1 <= x0: #?? 87 | centeri = int((x0 + x1)/2) 88 | else: 89 | centeri = random.randint(x0, x1) 90 | center.append(centeri) 91 | return center 92 | 93 | def crop_ND_volume_with_bounding_box(volume, min_idx, max_idx): 94 | """ 95 | crop/extract a subregion form an nd image. 96 | """ 97 | dim = len(volume.shape) 98 | assert(dim >= 2 and dim <= 5) 99 | if(dim == 2): 100 | output = volume[np.ix_(range(min_idx[0], max_idx[0] + 1), 101 | range(min_idx[1], max_idx[1] + 1))] 102 | elif(dim == 3): 103 | output = volume[np.ix_(range(min_idx[0], max_idx[0] + 1), 104 | range(min_idx[1], max_idx[1] + 1), 105 | range(min_idx[2], max_idx[2] + 1))] 106 | elif(dim == 4): 107 | output = volume[np.ix_(range(min_idx[0], max_idx[0] + 1), 108 | range(min_idx[1], max_idx[1] + 1), 109 | range(min_idx[2], max_idx[2] + 1), 110 | range(min_idx[3], max_idx[3] + 1))] 111 | elif(dim == 5): 112 | output = volume[np.ix_(range(min_idx[0], max_idx[0] + 1), 113 | range(min_idx[1], max_idx[1] + 1), 114 | range(min_idx[2], max_idx[2] + 1), 115 | range(min_idx[3], max_idx[3] + 1), 116 | range(min_idx[4], max_idx[4] + 1))] 117 | else: 118 | raise ValueError("the dimension number shoud be 2 to 5") 119 | return output 120 | 121 | def extract_roi_from_volume(volume, in_center, output_shape, fill = 'random'): 122 | """ 123 | extract a roi from a 3d volume 124 | inputs: 125 | volume: the input 3D volume 126 | in_center: the center of the roi 127 | output_shape: the size of the roi 128 | fill: 'random' or 'zero', the mode to fill roi region where is outside of the input volume 129 | outputs: 130 | output: the roi volume 131 | """ 132 | input_shape = volume.shape 133 | if(fill == 'random'): 134 | output = np.random.normal(0, 1, size = output_shape) 135 | else: 136 | output = np.zeros(output_shape) 137 | r0max = [int(x/2) for x in output_shape] 138 | r1max = [output_shape[i] - r0max[i] for i in range(len(r0max))] 139 | r0 = [min(r0max[i], in_center[i]) for i in range(len(r0max))] 140 | r1 = [min(r1max[i], input_shape[i] - in_center[i]) for i in range(len(r0max))] 141 | out_center = r0max 142 | 143 | output[np.ix_(range(out_center[0] - r0[0], out_center[0] + r1[0]), 144 | range(out_center[1] - r0[1], out_center[1] + r1[1]), 145 | range(out_center[2] - r0[2], out_center[2] + r1[2]))] = \ 146 | volume[np.ix_(range(in_center[0] - r0[0], in_center[0] + r1[0]), 147 | range(in_center[1] - r0[1], in_center[1] + r1[1]), 148 | range(in_center[2] - r0[2], in_center[2] + r1[2]))] 149 | return output 150 | 151 | def load_nifty_volume_as_array(filename): 152 | """ 153 | load nifty image into numpy array, and transpose it based on the [z,y,x] axis order 154 | The output array shape is like [Depth, Height, Width] 155 | inputs: 156 | filename: the input file name, should be *.nii or *.nii.gz 157 | outputs: 158 | data: a numpy data array 159 | """ 160 | img = sitk.ReadImage(filename) 161 | img_arr = sitk.GetArrayFromImage(img) 162 | return img_arr 163 | 164 | def itensity_normalize_one_volume(volume): 165 | """ 166 | normalize the itensity of an nd volume based on the mean and std of nonzeor region 167 | inputs: 168 | volume: the input nd volume 169 | outputs: 170 | out: the normalized nd volume 171 | """ 172 | pixels = volume[volume > 0] 173 | mean = pixels.mean() 174 | std = pixels.std() 175 | out = (volume - mean)/std 176 | out_random = np.zeros(volume.shape) 177 | out[volume == 0] = out_random[volume == 0] 178 | return out 179 | 180 | def get_ND_bounding_box(label, margin): 181 | """ 182 | get the bounding box of the non-zero region of an ND volume 183 | """ 184 | input_shape = label.shape 185 | if(type(margin) is int ): 186 | margin = [margin]*len(input_shape) 187 | assert(len(input_shape) == len(margin)) 188 | indxes = np.nonzero(label) 189 | idx_min = [] 190 | idx_max = [] 191 | for i in range(len(input_shape)): 192 | idx_min.append(indxes[i].min()) 193 | idx_max.append(indxes[i].max()) 194 | 195 | for i in range(len(input_shape)): 196 | idx_min[i] = max(idx_min[i] - margin[i], 0) 197 | idx_max[i] = min(idx_max[i] + margin[i], input_shape[i] - 1) 198 | return idx_min, idx_max 199 | 200 | def set_ND_volume_roi_with_bounding_box_range(volume, bb_min, bb_max, sub_volume): 201 | """ 202 | set a subregion to an nd image. 203 | """ 204 | dim = len(bb_min) 205 | out = volume 206 | if(dim == 2): 207 | out[np.ix_(range(bb_min[0], bb_max[0] + 1), 208 | range(bb_min[1], bb_max[1] + 1))] = sub_volume 209 | elif(dim == 3): 210 | out[np.ix_(range(bb_min[0], bb_max[0] + 1), 211 | range(bb_min[1], bb_max[1] + 1), 212 | range(bb_min[2], bb_max[2] + 1))] = sub_volume 213 | elif(dim == 4): 214 | out[np.ix_(range(bb_min[0], bb_max[0] + 1), 215 | range(bb_min[1], bb_max[1] + 1), 216 | range(bb_min[2], bb_max[2] + 1), 217 | range(bb_min[3], bb_max[3] + 1))] = sub_volume 218 | else: 219 | raise ValueError("array dimension should be 2, 3 or 4") 220 | return out 221 | 222 | def set_roi_to_volume(volume, center, sub_volume): 223 | """ 224 | set the content of an roi of a 3d/4d volume to a sub volume 225 | inputs: 226 | volume: the input 3D/4D volume 227 | center: the center of the roi 228 | sub_volume: the content of sub volume 229 | outputs: 230 | output_volume: the output 3D/4D volume 231 | """ 232 | volume_shape = volume.shape 233 | patch_shape = sub_volume.shape 234 | output_volume = volume 235 | for i in range(len(center)): 236 | if(center[i] >= volume_shape[i]): 237 | return output_volume 238 | r0max = [int(x/2) for x in patch_shape] 239 | r1max = [patch_shape[i] - r0max[i] for i in range(len(r0max))] 240 | r0 = [min(r0max[i], center[i]) for i in range(len(r0max))] 241 | r1 = [min(r1max[i], volume_shape[i] - center[i]) for i in range(len(r0max))] 242 | patch_center = r0max 243 | 244 | if(len(center) == 3): 245 | output_volume[np.ix_(range(center[0] - r0[0], center[0] + r1[0]), 246 | range(center[1] - r0[1], center[1] + r1[1]), 247 | range(center[2] - r0[2], center[2] + r1[2]))] = \ 248 | sub_volume[np.ix_(range(patch_center[0] - r0[0], patch_center[0] + r1[0]), 249 | range(patch_center[1] - r0[1], patch_center[1] + r1[1]), 250 | range(patch_center[2] - r0[2], patch_center[2] + r1[2]))] 251 | elif(len(center) == 4): 252 | output_volume[np.ix_(range(center[0] - r0[0], center[0] + r1[0]), 253 | range(center[1] - r0[1], center[1] + r1[1]), 254 | range(center[2] - r0[2], center[2] + r1[2]), 255 | range(center[3] - r0[3], center[3] + r1[3]))] = \ 256 | sub_volume[np.ix_(range(patch_center[0] - r0[0], patch_center[0] + r1[0]), 257 | range(patch_center[1] - r0[1], patch_center[1] + r1[1]), 258 | range(patch_center[2] - r0[2], patch_center[2] + r1[2]), 259 | range(patch_center[3] - r0[3], patch_center[3] + r1[3]))] 260 | else: 261 | raise ValueError("array dimension should be 3 or 4") 262 | return output_volume 263 | 264 | def get_none_zero_region(im, margin): 265 | """ 266 | get the bounding box of the non-zero region of an ND volume 267 | """ 268 | input_shape = im.shape 269 | if(type(margin) is int ): 270 | margin = [margin]*len(input_shape) 271 | assert(len(input_shape) == len(margin)) 272 | indxes = np.nonzero(im) 273 | idx_min = [] 274 | idx_max = [] 275 | for i in range(len(input_shape)): 276 | idx_min.append(indxes[i].min()) 277 | idx_max.append(indxes[i].max()) 278 | 279 | for i in range(len(input_shape)): 280 | idx_min[i] = max(idx_min[i] - margin[i], 0) 281 | idx_max[i] = min(idx_max[i] + margin[i], input_shape[i] - 1) 282 | return idx_min, idx_max 283 | 284 | 285 | -------------------------------------------------------------------------------- /pre-process/proprecess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd \n", 10 | "import numpy as np \n", 11 | "import os \n", 12 | "import ants" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "# Get Id " 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "fold = '/public/lixin/lung/lung-lobe-seg/matchedfile'" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 3, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "fold_list = os.listdir(fold)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 4, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "data": { 47 | "text/plain": [ 48 | "47" 49 | ] 50 | }, 51 | "execution_count": 4, 52 | "metadata": {}, 53 | "output_type": "execute_result" 54 | } 55 | ], 56 | "source": [ 57 | "len(fold_list)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 5, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "text/plain": [ 68 | "['50',\n", 69 | " '25',\n", 70 | " '26',\n", 71 | " '9',\n", 72 | " '33',\n", 73 | " '20',\n", 74 | " '28',\n", 75 | " '46',\n", 76 | " '43',\n", 77 | " '31',\n", 78 | " '11',\n", 79 | " '36',\n", 80 | " '35',\n", 81 | " '23',\n", 82 | " '27',\n", 83 | " '15',\n", 84 | " '37',\n", 85 | " '12',\n", 86 | " '44',\n", 87 | " '17',\n", 88 | " '24',\n", 89 | " '19',\n", 90 | " '48',\n", 91 | " '6',\n", 92 | " '13',\n", 93 | " '34',\n", 94 | " '41',\n", 95 | " '49',\n", 96 | " '18',\n", 97 | " '38',\n", 98 | " '32',\n", 99 | " '10',\n", 100 | " '8',\n", 101 | " '30',\n", 102 | " '3',\n", 103 | " '42',\n", 104 | " '16',\n", 105 | " '39',\n", 106 | " '1',\n", 107 | " '22',\n", 108 | " '5',\n", 109 | " '2',\n", 110 | " '47',\n", 111 | " '45',\n", 112 | " '21',\n", 113 | " '40',\n", 114 | " '4']" 115 | ] 116 | }, 117 | "execution_count": 5, 118 | "metadata": {}, 119 | "output_type": "execute_result" 120 | } 121 | ], 122 | "source": [ 123 | "fold_list" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 4, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "fold_nii = '/public/lixin/lung/lung-lobe-seg/data/'" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 7, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "num = 0\n", 142 | "for fo in fold_list:\n", 143 | " path = os.listdir(os.path.join(fold, fo))\n", 144 | "# print(path)\n", 145 | "\n", 146 | " for file in path:\n", 147 | "# print(file[-4:])\n", 148 | " if file[-4:] == 'nrrd':\n", 149 | " ID = file.split('_')[0]\n", 150 | "# print(ID)\n", 151 | "# mask = ants.image_read(os.path.join(fold, fo, file))\n", 152 | "# print(mask.shape)\n", 153 | "# new_mask = ants.resample_image(mask, (256,256,int(mask.shape[2]/2)), True, 1)\n", 154 | "# print(new_mask.shape)\n", 155 | "# new_mask_path = os.path.join(fold_nii, ID + '_mask.nii.gz')\n", 156 | "# ants.image_write(new_mask, new_mask_path) \n", 157 | "\n", 158 | " if file == 'data':\n", 159 | "# print('ID', ID)\n", 160 | "# print(fo, 'fo')\n", 161 | " img_fold = os.path.join(fold, fo, file)\n", 162 | " path_img = os.listdir(img_fold)\n", 163 | " for img in path_img:\n", 164 | " if img[-3:] == 'mhd':\n", 165 | "# ID = img.split('.')[0]\n", 166 | "# print(os.path.join(fold, fo, 'data', img))\n", 167 | " da_img = ants.image_read(os.path.join(fold, fo, 'data', img))\n", 168 | "# print('da_img', da_img.shape)\n", 169 | " new_img_path = os.path.join(fold_nii, ID + '_img.nii.gz')\n", 170 | " if not os.path.exists(new_img_path):\n", 171 | " print('continue')\n", 172 | " print('ID', ID)\n", 173 | " print(fo, 'fo')\n", 174 | " num += 1\n", 175 | " print(num)\n", 176 | " continue\n", 177 | "\n", 178 | "# print(os.path.join(fold, fo, 'data', img))\n", 179 | "# new_da_img = ants.resample_image(da_img, (256, 256, int(da_img.shape[2]/2)), True, 0)\n", 180 | "# print('new_da_img', new_da_img.shape)\n", 181 | " \n", 182 | "# ants.image_write(new_da_img, new_img_path)\n", 183 | "# num += 1\n", 184 | "# print(num)\n", 185 | "\n", 186 | "\n", 187 | " # print(path)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 5, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "import glob" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 6, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "all_path = glob.glob(os.path.join(fold, '*' ,'data', '*.mhd'))" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 7, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "all_mask_path = glob.glob(os.path.join(fold, '*', '*.nrrd'))" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 8, 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "data": { 224 | "text/plain": [ 225 | "47" 226 | ] 227 | }, 228 | "execution_count": 8, 229 | "metadata": {}, 230 | "output_type": "execute_result" 231 | } 232 | ], 233 | "source": [ 234 | "len(all_path)" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 20, 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "data": { 244 | "text/plain": [ 245 | "['/public/lixin/lung/lung-lobe-seg/matchedfile/50/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.199069398344356765037879821616.mhd',\n", 246 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/25/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.297251044869095073091780740645.mhd',\n", 247 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/26/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.307835307280028057486413359377.mhd',\n", 248 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/9/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.416701701108520592702405866796.mhd',\n", 249 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/33/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.300392272203629213913702120739.mhd',\n", 250 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/20/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.503980049263254396021509831276.mhd',\n", 251 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/28/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.413896555982844732694353377538.mhd',\n", 252 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/46/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.842980983137518332429408284002.mhd',\n", 253 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/43/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.122914038048856168343065566972.mhd',\n", 254 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/31/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.330643702676971528301859647742.mhd']" 255 | ] 256 | }, 257 | "execution_count": 20, 258 | "metadata": {}, 259 | "output_type": "execute_result" 260 | } 261 | ], 262 | "source": [ 263 | "all_path[:10]" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 21, 269 | "metadata": {}, 270 | "outputs": [ 271 | { 272 | "data": { 273 | "text/plain": [ 274 | "['/public/lixin/lung/lung-lobe-seg/matchedfile/50/1.3.6.1.4.1.14519.5.2.1.6279.6001.199069398344356765037879821616_LobeSegmentation.nrrd',\n", 275 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/25/1.3.6.1.4.1.14519.5.2.1.6279.6001.297251044869095073091780740645_LobeSegmentation.nrrd',\n", 276 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/26/1.3.6.1.4.1.14519.5.2.1.6279.6001.307835307280028057486413359377_LobeSegmentation.nrrd',\n", 277 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/9/1.3.6.1.4.1.14519.5.2.1.6279.6001.416701701108520592702405866796_LobeSegmentation.nrrd',\n", 278 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/33/1.3.6.1.4.1.14519.5.2.1.6279.6001.300392272203629213913702120739_LobeSegmentation.nrrd',\n", 279 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/20/1.3.6.1.4.1.14519.5.2.1.6279.6001.503980049263254396021509831276_LobeSegmentation.nrrd',\n", 280 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/28/1.3.6.1.4.1.14519.5.2.1.6279.6001.413896555982844732694353377538_LobeSegmentation.nrrd',\n", 281 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/46/1.3.6.1.4.1.14519.5.2.1.6279.6001.842980983137518332429408284002_LobeSegmentation.nrrd',\n", 282 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/43/1.3.6.1.4.1.14519.5.2.1.6279.6001.122914038048856168343065566972_LobeSegmentation.nrrd',\n", 283 | " '/public/lixin/lung/lung-lobe-seg/matchedfile/31/1.3.6.1.4.1.14519.5.2.1.6279.6001.330643702676971528301859647742_LobeSegmentation.nrrd']" 284 | ] 285 | }, 286 | "execution_count": 21, 287 | "metadata": {}, 288 | "output_type": "execute_result" 289 | } 290 | ], 291 | "source": [ 292 | "all_mask_path[:10]" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 22, 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "data": { 302 | "text/plain": [ 303 | "'/public/lixin/lung/lung-lobe-seg/matchedfile/50/data/1.3.6.1.4.1.14519.5.2.1.6279.6001.199069398344356765037879821616.mhd'" 304 | ] 305 | }, 306 | "execution_count": 22, 307 | "metadata": {}, 308 | "output_type": "execute_result" 309 | } 310 | ], 311 | "source": [ 312 | "all_path[0]" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 23, 318 | "metadata": {}, 319 | "outputs": [ 320 | { 321 | "name": "stdout", 322 | "output_type": "stream", 323 | "text": [ 324 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.199069398344356765037879821616\n", 325 | "new_da_img (256, 256, 236)\n", 326 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.297251044869095073091780740645\n", 327 | "new_da_img (256, 256, 224)\n", 328 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.307835307280028057486413359377\n", 329 | "new_da_img (256, 256, 66)\n", 330 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.416701701108520592702405866796\n", 331 | "new_da_img (256, 256, 138)\n", 332 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.300392272203629213913702120739\n", 333 | "new_da_img (256, 256, 205)\n", 334 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.503980049263254396021509831276\n", 335 | "new_da_img (256, 256, 56)\n", 336 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.413896555982844732694353377538\n", 337 | "new_da_img (256, 256, 77)\n", 338 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.842980983137518332429408284002\n", 339 | "new_da_img (256, 256, 115)\n", 340 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.122914038048856168343065566972\n", 341 | "new_da_img (256, 256, 64)\n", 342 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.330643702676971528301859647742\n", 343 | "new_da_img (256, 256, 128)\n", 344 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.199220738144407033276946096708\n", 345 | "new_da_img (256, 256, 142)\n", 346 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.188209889686363159853715266493\n", 347 | "new_da_img (256, 256, 73)\n", 348 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.188385286346390202873004762827\n", 349 | "new_da_img (256, 256, 190)\n", 350 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.868211851413924881662621747734\n", 351 | "new_da_img (256, 256, 123)\n", 352 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.237215747217294006286437405216\n", 353 | "new_da_img (256, 256, 125)\n", 354 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.193964947698259739624715468431\n", 355 | "new_da_img (256, 256, 230)\n", 356 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.404364125369979066736354549484\n", 357 | "new_da_img (256, 256, 88)\n", 358 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.286061375572911414226912429210\n", 359 | "new_da_img (256, 256, 128)\n", 360 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.255999614855292116767517149228\n", 361 | "new_da_img (256, 256, 68)\n", 362 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.249530219848512542668813996730\n", 363 | "new_da_img (256, 256, 240)\n", 364 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.194465340552956447447896167830\n", 365 | "new_da_img (256, 256, 64)\n", 366 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.333145094436144085379032922488\n", 367 | "new_da_img (256, 256, 69)\n", 368 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.670107649586205629860363487713\n", 369 | "new_da_img (256, 256, 120)\n", 370 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.261678072503577216586082745513\n", 371 | "new_da_img (256, 256, 89)\n", 372 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.134638281277099121660656324702\n", 373 | "new_da_img (256, 256, 62)\n", 374 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.183982839679953938397312236359\n", 375 | "new_da_img (256, 256, 69)\n", 376 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.192256506776434538421891524301\n", 377 | "new_da_img (256, 256, 93)\n", 378 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.146987333806092287055399155268\n", 379 | "new_da_img (256, 256, 73)\n", 380 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.119806527488108718706404165837\n", 381 | "new_da_img (256, 256, 280)\n", 382 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.295298571102631191572192562523\n", 383 | "new_da_img (256, 256, 116)\n", 384 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.323859712968543712594665815359\n", 385 | "new_da_img (256, 256, 133)\n", 386 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.317087518531899043292346860596\n", 387 | "new_da_img (256, 256, 237)\n", 388 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.119209873306155771318545953948\n", 389 | "new_da_img (256, 256, 259)\n", 390 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.183843376225716802567192412456\n", 391 | "new_da_img (256, 256, 66)\n", 392 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.177685820605315926524514718990\n", 393 | "new_da_img (256, 256, 136)\n", 394 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.187451715205085403623595258748\n", 395 | "new_da_img (256, 256, 62)\n", 396 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.888291896309937415860209787179\n", 397 | "new_da_img (256, 256, 136)\n", 398 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.259018373683540453277752706262\n", 399 | "new_da_img (256, 256, 65)\n", 400 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.109002525524522225658609808059\n", 401 | "new_da_img (256, 256, 80)\n", 402 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.151669338315069779994664893123\n", 403 | "new_da_img (256, 256, 62)\n", 404 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.272042302501586336192628818865\n", 405 | "new_da_img (256, 256, 140)\n", 406 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.119304665257760307862874140576\n", 407 | "new_da_img (256, 256, 150)\n", 408 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.112740418331256326754121315800\n", 409 | "new_da_img (256, 256, 74)\n", 410 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.332453873575389860371315979768\n", 411 | "new_da_img (256, 256, 102)\n", 412 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.898642529028521482602829374444\n", 413 | "new_da_img (256, 256, 100)\n", 414 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.316911475886263032009840828684\n", 415 | "new_da_img (256, 256, 103)\n", 416 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.112767175295249119452142211437\n", 417 | "new_da_img (256, 256, 58)\n" 418 | ] 419 | } 420 | ], 421 | "source": [ 422 | "for file in all_path:\n", 423 | " ID = file.split('/')[-1][:-4]\n", 424 | " print(ID)\n", 425 | " da_img = ants.image_read(file)\n", 426 | " new_img_path = os.path.join(fold_nii, ID + '_img.nii.gz')\n", 427 | "# if os.path.exists(new_img_path):\n", 428 | "# print('continue')\n", 429 | "# continue \n", 430 | " new_da_img = ants.resample_image(da_img, (256, 256, int(da_img.shape[2]/2)), True, 0)\n", 431 | " print('new_da_img', new_da_img.shape)\n", 432 | " ants.image_write(new_da_img, new_img_path) " 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 11, 438 | "metadata": {}, 439 | "outputs": [ 440 | { 441 | "name": "stdout", 442 | "output_type": "stream", 443 | "text": [ 444 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.199069398344356765037879821616\n", 445 | "new_da_img (256, 256, 236)\n", 446 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.297251044869095073091780740645\n", 447 | "new_da_img (256, 256, 224)\n", 448 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.307835307280028057486413359377\n", 449 | "new_da_img (256, 256, 66)\n", 450 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.416701701108520592702405866796\n", 451 | "new_da_img (256, 256, 138)\n", 452 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.300392272203629213913702120739\n", 453 | "new_da_img (256, 256, 205)\n", 454 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.503980049263254396021509831276\n", 455 | "new_da_img (256, 256, 56)\n", 456 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.413896555982844732694353377538\n", 457 | "new_da_img (256, 256, 77)\n", 458 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.842980983137518332429408284002\n", 459 | "new_da_img (256, 256, 115)\n", 460 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.122914038048856168343065566972\n", 461 | "new_da_img (256, 256, 64)\n", 462 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.330643702676971528301859647742\n", 463 | "new_da_img (256, 256, 128)\n", 464 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.199220738144407033276946096708\n", 465 | "new_da_img (256, 256, 142)\n", 466 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.188209889686363159853715266493\n", 467 | "new_da_img (256, 256, 73)\n", 468 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.188385286346390202873004762827\n", 469 | "new_da_img (256, 256, 190)\n", 470 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.868211851413924881662621747734\n", 471 | "new_da_img (256, 256, 123)\n", 472 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.237215747217294006286437405216\n", 473 | "new_da_img (256, 256, 125)\n", 474 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.193964947698259739624715468431\n", 475 | "new_da_img (256, 256, 230)\n", 476 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.404364125369979066736354549484\n", 477 | "new_da_img (256, 256, 88)\n", 478 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.286061375572911414226912429210\n", 479 | "new_da_img (256, 256, 128)\n", 480 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.255999614855292116767517149228\n", 481 | "new_da_img (256, 256, 68)\n", 482 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.249530219848512542668813996730\n", 483 | "new_da_img (256, 256, 240)\n", 484 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.194465340552956447447896167830\n", 485 | "new_da_img (256, 256, 64)\n", 486 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.333145094436144085379032922488\n", 487 | "new_da_img (256, 256, 69)\n", 488 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.670107649586205629860363487713\n", 489 | "new_da_img (256, 256, 120)\n", 490 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.261678072503577216586082745513\n", 491 | "new_da_img (256, 256, 89)\n", 492 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.134638281277099121660656324702\n", 493 | "new_da_img (256, 256, 62)\n", 494 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.183982839679953938397312236359\n", 495 | "new_da_img (256, 256, 69)\n", 496 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.192256506776434538421891524301\n", 497 | "new_da_img (256, 256, 93)\n", 498 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.146987333806092287055399155268\n", 499 | "new_da_img (256, 256, 73)\n", 500 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.119806527488108718706404165837\n", 501 | "new_da_img (256, 256, 280)\n", 502 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.295298571102631191572192562523\n", 503 | "new_da_img (256, 256, 116)\n", 504 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.323859712968543712594665815359\n", 505 | "new_da_img (256, 256, 133)\n", 506 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.317087518531899043292346860596\n", 507 | "new_da_img (256, 256, 237)\n", 508 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.119209873306155771318545953948\n", 509 | "new_da_img (256, 256, 259)\n", 510 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.183843376225716802567192412456\n", 511 | "new_da_img (256, 256, 66)\n", 512 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.177685820605315926524514718990\n", 513 | "new_da_img (256, 256, 136)\n", 514 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.187451715205085403623595258748\n", 515 | "new_da_img (256, 256, 62)\n", 516 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.888291896309937415860209787179\n", 517 | "new_da_img (256, 256, 136)\n", 518 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.259018373683540453277752706262\n", 519 | "new_da_img (256, 256, 65)\n", 520 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.109002525524522225658609808059\n", 521 | "new_da_img (256, 256, 80)\n", 522 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.151669338315069779994664893123\n", 523 | "new_da_img (256, 256, 62)\n", 524 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.272042302501586336192628818865\n", 525 | "new_da_img (256, 256, 140)\n", 526 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.119304665257760307862874140576\n", 527 | "new_da_img (256, 256, 150)\n", 528 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.112740418331256326754121315800\n", 529 | "new_da_img (256, 256, 74)\n", 530 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.332453873575389860371315979768\n", 531 | "new_da_img (256, 256, 102)\n", 532 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.898642529028521482602829374444\n", 533 | "new_da_img (256, 256, 100)\n", 534 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.316911475886263032009840828684\n", 535 | "new_da_img (256, 256, 103)\n", 536 | "1.3.6.1.4.1.14519.5.2.1.6279.6001.112767175295249119452142211437\n", 537 | "new_da_img (256, 256, 58)\n" 538 | ] 539 | } 540 | ], 541 | "source": [ 542 | "for file in all_mask_path:\n", 543 | " ID = file.split('/')[-1]\n", 544 | " ID = ID.split('_')[0]\n", 545 | " print(ID)\n", 546 | " da_img = ants.image_read(file)\n", 547 | " new_img_path = os.path.join(fold_nii, ID + '_mask.nii.gz')\n", 548 | "# if os.path.exists(new_img_path):\n", 549 | "# print('continue')\n", 550 | "# continue \n", 551 | " new_da_img = ants.resample_image(da_img, (256, 256, int(da_img.shape[2]/2)), True, 1)\n", 552 | " print('new_da_img', new_da_img.shape)\n", 553 | " ants.image_write(new_da_img, new_img_path) " 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": null, 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [] 562 | } 563 | ], 564 | "metadata": { 565 | "kernelspec": { 566 | "display_name": "Python 3", 567 | "language": "python", 568 | "name": "python3" 569 | }, 570 | "language_info": { 571 | "codemirror_mode": { 572 | "name": "ipython", 573 | "version": 3 574 | }, 575 | "file_extension": ".py", 576 | "mimetype": "text/x-python", 577 | "name": "python", 578 | "nbconvert_exporter": "python", 579 | "pygments_lexer": "ipython3", 580 | "version": "3.6.8" 581 | } 582 | }, 583 | "nbformat": 4, 584 | "nbformat_minor": 2 585 | } 586 | -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Using TensorFlow backend.\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import os \n", 18 | "import glob\n", 19 | "import random\n", 20 | "import numpy as np \n", 21 | "import pandas as pd \n", 22 | "import SimpleITK as sitk\n", 23 | "from keras.optimizers import Adam\n", 24 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", 25 | "\n", 26 | "import model.metrics as mm\n", 27 | "import model.callBack as mc\n", 28 | "from model.network import unet3d\n", 29 | "from utils.utils import *\n", 30 | "from utils.getTrainData import image_gen3d\n", 31 | "import configTrain" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "folder = '/public/lixin/lung/lung-lobe-seg/data/'" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "img_list = glob.glob(os.path.join(folder, '*_img.nii.gz'))" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 4, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "47" 61 | ] 62 | }, 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "len(img_list)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 5, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "train_list = img_list[:37]\n", 79 | "val_list = img_list[37:]" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 6, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": [ 90 | "10" 91 | ] 92 | }, 93 | "execution_count": 6, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "len(val_list)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 7, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "aug_gen = image_gen3d(train_list) \n", 109 | "valid_x, valid_y = next(image_gen3d(val_list, len(val_list))) #" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 8, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "model = unet3d((48,224,224,1))" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 9, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "name": "stdout", 128 | "output_type": "stream", 129 | "text": [ 130 | "__________________________________________________________________________________________________\n", 131 | "Layer (type) Output Shape Param # Connected to \n", 132 | "==================================================================================================\n", 133 | "input_1 (InputLayer) (None, 48, 224, 224, 0 \n", 134 | "__________________________________________________________________________________________________\n", 135 | "conv3d_1 (Conv3D) (None, 48, 224, 224, 448 input_1[0][0] \n", 136 | "__________________________________________________________________________________________________\n", 137 | "batch_normalization_1 (BatchNor (None, 48, 224, 224, 64 conv3d_1[0][0] \n", 138 | "__________________________________________________________________________________________________\n", 139 | "leaky_re_lu_1 (LeakyReLU) (None, 48, 224, 224, 0 batch_normalization_1[0][0] \n", 140 | "__________________________________________________________________________________________________\n", 141 | "conv3d_2 (Conv3D) (None, 48, 224, 224, 6928 leaky_re_lu_1[0][0] \n", 142 | "__________________________________________________________________________________________________\n", 143 | "batch_normalization_2 (BatchNor (None, 48, 224, 224, 64 conv3d_2[0][0] \n", 144 | "__________________________________________________________________________________________________\n", 145 | "leaky_re_lu_2 (LeakyReLU) (None, 48, 224, 224, 0 batch_normalization_2[0][0] \n", 146 | "__________________________________________________________________________________________________\n", 147 | "conv3d_3 (Conv3D) (None, 48, 224, 224, 6928 leaky_re_lu_2[0][0] \n", 148 | "__________________________________________________________________________________________________\n", 149 | "batch_normalization_3 (BatchNor (None, 48, 224, 224, 64 conv3d_3[0][0] \n", 150 | "__________________________________________________________________________________________________\n", 151 | "leaky_re_lu_3 (LeakyReLU) (None, 48, 224, 224, 0 batch_normalization_3[0][0] \n", 152 | "__________________________________________________________________________________________________\n", 153 | "add_1 (Add) (None, 48, 224, 224, 0 leaky_re_lu_1[0][0] \n", 154 | " leaky_re_lu_3[0][0] \n", 155 | "__________________________________________________________________________________________________\n", 156 | "conv3d_4 (Conv3D) (None, 24, 112, 112, 13856 add_1[0][0] \n", 157 | "__________________________________________________________________________________________________\n", 158 | "batch_normalization_4 (BatchNor (None, 24, 112, 112, 128 conv3d_4[0][0] \n", 159 | "__________________________________________________________________________________________________\n", 160 | "leaky_re_lu_4 (LeakyReLU) (None, 24, 112, 112, 0 batch_normalization_4[0][0] \n", 161 | "__________________________________________________________________________________________________\n", 162 | "conv3d_5 (Conv3D) (None, 24, 112, 112, 27680 leaky_re_lu_4[0][0] \n", 163 | "__________________________________________________________________________________________________\n", 164 | "batch_normalization_5 (BatchNor (None, 24, 112, 112, 128 conv3d_5[0][0] \n", 165 | "__________________________________________________________________________________________________\n", 166 | "leaky_re_lu_5 (LeakyReLU) (None, 24, 112, 112, 0 batch_normalization_5[0][0] \n", 167 | "__________________________________________________________________________________________________\n", 168 | "conv3d_6 (Conv3D) (None, 24, 112, 112, 27680 leaky_re_lu_5[0][0] \n", 169 | "__________________________________________________________________________________________________\n", 170 | "batch_normalization_6 (BatchNor (None, 24, 112, 112, 128 conv3d_6[0][0] \n", 171 | "__________________________________________________________________________________________________\n", 172 | "leaky_re_lu_6 (LeakyReLU) (None, 24, 112, 112, 0 batch_normalization_6[0][0] \n", 173 | "__________________________________________________________________________________________________\n", 174 | "add_2 (Add) (None, 24, 112, 112, 0 leaky_re_lu_4[0][0] \n", 175 | " leaky_re_lu_6[0][0] \n", 176 | "__________________________________________________________________________________________________\n", 177 | "conv3d_7 (Conv3D) (None, 12, 56, 56, 6 55360 add_2[0][0] \n", 178 | "__________________________________________________________________________________________________\n", 179 | "batch_normalization_7 (BatchNor (None, 12, 56, 56, 6 256 conv3d_7[0][0] \n", 180 | "__________________________________________________________________________________________________\n", 181 | "leaky_re_lu_7 (LeakyReLU) (None, 12, 56, 56, 6 0 batch_normalization_7[0][0] \n", 182 | "__________________________________________________________________________________________________\n", 183 | "conv3d_8 (Conv3D) (None, 12, 56, 56, 6 110656 leaky_re_lu_7[0][0] \n", 184 | "__________________________________________________________________________________________________\n", 185 | "batch_normalization_8 (BatchNor (None, 12, 56, 56, 6 256 conv3d_8[0][0] \n", 186 | "__________________________________________________________________________________________________\n", 187 | "leaky_re_lu_8 (LeakyReLU) (None, 12, 56, 56, 6 0 batch_normalization_8[0][0] \n", 188 | "__________________________________________________________________________________________________\n", 189 | "conv3d_9 (Conv3D) (None, 12, 56, 56, 6 110656 leaky_re_lu_8[0][0] \n", 190 | "__________________________________________________________________________________________________\n", 191 | "batch_normalization_9 (BatchNor (None, 12, 56, 56, 6 256 conv3d_9[0][0] \n", 192 | "__________________________________________________________________________________________________\n", 193 | "leaky_re_lu_9 (LeakyReLU) (None, 12, 56, 56, 6 0 batch_normalization_9[0][0] \n", 194 | "__________________________________________________________________________________________________\n", 195 | "add_3 (Add) (None, 12, 56, 56, 6 0 leaky_re_lu_7[0][0] \n", 196 | " leaky_re_lu_9[0][0] \n", 197 | "__________________________________________________________________________________________________\n", 198 | "conv3d_10 (Conv3D) (None, 6, 28, 28, 12 221312 add_3[0][0] \n", 199 | "__________________________________________________________________________________________________\n", 200 | "batch_normalization_10 (BatchNo (None, 6, 28, 28, 12 512 conv3d_10[0][0] \n", 201 | "__________________________________________________________________________________________________\n", 202 | "leaky_re_lu_10 (LeakyReLU) (None, 6, 28, 28, 12 0 batch_normalization_10[0][0] \n", 203 | "__________________________________________________________________________________________________\n", 204 | "conv3d_11 (Conv3D) (None, 6, 28, 28, 12 442496 leaky_re_lu_10[0][0] \n", 205 | "__________________________________________________________________________________________________\n", 206 | "batch_normalization_11 (BatchNo (None, 6, 28, 28, 12 512 conv3d_11[0][0] \n", 207 | "__________________________________________________________________________________________________\n", 208 | "leaky_re_lu_11 (LeakyReLU) (None, 6, 28, 28, 12 0 batch_normalization_11[0][0] \n", 209 | "__________________________________________________________________________________________________\n", 210 | "conv3d_12 (Conv3D) (None, 6, 28, 28, 12 442496 leaky_re_lu_11[0][0] \n", 211 | "__________________________________________________________________________________________________\n", 212 | "batch_normalization_12 (BatchNo (None, 6, 28, 28, 12 512 conv3d_12[0][0] \n", 213 | "__________________________________________________________________________________________________\n", 214 | "leaky_re_lu_12 (LeakyReLU) (None, 6, 28, 28, 12 0 batch_normalization_12[0][0] \n", 215 | "__________________________________________________________________________________________________\n", 216 | "add_4 (Add) (None, 6, 28, 28, 12 0 leaky_re_lu_10[0][0] \n", 217 | " leaky_re_lu_12[0][0] \n", 218 | "__________________________________________________________________________________________________\n", 219 | "conv3d_13 (Conv3D) (None, 3, 14, 14, 25 884992 add_4[0][0] \n", 220 | "__________________________________________________________________________________________________\n", 221 | "batch_normalization_13 (BatchNo (None, 3, 14, 14, 25 1024 conv3d_13[0][0] \n", 222 | "__________________________________________________________________________________________________\n", 223 | "leaky_re_lu_13 (LeakyReLU) (None, 3, 14, 14, 25 0 batch_normalization_13[0][0] \n", 224 | "__________________________________________________________________________________________________\n", 225 | "conv3d_14 (Conv3D) (None, 3, 14, 14, 25 1769728 leaky_re_lu_13[0][0] \n", 226 | "__________________________________________________________________________________________________\n", 227 | "batch_normalization_14 (BatchNo (None, 3, 14, 14, 25 1024 conv3d_14[0][0] \n", 228 | "__________________________________________________________________________________________________\n", 229 | "leaky_re_lu_14 (LeakyReLU) (None, 3, 14, 14, 25 0 batch_normalization_14[0][0] \n", 230 | "__________________________________________________________________________________________________\n", 231 | "conv3d_15 (Conv3D) (None, 3, 14, 14, 25 1769728 leaky_re_lu_14[0][0] \n", 232 | "__________________________________________________________________________________________________\n", 233 | "batch_normalization_15 (BatchNo (None, 3, 14, 14, 25 1024 conv3d_15[0][0] \n", 234 | "__________________________________________________________________________________________________\n", 235 | "leaky_re_lu_15 (LeakyReLU) (None, 3, 14, 14, 25 0 batch_normalization_15[0][0] \n", 236 | "__________________________________________________________________________________________________\n", 237 | "add_5 (Add) (None, 3, 14, 14, 25 0 leaky_re_lu_13[0][0] \n", 238 | " leaky_re_lu_15[0][0] \n", 239 | "__________________________________________________________________________________________________\n", 240 | "up_sampling3d_1 (UpSampling3D) (None, 6, 28, 28, 25 0 add_5[0][0] \n", 241 | "__________________________________________________________________________________________________\n", 242 | "conv3d_16 (Conv3D) (None, 6, 28, 28, 12 884864 up_sampling3d_1[0][0] \n", 243 | "__________________________________________________________________________________________________\n", 244 | "batch_normalization_16 (BatchNo (None, 6, 28, 28, 12 512 conv3d_16[0][0] \n", 245 | "__________________________________________________________________________________________________\n", 246 | "leaky_re_lu_16 (LeakyReLU) (None, 6, 28, 28, 12 0 batch_normalization_16[0][0] \n", 247 | "__________________________________________________________________________________________________\n", 248 | "concatenate_1 (Concatenate) (None, 6, 28, 28, 25 0 leaky_re_lu_16[0][0] \n", 249 | " add_4[0][0] \n", 250 | "__________________________________________________________________________________________________\n", 251 | "conv3d_17 (Conv3D) (None, 6, 28, 28, 12 884864 concatenate_1[0][0] \n", 252 | "__________________________________________________________________________________________________\n", 253 | "batch_normalization_17 (BatchNo (None, 6, 28, 28, 12 512 conv3d_17[0][0] \n", 254 | "__________________________________________________________________________________________________\n", 255 | "leaky_re_lu_17 (LeakyReLU) (None, 6, 28, 28, 12 0 batch_normalization_17[0][0] \n", 256 | "__________________________________________________________________________________________________\n", 257 | "conv3d_18 (Conv3D) (None, 6, 28, 28, 12 442496 leaky_re_lu_17[0][0] \n", 258 | "__________________________________________________________________________________________________\n", 259 | "batch_normalization_18 (BatchNo (None, 6, 28, 28, 12 512 conv3d_18[0][0] \n", 260 | "__________________________________________________________________________________________________\n", 261 | "leaky_re_lu_18 (LeakyReLU) (None, 6, 28, 28, 12 0 batch_normalization_18[0][0] \n", 262 | "__________________________________________________________________________________________________\n", 263 | "up_sampling3d_2 (UpSampling3D) (None, 12, 56, 56, 1 0 leaky_re_lu_18[0][0] \n", 264 | "__________________________________________________________________________________________________\n", 265 | "conv3d_19 (Conv3D) (None, 12, 56, 56, 6 221248 up_sampling3d_2[0][0] \n", 266 | "__________________________________________________________________________________________________\n", 267 | "batch_normalization_19 (BatchNo (None, 12, 56, 56, 6 256 conv3d_19[0][0] \n", 268 | "__________________________________________________________________________________________________\n", 269 | "leaky_re_lu_19 (LeakyReLU) (None, 12, 56, 56, 6 0 batch_normalization_19[0][0] \n", 270 | "__________________________________________________________________________________________________\n", 271 | "concatenate_2 (Concatenate) (None, 12, 56, 56, 1 0 leaky_re_lu_19[0][0] \n", 272 | " add_3[0][0] \n", 273 | "__________________________________________________________________________________________________\n", 274 | "conv3d_20 (Conv3D) (None, 12, 56, 56, 6 221248 concatenate_2[0][0] \n", 275 | "__________________________________________________________________________________________________\n", 276 | "batch_normalization_20 (BatchNo (None, 12, 56, 56, 6 256 conv3d_20[0][0] \n", 277 | "__________________________________________________________________________________________________\n", 278 | "leaky_re_lu_20 (LeakyReLU) (None, 12, 56, 56, 6 0 batch_normalization_20[0][0] \n", 279 | "__________________________________________________________________________________________________\n", 280 | "conv3d_21 (Conv3D) (None, 12, 56, 56, 6 110656 leaky_re_lu_20[0][0] \n", 281 | "__________________________________________________________________________________________________\n", 282 | "batch_normalization_21 (BatchNo (None, 12, 56, 56, 6 256 conv3d_21[0][0] \n", 283 | "__________________________________________________________________________________________________\n", 284 | "leaky_re_lu_21 (LeakyReLU) (None, 12, 56, 56, 6 0 batch_normalization_21[0][0] \n", 285 | "__________________________________________________________________________________________________\n", 286 | "up_sampling3d_4 (UpSampling3D) (None, 24, 112, 112, 0 leaky_re_lu_21[0][0] \n", 287 | "__________________________________________________________________________________________________\n", 288 | "conv3d_23 (Conv3D) (None, 24, 112, 112, 55328 up_sampling3d_4[0][0] \n", 289 | "__________________________________________________________________________________________________\n", 290 | "batch_normalization_23 (BatchNo (None, 24, 112, 112, 128 conv3d_23[0][0] \n", 291 | "__________________________________________________________________________________________________\n", 292 | "leaky_re_lu_23 (LeakyReLU) (None, 24, 112, 112, 0 batch_normalization_23[0][0] \n", 293 | "__________________________________________________________________________________________________\n", 294 | "concatenate_3 (Concatenate) (None, 24, 112, 112, 0 leaky_re_lu_23[0][0] \n", 295 | " add_2[0][0] \n", 296 | "__________________________________________________________________________________________________\n", 297 | "conv3d_24 (Conv3D) (None, 24, 112, 112, 55328 concatenate_3[0][0] \n", 298 | "__________________________________________________________________________________________________\n", 299 | "batch_normalization_24 (BatchNo (None, 24, 112, 112, 128 conv3d_24[0][0] \n", 300 | "__________________________________________________________________________________________________\n", 301 | "leaky_re_lu_24 (LeakyReLU) (None, 24, 112, 112, 0 batch_normalization_24[0][0] \n", 302 | "__________________________________________________________________________________________________\n", 303 | "conv3d_25 (Conv3D) (None, 24, 112, 112, 27680 leaky_re_lu_24[0][0] \n", 304 | "__________________________________________________________________________________________________\n", 305 | "batch_normalization_25 (BatchNo (None, 24, 112, 112, 128 conv3d_25[0][0] \n", 306 | "__________________________________________________________________________________________________\n", 307 | "leaky_re_lu_25 (LeakyReLU) (None, 24, 112, 112, 0 batch_normalization_25[0][0] \n", 308 | "__________________________________________________________________________________________________\n", 309 | "up_sampling3d_6 (UpSampling3D) (None, 48, 224, 224, 0 leaky_re_lu_25[0][0] \n", 310 | "__________________________________________________________________________________________________\n", 311 | "conv3d_27 (Conv3D) (None, 48, 224, 224, 13840 up_sampling3d_6[0][0] \n", 312 | "__________________________________________________________________________________________________\n", 313 | "batch_normalization_27 (BatchNo (None, 48, 224, 224, 64 conv3d_27[0][0] \n", 314 | "__________________________________________________________________________________________________\n", 315 | "leaky_re_lu_27 (LeakyReLU) (None, 48, 224, 224, 0 batch_normalization_27[0][0] \n", 316 | "__________________________________________________________________________________________________\n", 317 | "concatenate_4 (Concatenate) (None, 48, 224, 224, 0 leaky_re_lu_27[0][0] \n", 318 | " add_1[0][0] \n", 319 | "__________________________________________________________________________________________________\n", 320 | "conv3d_28 (Conv3D) (None, 48, 224, 224, 13840 concatenate_4[0][0] \n", 321 | "__________________________________________________________________________________________________\n", 322 | "batch_normalization_28 (BatchNo (None, 48, 224, 224, 64 conv3d_28[0][0] \n", 323 | "__________________________________________________________________________________________________\n", 324 | "leaky_re_lu_28 (LeakyReLU) (None, 48, 224, 224, 0 batch_normalization_28[0][0] \n", 325 | "__________________________________________________________________________________________________\n", 326 | "conv3d_29 (Conv3D) (None, 48, 224, 224, 6928 leaky_re_lu_28[0][0] \n", 327 | "__________________________________________________________________________________________________\n", 328 | "conv3d_22 (Conv3D) (None, 12, 56, 56, 6 10374 leaky_re_lu_21[0][0] \n", 329 | "__________________________________________________________________________________________________\n", 330 | "batch_normalization_29 (BatchNo (None, 48, 224, 224, 64 conv3d_29[0][0] \n", 331 | "__________________________________________________________________________________________________\n", 332 | "conv3d_26 (Conv3D) (None, 24, 112, 112, 5190 leaky_re_lu_25[0][0] \n", 333 | "__________________________________________________________________________________________________\n", 334 | "batch_normalization_22 (BatchNo (None, 12, 56, 56, 6 24 conv3d_22[0][0] \n", 335 | "__________________________________________________________________________________________________\n", 336 | "leaky_re_lu_29 (LeakyReLU) (None, 48, 224, 224, 0 batch_normalization_29[0][0] \n", 337 | "__________________________________________________________________________________________________\n", 338 | "batch_normalization_26 (BatchNo (None, 24, 112, 112, 24 conv3d_26[0][0] \n", 339 | "__________________________________________________________________________________________________\n", 340 | "leaky_re_lu_22 (LeakyReLU) (None, 12, 56, 56, 6 0 batch_normalization_22[0][0] \n", 341 | "__________________________________________________________________________________________________\n", 342 | "conv3d_30 (Conv3D) (None, 48, 224, 224, 2598 leaky_re_lu_29[0][0] \n", 343 | "__________________________________________________________________________________________________\n", 344 | "leaky_re_lu_26 (LeakyReLU) (None, 24, 112, 112, 0 batch_normalization_26[0][0] \n", 345 | "__________________________________________________________________________________________________\n", 346 | "up_sampling3d_3 (UpSampling3D) (None, 24, 112, 112, 0 leaky_re_lu_22[0][0] \n", 347 | "__________________________________________________________________________________________________\n", 348 | "batch_normalization_30 (BatchNo (None, 48, 224, 224, 24 conv3d_30[0][0] \n", 349 | "__________________________________________________________________________________________________\n", 350 | "add_6 (Add) (None, 24, 112, 112, 0 leaky_re_lu_26[0][0] \n", 351 | " up_sampling3d_3[0][0] \n", 352 | "__________________________________________________________________________________________________\n", 353 | "leaky_re_lu_30 (LeakyReLU) (None, 48, 224, 224, 0 batch_normalization_30[0][0] \n", 354 | "__________________________________________________________________________________________________\n", 355 | "up_sampling3d_5 (UpSampling3D) (None, 48, 224, 224, 0 add_6[0][0] \n", 356 | "__________________________________________________________________________________________________\n", 357 | "add_7 (Add) (None, 48, 224, 224, 0 leaky_re_lu_30[0][0] \n", 358 | " up_sampling3d_5[0][0] \n", 359 | "__________________________________________________________________________________________________\n", 360 | "conv3d_31 (Conv3D) (None, 48, 224, 224, 42 add_7[0][0] \n", 361 | "__________________________________________________________________________________________________\n", 362 | "softmax (Activation) (None, 48, 224, 224, 0 conv3d_31[0][0] \n", 363 | "==================================================================================================\n", 364 | "Total params: 8,856,372\n", 365 | "Trainable params: 8,851,920\n", 366 | "Non-trainable params: 4,452\n", 367 | "__________________________________________________________________________________________________\n" 368 | ] 369 | } 370 | ], 371 | "source": [ 372 | "model.summary()" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": {}, 378 | "source": [ 379 | "# Train" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": 17, 385 | "metadata": { 386 | "scrolled": true 387 | }, 388 | "outputs": [ 389 | { 390 | "name": "stderr", 391 | "output_type": "stream", 392 | "text": [ 393 | "/opt/anaconda3/envs/py36_tensorflow-1.12.2/lib/python3.6/site-packages/keras/callbacks.py:1065: UserWarning: `epsilon` argument is deprecated and will be removed, use `min_delta` instead.\n", 394 | " warnings.warn('`epsilon` argument is deprecated and '\n" 395 | ] 396 | } 397 | ], 398 | "source": [ 399 | "callbacks_list = call(\n", 400 | " weight_path = configTrain.best_model_save_path, \n", 401 | " monitor = configTrain.monitor , \n", 402 | " mode = configTrain.mode, \n", 403 | " reduce_lr_p = configTrain.reduce_lr_p, \n", 404 | " early_p = configTrain.early_p, \n", 405 | " log_csv_path = configTrain.log_csv_path,\n", 406 | " min_lr = configTrain.min_lr,\n", 407 | " factor = 0.5,\n", 408 | " min_delta = 0.0001,\n", 409 | " cooldown = 2,\n", 410 | " verbose = 1)" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 10, 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [ 419 | "n_labels = 6\n", 420 | "label_wise_dice_metrics = [mm.get_label_dice_coefficient_function(index) for index in range(n_labels)]\n", 421 | "metrics = label_wise_dice_metrics" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 19, 427 | "metadata": { 428 | "scrolled": false 429 | }, 430 | "outputs": [ 431 | { 432 | "name": "stdout", 433 | "output_type": "stream", 434 | "text": [ 435 | "Segmentation Models: using `keras` framework.\n", 436 | "Epoch 1/200\n", 437 | "20/20 [==============================] - 77s 4s/step - loss: -0.9857 - label_0_dice_coef: 0.9853 - label_1_dice_coef: 0.9875 - label_2_dice_coef: 0.9888 - label_3_dice_coef: 0.9890 - label_4_dice_coef: 0.9882 - label_5_dice_coef: 0.9884 - val_loss: -0.9788 - val_label_0_dice_coef: 0.9757 - val_label_1_dice_coef: 0.9783 - val_label_2_dice_coef: 0.9794 - val_label_3_dice_coef: 0.9790 - val_label_4_dice_coef: 0.9803 - val_label_5_dice_coef: 0.9805\n", 438 | "\n", 439 | "Epoch 00001: val_loss improved from inf to -0.97878, saving model to model_save/2019-10-16-2/one/pretrain_model_weights.best.hdf5\n", 440 | "Epoch 2/200\n", 441 | "20/20 [==============================] - 50s 2s/step - loss: -0.9864 - label_0_dice_coef: 0.9824 - label_1_dice_coef: 0.9862 - label_2_dice_coef: 0.9877 - label_3_dice_coef: 0.9885 - label_4_dice_coef: 0.9889 - label_5_dice_coef: 0.9892 - val_loss: -0.9788 - val_label_0_dice_coef: 0.9775 - val_label_1_dice_coef: 0.9809 - val_label_2_dice_coef: 0.9815 - val_label_3_dice_coef: 0.9805 - val_label_4_dice_coef: 0.9812 - val_label_5_dice_coef: 0.9816\n", 442 | "\n", 443 | "Epoch 00002: val_loss improved from -0.97878 to -0.97878, saving model to model_save/2019-10-16-2/one/pretrain_model_weights.best.hdf5\n", 444 | "Epoch 3/200\n", 445 | "20/20 [==============================] - 65s 3s/step - loss: -0.9859 - label_0_dice_coef: 0.9871 - label_1_dice_coef: 0.9885 - label_2_dice_coef: 0.9891 - label_3_dice_coef: 0.9893 - label_4_dice_coef: 0.9888 - label_5_dice_coef: 0.9886 - val_loss: -0.9811 - val_label_0_dice_coef: 0.9790 - val_label_1_dice_coef: 0.9827 - val_label_2_dice_coef: 0.9841 - val_label_3_dice_coef: 0.9835 - val_label_4_dice_coef: 0.9837 - val_label_5_dice_coef: 0.9839\n", 446 | "\n", 447 | "Epoch 00003: val_loss improved from -0.97878 to -0.98114, saving model to model_save/2019-10-16-2/one/pretrain_model_weights.best.hdf5\n", 448 | "Epoch 4/200\n", 449 | "20/20 [==============================] - 65s 3s/step - loss: -0.9869 - label_0_dice_coef: 0.9833 - label_1_dice_coef: 0.9861 - label_2_dice_coef: 0.9874 - label_3_dice_coef: 0.9870 - label_4_dice_coef: 0.9878 - label_5_dice_coef: 0.9883 - val_loss: -0.9786 - val_label_0_dice_coef: 0.9755 - val_label_1_dice_coef: 0.9788 - val_label_2_dice_coef: 0.9805 - val_label_3_dice_coef: 0.9799 - val_label_4_dice_coef: 0.9811 - val_label_5_dice_coef: 0.9813\n", 450 | "\n", 451 | "Epoch 00004: val_loss did not improve from -0.98114\n", 452 | "Epoch 5/200\n", 453 | " 9/20 [============>.................] - ETA: 34s - loss: -0.9876 - label_0_dice_coef: 0.9867 - label_1_dice_coef: 0.9884 - label_2_dice_coef: 0.9897 - label_3_dice_coef: 0.9904 - label_4_dice_coef: 0.9903 - label_5_dice_coef: 0.9901" 454 | ] 455 | }, 456 | { 457 | "ename": "KeyboardInterrupt", 458 | "evalue": "", 459 | "output_type": "error", 460 | "traceback": [ 461 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 462 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 463 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mvalidation_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mvalid_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mcallbacks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcallbacks_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mworkers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m )]\n", 464 | "\u001b[0;32m/opt/anaconda3/envs/py36_tensorflow-1.12.2/lib/python3.6/site-packages/keras/legacy/interfaces.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 89\u001b[0m warnings.warn('Update your `' + object_name + '` call to the ' +\n\u001b[1;32m 90\u001b[0m 'Keras 2 API: ' + signature, stacklevel=2)\n\u001b[0;32m---> 91\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 92\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_original_function\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 465 | "\u001b[0;32m/opt/anaconda3/envs/py36_tensorflow-1.12.2/lib/python3.6/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mfit_generator\u001b[0;34m(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)\u001b[0m\n\u001b[1;32m 1416\u001b[0m \u001b[0muse_multiprocessing\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0muse_multiprocessing\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1417\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1418\u001b[0;31m initial_epoch=initial_epoch)\n\u001b[0m\u001b[1;32m 1419\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1420\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0minterfaces\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegacy_generator_methods_support\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 466 | "\u001b[0;32m/opt/anaconda3/envs/py36_tensorflow-1.12.2/lib/python3.6/site-packages/keras/engine/training_generator.py\u001b[0m in \u001b[0;36mfit_generator\u001b[0;34m(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)\u001b[0m\n\u001b[1;32m 215\u001b[0m outs = model.train_on_batch(x, y,\n\u001b[1;32m 216\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 217\u001b[0;31m class_weight=class_weight)\n\u001b[0m\u001b[1;32m 218\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 467 | "\u001b[0;32m/opt/anaconda3/envs/py36_tensorflow-1.12.2/lib/python3.6/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mtrain_on_batch\u001b[0;34m(self, x, y, sample_weight, class_weight)\u001b[0m\n\u001b[1;32m 1215\u001b[0m \u001b[0mins\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msample_weights\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1216\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_train_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1217\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1218\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0munpack_singleton\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1219\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 468 | "\u001b[0;32m/opt/anaconda3/envs/py36_tensorflow-1.12.2/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2713\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_legacy_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2714\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2715\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2716\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2717\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mpy_any\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 469 | "\u001b[0;32m/opt/anaconda3/envs/py36_tensorflow-1.12.2/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2673\u001b[0m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_metadata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2674\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2675\u001b[0;31m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2676\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfetched\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2677\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 470 | "\u001b[0;32m/opt/anaconda3/envs/py36_tensorflow-1.12.2/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1437\u001b[0m ret = tf_session.TF_SessionRunCallable(\n\u001b[1;32m 1438\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatus\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1439\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 1440\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1441\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 471 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 472 | ] 473 | } 474 | ], 475 | "source": [ 476 | "import segmentation_models as sm\n", 477 | "# from keras.utils import multi_gpu_model\n", 478 | "# muti_model = multi_gpu_model(model, gpus=3) \n", 479 | "model.load_weights(configTrain.best_model_save_path) \n", 480 | "\n", 481 | "model.compile(\n", 482 | " optimizer = Adam(lr = configTrain.lr),\n", 483 | " loss = mm.dice_coefficient_loss, #sm.losses.bce_jaccard_loss,\n", 484 | " metrics = metrics#[sm.metrics.iou_score],\n", 485 | ")\n", 486 | "\n", 487 | "loss_history = [model.fit_generator(\n", 488 | " aug_gen, \n", 489 | " steps_per_epoch = 20, \n", 490 | " epochs = 200, \n", 491 | " validation_data = (valid_x, valid_y),\n", 492 | " callbacks = callbacks_list,\n", 493 | " workers = 1 \n", 494 | " )]" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": 20, 500 | "metadata": {}, 501 | "outputs": [ 502 | { 503 | "name": "stdout", 504 | "output_type": "stream", 505 | "text": [ 506 | "Done!\n" 507 | ] 508 | } 509 | ], 510 | "source": [ 511 | "model.load_weights(configTrain.best_model_save_path) \n", 512 | "model.save(configTrain.model_save_path)\n", 513 | "print('Done!')" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "metadata": {}, 520 | "outputs": [], 521 | "source": [] 522 | } 523 | ], 524 | "metadata": { 525 | "kernelspec": { 526 | "display_name": "Python 3", 527 | "language": "python", 528 | "name": "python3" 529 | }, 530 | "language_info": { 531 | "codemirror_mode": { 532 | "name": "ipython", 533 | "version": 3 534 | }, 535 | "file_extension": ".py", 536 | "mimetype": "text/x-python", 537 | "name": "python", 538 | "nbconvert_exporter": "python", 539 | "pygments_lexer": "ipython3", 540 | "version": "3.6.8" 541 | } 542 | }, 543 | "nbformat": 4, 544 | "nbformat_minor": 2 545 | } 546 | --------------------------------------------------------------------------------