├── Screenshots ├── jg2.png ├── modifi_unet.png └── preprocession.jpg ├── README.md ├── proprocess ├── N4BiasFieldCorrection.py ├── data.py ├── register.py └── dicom_nii.py ├── train.py └── unet_model.py /Screenshots/jg2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zoukai214/CT-Synthetic-MR-Images/HEAD/Screenshots/jg2.png -------------------------------------------------------------------------------- /Screenshots/modifi_unet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zoukai214/CT-Synthetic-MR-Images/HEAD/Screenshots/modifi_unet.png -------------------------------------------------------------------------------- /Screenshots/preprocession.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zoukai214/CT-Synthetic-MR-Images/HEAD/Screenshots/preprocession.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ct_to_mri 2 | # input ct data use U-net method systh mri 3 | 来自论文《Whole Brain Segmentation and Labeling from CT Using Synthetic MR Images》,主要用于通过分割模型,将CT数据合成MRI数据。 4 | 5 | ## 运行环境 6 | keras>1.3 7 | Tensorfow>1.0 8 | SimppleITK=1.1.0 9 | ## 数据收集 10 | 收集同一人的CT/MRI数据,需要对数据进行N4偏差场矫正,白质均值正则化,再讲两个数据进行刚性配准。 11 | 12 | ## 模型的修改 13 | ![改进的unet](https://github.com/zoukai214/CT-Synthetic-MR-Images/tree/master/Screenshots/modifi_unet.png) 14 | 15 | 按上图对Unet进行修改,参考unet_model.py文件 16 | 17 | ## 数据预处理 18 | ![数据预处理](https://github.com/zoukai214/CT-Synthetic-MR-Images/tree/master/Screenshots/preprocession.jpg) 19 | 如上图对CT,MRI数据进行数据预处理 20 | ## 训练 21 | train.py 22 | -------------------------------------------------------------------------------- /proprocess/N4BiasFieldCorrection.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import os 3 | 4 | def N4(input_dir,OUTPUT_DIR): 5 | for root, dirs, files in os.walk(input_dir): 6 | for file in files: 7 | num = file.rfind('_') 8 | ct_label = file[num + 1:] 9 | if ct_label == "mri.nii": 10 | ct_path = os.path.join(root, file) 11 | print(ct_path) 12 | inputImage = sitk.ReadImage(ct_path) 13 | maskImage = sitk.OtsuThreshold(inputImage,0,1,200) 14 | 15 | inputImage = sitk.Cast(inputImage,sitk.sitkFloat32) 16 | 17 | corrector = sitk.N4BiasFieldCorrectionImageFilter() 18 | output = corrector.Execute(inputImage,maskImage) 19 | sitk.WriteImage(output, os.path.join(OUTPUT_DIR, '{}.nii'.format(file))) 20 | input_dir = "/media/zoukai/软件/DATA/Brain_register_ct/" 21 | OUTPUT_DIR = "/media/zoukai/软件/DATA/Brain_N4/" 22 | N4(input_dir, OUTPUT_DIR) -------------------------------------------------------------------------------- /proprocess/data.py: -------------------------------------------------------------------------------- 1 | #合成128*128图像npy,作为训练数据/测试数据 2 | import glob 3 | import os 4 | import numpy as np 5 | import SimpleITK as sitk 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from PIL import Image 9 | from skimage import transform 10 | %matplotlib inline 11 | image_path = "/home/zoukai/Data/CTMR/dwicrop/8_trainct" 12 | #1.读取文件夹中的文件 13 | image_name_arr = glob.glob(os.path.join(image_path,"*.nii")) 14 | 15 | #2.预设合成npy的维度与步长 16 | imgs = np.ndarray((5120,128,128,1),dtype = np.uint8) 17 | output_shape = np.array([128,128,1]) 18 | stride = np.array([128,128,1]) 19 | 20 | #3.遍历每一个文件,并取出其中128x128 21 | c = 0 22 | for index,item in enumerate(image_name_arr): 23 | mri_image = sitk.ReadImage(item,sitk.sitkFloat32) 24 | img = sitk.GetArrayFromImage(mri_image).astype(np.float32) 25 | img = img.transpose(2,1,0) 26 | # ct数据进行加1000的处理 27 | img = img+1000. 28 | 29 | #3.判断是否是192x192分辨率的 30 | if img.shape[0] == 192: 31 | for i in range(img.shape[2]): 32 | for h in range(2): 33 | for w in range(2): 34 | image_path = img[h*(stride[0]-64):h*(stride[0]-64)+output_shape[0], 35 | w*(stride[1]-64):w*(stride[1]-64)+output_shape[1], 36 | i*stride[2]:i*stride[2]+output_shape[2]] 37 | imgs[c] = image_path 38 | c += 1 39 | print(c) 40 | else: 41 | for i in range(img.shape[2]): 42 | image_path = img[:,:,i*stride[2]:i*stride[2]+output_shape[2]] 43 | imgs[c] = image_path 44 | c+=1 45 | print(c) 46 | 47 | np.save('../data/input_npy/8_trainct.npy',imgs) 48 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from unet_model import * 2 | from keras.callbacks import TensorBoard 3 | 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 5 | 6 | 7 | def train_and_predict(): 8 | print("-" * 30) 9 | print("Loading and preprocessing train data ...") 10 | print("-" * 30) 11 | 12 | imgs_train = np.load('/home/zoukai/project/unet-master/unet-master/data/input_npy/7_trainct.npy') 13 | imgs_mask_train = np.load('/home/zoukai/project/unet-master/unet-master/data/input_npy/7_trainmri.npy') 14 | 15 | #imgs_train = np.reshape(imgs_train,imgs_train.shape+(1,)) 16 | #imgs_mask_train = np.reshape(imgs_mask_train,imgs_mask_train.shape+(1,)) 17 | 18 | imgs_train = imgs_train.astype('float32') 19 | imgs_mask_train = imgs_mask_train.astype('float32') 20 | total = imgs_train.shape[0] 21 | print("-" * 30) 22 | print("Create and compiling model...") 23 | print('-' * 30) 24 | model = unet() 25 | model_checkpoint = ModelCheckpoint('7_normalization.hdf5', monitor='loss', verbose=1, save_best_only=True) 26 | print("-" * 30) 27 | print("Fitting model ...") 28 | print("-" * 30) 29 | 30 | tb = TensorBoard(log_dir='./logs', # log 目录 31 | histogram_freq=1, # 按照何等频率(epoch)来计算直方图,0为不计算 32 | batch_size=32, # 用多大量的数据计算直方图 33 | write_graph=True, # 是否存储网络结构图 34 | write_grads=False, # 是否可视化梯度直方图 35 | write_images=False, # 是否可视化参数 36 | embeddings_freq=0, 37 | embeddings_layer_names=None, 38 | embeddings_metadata=None) 39 | 40 | 41 | model.fit(imgs_train, imgs_mask_train, batch_size=16, epochs=1000,validation_split=0.2, verbose=1, shuffle=True, 42 | callbacks=[model_checkpoint,tb]) 43 | 44 | if __name__ == "__main__": 45 | train_and_predict() 46 | -------------------------------------------------------------------------------- /proprocess/register.py: -------------------------------------------------------------------------------- 1 | import os 2 | import SimpleITK as sitk 3 | 4 | def rig_register(CT_DIR ,MRI_DIR,OUTPUT_DIR,OUTPUT_NAME): 5 | 6 | # fixed_image = sitk.ReadImage(fdata("training_001_ct.mha"), sitk.sitkFloat32) 7 | fixed_image = sitk.ReadImage(CT_DIR, sitk.sitkFloat32) 8 | # 取DWI数据,即取mri中的前20层 9 | fixed_image = fixed_image[:, :, :20] 10 | moving_image = sitk.ReadImage(MRI_DIR, sitk.sitkFloat32) 11 | 12 | initial_transform = sitk.CenteredTransformInitializer(fixed_image, 13 | moving_image, 14 | sitk.Euler3DTransform(), 15 | sitk.CenteredTransformInitializerFilter.GEOMETRY) 16 | 17 | moving_resampled = sitk.Resample(moving_image, fixed_image, initial_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelID()) 18 | registration_method = sitk.ImageRegistrationMethod() 19 | 20 | # Similarity metric settings. 21 | registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) 22 | registration_method.SetMetricSamplingStrategy(registration_method.RANDOM) 23 | registration_method.SetMetricSamplingPercentage(0.01) 24 | 25 | registration_method.SetInterpolator(sitk.sitkLinear) 26 | 27 | # Optimizer settings. 28 | registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, convergenceMinimumValue=1e-6, convergenceWindowSize=10) 29 | registration_method.SetOptimizerScalesFromPhysicalShift() 30 | 31 | # Setup for the multi-resolution framework. 32 | registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1]) 33 | registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0]) 34 | registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() 35 | 36 | # Don't optimize in-place, we would possibly like to run this cell multiple times. 37 | registration_method.SetInitialTransform(initial_transform, inPlace=False) 38 | final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 39 | sitk.Cast(moving_image, sitk.sitkFloat32)) 40 | moving_resampled = sitk.Resample(moving_image, fixed_image, final_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelID()) 41 | 42 | sitk.WriteImage(moving_resampled, os.path.join(OUTPUT_DIR, '{}.nii'.format(OUTPUT_NAME))) 43 | sitk.WriteTransform(final_transform, os.path.join(OUTPUT_DIR, '{}.tfm'.format(OUTPUT_NAME))) 44 | 45 | # if __name__ == '__main__': 46 | # #输入ct文件的路径 47 | # CT_DIR = "/home/zoukai/Data/CTMR/CT/20180703_1704151HeadRoutines004a002.nii" 48 | # #输入mri文件的路径 49 | # MRI_DIR = "/home/zoukai/Data/CTMR/MR/sub-25634_ses-1_T1w.nii" 50 | # #输出文件夹的路径 51 | # OUTPUT_DIR = "/home/zoukai/Data/CTMR/MR_2/" 52 | # #输出文件的名字 53 | # OUTPUT_NAME = "ZK_007" 54 | # rig_register(CT_DIR, MRI_DIR,OUTPUT_DIR, OUTPUT_NAME) 55 | -------------------------------------------------------------------------------- /proprocess/dicom_nii.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding:utf-8 -*- 3 | import SimpleITK as sitk 4 | import os 5 | import numpy as np 6 | 7 | def ReadDicom(root,OUTPUT_DIR): 8 | for filename in os.listdir(root): 9 | pathname = os.path.join(root,filename) 10 | if (os.path.isdir(pathname)): 11 | for two_file in os.listdir(pathname): 12 | two_file_dir = os.path.join(pathname,two_file) 13 | if (os.path.isdir(two_file_dir)): 14 | if two_file == "CT": 15 | two_pathname = os.path.join(two_file_dir,"DICOM") 16 | #print(two_file_dir) 17 | reader = sitk.ImageSeriesReader() 18 | dicom_names = reader.GetGDCMSeriesFileNames(two_pathname) 19 | reader.SetFileNames(dicom_names) 20 | image = reader.Execute() 21 | sitk.WriteImage(image, os.path.join(OUTPUT_DIR, '{}_{}.nii'.format(filename,two_file))) 22 | 23 | def regieter(input_dir,OUTPUT_DIR): 24 | L = [] 25 | for root, dirs, files in os.walk(input_dir): 26 | for file in files: 27 | num = file.rfind('_') 28 | ct_name = file[num+1:] 29 | mri_name = file[:num] 30 | if ct_name == "CT.nii": 31 | zk_mri_path = mri_name+'_mri.nii' 32 | mri_path = os.path.join(root,zk_mri_path) 33 | if os.path.exists(mri_path): 34 | MRI_DIR = mri_path 35 | else: 36 | MRI_DIR = mri_name +'_DWI.nii' 37 | MRI_DIR = os.path.join(root,MRI_DIR) 38 | 39 | CT_DIR = os.path.join(root,file) 40 | print(CT_DIR,MRI_DIR) 41 | # fixed_image = sitk.ReadImage(fdata("training_001_ct.mha"), sitk.sitkFloat32) 42 | fixed_image = sitk.ReadImage(MRI_DIR, sitk.sitkFloat32) 43 | ## 取其中的体素spacing 44 | fixed_image_spaceing = fixed_image.GetSpacing() 45 | moving_image = sitk.ReadImage(CT_DIR, sitk.sitkFloat32) 46 | # 取DWI数据,即取mri中的前20层 47 | fixed_image = sitk.GetImageFromArray(sitk.GetArrayFromImage(fixed_image)[0]) 48 | fixed_image.SetSpacing(fixed_image_spaceing) 49 | 50 | 51 | 52 | initial_transform = sitk.CenteredTransformInitializer(fixed_image, 53 | moving_image, 54 | sitk.Euler3DTransform(), 55 | sitk.CenteredTransformInitializerFilter.GEOMETRY) 56 | 57 | 58 | registration_method = sitk.ImageRegistrationMethod() 59 | 60 | # Similarity metric settings. 61 | registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) 62 | registration_method.SetMetricSamplingStrategy(registration_method.RANDOM) 63 | registration_method.SetMetricSamplingPercentage(0.01) 64 | 65 | registration_method.SetInterpolator(sitk.sitkLinear) 66 | 67 | # Optimizer settings. 68 | registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, 69 | convergenceMinimumValue=1e-6, 70 | convergenceWindowSize=10) 71 | registration_method.SetOptimizerScalesFromPhysicalShift() 72 | 73 | # Setup for the multi-resolution framework. 74 | 75 | registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1]) 76 | registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0]) 77 | registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() 78 | 79 | # Don't optimize in-place, we would possibly like to run this cell multiple times. 80 | registration_method.SetInitialTransform(initial_transform, inPlace=False) 81 | final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 82 | sitk.Cast(moving_image, sitk.sitkFloat32)) 83 | moving_resampled = sitk.Resample(moving_image, fixed_image, final_transform, sitk.sitkLinear, 0.0, 84 | moving_image.GetPixelID()) 85 | 86 | sitk.WriteImage(moving_resampled, os.path.join(OUTPUT_DIR, 'sys_{}_mri.nii'.format(mri_name))) 87 | #sitk.WriteTransform(final_transform, os.path.join(OUTPUT_DIR, 'sys_{}.tfm'.format(mri_name))) 88 | 89 | 90 | def sys_mri(input_dir,OUTPUT_DIR): 91 | for root, dirs, files in os.walk(input_dir): 92 | for file in files: 93 | num = file.rfind('_') 94 | ct_name = file[num+1:] 95 | mri_name = file[:num] 96 | if ct_name == "CT.nii": 97 | zk_mri_path = mri_name+'_mri.nii' 98 | mri_path = os.path.join(root,zk_mri_path) 99 | if os.path.exists(mri_path): 100 | MRI_DIR = mri_path 101 | else: 102 | MRI_DIR = mri_name +'_DWI.nii' 103 | MRI_DIR = os.path.join(root,MRI_DIR) 104 | MRI_IMAGE = sitk.ReadImage(MRI_DIR, sitk.sitkFloat32) 105 | mri_image_spaceing = MRI_IMAGE.GetSpacing() 106 | mri_image = sitk.GetImageFromArray(sitk.GetArrayFromImage(MRI_IMAGE)[0]) 107 | mri_image.SetSpacing(mri_image_spaceing) 108 | sitk.WriteImage(mri_image, os.path.join(OUTPUT_DIR, 'sys_{}_mri.nii'.format(mri_name))) 109 | #root = "/media/zoukai/软件/yikai/Brain_MRI2_nanshan/" 110 | #OUTPUT_DIR = "/home/zoukai/Data/CTMR/DWI2" 111 | #ReadDicom(root,OUTPUT_DIR) 112 | input_dir = "/media/zoukai/软件/DATA/Brain_MRI2/" 113 | OUTPUT_DIR = "/media/zoukai/软件/DATA/Brain_3D_MRI/" 114 | sys_mri(input_dir,OUTPUT_DIR) 115 | #regieter(input_dir,OUTPUT_DIR) 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /unet_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import skimage.io as io 4 | import skimage.transform as trans 5 | import numpy as np 6 | from keras.models import * 7 | # from keras.layers import * 8 | from keras.layers import MaxPooling2D, UpSampling2D, Dropout 9 | from keras.layers.convolutional import Conv2D, Conv2DTranspose 10 | from keras.layers.merge import concatenate 11 | from keras.optimizers import * 12 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler 13 | from keras import backend as keras 14 | 15 | from keras.layers import BatchNormalization 16 | 17 | 18 | def my_loss(y_true, y_pred): 19 | 20 | m,n_H,n_W,n_C = y_pred.get_shape().as_list() 21 | y_true_unrolled = keras.transpose(keras.reshape(y_true,[-1])) 22 | y_pred_unrolled = keras.transpose(keras.reshape(y_pred,[-1])) 23 | J_content = keras.sum(keras.square(y_pred-y_true)/(4*n_H*n_W*n_C)) 24 | return J_content 25 | 26 | def unet(pretrained_weights=None, input_size=(128, 128, 1)): 27 | inputs = Input(input_size) 28 | # conv1 = model.add(Conv2D(64,3,activation='relu',padding='same',kernel_initial='he_normal')(inputs)) 29 | conv1 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs) 30 | conv1 = BatchNormalization(axis=-1)(conv1) 31 | conv1 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1) 32 | conv1 = BatchNormalization(axis=-1)(conv1) 33 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 34 | 35 | conv2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1) 36 | conv2 = BatchNormalization(axis=-1)(conv2) 37 | conv2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2) 38 | conv2 = BatchNormalization(axis=-1)(conv2) 39 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 40 | 41 | conv3 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2) 42 | conv3 = BatchNormalization(axis=-1)(conv3) 43 | conv3 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3) 44 | conv3 = BatchNormalization(axis=-1)(conv3) 45 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 46 | 47 | conv4 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3) 48 | conv4 = BatchNormalization(axis=-1)(conv4) 49 | conv4 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4) 50 | conv4 = BatchNormalization(axis=-1)(conv4) 51 | #drop4 = Dropout(0.2)(conv4) 52 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 53 | 54 | conv5 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4) 55 | conv5 = BatchNormalization(axis=-1)(conv5) 56 | conv5 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5) 57 | conv5 = BatchNormalization(axis=-1)(conv5) 58 | #drop5 = Dropout(0.2)(conv5) 59 | 60 | # up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) 61 | #up6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), activation='relu', padding='same',kernel_initializer='he_normal')(conv5) 62 | up6 = UpSampling2D((2, 2))(conv5) 63 | # merge6 = merge([drop4,up6], mode = 'concat', concat_axis = 3) 64 | merge6 = concatenate([conv4, up6], axis=3) 65 | conv6 = Conv2D(256, 5, activation='relu', padding='same', kernel_initializer='he_normal')(merge6) 66 | conv6 = BatchNormalization(axis=-1)(conv6) 67 | conv6 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6) 68 | conv6 = BatchNormalization(axis=-1)(conv6) 69 | 70 | # up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) 71 | #up7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), activation='relu', padding='same',kernel_initializer='he_normal')(conv6) 72 | up7 = UpSampling2D((2,2))(conv6) 73 | # merge7 = merge([conv3,up7], mode = 'concat', concat_axis = 3) 74 | merge7 = concatenate([conv3, up7], axis=3) 75 | conv7 = Conv2D(128, 5, activation='relu', padding='same', kernel_initializer='he_normal')(merge7) 76 | conv7 = BatchNormalization(axis=-1)(conv7) 77 | conv7 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7) 78 | conv7 = BatchNormalization(axis=-1)(conv7) 79 | 80 | # up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) 81 | #up8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), activation='relu', padding='same',kernel_initializer='he_normal')(conv7) 82 | up8 = UpSampling2D((2,2))(conv7) 83 | merge8 = concatenate([conv2, up8], axis=3) 84 | conv8 = Conv2D(64, 5, activation='relu', padding='same', kernel_initializer='he_normal')(merge8) 85 | conv8 = BatchNormalization(axis=-1)(conv8) 86 | conv8 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8) 87 | conv8 = BatchNormalization(axis=-1)(conv8) 88 | 89 | # up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) 90 | #up9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), activation='relu', padding='same',kernel_initializer='he_normal')(conv8) 91 | up9 = UpSampling2D((2,2))(conv8) 92 | merge9 = concatenate([conv1, up9], axis=3) 93 | conv9 = Conv2D(32, 5, activation='relu', padding='same', kernel_initializer='he_normal')(merge9) 94 | conv9 = BatchNormalization(axis=-1)(conv9) 95 | conv9 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9) 96 | conv9 = BatchNormalization(axis=-1)(conv9) 97 | 98 | merge10 = concatenate([inputs,conv9],axis=3) 99 | #conv9 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge10) 100 | conv10 = Conv2D(1, 1, activation='relu',padding='same')(merge10) 101 | 102 | model = Model(input=inputs, output=conv10) 103 | 104 | #model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy']) 105 | model.compile(optimizer=Adam(lr=1e-4), loss='mean_squared_error', metrics=['accuracy','mae']) 106 | # model.summary() 107 | 108 | if (pretrained_weights): 109 | model.load_weights(pretrained_weights) 110 | 111 | return model 112 | --------------------------------------------------------------------------------