├── requirements.txt
├── .idea
├── dictionaries
│ └── Administrator.xml
├── markdown-navigator
│ └── profiles_settings.xml
├── vcs.xml
├── encodings.xml
├── vagrant.xml
├── modules.xml
├── deployment.xml
├── 3DUNET.iml
├── misc.xml
└── markdown-navigator.xml
├── .gitignore
├── README.md
├── utils
├── yaml_utils.py
└── nii_utils.py
├── model.py
└── train.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | yaml
2 | nibabel
3 | numpy
4 | pathlib
5 | keras
6 | random
7 | skimage
--------------------------------------------------------------------------------
/.idea/dictionaries/Administrator.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | .gitignore
3 | .idea/workspace.xml
4 | __pycache__/
5 | utils/__pycache__/
6 | _
7 | .idea
8 |
--------------------------------------------------------------------------------
/.idea/markdown-navigator/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vagrant.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 3DUNET
2 | 3D u-net keras(Simple version)
3 |
4 | This is the simaple version to build 3d unet.
5 |
6 | deal with dataset and build generator in the train.py
7 | model with dsc_metric and dsc_loss in the model.py building by keras
8 |
9 | if you want to train
10 | ```python
11 | python train.py
12 | ```
13 |
--------------------------------------------------------------------------------
/utils/yaml_utils.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from pathlib import Path
3 |
4 |
5 | def read(path):
6 | with Path(path).open('r') as file:
7 | params = yaml.load(file, Loader=yaml.SafeLoader)
8 | return params
9 |
10 |
11 | def write(path, data):
12 | with Path(path).open('w') as file:
13 | yaml.dump(data, file)
14 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/3DUNET.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/utils/nii_utils.py:
--------------------------------------------------------------------------------
1 | import nibabel
2 | from pathlib import Path
3 |
4 |
5 | def nii_reader(path):
6 | image = nibabel.load(str(path))
7 | image_array = image.get_fdata()
8 | return image_array
9 |
10 |
11 | def nii_header_reader(path):
12 | image = nibabel.load(str(path))
13 | image_header = image.header
14 | pix_dim = image_header.get('pixdim')
15 | image_affine = image.affine
16 | return {'header': image_header, 'affine': image_affine, 'spacing': (pix_dim[1], pix_dim[2], pix_dim[3])}
17 |
18 |
19 | def nii_writer(path, header, image_array):
20 | Path(path).parent.mkdir(parents=True, exist_ok=True)
21 | image = nibabel.Nifti1Image(image_array, affine=header['affine'], header=header['header'])
22 | nibabel.save(image, str(path))
23 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | Internationalization
13 |
14 |
15 | JSON and JSON5
16 |
17 |
18 | Python
19 |
20 |
21 | XML
22 |
23 |
24 |
25 |
26 | Angular
27 |
28 |
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from keras import Input, Model
2 | from keras.layers import BatchNormalization, concatenate, Conv3D, Activation, MaxPooling3D, UpSampling3D, \
3 | Deconvolution3D
4 | from keras.optimizers import Adam
5 | from keras import backend as K
6 |
7 | K.set_image_dim_ordering('th')
8 | K.set_image_data_format("channels_first")
9 |
10 |
11 | def dice_coefficient(y_true, y_pred, smooth=1.):
12 | y_true_f = K.flatten(y_true)
13 | y_pred_f = K.flatten(y_pred)
14 | intersection = K.sum(y_true_f * y_pred_f)
15 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
16 |
17 |
18 | def dice_coefficient_loss(y_true, y_pred):
19 | return -dice_coefficient(y_true, y_pred)
20 |
21 |
22 | def convolution_block(input_layer, n_filters, kernel=(3, 3, 3), padding='same', strides=(1, 1, 1)):
23 | block = Conv3D(n_filters, kernel, padding=padding, strides=strides)(input_layer)
24 | layer = BatchNormalization(axis=1)(block)
25 | return Activation('relu')(layer)
26 |
27 |
28 | def unet_3d(input_shape, n_base_filters=32):
29 | _input = Input(input_shape)
30 | _block = _input
31 | bridge_list = list()
32 |
33 | for layer_depth in range(4):
34 | _block = convolution_block(input_layer=_block, n_filters=n_base_filters * (2 ** layer_depth))
35 | _block = convolution_block(input_layer=_block, n_filters=n_base_filters * (2 ** layer_depth))
36 | if layer_depth < 4 - 1:
37 | bridge_list.append(_block)
38 | _block = MaxPooling3D(pool_size=(2, 2, 2))(_block)
39 |
40 | for layer_depth in reversed(range(4)):
41 | if layer_depth < 4 - 1:
42 | _block = UpSampling3D(size=(2, 2, 2))(_block) # or change ti Deconvolution3D
43 | # _block = Deconvolution3D(filters=n_base_filters * (2 ** layer_depth), kernel_size=(2, 2, 2),
44 | # strides=(2, 2, 2))
45 | _block = concatenate([_block, bridge_list[layer_depth]], axis=1)
46 | _block = convolution_block(input_layer=_block, n_filters=n_base_filters * (2 ** layer_depth))
47 | _block = convolution_block(input_layer=_block, n_filters=n_base_filters * (2 ** layer_depth))
48 |
49 | final_convolution = Conv3D(1, (1, 1, 1), activation='sigmoid')(_block)
50 | model = Model(inputs=_input, outputs=final_convolution)
51 | model.compile(optimizer=Adam(lr=0.00001), loss=dice_coefficient_loss, metrics=[dice_coefficient])
52 | return model
53 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pathlib import Path
3 | from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping, ReduceLROnPlateau
4 | from utils import nii_utils, yaml_utils
5 | from model import unet_3d
6 | from random import shuffle
7 | from skimage import transform
8 |
9 | output_path = Path('E:/Dataset/BraTS_2018')
10 |
11 |
12 | def create_data_yaml(path):
13 | if Path(str(output_path / 't22seg_train.yaml')).exists():
14 | return
15 | paired_data = list()
16 | path = Path(path) / 't2'
17 | for t1_file in path.iterdir():
18 | seg_file = str(t1_file).replace('t2', 'seg')
19 | t1_image = nii_utils.nii_reader(str(t1_file))
20 | seg_image = nii_utils.nii_reader(str(seg_file))
21 | if t1_image.shape == seg_image.shape: # check dataset
22 | paired_data.append({'t2': str(t1_file), 'seg': str(seg_file)})
23 | shuffle(paired_data)
24 | yaml_utils.write(str(output_path / 't22seg_train.yaml'), paired_data[:8 * len(paired_data) // 10]) # train 80%
25 | yaml_utils.write(str(output_path / 't22seg_test.yaml'), paired_data[8 * len(paired_data) // 10:]) # test 20%
26 |
27 |
28 | def data_generator(data_list, batch_size):
29 | batch_x_list = list()
30 | batch_y_list = list()
31 | while True:
32 | for i in data_list:
33 | t2_model = nii_utils.nii_reader(i['t2'])
34 | t2_model = transform.resize(t2_model, (64, 64, 32))
35 | seg_model = nii_utils.nii_reader(i['seg'])
36 | seg_model = transform.resize(seg_model, (64, 64, 32))
37 | batch_x_list.append([t2_model])
38 | batch_y_list.append([seg_model])
39 | if len(batch_x_list) == batch_size:
40 | yield np.asarray(batch_x_list), np.asarray(batch_y_list)
41 | batch_x_list = list()
42 | batch_y_list = list()
43 |
44 |
45 | def data_loader():
46 | train_list = yaml_utils.read(str(output_path / 't22seg_train.yaml'))
47 | train_generator = data_generator(train_list, batch_size=6)
48 |
49 | test_list = yaml_utils.read(str(output_path / 't22seg_test.yaml'))
50 | test_generator = data_generator(test_list, batch_size=12)
51 | return train_generator, len(train_list), test_generator, len(test_list)
52 |
53 |
54 | if __name__ == '__main__':
55 | create_data_yaml(output_path) # first deal with dataset
56 |
57 | train_generator, train_steps, validation_generator, validation_steps = data_loader() # second create generator
58 |
59 | _model = unet_3d(input_shape=(1, 64, 64, 32)) # third create model (channels,x,y,z)
60 |
61 | Path('_').mkdir(parents=True, exist_ok=True) # create file in fold _ for finding and deleting easily
62 | _model.fit_generator(generator=train_generator, steps_per_epoch=train_steps, epochs=200, # final train model
63 | validation_data=validation_generator, validation_steps=validation_steps,
64 | callbacks=[ModelCheckpoint('_/tumor_segmentation_model.h5', save_best_only=True),
65 | CSVLogger('_/training.log', append=True),
66 | ReduceLROnPlateau(factor=0.5, patience=50, verbose=1),
67 | EarlyStopping(verbose=1, patience=None)])
68 |
--------------------------------------------------------------------------------
/.idea/markdown-navigator.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------