├── DyFA ├── DenseVnet3D.py ├── DyFA_Model.py ├── Preprocessing_utlities.py ├── config.py ├── loss_funnction_And_matrics.py ├── resume_training_using_check_point.py └── tfrecords_utilities.py ├── README.md ├── SPIE_presentation.pptx ├── StFA ├── DenseVnet3D.py ├── Preprocessing_utlities.py ├── StFA_Model.py ├── config.py ├── loss_funnction_And_matrics.py ├── resume_training_using_check_point.py └── tfrecords_utilities.py └── figure ├── Model_Architecture.png ├── dataset.png └── results.png /DyFA/DenseVnet3D.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | import tensorflow as tf 5 | 6 | 7 | 8 | ##########---tf bilinear UpSampling3D 9 | def up_sampling(input_tensor, scale): 10 | net = tf.keras.layers.TimeDistributed(tf.keras.layers.UpSampling2D(size=(scale, scale), interpolation='bilinear'))(input_tensor) 11 | net = tf.keras.layers.Permute((2, 1, 3, 4))(net) # (B, z, H, W, C) -> (B, H, z, w, c) 12 | net = tf.keras.layers.TimeDistributed(tf.keras.layers.UpSampling2D(size=(scale, 1), interpolation='bilinear'))(net) 13 | net = tf.keras.layers.Permute((2, 1, 3, 4))(net) # (B, z, H, W, C) -> (B, H, z, w, c) 14 | return net 15 | 16 | #######-----Bottleneck 17 | def Bottleneck(x, nb_filter, increase_factor=4., weight_decay=1e-4): 18 | inter_channel = int(nb_filter * increase_factor) 19 | x = tf.keras.layers.Conv3D(inter_channel, (1, 1, 1), 20 | kernel_initializer='he_normal', 21 | padding='same', 22 | use_bias=False, 23 | kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) 24 | x = tf.keras.layers.BatchNormalization(epsilon=1.1e-5)(x) 25 | x = tf.nn.relu6(x) 26 | return x 27 | 28 | #####------------>>> Convolutional Block 29 | def conv_block(input, nb_filter, kernal_size=(3, 3, 3), dilation_rate=1, 30 | bottleneck=False, dropout_rate=None, weight_decay=1e-4): 31 | ''' Apply BatchNorm, Relu, 3x3X3 Conv3D, optional bottleneck block and dropout 32 | Args: 33 | input: Input tensor 34 | nb_filter: number of filters 35 | bottleneck: add bottleneck block 36 | dropout_rate: dropout rate 37 | weight_decay: weight decay factor 38 | Returns: tensor with batch_norm, relu and convolution3D added (optional bottleneck) 39 | ''' 40 | 41 | 42 | x = tf.keras.layers.BatchNormalization(epsilon=1.1e-5)(input) 43 | x = tf.nn.relu6(x) 44 | 45 | if bottleneck: 46 | inter_channel = nb_filter # Obtained from https://github.com/liuzhuang13/DenseNet/blob/master/densenet.lua 47 | x = tf.keras.layers.Conv3D(inter_channel, (1, 1, 1), 48 | kernel_initializer='he_normal', 49 | padding='same', 50 | use_bias=False, 51 | kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) 52 | x = tf.keras.layers.BatchNormalization(epsilon=1.1e-5)(x) 53 | x = tf.nn.relu6(x) 54 | 55 | x = tf.keras.layers.Conv3D(nb_filter, kernal_size, 56 | dilation_rate=dilation_rate, 57 | kernel_initializer='he_normal', 58 | padding='same', 59 | use_bias=False)(x) 60 | if dropout_rate: 61 | x = tf.keras.layers.SpatialDropout3D(dropout_rate)(x) 62 | return x 63 | 64 | ##--------------------DenseBlock-------#### 65 | def dense_block(x, nb_layers, growth_rate, kernal_size=(3, 3, 3), 66 | dilation_list=None, 67 | bottleneck=True, dropout_rate=None, weight_decay=1e-4, 68 | return_concat_list=False): 69 | ''' Build a dense_block where the output of each conv_block is fed to subsequent ones 70 | Args: 71 | x: input tensor 72 | nb_layers: the number of layers of conv_block to append to the model. 73 | nb_filter: number of filters 74 | growth_rate: growth rate 75 | bottleneck: bottleneck block 76 | dropout_rate: dropout rate 77 | weight_decay: weight decay factor 78 | grow_nb_filters: flag to decide to allow number of filters to grow 79 | return_concat_list: return the list of feature maps along with the actual output 80 | Returns: tensor with nb_layers of conv_block appended 81 | ''' 82 | 83 | if dilation_list is None: 84 | dilation_list = [1] * nb_layers 85 | elif type(dilation_list) is int: 86 | dilation_list = [dilation_list] * nb_layers 87 | else: 88 | if len(dilation_list) != nb_layers: 89 | raise ('the length of dilation_list should be equal to nb_layers %d' % nb_layers) 90 | 91 | x_list = [x] 92 | 93 | for i in range(nb_layers): 94 | cb = conv_block(x, growth_rate, kernal_size, dilation_list[i], 95 | bottleneck, dropout_rate, weight_decay) 96 | x_list.append(cb) 97 | if i == 0: 98 | x = cb 99 | else: 100 | x = tf.keras.layers.concatenate([x, cb], axis=-1) 101 | 102 | if return_concat_list: 103 | return x, x_list 104 | else: 105 | return x 106 | 107 | ###---------transition_block 108 | def transition_block(input, nb_filter, compression=1.0, weight_decay=1e-4, 109 | pool_kernal=(3, 3, 3), pool_strides=(2, 2, 2)): 110 | ''' Apply BatchNorm, Relu 1x1, Conv3D, optional compression, dropout and Maxpooling3D 111 | Args: 112 | input: input tensor 113 | nb_filter: number of filters 114 | compression: calculated as 1 - reduction. Reduces the number of feature maps 115 | in the transition block. 116 | dropout_rate: dropout rate 117 | weight_decay: weight decay factor 118 | Returns: keras tensor, after applying batch_norm, relu-conv, dropout, maxpool 119 | ''' 120 | 121 | 122 | x =tf.keras.layers.BatchNormalization(epsilon=1.1e-5)(input) 123 | x = tf.nn.relu6(x) 124 | x = tf.keras.layers.Conv3D(int(nb_filter * compression), (1, 1, 1), 125 | kernel_initializer='he_normal', 126 | padding='same', 127 | use_bias=False, 128 | kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) 129 | x = tf.keras.layers.AveragePooling3D(pool_kernal, strides=pool_strides)(x) 130 | 131 | return x 132 | 133 | ###---Trasnsition up block 134 | def transition_up_block(input, nb_filters, compression=1.0, 135 | kernal_size=(3, 3, 3), pool_strides=(2, 2, 2), 136 | type='deconv', weight_decay=1E-4): 137 | ''' SubpixelConvolutional Upscaling (factor = 2) 138 | Args: 139 | input: tensor 140 | nb_filters: number of layers 141 | type: can be 'upsampling', 'subpixel', 'deconv'. Determines type of upsampling performed 142 | weight_decay: weight decay factor 143 | Returns: keras tensor, after applying upsampling operation. 144 | ''' 145 | 146 | if type == 'upsampling': 147 | x = tf.keras.layers.UpSampling3D(size=kernal_size, interpolation='bilinear')(input) 148 | x = tf.keras.layers.BatchNormalization(epsilon=1.1e-5)(x) 149 | x = tf.nn.relu6(x) 150 | x = tf.keras.layers.Conv3D(int(nb_filters * compression), (1, 1, 1), 151 | kernel_initializer='he_normal', 152 | padding='same', 153 | use_bias=False, 154 | kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) 155 | 156 | else: 157 | x = tf.keras.layers.Conv3DTranspose(int(nb_filters * compression), 158 | kernal_size, 159 | strides=pool_strides, 160 | activation='relu', 161 | padding='same', 162 | kernel_initializer='he_normal', kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(input) 163 | 164 | return x 165 | 166 | 167 | 168 | def DenseVnet3D(inputs, 169 | nb_classes=1, 170 | encoder_nb_layers=(5, 8, 8), 171 | growth_rate=(4, 8, 12), 172 | dilation_list=(5, 3, 1), 173 | dropout_rate=0.25, 174 | weight_decay=1e-4, 175 | init_conv_filters=24): 176 | """ 3D DenseVNet Implementation by f.i.tushar, tf 2.0. 177 | This is a tensorflow 2.0 Implementation of paper: 178 | Gibson et al., "Automatic multi-organ segmentation on abdominal CT with 179 | dense V-networks" 2018. 180 | 181 | Reference Implementation: vision4med :i) https://github.com/baibaidj/vision4med/blob/5c23f57c2836bfabd7bd95a024a0a0b776b181b5/nets/DenseVnet.py 182 | ii) https://niftynet.readthedocs.io/en/dev/_modules/niftynet/network/dense_vnet.html#DenseVNet 183 | 184 | Input 185 | | 186 | --[ DFS ]-----------------------[ Conv ]------------[ Conv ]------[+]--> 187 | | | | | 188 | -----[ DFS ]---------------[ Conv ]------ | | 189 | | | | 190 | -----[ DFS ]-------[ Conv ]--------- | 191 | [ Prior ]--- 192 | Args: 193 | inputs: Input , input shape should be (Batch,D,H,W,channels) 194 | nb_classes: number of classes 195 | encoder_nb_layers: Number of Layer in each dense_block 196 | growth_rate: Number of filters in each DenseBlock 197 | dilation_list=Dilation rate each level 198 | dropout_rate: dropout rate 199 | weight_decay: weight decay 200 | Returns: Returns the Segmentation Prediction of Given Input Shape 201 | """ 202 | #--|Getting the Input 203 | img_input = inputs 204 | input_shape = tf.shape(img_input) # Input shape 205 | nb_dense_block = len(encoder_nb_layers)# Convert tuple to list 206 | 207 | # Initial convolution 208 | x = tf.keras.layers.Conv3D(init_conv_filters, (5, 5, 5), 209 | strides=2, 210 | kernel_initializer='he_normal', 211 | padding='same', 212 | name='initial_conv3D', 213 | use_bias=False, 214 | kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(img_input) 215 | x = tf.keras.layers.BatchNormalization(epsilon=1.1e-5)(x) 216 | x = tf.nn.relu6(x) 217 | 218 | #Making the skiplist for concationatin 219 | skip_list = [] 220 | 221 | # Add dense blocks 222 | for block_idx in range(nb_dense_block): 223 | ''' 224 | |--Input for dense_block is as following 225 | |---#x=Input, 226 | #encoder_nb_layers[block_idx]=Number of layer in a dense_block 227 | #growth_rate[block_idx]= Number of Filter in that DenseBlock 228 | #dilation_list= Dilation Rate. 229 | 230 | ''' 231 | x = dense_block(x, encoder_nb_layers[block_idx], 232 | growth_rate[block_idx], 233 | kernal_size=(3, 3, 3), 234 | dilation_list=dilation_list[block_idx], 235 | dropout_rate=dropout_rate, 236 | weight_decay=weight_decay, 237 | ) 238 | 239 | # Skip connection 240 | skip_list.append(x) 241 | #Pooling 242 | x = tf.keras.layers.AveragePooling3D((2, 2, 2))(x) 243 | # x = __transition_block(x, nb_filter,compression=compression,weight_decay=weight_decay,pool_kernal=(3, 3, 3),pool_strides=(2, 2, 2)) 244 | 245 | 246 | ##Convolutiion and third Resolution layer and Updample. 247 | x_level3 = conv_block(skip_list[-1], growth_rate[2], bottleneck=True, dropout_rate=dropout_rate) 248 | x_level3 = up_sampling(x_level3, scale=4) 249 | # x_level3 = UpSampling3D(size = (4,4,4))(x_level3) 250 | 251 | ##Convolutiion and 2nd Resolution layer and Updample. 252 | x_level2 = conv_block(skip_list[-2], growth_rate[1], bottleneck=True, dropout_rate=dropout_rate) 253 | x_level2 = up_sampling(x_level2, scale=2) 254 | # x_level2 = UpSampling3D(size=(2, 2, 2))(x_level2) 255 | 256 | ##Convolutiion and first Resolution layer 257 | x_level1 = conv_block(skip_list[-3], growth_rate[0], bottleneck=True, dropout_rate=dropout_rate) 258 | #x_level1 = up_sampling(x_level1, scale=2) 259 | x = tf.keras.layers.Concatenate()([x_level3, x_level2, x_level1]) 260 | 261 | ###--Final Convolution--- 262 | x = conv_block(x, 24, bottleneck=False, dropout_rate=dropout_rate) 263 | ##----Upsampling--TheFinal Output----##### 264 | x = up_sampling(x, scale=2) 265 | 266 | ####------Prediction---------------### 267 | if nb_classes == 1: 268 | x = tf.keras.layers.Conv3D(nb_classes, 1, activation='sigmoid', padding='same', use_bias=False)(x) 269 | elif nb_classes > 1: 270 | x = tf.keras.layers.Conv3D(nb_classes, 1, activation='softmax', padding='same', use_bias=False)(x) 271 | #x = tf.argmax(x, axis=-1) 272 | print(x) 273 | 274 | # Create model. 275 | model = tf.keras.Model(img_input, x, name='DenseVnet3D') 276 | return model 277 | ''' 278 | ###################----Demo Usages----############# 279 | INPUT_PATCH_SIZE=[384,192,192,1] 280 | NUMBER_OF_CLASSES=1 281 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 282 | 283 | #Model_3D=DenseVnet3D(inputs,nb_classes=1,encoder_nb_layers=(5, 8, 8),growth_rate=(4, 8, 12),dilation_list=(5, 3, 1)) 284 | Model_3D=DenseVnet3D(inputs,nb_classes=1,encoder_nb_layers=(4, 8, 16),growth_rate=(12,24,24),dilation_list=(5, 10, 10),dropout_rate=0.25) 285 | Model_3D.summary() 286 | tf.keras.utils.plot_model(Model_3D, 'DenseVnet3D.png',show_shapes=True) 287 | ''' 288 | -------------------------------------------------------------------------------- /DyFA/DyFA_Model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.layers import Conv3D, Input, MaxPooling3D, Dropout, concatenate, UpSampling3D 4 | import tensorflow as tf 5 | from config import* 6 | from loss_funnction_And_matrics import* 7 | import numpy as np 8 | from DenseVnet3D import DenseVnet3D 9 | #from Unet3D import Unet3D 10 | 11 | ####----Residual Blocks used for Resnet3D 12 | def Residual_Block(inputs, 13 | out_filters, 14 | kernel_size=(3, 3, 3), 15 | strides=(1, 1, 1), 16 | use_bias=False, 17 | activation=tf.nn.relu6, 18 | kernel_initializer=tf.keras.initializers.VarianceScaling(distribution='uniform'), 19 | bias_initializer=tf.zeros_initializer(), 20 | kernel_regularizer=tf.keras.regularizers.l2(l=0.001), 21 | bias_regularizer=None, 22 | **kwargs): 23 | 24 | 25 | conv_params={'padding': 'same', 26 | 'use_bias': use_bias, 27 | 'kernel_initializer': kernel_initializer, 28 | 'bias_initializer': bias_initializer, 29 | 'kernel_regularizer': kernel_regularizer, 30 | 'bias_regularizer': bias_regularizer} 31 | 32 | in_filters = inputs.get_shape().as_list()[-1] 33 | x=inputs 34 | orig_x=x 35 | 36 | ##building 37 | # Adjust the strided conv kernel size to prevent losing information 38 | k = [s * 2 if s > 1 else k for k, s in zip(kernel_size, strides)] 39 | 40 | if np.prod(strides) != 1: 41 | orig_x = tf.keras.layers.MaxPool3D(pool_size=strides,strides=strides,padding='valid')(orig_x) 42 | 43 | ##sub-unit-0 44 | x=tf.keras.layers.BatchNormalization()(x) 45 | x=activation(x) 46 | x=tf.keras.layers.Conv3D(filters=out_filters,kernel_size=k,strides=strides,**conv_params)(x) 47 | 48 | ##sub-unit-1 49 | x=tf.keras.layers.BatchNormalization()(x) 50 | x=activation(x) 51 | x=tf.keras.layers.Conv3D(filters=out_filters,kernel_size=kernel_size,strides=(1,1,1),**conv_params)(x) 52 | 53 | # Handle differences in input and output filter sizes 54 | if in_filters < out_filters: 55 | orig_x = tf.pad(tensor=orig_x,paddings=[[0, 0]] * (len(x.get_shape().as_list()) - 1) + [[ 56 | int(np.floor((out_filters - in_filters) / 2.)), 57 | int(np.ceil((out_filters - in_filters) / 2.))]]) 58 | 59 | elif in_filters > out_filters: 60 | orig_x = tf.keras.layers.Conv3D(filters=out_filters,kernel_size=kernel_size,strides=(1,1,1),**conv_params)(orig_x) 61 | 62 | x += orig_x 63 | return x 64 | 65 | 66 | 67 | ## Resnet----3D 68 | def Resnet3D(inputs, 69 | num_classes, 70 | num_res_units=TRAIN_NUM_RES_UNIT, 71 | filters=TRAIN_NUM_FILTERS, 72 | strides=TRAIN_STRIDES, 73 | use_bias=False, 74 | activation=TRAIN_CLASSIFY_ACTICATION, 75 | kernel_initializer=TRAIN_KERNAL_INITIALIZER, 76 | bias_initializer=tf.zeros_initializer(), 77 | kernel_regularizer=tf.keras.regularizers.l2(l=0.001), 78 | bias_regularizer=None, 79 | **kwargs): 80 | conv_params = {'padding': 'same', 81 | 'use_bias': use_bias, 82 | 'kernel_initializer': kernel_initializer, 83 | 'bias_initializer': bias_initializer, 84 | 'kernel_regularizer': kernel_regularizer, 85 | 'bias_regularizer': bias_regularizer} 86 | 87 | 88 | ##building 89 | k = [s * 2 if s > 1 else 3 for s in strides[0]] 90 | 91 | 92 | #Input 93 | x = inputs 94 | #1st-convo 95 | x=tf.keras.layers.Conv3D(filters[0], k, strides[0], **conv_params)(x) 96 | 97 | for res_scale in range(1, len(filters)): 98 | x = Residual_Block( 99 | inputs=x, 100 | out_filters=filters[res_scale], 101 | strides=strides[res_scale], 102 | activation=activation, 103 | name='unit_{}_0'.format(res_scale)) 104 | for i in range(1, num_res_units): 105 | x = Residual_Block( 106 | inputs=x, 107 | out_filters=filters[res_scale], 108 | strides=(1, 1, 1), 109 | activation=activation, 110 | name='unit_{}_{}'.format(res_scale, i)) 111 | 112 | 113 | x=tf.keras.layers.BatchNormalization()(x) 114 | x=activation(x) 115 | #axis = tuple(range(len(x.get_shape().as_list())))[1:-1] 116 | #x = tf.reduce_mean(x, axis=axis, name='global_avg_pool') 117 | x=tf.keras.layers.GlobalAveragePooling3D()(x) 118 | x =tf.keras.layers.Dropout(0.5)(x) 119 | classifier=tf.keras.layers.Dense(units=num_classes,activation='sigmoid')(x) 120 | 121 | #model = tf.keras.Model(inputs=inputs, outputs=classifier) 122 | #model.compile(optimizer=Adam(lr=TRAIN_CLASSIFY_LEARNING_RATE), loss=[TRAIN_CLASSIFY_LOSS], metrics=[TRAIN_CLASSIFY_METRICS,tf.keras.metrics.AUC()]) 123 | 124 | return classifier 125 | 126 | ### Final Model 127 | def DyFAModel_WithUnet(Unet_Model_Path,Input_shape,num_classes_clf,num_classes_for_seg): 128 | 129 | ###----Loading Segmentation Module---### 130 | inputs = tf.keras.Input(shape=Input_shape, name='CT') 131 | model_3DUnet=Unet3D(inputs,num_classes_for_seg) 132 | 133 | #-| Loading the Best Segmentation Weight 134 | model_3DUnet.load_weights(Unet_Model_Path) 135 | #-| Making the Segmentation Model Non-Trainable 136 | model_3DUnet.trainable = False 137 | 138 | #--| Getting the Features from Different Resolutions 139 | f_r1=(model_3DUnet.get_layer('Feature_R1').output) 140 | f_r2=(model_3DUnet.get_layer('Feature_R2').output) 141 | f_r3=(model_3DUnet.get_layer('Feature_R3').output) 142 | f_r4=(model_3DUnet.get_layer('Feature_R4').output) 143 | #f_r5=(model_3DUnet.get_layer('Feature_R5').output) 144 | last_predict=(model_3DUnet.get_layer('conv3d_17').output) 145 | #-| Upsampling the lower Resolution FA 146 | up2=(UpSampling3D(size = (2,2,2))(f_r2)) 147 | up3=(UpSampling3D(size = (4,4,4))(f_r3)) 148 | up4=(UpSampling3D(size = (8,8,8))(f_r4)) 149 | #up5=(UpSampling3D(size = (16,16,16))(f_r5)) 150 | #-| Concatenate the FAs 151 | FA_concatination=concatenate([f_r1,up2,up3,up4,last_predict],axis=-1) 152 | 153 | #-|| DyFA- Pass the Concatinated Feature to 1x1x1 convolution to get a 1 channel Volume. 154 | DyFA=tf.keras.layers.Conv3D(1, 1, name='DyFA')(FA_concatination) 155 | 156 | #-|| Making a HxWxDx2 channel Input data for the DyFA Classification Model 157 | DyFA_INPUT=concatenate([DyFA,inputs],axis=-1) 158 | 159 | DyFA_Model_output=Resnet3D(DyFA_INPUT,num_classes=num_classes_clf) 160 | 161 | Final_DyFAmodel=tf.keras.Model(inputs=inputs, outputs=DyFA_Model_output) 162 | 163 | return Final_DyFAmodel 164 | 165 | 166 | def DyFAModel_withDenseVnet(DenseVnet3D_Model_Path,Input_shape,num_classes_clf,num_classes_for_seg): 167 | 168 | ###----Loading Segmentation Module---### 169 | inputs = tf.keras.Input(shape=Input_shape, name='CT') 170 | model_3DDenseVnet=DenseVnet3D(inputs,nb_classes=SEG_NUMBER_OF_CLASSES,encoder_nb_layers=NUM_DENSEBLOCK_EACH_RESOLUTION,growth_rate=NUM_OF_FILTER_EACH_RESOLUTION,dilation_list=DILATION_RATE,dropout_rate=DROPOUT_RATE) 171 | #-| Loading the Best Segmentation Weight 172 | model_3DDenseVnet.load_weights(DenseVnet3D_Model_Path) 173 | model_3DDenseVnet.summary() 174 | #-| Making the Segmentation Model Non-Trainable 175 | model_3DDenseVnet.trainable = False 176 | #-| Getting the features 177 | f_60_192_96_96=(model_3DDenseVnet.get_layer('concatenate_25').output) 178 | #last_predict=(model_3DDenseVnet.get_layer('conv3d_63').output) 179 | #-| Upsampling the lower Resolution FA 180 | upsampled_F=(UpSampling3D(size = (2,2,2))(f_60_192_96_96)) 181 | #-| Concatenate the FAs 182 | #FA_concatination=concatenate([upsampled_F,last_predict],axis=-1) 183 | 184 | #-|| DyFA- Pass the Concatinated Feature to 1x1x1 convolution to get a 1 channel Volume. 185 | DyFA=tf.keras.layers.Conv3D(1, 1, name='DyFA')(upsampled_F) 186 | 187 | #-|| Making a HxWxDx2 channel Input data for the DyFA Classification Model 188 | DyFA_INPUT=concatenate([DyFA,inputs],axis=-1) 189 | 190 | DyFA_Model_output=Resnet3D(DyFA_INPUT,num_classes=num_classes_clf) 191 | 192 | Final_DyFAmodel=tf.keras.Model(inputs=inputs, outputs=DyFA_Model_output) 193 | 194 | return Final_DyFAmodel 195 | -------------------------------------------------------------------------------- /DyFA/Preprocessing_utlities.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | from __future__ import print_function 3 | from __future__ import division 4 | from __future__ import absolute_import 5 | import numpy as np 6 | from scipy.ndimage.interpolation import map_coordinates 7 | from scipy.ndimage.filters import gaussian_filter 8 | import tensorflow as tf 9 | 10 | 11 | def whitening(image): 12 | """Whitening. Normalises image to zero mean and unit variance.""" 13 | 14 | image = image.astype(np.float32) 15 | 16 | mean = np.mean(image) 17 | std = np.std(image) 18 | 19 | if std > 0: 20 | ret = (image - mean) / std 21 | else: 22 | ret = image * 0. 23 | return ret 24 | 25 | 26 | def normalise_zero_one(image): 27 | """Image normalisation. Normalises image to fit [0, 1] range.""" 28 | 29 | image = image.astype(np.float32) 30 | 31 | minimum = np.min(image) 32 | maximum = np.max(image) 33 | 34 | if maximum > minimum: 35 | ret = (image - minimum) / (maximum - minimum) 36 | else: 37 | ret = image * 0. 38 | return ret 39 | 40 | 41 | def normalise_one_one(image): 42 | """Image normalisation. Normalises image to fit [-1, 1] range.""" 43 | 44 | ret = normalise_zero_one(image) 45 | ret *= 2. 46 | ret -= 1. 47 | return ret 48 | 49 | 50 | def flip(imagelist, axis=1): 51 | """Randomly flip spatial dimensions 52 | Args: 53 | imagelist (np.ndarray or list or tuple): image(s) to be flipped 54 | axis (int): axis along which to flip the images 55 | Returns: 56 | np.ndarray or list or tuple: same as imagelist but randomly flipped 57 | along axis 58 | """ 59 | 60 | # Check if a single image or a list of images has been passed 61 | was_singular = False 62 | if isinstance(imagelist, np.ndarray): 63 | imagelist = [imagelist] 64 | was_singular = True 65 | 66 | # With a probility of 0.5 flip the image(s) across `axis` 67 | do_flip = np.random.random(1) 68 | if do_flip > 0.5: 69 | for i in range(len(imagelist)): 70 | imagelist[i] = np.flip(imagelist[i], axis=axis) 71 | if was_singular: 72 | return imagelist[0] 73 | return imagelist 74 | 75 | 76 | def add_gaussian_offset(image, sigma=0.1): 77 | """ 78 | Add Gaussian offset to an image. Adds the offset to each channel 79 | independently. 80 | Args: 81 | image (np.ndarray): image to add noise to 82 | sigma (float): stddev of the Gaussian distribution to generate noise 83 | from 84 | Returns: 85 | np.ndarray: same as image but with added offset to each channel 86 | """ 87 | 88 | offsets = np.random.normal(0, sigma, ([1] * (image.ndim - 1) + [image.shape[-1]])) 89 | image += offsets 90 | return image 91 | 92 | 93 | def add_gaussian_noise(image, sigma=0.05): 94 | """ 95 | Add Gaussian noise to an image 96 | Args: 97 | image (np.ndarray): image to add noise to 98 | sigma (float): stddev of the Gaussian distribution to generate noise 99 | from 100 | Returns: 101 | np.ndarray: same as image but with added offset to each channel 102 | """ 103 | 104 | image += np.random.normal(0, sigma, image.shape) 105 | return image 106 | 107 | 108 | def elastic_transform(image, alpha, sigma): 109 | """ 110 | Elastic deformation of images as described in [1]. 111 | [1] Simard, Steinkraus and Platt, "Best Practices for Convolutional 112 | Neural Networks applied to Visual Document Analysis", in Proc. of the 113 | International Conference on Document Analysis and Recognition, 2003. 114 | Based on gist https://gist.github.com/erniejunior/601cdf56d2b424757de5 115 | Args: 116 | image (np.ndarray): image to be deformed 117 | alpha (list): scale of transformation for each dimension, where larger 118 | values have more deformation 119 | sigma (list): Gaussian window of deformation for each dimension, where 120 | smaller values have more localised deformation 121 | Returns: 122 | np.ndarray: deformed image 123 | """ 124 | 125 | assert len(alpha) == len(sigma), \ 126 | "Dimensions of alpha and sigma are different" 127 | 128 | channelbool = image.ndim - len(alpha) 129 | out = np.zeros((len(alpha) + channelbool, ) + image.shape) 130 | 131 | # Generate a Gaussian filter, leaving channel dimensions zeroes 132 | for jj in range(len(alpha)): 133 | array = (np.random.rand(*image.shape) * 2 - 1) 134 | out[jj] = gaussian_filter(array, sigma[jj], 135 | mode="constant", cval=0) * alpha[jj] 136 | 137 | # Map mask to indices 138 | shapes = list(map(lambda x: slice(0, x, None), image.shape)) 139 | grid = np.broadcast_arrays(*np.ogrid[shapes]) 140 | indices = list(map((lambda x: np.reshape(x, (-1, 1))), grid + np.array(out))) 141 | 142 | # Transform image based on masked indices 143 | transformed_image = map_coordinates(image, indices, order=0, 144 | mode='reflect').reshape(image.shape) 145 | 146 | return transformed_image 147 | 148 | def extract_class_balanced_example_array(image, 149 | label, 150 | example_size=[1, 64, 64], 151 | n_examples=1, 152 | classes=2, 153 | class_weights=None): 154 | """Extract training examples from an image (and corresponding label) subject 155 | to class balancing. Returns an image example array and the 156 | corresponding label array. 157 | 158 | Args: 159 | image (np.ndarray): image to extract class-balanced patches from 160 | label (np.ndarray): labels to use for balancing the classes 161 | example_size (list or tuple): shape of the patches to extract 162 | n_examples (int): number of patches to extract in total 163 | classes (int or list or tuple): number of classes or list of classes 164 | to extract 165 | 166 | Returns: 167 | np.ndarray, np.ndarray: class-balanced patches extracted from full 168 | images with the shape [batch, example_size..., image_channels] 169 | """ 170 | assert image.shape[:-1] == label.shape, 'Image and label shape must match' 171 | assert image.ndim - 1 == len(example_size), \ 172 | 'Example size doesnt fit image size' 173 | #assert all([i_s >= e_s for i_s, e_s in zip(image.shape, example_size)]), \ 174 | #'Image must be larger than example shape' 175 | rank = len(example_size) 176 | 177 | 178 | 179 | if isinstance(classes, int): 180 | classes = tuple(range(classes)) 181 | n_classes = len(classes) 182 | 183 | 184 | if class_weights is None: 185 | n_ex_per_class = np.ones(n_classes).astype(int) * int(np.round(n_examples / n_classes)) 186 | else: 187 | assert len(class_weights) == n_classes, \ 188 | 'Class_weights must match number of classes' 189 | class_weights = np.array(class_weights) 190 | n_ex_per_class = np.round((class_weights / class_weights.sum()) * n_examples).astype(int) 191 | 192 | # Compute an example radius to define the region to extract around a 193 | # center location 194 | ex_rad = np.array(list(zip(np.floor(np.array(example_size) / 2.0), 195 | np.ceil(np.array(example_size) / 2.0))), 196 | dtype=np.int) 197 | 198 | class_ex_images = [] 199 | class_ex_lbls = [] 200 | min_ratio = 1. 201 | for c_idx, c in enumerate(classes): 202 | # Get valid, random center locations belonging to that class 203 | idx = np.argwhere(label == c) 204 | 205 | ex_images = [] 206 | ex_lbls = [] 207 | 208 | if len(idx) == 0 or n_ex_per_class[c_idx] == 0: 209 | class_ex_images.append([]) 210 | class_ex_lbls.append([]) 211 | continue 212 | 213 | # Extract random locations 214 | r_idx_idx = np.random.choice(len(idx), 215 | size=min(n_ex_per_class[c_idx], len(idx)), 216 | replace=False).astype(int) 217 | r_idx = idx[r_idx_idx] 218 | 219 | # Shift the random to valid locations if necessary 220 | r_idx = np.array( 221 | [np.array([max(min(r[dim], image.shape[dim] - ex_rad[dim][1]), 222 | ex_rad[dim][0]) for dim in range(rank)]) 223 | for r in r_idx]) 224 | 225 | for i in range(len(r_idx)): 226 | # Extract class-balanced examples from the original image 227 | slicer = [slice(r_idx[i][dim] - ex_rad[dim][0], r_idx[i][dim] + ex_rad[dim][1]) for dim in range(rank)] 228 | 229 | ex_image = image[slicer][np.newaxis, :] 230 | 231 | ex_lbl = label[slicer][np.newaxis, :] 232 | 233 | # Concatenate them and return the examples 234 | ex_images = np.concatenate((ex_images, ex_image), axis=0) \ 235 | if (len(ex_images) != 0) else ex_image 236 | ex_lbls = np.concatenate((ex_lbls, ex_lbl), axis=0) \ 237 | if (len(ex_lbls) != 0) else ex_lbl 238 | 239 | class_ex_images.append(ex_images) 240 | class_ex_lbls.append(ex_lbls) 241 | 242 | ratio = n_ex_per_class[c_idx] / len(ex_images) 243 | min_ratio = ratio if ratio < min_ratio else min_ratio 244 | 245 | indices = np.floor(n_ex_per_class * min_ratio).astype(int) 246 | 247 | ex_images = np.concatenate([cimage[:idxs] for cimage, idxs in zip(class_ex_images, indices) 248 | if len(cimage) > 0], axis=0) 249 | ex_lbls = np.concatenate([clbl[:idxs] for clbl, idxs in zip(class_ex_lbls, indices) 250 | if len(clbl) > 0], axis=0) 251 | 252 | return ex_images, ex_lbls 253 | 254 | def resize_image_with_crop_or_pad(image, img_size=(64, 64, 64), **kwargs): 255 | """Image resizing. Resizes image by cropping or padding dimension 256 | to fit specified size. 257 | Args: 258 | image (np.ndarray): image to be resized 259 | img_size (list or tuple): new image size 260 | kwargs (): additional arguments to be passed to np.pad 261 | Returns: 262 | np.ndarray: resized image 263 | """ 264 | 265 | assert isinstance(image, (np.ndarray, np.generic)) 266 | assert (image.ndim - 1 == len(img_size) or image.ndim == len(img_size)), \ 267 | 'Example size doesnt fit image size' 268 | 269 | # Get the image dimensionality 270 | rank = len(img_size) 271 | 272 | # Create placeholders for the new shape 273 | from_indices = [[0, image.shape[dim]] for dim in range(rank)] 274 | to_padding = [[0, 0] for dim in range(rank)] 275 | 276 | slicer = [slice(None)] * rank 277 | 278 | # For each dimensions find whether it is supposed to be cropped or padded 279 | for i in range(rank): 280 | if image.shape[i] < img_size[i]: 281 | to_padding[i][0] = (img_size[i] - image.shape[i]) // 2 282 | to_padding[i][1] = img_size[i] - image.shape[i] - to_padding[i][0] 283 | else: 284 | from_indices[i][0] = int(np.floor((image.shape[i] - img_size[i]) / 2.)) 285 | from_indices[i][1] = from_indices[i][0] + img_size[i] 286 | 287 | # Create slicer object to crop or leave each dimension 288 | slicer[i] = slice(from_indices[i][0], from_indices[i][1]) 289 | 290 | # Pad the cropped image to extend the missing dimension 291 | return np.pad(image[slicer], to_padding, **kwargs) 292 | 293 | 294 | def extract_random_example_array(image_list,example_size=[1, 64, 64],n_examples=1): 295 | 296 | """Randomly extract training examples from image (and a corresponding label). 297 | Returns an image example array and the corresponding label array. 298 | Args: 299 | image_list (np.ndarray or list or tuple): image(s) to extract random 300 | patches from 301 | example_size (list or tuple): shape of the patches to extract 302 | n_examples (int): number of patches to extract in total 303 | Returns: 304 | np.ndarray, np.ndarray: class-balanced patches extracted from full 305 | images with the shape [batch, example_size..., image_channels] 306 | """ 307 | 308 | assert n_examples > 0 309 | 310 | was_singular = False 311 | if isinstance(image_list, np.ndarray): 312 | image_list = [image_list] 313 | was_singular = True 314 | 315 | assert all([i_s >= e_s for i_s, e_s in zip(image_list[0].shape, example_size)]), \ 316 | 'Image must be bigger than example shape' 317 | assert (image_list[0].ndim - 1 == len(example_size) or image_list[0].ndim == len(example_size)), \ 318 | 'Example size doesnt fit image size' 319 | 320 | for i in image_list: 321 | if len(image_list) > 1: 322 | assert (i.ndim - 1 == image_list[0].ndim or i.ndim == image_list[0].ndim or i.ndim + 1 == image_list[0].ndim),\ 323 | 'Example size doesnt fit image size' 324 | 325 | assert all([i0_s == i_s for i0_s, i_s in zip(image_list[0].shape, i.shape)]), \ 326 | 'Image shapes must match' 327 | 328 | rank = len(example_size) 329 | 330 | # Extract random examples from image and label 331 | valid_loc_range = [image_list[0].shape[i] - example_size[i] for i in range(rank)] 332 | 333 | rnd_loc = [np.random.randint(valid_loc_range[dim], size=n_examples) 334 | if valid_loc_range[dim] > 0 335 | else np.zeros(n_examples, dtype=int) for dim in range(rank)] 336 | 337 | examples = [[]] * len(image_list) 338 | for i in range(n_examples): 339 | slicer = [slice(rnd_loc[dim][i], rnd_loc[dim][i] + example_size[dim]) 340 | for dim in range(rank)] 341 | 342 | for j in range(len(image_list)): 343 | ex_image = image_list[j][slicer][np.newaxis] 344 | # Concatenate and return the examples 345 | examples[j] = np.concatenate((examples[j], ex_image), axis=0) \ 346 | if (len(examples[j]) != 0) else ex_image 347 | 348 | if was_singular: 349 | return examples[0] 350 | return examples 351 | -------------------------------------------------------------------------------- /DyFA/config.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from loss_funnction_And_matrics import* 3 | import math 4 | ###---Number-of-GPU 5 | NUM_OF_GPU=2 6 | #["gpu:1","gpu:2","gpu:3"] 7 | DISTRIIBUTED_STRATEGY_GPUS=["gpu:0","gpu:1"] 8 | ###-----SEGMENATTION----### 9 | SEGMENTATION_MODEL_PATH='/image_data/Scripts/April_Model/DyFA_61FC1X1_April17_2020/LungSEG_DenseVnet_2.60_4998.h5' 10 | SEGMENTATION_NUM_OF_CLASSES=31 11 | #####-----Configure DenseVnet3D---########## 12 | SEG_NUMBER_OF_CLASSES=31 13 | SEG_INPUT_PATCH_SIZE=(128,160,160, 1) 14 | NUM_DENSEBLOCK_EACH_RESOLUTION=(4, 8, 16) 15 | NUM_OF_FILTER_EACH_RESOLUTION=(12,24,24) 16 | DILATION_RATE=(5, 10, 10) 17 | DROPOUT_RATE=0.25 18 | ###----Resume-Training 19 | RESUME_TRAINING=0 20 | RESUME_TRAIING_MODEL='/image_data/Scripts/April_Model/DyFA_61FC1X1_April17_2020/Model_DyFA_61FC1X1_April17_2020/' 21 | TRAINING_INITIAL_EPOCH=0 22 | ##Network Configuration 23 | NUMBER_OF_CLASSES=5 24 | INPUT_PATCH_SIZE=(128,160,160, 1) 25 | TRAIN_NUM_RES_UNIT=3 26 | TRAIN_NUM_FILTERS=(16, 32, 64, 128) 27 | TRAIN_STRIDES=((1, 1, 1), (2, 2, 2), (2, 2, 2), (2, 2, 2)) 28 | TRAIN_CLASSIFY_ACTICATION=tf.nn.relu6 29 | TRAIN_KERNAL_INITIALIZER=tf.keras.initializers.VarianceScaling(distribution='uniform') 30 | ##Training Hyper-Parameter 31 | ##Training Hyper-Parameter 32 | TRAIN_CLASSIFY_LEARNING_RATE =1e-4 33 | TRAIN_CLASSIFY_LOSS=Weighted_BCTL 34 | OPTIMIZER=tf.keras.optimizers.Adam(lr=TRAIN_CLASSIFY_LEARNING_RATE,epsilon=1e-5) 35 | TRAIN_CLASSIFY_METRICS=tf.keras.metrics.AUC() 36 | BATCH_SIZE=12 37 | TRAINING_STEP_PER_EPOCH=math.ceil((3514)/BATCH_SIZE) 38 | VALIDATION_STEP=math.ceil((759)/BATCH_SIZE) 39 | TRAING_EPOCH=300 40 | NUMBER_OF_PARALLEL_CALL=6 41 | PARSHING=3*BATCH_SIZE 42 | #--Callbacks----- 43 | ModelCheckpoint_MOTITOR='val_loss' 44 | TRAINING_SAVE_MODEL_PATH='/image_data/Scripts/April_Model/DyFA_61FC1X1_April17_2020/Model_DyFA_61FC1X1_April17_2020/' 45 | TRAINING_CSV='DyFA_61FC1X1_April17_2020.csv' 46 | LOG_NAME="Log_DyFA_61FC1X1_April17_2020" 47 | MODEL_SAVING_NAME="DyFAModel61FC1X1_{val_loss:.2f}_{epoch}.h5" 48 | 49 | #### 50 | TRAINING_TF_RECORDS='/image_data/nobackup/Lung_CenterPatch_2mm_March27_2020/tf/Train_tfrecords/' 51 | VALIDATION_TF_RECORDS='/image_data/nobackup/Lung_CenterPatch_2mm_March27_2020/tf/Val_tfrecords/' 52 | -------------------------------------------------------------------------------- /DyFA/loss_funnction_And_matrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | ####---Loss 4 | @tf.function 5 | def macro_soft_f1(y, y_hat): 6 | """Compute the macro soft F1-score as a cost (average 1 - soft-F1 across all labels). 7 | Use probability values instead of binary predictions. 8 | 9 | Args: 10 | y (int32 Tensor): targets array of shape (BATCH_SIZE, N_LABELS) 11 | y_hat (float32 Tensor): probability matrix from forward propagation of shape (BATCH_SIZE, N_LABELS) 12 | 13 | Returns: 14 | cost (scalar Tensor): value of the cost function for the batch 15 | """ 16 | y = tf.cast(y, tf.float32) 17 | y_hat = tf.cast(y_hat, tf.float32) 18 | tp = tf.reduce_sum(y_hat * y, axis=0) 19 | fp = tf.reduce_sum(y_hat * (1 - y), axis=0) 20 | fn = tf.reduce_sum((1 - y_hat) * y, axis=0) 21 | soft_f1 = 2*tp / (2*tp + fn + fp + 1e-16) 22 | cost = 1 - soft_f1 # reduce 1 - soft-f1 in order to increase soft-f1 23 | macro_cost = tf.reduce_mean(cost) # average on all labels 24 | return macro_cost 25 | 26 | 27 | ###Matrics 28 | @tf.function 29 | def macro_f1(y, y_hat, thresh=0.5): 30 | """Compute the macro F1-score on a batch of observations (average F1 across labels) 31 | 32 | Args: 33 | y (int32 Tensor): labels array of shape (BATCH_SIZE, N_LABELS) 34 | y_hat (float32 Tensor): probability matrix from forward propagation of shape (BATCH_SIZE, N_LABELS) 35 | thresh: probability value above which we predict positive 36 | 37 | Returns: 38 | macro_f1 (scalar Tensor): value of macro F1 for the batch 39 | """ 40 | y_pred = tf.cast(tf.greater(y_hat, thresh), tf.float32) 41 | tp = tf.cast(tf.math.count_nonzero(y_pred * y, axis=0), tf.float32) 42 | fp = tf.cast(tf.math.count_nonzero(y_pred * (1 - y), axis=0), tf.float32) 43 | fn = tf.cast(tf.math.count_nonzero((1 - y_pred) * y, axis=0), tf.float32) 44 | f1 = 2*tp / (2*tp + fn + fp + 1e-16) 45 | macro_f1 = tf.reduce_mean(f1) 46 | return macro_f1 47 | 48 | 49 | 50 | @tf.function 51 | def Weighted_BCTL(y_true, y_pred): 52 | 53 | # Manually calculate the weighted cross entropy. 54 | # Formula is qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 55 | # where z are labels, x is logits, and q is the weight. 56 | # Since the values passed are from sigmoid (assuming in this case) 57 | # sigmoid(x) will be replaced by y_pred 58 | # qz * -log(sigmoid(x)) 1e-6 is added as an epsilon to stop passing a zero into the log 59 | 60 | ##get the positive labels 61 | 62 | y_true = tf.cast(y_true, tf.float32) 63 | y_pred = tf.cast(y_pred , tf.float32) 64 | 65 | P=tf.cast(tf.math.count_nonzero(y_true), tf.float32) 66 | N=tf.cast(len(tf.where(y_true==0)),tf.float32) 67 | 68 | BP1=P+N/P 69 | BP=tf.cast(BP1,tf.float32) 70 | 71 | BN=N+P/N 72 | BN=tf.cast(BN,tf.float32) 73 | 74 | 75 | x_1 =BP*(y_true * -tf.math.log(y_pred + 1e-6)) 76 | x_2 =BN*((1 - y_true) * -tf.math.log(1 - y_pred + 1e-6)) 77 | 78 | cost=tf.add(x_1, x_2) 79 | cost_a=tf.reduce_mean(cost) 80 | return cost_a 81 | 82 | -------------------------------------------------------------------------------- /DyFA/resume_training_using_check_point.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | import tensorflow as tf 3 | ''' 4 | tf.config.optimizer.set_jit(True) 5 | gpus = tf.config.experimental.list_physical_devices('GPU') 6 | if gpus: 7 | # Restrict TensorFlow to only use the first GPU 8 | try: 9 | tf.config.experimental.set_visible_devices(gpus[0], 'GPU') 10 | except RuntimeError as e: 11 | # Visible devices must be set at program startup 12 | print(e) 13 | ''' 14 | 15 | from tensorflow.keras.optimizers import Adam 16 | from config import* 17 | import os 18 | import datetime 19 | from DyFA_Model import* 20 | from tfrecords_utilities import decode_ct 21 | import numpy as np 22 | import random 23 | 24 | ####----Getting --the tfrecords 25 | def getting_list(path): 26 | a=[file for file in os.listdir(path) if file.endswith('.tfrecords')] 27 | all_tfrecoeds=random.sample(a, len(a)) 28 | #all_tfrecoeds.sort(key=lambda f: int(filter(str.isdigit, f))) 29 | list_of_tfrecords=[] 30 | for i in range(len(all_tfrecoeds)): 31 | tf_path=path+all_tfrecoeds[i] 32 | list_of_tfrecords.append(tf_path) 33 | return list_of_tfrecords 34 | 35 | #--Traing Decoder 36 | def load_training_tfrecords(record_mask_file,batch_size): 37 | dataset=tf.data.Dataset.list_files(record_mask_file).interleave(lambda x: tf.data.TFRecordDataset(x),cycle_length=NUMBER_OF_PARALLEL_CALL,num_parallel_calls=NUMBER_OF_PARALLEL_CALL) 38 | dataset=dataset.map(decode_ct,num_parallel_calls=NUMBER_OF_PARALLEL_CALL).repeat(TRAING_EPOCH).batch(batch_size) 39 | batched_dataset=dataset.prefetch(PARSHING) 40 | return batched_dataset 41 | 42 | #--Validation Decoder 43 | def load_validation_tfrecords(record_mask_file,batch_size): 44 | dataset=tf.data.Dataset.list_files(record_mask_file).interleave(tf.data.TFRecordDataset,cycle_length=NUMBER_OF_PARALLEL_CALL,num_parallel_calls=NUMBER_OF_PARALLEL_CALL) 45 | dataset=dataset.map(decode_ct,num_parallel_calls=NUMBER_OF_PARALLEL_CALL).repeat(TRAING_EPOCH).batch(batch_size) 46 | batched_dataset=dataset.prefetch(PARSHING) 47 | return batched_dataset 48 | 49 | 50 | def Training(): 51 | 52 | #TensorBoard 53 | logdir = os.path.join(LOG_NAME, datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) 54 | tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1) 55 | ##csv_logger 56 | csv_logger = tf.keras.callbacks.CSVLogger(TRAINING_CSV) 57 | ##Model-checkpoings 58 | path=TRAINING_SAVE_MODEL_PATH 59 | model_path=os.path.join(path, MODEL_SAVING_NAME) 60 | Model_callback= tf.keras.callbacks.ModelCheckpoint(filepath=model_path,save_best_only=False,save_weights_only=True,monitor=ModelCheckpoint_MOTITOR,verbose=1) 61 | ##----Preparing Data 62 | tf_train=getting_list(TRAINING_TF_RECORDS) 63 | tf_val=getting_list(VALIDATION_TF_RECORDS) 64 | traing_data=load_training_tfrecords(tf_train,BATCH_SIZE) 65 | Val_batched_dataset=load_validation_tfrecords(tf_val,BATCH_SIZE) 66 | 67 | if (NUM_OF_GPU==1): 68 | if RESUME_TRAINING==1: 69 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 70 | Model_3D=DyFAModel_withDenseVnet(SEGMENTATION_MODEL_PATH,INPUT_PATCH_SIZE,NUMBER_OF_CLASSES,SEGMENTATION_NUM_OF_CLASSES) 71 | Model_3D.load_weights(RESUME_TRAIING_MODEL) 72 | initial_epoch_of_training=TRAINING_INITIAL_EPOCH 73 | print('Resume-Training From-Epoch{}-Loading-Model-from_{}'.format(initial_epoch_of_training,RESUME_TRAIING_MODEL)) 74 | Model_3D.compile(optimizer=OPTIMIZER, loss=[TRAIN_CLASSIFY_LOSS], metrics=[tf.keras.metrics.AUC(),tf.keras.metrics.Precision(),tf.keras.metrics.Recall()]) 75 | Model_3D.summary() 76 | else: 77 | initial_epoch_of_training=0 78 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 79 | Model_3D=DyFAModel_withDenseVnet(SEGMENTATION_MODEL_PATH,INPUT_PATCH_SIZE,NUMBER_OF_CLASSES,SEGMENTATION_NUM_OF_CLASSES) 80 | Model_3D.compile(optimizer=OPTIMIZER, loss=[TRAIN_CLASSIFY_LOSS], metrics=[TRAIN_CLASSIFY_METRICS,tf.keras.metrics.Precision(),tf.keras.metrics.Recall()]) 81 | Model_3D.summary() 82 | Model_3D.fit(traing_data, 83 | steps_per_epoch=TRAINING_STEP_PER_EPOCH, 84 | epochs=TRAING_EPOCH, 85 | initial_epoch=initial_epoch_of_training, 86 | validation_data=Val_batched_dataset, 87 | validation_steps=VALIDATION_STEP, 88 | callbacks=[tensorboard_callback,csv_logger,Model_callback]) 89 | 90 | ###Multigpu---- 91 | else: 92 | mirrored_strategy = tf.distribute.MirroredStrategy(DISTRIIBUTED_STRATEGY_GPUS) 93 | with mirrored_strategy.scope(): 94 | if RESUME_TRAINING==1: 95 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 96 | Model_3D=DyFAModel_withDenseVnet(SEGMENTATION_MODEL_PATH,INPUT_PATCH_SIZE,NUMBER_OF_CLASSES,SEGMENTATION_NUM_OF_CLASSES) 97 | Model_3D.load_weights(RESUME_TRAIING_MODEL) 98 | initial_epoch_of_training=TRAINING_INITIAL_EPOCH 99 | print('Resume-Training From-Epoch{}-Loading-Model-from_{}'.format(initial_epoch_of_training,RESUME_TRAIING_MODEL)) 100 | Model_3D.compile(optimizer=OPTIMIZER, loss=[TRAIN_CLASSIFY_LOSS], metrics=[tf.keras.metrics.AUC(),tf.keras.metrics.Precision(),tf.keras.metrics.Recall()]) 101 | Model_3D.summary() 102 | else: 103 | initial_epoch_of_training=0 104 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 105 | Model_3D=DyFAModel_withDenseVnet(SEGMENTATION_MODEL_PATH,INPUT_PATCH_SIZE,NUMBER_OF_CLASSES,SEGMENTATION_NUM_OF_CLASSES) 106 | Model_3D.compile(optimizer=OPTIMIZER, loss=[TRAIN_CLASSIFY_LOSS], metrics=[TRAIN_CLASSIFY_METRICS,tf.keras.metrics.Precision(),tf.keras.metrics.Recall()]) 107 | Model_3D.summary() 108 | Model_3D.fit(traing_data, 109 | steps_per_epoch=TRAINING_STEP_PER_EPOCH, 110 | epochs=TRAING_EPOCH, 111 | initial_epoch=initial_epoch_of_training, 112 | validation_data=Val_batched_dataset, 113 | validation_steps=VALIDATION_STEP, 114 | callbacks=[tensorboard_callback,csv_logger,Model_callback]) 115 | 116 | if __name__ == '__main__': 117 | Training() 118 | -------------------------------------------------------------------------------- /DyFA/tfrecords_utilities.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import pandas as pd 6 | import SimpleITK as sitk 7 | from Preprocessing_utlities import extract_class_balanced_example_array 8 | from Preprocessing_utlities import resize_image_with_crop_or_pad 9 | from scipy.ndimage.interpolation import map_coordinates 10 | from scipy.ndimage.filters import gaussian_filter 11 | from config import* 12 | 13 | 14 | ########################-------Fucntions for tf records 15 | # The following functions can be used to convert a value to a type compatible 16 | # with tf.Example. 17 | def _bytes_feature(value): 18 | """Returns a bytes_list from a string / byte.""" 19 | if isinstance(value, type(tf.constant(0))): 20 | value = value.numpy() # BytesList won't unpack a string from an EagerTensor. 21 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 22 | 23 | def _float_feature(value): 24 | """Returns a float_list from a float / double.""" 25 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 26 | 27 | def _int64_feature(value): 28 | """Returns an int64_list from a bool / enum / int / uint.""" 29 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 30 | 31 | 32 | def flow_from_df(dataframe: pd.DataFrame, chunk_size): 33 | for start_row in range(0, dataframe.shape[0], chunk_size): 34 | end_row = min(start_row + chunk_size, dataframe.shape[0]) 35 | yield dataframe.iloc[start_row:end_row, :] 36 | 37 | 38 | def creat_tfrecord(df,extraction_perameter,tf_name): 39 | 40 | read_csv=df.as_matrix() 41 | patch_params = extraction_perameter 42 | 43 | img_list=[] 44 | mask_list=[] 45 | lbl_list=[] 46 | id_name=[] 47 | 48 | for Data in read_csv: 49 | img_path = Data[4] 50 | subject_id = img_path.split('/')[-1].split('.')[0] 51 | Subject_lbl=Data[5:10] 52 | print(Subject_lbl.shape) 53 | 54 | print('Subject ID-{}'.format(subject_id)) 55 | print('Labels--{}'.format(Subject_lbl)) 56 | 57 | #Img 58 | img_sitk = sitk.ReadImage(img_path, sitk.sitkFloat32) 59 | image= sitk.GetArrayFromImage(img_sitk) 60 | #Mask 61 | mask_fn = str(Data[10]) 62 | mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_fn)).astype(np.int32) 63 | print(mask.shape) 64 | print(image.shape) 65 | 66 | patch_size=patch_params['example_size'] 67 | img_shape=image.shape 68 | 69 | ###----padding_data_if_needed 70 | #####----z dimention-----###### 71 | if (patch_size[0] >=img_shape[0]): 72 | dimention1=patch_size[0]+10 73 | else: 74 | dimention1=img_shape[0] 75 | 76 | #####----x dimention-----###### 77 | if (patch_size[1] >=img_shape[1]): 78 | dimention2=patch_size[1]+10 79 | else: 80 | dimention2=img_shape[1] 81 | 82 | #####----Y dimention-----###### 83 | if (patch_size[2] >=img_shape[2]): 84 | dimention3=patch_size[2]+10 85 | else: 86 | dimention3=img_shape[2] 87 | print('------before padding image shape--{}-----'.format(image.shape)) 88 | image=resize_image_with_crop_or_pad(image, [dimention1,dimention2,dimention3], mode='symmetric') 89 | mask=resize_image_with_crop_or_pad(mask, [dimention1,dimention2,dimention3], mode='symmetric') 90 | print('######before padding image shape--{}#####'.format(image.shape)) 91 | 92 | 93 | 94 | img_shape=image.shape 95 | image= np.expand_dims(image, axis=3) 96 | 97 | images,masks = extract_class_balanced_example_array( 98 | image,mask, 99 | example_size=patch_params['example_size'], 100 | n_examples=patch_params['n_examples'], 101 | classes=4,class_weights=[0,0,1,1]) 102 | 103 | print(images.shape) 104 | 105 | for e in range(patch_params['n_examples']): 106 | img_list.append(images[e][:,:,:,0]) 107 | #print(images[e][:,:,:,0].shape) 108 | mask_list.append(masks[e][:,:,:]) 109 | #print('Mask-Shape=={}'.format(masks[e][:,:,:].shape)) 110 | lbl_list.append(Subject_lbl) 111 | patch_name=str(subject_id+'_{}'.format(e)) 112 | #Converting_string_bytes 113 | patch_name =bytes(patch_name, 'utf-8') 114 | #print(patch_name) 115 | id_name.append(patch_name) 116 | 117 | print('This Rfrecords will contain--{}--Pathes--of-size--{}'.format(len(id_name),patch_params['example_size'])) 118 | 119 | record_mask_file = tf_name 120 | with tf.io.TFRecordWriter(record_mask_file) as writer: 121 | for e in range(len(img_list)): 122 | feature = {'label1': _int64_feature(lbl_list[e][0]), 123 | 'label2': _int64_feature(lbl_list[e][1]), 124 | 'label3': _int64_feature(lbl_list[e][2]), 125 | 'label4': _int64_feature(lbl_list[e][3]), 126 | 'label5': _int64_feature(lbl_list[e][4]), 127 | 'image':_bytes_feature(img_list[e].tostring()), 128 | 'mask':_bytes_feature(mask_list[e].tostring()), 129 | 'Height':_int64_feature(patch_params['example_size'][0]), 130 | 'Weight':_int64_feature(patch_params['example_size'][1]), 131 | 'Depth':_int64_feature(patch_params['example_size'][2]), 132 | 'label_shape':_int64_feature(5), 133 | 'Sub_id':_bytes_feature(id_name[e]) 134 | } 135 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 136 | writer.write(example.SerializeToString()) 137 | 138 | writer.close() 139 | 140 | return 141 | 142 | 143 | @tf.function 144 | def decode_ct(Serialized_example): 145 | 146 | features={ 147 | 'label1': tf.io.FixedLenFeature([],tf.int64), 148 | 'label2': tf.io.FixedLenFeature([],tf.int64), 149 | 'label3': tf.io.FixedLenFeature([],tf.int64), 150 | 'label4': tf.io.FixedLenFeature([],tf.int64), 151 | 'label5': tf.io.FixedLenFeature([],tf.int64), 152 | 'image':tf.io.FixedLenFeature([],tf.string), 153 | 'mask':tf.io.FixedLenFeature([],tf.string), 154 | 'Height':tf.io.FixedLenFeature([],tf.int64), 155 | 'Weight':tf.io.FixedLenFeature([],tf.int64), 156 | 'Depth':tf.io.FixedLenFeature([],tf.int64), 157 | 'label_shape':tf.io.FixedLenFeature([],tf.int64), 158 | 'Sub_id':tf.io.FixedLenFeature([],tf.string) 159 | 160 | } 161 | examples=tf.io.parse_single_example(Serialized_example,features) 162 | ##Decode_image_float 163 | image_1 = tf.io.decode_raw(examples['image'], float) 164 | #Decode_mask_as_int32 165 | #mask_1 = tf.io.decode_raw(examples['mask'], tf.int32) 166 | ##Subject id is already in bytes format 167 | #sub_id=examples['Sub_id'] 168 | 169 | 170 | img_shape=[examples['Height'],examples['Weight'],examples['Depth']] 171 | #img_shape2=[img_shape[0],img_shape[1],img_shape[2]] 172 | print(img_shape) 173 | #Resgapping_the_data 174 | img=tf.reshape(image_1,img_shape) 175 | #Because CNN expect(batch,H,W,D,CHANNEL) 176 | img=tf.expand_dims(img, axis=-1) 177 | #mask=tf.reshape(mask_1,img_shape) 178 | #mask=tf.expand_dims(mask, axis=-1) 179 | ###casting_values 180 | img=tf.cast(img, tf.float32) 181 | #mask=tf.cast(mask,tf.int32) 182 | 183 | lbl=[examples['label1'],examples['label2'],examples['label3'],examples['label4'],examples['label5']] 184 | ##Transpossing the Multilabels 185 | #lbl=tf.linalg.matrix_transpose(lbl) 186 | return img,lbl 187 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Weakly-Supervised-3D-Classification-of-Chest-CT-using-Aggregated-Multi-Resolution-Deep-Segmentation-Features 2 | This Repo contains the updated implementation of our paper "Weakly supervised 3D classification of chest CT using aggregated multi-resolution deep segmentation features", Proc. SPIE 11314, Medical Imaging 2020: Computer-Aided Diagnosis, 1131408 (16 March 2020). 3 | 4 | * Version-1: Implemented Segmentation Module and Classification Seperately and was in Tensorflow 1.x 5 | Can be seen here: https://github.com/anindox8/Deep-Segmentation-Features-for-Weakly-Supervised-3D-Disease-Classification-in-Chest-CT 6 | 7 | * Version-2: Updated the Implementation , For reducing computation expenses the Segmenation Module and Classifiction Module is combined,updated implementation is in Tensorflow 2.0. This implemnetation is 2 times faster than the Version-1 in terms of training. Also Project has been moved from multi-class to multi-label classification setup (Follow the SPIE presentation for clear idea). 8 | 9 | If our work help in your task or project please site the work at (https://doi.org/10.1117/12.2550857). This work is been presented at SPIE Medical Imaging, 2020, Houston, Texas, United States. Presentation can be seen here : https://www.spiedigitallibrary.org/conference-proceedings-of-spie/11314/2550857/Weakly-supervised-3D-classification-of-chest-CT-using-aggregated-multi/10.1117/12.2550857.full?SSO=1 10 | 11 | ![model Architecture for multi-label approach version-2 implementation](https://github.com/fitushar/Weakly-Supervised-3D-Classification-of-Chest-CT-using-Aggregated-Multi-Resolution-Deep-Segmentation-/blob/master/figure/Model_Architecture.png) 12 | 13 | ## Citation 14 | ``` 15 | Anindo Saha*, Fakrul I. Tushar*, Khrystyna Faryna, Vincent M. D'Anniballe, Rui Hou, 16 | Maciej A. Mazurowski, Geoffrey D. Rubin M.D., and Joseph Y. Lo 17 | "Weakly supervised 3D classification of chest CT using aggregated multi-resolution deep segmentation features", 18 | Proc. SPIE 11314, Medical Imaging 2020: Computer-Aided Diagnosis,1131408 (16 March 2020); 19 | https://doi.org/10.1117/12.2550857 20 | (*Authors with equal contribution to this work.) 21 | ``` 22 | 23 | ## Directories and Files 24 | * i) DyFA -|--> Dynamic Feature aggragation Model and training script. 25 | ```ruby 26 | a) config.py |-- Configuration file to train the DyFA model 27 | b) DenseVnet3D.py |-- 3D implementation of the DenseVnet (Segmentation Module) 28 | c) DyFA_Model.py |-- DyFA model (Segmentation+Classification Module) 29 | d) loss_funnction_And_matrics |-- Loss Function. 30 | e) resume_training_using_check_point |-- Training Script 31 | f) tfrecords_utilities |-- Tfrecords decoding function 32 | ``` 33 | * ii) SyFa -|--> Static Feature aggragation Model and training script. 34 | ```ruby 35 | a) config.py |-- Configuration file to train the DyFA model 36 | b) DenseVnet3D.py |-- 3D implementation of the DenseVnet (Segmentation Module) 37 | c) StFA_Model.py |-- StFA model (Segmentation+Classification Module) 38 | d) loss_funnction_And_matrics |-- Loss Function. 39 | e) resume_training_using_check_point |-- Training Script 40 | f) tfrecords_utilities |-- Tfrecords decoding function 41 | ``` 42 | * iii) Figure -|--> Figure used in this Repo 43 | * iv) SPIE_presentation -|--> SPIE presentation 44 | 45 | 46 | ## How to run 47 | 48 | To run the model all is to need to configure the `config.py` based on your requiremnet. and use the command 49 | 50 | * `python resume_training_using_check_point.py` 51 | 52 | * `config.py` 53 | ```ruby 54 | import tensorflow as tf 55 | from loss_funnction_And_matrics import* 56 | import math 57 | ###---Number-of-GPU 58 | NUM_OF_GPU=2 59 | #["gpu:1","gpu:2","gpu:3"] 60 | DISTRIIBUTED_STRATEGY_GPUS=["gpu:0","gpu:1"] 61 | ###-----SEGMENATTION----### 62 | SEGMENTATION_MODEL_PATH='/Path/of/the/Segmentation Module/weight/Model.h5'.h5' 63 | SEGMENTATION_NUM_OF_CLASSES=31 64 | #####-----Configure DenseVnet3D---########## 65 | SEG_NUMBER_OF_CLASSES=31 66 | SEG_INPUT_PATCH_SIZE=(128,160,160, 1) 67 | NUM_DENSEBLOCK_EACH_RESOLUTION=(4, 8, 16) 68 | NUM_OF_FILTER_EACH_RESOLUTION=(12,24,24) 69 | DILATION_RATE=(5, 10, 10) 70 | DROPOUT_RATE=0.25 71 | ###----Resume-Training 72 | ''' 73 | if want to resume training from the weights Set 74 | RESUME_TRAINING=1 75 | ''' 76 | RESUME_TRAINING=0 77 | RESUME_TRAIING_MODEL='/Path/of/the/model/weight/Model.h5' 78 | TRAINING_INITIAL_EPOCH=0 79 | ##Network Configuration 80 | NUMBER_OF_CLASSES=5 81 | INPUT_PATCH_SIZE=(128,160,160, 1) 82 | TRAIN_NUM_RES_UNIT=3 83 | TRAIN_NUM_FILTERS=(16, 32, 64, 128) 84 | TRAIN_STRIDES=((1, 1, 1), (2, 2, 2), (2, 2, 2), (2, 2, 2)) 85 | TRAIN_CLASSIFY_ACTICATION=tf.nn.relu6 86 | TRAIN_KERNAL_INITIALIZER=tf.keras.initializers.VarianceScaling(distribution='uniform') 87 | ##Training Hyper-Parameter 88 | ##Training Hyper-Parameter 89 | TRAIN_CLASSIFY_LEARNING_RATE =1e-4 90 | TRAIN_CLASSIFY_LOSS=Weighted_BCTL 91 | OPTIMIZER=tf.keras.optimizers.Adam(lr=TRAIN_CLASSIFY_LEARNING_RATE,epsilon=1e-5) 92 | TRAIN_CLASSIFY_METRICS=tf.keras.metrics.AUC() 93 | BATCH_SIZE=12 94 | TRAINING_STEP_PER_EPOCH=math.ceil((3514)/BATCH_SIZE) 95 | VALIDATION_STEP=math.ceil((759)/BATCH_SIZE) 96 | TRAING_EPOCH=300 97 | NUMBER_OF_PARALLEL_CALL=6 98 | PARSHING=3*BATCH_SIZE 99 | #--Callbacks----- 100 | ModelCheckpoint_MOTITOR='val_loss' 101 | TRAINING_SAVE_MODEL_PATH='/Path/to/save/model/weight/Model.h5' 102 | TRAINING_CSV='DyFA_61FC1X1_April17_2020.csv' 103 | LOG_NAME="Log_DyFA_61FC1X1_April17_2020" 104 | MODEL_SAVING_NAME="DyFAModel61FC1X1_{val_loss:.2f}_{epoch}.h5" 105 | #### 106 | TRAINING_TF_RECORDS='/Training/tfrecords/path/' 107 | VALIDATION_TF_RECORDS='/Val/tfrecords/path/' 108 | ``` 109 | 110 | ## Multi-label Data Statistics 111 | ![Multi-label Data Statistics](https://github.com/fitushar/Weakly-Supervised-3D-Classification-of-Chest-CT-using-Aggregated-Multi-Resolution-Deep-Segmentation-/blob/master/figure/dataset.png) 112 | 113 | ## Results 114 | ![Classification Results](https://github.com/fitushar/Weakly-Supervised-3D-Classification-of-Chest-CT-using-Aggregated-Multi-Resolution-Deep-Segmentation-/blob/master/figure/results.png) 115 | 116 | 117 | -------------------------------------------------------------------------------- /SPIE_presentation.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fitushar/WeaklySupervised-3D-Classification-of-Chest-CT-using-Aggregated-MultiResolution-Segmentation-Feature/30975d90c8c7f84e498e8f54746c5b71b535d9d3/SPIE_presentation.pptx -------------------------------------------------------------------------------- /StFA/DenseVnet3D.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | import tensorflow as tf 5 | 6 | 7 | 8 | ##########---tf bilinear UpSampling3D 9 | def up_sampling(input_tensor, scale): 10 | net = tf.keras.layers.TimeDistributed(tf.keras.layers.UpSampling2D(size=(scale, scale), interpolation='bilinear'))(input_tensor) 11 | net = tf.keras.layers.Permute((2, 1, 3, 4))(net) # (B, z, H, W, C) -> (B, H, z, w, c) 12 | net = tf.keras.layers.TimeDistributed(tf.keras.layers.UpSampling2D(size=(scale, 1), interpolation='bilinear'))(net) 13 | net = tf.keras.layers.Permute((2, 1, 3, 4))(net) # (B, z, H, W, C) -> (B, H, z, w, c) 14 | return net 15 | 16 | #######-----Bottleneck 17 | def Bottleneck(x, nb_filter, increase_factor=4., weight_decay=1e-4): 18 | inter_channel = int(nb_filter * increase_factor) 19 | x = tf.keras.layers.Conv3D(inter_channel, (1, 1, 1), 20 | kernel_initializer='he_normal', 21 | padding='same', 22 | use_bias=False, 23 | kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) 24 | x = tf.keras.layers.BatchNormalization(epsilon=1.1e-5)(x) 25 | x = tf.nn.relu6(x) 26 | return x 27 | 28 | #####------------>>> Convolutional Block 29 | def conv_block(input, nb_filter, kernal_size=(3, 3, 3), dilation_rate=1, 30 | bottleneck=False, dropout_rate=None, weight_decay=1e-4): 31 | ''' Apply BatchNorm, Relu, 3x3X3 Conv3D, optional bottleneck block and dropout 32 | Args: 33 | input: Input tensor 34 | nb_filter: number of filters 35 | bottleneck: add bottleneck block 36 | dropout_rate: dropout rate 37 | weight_decay: weight decay factor 38 | Returns: tensor with batch_norm, relu and convolution3D added (optional bottleneck) 39 | ''' 40 | 41 | 42 | x = tf.keras.layers.BatchNormalization(epsilon=1.1e-5)(input) 43 | x = tf.nn.relu6(x) 44 | 45 | if bottleneck: 46 | inter_channel = nb_filter # Obtained from https://github.com/liuzhuang13/DenseNet/blob/master/densenet.lua 47 | x = tf.keras.layers.Conv3D(inter_channel, (1, 1, 1), 48 | kernel_initializer='he_normal', 49 | padding='same', 50 | use_bias=False, 51 | kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) 52 | x = tf.keras.layers.BatchNormalization(epsilon=1.1e-5)(x) 53 | x = tf.nn.relu6(x) 54 | 55 | x = tf.keras.layers.Conv3D(nb_filter, kernal_size, 56 | dilation_rate=dilation_rate, 57 | kernel_initializer='he_normal', 58 | padding='same', 59 | use_bias=False)(x) 60 | if dropout_rate: 61 | x = tf.keras.layers.SpatialDropout3D(dropout_rate)(x) 62 | return x 63 | 64 | ##--------------------DenseBlock-------#### 65 | def dense_block(x, nb_layers, growth_rate, kernal_size=(3, 3, 3), 66 | dilation_list=None, 67 | bottleneck=True, dropout_rate=None, weight_decay=1e-4, 68 | return_concat_list=False): 69 | ''' Build a dense_block where the output of each conv_block is fed to subsequent ones 70 | Args: 71 | x: input tensor 72 | nb_layers: the number of layers of conv_block to append to the model. 73 | nb_filter: number of filters 74 | growth_rate: growth rate 75 | bottleneck: bottleneck block 76 | dropout_rate: dropout rate 77 | weight_decay: weight decay factor 78 | grow_nb_filters: flag to decide to allow number of filters to grow 79 | return_concat_list: return the list of feature maps along with the actual output 80 | Returns: tensor with nb_layers of conv_block appended 81 | ''' 82 | 83 | if dilation_list is None: 84 | dilation_list = [1] * nb_layers 85 | elif type(dilation_list) is int: 86 | dilation_list = [dilation_list] * nb_layers 87 | else: 88 | if len(dilation_list) != nb_layers: 89 | raise ('the length of dilation_list should be equal to nb_layers %d' % nb_layers) 90 | 91 | x_list = [x] 92 | 93 | for i in range(nb_layers): 94 | cb = conv_block(x, growth_rate, kernal_size, dilation_list[i], 95 | bottleneck, dropout_rate, weight_decay) 96 | x_list.append(cb) 97 | if i == 0: 98 | x = cb 99 | else: 100 | x = tf.keras.layers.concatenate([x, cb], axis=-1) 101 | 102 | if return_concat_list: 103 | return x, x_list 104 | else: 105 | return x 106 | 107 | ###---------transition_block 108 | def transition_block(input, nb_filter, compression=1.0, weight_decay=1e-4, 109 | pool_kernal=(3, 3, 3), pool_strides=(2, 2, 2)): 110 | ''' Apply BatchNorm, Relu 1x1, Conv3D, optional compression, dropout and Maxpooling3D 111 | Args: 112 | input: input tensor 113 | nb_filter: number of filters 114 | compression: calculated as 1 - reduction. Reduces the number of feature maps 115 | in the transition block. 116 | dropout_rate: dropout rate 117 | weight_decay: weight decay factor 118 | Returns: keras tensor, after applying batch_norm, relu-conv, dropout, maxpool 119 | ''' 120 | 121 | 122 | x =tf.keras.layers.BatchNormalization(epsilon=1.1e-5)(input) 123 | x = tf.nn.relu6(x) 124 | x = tf.keras.layers.Conv3D(int(nb_filter * compression), (1, 1, 1), 125 | kernel_initializer='he_normal', 126 | padding='same', 127 | use_bias=False, 128 | kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) 129 | x = tf.keras.layers.AveragePooling3D(pool_kernal, strides=pool_strides)(x) 130 | 131 | return x 132 | 133 | ###---Trasnsition up block 134 | def transition_up_block(input, nb_filters, compression=1.0, 135 | kernal_size=(3, 3, 3), pool_strides=(2, 2, 2), 136 | type='deconv', weight_decay=1E-4): 137 | ''' SubpixelConvolutional Upscaling (factor = 2) 138 | Args: 139 | input: tensor 140 | nb_filters: number of layers 141 | type: can be 'upsampling', 'subpixel', 'deconv'. Determines type of upsampling performed 142 | weight_decay: weight decay factor 143 | Returns: keras tensor, after applying upsampling operation. 144 | ''' 145 | 146 | if type == 'upsampling': 147 | x = tf.keras.layers.UpSampling3D(size=kernal_size, interpolation='bilinear')(input) 148 | x = tf.keras.layers.BatchNormalization(epsilon=1.1e-5)(x) 149 | x = tf.nn.relu6(x) 150 | x = tf.keras.layers.Conv3D(int(nb_filters * compression), (1, 1, 1), 151 | kernel_initializer='he_normal', 152 | padding='same', 153 | use_bias=False, 154 | kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x) 155 | 156 | else: 157 | x = tf.keras.layers.Conv3DTranspose(int(nb_filters * compression), 158 | kernal_size, 159 | strides=pool_strides, 160 | activation='relu', 161 | padding='same', 162 | kernel_initializer='he_normal', kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(input) 163 | 164 | return x 165 | 166 | 167 | 168 | def DenseVnet3D(inputs, 169 | nb_classes=1, 170 | encoder_nb_layers=(5, 8, 8), 171 | growth_rate=(4, 8, 12), 172 | dilation_list=(5, 3, 1), 173 | dropout_rate=0.25, 174 | weight_decay=1e-4, 175 | init_conv_filters=24): 176 | """ 3D DenseVNet Implementation by f.i.tushar, tf 2.0. 177 | This is a tensorflow 2.0 Implementation of paper: 178 | Gibson et al., "Automatic multi-organ segmentation on abdominal CT with 179 | dense V-networks" 2018. 180 | 181 | Reference Implementation: vision4med :i) https://github.com/baibaidj/vision4med/blob/5c23f57c2836bfabd7bd95a024a0a0b776b181b5/nets/DenseVnet.py 182 | ii) https://niftynet.readthedocs.io/en/dev/_modules/niftynet/network/dense_vnet.html#DenseVNet 183 | 184 | Input 185 | | 186 | --[ DFS ]-----------------------[ Conv ]------------[ Conv ]------[+]--> 187 | | | | | 188 | -----[ DFS ]---------------[ Conv ]------ | | 189 | | | | 190 | -----[ DFS ]-------[ Conv ]--------- | 191 | [ Prior ]--- 192 | Args: 193 | inputs: Input , input shape should be (Batch,D,H,W,channels) 194 | nb_classes: number of classes 195 | encoder_nb_layers: Number of Layer in each dense_block 196 | growth_rate: Number of filters in each DenseBlock 197 | dilation_list=Dilation rate each level 198 | dropout_rate: dropout rate 199 | weight_decay: weight decay 200 | Returns: Returns the Segmentation Prediction of Given Input Shape 201 | """ 202 | #--|Getting the Input 203 | img_input = inputs 204 | input_shape = tf.shape(img_input) # Input shape 205 | nb_dense_block = len(encoder_nb_layers)# Convert tuple to list 206 | 207 | # Initial convolution 208 | x = tf.keras.layers.Conv3D(init_conv_filters, (5, 5, 5), 209 | strides=2, 210 | kernel_initializer='he_normal', 211 | padding='same', 212 | name='initial_conv3D', 213 | use_bias=False, 214 | kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(img_input) 215 | x = tf.keras.layers.BatchNormalization(epsilon=1.1e-5)(x) 216 | x = tf.nn.relu6(x) 217 | 218 | #Making the skiplist for concationatin 219 | skip_list = [] 220 | 221 | # Add dense blocks 222 | for block_idx in range(nb_dense_block): 223 | ''' 224 | |--Input for dense_block is as following 225 | |---#x=Input, 226 | #encoder_nb_layers[block_idx]=Number of layer in a dense_block 227 | #growth_rate[block_idx]= Number of Filter in that DenseBlock 228 | #dilation_list= Dilation Rate. 229 | 230 | ''' 231 | x = dense_block(x, encoder_nb_layers[block_idx], 232 | growth_rate[block_idx], 233 | kernal_size=(3, 3, 3), 234 | dilation_list=dilation_list[block_idx], 235 | dropout_rate=dropout_rate, 236 | weight_decay=weight_decay, 237 | ) 238 | 239 | # Skip connection 240 | skip_list.append(x) 241 | #Pooling 242 | x = tf.keras.layers.AveragePooling3D((2, 2, 2))(x) 243 | # x = __transition_block(x, nb_filter,compression=compression,weight_decay=weight_decay,pool_kernal=(3, 3, 3),pool_strides=(2, 2, 2)) 244 | 245 | 246 | ##Convolutiion and third Resolution layer and Updample. 247 | x_level3 = conv_block(skip_list[-1], growth_rate[2], bottleneck=True, dropout_rate=dropout_rate) 248 | x_level3 = up_sampling(x_level3, scale=4) 249 | # x_level3 = UpSampling3D(size = (4,4,4))(x_level3) 250 | 251 | ##Convolutiion and 2nd Resolution layer and Updample. 252 | x_level2 = conv_block(skip_list[-2], growth_rate[1], bottleneck=True, dropout_rate=dropout_rate) 253 | x_level2 = up_sampling(x_level2, scale=2) 254 | # x_level2 = UpSampling3D(size=(2, 2, 2))(x_level2) 255 | 256 | ##Convolutiion and first Resolution layer 257 | x_level1 = conv_block(skip_list[-3], growth_rate[0], bottleneck=True, dropout_rate=dropout_rate) 258 | #x_level1 = up_sampling(x_level1, scale=2) 259 | x = tf.keras.layers.Concatenate()([x_level3, x_level2, x_level1]) 260 | 261 | ###--Final Convolution--- 262 | x = conv_block(x, 24, bottleneck=False, dropout_rate=dropout_rate) 263 | ##----Upsampling--TheFinal Output----##### 264 | x = up_sampling(x, scale=2) 265 | 266 | ####------Prediction---------------### 267 | if nb_classes == 1: 268 | x = tf.keras.layers.Conv3D(nb_classes, 1, activation='sigmoid', padding='same', use_bias=False)(x) 269 | elif nb_classes > 1: 270 | x = tf.keras.layers.Conv3D(nb_classes, 1, activation='softmax', padding='same', use_bias=False)(x) 271 | #x = tf.argmax(x, axis=-1) 272 | print(x) 273 | 274 | # Create model. 275 | model = tf.keras.Model(img_input, x, name='DenseVnet3D') 276 | return model 277 | ''' 278 | ###################----Demo Usages----############# 279 | INPUT_PATCH_SIZE=[384,192,192,1] 280 | NUMBER_OF_CLASSES=1 281 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 282 | 283 | #Model_3D=DenseVnet3D(inputs,nb_classes=1,encoder_nb_layers=(5, 8, 8),growth_rate=(4, 8, 12),dilation_list=(5, 3, 1)) 284 | Model_3D=DenseVnet3D(inputs,nb_classes=1,encoder_nb_layers=(4, 8, 16),growth_rate=(12,24,24),dilation_list=(5, 10, 10),dropout_rate=0.25) 285 | Model_3D.summary() 286 | tf.keras.utils.plot_model(Model_3D, 'DenseVnet3D.png',show_shapes=True) 287 | ''' 288 | -------------------------------------------------------------------------------- /StFA/Preprocessing_utlities.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | from __future__ import print_function 3 | from __future__ import division 4 | from __future__ import absolute_import 5 | import numpy as np 6 | from scipy.ndimage.interpolation import map_coordinates 7 | from scipy.ndimage.filters import gaussian_filter 8 | import tensorflow as tf 9 | 10 | 11 | def whitening(image): 12 | """Whitening. Normalises image to zero mean and unit variance.""" 13 | 14 | image = image.astype(np.float32) 15 | 16 | mean = np.mean(image) 17 | std = np.std(image) 18 | 19 | if std > 0: 20 | ret = (image - mean) / std 21 | else: 22 | ret = image * 0. 23 | return ret 24 | 25 | 26 | def normalise_zero_one(image): 27 | """Image normalisation. Normalises image to fit [0, 1] range.""" 28 | 29 | image = image.astype(np.float32) 30 | 31 | minimum = np.min(image) 32 | maximum = np.max(image) 33 | 34 | if maximum > minimum: 35 | ret = (image - minimum) / (maximum - minimum) 36 | else: 37 | ret = image * 0. 38 | return ret 39 | 40 | 41 | def normalise_one_one(image): 42 | """Image normalisation. Normalises image to fit [-1, 1] range.""" 43 | 44 | ret = normalise_zero_one(image) 45 | ret *= 2. 46 | ret -= 1. 47 | return ret 48 | 49 | 50 | def flip(imagelist, axis=1): 51 | """Randomly flip spatial dimensions 52 | Args: 53 | imagelist (np.ndarray or list or tuple): image(s) to be flipped 54 | axis (int): axis along which to flip the images 55 | Returns: 56 | np.ndarray or list or tuple: same as imagelist but randomly flipped 57 | along axis 58 | """ 59 | 60 | # Check if a single image or a list of images has been passed 61 | was_singular = False 62 | if isinstance(imagelist, np.ndarray): 63 | imagelist = [imagelist] 64 | was_singular = True 65 | 66 | # With a probility of 0.5 flip the image(s) across `axis` 67 | do_flip = np.random.random(1) 68 | if do_flip > 0.5: 69 | for i in range(len(imagelist)): 70 | imagelist[i] = np.flip(imagelist[i], axis=axis) 71 | if was_singular: 72 | return imagelist[0] 73 | return imagelist 74 | 75 | 76 | def add_gaussian_offset(image, sigma=0.1): 77 | """ 78 | Add Gaussian offset to an image. Adds the offset to each channel 79 | independently. 80 | Args: 81 | image (np.ndarray): image to add noise to 82 | sigma (float): stddev of the Gaussian distribution to generate noise 83 | from 84 | Returns: 85 | np.ndarray: same as image but with added offset to each channel 86 | """ 87 | 88 | offsets = np.random.normal(0, sigma, ([1] * (image.ndim - 1) + [image.shape[-1]])) 89 | image += offsets 90 | return image 91 | 92 | 93 | def add_gaussian_noise(image, sigma=0.05): 94 | """ 95 | Add Gaussian noise to an image 96 | Args: 97 | image (np.ndarray): image to add noise to 98 | sigma (float): stddev of the Gaussian distribution to generate noise 99 | from 100 | Returns: 101 | np.ndarray: same as image but with added offset to each channel 102 | """ 103 | 104 | image += np.random.normal(0, sigma, image.shape) 105 | return image 106 | 107 | 108 | def elastic_transform(image, alpha, sigma): 109 | """ 110 | Elastic deformation of images as described in [1]. 111 | [1] Simard, Steinkraus and Platt, "Best Practices for Convolutional 112 | Neural Networks applied to Visual Document Analysis", in Proc. of the 113 | International Conference on Document Analysis and Recognition, 2003. 114 | Based on gist https://gist.github.com/erniejunior/601cdf56d2b424757de5 115 | Args: 116 | image (np.ndarray): image to be deformed 117 | alpha (list): scale of transformation for each dimension, where larger 118 | values have more deformation 119 | sigma (list): Gaussian window of deformation for each dimension, where 120 | smaller values have more localised deformation 121 | Returns: 122 | np.ndarray: deformed image 123 | """ 124 | 125 | assert len(alpha) == len(sigma), \ 126 | "Dimensions of alpha and sigma are different" 127 | 128 | channelbool = image.ndim - len(alpha) 129 | out = np.zeros((len(alpha) + channelbool, ) + image.shape) 130 | 131 | # Generate a Gaussian filter, leaving channel dimensions zeroes 132 | for jj in range(len(alpha)): 133 | array = (np.random.rand(*image.shape) * 2 - 1) 134 | out[jj] = gaussian_filter(array, sigma[jj], 135 | mode="constant", cval=0) * alpha[jj] 136 | 137 | # Map mask to indices 138 | shapes = list(map(lambda x: slice(0, x, None), image.shape)) 139 | grid = np.broadcast_arrays(*np.ogrid[shapes]) 140 | indices = list(map((lambda x: np.reshape(x, (-1, 1))), grid + np.array(out))) 141 | 142 | # Transform image based on masked indices 143 | transformed_image = map_coordinates(image, indices, order=0, 144 | mode='reflect').reshape(image.shape) 145 | 146 | return transformed_image 147 | 148 | def extract_class_balanced_example_array(image, 149 | label, 150 | example_size=[1, 64, 64], 151 | n_examples=1, 152 | classes=2, 153 | class_weights=None): 154 | """Extract training examples from an image (and corresponding label) subject 155 | to class balancing. Returns an image example array and the 156 | corresponding label array. 157 | 158 | Args: 159 | image (np.ndarray): image to extract class-balanced patches from 160 | label (np.ndarray): labels to use for balancing the classes 161 | example_size (list or tuple): shape of the patches to extract 162 | n_examples (int): number of patches to extract in total 163 | classes (int or list or tuple): number of classes or list of classes 164 | to extract 165 | 166 | Returns: 167 | np.ndarray, np.ndarray: class-balanced patches extracted from full 168 | images with the shape [batch, example_size..., image_channels] 169 | """ 170 | assert image.shape[:-1] == label.shape, 'Image and label shape must match' 171 | assert image.ndim - 1 == len(example_size), \ 172 | 'Example size doesnt fit image size' 173 | #assert all([i_s >= e_s for i_s, e_s in zip(image.shape, example_size)]), \ 174 | #'Image must be larger than example shape' 175 | rank = len(example_size) 176 | 177 | 178 | 179 | if isinstance(classes, int): 180 | classes = tuple(range(classes)) 181 | n_classes = len(classes) 182 | 183 | 184 | if class_weights is None: 185 | n_ex_per_class = np.ones(n_classes).astype(int) * int(np.round(n_examples / n_classes)) 186 | else: 187 | assert len(class_weights) == n_classes, \ 188 | 'Class_weights must match number of classes' 189 | class_weights = np.array(class_weights) 190 | n_ex_per_class = np.round((class_weights / class_weights.sum()) * n_examples).astype(int) 191 | 192 | # Compute an example radius to define the region to extract around a 193 | # center location 194 | ex_rad = np.array(list(zip(np.floor(np.array(example_size) / 2.0), 195 | np.ceil(np.array(example_size) / 2.0))), 196 | dtype=np.int) 197 | 198 | class_ex_images = [] 199 | class_ex_lbls = [] 200 | min_ratio = 1. 201 | for c_idx, c in enumerate(classes): 202 | # Get valid, random center locations belonging to that class 203 | idx = np.argwhere(label == c) 204 | 205 | ex_images = [] 206 | ex_lbls = [] 207 | 208 | if len(idx) == 0 or n_ex_per_class[c_idx] == 0: 209 | class_ex_images.append([]) 210 | class_ex_lbls.append([]) 211 | continue 212 | 213 | # Extract random locations 214 | r_idx_idx = np.random.choice(len(idx), 215 | size=min(n_ex_per_class[c_idx], len(idx)), 216 | replace=False).astype(int) 217 | r_idx = idx[r_idx_idx] 218 | 219 | # Shift the random to valid locations if necessary 220 | r_idx = np.array( 221 | [np.array([max(min(r[dim], image.shape[dim] - ex_rad[dim][1]), 222 | ex_rad[dim][0]) for dim in range(rank)]) 223 | for r in r_idx]) 224 | 225 | for i in range(len(r_idx)): 226 | # Extract class-balanced examples from the original image 227 | slicer = [slice(r_idx[i][dim] - ex_rad[dim][0], r_idx[i][dim] + ex_rad[dim][1]) for dim in range(rank)] 228 | 229 | ex_image = image[slicer][np.newaxis, :] 230 | 231 | ex_lbl = label[slicer][np.newaxis, :] 232 | 233 | # Concatenate them and return the examples 234 | ex_images = np.concatenate((ex_images, ex_image), axis=0) \ 235 | if (len(ex_images) != 0) else ex_image 236 | ex_lbls = np.concatenate((ex_lbls, ex_lbl), axis=0) \ 237 | if (len(ex_lbls) != 0) else ex_lbl 238 | 239 | class_ex_images.append(ex_images) 240 | class_ex_lbls.append(ex_lbls) 241 | 242 | ratio = n_ex_per_class[c_idx] / len(ex_images) 243 | min_ratio = ratio if ratio < min_ratio else min_ratio 244 | 245 | indices = np.floor(n_ex_per_class * min_ratio).astype(int) 246 | 247 | ex_images = np.concatenate([cimage[:idxs] for cimage, idxs in zip(class_ex_images, indices) 248 | if len(cimage) > 0], axis=0) 249 | ex_lbls = np.concatenate([clbl[:idxs] for clbl, idxs in zip(class_ex_lbls, indices) 250 | if len(clbl) > 0], axis=0) 251 | 252 | return ex_images, ex_lbls 253 | 254 | def resize_image_with_crop_or_pad(image, img_size=(64, 64, 64), **kwargs): 255 | """Image resizing. Resizes image by cropping or padding dimension 256 | to fit specified size. 257 | Args: 258 | image (np.ndarray): image to be resized 259 | img_size (list or tuple): new image size 260 | kwargs (): additional arguments to be passed to np.pad 261 | Returns: 262 | np.ndarray: resized image 263 | """ 264 | 265 | assert isinstance(image, (np.ndarray, np.generic)) 266 | assert (image.ndim - 1 == len(img_size) or image.ndim == len(img_size)), \ 267 | 'Example size doesnt fit image size' 268 | 269 | # Get the image dimensionality 270 | rank = len(img_size) 271 | 272 | # Create placeholders for the new shape 273 | from_indices = [[0, image.shape[dim]] for dim in range(rank)] 274 | to_padding = [[0, 0] for dim in range(rank)] 275 | 276 | slicer = [slice(None)] * rank 277 | 278 | # For each dimensions find whether it is supposed to be cropped or padded 279 | for i in range(rank): 280 | if image.shape[i] < img_size[i]: 281 | to_padding[i][0] = (img_size[i] - image.shape[i]) // 2 282 | to_padding[i][1] = img_size[i] - image.shape[i] - to_padding[i][0] 283 | else: 284 | from_indices[i][0] = int(np.floor((image.shape[i] - img_size[i]) / 2.)) 285 | from_indices[i][1] = from_indices[i][0] + img_size[i] 286 | 287 | # Create slicer object to crop or leave each dimension 288 | slicer[i] = slice(from_indices[i][0], from_indices[i][1]) 289 | 290 | # Pad the cropped image to extend the missing dimension 291 | return np.pad(image[slicer], to_padding, **kwargs) 292 | 293 | 294 | def extract_random_example_array(image_list,example_size=[1, 64, 64],n_examples=1): 295 | 296 | """Randomly extract training examples from image (and a corresponding label). 297 | Returns an image example array and the corresponding label array. 298 | Args: 299 | image_list (np.ndarray or list or tuple): image(s) to extract random 300 | patches from 301 | example_size (list or tuple): shape of the patches to extract 302 | n_examples (int): number of patches to extract in total 303 | Returns: 304 | np.ndarray, np.ndarray: class-balanced patches extracted from full 305 | images with the shape [batch, example_size..., image_channels] 306 | """ 307 | 308 | assert n_examples > 0 309 | 310 | was_singular = False 311 | if isinstance(image_list, np.ndarray): 312 | image_list = [image_list] 313 | was_singular = True 314 | 315 | assert all([i_s >= e_s for i_s, e_s in zip(image_list[0].shape, example_size)]), \ 316 | 'Image must be bigger than example shape' 317 | assert (image_list[0].ndim - 1 == len(example_size) or image_list[0].ndim == len(example_size)), \ 318 | 'Example size doesnt fit image size' 319 | 320 | for i in image_list: 321 | if len(image_list) > 1: 322 | assert (i.ndim - 1 == image_list[0].ndim or i.ndim == image_list[0].ndim or i.ndim + 1 == image_list[0].ndim),\ 323 | 'Example size doesnt fit image size' 324 | 325 | assert all([i0_s == i_s for i0_s, i_s in zip(image_list[0].shape, i.shape)]), \ 326 | 'Image shapes must match' 327 | 328 | rank = len(example_size) 329 | 330 | # Extract random examples from image and label 331 | valid_loc_range = [image_list[0].shape[i] - example_size[i] for i in range(rank)] 332 | 333 | rnd_loc = [np.random.randint(valid_loc_range[dim], size=n_examples) 334 | if valid_loc_range[dim] > 0 335 | else np.zeros(n_examples, dtype=int) for dim in range(rank)] 336 | 337 | examples = [[]] * len(image_list) 338 | for i in range(n_examples): 339 | slicer = [slice(rnd_loc[dim][i], rnd_loc[dim][i] + example_size[dim]) 340 | for dim in range(rank)] 341 | 342 | for j in range(len(image_list)): 343 | ex_image = image_list[j][slicer][np.newaxis] 344 | # Concatenate and return the examples 345 | examples[j] = np.concatenate((examples[j], ex_image), axis=0) \ 346 | if (len(examples[j]) != 0) else ex_image 347 | 348 | if was_singular: 349 | return examples[0] 350 | return examples 351 | -------------------------------------------------------------------------------- /StFA/StFA_Model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.layers import Conv3D, Input, MaxPooling3D, Dropout, concatenate, UpSampling3D 4 | import tensorflow as tf 5 | from config import* 6 | from loss_funnction_And_matrics import* 7 | import numpy as np 8 | from DenseVnet3D import DenseVnet3D 9 | #from Unet3D import Unet3D 10 | 11 | ####----Residual Blocks used for Resnet3D 12 | def Residual_Block(inputs, 13 | out_filters, 14 | kernel_size=(3, 3, 3), 15 | strides=(1, 1, 1), 16 | use_bias=False, 17 | activation=tf.nn.relu6, 18 | kernel_initializer=tf.keras.initializers.VarianceScaling(distribution='uniform'), 19 | bias_initializer=tf.zeros_initializer(), 20 | kernel_regularizer=tf.keras.regularizers.l2(l=0.001), 21 | bias_regularizer=None, 22 | **kwargs): 23 | 24 | 25 | conv_params={'padding': 'same', 26 | 'use_bias': use_bias, 27 | 'kernel_initializer': kernel_initializer, 28 | 'bias_initializer': bias_initializer, 29 | 'kernel_regularizer': kernel_regularizer, 30 | 'bias_regularizer': bias_regularizer} 31 | 32 | in_filters = inputs.get_shape().as_list()[-1] 33 | x=inputs 34 | orig_x=x 35 | 36 | ##building 37 | # Adjust the strided conv kernel size to prevent losing information 38 | k = [s * 2 if s > 1 else k for k, s in zip(kernel_size, strides)] 39 | 40 | if np.prod(strides) != 1: 41 | orig_x = tf.keras.layers.MaxPool3D(pool_size=strides,strides=strides,padding='valid')(orig_x) 42 | 43 | ##sub-unit-0 44 | x=tf.keras.layers.BatchNormalization()(x) 45 | x=activation(x) 46 | x=tf.keras.layers.Conv3D(filters=out_filters,kernel_size=k,strides=strides,**conv_params)(x) 47 | 48 | ##sub-unit-1 49 | x=tf.keras.layers.BatchNormalization()(x) 50 | x=activation(x) 51 | x=tf.keras.layers.Conv3D(filters=out_filters,kernel_size=kernel_size,strides=(1,1,1),**conv_params)(x) 52 | 53 | # Handle differences in input and output filter sizes 54 | if in_filters < out_filters: 55 | orig_x = tf.pad(tensor=orig_x,paddings=[[0, 0]] * (len(x.get_shape().as_list()) - 1) + [[ 56 | int(np.floor((out_filters - in_filters) / 2.)), 57 | int(np.ceil((out_filters - in_filters) / 2.))]]) 58 | 59 | elif in_filters > out_filters: 60 | orig_x = tf.keras.layers.Conv3D(filters=out_filters,kernel_size=kernel_size,strides=(1,1,1),**conv_params)(orig_x) 61 | 62 | x += orig_x 63 | return x 64 | 65 | 66 | 67 | ## Resnet----3D 68 | def Resnet3D(inputs, 69 | num_classes, 70 | num_res_units=TRAIN_NUM_RES_UNIT, 71 | filters=TRAIN_NUM_FILTERS, 72 | strides=TRAIN_STRIDES, 73 | use_bias=False, 74 | activation=TRAIN_CLASSIFY_ACTICATION, 75 | kernel_initializer=TRAIN_KERNAL_INITIALIZER, 76 | bias_initializer=tf.zeros_initializer(), 77 | kernel_regularizer=tf.keras.regularizers.l2(l=0.001), 78 | bias_regularizer=None, 79 | **kwargs): 80 | conv_params = {'padding': 'same', 81 | 'use_bias': use_bias, 82 | 'kernel_initializer': kernel_initializer, 83 | 'bias_initializer': bias_initializer, 84 | 'kernel_regularizer': kernel_regularizer, 85 | 'bias_regularizer': bias_regularizer} 86 | 87 | 88 | ##building 89 | k = [s * 2 if s > 1 else 3 for s in strides[0]] 90 | 91 | 92 | #Input 93 | x = inputs 94 | #1st-convo 95 | x=tf.keras.layers.Conv3D(filters[0], k, strides[0], **conv_params)(x) 96 | 97 | for res_scale in range(1, len(filters)): 98 | x = Residual_Block( 99 | inputs=x, 100 | out_filters=filters[res_scale], 101 | strides=strides[res_scale], 102 | activation=activation, 103 | name='unit_{}_0'.format(res_scale)) 104 | for i in range(1, num_res_units): 105 | x = Residual_Block( 106 | inputs=x, 107 | out_filters=filters[res_scale], 108 | strides=(1, 1, 1), 109 | activation=activation, 110 | name='unit_{}_{}'.format(res_scale, i)) 111 | 112 | 113 | x=tf.keras.layers.BatchNormalization()(x) 114 | x=activation(x) 115 | #axis = tuple(range(len(x.get_shape().as_list())))[1:-1] 116 | #x = tf.reduce_mean(x, axis=axis, name='global_avg_pool') 117 | x=tf.keras.layers.GlobalAveragePooling3D()(x) 118 | x =tf.keras.layers.Dropout(0.5)(x) 119 | classifier=tf.keras.layers.Dense(units=num_classes,activation='sigmoid')(x) 120 | 121 | #model = tf.keras.Model(inputs=inputs, outputs=classifier) 122 | #model.compile(optimizer=Adam(lr=TRAIN_CLASSIFY_LEARNING_RATE), loss=[TRAIN_CLASSIFY_LOSS], metrics=[TRAIN_CLASSIFY_METRICS,tf.keras.metrics.AUC()]) 123 | 124 | return classifier 125 | 126 | ### Final Model 127 | def DyFAModel_WithUnet(Unet_Model_Path,Input_shape,num_classes_clf,num_classes_for_seg): 128 | 129 | ###----Loading Segmentation Module---### 130 | inputs = tf.keras.Input(shape=Input_shape, name='CT') 131 | model_3DUnet=Unet3D(inputs,num_classes_for_seg) 132 | 133 | #-| Loading the Best Segmentation Weight 134 | model_3DUnet.load_weights(Unet_Model_Path) 135 | #-| Making the Segmentation Model Non-Trainable 136 | model_3DUnet.trainable = False 137 | 138 | #--| Getting the Features from Different Resolutions 139 | f_r1=(model_3DUnet.get_layer('Feature_R1').output) 140 | f_r2=(model_3DUnet.get_layer('Feature_R2').output) 141 | f_r3=(model_3DUnet.get_layer('Feature_R3').output) 142 | f_r4=(model_3DUnet.get_layer('Feature_R4').output) 143 | #f_r5=(model_3DUnet.get_layer('Feature_R5').output) 144 | last_predict=(model_3DUnet.get_layer('conv3d_17').output) 145 | #-| Upsampling the lower Resolution FA 146 | up2=(UpSampling3D(size = (2,2,2))(f_r2)) 147 | up3=(UpSampling3D(size = (4,4,4))(f_r3)) 148 | up4=(UpSampling3D(size = (8,8,8))(f_r4)) 149 | #up5=(UpSampling3D(size = (16,16,16))(f_r5)) 150 | #-| Concatenate the FAs 151 | FA_concatination=concatenate([f_r1,up2,up3,up4,last_predict],axis=-1) 152 | 153 | #-|| DyFA- Pass the Concatinated Feature to 1x1x1 convolution to get a 1 channel Volume. 154 | DyFA=tf.keras.layers.Conv3D(1, 1, name='DyFA')(FA_concatination) 155 | 156 | #-|| Making a HxWxDx2 channel Input data for the DyFA Classification Model 157 | DyFA_INPUT=concatenate([DyFA,inputs],axis=-1) 158 | 159 | DyFA_Model_output=Resnet3D(DyFA_INPUT,num_classes=num_classes_clf) 160 | 161 | Final_DyFAmodel=tf.keras.Model(inputs=inputs, outputs=DyFA_Model_output) 162 | 163 | return Final_DyFAmodel 164 | 165 | 166 | def StFAModel_withDenseVnet(DenseVnet3D_Model_Path,Input_shape,num_classes_clf,num_classes_for_seg): 167 | 168 | ###----Loading Segmentation Module---### 169 | inputs = tf.keras.Input(shape=Input_shape, name='CT') 170 | model_3DDenseVnet=DenseVnet3D(inputs,nb_classes=SEG_NUMBER_OF_CLASSES,encoder_nb_layers=NUM_DENSEBLOCK_EACH_RESOLUTION,growth_rate=NUM_OF_FILTER_EACH_RESOLUTION,dilation_list=DILATION_RATE,dropout_rate=DROPOUT_RATE) 171 | #-| Loading the Best Segmentation Weight 172 | model_3DDenseVnet.load_weights(DenseVnet3D_Model_Path) 173 | model_3DDenseVnet.summary() 174 | #-| Making the Segmentation Model Non-Trainable 175 | model_3DDenseVnet.trainable = False 176 | #-| Getting the features 177 | f_60_192_96_96=(model_3DDenseVnet.get_layer('concatenate_25').output) 178 | last_predict=(model_3DDenseVnet.get_layer('conv3d_63').output) 179 | #-| Upsampling the lower Resolution FA 180 | upsampled_F=(UpSampling3D(size = (2,2,2))(f_60_192_96_96)) 181 | #-| Concatenate the FAs 182 | #FA_concatination=concatenate([upsampled_F,last_predict],axis=-1) #not using last layeroutput 183 | FA_concatination=tf.math.reduce_mean(upsampled_F,axis=-1) 184 | FA_concatination=(FA_concatination-tf.math.reduce_min(FA_concatination))/(tf.math.reduce_max(FA_concatination)-tf.math.reduce_min(FA_concatination)) 185 | FA_concatination=tf.expand_dims(FA_concatination, axis=-1) 186 | #-|| DyFA- Pass the Concatinated Feature to 1x1x1 convolution to get a 1 channel Volume. 187 | #DyFA=tf.keras.layers.Conv3D(1, 1, name='DyFA')(FA_concatination) 188 | 189 | #-|| Making a HxWxDx2 channel Input data for the DyFA Classification Model 190 | DyFA_INPUT=concatenate([FA_concatination,inputs],axis=-1) 191 | 192 | DyFA_Model_output=Resnet3D(DyFA_INPUT,num_classes=num_classes_clf) 193 | 194 | Final_DyFAmodel=tf.keras.Model(inputs=inputs, outputs=DyFA_Model_output) 195 | 196 | return Final_DyFAmodel 197 | -------------------------------------------------------------------------------- /StFA/config.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from loss_funnction_And_matrics import* 3 | import math 4 | ###---Number-of-GPU 5 | NUM_OF_GPU=1 6 | #["gpu:1","gpu:2","gpu:3"] 7 | DISTRIIBUTED_STRATEGY_GPUS=["gpu:0"] 8 | ###-----SEGMENATTION----### 9 | SEGMENTATION_MODEL_PATH='/image_data/Scripts/April_Model/DyFA_61FAvg_April17_2020/LungSEG_DenseVnet_2.60_4998.h5' 10 | SEGMENTATION_NUM_OF_CLASSES=31 11 | #####-----Configure DenseVnet3D---########## 12 | SEG_NUMBER_OF_CLASSES=31 13 | SEG_INPUT_PATCH_SIZE=(128,160,160, 1) 14 | NUM_DENSEBLOCK_EACH_RESOLUTION=(4, 8, 16) 15 | NUM_OF_FILTER_EACH_RESOLUTION=(12,24,24) 16 | DILATION_RATE=(5, 10, 10) 17 | DROPOUT_RATE=0.25 18 | ###----Resume-Training 19 | RESUME_TRAINING=1 20 | RESUME_TRAIING_MODEL='/image_data/Scripts/April_Model/DyFA_61FAvg_April17_2020/Model_DyFA_61FAvg_April17_2020/DyFAModel60FAvg_9.62_55.h5' 21 | TRAINING_INITIAL_EPOCH=55 22 | ##Network Configuration 23 | NUMBER_OF_CLASSES=5 24 | INPUT_PATCH_SIZE=(128,160,160, 1) 25 | TRAIN_NUM_RES_UNIT=3 26 | TRAIN_NUM_FILTERS=(16, 32, 64, 128) 27 | TRAIN_STRIDES=((1, 1, 1), (2, 2, 2), (2, 2, 2), (2, 2, 2)) 28 | TRAIN_CLASSIFY_ACTICATION=tf.nn.relu6 29 | TRAIN_KERNAL_INITIALIZER=tf.keras.initializers.VarianceScaling(distribution='uniform') 30 | ##Training Hyper-Parameter 31 | ##Training Hyper-Parameter 32 | TRAIN_CLASSIFY_LEARNING_RATE =1e-4 33 | TRAIN_CLASSIFY_LOSS=Weighted_BCTL 34 | OPTIMIZER=tf.keras.optimizers.Adam(lr=TRAIN_CLASSIFY_LEARNING_RATE,epsilon=1e-5) 35 | TRAIN_CLASSIFY_METRICS=tf.keras.metrics.AUC() 36 | BATCH_SIZE=6 37 | TRAINING_STEP_PER_EPOCH=math.ceil((3514)/BATCH_SIZE) 38 | VALIDATION_STEP=math.ceil((759)/BATCH_SIZE) 39 | TRAING_EPOCH=300 40 | NUMBER_OF_PARALLEL_CALL=3 41 | PARSHING=2*BATCH_SIZE 42 | #--Callbacks----- 43 | ModelCheckpoint_MOTITOR='val_loss' 44 | TRAINING_SAVE_MODEL_PATH='/image_data/Scripts/April_Model/DyFA_61FAvg_April17_2020/Model_DyFA_61FAvg_April17_2020/' 45 | TRAINING_CSV='DyFA_61FAvg_April17_2020.csv' 46 | LOG_NAME="Log_DyFA_60FAvg_April17_2020" 47 | MODEL_SAVING_NAME="DyFAModel60FAvg_{val_loss:.2f}_{epoch}.h5" 48 | 49 | #### 50 | TRAINING_TF_RECORDS='/image_data/nobackup/Lung_CenterPatch_2mm_March27_2020/tf/Train_tfrecords/' 51 | VALIDATION_TF_RECORDS='/image_data/nobackup/Lung_CenterPatch_2mm_March27_2020/tf/Val_tfrecords/' 52 | -------------------------------------------------------------------------------- /StFA/loss_funnction_And_matrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | ####---Loss 4 | @tf.function 5 | def macro_soft_f1(y, y_hat): 6 | """Compute the macro soft F1-score as a cost (average 1 - soft-F1 across all labels). 7 | Use probability values instead of binary predictions. 8 | 9 | Args: 10 | y (int32 Tensor): targets array of shape (BATCH_SIZE, N_LABELS) 11 | y_hat (float32 Tensor): probability matrix from forward propagation of shape (BATCH_SIZE, N_LABELS) 12 | 13 | Returns: 14 | cost (scalar Tensor): value of the cost function for the batch 15 | """ 16 | y = tf.cast(y, tf.float32) 17 | y_hat = tf.cast(y_hat, tf.float32) 18 | tp = tf.reduce_sum(y_hat * y, axis=0) 19 | fp = tf.reduce_sum(y_hat * (1 - y), axis=0) 20 | fn = tf.reduce_sum((1 - y_hat) * y, axis=0) 21 | soft_f1 = 2*tp / (2*tp + fn + fp + 1e-16) 22 | cost = 1 - soft_f1 # reduce 1 - soft-f1 in order to increase soft-f1 23 | macro_cost = tf.reduce_mean(cost) # average on all labels 24 | return macro_cost 25 | 26 | 27 | ###Matrics 28 | @tf.function 29 | def macro_f1(y, y_hat, thresh=0.5): 30 | """Compute the macro F1-score on a batch of observations (average F1 across labels) 31 | 32 | Args: 33 | y (int32 Tensor): labels array of shape (BATCH_SIZE, N_LABELS) 34 | y_hat (float32 Tensor): probability matrix from forward propagation of shape (BATCH_SIZE, N_LABELS) 35 | thresh: probability value above which we predict positive 36 | 37 | Returns: 38 | macro_f1 (scalar Tensor): value of macro F1 for the batch 39 | """ 40 | y_pred = tf.cast(tf.greater(y_hat, thresh), tf.float32) 41 | tp = tf.cast(tf.math.count_nonzero(y_pred * y, axis=0), tf.float32) 42 | fp = tf.cast(tf.math.count_nonzero(y_pred * (1 - y), axis=0), tf.float32) 43 | fn = tf.cast(tf.math.count_nonzero((1 - y_pred) * y, axis=0), tf.float32) 44 | f1 = 2*tp / (2*tp + fn + fp + 1e-16) 45 | macro_f1 = tf.reduce_mean(f1) 46 | return macro_f1 47 | 48 | 49 | 50 | @tf.function 51 | def Weighted_BCTL(y_true, y_pred): 52 | 53 | # Manually calculate the weighted cross entropy. 54 | # Formula is qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 55 | # where z are labels, x is logits, and q is the weight. 56 | # Since the values passed are from sigmoid (assuming in this case) 57 | # sigmoid(x) will be replaced by y_pred 58 | # qz * -log(sigmoid(x)) 1e-6 is added as an epsilon to stop passing a zero into the log 59 | 60 | ##get the positive labels 61 | 62 | y_true = tf.cast(y_true, tf.float32) 63 | y_pred = tf.cast(y_pred , tf.float32) 64 | 65 | P=tf.cast(tf.math.count_nonzero(y_true), tf.float32) 66 | N=tf.cast(len(tf.where(y_true==0)),tf.float32) 67 | 68 | BP1=P+N/P 69 | BP=tf.cast(BP1,tf.float32) 70 | 71 | BN=N+P/N 72 | BN=tf.cast(BN,tf.float32) 73 | 74 | 75 | x_1 =BP*(y_true * -tf.math.log(y_pred + 1e-6)) 76 | x_2 =BN*((1 - y_true) * -tf.math.log(1 - y_pred + 1e-6)) 77 | 78 | cost=tf.add(x_1, x_2) 79 | cost_a=tf.reduce_mean(cost) 80 | return cost_a 81 | 82 | -------------------------------------------------------------------------------- /StFA/resume_training_using_check_point.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | import tensorflow as tf 3 | 4 | tf.config.optimizer.set_jit(True) 5 | gpus = tf.config.experimental.list_physical_devices('GPU') 6 | if gpus: 7 | # Restrict TensorFlow to only use the first GPU 8 | try: 9 | tf.config.experimental.set_visible_devices(gpus[0], 'GPU') 10 | except RuntimeError as e: 11 | # Visible devices must be set at program startup 12 | print(e) 13 | 14 | 15 | from tensorflow.keras.optimizers import Adam 16 | from config import* 17 | import os 18 | import datetime 19 | from StFA_Model import* 20 | from tfrecords_utilities import decode_ct 21 | import numpy as np 22 | import random 23 | 24 | ####----Getting --the tfrecords 25 | def getting_list(path): 26 | a=[file for file in os.listdir(path) if file.endswith('.tfrecords')] 27 | all_tfrecoeds=random.sample(a, len(a)) 28 | #all_tfrecoeds.sort(key=lambda f: int(filter(str.isdigit, f))) 29 | list_of_tfrecords=[] 30 | for i in range(len(all_tfrecoeds)): 31 | tf_path=path+all_tfrecoeds[i] 32 | list_of_tfrecords.append(tf_path) 33 | return list_of_tfrecords 34 | 35 | #--Traing Decoder 36 | def load_training_tfrecords(record_mask_file,batch_size): 37 | dataset=tf.data.Dataset.list_files(record_mask_file).interleave(lambda x: tf.data.TFRecordDataset(x),cycle_length=NUMBER_OF_PARALLEL_CALL,num_parallel_calls=NUMBER_OF_PARALLEL_CALL) 38 | dataset=dataset.map(decode_ct,num_parallel_calls=NUMBER_OF_PARALLEL_CALL).repeat(TRAING_EPOCH).batch(batch_size) 39 | batched_dataset=dataset.prefetch(PARSHING) 40 | return batched_dataset 41 | 42 | #--Validation Decoder 43 | def load_validation_tfrecords(record_mask_file,batch_size): 44 | dataset=tf.data.Dataset.list_files(record_mask_file).interleave(tf.data.TFRecordDataset,cycle_length=NUMBER_OF_PARALLEL_CALL,num_parallel_calls=NUMBER_OF_PARALLEL_CALL) 45 | dataset=dataset.map(decode_ct,num_parallel_calls=NUMBER_OF_PARALLEL_CALL).repeat(TRAING_EPOCH).batch(batch_size) 46 | batched_dataset=dataset.prefetch(PARSHING) 47 | return batched_dataset 48 | 49 | 50 | def Training(): 51 | 52 | #TensorBoard 53 | logdir = os.path.join(LOG_NAME, datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) 54 | tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1) 55 | ##csv_logger 56 | csv_logger = tf.keras.callbacks.CSVLogger(TRAINING_CSV) 57 | ##Model-checkpoings 58 | path=TRAINING_SAVE_MODEL_PATH 59 | model_path=os.path.join(path, MODEL_SAVING_NAME) 60 | Model_callback= tf.keras.callbacks.ModelCheckpoint(filepath=model_path,save_best_only=False,save_weights_only=True,monitor=ModelCheckpoint_MOTITOR,verbose=1) 61 | ##----Preparing Data 62 | tf_train=getting_list(TRAINING_TF_RECORDS) 63 | tf_val=getting_list(VALIDATION_TF_RECORDS) 64 | traing_data=load_training_tfrecords(tf_train,BATCH_SIZE) 65 | Val_batched_dataset=load_validation_tfrecords(tf_val,BATCH_SIZE) 66 | 67 | if (NUM_OF_GPU==1): 68 | if RESUME_TRAINING==1: 69 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 70 | Model_3D=StFAModel_withDenseVnet(SEGMENTATION_MODEL_PATH,INPUT_PATCH_SIZE,NUMBER_OF_CLASSES,SEGMENTATION_NUM_OF_CLASSES) 71 | Model_3D.load_weights(RESUME_TRAIING_MODEL) 72 | initial_epoch_of_training=TRAINING_INITIAL_EPOCH 73 | print('Resume-Training From-Epoch{}-Loading-Model-from_{}'.format(initial_epoch_of_training,RESUME_TRAIING_MODEL)) 74 | Model_3D.compile(optimizer=OPTIMIZER, loss=[TRAIN_CLASSIFY_LOSS], metrics=[tf.keras.metrics.AUC(),tf.keras.metrics.Precision(),tf.keras.metrics.Recall()]) 75 | Model_3D.summary() 76 | else: 77 | initial_epoch_of_training=0 78 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 79 | Model_3D=StFAModel_withDenseVnet(SEGMENTATION_MODEL_PATH,INPUT_PATCH_SIZE,NUMBER_OF_CLASSES,SEGMENTATION_NUM_OF_CLASSES) 80 | Model_3D.compile(optimizer=OPTIMIZER, loss=[TRAIN_CLASSIFY_LOSS], metrics=[TRAIN_CLASSIFY_METRICS,tf.keras.metrics.Precision(),tf.keras.metrics.Recall()]) 81 | Model_3D.summary() 82 | Model_3D.fit(traing_data, 83 | steps_per_epoch=TRAINING_STEP_PER_EPOCH, 84 | epochs=TRAING_EPOCH, 85 | initial_epoch=initial_epoch_of_training, 86 | validation_data=Val_batched_dataset, 87 | validation_steps=VALIDATION_STEP, 88 | callbacks=[tensorboard_callback,csv_logger,Model_callback]) 89 | 90 | ###Multigpu---- 91 | else: 92 | mirrored_strategy = tf.distribute.MirroredStrategy(DISTRIIBUTED_STRATEGY_GPUS) 93 | with mirrored_strategy.scope(): 94 | if RESUME_TRAINING==1: 95 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 96 | Model_3D=StFAModel_withDenseVnet(SEGMENTATION_MODEL_PATH,INPUT_PATCH_SIZE,NUMBER_OF_CLASSES,SEGMENTATION_NUM_OF_CLASSES) 97 | Model_3D.load_weights(RESUME_TRAIING_MODEL) 98 | initial_epoch_of_training=TRAINING_INITIAL_EPOCH 99 | print('Resume-Training From-Epoch{}-Loading-Model-from_{}'.format(initial_epoch_of_training,RESUME_TRAIING_MODEL)) 100 | Model_3D.compile(optimizer=OPTIMIZER, loss=[TRAIN_CLASSIFY_LOSS], metrics=[tf.keras.metrics.AUC(),tf.keras.metrics.Precision(),tf.keras.metrics.Recall()]) 101 | Model_3D.summary() 102 | else: 103 | initial_epoch_of_training=0 104 | inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') 105 | Model_3D=StFAModel_withDenseVnet(SEGMENTATION_MODEL_PATH,INPUT_PATCH_SIZE,NUMBER_OF_CLASSES,SEGMENTATION_NUM_OF_CLASSES) 106 | Model_3D.compile(optimizer=OPTIMIZER, loss=[TRAIN_CLASSIFY_LOSS], metrics=[TRAIN_CLASSIFY_METRICS,tf.keras.metrics.Precision(),tf.keras.metrics.Recall()]) 107 | Model_3D.summary() 108 | Model_3D.fit(traing_data, 109 | steps_per_epoch=TRAINING_STEP_PER_EPOCH, 110 | epochs=TRAING_EPOCH, 111 | initial_epoch=initial_epoch_of_training, 112 | validation_data=Val_batched_dataset, 113 | validation_steps=VALIDATION_STEP, 114 | callbacks=[tensorboard_callback,csv_logger,Model_callback]) 115 | 116 | if __name__ == '__main__': 117 | Training() 118 | -------------------------------------------------------------------------------- /StFA/tfrecords_utilities.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import pandas as pd 6 | import SimpleITK as sitk 7 | from Preprocessing_utlities import extract_class_balanced_example_array 8 | from Preprocessing_utlities import resize_image_with_crop_or_pad 9 | from scipy.ndimage.interpolation import map_coordinates 10 | from scipy.ndimage.filters import gaussian_filter 11 | from config import* 12 | 13 | 14 | ########################-------Fucntions for tf records 15 | # The following functions can be used to convert a value to a type compatible 16 | # with tf.Example. 17 | def _bytes_feature(value): 18 | """Returns a bytes_list from a string / byte.""" 19 | if isinstance(value, type(tf.constant(0))): 20 | value = value.numpy() # BytesList won't unpack a string from an EagerTensor. 21 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 22 | 23 | def _float_feature(value): 24 | """Returns a float_list from a float / double.""" 25 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 26 | 27 | def _int64_feature(value): 28 | """Returns an int64_list from a bool / enum / int / uint.""" 29 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 30 | 31 | 32 | def flow_from_df(dataframe: pd.DataFrame, chunk_size): 33 | for start_row in range(0, dataframe.shape[0], chunk_size): 34 | end_row = min(start_row + chunk_size, dataframe.shape[0]) 35 | yield dataframe.iloc[start_row:end_row, :] 36 | 37 | 38 | def creat_tfrecord(df,extraction_perameter,tf_name): 39 | 40 | read_csv=df.as_matrix() 41 | patch_params = extraction_perameter 42 | 43 | img_list=[] 44 | mask_list=[] 45 | lbl_list=[] 46 | id_name=[] 47 | 48 | for Data in read_csv: 49 | img_path = Data[4] 50 | subject_id = img_path.split('/')[-1].split('.')[0] 51 | Subject_lbl=Data[5:10] 52 | print(Subject_lbl.shape) 53 | 54 | print('Subject ID-{}'.format(subject_id)) 55 | print('Labels--{}'.format(Subject_lbl)) 56 | 57 | #Img 58 | img_sitk = sitk.ReadImage(img_path, sitk.sitkFloat32) 59 | image= sitk.GetArrayFromImage(img_sitk) 60 | #Mask 61 | mask_fn = str(Data[10]) 62 | mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_fn)).astype(np.int32) 63 | print(mask.shape) 64 | print(image.shape) 65 | 66 | patch_size=patch_params['example_size'] 67 | img_shape=image.shape 68 | 69 | ###----padding_data_if_needed 70 | #####----z dimention-----###### 71 | if (patch_size[0] >=img_shape[0]): 72 | dimention1=patch_size[0]+10 73 | else: 74 | dimention1=img_shape[0] 75 | 76 | #####----x dimention-----###### 77 | if (patch_size[1] >=img_shape[1]): 78 | dimention2=patch_size[1]+10 79 | else: 80 | dimention2=img_shape[1] 81 | 82 | #####----Y dimention-----###### 83 | if (patch_size[2] >=img_shape[2]): 84 | dimention3=patch_size[2]+10 85 | else: 86 | dimention3=img_shape[2] 87 | print('------before padding image shape--{}-----'.format(image.shape)) 88 | image=resize_image_with_crop_or_pad(image, [dimention1,dimention2,dimention3], mode='symmetric') 89 | mask=resize_image_with_crop_or_pad(mask, [dimention1,dimention2,dimention3], mode='symmetric') 90 | print('######before padding image shape--{}#####'.format(image.shape)) 91 | 92 | 93 | 94 | img_shape=image.shape 95 | image= np.expand_dims(image, axis=3) 96 | 97 | images,masks = extract_class_balanced_example_array( 98 | image,mask, 99 | example_size=patch_params['example_size'], 100 | n_examples=patch_params['n_examples'], 101 | classes=4,class_weights=[0,0,1,1]) 102 | 103 | print(images.shape) 104 | 105 | for e in range(patch_params['n_examples']): 106 | img_list.append(images[e][:,:,:,0]) 107 | #print(images[e][:,:,:,0].shape) 108 | mask_list.append(masks[e][:,:,:]) 109 | #print('Mask-Shape=={}'.format(masks[e][:,:,:].shape)) 110 | lbl_list.append(Subject_lbl) 111 | patch_name=str(subject_id+'_{}'.format(e)) 112 | #Converting_string_bytes 113 | patch_name =bytes(patch_name, 'utf-8') 114 | #print(patch_name) 115 | id_name.append(patch_name) 116 | 117 | print('This Rfrecords will contain--{}--Pathes--of-size--{}'.format(len(id_name),patch_params['example_size'])) 118 | 119 | record_mask_file = tf_name 120 | with tf.io.TFRecordWriter(record_mask_file) as writer: 121 | for e in range(len(img_list)): 122 | feature = {'label1': _int64_feature(lbl_list[e][0]), 123 | 'label2': _int64_feature(lbl_list[e][1]), 124 | 'label3': _int64_feature(lbl_list[e][2]), 125 | 'label4': _int64_feature(lbl_list[e][3]), 126 | 'label5': _int64_feature(lbl_list[e][4]), 127 | 'image':_bytes_feature(img_list[e].tostring()), 128 | 'mask':_bytes_feature(mask_list[e].tostring()), 129 | 'Height':_int64_feature(patch_params['example_size'][0]), 130 | 'Weight':_int64_feature(patch_params['example_size'][1]), 131 | 'Depth':_int64_feature(patch_params['example_size'][2]), 132 | 'label_shape':_int64_feature(5), 133 | 'Sub_id':_bytes_feature(id_name[e]) 134 | } 135 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 136 | writer.write(example.SerializeToString()) 137 | 138 | writer.close() 139 | 140 | return 141 | 142 | 143 | @tf.function 144 | def decode_ct(Serialized_example): 145 | 146 | features={ 147 | 'label1': tf.io.FixedLenFeature([],tf.int64), 148 | 'label2': tf.io.FixedLenFeature([],tf.int64), 149 | 'label3': tf.io.FixedLenFeature([],tf.int64), 150 | 'label4': tf.io.FixedLenFeature([],tf.int64), 151 | 'label5': tf.io.FixedLenFeature([],tf.int64), 152 | 'image':tf.io.FixedLenFeature([],tf.string), 153 | 'mask':tf.io.FixedLenFeature([],tf.string), 154 | 'Height':tf.io.FixedLenFeature([],tf.int64), 155 | 'Weight':tf.io.FixedLenFeature([],tf.int64), 156 | 'Depth':tf.io.FixedLenFeature([],tf.int64), 157 | 'label_shape':tf.io.FixedLenFeature([],tf.int64), 158 | 'Sub_id':tf.io.FixedLenFeature([],tf.string) 159 | 160 | } 161 | examples=tf.io.parse_single_example(Serialized_example,features) 162 | ##Decode_image_float 163 | image_1 = tf.io.decode_raw(examples['image'], float) 164 | #Decode_mask_as_int32 165 | #mask_1 = tf.io.decode_raw(examples['mask'], tf.int32) 166 | ##Subject id is already in bytes format 167 | #sub_id=examples['Sub_id'] 168 | 169 | 170 | img_shape=[examples['Height'],examples['Weight'],examples['Depth']] 171 | #img_shape2=[img_shape[0],img_shape[1],img_shape[2]] 172 | print(img_shape) 173 | #Resgapping_the_data 174 | img=tf.reshape(image_1,img_shape) 175 | #Because CNN expect(batch,H,W,D,CHANNEL) 176 | img=tf.expand_dims(img, axis=-1) 177 | #mask=tf.reshape(mask_1,img_shape) 178 | #mask=tf.expand_dims(mask, axis=-1) 179 | ###casting_values 180 | img=tf.cast(img, tf.float32) 181 | #mask=tf.cast(mask,tf.int32) 182 | 183 | lbl=[examples['label1'],examples['label2'],examples['label3'],examples['label4'],examples['label5']] 184 | ##Transpossing the Multilabels 185 | #lbl=tf.linalg.matrix_transpose(lbl) 186 | return img,lbl 187 | -------------------------------------------------------------------------------- /figure/Model_Architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fitushar/WeaklySupervised-3D-Classification-of-Chest-CT-using-Aggregated-MultiResolution-Segmentation-Feature/30975d90c8c7f84e498e8f54746c5b71b535d9d3/figure/Model_Architecture.png -------------------------------------------------------------------------------- /figure/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fitushar/WeaklySupervised-3D-Classification-of-Chest-CT-using-Aggregated-MultiResolution-Segmentation-Feature/30975d90c8c7f84e498e8f54746c5b71b535d9d3/figure/dataset.png -------------------------------------------------------------------------------- /figure/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fitushar/WeaklySupervised-3D-Classification-of-Chest-CT-using-Aggregated-MultiResolution-Segmentation-Feature/30975d90c8c7f84e498e8f54746c5b71b535d9d3/figure/results.png --------------------------------------------------------------------------------