├── requirements.txt ├── .idea ├── dictionaries │ └── Administrator.xml ├── markdown-navigator │ └── profiles_settings.xml ├── vcs.xml ├── encodings.xml ├── vagrant.xml ├── modules.xml ├── deployment.xml ├── 3DUNET.iml ├── misc.xml └── markdown-navigator.xml ├── .gitignore ├── README.md ├── utils ├── yaml_utils.py └── nii_utils.py ├── model.py └── train.py /requirements.txt: -------------------------------------------------------------------------------- 1 | yaml 2 | nibabel 3 | numpy 4 | pathlib 5 | keras 6 | random 7 | skimage -------------------------------------------------------------------------------- /.idea/dictionaries/Administrator.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | .gitignore 3 | .idea/workspace.xml 4 | __pycache__/ 5 | utils/__pycache__/ 6 | _ 7 | .idea 8 | -------------------------------------------------------------------------------- /.idea/markdown-navigator/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/vagrant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3DUNET 2 | 3D u-net keras(Simple version) 3 | 4 | This is the simaple version to build 3d unet. 5 | 6 | deal with dataset and build generator in the train.py 7 | model with dsc_metric and dsc_loss in the model.py building by keras 8 | 9 | if you want to train 10 | ```python 11 | python train.py 12 | ``` 13 | -------------------------------------------------------------------------------- /utils/yaml_utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from pathlib import Path 3 | 4 | 5 | def read(path): 6 | with Path(path).open('r') as file: 7 | params = yaml.load(file, Loader=yaml.SafeLoader) 8 | return params 9 | 10 | 11 | def write(path, data): 12 | with Path(path).open('w') as file: 13 | yaml.dump(data, file) 14 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.idea/3DUNET.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /utils/nii_utils.py: -------------------------------------------------------------------------------- 1 | import nibabel 2 | from pathlib import Path 3 | 4 | 5 | def nii_reader(path): 6 | image = nibabel.load(str(path)) 7 | image_array = image.get_fdata() 8 | return image_array 9 | 10 | 11 | def nii_header_reader(path): 12 | image = nibabel.load(str(path)) 13 | image_header = image.header 14 | pix_dim = image_header.get('pixdim') 15 | image_affine = image.affine 16 | return {'header': image_header, 'affine': image_affine, 'spacing': (pix_dim[1], pix_dim[2], pix_dim[3])} 17 | 18 | 19 | def nii_writer(path, header, image_array): 20 | Path(path).parent.mkdir(parents=True, exist_ok=True) 21 | image = nibabel.Nifti1Image(image_array, affine=header['affine'], header=header['header']) 22 | nibabel.save(image, str(path)) 23 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | Internationalization 13 | 14 | 15 | JSON and JSON5 16 | 17 | 18 | Python 19 | 20 | 21 | XML 22 | 23 | 24 | 25 | 26 | Angular 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from keras import Input, Model 2 | from keras.layers import BatchNormalization, concatenate, Conv3D, Activation, MaxPooling3D, UpSampling3D, \ 3 | Deconvolution3D 4 | from keras.optimizers import Adam 5 | from keras import backend as K 6 | 7 | K.set_image_dim_ordering('th') 8 | K.set_image_data_format("channels_first") 9 | 10 | 11 | def dice_coefficient(y_true, y_pred, smooth=1.): 12 | y_true_f = K.flatten(y_true) 13 | y_pred_f = K.flatten(y_pred) 14 | intersection = K.sum(y_true_f * y_pred_f) 15 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 16 | 17 | 18 | def dice_coefficient_loss(y_true, y_pred): 19 | return -dice_coefficient(y_true, y_pred) 20 | 21 | 22 | def convolution_block(input_layer, n_filters, kernel=(3, 3, 3), padding='same', strides=(1, 1, 1)): 23 | block = Conv3D(n_filters, kernel, padding=padding, strides=strides)(input_layer) 24 | layer = BatchNormalization(axis=1)(block) 25 | return Activation('relu')(layer) 26 | 27 | 28 | def unet_3d(input_shape, n_base_filters=32): 29 | _input = Input(input_shape) 30 | _block = _input 31 | bridge_list = list() 32 | 33 | for layer_depth in range(4): 34 | _block = convolution_block(input_layer=_block, n_filters=n_base_filters * (2 ** layer_depth)) 35 | _block = convolution_block(input_layer=_block, n_filters=n_base_filters * (2 ** layer_depth)) 36 | if layer_depth < 4 - 1: 37 | bridge_list.append(_block) 38 | _block = MaxPooling3D(pool_size=(2, 2, 2))(_block) 39 | 40 | for layer_depth in reversed(range(4)): 41 | if layer_depth < 4 - 1: 42 | _block = UpSampling3D(size=(2, 2, 2))(_block) # or change ti Deconvolution3D 43 | # _block = Deconvolution3D(filters=n_base_filters * (2 ** layer_depth), kernel_size=(2, 2, 2), 44 | # strides=(2, 2, 2)) 45 | _block = concatenate([_block, bridge_list[layer_depth]], axis=1) 46 | _block = convolution_block(input_layer=_block, n_filters=n_base_filters * (2 ** layer_depth)) 47 | _block = convolution_block(input_layer=_block, n_filters=n_base_filters * (2 ** layer_depth)) 48 | 49 | final_convolution = Conv3D(1, (1, 1, 1), activation='sigmoid')(_block) 50 | model = Model(inputs=_input, outputs=final_convolution) 51 | model.compile(optimizer=Adam(lr=0.00001), loss=dice_coefficient_loss, metrics=[dice_coefficient]) 52 | return model 53 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping, ReduceLROnPlateau 4 | from utils import nii_utils, yaml_utils 5 | from model import unet_3d 6 | from random import shuffle 7 | from skimage import transform 8 | 9 | output_path = Path('E:/Dataset/BraTS_2018') 10 | 11 | 12 | def create_data_yaml(path): 13 | if Path(str(output_path / 't22seg_train.yaml')).exists(): 14 | return 15 | paired_data = list() 16 | path = Path(path) / 't2' 17 | for t1_file in path.iterdir(): 18 | seg_file = str(t1_file).replace('t2', 'seg') 19 | t1_image = nii_utils.nii_reader(str(t1_file)) 20 | seg_image = nii_utils.nii_reader(str(seg_file)) 21 | if t1_image.shape == seg_image.shape: # check dataset 22 | paired_data.append({'t2': str(t1_file), 'seg': str(seg_file)}) 23 | shuffle(paired_data) 24 | yaml_utils.write(str(output_path / 't22seg_train.yaml'), paired_data[:8 * len(paired_data) // 10]) # train 80% 25 | yaml_utils.write(str(output_path / 't22seg_test.yaml'), paired_data[8 * len(paired_data) // 10:]) # test 20% 26 | 27 | 28 | def data_generator(data_list, batch_size): 29 | batch_x_list = list() 30 | batch_y_list = list() 31 | while True: 32 | for i in data_list: 33 | t2_model = nii_utils.nii_reader(i['t2']) 34 | t2_model = transform.resize(t2_model, (64, 64, 32)) 35 | seg_model = nii_utils.nii_reader(i['seg']) 36 | seg_model = transform.resize(seg_model, (64, 64, 32)) 37 | batch_x_list.append([t2_model]) 38 | batch_y_list.append([seg_model]) 39 | if len(batch_x_list) == batch_size: 40 | yield np.asarray(batch_x_list), np.asarray(batch_y_list) 41 | batch_x_list = list() 42 | batch_y_list = list() 43 | 44 | 45 | def data_loader(): 46 | train_list = yaml_utils.read(str(output_path / 't22seg_train.yaml')) 47 | train_generator = data_generator(train_list, batch_size=6) 48 | 49 | test_list = yaml_utils.read(str(output_path / 't22seg_test.yaml')) 50 | test_generator = data_generator(test_list, batch_size=12) 51 | return train_generator, len(train_list), test_generator, len(test_list) 52 | 53 | 54 | if __name__ == '__main__': 55 | create_data_yaml(output_path) # first deal with dataset 56 | 57 | train_generator, train_steps, validation_generator, validation_steps = data_loader() # second create generator 58 | 59 | _model = unet_3d(input_shape=(1, 64, 64, 32)) # third create model (channels,x,y,z) 60 | 61 | Path('_').mkdir(parents=True, exist_ok=True) # create file in fold _ for finding and deleting easily 62 | _model.fit_generator(generator=train_generator, steps_per_epoch=train_steps, epochs=200, # final train model 63 | validation_data=validation_generator, validation_steps=validation_steps, 64 | callbacks=[ModelCheckpoint('_/tumor_segmentation_model.h5', save_best_only=True), 65 | CSVLogger('_/training.log', append=True), 66 | ReduceLROnPlateau(factor=0.5, patience=50, verbose=1), 67 | EarlyStopping(verbose=1, patience=None)]) 68 | -------------------------------------------------------------------------------- /.idea/markdown-navigator.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 36 | 37 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | --------------------------------------------------------------------------------