├── unet_config.py ├── Unet3D.py ├── README.md └── Train_Unet3D.py /unet_config.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import math 3 | ###---Number-of-GPU 4 | NUM_OF_GPU=4 5 | DISTRIIBUTED_STRATEGY_GPUS=["gpu:0","gpu:1","gpu:2","gpu:3"] 6 | ''' 7 | if want to resume training from the weights Set 8 | RESUME_TRAINING=1 9 | ''' 10 | ###----Resume-Training 11 | RESUME_TRAINING=1 12 | RESUME_TRAIING_MODEL='/Path/of/the/model/weight/Model.h5' 13 | TRAINING_INITIAL_EPOCH=1381 14 | NUMBER_OF_CLASSES=1 15 | INPUT_PATCH_SIZE=(384,192,192, 1) 16 | ##Training Hyper-Parameter 17 | TRAIN_CLASSIFY_LEARNING_RATE =1e-4 18 | #TRAIN_CLASSIFY_LOSS=tf.keras.losses.binary_crossentropy() 19 | OPTIMIZER=tf.keras.optimizers.Adam(lr=TRAIN_CLASSIFY_LEARNING_RATE,epsilon=1e-5) 20 | #TRAIN_CLASSIFY_METRICS=tf.keras.metrics.binary_accuracy() 21 | BATCH_SIZE=4 22 | TRAINING_STEP_PER_EPOCH=math.ceil((76)/BATCH_SIZE) 23 | VALIDATION_STEP=math.ceil((6)/BATCH_SIZE) 24 | TRAING_EPOCH=1600 25 | NUMBER_OF_PARALLEL_CALL=4 26 | PARSHING=2*BATCH_SIZE 27 | #--Callbacks----- 28 | ModelCheckpoint_MOTITOR='LUNGSegVal_loss' 29 | TRAINING_SAVE_MODEL_PATH='/Path/to/save/model/weight/Model.h5' 30 | TRAINING_CSV='LungSEG_Model_March30_2020.csv' 31 | 32 | 33 | #### 34 | TRAINING_TF_RECORDS='/Training/tfrecords/path/' 35 | VALIDATION_TF_RECORDS='/Val/tfrecords/path/' 36 | -------------------------------------------------------------------------------- /Unet3D.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.models import Model 2 | from tensorflow.keras.optimizers import Adam 3 | from tensorflow.keras.layers import Conv3D, Input, MaxPooling3D, Dropout, concatenate, UpSampling3D 4 | import tensorflow as tf 5 | 6 | def Unet3D(inputs,num_classes): 7 | x=inputs 8 | conv1 = Conv3D(8, 3, activation = 'relu', padding = 'same',data_format="channels_last")(x) 9 | conv1 = Conv3D(8, 3, activation = 'relu', padding = 'same')(conv1) 10 | pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conv1) 11 | conv2 = Conv3D(16, 3, activation = 'relu', padding = 'same')(pool1) 12 | conv2 = Conv3D(16, 3, activation = 'relu', padding = 'same')(conv2) 13 | pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conv2) 14 | conv3 = Conv3D(32, 3, activation = 'relu', padding = 'same')(pool2) 15 | conv3 = Conv3D(32, 3, activation = 'relu', padding = 'same')(conv3) 16 | pool3 = MaxPooling3D(pool_size=(2, 2, 2))(conv3) 17 | conv4 = Conv3D(64, 3, activation = 'relu', padding = 'same')(pool3) 18 | conv4 = Conv3D(64, 3, activation = 'relu', padding = 'same')(conv4) 19 | drop4 = Dropout(0.5)(conv4) 20 | pool4 = MaxPooling3D(pool_size=(2, 2, 2))(drop4) 21 | 22 | conv5 = Conv3D(128, 3, activation = 'relu', padding = 'same')(pool4) 23 | conv5 = Conv3D(128, 3, activation = 'relu', padding = 'same')(conv5) 24 | drop5 = Dropout(0.5)(conv5) 25 | 26 | up6 = Conv3D(64, 2, activation = 'relu', padding = 'same')(UpSampling3D(size = (2,2,2))(drop5)) 27 | merge6 = concatenate([drop4,up6],axis=-1) 28 | conv6 = Conv3D(64, 3, activation = 'relu', padding = 'same')(merge6) 29 | conv6 = Conv3D(64, 3, activation = 'relu', padding = 'same')(conv6) 30 | 31 | up7 = Conv3D(32, 2, activation = 'relu', padding = 'same')(UpSampling3D(size = (2,2,2))(conv6)) 32 | merge7 = concatenate([conv3,up7],axis=-1) 33 | conv7 = Conv3D(32, 3, activation = 'relu', padding = 'same')(merge7) 34 | conv7 = Conv3D(32, 3, activation = 'relu', padding = 'same')(conv7) 35 | 36 | up8 = Conv3D(16, 2, activation = 'relu', padding = 'same')(UpSampling3D(size = (2,2,2))(conv7)) 37 | merge8 = concatenate([conv2,up8],axis=-1) 38 | conv8 = Conv3D(16, 3, activation = 'relu', padding = 'same')(merge8) 39 | conv8 = Conv3D(16, 3, activation = 'relu', padding = 'same')(conv8) 40 | 41 | up9 = Conv3D(8, 2, activation = 'relu', padding = 'same')(UpSampling3D(size = (2,2,2))(conv8)) 42 | merge9 = concatenate([conv1,up9],axis=-1) 43 | conv9 = Conv3D(8, 3, activation = 'relu', padding = 'same')(merge9) 44 | conv9 = Conv3D(8, 3, activation = 'relu', padding = 'same')(conv9) 45 | conv10 = Conv3D(1, 1, activation = 'sigmoid')(conv9) 46 | model = Model(inputs=inputs, outputs = conv10) 47 | #model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy']) 48 | return model 49 | 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3DUnet_tensorflow2.0 2 | This Repo is for implementation of 3D unet in Tensorflow 2.0v 3 | 4 | ## Files: 5 | * i) `unet_config.py -|--> All the Netword and Training configuration` 6 | * ii) `Unet3D |--> Network architecture` 7 | * iii) `Train_Unet3D |--> Training Script. it has tfrecord decoder, tfdataset reading pipeline and training loop,Losses and Matrics function. Binary Dice Coefficent and Dice Loss` 8 | 9 | ## How to run 10 | To run the model all is to need to configure the `unet_config.py` based on your requiremnet. 11 | ```ruby 12 | ###---Number-of-GPU 13 | NUM_OF_GPU=4 14 | DISTRIIBUTED_STRATEGY_GPUS=["gpu:0","gpu:1","gpu:2","gpu:3"] 15 | ''' 16 | if want to resume training from the weights Set 17 | RESUME_TRAINING=1 18 | ''' 19 | ###----Resume-Training 20 | RESUME_TRAINING=1 21 | RESUME_TRAIING_MODEL='/Path/of/the/model/weight/Model.h5' 22 | TRAINING_INITIAL_EPOCH=1381 23 | NUMBER_OF_CLASSES=1 24 | INPUT_PATCH_SIZE=(384,192,192, 1) 25 | ##Training Hyper-Parameter 26 | TRAIN_CLASSIFY_LEARNING_RATE =1e-4 27 | #TRAIN_CLASSIFY_LOSS=tf.keras.losses.binary_crossentropy() 28 | OPTIMIZER=tf.keras.optimizers.Adam(lr=TRAIN_CLASSIFY_LEARNING_RATE,epsilon=1e-5) 29 | #TRAIN_CLASSIFY_METRICS=tf.keras.metrics.binary_accuracy() 30 | BATCH_SIZE=4 31 | TRAINING_STEP_PER_EPOCH=math.ceil((76)/BATCH_SIZE) 32 | VALIDATION_STEP=math.ceil((6)/BATCH_SIZE) 33 | TRAING_EPOCH=1600 34 | NUMBER_OF_PARALLEL_CALL=4 35 | PARSHING=2*BATCH_SIZE 36 | #--Callbacks----- 37 | ModelCheckpoint_MOTITOR='LUNGSegVal_loss' 38 | TRAINING_SAVE_MODEL_PATH=''/Path/to/save/model/weight/Model.h5'' 39 | TRAINING_CSV='LungSEG_Model_March30_2020.csv' 40 | #### 41 | TRAINING_TF_RECORDS='/Training/tfrecords/path/' 42 | VALIDATION_TF_RECORDS='/Val/tfrecords/path/' 43 | ``` 44 | 45 | ## Dice Loss 46 | ```ruby 47 | def dice_coe(y_true,y_pred, loss_type='jaccard', smooth=1.): 48 | 49 | y_true_f = tf.reshape(y_true,[-1]) 50 | y_pred_f = tf.reshape(y_pred,[-1]) 51 | 52 | intersection = tf.reduce_sum(y_true_f * y_pred_f) 53 | 54 | if loss_type == 'jaccard': 55 | union = tf.reduce_sum(tf.square(y_pred_f)) + tf.reduce_sum(tf.square(y_true_f)) 56 | 57 | elif loss_type == 'sorensen': 58 | union = tf.reduce_sum(y_pred_f) + tf.reduce_sum(y_true_f) 59 | 60 | else: 61 | raise ValueError("Unknown `loss_type`: %s" % loss_type) 62 | 63 | return (2. * intersection + smooth) / (union + smooth) 64 | 65 | def dice_loss(y_true,y_pred, loss_type='jaccard', smooth=1.): 66 | 67 | y_true_f = tf.cast(tf.reshape(y_true,[-1]),tf.float32) 68 | y_pred_f =tf.cast(tf.reshape(y_pred,[-1]),tf.float32) 69 | 70 | intersection = tf.reduce_sum(y_true_f * y_pred_f) 71 | 72 | if loss_type == 'jaccard': 73 | union = tf.reduce_sum(tf.square(y_pred_f)) + tf.reduce_sum(tf.square(y_true_f)) 74 | 75 | elif loss_type == 'sorensen': 76 | union = tf.reduce_sum(y_pred_f) + tf.reduce_sum(y_true_f) 77 | 78 | else: 79 | raise ValueError("Unknown `loss_type`: %s" % loss_type) 80 | 81 | return (1-(2. * intersection + smooth) / (union + smooth)) 82 | ``` 83 | -------------------------------------------------------------------------------- /Train_Unet3D.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | import tensorflow as tf 3 | ''' 4 | tf.config.optimizer.set_jit(True) 5 | gpus = tf.config.experimental.list_physical_devices('GPU') 6 | if gpus: 7 | # Restrict TensorFlow to only use the first GPU 8 | try: 9 | tf.config.experimental.set_visible_devices(gpus[0], 'GPU') 10 | except RuntimeError as e: 11 | # Visible devices must be set at program startup 12 | print(e) 13 | ''' 14 | 15 | from tensorflow.keras.optimizers import Adam 16 | from unet_config import* 17 | import os 18 | import datetime 19 | from Unet3D import Unet3D 20 | import numpy as np 21 | import random 22 | 23 | def dice_coe(y_true,y_pred, loss_type='jaccard', smooth=1.): 24 | 25 | y_true_f = tf.reshape(y_true,[-1]) 26 | y_pred_f = tf.reshape(y_pred,[-1]) 27 | 28 | intersection = tf.reduce_sum(y_true_f * y_pred_f) 29 | 30 | if loss_type == 'jaccard': 31 | union = tf.reduce_sum(tf.square(y_pred_f)) + tf.reduce_sum(tf.square(y_true_f)) 32 | 33 | elif loss_type == 'sorensen': 34 | union = tf.reduce_sum(y_pred_f) + tf.reduce_sum(y_true_f) 35 | 36 | else: 37 | raise ValueError("Unknown `loss_type`: %s" % loss_type) 38 | 39 | return (2. * intersection + smooth) / (union + smooth) 40 | 41 | def dice_loss(y_true,y_pred, loss_type='jaccard', smooth=1.): 42 | 43 | y_true_f = tf.cast(tf.reshape(y_true,[-1]),tf.float32) 44 | y_pred_f =tf.cast(tf.reshape(y_pred,[-1]),tf.float32) 45 | 46 | intersection = tf.reduce_sum(y_true_f * y_pred_f) 47 | 48 | if loss_type == 'jaccard': 49 | union = tf.reduce_sum(tf.square(y_pred_f)) + tf.reduce_sum(tf.square(y_true_f)) 50 | 51 | elif loss_type == 'sorensen': 52 | union = tf.reduce_sum(y_pred_f) + tf.reduce_sum(y_true_f) 53 | 54 | else: 55 | raise ValueError("Unknown `loss_type`: %s" % loss_type) 56 | 57 | return (1-(2. * intersection + smooth) / (union + smooth)) 58 | 59 | 60 | @tf.function 61 | def decode_SEGct(Serialized_example): 62 | 63 | features={ 64 | 'image':tf.io.FixedLenFeature([],tf.string), 65 | 'mask':tf.io.FixedLenFeature([],tf.string), 66 | 'Height':tf.io.FixedLenFeature([],tf.int64), 67 | 'Weight':tf.io.FixedLenFeature([],tf.int64), 68 | 'Depth':tf.io.FixedLenFeature([],tf.int64), 69 | 'Sub_id':tf.io.FixedLenFeature([],tf.string) 70 | 71 | } 72 | examples=tf.io.parse_single_example(Serialized_example,features) 73 | ##Decode_image_float 74 | image_1 = tf.io.decode_raw(examples['image'], float) 75 | #Decode_mask_as_int32 76 | mask_1 = tf.io.decode_raw(examples['mask'], tf.int32) 77 | ##Subject id is already in bytes format 78 | #sub_id=examples['Sub_id'] 79 | img_shape=[examples['Height'],examples['Weight'],examples['Depth']] 80 | #img_shape2=[img_shape[0],img_shape[1],img_shape[2]] 81 | print(img_shape) 82 | #Resgapping_the_data 83 | img=tf.reshape(image_1,img_shape) 84 | mask=tf.reshape(mask_1,img_shape) 85 | #Because CNN expect(batch,H,W,D,CHANNEL) 86 | img=tf.expand_dims(img, axis=-1) 87 | mask=tf.expand_dims(mask, axis=-1) 88 | ###casting_values 89 | img=tf.cast(img, tf.float32) 90 | mask=tf.cast(mask,tf.int32) 91 | 92 | return img,mask 93 | 94 | 95 | 96 | def getting_list(path): 97 | a=[file for file in os.listdir(path) if file.endswith('.tfrecords')] 98 | all_tfrecoeds=random.sample(a, len(a)) 99 | #all_tfrecoeds.sort(key=lambda f: int(filter(str.isdigit, f))) 100 | list_of_tfrecords=[] 101 | for i in range(len(all_tfrecoeds)): 102 | tf_path=path+all_tfrecoeds[i] 103 | list_of_tfrecords.append(tf_path) 104 | return list_of_tfrecords 105 | 106 | #--Traing Decoder 107 | def load_training_tfrecords(record_mask_file,batch_size): 108 | dataset=tf.data.Dataset.list_files(record_mask_file).interleave(lambda x: tf.data.TFRecordDataset(x),cycle_length=NUMBER_OF_PARALLEL_CALL,num_parallel_calls=NUMBER_OF_PARALLEL_CALL) 109 | dataset=dataset.map(decode_SEGct,num_parallel_calls=NUMBER_OF_PARALLEL_CALL).repeat(TRAING_EPOCH).batch(batch_size) 110 | batched_dataset=dataset.prefetch(PARSHING) 111 | return batched_dataset 112 | 113 | #--Validation Decoder 114 | def load_validation_tfrecords(record_mask_file,batch_size): 115 | dataset=tf.data.Dataset.list_files(record_mask_file).interleave(tf.data.TFRecordDataset,cycle_length=NUMBER_OF_PARALLEL_CALL,num_parallel_calls=NUMBER_OF_PARALLEL_CALL) 116 | dataset=dataset.map(decode_SEGct,num_parallel_calls=NUMBER_OF_PARALLEL_CALL).repeat(TRAING_EPOCH).batch(batch_size) 117 | batched_dataset=dataset.prefetch(PARSHING) 118 | return batched_dataset 119 | 120 | 121 | def Training(): 122 | 123 | #TensorBoard 124 | logdir = os.path.join("LungSEG_Log_March30_2020", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) 125 | tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1) 126 | ##csv_logger 127 | csv_logger = tf.keras.callbacks.CSVLogger(TRAINING_CSV) 128 | ##Model-checkpoings 129 | path=TRAINING_SAVE_MODEL_PATH 130 | model_path=os.path.join(path, "LungSEGModel_{val_loss:.2f}_{epoch}.h5") 131 | Model_callback= tf.keras.callbacks.ModelCheckpoint(filepath=model_path,save_best_only=False,save_weights_only=True,monitor=ModelCheckpoint_MOTITOR,verbose=1) 132 | 133 | tf_train=getting_list(TRAINING_TF_RECORDS) 134 | tf_val=getting_list(VALIDATION_TF_RECORDS) 135 | 136 | traing_data=load_training_tfrecords(tf_train,BATCH_SIZE) 137 | Val_batched_dataset=load_validation_tfrecords(tf_val,BATCH_SIZE) 138 | 139 | if (NUM_OF_GPU==1): 140 | 141 | if RESUME_TRAINING==1: 142 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 143 | Model_3D=Unet3D(inputs,num_classes=NUMBER_OF_CLASSES) 144 | Model_3D.load_weights(RESUME_TRAIING_MODEL) 145 | initial_epoch_of_training=TRAINING_INITIAL_EPOCH 146 | Model_3D.compile(optimizer=OPTIMIZER, loss=[dice_loss], metrics=['accuracy',dice_coe]) 147 | Model_3D.summary() 148 | else: 149 | initial_epoch_of_training=0 150 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 151 | Model_3D=Unet3D(inputs,num_classes=NUMBER_OF_CLASSES) 152 | Model_3D.compile(optimizer=OPTIMIZER, loss=[dice_loss], metrics=['accuracy',dice_coe]) 153 | Model_3D.summary() 154 | 155 | Model_3D.fit(traing_data, 156 | steps_per_epoch=TRAINING_STEP_PER_EPOCH, 157 | epochs=TRAING_EPOCH, 158 | initial_epoch=initial_epoch_of_training, 159 | validation_data=Val_batched_dataset, 160 | validation_steps=VALIDATION_STEP, 161 | callbacks=[tensorboard_callback,csv_logger,Model_callback]) 162 | 163 | ###Multigpu---- 164 | else: 165 | mirrored_strategy = tf.distribute.MirroredStrategy(DISTRIIBUTED_STRATEGY_GPUS) 166 | with mirrored_strategy.scope(): 167 | if RESUME_TRAINING==1: 168 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 169 | Model_3D=Unet3D(inputs,num_classes=NUMBER_OF_CLASSES) 170 | Model_3D.load_weights(RESUME_TRAIING_MODEL) 171 | initial_epoch_of_training=TRAINING_INITIAL_EPOCH 172 | Model_3D.compile(optimizer=OPTIMIZER, loss=[dice_loss], metrics=['accuracy',dice_coe]) 173 | Model_3D.summary() 174 | else: 175 | initial_epoch_of_training=0 176 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 177 | Model_3D=Unet3D(inputs,num_classes=NUMBER_OF_CLASSES) 178 | Model_3D.compile(optimizer=OPTIMIZER, loss=[dice_loss], metrics=['accuracy',dice_coe]) 179 | Model_3D.summary() 180 | 181 | 182 | 183 | Model_3D.fit(traing_data,steps_per_epoch=TRAINING_STEP_PER_EPOCH,epochs=TRAING_EPOCH,initial_epoch=initial_epoch_of_training,validation_data=Val_batched_dataset,validation_steps=VALIDATION_STEP, 184 | callbacks=[tensorboard_callback,csv_logger,Model_callback]) 185 | 186 | if __name__ == '__main__': 187 | Training() 188 | --------------------------------------------------------------------------------