├── UNET.txt ├── README.md ├── one_hot_encoding.py ├── augmentation.py ├── preprocessing.py ├── Unet1.py ├── data_train.py └── main_.py /UNET.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unet 2 | The LGE MRI used in the study were collected from 45 patients, of which 9 were randomly selected for testing. To augment the training data, I registered the training images to other image spaces using a set of artificially generated rigid, affine and deformable transformations, resulting in 5405 2D slices. 3 | I used Dice coefficient as metrics for evaluation of segmentation accuracy. The Dice of LV blood pool, Dice of Myocardium and Dice of RV blood pool on test data have reached 0.90,0.81 and 0.83 respectively. 4 | -------------------------------------------------------------------------------- /one_hot_encoding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue May 5 16:45:45 2020 4 | 5 | @author: DELL 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | import glob 11 | PATH = 'D:/img' 12 | SIZE = 240 13 | def change(path): 14 | #path = 'D:/grade4.2/pj1/label1/aug_label_0_8741.png' 15 | img = cv2.imread(path) 16 | img1 = np.array(img) 17 | #img2= np.array(img) 18 | #print(img1[:,200,:]) 19 | #print(img1[:,120,:]) 20 | img1 = img1/255. 21 | #print(img) 22 | return img1 23 | #print(img) 24 | #img /= np.std(img,axis=0) 25 | #sum_ = np.sum(img==[84,1,68]*a) 26 | #print(sum_) 27 | #img 28 | #a = img[:,:,2]#.reshape(240,240,1) 29 | #a = a.reshape(240,240,1) 30 | #img = np.concatenate((img,a),axis=2) 31 | 32 | #print(img1[:,120]) 33 | z = img1[:,:,0].reshape(240,240,1)#mask 34 | a = img2[:,:,0].reshape(240,240,1)#huang 35 | b = img1[:,:,2].reshape(240,240,1)#lan 36 | c = img2[:,:,2].reshape(240,240,1) 37 | 38 | #print(z[:,120]) 39 | #a[a<50]=0 40 | z[z<80]=0 41 | z[z>90]=0 42 | z[z>0]=1 43 | #print(z[:,120]) 44 | a[a<130]=0 45 | a[a>=150]=0 46 | a[a>0]=1 47 | b[b<=230]=0 48 | b[b>230]=1 49 | #print(b[:,120]) 50 | #c[c<120]=0 51 | c[c>155]=0 52 | c[c<144]=0 53 | c[c>0]=1 54 | #print(c[:,200]) 55 | img3 = np.concatenate((z,a),axis=2) 56 | img3 = np.concatenate((img3,b),axis=2) 57 | img3 = np.concatenate((img3,c),axis=2) 58 | #print(img3[:,120,:]) 59 | #cv2.imwrite('D:/grade4.2/pj1/merge/2.png',img1) 60 | 61 | return img3 62 | 63 | def main(): 64 | paths = glob.glob(PATH+'/*.png')[:111] 65 | #paths = paths[800:] 66 | #print(paths[1]) 67 | imgs = np.ndarray((len(paths),SIZE,SIZE,3),dtype=np.float32) 68 | i = 0 69 | for name in paths: 70 | img = change(name) 71 | imgs[i] = img 72 | #print(imgs[i]) 73 | i+=1 74 | print(i) 75 | #print(imgs[5,:,160,:]) 76 | np.save(PATH+'/T2try_img.npy',imgs) 77 | main() -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue May 5 16:44:26 2020 4 | 5 | @author: DELL 6 | """ 7 | 8 | import glob 9 | import random 10 | from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img 11 | from PIL import Image 12 | path = 'D:/img' 13 | path_ = 'D:/label' 14 | def load_data(path): 15 | img = load_img(path) 16 | img = img_to_array(img) 17 | img = img.reshape((1,)+img.shape) 18 | return img 19 | def image_augmentation(img,label,augnum):#num means batch_size 20 | image_datagen = ImageDataGenerator(rotation_range = 0.2, 21 | width_shift_range = 0.1, 22 | height_shift_range = 0.1, 23 | shear_range = 0.1, 24 | zoom_range = 0.1, 25 | fill_mode = 'nearest') 26 | label_datagen = ImageDataGenerator(rotation_range = 0.2, 27 | width_shift_range = 0.1, 28 | height_shift_range = 0.1, 29 | shear_range = 0.1, 30 | zoom_range = 0.1, 31 | fill_mode = 'nearest') 32 | 33 | #random.seed(1) 34 | seed = random.randint(1,100000) 35 | n = 0 36 | for batch in image_datagen.flow(img,batch_size=1,save_to_dir=path,save_prefix='aug',save_format='png',seed=seed): 37 | n +=1 38 | if n > augnum: 39 | break 40 | n = 0 41 | 42 | for batch in image_datagen.flow(label,batch_size=1,save_to_dir=path_,save_prefix='aug_label',save_format='png',seed=seed): 43 | n +=1 44 | if n >augnum: 45 | break 46 | 47 | return 48 | 49 | img_path = glob.glob(path+'/*.png') 50 | #img_path = img_path[:5] 51 | label_path = glob.glob(path_+'/*.png') 52 | #label_path = label_path[:5] 53 | for i in range(len(img_path)): 54 | img = load_data(img_path[i]) 55 | label = load_data(label_path[i]) 56 | image_augmentation(img,label,15) 57 | ''' 58 | img = load_data(path) 59 | image_augmentation(img,5) 60 | ''' -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue May 5 16:38:28 2020 4 | 5 | @author: DELL 6 | """ 7 | 8 | import SimpleITK as sitk 9 | import sys,os 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | IMAGE_PATH = '' 13 | IMAGE_FORMAT = '.png' 14 | LABEL_PATH = '' 15 | LABEL_FORMAT = '.png' 16 | 17 | path_ = '' 18 | path_lb = '' 19 | #200=pool of left ventricle,500=myocardium of left ventricle,600=pool of right ventricle 20 | LABEL_NUM = [200,500,600] 21 | #img = sitk.ReadImage(path) 22 | #data = sitk.GetArrayFromImage(img) 23 | #all_one_array = np.ones_like(data[0]) 24 | def load_img(path): 25 | img = sitk.ReadImage(path) 26 | data = sitk.GetArrayFromImage(img) 27 | return data 28 | ''' 29 | def clip_img(img,path): 30 | for i in range(img.shape[0]) 31 | clip = img[i] 32 | clip = clip[0:240,0:240] 33 | clip_file = os.path.join(IMAGE_PATH,path+str(i)+IMAGE_FORMAT) 34 | plt.imsave(clip_file,clip) 35 | ''' 36 | def clip_all(img,label,path1,path2): 37 | all_one_array = np.ones_like(img[0]) 38 | a = '' 39 | for i in range(label.shape[0]): 40 | true_sum = np.sum(label[i] == LABEL_NUM[0]*all_one_array)+np.sum(label[i]==LABEL_NUM[1]*all_one_array)+np.sum(label[i]==LABEL_NUM[2]*all_one_array) 41 | if true_sum>0: 42 | clip = img[i] 43 | clip_ = label[i] 44 | clip = clip[100:340,100:340] 45 | clip_ = clip_[100:340,100:340] 46 | clip_file = os.path.join(IMAGE_PATH,path1+str(i)+ IMAGE_FORMAT) 47 | clip_file_ = os.path.join(LABEL_PATH,path2+str(i)+LABEL_FORMAT) 48 | plt.imsave(clip_file,clip) 49 | plt.imsave(clip_file_,clip_) 50 | else: 51 | a += str(i) 52 | return a 53 | 54 | 55 | 56 | def main(): 57 | path_ = 'D:/grade4.2/pj1/MSCMRSeg/MSCMRSeg/LGE/img' 58 | path_lb = 'D:/grade4.2/pj1/MSCMRSeg/MSCMRSeg/LGE/lab' 59 | dirs_ = os.listdir(path_) 60 | #dirs_ = dirs[0::3] 61 | dirs_lb = os.listdir(path_lb) 62 | total_list = [] 63 | for i in range(len(dirs_)): 64 | img = load_img(os.path.join(path_,dirs_[i])) 65 | label = load_img(os.path.join(path_lb,dirs_lb[i])) 66 | a = clip_all(img,label,os.path.splitext(os.path.splitext(dirs_[i])[0])[0],os.path.splitext(os.path.splitext(dirs_lb[i])[0])[0]) 67 | total_list.append(a) 68 | total_list = np.array(total_list) 69 | np.save('D:/grade4.2/pj1/list.npy',total_list) 70 | ''' 71 | for dir in dirs_: 72 | img = load_img(os.path.join(path,dir)) 73 | clip_img(img,os.path.splitext(os.path.splitext(dir)[0])[0]) 74 | dirs_lb = os.listdir(path_lb) 75 | 76 | for dir in dirs_lb: 77 | img = load_img(os.path.join(path_lb,dir)) 78 | clip_label(img,os.path.splitext(os.path.splitext(dir)[0])[0]) 79 | ''' 80 | main() 81 | -------------------------------------------------------------------------------- /Unet1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Mar 8 15:23:22 2020 4 | 5 | @author: DELL 6 | """ 7 | import os 8 | import glob 9 | from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img 10 | from PIL import Image 11 | import numpy as np 12 | import tensorflow as tf 13 | from tensorflow.python.keras.callbacks import ModelCheckpoint, LearningRateScheduler 14 | from tensorflow.python.keras.models import * 15 | from tensorflow.python.keras.layers import BatchNormalization,concatenate,Input, Conv2D, Activation,MaxPooling2D, Dense,UpSampling2D, Dropout, Cropping2D,Flatten,Add,Multiply 16 | import random 17 | from tensorflow.python import keras 18 | from tensorflow.python.keras import backend 19 | import matplotlib.pyplot as plt 20 | from tensorflow.python.keras import * 21 | import math 22 | SIZE = 240 23 | IMG_PATH = './LGE' 24 | LABEL_PATH = './LGE' 25 | IMG_PATH1 = './data3' 26 | LABEL_PATH1 = './data3' 27 | class LossHistory(keras.callbacks.Callback): 28 | def on_train_begin(self, logs={}): 29 | self.losses = {'batch':[], 'epoch':[]} 30 | self.accuracy = {'batch':[], 'epoch':[]} 31 | self.val_loss = {'batch':[], 'epoch':[]} 32 | self.val_acc = {'batch':[], 'epoch':[]} 33 | 34 | def on_batch_end(self, batch, logs={}): 35 | self.losses['batch'].append(logs.get('loss')) 36 | self.accuracy['batch'].append(logs.get('acc')) 37 | self.val_loss['batch'].append(logs.get('val_loss')) 38 | self.val_acc['batch'].append(logs.get('val_acc')) 39 | 40 | def on_epoch_end(self, batch, logs={}): 41 | self.losses['epoch'].append(logs.get('loss')) 42 | self.accuracy['epoch'].append(logs.get('acc')) 43 | self.val_loss['epoch'].append(logs.get('val_loss')) 44 | self.val_acc['epoch'].append(logs.get('val_acc')) 45 | 46 | def loss_plot(self, loss_type): 47 | iters = range(len(self.losses[loss_type])) 48 | plt.figure() 49 | # acc 50 | plt.plot(iters, self.accuracy[loss_type], 'r', label='train acc') 51 | # loss 52 | plt.plot(iters, self.losses[loss_type], 'g', label='train loss') 53 | if loss_type == 'epoch': 54 | # val_acc 55 | plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc') 56 | # val_loss 57 | plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss') 58 | plt.grid(True) 59 | plt.xlabel(loss_type) 60 | plt.ylabel('acc-loss') 61 | plt.legend(loc="upper right") 62 | plt.show() 63 | 64 | def scheduler(epoch): 65 | #lr = backend.get_value(model.optimizer.lr) 66 | # initial learningrate=0.01 67 | if epoch == 0: 68 | lr = 0.001 69 | return lr 70 | else: 71 | lr = backend.get_value(model.optimizer.lr) 72 | #backend.set_value(model.optimizer.lr,lr*math.exp(-0.3*epoch)) 73 | backend.set_value(model.optimizer.lr,lr*math.exp(-0.3*epoch)) 74 | print("lr changed to {}".format(lr*math.exp(-0.5*epoch))) 75 | return backend.get_value(model.optimizer.lr) 76 | def load_data(path): 77 | img = load_img(path) 78 | img = img_to_array(img) 79 | img = img.reshape((1,)+img.shape) 80 | return img 81 | def image_augmentation(img,label,augnum):#num means batch_size 82 | #label = label.reshape((1,)+label.shape) 83 | 84 | image_datagen = ImageDataGenerator(rotation_range = 0.2, 85 | width_shift_range = 0.2, 86 | height_shift_range = 0.2, 87 | shear_range = 0.2, 88 | zoom_range = 0.2, 89 | fill_mode = 'nearest') 90 | label_datagen = ImageDataGenerator(rotation_range = 0.2, 91 | width_shift_range = 0.2, 92 | height_shift_range = 0.2, 93 | shear_range = 0.2, 94 | zoom_range = 0.2, 95 | fill_mode = 'nearest') 96 | seed = random.randint(1,10000) 97 | n = 0 98 | seed_ = 1 99 | image_datagen.fit(img,seed=seed_) 100 | label_datagen.fit(label,seed=seed_) 101 | for batch in image_datagen.flow(img,batch_size=1,save_to_dir=IMG_PATH,save_prefix='aug',save_format='png',seed=seed): 102 | n +=1 103 | if n > augnum: 104 | break 105 | n = 0 106 | for batch in label_datagen.flow(label,batch_size=1,save_to_dir=LABEL_PATH,save_prefix='aug_label',save_format='png',seed=seed): 107 | n +=1 108 | if n >augnum: 109 | break 110 | return 111 | ''' 112 | seed = 1 113 | image_datagen.fit(img,augment=True,seed=seed) 114 | label_datagen.fit(label,augment=True,seed=seed) 115 | image_generator = image_datagen.flow_from_directory(IMG_PATH,class_mode=None,seed=seed) 116 | label_generator = label_datagen.flow_from_directory(LABEL_PATH,class_mode=None,seed=seed) 117 | train_generator = zip(image_generator,label_generator) 118 | return train_generator 119 | ''' 120 | def create_train_data(): 121 | i = 0 122 | print("Creating training images...") 123 | imgs = glob.glob(IMG_PATH +'/*.png') 124 | #labels = glob.glob(LABEL_PATH+'/*.png') 125 | imgdatas = np.ndarray((len(imgs),SIZE,SIZE,3),dtype=np.int8) 126 | #imglabels = np.ndarray((len(labels),SIZE,SIZE,4),dtype=np.int8) 127 | for imgname in imgs: 128 | img = load_img(imgname) 129 | img = img_to_array(img) 130 | img = img/255. 131 | img -= np.mean(img) 132 | img /= np.std(img,axis=0) 133 | imgdatas[i] = img 134 | i += 1 135 | i = 0 136 | ''' 137 | for imgname in labels: 138 | img = load_img(imgname) 139 | img = img_to_array(img) 140 | img0 = img[:,:,2].reshape(240,240,1)#mask 141 | img1 = img[:,:,2].reshape(240,240,1)#yellow 142 | img2 = img[:,:,0].reshape(240,240,1)#blue 143 | img3 = img[:,:,1].reshape(240,240,1)#green 144 | img0[img0<80]=0 145 | img0[img0>90]=0 146 | img0[img0>0]=1 147 | img1[img1<130]=0 148 | img1[img1>150]=0 149 | img1[img1>0]=1 150 | img2[img2>=230]=1 151 | img2[img2<230]=0 152 | img3[img3>=200]=1 153 | img3[img3<200]=0 154 | img = np.concatenate((img0,img1),axis=2) 155 | img = np.concatenate((img,img2),axis=2) 156 | img = np.concatenate((img,img3),axis=2) 157 | #img /= np.std(img,axis=0) 158 | imglabels[i] = img 159 | i += 1 160 | ''' 161 | print('loading done') 162 | np.save(IMG_PATH + '/imgs_train.npy',imgdatas) 163 | #np.save(LABEL_PATH + '/labels_train.npy',imglabels) 164 | print('saved') 165 | return 166 | def load_train_data(): 167 | print('loading') 168 | img_train = np.load(IMG_PATH1+'/imgs_train.npy') 169 | img_train -= np.mean(img_train2) 170 | img_train /= np.std(img_train2,axis=0) 171 | img_label = np.load(LABEL_PATH1+'/labels_train.npy') 172 | return img_train,img_label 173 | def dice_coef(y_true,y_pred): 174 | sum1 = 2*tf.reduce_sum(y_true*y_pred,axis=(0,1,2)) 175 | sum2 = tf.reduce_sum(y_true**2+y_pred**2,axis=(0,1,2)) 176 | dice = (sum1+0.1)/(sum2+0.1) 177 | dice = tf.reduce_mean(dice) 178 | return dice 179 | def dice_coef_loss(y_true,y_pred): 180 | return 1.-dice_coef(y_true,y_pred) 181 | def dice_Score_0(y_true,y_pred): 182 | sum1 = tf.reduce_sum(y_true*y_pred,axis=(0,1,2)) 183 | sum2 = tf.reduce_sum(y_true**2+y_pred**2,axis=(0,1,2)) 184 | #sum2 = tf.reduce_sum(y_true**2,axis=(0,1,2)) 185 | dice = 2*sum1/sum2 186 | return dice[0] 187 | def dice_Score_1(y_true,y_pred): 188 | sum1 = tf.reduce_sum(y_true*y_pred,axis=(0,1,2)) 189 | sum2 = tf.reduce_sum(y_true**2+y_pred**2,axis=(0,1,2)) 190 | #sum2 = tf.reduce_sum(y_true**2,axis=(0,1,2)) 191 | dice = 2*sum1/sum2 192 | return dice[1] 193 | def dice_Score_2(y_true,y_pred): 194 | sum1 = tf.reduce_sum(y_true*y_pred,axis=(0,1,2)) 195 | sum2 = tf.reduce_sum(y_true**2+y_pred**2,axis=(0,1,2)) 196 | #sum2 = tf.reduce_sum(y_true**2,axis=(0,1,2)) 197 | dice = 2*sum1/sum2 198 | return dice[2] 199 | def dice_Score_3(y_true,y_pred): 200 | sum1 = tf.reduce_sum(y_true*y_pred,axis=(0,1,2)) 201 | sum2 = tf.reduce_sum(y_true**2+y_pred**2,axis=(0,1,2)) 202 | #sum2 = tf.reduce_sum(y_true**2,axis=(0,1,2)) 203 | dice = 2*sum1/sum2 204 | return dice[3] 205 | def Unet(num_class, image_size): 206 | 207 | inputs = Input(shape=[image_size, image_size, 3]) 208 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(inputs) 209 | #conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same')(conv1) 210 | conv1= Conv2D(64,3,padding='same',kernel_initializer='glorot_normal',bias_initializer='zeros')(conv1) 211 | nor1 = BatchNormalization(momentum=.99,epsilon=0.001, 212 | center=True,scale=True,beta_initializer='zeros',gamma_initializer='ones', 213 | moving_mean_initializer='zeros',moving_variance_initializer='ones')(conv1) 214 | act1 = Activation('relu')(nor1) 215 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 216 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(pool1) 217 | #conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same')(conv2) 218 | conv2 = Conv2D(128,3,padding='same',kernel_initializer='glorot_normal',bias_initializer='zeros')(conv2) 219 | nor2 = BatchNormalization(momentum=.99,epsilon=0.001, 220 | center=True,scale=True,beta_initializer='zeros',gamma_initializer='ones', 221 | moving_mean_initializer='zeros',moving_variance_initializer='ones')(conv2) 222 | pool2 = MaxPooling2D(pool_size=(2, 2))(nor2) 223 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(pool2) 224 | #conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same')(conv3) 225 | conv3 = Conv2D(256,3,padding='same',kernel_initializer='glorot_normal',bias_initializer='zeros')(conv3) 226 | nor3 = BatchNormalization(momentum=.99,epsilon=0.001, 227 | center=True,scale=True,beta_initializer='zeros',gamma_initializer='ones', 228 | moving_mean_initializer='zeros',moving_variance_initializer='ones')(conv3) 229 | pool3 = MaxPooling2D(pool_size=(2, 2))(nor3) 230 | 231 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same')(pool3) 232 | #conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same')(conv4) 233 | conv4 = Conv2D(512,3,padding='same')(conv4) 234 | nor4 = BatchNormalization(momentum=.99,epsilon=0.001, 235 | center=True,scale=True,beta_initializer='zeros',gamma_initializer='ones', 236 | moving_mean_initializer='zeros',moving_variance_initializer='ones')(conv4) 237 | #drop4 = Dropout(0.5)(conv4) 238 | pool4 = MaxPooling2D(pool_size=(2, 2))(nor4) 239 | 240 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(pool4) 241 | #conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same')(conv5) 242 | flatten1 = Flatten()(conv5) 243 | dense1 = Dense(1024,activation='relu')(flatten1) 244 | dense2 = Dense(256,activation='relu')(dense1) 245 | dense3 = Dense(3,activation='softmax')(dense2) 246 | conv5 = Conv2D(1024,3,padding='same',kernel_initializer='glorot_normal',bias_initializer='zeros')(conv5) 247 | nor5 = BatchNormalization(momentum=.99,epsilon=0.001, 248 | center=True,scale=True,beta_initializer='zeros',gamma_initializer='ones', 249 | moving_mean_initializer='zeros',moving_variance_initializer='ones')(conv5) 250 | #drop5 = Dropout(0.5)(nor5) 251 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(UpSampling2D(size = (2,2))(nor5)) 252 | merge6 = concatenate([nor4,up6], axis = 3) 253 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(merge6) 254 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(conv6) 255 | 256 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(UpSampling2D(size = (2,2))(conv6)) 257 | merge7 = concatenate([nor3,up7], axis = 3) 258 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(merge7) 259 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(conv7) 260 | 261 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(UpSampling2D(size = (2,2))(conv7)) 262 | merge8 = concatenate([nor2,up8], axis = 3) 263 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(merge8) 264 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(conv8) 265 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(conv8) 266 | 267 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(UpSampling2D(size = (2,2))(conv8)) 268 | merge9 = concatenate([nor1,up9], axis = 3) 269 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(merge9) 270 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same',kernel_initializer='glorot_normal',bias_initializer='zeros')(conv9) 271 | # conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same')(conv9) 272 | conv10 = Conv2D(num_class, 1, activation = 'sigmoid')(conv9) 273 | model = Model(inputs = inputs, outputs = conv10) 274 | model.compile(optimizer = 'adam', 275 | loss=dice_coef_loss, metrics = [dice_Score_0,dice_Score_1,dice_Score_2,dice_Score_3] ) 276 | ''' 277 | model = Model(inputs = [inputs], outputs =[conv10,dense3]) 278 | model.compile(optimizer = 'adam', 279 | loss={ 280 | 'conv2d_23':dice_coef_loss, 281 | 'dense_2':'sparse_categorical_crossentropy'}, 282 | loss_weights={ 283 | 'conv2d_23':0.5, 284 | 'dense_2':0.5}, 285 | metrics = { 286 | 'conv2d':[dice_Score_0,dice_Score_1,dice_Score_2,dice_Score_3], 287 | 'dense_2':['accuracy']}) 288 | ''' 289 | return model 290 | 291 | def train(): 292 | ''' 293 | if os.path.exists(IMG_PATH+'/imgs_train.npy')==False: 294 | img_path = glob.glob(IMG_PATH+'/*.png') 295 | label_path = glob.glob(LABEL_PATH+'/*.png') 296 | print('loading') 297 | for i in range(len(img_path)): 298 | if i%10==0: 299 | print(i) 300 | img = load_data(img_path[i]) 301 | label = load_data(label_path[i]) 302 | image_augmentation(img,label,10) 303 | ''' 304 | # create_train_data() 305 | img_train,img_label = load_train_data() 306 | print('loaded') 307 | global model 308 | model = Unet(4,240) 309 | #os.mkdir('my_log_dir_0330_1') 310 | model_checkpoint = ModelCheckpoint('./SNn.h5', 311 | monitor='loss',verbose=1, save_best_only=True) 312 | history = LossHistory() 313 | reduce_lr = LearningRateScheduler(scheduler) 314 | #callbacks = [ 315 | # keras.callbacks.TensorBoard( 316 | # log_dir='my_log_dir', 317 | # histogram_freq=1, 318 | # ),history,reduce_lr,model_checkpoint] 319 | model.fit(img_train,img_label,batch_size=3,epochs=10,validation_split=0.1, 320 | shuffle=True,verbose=1,callbacks=[reduce_lr,model_checkpoint]) 321 | history.loss_plot('epoch') 322 | model.save('./SNn.h5') 323 | return 324 | 325 | def save_img(): 326 | imgs = np.load(Test_PATH + '/test_result.npy') 327 | for i in range(imgs.shape[0]): 328 | img = imgs[i] 329 | img = array_to_img(img) 330 | img.save(Test_PATH+'/result%d.png'%(i)) 331 | return 332 | 333 | train() 334 | -------------------------------------------------------------------------------- /data_train.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import os 3 | import numpy as np 4 | import shutil 5 | from PIL import Image 6 | from tqdm import tqdm 7 | import cv2 8 | 9 | 10 | def to_standard(img): 11 | img[np.isnan(img)] = 0 12 | for k in range(np.shape(img)[2]): 13 | st = img[:, :, k] 14 | if np.amin(st) != np.amax(st): 15 | st -= np.amin(st) 16 | st /= np.amax(st) 17 | st *= 255 18 | return img 19 | 20 | 21 | def to_2d(data="train_for_SRNN"): 22 | folders = os.listdir(data) 23 | for folder in folders: 24 | if '2d' not in folder: 25 | count = 0 26 | print("Current Folder: {}".format(folder)) 27 | folder_path = os.path.join(data, folder) 28 | 29 | save_path = folder_path + '_2d' 30 | if not os.path.exists(save_path): 31 | os.makedirs(save_path) 32 | else: 33 | shutil.rmtree(save_path) 34 | os.makedirs(save_path) 35 | 36 | files = os.listdir(folder_path) 37 | for file in files: 38 | file_name = file[0: file.find('.')] 39 | file_path = os.path.join(folder_path, file) 40 | label = to_standard(nib.load(file_path).get_fdata()) 41 | print(label.shape) 42 | 43 | slices = label.shape[2] 44 | for s in range(slices): 45 | position = s / slices 46 | new_name = os.path.join(save_path, (str(s) + '_{:.1f}_' + file_name + '_lab.png').format(position)) 47 | 48 | slice_temp = label[:, :, s] 49 | slice_img = Image.fromarray(alter_intensity(slice_temp.round().astype(np.uint8), folder)) 50 | 51 | slice_img.save(new_name, "PNG", dpi=[300, 300], quality=95) 52 | count += 1 53 | print("Finishing processing: {}".format(new_name)) 54 | 55 | with open("num_record.txt", 'a') as f: 56 | f.write(folder + ': ' + str(count // 2) + '\n') 57 | 58 | 59 | def verify_intensity(data="train_for_SRNN"): 60 | folders = os.listdir(data) 61 | for folder in folders: 62 | if '2d' in folder: 63 | folder_path = os.path.join(data, folder) 64 | files = os.listdir(folder_path) 65 | for file in files: 66 | file_path = os.path.join(folder_path, file) 67 | lab = Image.open(file_path) 68 | lab_array = np.array(lab) 69 | intensity = np.unique(lab_array) 70 | print(intensity) 71 | 72 | 73 | def alter_intensity(img, folder): 74 | if 'seg' in folder: 75 | img[img == 51] = 200 76 | img[img == 153] = 88 77 | img[img == 255] = 244 78 | else: 79 | img[img == 85] = 88 80 | img[img == 170] = 200 81 | img[img == 255] = 244 82 | return img 83 | 84 | 85 | def unite_folders(data="train_for_SRNN"): 86 | save = "in_use" 87 | if not os.path.exists(save): 88 | os.makedirs(save) 89 | else: 90 | shutil.rmtree(save) 91 | os.makedirs(save) 92 | 93 | folders = os.listdir(data) 94 | for folder in folders: 95 | if '2d' in folder: 96 | folder_path = os.path.join(data, folder) 97 | files = os.listdir(folder_path) 98 | for file in files: 99 | file_path = os.path.join(folder_path, file) 100 | save_path = os.path.join(save, file) 101 | shutil.copy(file_path, save_path) 102 | print("Done.") 103 | 104 | 105 | def verify_shape(data="crop"): 106 | files = os.listdir(data) 107 | save_regular = [] 108 | save_strange = [] 109 | center_strange = [] 110 | for file in files: 111 | file_path = os.path.join(data, file) 112 | lab = np.array(Image.open(file_path)) 113 | if 'gt' not in file: 114 | save_regular.append(lab.shape) 115 | else: 116 | save_strange.append(lab.shape) 117 | center_strange.append(compute_center(lab)) 118 | print(save_regular) 119 | print(save_strange) 120 | # print(center_strange) 121 | 122 | 123 | def compute_center(label): 124 | label = np.expand_dims(label, axis=-1) 125 | points = np.where(label > 0) 126 | return np.array([[np.average(points[0][points[2] == j]), np.average(points[1][points[2] == j])] for j in 127 | range(label.shape[-1])]) 128 | 129 | 130 | # old method, don't use it 131 | def resize_and_crop(data='in_use'): 132 | files = os.listdir(data) 133 | save_path = "resize" 134 | if not os.path.exists(save_path): 135 | os.makedirs(save_path) 136 | else: 137 | shutil.rmtree(save_path) 138 | os.makedirs(save_path) 139 | 140 | for file in files: 141 | file_path = os.path.join(data, file) 142 | lab = Image.open(file_path) 143 | if 'gt' in file: 144 | pass 145 | height, width = lab.size 146 | lab_new = lab.resize((int(2.5*height), int(2.5*width)), Image.NEAREST) 147 | print(lab_new.size) 148 | print(np.unique(lab_new)) 149 | lab_new.save("resize/" + file) 150 | else: 151 | shutil.copy(file_path, os.path.join(save_path, file)) 152 | lab_new = np.lib.pad(lab, 60, 'constant', constant_values=0) 153 | Image.fromarray(lab_new).save("resize/" + file) 154 | 155 | 156 | # old method, don't use it 157 | def center_crop_2d(data='resize', center_roi=(120, 120, 1)): 158 | 159 | files = os.listdir(data) 160 | save_path = "crop" 161 | if not os.path.exists(save_path): 162 | os.makedirs(save_path) 163 | else: 164 | shutil.rmtree(save_path) 165 | os.makedirs(save_path) 166 | 167 | for file in files: 168 | 169 | file_path = os.path.join(data, file) 170 | label = np.expand_dims(Image.open(file_path), axis=-1) 171 | 172 | if 'gt' in file: 173 | assert np.all(np.array(center_roi) <= np.array(label.shape)), print( 174 | 'Patch size exceeds dimensions.') 175 | 176 | center = compute_center(label) 177 | where_are_nan = np.isnan(center) 178 | center[where_are_nan] = int(label.shape[0] // 2) 179 | 180 | x = np.array([center[i][0] for i in range(label.shape[-1])]).astype(np.int) 181 | y = np.array([center[i][1] for i in range(label.shape[-1])]).astype(np.int) 182 | 183 | x = x[0] 184 | y = y[0] 185 | 186 | beginx = x - center_roi[0] 187 | beginy = y - center_roi[1] 188 | endx = x + center_roi[0] 189 | endy = y + center_roi[1] 190 | 191 | gt = label[beginx:endx, beginy:endy, :] 192 | Image.fromarray(np.squeeze(gt)).save(os.path.join(save_path, file)) 193 | 194 | else: 195 | gt = label 196 | Image.fromarray(np.squeeze(gt)).save(os.path.join(save_path, file)) 197 | 198 | 199 | def center_crop_xz(data='in_use', roi=(120, 120)): 200 | 201 | files = os.listdir(data) 202 | save_path = "cropped_in_use" 203 | if not os.path.exists(save_path): 204 | os.makedirs(save_path) 205 | else: 206 | shutil.rmtree(save_path) 207 | os.makedirs(save_path) 208 | 209 | for file in tqdm(files): 210 | file_path = os.path.join(data, file) 211 | label = np.array(Image.open(file_path)) 212 | 213 | if 'gt' in file: 214 | assert np.all(np.array(roi) <= np.array(label.shape)), print( 215 | 'Patch size exceeds dimensions.') 216 | 217 | center = compute_center(label) 218 | where_are_nan = np.isnan(center) 219 | center[where_are_nan] = int(label.shape[0] // 2) 220 | center_2d = np.array(center[0], dtype=np.int32) 221 | window_size = roi[0] * 2 222 | 223 | begin = np.where(center_2d - window_size // 2 < 0, 224 | 0, 225 | center_2d - window_size // 2) 226 | 227 | end = np.where(center_2d - window_size // 2 + window_size > label.shape[:], 228 | label.shape[:], 229 | center_2d - window_size // 2 + window_size) 230 | 231 | offset1 = np.where(center_2d - window_size // 2 < 0, 232 | window_size // 2 - center_2d, 233 | 0) 234 | 235 | offset2 = np.where(center_2d - window_size // 2 + window_size > label.shape[:], 236 | center_2d - window_size // 2 + window_size - label.shape[:], 237 | 0) 238 | 239 | label_crop = label[begin[0]:end[0], begin[1]:end[1]] 240 | 241 | label_pad = np.pad(label_crop, pad_width=((offset1[0], offset2[0]), 242 | (offset1[1], offset2[1])), 243 | mode='constant') 244 | print(label_pad.shape) 245 | 246 | Image.fromarray(np.squeeze(label_pad)).save(os.path.join(save_path, file)) 247 | 248 | else: 249 | label_pad = np.pad(label, 60, 'constant', constant_values=0) 250 | print(label_pad.shape) 251 | Image.fromarray(np.squeeze(label_pad)).save(os.path.join(save_path, file)) 252 | 253 | 254 | def center_crop_old_data(data_path='../dataset/train_2d', 255 | save_path="../dataset/train_2d_crop", 256 | roi=(120, 120)): 257 | print("Currently crop ROI for dataset: ", data_path) 258 | files = os.listdir(data_path) 259 | 260 | if not os.path.exists(save_path): 261 | os.makedirs(save_path) 262 | else: 263 | shutil.rmtree(save_path) 264 | os.makedirs(save_path) 265 | 266 | for file in files: 267 | if 'lab' in file: 268 | label_path = os.path.join(data_path, file) 269 | label = np.array(Image.open(label_path)) 270 | image_name = file.replace("_lab.png", "_img.png") 271 | image = np.array(Image.open(os.path.join(data_path, image_name))) 272 | 273 | assert np.all(np.array(roi) <= np.array(label.shape)), print( 274 | 'Patch size exceeds dimensions.') 275 | 276 | center = compute_center(label) 277 | where_are_nan = np.isnan(center) 278 | center[where_are_nan] = int(label.shape[0] // 2) 279 | center_2d = np.array(center[0], dtype=np.int32) 280 | window_size = roi[0] * 2 281 | 282 | begin = np.where(center_2d - window_size // 2 < 0, 283 | 0, 284 | center_2d - window_size // 2) 285 | 286 | end = np.where(center_2d - window_size // 2 + window_size > label.shape[:], 287 | label.shape[:], 288 | center_2d - window_size // 2 + window_size) 289 | 290 | offset1 = np.where(center_2d - window_size // 2 < 0, 291 | window_size // 2 - center_2d, 292 | 0) 293 | 294 | offset2 = np.where(center_2d - window_size // 2 + window_size > label.shape[:], 295 | center_2d - window_size // 2 + window_size - label.shape[:], 296 | 0) 297 | 298 | label_crop = label[begin[0]:end[0], begin[1]:end[1]] 299 | image_crop = image[begin[0]:end[0], begin[1]:end[1]] 300 | 301 | label_pad = np.pad(label_crop, pad_width=((offset1[0], offset2[0]), 302 | (offset1[1], offset2[1])), 303 | mode='constant') 304 | image_pad = np.pad(image_crop, pad_width=((offset1[0], offset2[0]), 305 | (offset1[1], offset2[1])), 306 | mode='constant') 307 | # print(label_pad.shape) 308 | # print(image_pad.shape) 309 | 310 | Image.fromarray(np.squeeze(label_pad)).save(os.path.join(save_path, file)) 311 | Image.fromarray(np.squeeze(image_pad)).save(os.path.join(save_path, image_name)) 312 | print("Done: ", image_name) 313 | 314 | 315 | def center_crop_old_old(data_path='../dataset/train_2d', 316 | save_path="../dataset/train_2d_crop", 317 | roi=(120, 120, 1)): 318 | print("Currently crop ROI for dataset: ", data_path) 319 | files = os.listdir(data_path) 320 | 321 | if not os.path.exists(save_path): 322 | os.makedirs(save_path) 323 | else: 324 | shutil.rmtree(save_path) 325 | os.makedirs(save_path) 326 | 327 | for file in files: 328 | if 'lab' in file: 329 | label_path = os.path.join(data_path, file) 330 | label = np.array(Image.open(label_path)) 331 | image_name = file.replace("_lab.png", "_img.png") 332 | image = np.array(Image.open(os.path.join(data_path, image_name))) 333 | label = np.expand_dims(label, axis=-1) 334 | image = np.expand_dims(image, axis=-1) 335 | 336 | assert np.all(np.array(roi) <= np.array(label.shape)), print( 337 | 'Patch size exceeds dimensions.') 338 | center = compute_center(label) 339 | where_are_nan = np.isnan(center) 340 | center[where_are_nan] = int(label.shape[0] // 2) 341 | 342 | x = np.array([center[i][0] for i in range(label.shape[-1])]).astype(np.int) 343 | y = np.array([center[i][1] for i in range(label.shape[-1])]).astype(np.int) 344 | 345 | x = x[0] 346 | y = y[0] 347 | 348 | beginx = x - roi[0] 349 | beginy = y - roi[1] 350 | endx = x + roi[0] 351 | endy = y + roi[1] 352 | 353 | gt = label[beginx:endx, beginy:endy, :] 354 | img = image[beginx:endx, beginy:endy, :] 355 | 356 | Image.fromarray(np.squeeze(gt)).save(os.path.join(save_path, file)) 357 | Image.fromarray(np.squeeze(img)).save(os.path.join(save_path, image_name)) 358 | print("Done: ", image_name) 359 | 360 | 361 | def check_shape(data="test_crop"): 362 | 363 | files = os.listdir(data) 364 | for file in files: 365 | file_path = os.path.join(data, file) 366 | lab = Image.open(file_path) 367 | # if 'gt' not in file: 368 | print("Shape of current label: ", lab.size) 369 | print("Intensities of current label: ", np.unique(lab)) 370 | 371 | 372 | class LabelErosion: 373 | 374 | def __init__(self, data_path="cropped_in_use", save_path="noise_in_use", debug=False): 375 | self.data_path = data_path 376 | self.save_path = save_path 377 | self.debug = debug 378 | self.intensities = self.get_intensities() 379 | 380 | def change_intensity(self, img_raw): 381 | img = np.zeros_like(img_raw) 382 | for i in range(len(self.intensities)): 383 | if i < len(self.intensities) - 1: 384 | img[img_raw == self.intensities[i]] = self.intensities[i+1] 385 | else: 386 | img[img_raw == self.intensities[i]] = self.intensities[0] 387 | return img 388 | 389 | def get_intensities(self): 390 | images = os.listdir(self.data_path) 391 | image_path = os.path.join(self.data_path, images[0]) 392 | img = cv2.imread(image_path, 0) 393 | intensities = np.unique(img) 394 | return intensities 395 | 396 | def sample_class(self, array): 397 | return array[np.random.randint(0, 3)] 398 | 399 | def produce_noise(self, image): 400 | image_path = os.path.join(self.data_path, image) 401 | img = cv2.imread(image_path, 0) 402 | 403 | img_1 = self.change_intensity(img) 404 | img_2 = self.change_intensity(img_1) 405 | img_3 = self.change_intensity(img_2) 406 | 407 | candidate = np.stack([img_1, img_2, img_3], axis=-1) 408 | selection = np.random.uniform(0, 1, img.shape) 409 | 410 | sampled_img = np.apply_along_axis(self.sample_class, axis=-1, arr=candidate) 411 | # print(sampled_img.shape) 412 | 413 | result = np.where(selection < 0.1, sampled_img, img) 414 | 415 | if self.debug: 416 | cv2.imshow("img_raw", img) 417 | cv2.imshow("img_1", img_1) 418 | cv2.imshow("img_2", img_2) 419 | cv2.imshow("img_3", img_3) 420 | cv2.imshow("sampled", result) 421 | cv2.waitKey(0) 422 | cv2.destroyAllWindows() 423 | 424 | return result 425 | 426 | def main(self): 427 | if not os.path.exists(self.save_path): 428 | os.makedirs(self.save_path) 429 | else: 430 | shutil.rmtree(self.save_path) 431 | os.makedirs(self.save_path) 432 | 433 | images = os.listdir(self.data_path) 434 | for image in images: 435 | noise_image = self.produce_noise(image) 436 | image = "noise_" + image 437 | cv2.imwrite(os.path.join(self.save_path, image), noise_image) 438 | print("Finish process: ", image) 439 | 440 | 441 | if __name__ == '__main__': 442 | 443 | # to_2d() 444 | # verify_intensity() 445 | # unite_folders() 446 | # verify_shape() 447 | # resize_and_crop() 448 | # center_crop_2d() 449 | # center_crop_xz() 450 | # check_shape() 451 | 452 | # center_crop_old_data(data_path="test/test_img", save_path="test/test_save_1") 453 | # center_crop_old_data() 454 | # center_crop_old_old(data_path="test/test_img", save_path="test/test_save") 455 | center_crop_old_old() 456 | 457 | # L = LabelErosion(data_path="../dataset/train_2d_crop", save_path="../dataset/noise_train_2d_crop") 458 | # L.main() 459 | 460 | 461 | 462 | 463 | 464 | -------------------------------------------------------------------------------- /main_.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import, unicode_literals 2 | from TensorflowCode.core import SRNN_structure as SRNN, util, GTLabelProvider as label_provider, \ 3 | image_util_SCN as image_util, unet_SCN as unet 4 | import numpy as np 5 | import click 6 | import os 7 | import logging 8 | from datetime import datetime 9 | 10 | t = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 11 | 12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" #按照顺序排列GPU 13 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 14 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 15 | 16 | CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) 17 | 18 | 19 | @click.command(context_settings=CONTEXT_SETTINGS) 20 | @click.option('--saliency', default='Base_Unet', help='Key word for this running time') 21 | @click.option('--run_times', default=1, type=click.IntRange(min=1, clamp=True), help='network training times') 22 | @click.option('--time', default=t, help='the current time or the time when the model to restore was trained') 23 | @click.option('--trainer_learning_rate', default=0.001, type=click.FloatRange(min=1e-8, clamp=True), 24 | help='network learning rate') 25 | @click.option('--train_validation_batch_size', default=5, type=click.IntRange(min=1), 26 | help='the number of validation cases') 27 | @click.option('--test_n_files', default=15, type=click.IntRange(min=1), help='the number of test cases') 28 | @click.option('--train_original_search_path', default='../dataset/train_original/*.nii.gz', 29 | help='search pattern to find all original training data and label images') 30 | @click.option('--srnn_search_path', default='../dataset/crop/*.png', 31 | help='search pattern to find all ground truth label to train SRNN') 32 | @click.option('--train_search_path', default='../dataset/train_data_2d/*.png', 33 | help='search pattern to find all training data and label images') 34 | @click.option('--train_data_suffix', default='_img.png', help='suffix pattern for the training data images') 35 | @click.option('--train_label_suffix', default='_lab.png', help='suffix pattern for the training label images') 36 | @click.option('--train_shuffle_data', default=True, type=bool, 37 | help='whether the order of training files should be randomized after each epoch') 38 | @click.option('--train_crop_patch', default=False, type=bool, 39 | help='whether patches of a certain size need to be cropped for training') 40 | @click.option('--train_patch_size', default=(-1, -1, -1), 41 | type=(click.IntRange(min=-1), click.IntRange(min=-1), click.IntRange(min=-1)), 42 | help='size of the training patches') 43 | @click.option('--train_channels', default=1, type=click.IntRange(min=1), help='number of training data channels') 44 | @click.option('--train_n_class', default=4, type=click.IntRange(min=1), 45 | help='number of training label classes, including the background') 46 | @click.option('--train_contain_foreground', default=False, type=bool, 47 | help='if the training patches should contain foreground') 48 | @click.option('--train_label_intensity', default=(0, 88, 200, 244), multiple=True, 49 | type=click.IntRange(min=0), help='list of intensities of the training ground truths') 50 | @click.option('--net_layers', default=5, type=click.IntRange(min=2), 51 | help='number of convolutional blocks in the down-sampling path') 52 | @click.option('--net_features_root', default=16, type=click.IntRange(min=1), 53 | help='number of features of the first convolution layer') 54 | @click.option('--net_cost_name', default=u'exponential_logarithmic', 55 | type=click.Choice(["cross_entropy", "weighted_cross_entropy", "dice_loss", 56 | "generalized_dice_loss", "cross_entropy+dice_loss", 57 | "weighted_cross_entropy+generalized_dice_loss", 58 | "exponential_logarithmic"]), help='type of the cost function') 59 | @click.option('--net_regularizer_type', default=None, 60 | type=click.Choice(['L2_norm', 'L1_norm', 'anatomical_constraint']), 61 | help='type of regularization') 62 | @click.option('--net_regularization_coefficient', default=5e-4, type=click.FloatRange(min=0), 63 | help='regularization coefficient') 64 | @click.option('--net_srnn_model_path', default='./autoencoder_trained_%s', 65 | help='path where to restore the SRNN auto-encoder parameters for regularization') 66 | @click.option('--trainer_batch_size', default=32, type=click.IntRange(min=1, clamp=True), 67 | help='batch size for each training iteration') 68 | @click.option('--trainer_optimizer_name', default='adam', type=click.Choice(['momentum', 'adam']), 69 | help='type of the optimizer to use (momentum or adam)') 70 | @click.option('--train_model_path', default='./unet_trained_%s_%s/No_%d', help='path where to store checkpoints') 71 | @click.option('--train_training_iters', default=638, type=click.IntRange(min=1), 72 | help='number of training iterations during each epoch') 73 | @click.option('--train_epochs', default=30, type=click.IntRange(min=1), help='number of epochs') 74 | @click.option('--train_dropout_rate', default=0.2, type=click.FloatRange(min=0, max=1), help='dropout probability') 75 | @click.option('--train_clip_gradient', default=False, type=bool, 76 | help='whether to apply gradient clipping with L2 norm threshold 1.0') 77 | @click.option('--train_display_step', default=100, type=click.IntRange(min=1), 78 | help='number of steps till outputting stats') 79 | @click.option('--train_prediction_path', default='./validation_prediction_%s_%s/No_%d', 80 | help='path where to save predictions on each epoch') 81 | @click.option('--train_restore', default=False, type=bool, help='whether previous model checkpoint need restoring') 82 | @click.option('--test_search_path', default='../dataset/test_data/*.nii.gz', 83 | help='a search pattern to find all test data and label images') 84 | @click.option('--test_data_suffix', default='_img.nii.gz', help='suffix pattern for the test data images') 85 | @click.option('--test_label_suffix', default='_lab.nii.gz', help='suffix pattern for the test label images') 86 | @click.option('--test_shuffle_data', default=False, type=bool, 87 | help='whether the order of the loaded test files path should be randomized') 88 | @click.option('--test_channels', default=1, type=click.IntRange(min=1), help='number of test data channels') 89 | @click.option('--test_n_class', default=4, type=click.IntRange(min=1), 90 | help='number of test label classes, including the background') 91 | @click.option('--test_label_intensity', default=(0, 88, 200, 244), multiple=True, 92 | type=click.IntRange(min=0), 93 | help='tuple of intensities of the test ground truths') 94 | @click.option('--test_prediction_path', default=u'./test_prediction_%s_%s/No_%d', 95 | help='path where to save test predictions') 96 | @click.option('--val_search_path', default='../dataset/val_data/*.nii.gz', 97 | help='a search pattern to find all validation data and label images') 98 | @click.option('--val_data_suffix', default='_img.nii.gz', help='suffix pattern for the val data images') 99 | @click.option('--val_label_suffix', default='_lab.nii.gz', help='suffix pattern for the val label images') 100 | @click.option('--val_shuffle_data', default=False, type=bool, 101 | help='whether the order of the loaded val files path should be randomized') 102 | @click.option('--val_channels', default=1, type=click.IntRange(min=1), help='number of val data channels') 103 | @click.option('--val_n_class', default=4, type=click.IntRange(min=1), 104 | help='number of val label classes, including the background') 105 | @click.option('--val_label_intensity', default=(0, 88, 200, 244), multiple=True, 106 | type=click.IntRange(min=0), 107 | help='tuple of intensities of the test ground truths') 108 | @click.option('--train_center_crop', default=True, type=bool, 109 | help='whether to extract roi from center during training') 110 | @click.option('--train_center_roi', default=(120, 120, 1), multiple=True, type=click.IntRange(min=0), 111 | help='roi size you want to extract during training') 112 | @click.option('--test_center_crop', default=True, type=bool, 113 | help='whether to extract roi from center while testing') 114 | @click.option('--test_center_roi', default=(120, 120, 1), multiple=True, type=click.IntRange(min=0), 115 | help='roi size you want to extract while testing') 116 | @click.option('--scn_parameter', default=5e-4, type=click.FloatRange(min=0), 117 | help='weight of spatial constraint') 118 | @click.option('--scn_button', default=False, type=bool, 119 | help='whether to use spatial constraint') 120 | def run(run_times, time, train_search_path, train_data_suffix, train_label_suffix, train_shuffle_data, train_crop_patch, 121 | train_patch_size, train_channels, train_n_class, train_contain_foreground, train_label_intensity, 122 | train_original_search_path, net_layers, net_features_root, net_cost_name, net_regularizer_type, 123 | net_regularization_coefficient, net_srnn_model_path, trainer_batch_size, trainer_optimizer_name, 124 | trainer_learning_rate, train_validation_batch_size, train_model_path, train_training_iters, train_epochs, 125 | train_dropout_rate, train_clip_gradient, train_display_step, train_prediction_path, train_restore, 126 | test_search_path, test_data_suffix, test_label_suffix, test_shuffle_data, test_channels, test_n_class, 127 | test_label_intensity, test_n_files, test_prediction_path, val_search_path, val_data_suffix, val_label_suffix, 128 | val_shuffle_data, val_channels, val_n_class, val_label_intensity, saliency, train_center_crop, train_center_roi, 129 | test_center_crop, test_center_roi, scn_parameter, scn_button, srnn_search_path 130 | ): 131 | 132 | if train_restore: 133 | assert time != t, "The time when the model to restore was trained is not the time now! " #断言 134 | 135 | train_acc_table = np.array([]) 136 | train_dice_table = np.array([]) 137 | train_auc_table = np.array([]) 138 | train_sens_table = np.array([]) 139 | train_spec_table = np.array([]) 140 | 141 | test_acc_table = np.array([]) 142 | test_dice_table = np.array([]) 143 | test_auc_table = np.array([]) 144 | 145 | for i in range(run_times): 146 | train_data_provider = image_util.ImageDataProvider(search_path=train_search_path, 147 | data_suffix=train_data_suffix, 148 | label_suffix=train_label_suffix, 149 | shuffle_data=train_shuffle_data, 150 | crop_patch=train_crop_patch, 151 | patch_size=train_patch_size, 152 | channels=train_channels, 153 | n_class=train_n_class, 154 | contain_foreground=train_contain_foreground, 155 | label_intensity=train_label_intensity, 156 | center_crop=train_center_crop, 157 | center_roi=train_center_roi, 158 | inference_phase=False 159 | ) 160 | 161 | SRNN_data_provider = label_provider.GTLabelProvider(search_path=srnn_search_path, 162 | label_suffix=train_label_suffix, 163 | shuffle_data=train_shuffle_data, 164 | channels=train_channels, 165 | n_class=train_n_class, 166 | label_intensity=train_label_intensity, 167 | center_crop=True 168 | ) 169 | 170 | train_original_data_provider = image_util.ImageDataProvider(search_path=train_original_search_path, 171 | data_suffix=test_data_suffix, 172 | label_suffix=test_label_suffix, 173 | shuffle_data=False, 174 | crop_patch=False, 175 | patch_size=train_patch_size, 176 | channels=train_channels, 177 | n_class=train_n_class, 178 | contain_foreground=train_contain_foreground, 179 | label_intensity=train_label_intensity, 180 | center_crop=train_center_crop, 181 | center_roi=train_center_roi, 182 | inference_phase=True 183 | ) 184 | 185 | test_data_provider = image_util.ImageDataProvider(search_path=test_search_path, 186 | data_suffix=test_data_suffix, 187 | label_suffix=test_label_suffix, 188 | shuffle_data=test_shuffle_data, 189 | crop_patch=False, 190 | channels=test_channels, 191 | n_class=test_n_class, 192 | label_intensity=test_label_intensity, 193 | center_crop=test_center_crop, 194 | center_roi=test_center_roi, 195 | inference_phase=True) 196 | 197 | val_data_provider = image_util.ImageDataProvider(search_path=val_search_path, 198 | data_suffix=val_data_suffix, 199 | label_suffix=val_label_suffix, 200 | shuffle_data=val_shuffle_data, 201 | crop_patch=False, 202 | channels=val_channels, 203 | n_class=val_n_class, 204 | label_intensity=val_label_intensity, 205 | center_crop=test_center_crop, 206 | center_roi=test_center_roi, 207 | inference_phase=True) 208 | 209 | print("lalalala") 210 | if net_regularizer_type == 'anatomical_constraint': 211 | if os.path.exists(net_srnn_model_path % saliency): 212 | print("SRNN has already been trained.") 213 | else: 214 | logging.info("Train SRNN with 45 patient data...") 215 | srnn = SRNN.AutoEncoder(batch_size=trainer_batch_size, 216 | cost_kwargs={'regularizer_type': 'L1_norm', 217 | 'regularization_coefficient': 5e-4}) 218 | srnn.train(train_data_provider, net_srnn_model_path % saliency) 219 | 220 | # logging.info("Train SRNN with additional ground truth labels...") 221 | # tf.reset_default_graph() 222 | # srnn_2 = SRNN.AutoEncoder(batch_size=trainer_batch_size, 223 | # cost_kwargs={'regularizer_type': 'L2_norm', 224 | # 'regularization_coefficient': 5e-4}) 225 | # srnn_2.train(SRNN_data_provider, net_srnn_model_path % saliency, training_iters=34, restore=True) 226 | logging.info("Done pre-train.") 227 | 228 | net = unet.UNet(layers=net_layers, features_root=net_features_root, channels=train_channels, 229 | n_class=train_n_class, batch_size=trainer_batch_size, cost_name=net_cost_name, 230 | sc_coefficient=scn_parameter, need_sc=scn_button, 231 | cost_kwargs={'regularizer_type': net_regularizer_type, 232 | 'regularization_coefficient': net_regularization_coefficient, 233 | 'srnn_model_path': (net_srnn_model_path % saliency)}) 234 | 235 | trainer = unet.Trainer(net, batch_size=trainer_batch_size, optimizer_name=trainer_optimizer_name, 236 | opt_kwargs={'learning_rate': trainer_learning_rate}, dropout=train_dropout_rate) 237 | 238 | path, train_acc, train_dice, train_auc, train_sens, train_spec = trainer.train(train_data_provider, 239 | val_data_provider, 240 | train_original_data_provider, 241 | train_validation_batch_size, 242 | model_path=train_model_path 243 | % (saliency, time, i), 244 | training_iters=train_training_iters, 245 | epochs=train_epochs, 246 | clip_gradient=train_clip_gradient, 247 | display_step=train_display_step, 248 | prediction_path=train_prediction_path 249 | % (saliency, time, i), 250 | restore=train_restore) 251 | train_acc_table = np.hstack((train_acc_table, train_acc)) 252 | train_dice_table = np.hstack((train_dice_table, train_dice)) 253 | train_auc_table = np.hstack((train_auc_table, train_auc)) 254 | train_sens_table = np.hstack((train_sens_table, train_sens)) 255 | train_spec_table = np.hstack((train_spec_table, train_spec)) 256 | 257 | train_summary_path = './train_summary_%s_%s' % (saliency, time) 258 | if not os.path.exists(train_summary_path): 259 | logging.info('Allocating {:}'.format(train_summary_path)) 260 | os.makedirs(train_summary_path) 261 | np.savez(os.path.join(train_summary_path, 'No_%d.npz' % i), acc=train_acc, dice=train_dice, auc=train_auc, 262 | sens=train_sens, spec=train_spec) 263 | 264 | test_data_provider.reset_index() 265 | test_data, test_labels, test_affine, _ = test_data_provider(test_n_files) 266 | predictions = net.predict(path, test_data) 267 | 268 | test_acc = unet.acc_rate(predictions, test_labels) 269 | test_dice = unet.dice_score(predictions, test_labels) 270 | test_auc = unet.auc_score(predictions, test_labels) 271 | 272 | test_acc_table = np.hstack((test_acc_table, test_acc)) 273 | test_dice_table = np.hstack((test_dice_table, test_dice)) 274 | test_auc_table = np.hstack((test_auc_table, test_auc)) 275 | 276 | dice_score_path = './dice_score_%s_%s' % (saliency, time) 277 | if not os.path.exists(dice_score_path): 278 | logging.info('Allocating {:}'.format(dice_score_path)) 279 | os.makedirs(dice_score_path) 280 | np.save(os.path.join(dice_score_path, 'No_%d.npy' % i), test_dice) 281 | print("##################################################") 282 | print("Mean Dice score= {:.4f}".format(np.mean(test_dice))) 283 | 284 | for j in range(len(test_data)): 285 | test_data[j] = np.expand_dims(test_data[j], axis=0).transpose((0, 2, 3, 1, 4)) 286 | test_labels[j] = np.expand_dims(test_labels[j], axis=0).transpose((0, 2, 3, 1, 4)) 287 | predictions[j] = np.expand_dims(predictions[j], axis=0).transpose((0, 2, 3, 1, 4)) 288 | 289 | util.save_prediction(test_data, test_labels, predictions, test_prediction_path % (saliency, time, i)) 290 | util.save_prediction_1(predictions, test_affine, test_prediction_path % (saliency, time, i)) 291 | util.save_prediction_2(predictions, test_prediction_path % (saliency, time, i)) 292 | 293 | test_summary_path = './test_summary_%s_%s' % (saliency, time) 294 | if not os.path.exists(test_summary_path): 295 | logging.info('Allocating {:}'.format(test_summary_path)) 296 | os.makedirs(test_summary_path) 297 | np.savez(os.path.join(test_summary_path, 'No_%d.npz' % i), acc=test_acc, dice=test_dice, auc=test_auc) 298 | 299 | mean_train_acc = np.mean(np.reshape(train_acc_table, [run_times, -1]), axis=0) 300 | mean_train_dice = np.mean(np.reshape(train_dice_table, [run_times, -1]), axis=0) 301 | mean_train_auc = np.mean(np.reshape(train_auc_table, [run_times, -1]), axis=0) 302 | mean_train_sens = np.mean(np.reshape(train_sens_table, [run_times, -1]), axis=0) 303 | mean_train_spec = np.mean(np.reshape(train_spec_table, [run_times, -1]), axis=0) 304 | mean_test_acc = np.mean(np.reshape(test_acc_table, [run_times, -1]), axis=0) 305 | mean_test_dice = np.mean(np.reshape(train_dice_table, [run_times, -1]), axis=0) 306 | mean_test_auc = np.mean(np.reshape(train_auc_table, [run_times, -1]), axis=0) 307 | 308 | np.savez('./mean_train_summary_%s_%s.npz' % (saliency, time), acc=mean_train_acc, auc=mean_train_auc, 309 | sens=mean_train_sens, spec=mean_train_spec, dice=mean_train_dice) 310 | np.savez('./mean_test_summary_%s_%s.npz' % (saliency, time), acc=mean_test_acc, auc=mean_test_auc, 311 | dice=mean_test_dice) 312 | 313 | 314 | if __name__ == '__main__': 315 | run() 316 | --------------------------------------------------------------------------------