├── README.md ├── model ├── custom_layers.py ├── densenet3d_regression.py ├── resnet_3d.py ├── sfcn_keras.py └── vgg_16.py └── regression_train.py /README.md: -------------------------------------------------------------------------------- 1 | # Brain_age_prediction 2 | Neuroimaging based brain age prediction using 3D Dense-net 3 | -------------------------------------------------------------------------------- /model/custom_layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | author: Gao Huang, Zhuang Liu, Kilian Q. Weinberger, Laurens van der Maaten 5 | Densely Connected Convolutional Networks 6 | arXiv:1608.06993 7 | (See https://github.com/flyyufelix/DenseNet-Keras/blob/master/densenet161.py) 8 | """ 9 | """ 10 | from tensorflow.python.keras.layers import Layer, InputSpec 11 | try: 12 | from tensorflow.keras import initializations 13 | except ImportError: 14 | from tensorflow.keras import initializers as initializations 15 | import tensorflow.keras.backend as K 16 | """ 17 | 18 | from keras.layers import Layer, InputSpec 19 | try: 20 | from keras import initializations 21 | except ImportError: 22 | from keras import initializers as initializations 23 | import keras.backend as K 24 | 25 | class Scale(Layer): 26 | '''Custom Layer for DenseNet used for BatchNormalization. 27 | 28 | Learns a set of weights and biases used for scaling the input data. 29 | the output consists simply in an element-wise multiplication of the input 30 | and a sum of a set of constants: 31 | out = in * gamma + beta, 32 | where 'gamma' and 'beta' are the weights and biases larned. 33 | # Arguments 34 | axis: integer, axis along which to normalize in mode 0. For instance, 35 | if your input tensor has shape (samples, channels, rows, cols), 36 | set axis to 1 to normalize per feature map (channels axis). 37 | momentum: momentum in the computation of the 38 | exponential average of the mean and standard deviation 39 | of the data, for feature-wise normalization. 40 | weights: Initialization weights. 41 | List of 2 Numpy arrays, with shapes: 42 | `[(input_shape,), (input_shape,)]` 43 | beta_init: name of initialization function for shift parameter 44 | (see [initializations](../initializations.md)), or alternatively, 45 | Theano/TensorFlow function to use for weights initialization. 46 | This parameter is only relevant if you don't pass a `weights` argument. 47 | gamma_init: name of initialization function for scale parameter (see 48 | [initializations](../initializations.md)), or alternatively, 49 | Theano/TensorFlow function to use for weights initialization. 50 | This parameter is only relevant if you don't pass a `weights` argument. 51 | ''' 52 | def __init__(self, weights=None, axis=-1, momentum = 0.9, beta_init='zero', gamma_init='one', **kwargs): 53 | self.momentum = momentum 54 | self.axis = axis 55 | self.beta_init = initializations.get(beta_init) 56 | self.gamma_init = initializations.get(gamma_init) 57 | self.initial_weights = weights 58 | super(Scale, self).__init__(**kwargs) 59 | 60 | def build(self, input_shape): 61 | self.input_spec = [InputSpec(shape=input_shape)] 62 | shape = (int(input_shape[self.axis]),) 63 | 64 | # Tensorflow >= 1.0.0 compatibility 65 | self.gamma = K.variable(self.gamma_init(shape), name='{}_gamma'.format(self.name)) 66 | self.beta = K.variable(self.beta_init(shape), name='{}_beta'.format(self.name)) 67 | #self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name)) 68 | #self.beta = self.beta_init(shape, name='{}_beta'.format(self.name)) 69 | self._trainable_weights = [self.gamma, self.beta] 70 | 71 | if self.initial_weights is not None: 72 | self.set_weights(self.initial_weights) 73 | del self.initial_weights 74 | 75 | def call(self, x, mask=None): 76 | input_shape = self.input_spec[0].shape 77 | broadcast_shape = [1] * len(input_shape) 78 | broadcast_shape[self.axis] = input_shape[self.axis] 79 | 80 | out = K.reshape(self.gamma, broadcast_shape) * x + K.reshape(self.beta, broadcast_shape) 81 | return out 82 | 83 | def get_config(self): 84 | config = {"momentum": self.momentum, "axis": self.axis} 85 | base_config = super(Scale, self).get_config() 86 | return dict(list(base_config.items()) + list(config.items())) -------------------------------------------------------------------------------- /model/densenet3d_regression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @ Modified by Jeyeon Lee 5 | author: Gao Huang, Zhuang Liu, Kilian Q. Weinberger, Laurens van der Maaten 6 | Densely Connected Convolutional Networks 7 | arXiv:1608.06993 8 | (See https://github.com/flyyufelix/DenseNet-Keras/blob/master/densenet161.py) 9 | """ 10 | 11 | 12 | from keras.models import Model 13 | from keras.layers import ( 14 | Input, concatenate, ZeroPadding3D, 15 | Conv3D, Dense, Dropout, Activation, 16 | AveragePooling3D, GlobalAveragePooling3D, MaxPooling3D, 17 | BatchNormalization 18 | ) 19 | import keras.backend as K 20 | 21 | from custom_layers import Scale 22 | 23 | def build_densenet_forCAM(input_shape, densenettype): 24 | '''Instantiate the DenseNet 161 architecture, 25 | # Arguments 26 | nb_dense_block: number of dense blocks to add to end 27 | growth_rate: number of filters to add per dense block 28 | nb_filter: initial number of filters 29 | reduction: reduction factor of transition blocks. 30 | dropout_rate: dropout rate 31 | weight_decay: weight decay factor 32 | classes: optional number of classes to classify images 33 | weights_path: path to pre-trained weights 34 | # Returns 35 | A Keras model instance. 36 | ''' 37 | eps = 1.1e-5 38 | nb_dense_block=4 #4 39 | growth_rate=48 40 | nb_filter=96 41 | reduction=0.0 42 | dropout_rate=0.0 43 | weight_decay=1e-4 44 | weights_path=None 45 | # compute compression factor 46 | compression = 1.0 - reduction 47 | 48 | # Handle Dimension Ordering for different backends 49 | global concat_axis 50 | #if K.image_data_format() == 'tf': 51 | concat_axis = 4 52 | img_input = Input(shape=input_shape, name='data') 53 | 54 | # From architecture for ImageNet (Table 1 in the paper) 55 | if densenettype == -2: 56 | nb_dense_block=2 57 | nb_filter = 64 58 | nb_layers = [3,6] # For DenseNet-CAM 59 | elif densenettype == -1: 60 | nb_filter = 64 61 | nb_layers = [3,6,6,4] # For DenseNet-CAM 62 | elif densenettype == 0: 63 | nb_filter = 64 64 | nb_layers = [3,6,12,8] # For DenseNet-CAM 65 | elif densenettype == 1: 66 | nb_filter = 64 67 | nb_layers = [6,12,24,16] # For DenseNet-121 68 | elif densenettype == 2: 69 | nb_filter = 96 70 | nb_layers = [6,12,36,24] # For DenseNet-161 71 | elif densenettype == 3: 72 | nb_filter = 64 73 | nb_layers = [6,12,32,32] # For DenseNet-169 74 | 75 | # Initial convolution 76 | x = ZeroPadding3D((3, 3, 3), name='conv1_zeropadding')(img_input) 77 | if densenettype == -2 or densenettype == -1 or densenettype == 0: 78 | x = Conv3D(nb_filter, (5, 5, 5), strides=(2, 2, 2), name='conv1', use_bias=False)(x) #7 7 7 79 | x = BatchNormalization(epsilon=eps, axis=concat_axis, name='conv1_bn')(x) 80 | x = Scale(axis=concat_axis, name='conv1_scale')(x) 81 | x = Activation('relu', name='relu1')(x) 82 | x = ZeroPadding3D((1, 1, 1), name='pool1_zeropadding')(x) 83 | if densenettype == -1 or densenettype == 0: 84 | x = MaxPooling3D((3, 3, 3), strides=(2, 2, 2), name='pool1')(x) 85 | else: 86 | x = MaxPooling3D((3, 3, 3), strides=(2, 2, 2), name='pool1')(x) 87 | # Add dense blocks 88 | for block_idx in range(nb_dense_block - 1): 89 | stage = block_idx+2 90 | x, nb_filter = dense_block(x, stage, nb_layers[block_idx], nb_filter, growth_rate, \ 91 | dropout_rate=dropout_rate, weight_decay=weight_decay) 92 | 93 | # Add transition_block 94 | x = transition_block(x, stage, nb_filter, compression=compression, \ 95 | dropout_rate=dropout_rate, weight_decay=weight_decay) 96 | nb_filter = int(nb_filter * compression) 97 | 98 | final_stage = stage + 1 99 | x, nb_filter = dense_block(x, final_stage, nb_layers[-1], nb_filter, growth_rate, \ 100 | dropout_rate=dropout_rate, weight_decay=weight_decay) 101 | 102 | x = BatchNormalization(epsilon=eps, axis=concat_axis, name='conv'+str(final_stage)+'_blk_bn')(x) 103 | x = Scale(axis=concat_axis, name='conv'+str(final_stage)+'_blk_scale')(x) 104 | x = Activation('relu', name='relu'+str(final_stage)+'_blk')(x) 105 | 106 | 107 | ############################################################################## 108 | # add lastconv and global average pool for CAM 109 | CAM_conv = Conv3D(filters=x._keras_shape[4], 110 | kernel_size=(3, 3, 3), 111 | strides=(1, 1, 1), padding="same", 112 | name='CAM_conv')(x) 113 | flatten1 = GlobalAveragePooling3D(name='CAM_pool')(CAM_conv) 114 | ############################################################################## 115 | 116 | #x = Dense(1, name='CAM_fc')(flatten1) 117 | #x = Activation('linear') 118 | #x = Activation('softmax', name='prob')(x) 119 | #x = model.add(Dense(1,activation="linear")) 120 | 121 | x = Dense(1,name='fc')(flatten1) 122 | x = Activation('linear')(x) 123 | model = Model(img_input, x, name='densenetregression') 124 | 125 | if weights_path is not None: 126 | model.load_weights(weights_path) 127 | 128 | return model 129 | 130 | 131 | def conv_block(x, stage, branch, nb_filter, dropout_rate=None, weight_decay=1e-4): 132 | '''Apply BatchNorm, Relu, bottleneck 1x1 Conv2D, 3x3 Conv2D, and option dropout 133 | # Arguments 134 | x: input tensor 135 | stage: index for dense block 136 | branch: layer index within each dense block 137 | nb_filter: number of filters 138 | dropout_rate: dropout rate 139 | weight_decay: weight decay factor 140 | ''' 141 | eps = 1.1e-5 142 | conv_name_base = 'conv' + str(stage) + '_' + str(branch) 143 | relu_name_base = 'relu' + str(stage) + '_' + str(branch) 144 | 145 | # 1x1 Convolution (Bottleneck layer) 146 | inter_channel = nb_filter * 4 147 | x = BatchNormalization(epsilon=eps, axis=concat_axis, name=conv_name_base+'_x1_bn')(x) 148 | x = Scale(axis=concat_axis, name=conv_name_base+'_x1_scale')(x) 149 | x = Activation('relu', name=relu_name_base+'_x1')(x) 150 | x = Conv3D(inter_channel, (1, 1, 1), name=conv_name_base+'_x1', use_bias=False)(x) 151 | 152 | if dropout_rate: 153 | x = Dropout(dropout_rate)(x) 154 | 155 | # 3x3 Convolution 156 | x = BatchNormalization(epsilon=eps, axis=concat_axis, name=conv_name_base+'_x2_bn')(x) 157 | x = Scale(axis=concat_axis, name=conv_name_base+'_x2_scale')(x) 158 | x = Activation('relu', name=relu_name_base+'_x2')(x) 159 | x = ZeroPadding3D((1, 1, 1), name=conv_name_base+'_x2_zeropadding')(x) 160 | x = Conv3D(nb_filter, (3, 3, 3), name=conv_name_base+'_x2', use_bias=False)(x) 161 | 162 | if dropout_rate: 163 | x = Dropout(dropout_rate)(x) 164 | 165 | return x 166 | 167 | 168 | def transition_block(x, stage, nb_filter, compression=1.0, dropout_rate=None, weight_decay=1E-4): 169 | ''' Apply BatchNorm, 1x1 Convolution, averagePooling, optional compression, dropout 170 | # Arguments 171 | x: input tensor 172 | stage: index for dense block 173 | nb_filter: number of filters 174 | compression: calculated as 1 - reduction. Reduces the number of feature maps in the transition block. 175 | dropout_rate: dropout rate 176 | weight_decay: weight decay factor 177 | ''' 178 | 179 | eps = 1.1e-5 180 | conv_name_base = 'conv' + str(stage) + '_blk' 181 | relu_name_base = 'relu' + str(stage) + '_blk' 182 | pool_name_base = 'pool' + str(stage) 183 | 184 | x = BatchNormalization(epsilon=eps, axis=concat_axis, name=conv_name_base+'_bn')(x) 185 | x = Scale(axis=concat_axis, name=conv_name_base+'_scale')(x) 186 | x = Activation('relu', name=relu_name_base)(x) 187 | x = Conv3D(int(nb_filter * compression), (1, 1, 1), name=conv_name_base, use_bias=False)(x) 188 | 189 | if dropout_rate: 190 | x = Dropout(dropout_rate)(x) 191 | 192 | x = AveragePooling3D((2, 2, 2), strides=(2, 2, 2), name=pool_name_base)(x) 193 | 194 | return x 195 | 196 | 197 | def dense_block(x, stage, nb_layers, nb_filter, growth_rate, dropout_rate=None, weight_decay=1e-4, grow_nb_filters=True): 198 | ''' Build a dense_block where the output of each conv_block is fed to subsequent ones 199 | # Arguments 200 | x: input tensor 201 | stage: index for dense block 202 | nb_layers: the number of layers of conv_block to append to the model. 203 | nb_filter: number of filters 204 | growth_rate: growth rate 205 | dropout_rate: dropout rate 206 | weight_decay: weight decay factor 207 | grow_nb_filters: flag to decide to allow number of filters to grow 208 | ''' 209 | 210 | # eps = 1.1e-5 211 | concat_feat = x 212 | 213 | for i in range(nb_layers): 214 | branch = i+1 215 | x = conv_block(concat_feat, stage, branch, growth_rate, dropout_rate, weight_decay) 216 | concat_feat = concatenate([concat_feat, x], name='concat_'+str(stage)+'_'+str(branch)) 217 | # concat_feat = add([concat_feat, x], name='concat_'+str(stage)+'_'+str(branch)) 218 | 219 | if grow_nb_filters: 220 | nb_filter += growth_rate 221 | 222 | return concat_feat, nb_filter 223 | 224 | -------------------------------------------------------------------------------- /model/resnet_3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Sep 23 21:30:40 2020 5 | @author: J.Lee 6 | """ 7 | 8 | from __future__ import ( 9 | absolute_import, 10 | division, 11 | print_function, 12 | unicode_literals 13 | ) 14 | import six 15 | from math import ceil 16 | from keras.models import Model 17 | from keras.layers import ( 18 | Input, 19 | Activation, 20 | Dense, 21 | Flatten 22 | ) 23 | from keras.layers.convolutional import ( 24 | Conv3D, 25 | AveragePooling3D, 26 | MaxPooling3D 27 | ) 28 | from keras.layers.merge import add 29 | from keras.layers.normalization import BatchNormalization 30 | from keras.regularizers import l2 31 | from keras import backend as K 32 | 33 | 34 | def _bn_relu(input): 35 | """Helper to build a BN -> relu block (by @raghakot).""" 36 | norm = BatchNormalization(axis=CHANNEL_AXIS)(input) 37 | return Activation("relu")(norm) 38 | 39 | 40 | def _conv_bn_relu3D(**conv_params): 41 | filters = conv_params["filters"] 42 | kernel_size = conv_params["kernel_size"] 43 | strides = conv_params.setdefault("strides", (1, 1, 1)) 44 | kernel_initializer = conv_params.setdefault( 45 | "kernel_initializer", "he_normal") 46 | padding = conv_params.setdefault("padding", "same") 47 | kernel_regularizer = conv_params.setdefault("kernel_regularizer", 48 | l2(1e-4)) 49 | 50 | def f(input): 51 | conv = Conv3D(filters=filters, kernel_size=kernel_size, 52 | strides=strides, kernel_initializer=kernel_initializer, 53 | padding=padding, 54 | kernel_regularizer=kernel_regularizer)(input) 55 | return _bn_relu(conv) 56 | 57 | return f 58 | 59 | 60 | def _bn_relu_conv3d(**conv_params): 61 | """Helper to build a BN -> relu -> conv3d block.""" 62 | filters = conv_params["filters"] 63 | kernel_size = conv_params["kernel_size"] 64 | strides = conv_params.setdefault("strides", (1, 1, 1)) 65 | kernel_initializer = conv_params.setdefault("kernel_initializer", 66 | "he_normal") 67 | padding = conv_params.setdefault("padding", "same") 68 | kernel_regularizer = conv_params.setdefault("kernel_regularizer", 69 | l2(1e-4)) 70 | 71 | def f(input): 72 | activation = _bn_relu(input) 73 | return Conv3D(filters=filters, kernel_size=kernel_size, 74 | strides=strides, kernel_initializer=kernel_initializer, 75 | padding=padding, 76 | kernel_regularizer=kernel_regularizer)(activation) 77 | return f 78 | 79 | def _shortcut3d(input, residual): 80 | """3D shortcut to match input and residual and merges them with "sum".""" 81 | stride_dim1 = ceil(input._keras_shape[DIM1_AXIS] \ 82 | / residual._keras_shape[DIM1_AXIS]) 83 | stride_dim2 = ceil(input._keras_shape[DIM2_AXIS] \ 84 | / residual._keras_shape[DIM2_AXIS]) 85 | stride_dim3 = ceil(input._keras_shape[DIM3_AXIS] \ 86 | / residual._keras_shape[DIM3_AXIS]) 87 | equal_channels = residual._keras_shape[CHANNEL_AXIS] \ 88 | == input._keras_shape[CHANNEL_AXIS] 89 | 90 | shortcut = input 91 | if stride_dim1 > 1 or stride_dim2 > 1 or stride_dim3 > 1 \ 92 | or not equal_channels: 93 | shortcut = Conv3D( 94 | filters=residual._keras_shape[CHANNEL_AXIS], 95 | kernel_size=(1, 1, 1), 96 | strides=(stride_dim1, stride_dim2, stride_dim3), 97 | kernel_initializer="he_normal", padding="valid", 98 | kernel_regularizer=l2(1e-4) 99 | )(input) 100 | return add([shortcut, residual]) 101 | 102 | 103 | def _residual_block3d(block_function, filters, kernel_regularizer, repetitions, 104 | is_first_layer=False): 105 | def f(input): 106 | for i in range(repetitions): 107 | strides = (1, 1, 1) 108 | if i == 0 and not is_first_layer: 109 | strides = (2, 2, 2) 110 | input = block_function(filters=filters, strides=strides, 111 | kernel_regularizer=kernel_regularizer, 112 | is_first_block_of_first_layer=( 113 | is_first_layer and i == 0) 114 | )(input) 115 | return input 116 | 117 | return f 118 | 119 | 120 | def basic_block(filters, strides=(1, 1, 1), kernel_regularizer=l2(1e-4), 121 | is_first_block_of_first_layer=False): 122 | """Basic 3 X 3 X 3 convolution blocks. Extended from raghakot's 2D impl.""" 123 | def f(input): 124 | if is_first_block_of_first_layer: 125 | # don't repeat bn->relu since we just did bn->relu->maxpool 126 | conv1 = Conv3D(filters=filters, kernel_size=(3, 3, 3), 127 | strides=strides, padding="same", 128 | kernel_initializer="he_normal", 129 | kernel_regularizer=kernel_regularizer 130 | )(input) 131 | else: 132 | conv1 = _bn_relu_conv3d(filters=filters, 133 | kernel_size=(3, 3, 3), 134 | strides=strides, 135 | kernel_regularizer=kernel_regularizer 136 | )(input) 137 | 138 | residual = _bn_relu_conv3d(filters=filters, kernel_size=(3, 3, 3), 139 | kernel_regularizer=kernel_regularizer 140 | )(conv1) 141 | return _shortcut3d(input, residual) 142 | 143 | return f 144 | 145 | 146 | def bottleneck(filters, strides=(1, 1, 1), kernel_regularizer=l2(1e-4), 147 | is_first_block_of_first_layer=False): 148 | """Basic 3 X 3 X 3 convolution blocks. Extended from raghakot's 2D impl.""" 149 | def f(input): 150 | if is_first_block_of_first_layer: 151 | # don't repeat bn->relu since we just did bn->relu->maxpool 152 | conv_1_1 = Conv3D(filters=filters, kernel_size=(1, 1, 1), 153 | strides=strides, padding="same", 154 | kernel_initializer="he_normal", 155 | kernel_regularizer=kernel_regularizer 156 | )(input) 157 | else: 158 | conv_1_1 = _bn_relu_conv3d(filters=filters, kernel_size=(1, 1, 1), 159 | strides=strides, 160 | kernel_regularizer=kernel_regularizer 161 | )(input) 162 | 163 | conv_3_3 = _bn_relu_conv3d(filters=filters, kernel_size=(3, 3, 3), 164 | kernel_regularizer=kernel_regularizer 165 | )(conv_1_1) 166 | residual = _bn_relu_conv3d(filters=filters * 4, kernel_size=(1, 1, 1), 167 | kernel_regularizer=kernel_regularizer 168 | )(conv_3_3) 169 | 170 | return _shortcut3d(input, residual) 171 | 172 | return f 173 | 174 | 175 | 176 | def _handle_data_format(): 177 | global DIM1_AXIS 178 | global DIM2_AXIS 179 | global DIM3_AXIS 180 | global CHANNEL_AXIS 181 | if K.image_data_format() == 'channels_last': 182 | DIM1_AXIS = 1 183 | DIM2_AXIS = 2 184 | DIM3_AXIS = 3 185 | CHANNEL_AXIS = 4 186 | else: 187 | CHANNEL_AXIS = 1 188 | DIM1_AXIS = 2 189 | DIM2_AXIS = 3 190 | DIM3_AXIS = 4 191 | 192 | 193 | 194 | def _get_block(identifier): 195 | if isinstance(identifier, six.string_types): 196 | res = globals().get(identifier) 197 | if not res: 198 | raise ValueError('Invalid {}'.format(identifier)) 199 | return res 200 | return identifier 201 | 202 | 203 | 204 | class Resnet3DBuilder(object): 205 | """ResNet3D.""" 206 | 207 | @staticmethod 208 | def build(input_shape, num_outputs, block_fn, repetitions, reg_factor): 209 | """Instantiate a vanilla ResNet3D keras model. 210 | # Arguments 211 | input_shape: Tuple of input shape in the format 212 | (conv_dim1, conv_dim2, conv_dim3, channels) if dim_ordering='tf' 213 | (filter, conv_dim1, conv_dim2, conv_dim3) if dim_ordering='th' 214 | num_outputs: The number of outputs at the final softmax layer 215 | block_fn: Unit block to use {'basic_block', 'bottlenack_block'} 216 | repetitions: Repetitions of unit blocks 217 | # Returns 218 | model: a 3D ResNet model that takes a 5D tensor (volumetric images 219 | in batch) as input and returns a 1D vector (prediction) as output. 220 | """ 221 | _handle_data_format() 222 | if len(input_shape) != 4: 223 | raise ValueError("Input shape should be a tuple " 224 | "(conv_dim1, conv_dim2, conv_dim3, channels) " 225 | "for tensorflow as backend or " 226 | "(channels, conv_dim1, conv_dim2, conv_dim3) " 227 | "for theano as backend") 228 | 229 | block_fn = _get_block(block_fn) 230 | input = Input(shape=input_shape) 231 | # first conv 232 | conv1 = _conv_bn_relu3D(filters=64, kernel_size=(7, 7, 7), 233 | strides=(2, 2, 2), 234 | kernel_regularizer=l2(reg_factor) 235 | )(input) 236 | pool1 = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2), 237 | padding="same")(conv1) 238 | 239 | # repeat blocks 240 | block = pool1 241 | filters = 64 242 | for i, r in enumerate(repetitions): 243 | block = _residual_block3d(block_fn, filters=filters, 244 | kernel_regularizer=l2(reg_factor), 245 | repetitions=r, is_first_layer=(i == 0) 246 | )(block) 247 | filters *= 2 248 | 249 | # last activation 250 | block_output = _bn_relu(block) 251 | 252 | # average poll and classification 253 | pool2 = AveragePooling3D(pool_size=(block._keras_shape[DIM1_AXIS], 254 | block._keras_shape[DIM2_AXIS], 255 | block._keras_shape[DIM3_AXIS]), 256 | strides=(1, 1, 1))(block_output) 257 | flatten1 = Flatten()(pool2) 258 | if num_outputs > 1: 259 | dense = Dense(units=num_outputs, 260 | kernel_initializer="he_normal", 261 | activation="softmax", 262 | kernel_regularizer=l2(reg_factor))(flatten1) 263 | else: 264 | dense = Dense(units=num_outputs, 265 | kernel_initializer="he_normal", 266 | activation="linear", 267 | kernel_regularizer=l2(reg_factor))(flatten1) 268 | 269 | model = Model(inputs=input, outputs=dense) 270 | return model 271 | 272 | @staticmethod 273 | def build_resnet_18(input_shape, num_outputs, reg_factor=1e-4): 274 | """Build resnet 18.""" 275 | return Resnet3DBuilder.build(input_shape, num_outputs, basic_block, 276 | [2, 2, 2, 2], reg_factor=reg_factor) 277 | 278 | @staticmethod 279 | def build_resnet_34(input_shape, num_outputs, reg_factor=1e-4): 280 | """Build resnet 34.""" 281 | return Resnet3DBuilder.build(input_shape, num_outputs, basic_block, 282 | [3, 4, 6, 3], reg_factor=reg_factor) 283 | 284 | @staticmethod 285 | def build_resnet_50(input_shape, num_outputs, reg_factor=1e-4): 286 | """Build resnet 50.""" 287 | return Resnet3DBuilder.build(input_shape, num_outputs, bottleneck, 288 | [3, 4, 6, 3], reg_factor=reg_factor) 289 | 290 | @staticmethod 291 | def build_resnet_101(input_shape, num_outputs, reg_factor=1e-4): 292 | """Build resnet 101.""" 293 | return Resnet3DBuilder.build(input_shape, num_outputs, bottleneck, 294 | [3, 4, 23, 3], reg_factor=reg_factor) 295 | 296 | @staticmethod 297 | def build_resnet_152(input_shape, num_outputs, reg_factor=1e-4): 298 | """Build resnet 152.""" 299 | return Resnet3DBuilder.build(input_shape, num_outputs, bottleneck, 300 | [3, 8, 36, 3], reg_factor=reg_factor) -------------------------------------------------------------------------------- /model/sfcn_keras.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @author: Jeyeon Lee 5 | """ 6 | 7 | import datetime 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import os 11 | import keras as K 12 | import tensorflow as tf 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.nn as nn 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | from keras.layers import Input, Activation, Conv3D, Flatten, Dense,MaxPooling3D, BatchNormalization, GlobalAveragePooling3D, Dropout 19 | from keras.models import Model, load_model 20 | from keras.preprocessing.image import ImageDataGenerator 21 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau 22 | 23 | 24 | def SFCN(input_shape,dropout): 25 | channel_number = [32, 64, 128, 256, 256,64] # [32, 64, 128, 256, 256, 64] 26 | output_dim = 1 27 | dropout = True 28 | n_layer = len(channel_number) 29 | img_input = Input(shape=input_shape, name='data') 30 | for i in range(n_layer): 31 | out_channel = channel_number[i] 32 | if i == 0: 33 | x = Conv3D(out_channel, kernel_size=3,strides=1, padding='same',name='conv_%d' % i)(img_input) 34 | x = BatchNormalization(axis=-1)(x) 35 | x = MaxPooling3D(pool_size=2, strides=2,name='pool_%d' % i)(x) 36 | x = Activation('relu')(x) 37 | if i>0 and i < n_layer - 1: 38 | x = Conv3D(out_channel, kernel_size=3,strides=1, padding='same',name='conv_%d' % i)(x) 39 | x = BatchNormalization(axis=-1)(x) 40 | x = MaxPooling3D(pool_size=2, strides=2,name='pool_%d' % i)(x) 41 | x = Activation('relu')(x) 42 | if i ==n_layer - 1: 43 | x = Conv3D(out_channel, kernel_size=(1,1,1),strides=(1,1,1), padding='same',name='conv_%d' % i)(x) 44 | x = BatchNormalization(axis=-1)(x) 45 | x = Activation('relu')(x) 46 | x = GlobalAveragePooling3D(name='GAP')(x) 47 | if dropout is True: 48 | x = Dropout(0.5)(x) 49 | out_channel = output_dim 50 | #x = Conv3D(out_channel, kernel_size=1, strides=1, padding='same', name='conv_%d' % i)(x) 51 | x = Dense(1, name='fc')(x) 52 | x = Activation('linear')(x) 53 | model = Model(inputs=img_input, outputs=x, name='SFCN') 54 | return model 55 | 56 | #model = SFCN((121,145,121,1),'False') 57 | #model.summary() -------------------------------------------------------------------------------- /model/vgg_16.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Sep 27 03:25:16 2020 5 | @author: J.Lee 6 | Model_01_VGGnet 7 | """ 8 | 9 | import datetime 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import os 13 | import keras 14 | import tensorflow as tf 15 | 16 | from keras.layers import Input, Activation, Conv1D, Conv2D, Conv3D, Flatten, Dense, MaxPooling1D, MaxPooling2D, MaxPooling3D,GlobalAveragePooling3D 17 | from keras.models import Model, load_model 18 | from keras.preprocessing.image import ImageDataGenerator 19 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau 20 | 21 | def vgg16_3D(input_image_size, model_class_num): 22 | 23 | #kernel_init = keras.initializers.glorot_uniform() 24 | #bias_init = keras.initializers.Constant(value=0.2) 25 | conv1_filt_num = 64 26 | conv2_filt_num = 128 27 | conv3_filt_num = 256 28 | conv4_filt_num = 512 29 | FC_first_second = 4096 30 | FC_last = model_class_num 31 | 32 | input_layer = Input(shape=input_image_size) 33 | 34 | conv1_1 = Conv3D(conv1_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv1_1')(input_layer) #, kernel_initializer=kernel_init, bias_initializer=bias_init 35 | conv1_2 = Conv3D(conv1_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv1_2')(conv1_1) 36 | pool1 = MaxPooling3D(pool_size=2, strides=2, name='pool1')(conv1_2) 37 | 38 | conv2_1 = Conv3D(conv2_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv2_1')(pool1) 39 | conv2_2 = Conv3D(conv2_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv2_2')(conv2_1) 40 | pool2 = MaxPooling3D(pool_size=2, strides=2, name='pool2')(conv2_2) 41 | 42 | conv3_1 = Conv3D(conv3_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv3_1')(pool2) 43 | conv3_2 = Conv3D(conv3_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv3_2')(conv3_1) 44 | conv3_3 = Conv3D(conv3_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv3_3')(conv3_2) 45 | pool3 = MaxPooling3D(pool_size=2, strides=2, name='pool3')(conv3_3) 46 | 47 | conv4_1 = Conv3D(conv4_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv4_1')(pool3) 48 | conv4_2 = Conv3D(conv4_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv4_2')(conv4_1) 49 | conv4_3 = Conv3D(conv4_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv4_3')(conv4_2) 50 | pool4 = MaxPooling3D(pool_size=2, strides=2, name='pool4')(conv4_3) 51 | 52 | conv5_1 = Conv3D(conv4_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv5_1')(pool4) 53 | conv5_2 = Conv3D(conv4_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv5_2')(conv5_1) 54 | conv5_3 = Conv3D(conv4_filt_num, kernel_size=3, strides=1, padding='same', activation='relu', name='conv5_3')(conv5_2) 55 | pool5 = MaxPooling3D(pool_size=2, strides=2, name='pool5')(conv5_3) 56 | 57 | flatten1 = GlobalAveragePooling3D(name='CAM_pool')(pool5) 58 | 59 | #flatten_6 = Flatten()(pool5) 60 | FC1 = Dense(FC_first_second, activation='relu', name='fc1')(flatten1) 61 | FC2 = Dense(FC_last, activation='relu', name='fc2')(FC1) 62 | #FC3 = Dense(FC_last, activation='relu', name='fc3')(FC2) 63 | 64 | outputs = Activation('linear')(FC2) 65 | 66 | model = Model(inputs=input_layer, outputs=outputs, name='vgg_3D') 67 | 68 | #model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc']) 69 | 70 | #model.summary() 71 | 72 | return model -------------------------------------------------------------------------------- /regression_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # !/usr/bin/env python3 5 | # -*- coding: utf-8 -*- 6 | """ 7 | Created on Thu Jan 31 11:36:37 2019 8 | 9 | for 3D CAM (DenseNet, ResNet50, InceptionV3) 10 | 11 | @author: J.Lee 12 | 13 | """ 14 | import inspect, os, sys, h5py, shutil 15 | import scipy.io as sio 16 | import numpy as np 17 | import tensorflow as tf 18 | # import matplotlib.pyplot as plt 19 | from sklearn.metrics import classification_report, confusion_matrix, mean_absolute_error 20 | from sklearn.model_selection import StratifiedKFold 21 | from keras.optimizers import SGD, Adam, RMSprop 22 | from keras.callbacks import Callback, ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, CSVLogger 23 | from densenet3d_regression import build_densenet_forCAM 24 | from resnet_3d import Resnet3DBuilder 25 | from keras.utils import plot_model 26 | from vgg_16 import vgg16_3D 27 | from sfcn_keras import SFCN 28 | from datetime import datetime 29 | 30 | #################### VARIABLES #################### 31 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 32 | j = 4 # nth_fold 0-4 33 | learning_rate = 0.001 34 | modality = 'mri' 35 | model_id = 2 # 0: densenet, 1: resnet101, 2: VGG16, 3: SFCN 36 | datapath = '/home/m186870/data_j6/dl_data/MRI_masked_bc/regression' 37 | datafilename = 'mri_age_optB4_v2_fold' + str(j) + '.mat' 38 | codepath = '/home/m186870/data_j6/dl_data/dl_code' 39 | lru, bs = 0.5, 4 40 | fit_iter, fit_ep = 10, 15 41 | ################################################### 42 | 43 | currentcode = inspect.getfile(inspect.currentframe()) 44 | print('\n [Info] Running code: ', currentcode) 45 | 46 | def Data_Load(mat_file_name): 47 | os.chdir(datapath) 48 | mat_contents = h5py.File(mat_file_name, 'r') 49 | X_Train = mat_contents['X_Train'] 50 | Y_Train = mat_contents['Y_Train'] 51 | X_Train = np.transpose(X_Train) # for transpose becauseof HDF matfile v7.3 52 | Y_Train = np.transpose(Y_Train) 53 | X_Val = mat_contents['X_Val'] 54 | Y_Val = mat_contents['Y_Val'] 55 | X_Val = np.transpose(X_Val) # for transpose becauseof HDF matfile v7.3 56 | Y_Val = np.transpose(Y_Val) 57 | X_Test = mat_contents['X_Test'] 58 | Y_Test = mat_contents['Y_Test'] 59 | X_Test = np.transpose(X_Test) # for transpose becauseof HDF matfile v7.3 60 | Y_Test = np.transpose(Y_Test) 61 | 62 | nanidx = np.isnan(X_Train) 63 | X_Train[nanidx] = 0 64 | nanidx = np.isnan(X_Test) 65 | X_Test[nanidx] = 0 66 | nanidx = np.isnan(X_Val) 67 | X_Val[nanidx] = 0 68 | 69 | if Y_Train.shape[0] < Y_Train.shape[1]: 70 | Y_Train = np.transpose(Y_Train) 71 | if Y_Test.shape[0] < Y_Test.shape[1]: 72 | Y_Test = np.transpose(Y_Test) 73 | if Y_Val.shape[0] < Y_Val.shape[1]: 74 | Y_Val = np.transpose(Y_Val) 75 | 76 | print('\tDatafile: ', datapath, datafilename) 77 | print('\tX_Train shape :', X_Train.shape) 78 | print('\tY_Train shape :', Y_Train.shape) 79 | print('\tX_Val shape :', X_Val.shape) 80 | print('\tY_Val shape :', Y_Val.shape) 81 | print('\tX_Test shape :', X_Test.shape) 82 | print('\tY_Test shape :', Y_Test.shape) 83 | return X_Train, Y_Train, X_Val, Y_Val, X_Test, Y_Test 84 | 85 | os.chdir(codepath) 86 | print(codepath) 87 | 88 | # %% Build Model 89 | if model_id == 0: # densenet3d 90 | modelname = 'densenet3dregression' 91 | model = build_densenet_forCAM((121, 145, 121, 1), 0) 92 | opt = 'Adam' 93 | outpath = datapath+'/paper_'+modality+'_age_optb4_densenet_lr' + str( 94 | learning_rate) + '_' + opt + '/' 95 | if model_id == 1: # Resnet 96 | modelname = 'Resenet101' 97 | model = Resnet3DBuilder.build_resnet_101((121, 145, 121, 1), 1) 98 | opt = 'Adam' 99 | outpath = datapath+'/paper_'+modality+'_age_optb4_resnet_lr' + str( 100 | learning_rate) + '_' + opt + '/' 101 | if model_id == 2: # VGG 102 | modelname = 'VGG16' 103 | model = vgg16_3D((121, 145, 121, 1), 1) 104 | opt = 'SGD' 105 | outpath = datapath+'/paper_'+modality+'_age_optb4_vgg_lr' + str( 106 | learning_rate) + '_' + opt + '/' 107 | if model_id == 3: # SFCN 108 | modelname = 'SFCN' 109 | model = SFCN((121, 145, 121, 1), 'False') 110 | opt = 'SGD' 111 | outpath = datapath+'/paper_'+modality+'_age_optb4_sfcn_lr' + str( 112 | learning_rate) + '_' + opt + '/' 113 | 114 | ################################## 115 | if not os.path.exists(datapath): 116 | os.makedirs(datapath) 117 | os.chdir(datapath) 118 | X_Train, Y_Train, X_Val, Y_Val, X_Test, Y_Test = Data_Load(datafilename) 119 | X_Train = np.reshape(X_Train, (X_Train.shape[0], X_Train.shape[1], X_Train.shape[2], X_Train.shape[3], 1)) 120 | X_Val = np.reshape(X_Val, (X_Val.shape[0], X_Val.shape[1], X_Val.shape[2], X_Val.shape[3], 1)) 121 | X_Test = np.reshape(X_Test, (X_Test.shape[0], X_Test.shape[1], X_Test.shape[2], X_Test.shape[3], 1)) 122 | 123 | print(datetime.now()) 124 | print('\n [Info] Data loading done') 125 | 126 | if not os.path.exists(outpath): 127 | os.makedirs(outpath) 128 | print('\tProcessing path: ', outpath) 129 | 130 | os.chdir(outpath) 131 | print('\n [Info] Model set: ', modelname) 132 | plot_model(model, to_file=modelname + '.pdf', show_shapes=True) 133 | model.summary() 134 | with open(modelname + '.txt', 'w') as f2: 135 | model.summary(print_fn=lambda x: f2.write(x + '\n')) 136 | 137 | # %% Training 138 | print("\n [Info] Training Start!") 139 | 140 | for i in range(fit_iter): 141 | print("\t Validating setnum:", str(j), "-Training iter:", str(i + 1)) 142 | filepath_weights_best = './weights.best_' + str(j) + 'fold.h5' 143 | filepath_weights_best_past = './weights.best_' + str(j) + 'fold.h5' 144 | 145 | if i > 0: 146 | model.load_weights(filepath_weights_best_past) 147 | learning_rate *= lru 148 | print('\tcurrent leraning_rate : ', learning_rate) 149 | 150 | if 'Adam' in opt: 151 | if model_id == 2: 152 | optm = Adam(lr=learning_rate) 153 | else: 154 | optm = Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) 155 | elif 'SGD' in opt: 156 | optm = SGD(lr=learning_rate) 157 | elif 'RMSprop' in opt: 158 | optm = RMSprop(lr=learning_rate, rho=0.9, epsilon=1e-08, decay=0.0) 159 | model.compile(loss='mean_absolute_error', optimizer=optm) 160 | 161 | checkpoint = ModelCheckpoint(filepath_weights_best, monitor='val_loss', verbose=1, save_best_only=True, mode='min') 162 | earlystop = EarlyStopping(monitor='val_loss', min_delta=0, patience=6, verbose=1, mode='auto') 163 | csv_logger = CSVLogger('training.log', separator=',', append=True) 164 | callbacks_list = [csv_logger, checkpoint, earlystop] 165 | trained_model = model.fit(X_Train, Y_Train, batch_size=bs, epochs=fit_ep, verbose=1, shuffle=True, 166 | callbacks=callbacks_list, validation_data=(X_Val, Y_Val)) 167 | 168 | # list all data in history 169 | print(trained_model.history.keys()) 170 | 171 | model.load_weights(filepath_weights_best) 172 | 173 | y_pred = model.predict(X_Test, batch_size=bs) 174 | mae = mean_absolute_error(y_pred, Y_Test) 175 | 176 | y_valpred = model.predict(X_Val, batch_size=bs) 177 | valmae = mean_absolute_error(y_valpred, Y_Val) 178 | 179 | print("[INFO] Test Mean absolute error: {:.4f}".format(mae)) 180 | with open('Acc_' + str(j) + 'fold_iter' + str(i + 1) + '.txt', 'w') as f: 181 | print("\nlearning_rate:" + str(learning_rate), file=f) 182 | print("\nValidation mean absolute error: {:.4f}".format(valmae), file=f) 183 | print("\nTest mean absolute error: {:.4f}".format(mae), file=f) 184 | 185 | outname = 'regression_score_test_' + str(j) + 'fold.mat' 186 | sio.savemat(outpath + outname, {'Ypred': y_pred, 'Ydata': Y_Test}) 187 | 188 | 189 | 190 | 191 | --------------------------------------------------------------------------------