├── example ├── brain_tumor_aug.png ├── brain_tumor_aug.pptx ├── brain_tumor_data.png └── brain_tumor_data.pptx ├── .gitignore ├── README.md ├── train.py ├── prepare_data_with_valid.py └── model.py /example/brain_tumor_aug.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsdonghao/u-net-brain-tumor/HEAD/example/brain_tumor_aug.png -------------------------------------------------------------------------------- /example/brain_tumor_aug.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsdonghao/u-net-brain-tumor/HEAD/example/brain_tumor_aug.pptx -------------------------------------------------------------------------------- /example/brain_tumor_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsdonghao/u-net-brain-tumor/HEAD/example/brain_tumor_data.png -------------------------------------------------------------------------------- /example/brain_tumor_data.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsdonghao/u-net-brain-tumor/HEAD/example/brain_tumor_data.pptx -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | orlaye/__pacache__ 2 | tensorlaye/.DS_Store 3 | .DS_Store 4 | dist 5 | build/ 6 | tensorlayer.egg-info 7 | data/.DS_Store 8 | *.pyc 9 | *.gz 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # U-Net Brain Tumor Segmentation 2 | 3 | 🚀:Feb 2019 the data processing implementation in this repo is not the fastest way (code need update, contribution is welcome), you can use TensorFlow dataset API instead. 4 | 5 | This repo show you how to train a U-Net for brain tumor segmentation. By default, you need to download the training set of [BRATS 2017](http://braintumorsegmentation.org) dataset, which have 210 HGG and 75 LGG volumes, and put the data folder along with all scripts. 6 | 7 | ```bash 8 | data 9 | -- Brats17TrainingData 10 | -- train_dev_all 11 | model.py 12 | train.py 13 | ... 14 | ``` 15 | 16 | ### About the data 17 | Note that according to the license, user have to apply the dataset from BRAST, please do **NOT** contact me for the dataset. Many thanks. 18 | 19 |
20 | 21 |
22 | Fig 1: Brain Image 23 |
24 | 25 | * Each volume have 4 scanning images: FLAIR、T1、T1c and T2. 26 | * Each volume have 4 segmentation labels: 27 | 28 | ``` 29 | Label 0: background 30 | Label 1: necrotic and non-enhancing tumor 31 | Label 2: edema  32 | Label 4: enhancing tumor 33 | ``` 34 | 35 | The `prepare_data_with_valid.py` split the training set into 2 folds for training and validating. By default, it will use only half of the data for the sake of training speed, if you want to use all data, just change `DATA_SIZE = 'half'` to `all`. 36 | 37 | ### About the method 38 | 39 | - Network and Loss: In this experiment, as we use [dice loss](http://tensorlayer.readthedocs.io/en/latest/modules/cost.html#dice-coefficient) to train a network, one network only predict one labels (Label 1,2 or 4). We evaluate the performance using [hard dice](http://tensorlayer.readthedocs.io/en/latest/modules/cost.html#hard-dice-coefficient) and [IOU](http://tensorlayer.readthedocs.io/en/latest/modules/cost.html#iou-coefficient). 40 | 41 | - Data augmenation: Includes random left and right flip, rotation, shifting, shearing, zooming and the most important one -- [Elastic trasnformation](http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html#elastic-transform), see ["Automatic Brain Tumor Detection and Segmentation Using U-Net Based Fully Convolutional Networks"](https://arxiv.org/pdf/1705.03820.pdf) for details. 42 | 43 |
44 | 45 |
46 | Fig 2: Data augmentation 47 |
48 | 49 | ### Start training 50 | 51 | We train HGG and LGG together, as one network only have one task, set the `task` to `all`, `necrotic`, `edema` or `enhance`, "all" means learn to segment all tumors. 52 | 53 | ``` 54 | python train.py --task=all 55 | ``` 56 | 57 | Note that, if the loss stick on 1 at the beginning, it means the network doesn't converge to near-perfect accuracy, please try restart it. 58 | 59 | ### Citation 60 | If you find this project useful, we would be grateful if you cite the TensorLayer paper: 61 | 62 | ``` 63 | @article{tensorlayer2017, 64 | author = {Dong, Hao and Supratak, Akara and Mai, Luo and Liu, Fangde and Oehmichen, Axel and Yu, Simiao and Guo, Yike}, 65 | journal = {ACM Multimedia}, 66 | title = {{TensorLayer: A Versatile Library for Efficient Deep Learning Development}}, 67 | url = {http://tensorlayer.org}, 68 | year = {2017} 69 | } 70 | ``` 71 | 72 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | import tensorflow as tf 5 | import tensorlayer as tl 6 | import numpy as np 7 | import os, time, model 8 | 9 | def distort_imgs(data): 10 | """ data augumentation """ 11 | x1, x2, x3, x4, y = data 12 | # x1, x2, x3, x4, y = tl.prepro.flip_axis_multi([x1, x2, x3, x4, y], # previous without this, hard-dice=83.7 13 | # axis=0, is_random=True) # up down 14 | x1, x2, x3, x4, y = tl.prepro.flip_axis_multi([x1, x2, x3, x4, y], 15 | axis=1, is_random=True) # left right 16 | x1, x2, x3, x4, y = tl.prepro.elastic_transform_multi([x1, x2, x3, x4, y], 17 | alpha=720, sigma=24, is_random=True) 18 | x1, x2, x3, x4, y = tl.prepro.rotation_multi([x1, x2, x3, x4, y], rg=20, 19 | is_random=True, fill_mode='constant') # nearest, constant 20 | x1, x2, x3, x4, y = tl.prepro.shift_multi([x1, x2, x3, x4, y], wrg=0.10, 21 | hrg=0.10, is_random=True, fill_mode='constant') 22 | x1, x2, x3, x4, y = tl.prepro.shear_multi([x1, x2, x3, x4, y], 0.05, 23 | is_random=True, fill_mode='constant') 24 | x1, x2, x3, x4, y = tl.prepro.zoom_multi([x1, x2, x3, x4, y], 25 | zoom_range=[0.9, 1.1], is_random=True, 26 | fill_mode='constant') 27 | return x1, x2, x3, x4, y 28 | 29 | def vis_imgs(X, y, path): 30 | """ show one slice """ 31 | if y.ndim == 2: 32 | y = y[:,:,np.newaxis] 33 | assert X.ndim == 3 34 | tl.vis.save_images(np.asarray([X[:,:,0,np.newaxis], 35 | X[:,:,1,np.newaxis], X[:,:,2,np.newaxis], 36 | X[:,:,3,np.newaxis], y]), size=(1, 5), 37 | image_path=path) 38 | 39 | def vis_imgs2(X, y_, y, path): 40 | """ show one slice with target """ 41 | if y.ndim == 2: 42 | y = y[:,:,np.newaxis] 43 | if y_.ndim == 2: 44 | y_ = y_[:,:,np.newaxis] 45 | assert X.ndim == 3 46 | tl.vis.save_images(np.asarray([X[:,:,0,np.newaxis], 47 | X[:,:,1,np.newaxis], X[:,:,2,np.newaxis], 48 | X[:,:,3,np.newaxis], y_, y]), size=(1, 6), 49 | image_path=path) 50 | 51 | def main(task='all'): 52 | ## Create folder to save trained model and result images 53 | save_dir = "checkpoint" 54 | tl.files.exists_or_mkdir(save_dir) 55 | tl.files.exists_or_mkdir("samples/{}".format(task)) 56 | 57 | ###======================== LOAD DATA ===================================### 58 | ## by importing this, you can load a training set and a validation set. 59 | # you will get X_train_input, X_train_target, X_dev_input and X_dev_target 60 | # there are 4 labels in targets: 61 | # Label 0: background 62 | # Label 1: necrotic and non-enhancing tumor 63 | # Label 2: edema 64 | # Label 4: enhancing tumor 65 | import prepare_data_with_valid as dataset 66 | X_train = dataset.X_train_input 67 | y_train = dataset.X_train_target[:,:,:,np.newaxis] 68 | X_test = dataset.X_dev_input 69 | y_test = dataset.X_dev_target[:,:,:,np.newaxis] 70 | 71 | if task == 'all': 72 | y_train = (y_train > 0).astype(int) 73 | y_test = (y_test > 0).astype(int) 74 | elif task == 'necrotic': 75 | y_train = (y_train == 1).astype(int) 76 | y_test = (y_test == 1).astype(int) 77 | elif task == 'edema': 78 | y_train = (y_train == 2).astype(int) 79 | y_test = (y_test == 2).astype(int) 80 | elif task == 'enhance': 81 | y_train = (y_train == 4).astype(int) 82 | y_test = (y_test == 4).astype(int) 83 | else: 84 | exit("Unknow task %s" % task) 85 | 86 | ###======================== HYPER-PARAMETERS ============================### 87 | batch_size = 10 88 | lr = 0.0001 89 | # lr_decay = 0.5 90 | # decay_every = 100 91 | beta1 = 0.9 92 | n_epoch = 100 93 | print_freq_step = 100 94 | 95 | ###======================== SHOW DATA ===================================### 96 | # show one slice 97 | X = np.asarray(X_train[80]) 98 | y = np.asarray(y_train[80]) 99 | # print(X.shape, X.min(), X.max()) # (240, 240, 4) -0.380588 2.62761 100 | # print(y.shape, y.min(), y.max()) # (240, 240, 1) 0 1 101 | nw, nh, nz = X.shape 102 | vis_imgs(X, y, 'samples/{}/_train_im.png'.format(task)) 103 | # show data augumentation results 104 | for i in range(10): 105 | x_flair, x_t1, x_t1ce, x_t2, label = distort_imgs([X[:,:,0,np.newaxis], X[:,:,1,np.newaxis], 106 | X[:,:,2,np.newaxis], X[:,:,3,np.newaxis], y])#[:,:,np.newaxis]]) 107 | # print(x_flair.shape, x_t1.shape, x_t1ce.shape, x_t2.shape, label.shape) # (240, 240, 1) (240, 240, 1) (240, 240, 1) (240, 240, 1) (240, 240, 1) 108 | X_dis = np.concatenate((x_flair, x_t1, x_t1ce, x_t2), axis=2) 109 | # print(X_dis.shape, X_dis.min(), X_dis.max()) # (240, 240, 4) -0.380588233471 2.62376139209 110 | vis_imgs(X_dis, label, 'samples/{}/_train_im_aug{}.png'.format(task, i)) 111 | 112 | with tf.device('/cpu:0'): 113 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 114 | with tf.device('/gpu:0'): #<- remove it if you train on CPU or other GPU 115 | ###======================== DEFIINE MODEL =======================### 116 | ## nz is 4 as we input all Flair, T1, T1c and T2. 117 | t_image = tf.placeholder('float32', [batch_size, nw, nh, nz], name='input_image') 118 | ## labels are either 0 or 1 119 | t_seg = tf.placeholder('float32', [batch_size, nw, nh, 1], name='target_segment') 120 | ## train inference 121 | net = model.u_net(t_image, is_train=True, reuse=False, n_out=1) 122 | ## test inference 123 | net_test = model.u_net(t_image, is_train=False, reuse=True, n_out=1) 124 | 125 | ###======================== DEFINE LOSS =========================### 126 | ## train losses 127 | out_seg = net.outputs 128 | dice_loss = 1 - tl.cost.dice_coe(out_seg, t_seg, axis=[0,1,2,3])#, 'jaccard', epsilon=1e-5) 129 | iou_loss = tl.cost.iou_coe(out_seg, t_seg, axis=[0,1,2,3]) 130 | dice_hard = tl.cost.dice_hard_coe(out_seg, t_seg, axis=[0,1,2,3]) 131 | loss = dice_loss 132 | 133 | ## test losses 134 | test_out_seg = net_test.outputs 135 | test_dice_loss = 1 - tl.cost.dice_coe(test_out_seg, t_seg, axis=[0,1,2,3])#, 'jaccard', epsilon=1e-5) 136 | test_iou_loss = tl.cost.iou_coe(test_out_seg, t_seg, axis=[0,1,2,3]) 137 | test_dice_hard = tl.cost.dice_hard_coe(test_out_seg, t_seg, axis=[0,1,2,3]) 138 | 139 | ###======================== DEFINE TRAIN OPTS =======================### 140 | t_vars = tl.layers.get_variables_with_name('u_net', True, True) 141 | with tf.device('/gpu:0'): 142 | with tf.variable_scope('learning_rate'): 143 | lr_v = tf.Variable(lr, trainable=False) 144 | train_op = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(loss, var_list=t_vars) 145 | 146 | ###======================== LOAD MODEL ==============================### 147 | tl.layers.initialize_global_variables(sess) 148 | ## load existing model if possible 149 | tl.files.load_and_assign_npz(sess=sess, name=save_dir+'/u_net_{}.npz'.format(task), network=net) 150 | 151 | ###======================== TRAINING ================================### 152 | for epoch in range(0, n_epoch+1): 153 | epoch_time = time.time() 154 | ## update decay learning rate at the beginning of a epoch 155 | # if epoch !=0 and (epoch % decay_every == 0): 156 | # new_lr_decay = lr_decay ** (epoch // decay_every) 157 | # sess.run(tf.assign(lr_v, lr * new_lr_decay)) 158 | # log = " ** new learning rate: %f" % (lr * new_lr_decay) 159 | # print(log) 160 | # elif epoch == 0: 161 | # sess.run(tf.assign(lr_v, lr)) 162 | # log = " ** init lr: %f decay_every_epoch: %d, lr_decay: %f" % (lr, decay_every, lr_decay) 163 | # print(log) 164 | 165 | total_dice, total_iou, total_dice_hard, n_batch = 0, 0, 0, 0 166 | for batch in tl.iterate.minibatches(inputs=X_train, targets=y_train, 167 | batch_size=batch_size, shuffle=True): 168 | images, labels = batch 169 | step_time = time.time() 170 | ## data augumentation for a batch of Flair, T1, T1c, T2 images 171 | # and label maps synchronously. 172 | data = tl.prepro.threading_data([_ for _ in zip(images[:,:,:,0, np.newaxis], 173 | images[:,:,:,1, np.newaxis], images[:,:,:,2, np.newaxis], 174 | images[:,:,:,3, np.newaxis], labels)], 175 | fn=distort_imgs) # (10, 5, 240, 240, 1) 176 | b_images = data[:,0:4,:,:,:] # (10, 4, 240, 240, 1) 177 | b_labels = data[:,4,:,:,:] 178 | b_images = b_images.transpose((0,2,3,1,4)) 179 | b_images.shape = (batch_size, nw, nh, nz) 180 | 181 | ## update network 182 | _, _dice, _iou, _diceh, out = sess.run([train_op, 183 | dice_loss, iou_loss, dice_hard, net.outputs], 184 | {t_image: b_images, t_seg: b_labels}) 185 | total_dice += _dice; total_iou += _iou; total_dice_hard += _diceh 186 | n_batch += 1 187 | 188 | ## you can show the predition here: 189 | # vis_imgs2(b_images[0], b_labels[0], out[0], "samples/{}/_tmp.png".format(task)) 190 | # exit() 191 | 192 | # if _dice == 1: # DEBUG 193 | # print("DEBUG") 194 | # vis_imgs2(b_images[0], b_labels[0], out[0], "samples/{}/_debug.png".format(task)) 195 | 196 | if n_batch % print_freq_step == 0: 197 | print("Epoch %d step %d 1-dice: %f hard-dice: %f iou: %f took %fs (2d with distortion)" 198 | % (epoch, n_batch, _dice, _diceh, _iou, time.time()-step_time)) 199 | 200 | ## check model fail 201 | if np.isnan(_dice): 202 | exit(" ** NaN loss found during training, stop training") 203 | if np.isnan(out).any(): 204 | exit(" ** NaN found in output images during training, stop training") 205 | 206 | print(" ** Epoch [%d/%d] train 1-dice: %f hard-dice: %f iou: %f took %fs (2d with distortion)" % 207 | (epoch, n_epoch, total_dice/n_batch, total_dice_hard/n_batch, total_iou/n_batch, time.time()-epoch_time)) 208 | 209 | ## save a predition of training set 210 | for i in range(batch_size): 211 | if np.max(b_images[i]) > 0: 212 | vis_imgs2(b_images[i], b_labels[i], out[i], "samples/{}/train_{}.png".format(task, epoch)) 213 | break 214 | elif i == batch_size-1: 215 | vis_imgs2(b_images[i], b_labels[i], out[i], "samples/{}/train_{}.png".format(task, epoch)) 216 | 217 | ###======================== EVALUATION ==========================### 218 | total_dice, total_iou, total_dice_hard, n_batch = 0, 0, 0, 0 219 | for batch in tl.iterate.minibatches(inputs=X_test, targets=y_test, 220 | batch_size=batch_size, shuffle=True): 221 | b_images, b_labels = batch 222 | _dice, _iou, _diceh, out = sess.run([test_dice_loss, 223 | test_iou_loss, test_dice_hard, net_test.outputs], 224 | {t_image: b_images, t_seg: b_labels}) 225 | total_dice += _dice; total_iou += _iou; total_dice_hard += _diceh 226 | n_batch += 1 227 | 228 | print(" **"+" "*17+"test 1-dice: %f hard-dice: %f iou: %f (2d no distortion)" % 229 | (total_dice/n_batch, total_dice_hard/n_batch, total_iou/n_batch)) 230 | print(" task: {}".format(task)) 231 | ## save a predition of test set 232 | for i in range(batch_size): 233 | if np.max(b_images[i]) > 0: 234 | vis_imgs2(b_images[i], b_labels[i], out[i], "samples/{}/test_{}.png".format(task, epoch)) 235 | break 236 | elif i == batch_size-1: 237 | vis_imgs2(b_images[i], b_labels[i], out[i], "samples/{}/test_{}.png".format(task, epoch)) 238 | 239 | ###======================== SAVE MODEL ==========================### 240 | tl.files.save_npz(net.all_params, name=save_dir+'/u_net_{}.npz'.format(task), sess=sess) 241 | 242 | if __name__ == "__main__": 243 | import argparse 244 | parser = argparse.ArgumentParser() 245 | 246 | parser.add_argument('--task', type=str, default='all', help='all, necrotic, edema, enhance') 247 | 248 | args = parser.parse_args() 249 | 250 | main(args.task) 251 | -------------------------------------------------------------------------------- /prepare_data_with_valid.py: -------------------------------------------------------------------------------- 1 | import tensorlayer as tl 2 | import numpy as np 3 | import os, csv, random, gc, pickle 4 | import nibabel as nib 5 | 6 | 7 | """ 8 | In seg file 9 | -------------- 10 | Label 1: necrotic and non-enhancing tumor 11 | Label 2: edema  12 | Label 4: enhancing tumor 13 | Label 0: background 14 | 15 | MRI 16 | ------- 17 | whole/complete tumor: 1 2 4 18 | core: 1 4 19 | enhance: 4 20 | """ 21 | ###============================= SETTINGS ===================================### 22 | DATA_SIZE = 'half' # (small, half or all) 23 | 24 | save_dir = "data/train_dev_all/" 25 | if not os.path.exists(save_dir): 26 | os.makedirs(save_dir) 27 | 28 | HGG_data_path = "data/Brats17TrainingData/HGG" 29 | LGG_data_path = "data/Brats17TrainingData/LGG" 30 | survival_csv_path = "data/Brats17TrainingData/survival_data.csv" 31 | ###==========================================================================### 32 | 33 | survival_id_list = [] 34 | survival_age_list =[] 35 | survival_peroid_list = [] 36 | 37 | with open(survival_csv_path, 'r') as f: 38 | reader = csv.reader(f) 39 | next(reader) 40 | for idx, content in enumerate(reader): 41 | survival_id_list.append(content[0]) 42 | survival_age_list.append(float(content[1])) 43 | survival_peroid_list.append(float(content[2])) 44 | 45 | print(len(survival_id_list)) #163 46 | 47 | if DATA_SIZE == 'all': 48 | HGG_path_list = tl.files.load_folder_list(path=HGG_data_path) 49 | LGG_path_list = tl.files.load_folder_list(path=LGG_data_path) 50 | elif DATA_SIZE == 'half': 51 | HGG_path_list = tl.files.load_folder_list(path=HGG_data_path)[0:100]# DEBUG WITH SMALL DATA 52 | LGG_path_list = tl.files.load_folder_list(path=LGG_data_path)[0:30] # DEBUG WITH SMALL DATA 53 | elif DATA_SIZE == 'small': 54 | HGG_path_list = tl.files.load_folder_list(path=HGG_data_path)[0:50] # DEBUG WITH SMALL DATA 55 | LGG_path_list = tl.files.load_folder_list(path=LGG_data_path)[0:20] # DEBUG WITH SMALL DATA 56 | else: 57 | exit("Unknow DATA_SIZE") 58 | print(len(HGG_path_list), len(LGG_path_list)) #210 #75 59 | 60 | HGG_name_list = [os.path.basename(p) for p in HGG_path_list] 61 | LGG_name_list = [os.path.basename(p) for p in LGG_path_list] 62 | 63 | survival_id_from_HGG = [] 64 | survival_id_from_LGG = [] 65 | for i in survival_id_list: 66 | if i in HGG_name_list: 67 | survival_id_from_HGG.append(i) 68 | elif i in LGG_name_list: 69 | survival_id_from_LGG.append(i) 70 | else: 71 | print(i) 72 | 73 | print(len(survival_id_from_HGG), len(survival_id_from_LGG)) #163, 0 74 | 75 | # use 42 from 210 (in 163 subset) and 15 from 75 as 0.8/0.2 train/dev split 76 | 77 | # use 126/42/42 from 210 (in 163 subset) and 45/15/15 from 75 as 0.6/0.2/0.2 train/dev/test split 78 | index_HGG = list(range(0, len(survival_id_from_HGG))) 79 | index_LGG = list(range(0, len(LGG_name_list))) 80 | # random.shuffle(index_HGG) 81 | # random.shuffle(index_HGG) 82 | 83 | if DATA_SIZE == 'all': 84 | dev_index_HGG = index_HGG[-84:-42] 85 | test_index_HGG = index_HGG[-42:] 86 | tr_index_HGG = index_HGG[:-84] 87 | dev_index_LGG = index_LGG[-30:-15] 88 | test_index_LGG = index_LGG[-15:] 89 | tr_index_LGG = index_LGG[:-30] 90 | elif DATA_SIZE == 'half': 91 | dev_index_HGG = index_HGG[-30:] # DEBUG WITH SMALL DATA 92 | test_index_HGG = index_HGG[-5:] 93 | tr_index_HGG = index_HGG[:-30] 94 | dev_index_LGG = index_LGG[-10:] # DEBUG WITH SMALL DATA 95 | test_index_LGG = index_LGG[-5:] 96 | tr_index_LGG = index_LGG[:-10] 97 | elif DATA_SIZE == 'small': 98 | dev_index_HGG = index_HGG[35:42] # DEBUG WITH SMALL DATA 99 | # print(index_HGG, dev_index_HGG) 100 | # exit() 101 | test_index_HGG = index_HGG[41:42] 102 | tr_index_HGG = index_HGG[0:35] 103 | dev_index_LGG = index_LGG[7:10] # DEBUG WITH SMALL DATA 104 | test_index_LGG = index_LGG[9:10] 105 | tr_index_LGG = index_LGG[0:7] 106 | 107 | survival_id_dev_HGG = [survival_id_from_HGG[i] for i in dev_index_HGG] 108 | survival_id_test_HGG = [survival_id_from_HGG[i] for i in test_index_HGG] 109 | survival_id_tr_HGG = [survival_id_from_HGG[i] for i in tr_index_HGG] 110 | 111 | survival_id_dev_LGG = [LGG_name_list[i] for i in dev_index_LGG] 112 | survival_id_test_LGG = [LGG_name_list[i] for i in test_index_LGG] 113 | survival_id_tr_LGG = [LGG_name_list[i] for i in tr_index_LGG] 114 | 115 | survival_age_dev = [survival_age_list[survival_id_list.index(i)] for i in survival_id_dev_HGG] 116 | survival_age_test = [survival_age_list[survival_id_list.index(i)] for i in survival_id_test_HGG] 117 | survival_age_tr = [survival_age_list[survival_id_list.index(i)] for i in survival_id_tr_HGG] 118 | 119 | survival_period_dev = [survival_peroid_list[survival_id_list.index(i)] for i in survival_id_dev_HGG] 120 | survival_period_test = [survival_peroid_list[survival_id_list.index(i)] for i in survival_id_test_HGG] 121 | survival_period_tr = [survival_peroid_list[survival_id_list.index(i)] for i in survival_id_tr_HGG] 122 | 123 | data_types = ['flair', 't1', 't1ce', 't2'] 124 | data_types_mean_std_dict = {i: {'mean': 0.0, 'std': 1.0} for i in data_types} 125 | 126 | # calculate mean and std for all data types 127 | 128 | # preserving_ratio = 0.0 129 | # preserving_ratio = 0.01 # 0.118 removed 130 | # preserving_ratio = 0.05 # 0.213 removed 131 | # preserving_ratio = 0.10 # 0.359 removed 132 | 133 | #==================== LOAD ALL IMAGES' PATH AND COMPUTE MEAN/ STD 134 | for i in data_types: 135 | data_temp_list = [] 136 | for j in HGG_name_list: 137 | img_path = os.path.join(HGG_data_path, j, j + '_' + i + '.nii.gz') 138 | img = nib.load(img_path).get_data() 139 | data_temp_list.append(img) 140 | 141 | for j in LGG_name_list: 142 | img_path = os.path.join(LGG_data_path, j, j + '_' + i + '.nii.gz') 143 | img = nib.load(img_path).get_data() 144 | data_temp_list.append(img) 145 | 146 | data_temp_list = np.asarray(data_temp_list) 147 | m = np.mean(data_temp_list) 148 | s = np.std(data_temp_list) 149 | data_types_mean_std_dict[i]['mean'] = m 150 | data_types_mean_std_dict[i]['std'] = s 151 | del data_temp_list 152 | print(data_types_mean_std_dict) 153 | 154 | with open(save_dir + 'mean_std_dict.pickle', 'wb') as f: 155 | pickle.dump(data_types_mean_std_dict, f, protocol=4) 156 | 157 | 158 | ##==================== GET NORMALIZE IMAGES 159 | X_train_input = [] 160 | X_train_target = [] 161 | # X_train_target_whole = [] # 1 2 4 162 | # X_train_target_core = [] # 1 4 163 | # X_train_target_enhance = [] # 4 164 | 165 | X_dev_input = [] 166 | X_dev_target = [] 167 | # X_dev_target_whole = [] # 1 2 4 168 | # X_dev_target_core = [] # 1 4 169 | # X_dev_target_enhance = [] # 4 170 | 171 | print(" HGG Validation") 172 | for i in survival_id_dev_HGG: 173 | all_3d_data = [] 174 | for j in data_types: 175 | img_path = os.path.join(HGG_data_path, i, i + '_' + j + '.nii.gz') 176 | img = nib.load(img_path).get_data() 177 | img = (img - data_types_mean_std_dict[j]['mean']) / data_types_mean_std_dict[j]['std'] 178 | img = img.astype(np.float32) 179 | all_3d_data.append(img) 180 | 181 | seg_path = os.path.join(HGG_data_path, i, i + '_seg.nii.gz') 182 | seg_img = nib.load(seg_path).get_data() 183 | seg_img = np.transpose(seg_img, (1, 0, 2)) 184 | for j in range(all_3d_data[0].shape[2]): 185 | combined_array = np.stack((all_3d_data[0][:, :, j], all_3d_data[1][:, :, j], all_3d_data[2][:, :, j], all_3d_data[3][:, :, j]), axis=2) 186 | combined_array = np.transpose(combined_array, (1, 0, 2))#.tolist() 187 | combined_array.astype(np.float32) 188 | X_dev_input.append(combined_array) 189 | 190 | seg_2d = seg_img[:, :, j] 191 | # whole = np.zeros_like(seg_2d) 192 | # core = np.zeros_like(seg_2d) 193 | # enhance = np.zeros_like(seg_2d) 194 | # for index, x in np.ndenumerate(seg_2d): 195 | # if x == 1: 196 | # whole[index] = 1 197 | # core[index] = 1 198 | # if x == 2: 199 | # whole[index] = 1 200 | # if x == 4: 201 | # whole[index] = 1 202 | # core[index] = 1 203 | # enhance[index] = 1 204 | # X_dev_target_whole.append(whole) 205 | # X_dev_target_core.append(core) 206 | # X_dev_target_enhance.append(enhance) 207 | seg_2d.astype(int) 208 | X_dev_target.append(seg_2d) 209 | del all_3d_data 210 | gc.collect() 211 | print("finished {}".format(i)) 212 | 213 | print(" LGG Validation") 214 | for i in survival_id_dev_LGG: 215 | all_3d_data = [] 216 | for j in data_types: 217 | img_path = os.path.join(LGG_data_path, i, i + '_' + j + '.nii.gz') 218 | img = nib.load(img_path).get_data() 219 | img = (img - data_types_mean_std_dict[j]['mean']) / data_types_mean_std_dict[j]['std'] 220 | img = img.astype(np.float32) 221 | all_3d_data.append(img) 222 | 223 | seg_path = os.path.join(LGG_data_path, i, i + '_seg.nii.gz') 224 | seg_img = nib.load(seg_path).get_data() 225 | seg_img = np.transpose(seg_img, (1, 0, 2)) 226 | for j in range(all_3d_data[0].shape[2]): 227 | combined_array = np.stack((all_3d_data[0][:, :, j], all_3d_data[1][:, :, j], all_3d_data[2][:, :, j], all_3d_data[3][:, :, j]), axis=2) 228 | combined_array = np.transpose(combined_array, (1, 0, 2))#.tolist() 229 | combined_array.astype(np.float32) 230 | X_dev_input.append(combined_array) 231 | 232 | seg_2d = seg_img[:, :, j] 233 | # whole = np.zeros_like(seg_2d) 234 | # core = np.zeros_like(seg_2d) 235 | # enhance = np.zeros_like(seg_2d) 236 | # for index, x in np.ndenumerate(seg_2d): 237 | # if x == 1: 238 | # whole[index] = 1 239 | # core[index] = 1 240 | # if x == 2: 241 | # whole[index] = 1 242 | # if x == 4: 243 | # whole[index] = 1 244 | # core[index] = 1 245 | # enhance[index] = 1 246 | # X_dev_target_whole.append(whole) 247 | # X_dev_target_core.append(core) 248 | # X_dev_target_enhance.append(enhance) 249 | seg_2d.astype(int) 250 | X_dev_target.append(seg_2d) 251 | del all_3d_data 252 | gc.collect() 253 | print("finished {}".format(i)) 254 | 255 | X_dev_input = np.asarray(X_dev_input, dtype=np.float32) 256 | X_dev_target = np.asarray(X_dev_target)#, dtype=np.float32) 257 | # print(X_dev_input.shape) 258 | # print(X_dev_target.shape) 259 | 260 | # with open(save_dir + 'dev_input.pickle', 'wb') as f: 261 | # pickle.dump(X_dev_input, f, protocol=4) 262 | # with open(save_dir + 'dev_target.pickle', 'wb') as f: 263 | # pickle.dump(X_dev_target, f, protocol=4) 264 | 265 | # del X_dev_input, X_dev_target 266 | 267 | print(" HGG Train") 268 | for i in survival_id_tr_HGG: 269 | all_3d_data = [] 270 | for j in data_types: 271 | img_path = os.path.join(HGG_data_path, i, i + '_' + j + '.nii.gz') 272 | img = nib.load(img_path).get_data() 273 | img = (img - data_types_mean_std_dict[j]['mean']) / data_types_mean_std_dict[j]['std'] 274 | img = img.astype(np.float32) 275 | all_3d_data.append(img) 276 | 277 | seg_path = os.path.join(HGG_data_path, i, i + '_seg.nii.gz') 278 | seg_img = nib.load(seg_path).get_data() 279 | seg_img = np.transpose(seg_img, (1, 0, 2)) 280 | for j in range(all_3d_data[0].shape[2]): 281 | combined_array = np.stack((all_3d_data[0][:, :, j], all_3d_data[1][:, :, j], all_3d_data[2][:, :, j], all_3d_data[3][:, :, j]), axis=2) 282 | combined_array = np.transpose(combined_array, (1, 0, 2))#.tolist() 283 | combined_array.astype(np.float32) 284 | X_train_input.append(combined_array) 285 | 286 | seg_2d = seg_img[:, :, j] 287 | # whole = np.zeros_like(seg_2d) 288 | # core = np.zeros_like(seg_2d) 289 | # enhance = np.zeros_like(seg_2d) 290 | # for index, x in np.ndenumerate(seg_2d): 291 | # if x == 1: 292 | # whole[index] = 1 293 | # core[index] = 1 294 | # if x == 2: 295 | # whole[index] = 1 296 | # if x == 4: 297 | # whole[index] = 1 298 | # core[index] = 1 299 | # enhance[index] = 1 300 | # X_train_target_whole.append(whole) 301 | # X_train_target_core.append(core) 302 | # X_train_target_enhance.append(enhance) 303 | seg_2d.astype(int) 304 | X_train_target.append(seg_2d) 305 | del all_3d_data 306 | print("finished {}".format(i)) 307 | # print(len(X_train_target)) 308 | 309 | 310 | print(" LGG Train") 311 | for i in survival_id_tr_LGG: 312 | all_3d_data = [] 313 | for j in data_types: 314 | img_path = os.path.join(LGG_data_path, i, i + '_' + j + '.nii.gz') 315 | img = nib.load(img_path).get_data() 316 | img = (img - data_types_mean_std_dict[j]['mean']) / data_types_mean_std_dict[j]['std'] 317 | img = img.astype(np.float32) 318 | all_3d_data.append(img) 319 | 320 | seg_path = os.path.join(LGG_data_path, i, i + '_seg.nii.gz') 321 | seg_img = nib.load(seg_path).get_data() 322 | seg_img = np.transpose(seg_img, (1, 0, 2)) 323 | for j in range(all_3d_data[0].shape[2]): 324 | combined_array = np.stack((all_3d_data[0][:, :, j], all_3d_data[1][:, :, j], all_3d_data[2][:, :, j], all_3d_data[3][:, :, j]), axis=2) 325 | combined_array = np.transpose(combined_array, (1, 0, 2))#.tolist() 326 | combined_array.astype(np.float32) 327 | X_train_input.append(combined_array) 328 | 329 | seg_2d = seg_img[:, :, j] 330 | # whole = np.zeros_like(seg_2d) 331 | # core = np.zeros_like(seg_2d) 332 | # enhance = np.zeros_like(seg_2d) 333 | # for index, x in np.ndenumerate(seg_2d): 334 | # if x == 1: 335 | # whole[index] = 1 336 | # core[index] = 1 337 | # if x == 2: 338 | # whole[index] = 1 339 | # if x == 4: 340 | # whole[index] = 1 341 | # core[index] = 1 342 | # enhance[index] = 1 343 | # X_train_target_whole.append(whole) 344 | # X_train_target_core.append(core) 345 | # X_train_target_enhance.append(enhance) 346 | seg_2d.astype(int) 347 | X_train_target.append(seg_2d) 348 | del all_3d_data 349 | print("finished {}".format(i)) 350 | 351 | X_train_input = np.asarray(X_train_input, dtype=np.float32) 352 | X_train_target = np.asarray(X_train_target)#, dtype=np.float32) 353 | # print(X_train_input.shape) 354 | # print(X_train_target.shape) 355 | 356 | # with open(save_dir + 'train_input.pickle', 'wb') as f: 357 | # pickle.dump(X_train_input, f, protocol=4) 358 | # with open(save_dir + 'train_target.pickle', 'wb') as f: 359 | # pickle.dump(X_train_target, f, protocol=4) 360 | 361 | 362 | 363 | # X_train_target_whole = np.asarray(X_train_target_whole) 364 | # X_train_target_core = np.asarray(X_train_target_core) 365 | # X_train_target_enhance = np.asarray(X_train_target_enhance) 366 | 367 | 368 | # X_dev_target_whole = np.asarray(X_dev_target_whole) 369 | # X_dev_target_core = np.asarray(X_dev_target_core) 370 | # X_dev_target_enhance = np.asarray(X_dev_target_enhance) 371 | 372 | 373 | # print(X_train_target_whole.shape) 374 | # print(X_train_target_core.shape) 375 | # print(X_train_target_enhance.shape) 376 | 377 | # print(X_dev_target_whole.shape) 378 | # print(X_dev_target_core.shape) 379 | # print(X_dev_target_enhance.shape) 380 | 381 | 382 | 383 | # with open(save_dir + 'train_target_whole.pickle', 'wb') as f: 384 | # pickle.dump(X_train_target_whole, f, protocol=4) 385 | 386 | # with open(save_dir + 'train_target_core.pickle', 'wb') as f: 387 | # pickle.dump(X_train_target_core, f, protocol=4) 388 | 389 | # with open(save_dir + 'train_target_enhance.pickle', 'wb') as f: 390 | # pickle.dump(X_train_target_enhance, f, protocol=4) 391 | 392 | # with open(save_dir + 'dev_target_whole.pickle', 'wb') as f: 393 | # pickle.dump(X_dev_target_whole, f, protocol=4) 394 | 395 | # with open(save_dir + 'dev_target_core.pickle', 'wb') as f: 396 | # pickle.dump(X_dev_target_core, f, protocol=4) 397 | 398 | # with open(save_dir + 'dev_target_enhance.pickle', 'wb') as f: 399 | # pickle.dump(X_dev_target_enhance, f, protocol=4) 400 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorlayer as tl 3 | from tensorlayer.layers import * 4 | import numpy as np 5 | 6 | 7 | from tensorlayer.layers import * 8 | def u_net(x, is_train=False, reuse=False, n_out=1): 9 | _, nx, ny, nz = x.get_shape().as_list() 10 | with tf.variable_scope("u_net", reuse=reuse): 11 | tl.layers.set_name_reuse(reuse) 12 | inputs = InputLayer(x, name='inputs') 13 | conv1 = Conv2d(inputs, 64, (3, 3), act=tf.nn.relu, name='conv1_1') 14 | conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, name='conv1_2') 15 | pool1 = MaxPool2d(conv1, (2, 2), name='pool1') 16 | conv2 = Conv2d(pool1, 128, (3, 3), act=tf.nn.relu, name='conv2_1') 17 | conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, name='conv2_2') 18 | pool2 = MaxPool2d(conv2, (2, 2), name='pool2') 19 | conv3 = Conv2d(pool2, 256, (3, 3), act=tf.nn.relu, name='conv3_1') 20 | conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, name='conv3_2') 21 | pool3 = MaxPool2d(conv3, (2, 2), name='pool3') 22 | conv4 = Conv2d(pool3, 512, (3, 3), act=tf.nn.relu, name='conv4_1') 23 | conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, name='conv4_2') 24 | pool4 = MaxPool2d(conv4, (2, 2), name='pool4') 25 | conv5 = Conv2d(pool4, 1024, (3, 3), act=tf.nn.relu, name='conv5_1') 26 | conv5 = Conv2d(conv5, 1024, (3, 3), act=tf.nn.relu, name='conv5_2') 27 | 28 | up4 = DeConv2d(conv5, 512, (3, 3), (nx/8, ny/8), (2, 2), name='deconv4') 29 | up4 = ConcatLayer([up4, conv4], 3, name='concat4') 30 | conv4 = Conv2d(up4, 512, (3, 3), act=tf.nn.relu, name='uconv4_1') 31 | conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, name='uconv4_2') 32 | up3 = DeConv2d(conv4, 256, (3, 3), (nx/4, ny/4), (2, 2), name='deconv3') 33 | up3 = ConcatLayer([up3, conv3], 3, name='concat3') 34 | conv3 = Conv2d(up3, 256, (3, 3), act=tf.nn.relu, name='uconv3_1') 35 | conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, name='uconv3_2') 36 | up2 = DeConv2d(conv3, 128, (3, 3), (nx/2, ny/2), (2, 2), name='deconv2') 37 | up2 = ConcatLayer([up2, conv2], 3, name='concat2') 38 | conv2 = Conv2d(up2, 128, (3, 3), act=tf.nn.relu, name='uconv2_1') 39 | conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, name='uconv2_2') 40 | up1 = DeConv2d(conv2, 64, (3, 3), (nx/1, ny/1), (2, 2), name='deconv1') 41 | up1 = ConcatLayer([up1, conv1] , 3, name='concat1') 42 | conv1 = Conv2d(up1, 64, (3, 3), act=tf.nn.relu, name='uconv1_1') 43 | conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, name='uconv1_2') 44 | conv1 = Conv2d(conv1, n_out, (1, 1), act=tf.nn.sigmoid, name='uconv1') 45 | return conv1 46 | 47 | # def u_net(x, is_train=False, reuse=False, pad='SAME', n_out=2): 48 | # """ Original U-Net for cell segmentataion 49 | # http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/ 50 | # Original x is [batch_size, 572, 572, ?], pad is VALID 51 | # """ 52 | # from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, DeConv2d, ConcatLayer 53 | # nx = int(x._shape[1]) 54 | # ny = int(x._shape[2]) 55 | # nz = int(x._shape[3]) 56 | # print(" * Input: size of image: %d %d %d" % (nx, ny, nz)) 57 | # 58 | # w_init = tf.truncated_normal_initializer(stddev=0.01) 59 | # b_init = tf.constant_initializer(value=0.0) 60 | # with tf.variable_scope("u_net", reuse=reuse): 61 | # tl.layers.set_name_reuse(reuse) 62 | # inputs = InputLayer(x, name='inputs') 63 | # 64 | # conv1 = Conv2d(inputs, 64, (3, 3), act=tf.nn.relu, padding=pad, 65 | # W_init=w_init, b_init=b_init, name='conv1_1') 66 | # conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding=pad, 67 | # W_init=w_init, b_init=b_init, name='conv1_2') 68 | # pool1 = MaxPool2d(conv1, (2, 2), padding=pad, name='pool1') 69 | # 70 | # conv2 = Conv2d(pool1, 128, (3, 3), act=tf.nn.relu, padding=pad, 71 | # W_init=w_init, b_init=b_init, name='conv2_1') 72 | # conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding=pad, 73 | # W_init=w_init, b_init=b_init, name='conv2_2') 74 | # pool2 = MaxPool2d(conv2, (2, 2), padding=pad, name='pool2') 75 | # 76 | # conv3 = Conv2d(pool2, 256, (3, 3), act=tf.nn.relu, padding=pad, 77 | # W_init=w_init, b_init=b_init, name='conv3_1') 78 | # conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding=pad, 79 | # W_init=w_init, b_init=b_init, name='conv3_2') 80 | # pool3 = MaxPool2d(conv3, (2, 2), padding=pad, name='pool3') 81 | # 82 | # conv4 = Conv2d(pool3, 512, (3, 3), act=tf.nn.relu, padding=pad, 83 | # W_init=w_init, b_init=b_init, name='conv4_1') 84 | # conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding=pad, 85 | # W_init=w_init, b_init=b_init, name='conv4_2') 86 | # pool4 = MaxPool2d(conv4, (2, 2), padding=pad, name='pool4') 87 | # 88 | # conv5 = Conv2d(pool4, 1024, (3, 3), act=tf.nn.relu, padding=pad, 89 | # W_init=w_init, b_init=b_init, name='conv5_1') 90 | # conv5 = Conv2d(conv5, 1024, (3, 3), act=tf.nn.relu, padding=pad, 91 | # W_init=w_init, b_init=b_init, name='conv5_2') 92 | # 93 | # print(" * After conv: %s" % conv5.outputs) 94 | # 95 | # up4 = DeConv2d(conv5, 512, (3, 3), out_size = (nx/8, ny/8), 96 | # strides=(2, 2), padding=pad, act=None, 97 | # W_init=w_init, b_init=b_init, name='deconv4') 98 | # up4 = ConcatLayer([up4, conv4], concat_dim=3, name='concat4') 99 | # conv4 = Conv2d(up4, 512, (3, 3), act=tf.nn.relu, padding=pad, 100 | # W_init=w_init, b_init=b_init, name='uconv4_1') 101 | # conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding=pad, 102 | # W_init=w_init, b_init=b_init, name='uconv4_2') 103 | # 104 | # up3 = DeConv2d(conv4, 256, (3, 3), out_size = (nx/4, ny/4), 105 | # strides=(2, 2), padding=pad, act=None, 106 | # W_init=w_init, b_init=b_init, name='deconv3') 107 | # up3 = ConcatLayer([up3, conv3], concat_dim=3, name='concat3') 108 | # conv3 = Conv2d(up3, 256, (3, 3), act=tf.nn.relu, padding=pad, 109 | # W_init=w_init, b_init=b_init, name='uconv3_1') 110 | # conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding=pad, 111 | # W_init=w_init, b_init=b_init, name='uconv3_2') 112 | # 113 | # up2 = DeConv2d(conv3, 128, (3, 3), out_size=(nx/2, ny/2), 114 | # strides=(2, 2), padding=pad, act=None, 115 | # W_init=w_init, b_init=b_init, name='deconv2') 116 | # up2 = ConcatLayer([up2, conv2] ,concat_dim=3, name='concat2') 117 | # conv2 = Conv2d(up2, 128, (3, 3), act=tf.nn.relu, padding=pad, 118 | # W_init=w_init, b_init=b_init, name='uconv2_1') 119 | # conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding=pad, 120 | # W_init=w_init, b_init=b_init, name='uconv2_2') 121 | # 122 | # up1 = DeConv2d(conv2, 64, (3, 3), out_size=(nx/1, ny/1), 123 | # strides=(2, 2), padding=pad, act=None, 124 | # W_init=w_init, b_init=b_init, name='deconv1') 125 | # up1 = ConcatLayer([up1, conv1] ,concat_dim=3, name='concat1') 126 | # conv1 = Conv2d(up1, 64, (3, 3), act=tf.nn.relu, padding=pad, 127 | # W_init=w_init, b_init=b_init, name='uconv1_1') 128 | # conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding=pad, 129 | # W_init=w_init, b_init=b_init, name='uconv1_2') 130 | # 131 | # conv1 = Conv2d(conv1, n_out, (1, 1), act=tf.nn.sigmoid, name='uconv1') 132 | # print(" * Output: %s" % conv1.outputs) 133 | # 134 | # # logits0 = conv1.outputs[:,:,:,0] # segmentataion 135 | # # logits1 = conv1.outputs[:,:,:,1] # edge 136 | # # logits0 = tf.expand_dims(logits0, axis=3) 137 | # # logits1 = tf.expand_dims(logits1, axis=3) 138 | # return conv1 139 | 140 | 141 | def u_net_bn(x, is_train=False, reuse=False, batch_size=None, pad='SAME', n_out=1): 142 | """image to image translation via conditional adversarial learning""" 143 | nx = int(x._shape[1]) 144 | ny = int(x._shape[2]) 145 | nz = int(x._shape[3]) 146 | print(" * Input: size of image: %d %d %d" % (nx, ny, nz)) 147 | 148 | w_init = tf.truncated_normal_initializer(stddev=0.01) 149 | b_init = tf.constant_initializer(value=0.0) 150 | gamma_init=tf.random_normal_initializer(1., 0.02) 151 | with tf.variable_scope("u_net", reuse=reuse): 152 | tl.layers.set_name_reuse(reuse) 153 | inputs = InputLayer(x, name='inputs') 154 | 155 | conv1 = Conv2d(inputs, 64, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv1') 156 | conv2 = Conv2d(conv1, 128, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv2') 157 | conv2 = BatchNormLayer(conv2, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn2') 158 | 159 | conv3 = Conv2d(conv2, 256, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv3') 160 | conv3 = BatchNormLayer(conv3, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn3') 161 | 162 | conv4 = Conv2d(conv3, 512, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv4') 163 | conv4 = BatchNormLayer(conv4, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn4') 164 | 165 | conv5 = Conv2d(conv4, 512, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv5') 166 | conv5 = BatchNormLayer(conv5, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn5') 167 | 168 | conv6 = Conv2d(conv5, 512, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv6') 169 | conv6 = BatchNormLayer(conv6, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn6') 170 | 171 | conv7 = Conv2d(conv6, 512, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv7') 172 | conv7 = BatchNormLayer(conv7, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn7') 173 | 174 | conv8 = Conv2d(conv7, 512, (4, 4), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2), padding=pad, W_init=w_init, b_init=b_init, name='conv8') 175 | print(" * After conv: %s" % conv8.outputs) 176 | # exit() 177 | # print(nx/8) 178 | up7 = DeConv2d(conv8, 512, (4, 4), out_size=(2, 2), strides=(2, 2), 179 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv7') 180 | up7 = BatchNormLayer(up7, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn7') 181 | 182 | # print(up6.outputs) 183 | up6 = ConcatLayer([up7, conv7], concat_dim=3, name='concat6') 184 | up6 = DeConv2d(up6, 1024, (4, 4), out_size=(4, 4), strides=(2, 2), 185 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv6') 186 | up6 = BatchNormLayer(up6, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn6') 187 | # print(up6.outputs) 188 | # exit() 189 | 190 | up5 = ConcatLayer([up6, conv6], concat_dim=3, name='concat5') 191 | up5 = DeConv2d(up5, 1024, (4, 4), out_size=(8, 8), strides=(2, 2), 192 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv5') 193 | up5 = BatchNormLayer(up5, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn5') 194 | # print(up5.outputs) 195 | # exit() 196 | 197 | up4 = ConcatLayer([up5, conv5] ,concat_dim=3, name='concat4') 198 | up4 = DeConv2d(up4, 1024, (4, 4), out_size=(15, 15), strides=(2, 2), 199 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv4') 200 | up4 = BatchNormLayer(up4, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn4') 201 | 202 | up3 = ConcatLayer([up4, conv4] ,concat_dim=3, name='concat3') 203 | up3 = DeConv2d(up3, 256, (4, 4), out_size=(30, 30), strides=(2, 2), 204 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv3') 205 | up3 = BatchNormLayer(up3, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn3') 206 | 207 | up2 = ConcatLayer([up3, conv3] ,concat_dim=3, name='concat2') 208 | up2 = DeConv2d(up2, 128, (4, 4), out_size=(60, 60), strides=(2, 2), 209 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv2') 210 | up2 = BatchNormLayer(up2, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn2') 211 | 212 | up1 = ConcatLayer([up2, conv2] ,concat_dim=3, name='concat1') 213 | up1 = DeConv2d(up1, 64, (4, 4), out_size=(120, 120), strides=(2, 2), 214 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv1') 215 | up1 = BatchNormLayer(up1, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn1') 216 | 217 | up0 = ConcatLayer([up1, conv1] ,concat_dim=3, name='concat0') 218 | up0 = DeConv2d(up0, 64, (4, 4), out_size=(240, 240), strides=(2, 2), 219 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv0') 220 | up0 = BatchNormLayer(up0, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn0') 221 | # print(up0.outputs) 222 | # exit() 223 | 224 | out = Conv2d(up0, n_out, (1, 1), act=tf.nn.sigmoid, name='out') 225 | 226 | print(" * Output: %s" % out.outputs) 227 | # exit() 228 | 229 | return out 230 | 231 | ## old implementation 232 | # def u_net_2d_64_1024_deconv(x, n_out=2): 233 | # from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, DeConv2d, ConcatLayer 234 | # nx = int(x._shape[1]) 235 | # ny = int(x._shape[2]) 236 | # nz = int(x._shape[3]) 237 | # print(" * Input: size of image: %d %d %d" % (nx, ny, nz)) 238 | # 239 | # w_init = tf.truncated_normal_initializer(stddev=0.01) 240 | # b_init = tf.constant_initializer(value=0.0) 241 | # inputs = InputLayer(x, name='inputs') 242 | # 243 | # conv1 = Conv2d(inputs, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_1') 244 | # conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_2') 245 | # pool1 = MaxPool2d(conv1, (2, 2), padding='SAME', name='pool1') 246 | # 247 | # conv2 = Conv2d(pool1, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_1') 248 | # conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_2') 249 | # pool2 = MaxPool2d(conv2, (2, 2), padding='SAME', name='pool2') 250 | # 251 | # conv3 = Conv2d(pool2, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_1') 252 | # conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_2') 253 | # pool3 = MaxPool2d(conv3, (2, 2), padding='SAME', name='pool3') 254 | # 255 | # conv4 = Conv2d(pool3, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_1') 256 | # conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_2') 257 | # pool4 = MaxPool2d(conv4, (2, 2), padding='SAME', name='pool4') 258 | # 259 | # conv5 = Conv2d(pool4, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_1') 260 | # conv5 = Conv2d(conv5, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_2') 261 | # 262 | # print(" * After conv: %s" % conv5.outputs) 263 | # 264 | # up4 = DeConv2d(conv5, 512, (3, 3), out_size = (nx/8, ny/8), strides = (2, 2), 265 | # padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv4') 266 | # up4 = ConcatLayer([up4, conv4], concat_dim=3, name='concat4') 267 | # conv4 = Conv2d(up4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv4_1') 268 | # conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv4_2') 269 | # 270 | # up3 = DeConv2d(conv4, 256, (3, 3), out_size = (nx/4, ny/4), strides = (2, 2), 271 | # padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv3') 272 | # up3 = ConcatLayer([up3, conv3], concat_dim=3, name='concat3') 273 | # conv3 = Conv2d(up3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv3_1') 274 | # conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv3_2') 275 | # 276 | # up2 = DeConv2d(conv3, 128, (3, 3), out_size = (nx/2, ny/2), strides = (2, 2), 277 | # padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv2') 278 | # up2 = ConcatLayer([up2, conv2] ,concat_dim=3, name='concat2') 279 | # conv2 = Conv2d(up2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv2_1') 280 | # conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv2_2') 281 | # 282 | # up1 = DeConv2d(conv2, 64, (3, 3), out_size = (nx/1, ny/1), strides = (2, 2), 283 | # padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv1') 284 | # up1 = ConcatLayer([up1, conv1] ,concat_dim=3, name='concat1') 285 | # conv1 = Conv2d(up1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv1_1') 286 | # conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv1_2') 287 | # 288 | # conv1 = Conv2d(conv1, n_out, (1, 1), act=None, name='uconv1') 289 | # print(" * Output: %s" % conv1.outputs) 290 | # outputs = tl.act.pixel_wise_softmax(conv1.outputs) 291 | # return conv1, outputs 292 | # 293 | # 294 | # def u_net_2d_32_1024_upsam(x, n_out=2): 295 | # """ 296 | # https://github.com/jocicmarko/ultrasound-nerve-segmentation 297 | # """ 298 | # from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, DeConv2d, ConcatLayer 299 | # batch_size = int(x._shape[0]) 300 | # nx = int(x._shape[1]) 301 | # ny = int(x._shape[2]) 302 | # nz = int(x._shape[3]) 303 | # print(" * Input: size of image: %d %d %d" % (nx, ny, nz)) 304 | # ## define initializer 305 | # w_init = tf.truncated_normal_initializer(stddev=0.01) 306 | # b_init = tf.constant_initializer(value=0.0) 307 | # inputs = InputLayer(x, name='inputs') 308 | # 309 | # conv1 = Conv2d(inputs, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_1') 310 | # conv1 = Conv2d(conv1, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_2') 311 | # pool1 = MaxPool2d(conv1, (2, 2), padding='SAME', name='pool1') 312 | # 313 | # conv2 = Conv2d(pool1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_1') 314 | # conv2 = Conv2d(conv2, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_2') 315 | # pool2 = MaxPool2d(conv2, (2,2), padding='SAME', name='pool2') 316 | # 317 | # conv3 = Conv2d(pool2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_1') 318 | # conv3 = Conv2d(conv3, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_2') 319 | # pool3 = MaxPool2d(conv3, (2, 2), padding='SAME', name='pool3') 320 | # 321 | # conv4 = Conv2d(pool3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_1') 322 | # conv4 = Conv2d(conv4, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_2') 323 | # pool4 = MaxPool2d(conv4, (2, 2), padding='SAME', name='pool4') 324 | # 325 | # conv5 = Conv2d(pool4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_1') 326 | # conv5 = Conv2d(conv5, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_2') 327 | # pool5 = MaxPool2d(conv5, (2, 2), padding='SAME', name='pool6') 328 | # 329 | # # hao add 330 | # conv6 = Conv2d(pool5, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv6_1') 331 | # conv6 = Conv2d(conv6, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv6_2') 332 | # 333 | # print(" * After conv: %s" % conv6.outputs) 334 | # 335 | # # hao add 336 | # up7 = UpSampling2dLayer(conv6, (15, 15), is_scale=False, method=1, name='up7') 337 | # up7 = ConcatLayer([up7, conv5], concat_dim=3, name='concat7') 338 | # conv7 = Conv2d(up7, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv7_1') 339 | # conv7 = Conv2d(conv7, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv7_2') 340 | # 341 | # # print(nx/8,ny/8) # 30 30 342 | # up8 = UpSampling2dLayer(conv7, (2, 2), method=1, name='up8') 343 | # up8 = ConcatLayer([up8, conv4], concat_dim=3, name='concat8') 344 | # conv8 = Conv2d(up8, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv8_1') 345 | # conv8 = Conv2d(conv8, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv8_2') 346 | # 347 | # up9 = UpSampling2dLayer(conv8, (2, 2), method=1, name='up9') 348 | # up9 = ConcatLayer([up9, conv3] ,concat_dim=3, name='concat9') 349 | # conv9 = Conv2d(up9, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv9_1') 350 | # conv9 = Conv2d(conv9, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv9_2') 351 | # 352 | # up10 = UpSampling2dLayer(conv9, (2, 2), method=1, name='up10') 353 | # up10 = ConcatLayer([up10, conv2] ,concat_dim=3, name='concat10') 354 | # conv10 = Conv2d(up10, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv10_1') 355 | # conv10 = Conv2d(conv10, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv10_2') 356 | # 357 | # up11 = UpSampling2dLayer(conv10, (2, 2), method=1, name='up11') 358 | # up11 = ConcatLayer([up11, conv1] ,concat_dim=3, name='concat11') 359 | # conv11 = Conv2d(up11, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv11_1') 360 | # conv11 = Conv2d(conv11, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv11_2') 361 | # 362 | # conv12 = Conv2d(conv11, n_out, (1, 1), act=None, name='conv12') 363 | # print(" * Output: %s" % conv12.outputs) 364 | # outputs = tl.act.pixel_wise_softmax(conv12.outputs) 365 | # return conv10, outputs 366 | # 367 | # 368 | # def u_net_2d_32_512_upsam(x, n_out=2): 369 | # """ 370 | # https://github.com/jocicmarko/ultrasound-nerve-segmentation 371 | # """ 372 | # from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, DeConv2d, ConcatLayer 373 | # batch_size = int(x._shape[0]) 374 | # nx = int(x._shape[1]) 375 | # ny = int(x._shape[2]) 376 | # nz = int(x._shape[3]) 377 | # print(" * Input: size of image: %d %d %d" % (nx, ny, nz)) 378 | # ## define initializer 379 | # w_init = tf.truncated_normal_initializer(stddev=0.01) 380 | # b_init = tf.constant_initializer(value=0.0) 381 | # inputs = InputLayer(x, name='inputs') 382 | # # inputs = Input((1, img_rows, img_cols)) 383 | # conv1 = Conv2d(inputs, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_1') 384 | # # print(conv1.outputs) # (10, 240, 240, 32) 385 | # # conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs) 386 | # conv1 = Conv2d(conv1, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_2') 387 | # # print(conv1.outputs) # (10, 240, 240, 32) 388 | # # conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1) 389 | # pool1 = MaxPool2d(conv1, (2, 2), padding='SAME', name='pool1') 390 | # # pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 391 | # # print(pool1.outputs) # (10, 120, 120, 32) 392 | # # exit() 393 | # conv2 = Conv2d(pool1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_1') 394 | # # conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1) 395 | # conv2 = Conv2d(conv2, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_2') 396 | # # conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2) 397 | # pool2 = MaxPool2d(conv2, (2,2), padding='SAME', name='pool2') 398 | # # pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 399 | # 400 | # conv3 = Conv2d(pool2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_1') 401 | # # conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2) 402 | # conv3 = Conv2d(conv3, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_2') 403 | # # conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3) 404 | # pool3 = MaxPool2d(conv3, (2, 2), padding='SAME', name='pool3') 405 | # # pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 406 | # # print(pool3.outputs) # (10, 30, 30, 64) 407 | # 408 | # conv4 = Conv2d(pool3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_1') 409 | # # print(conv4.outputs) # (10, 30, 30, 256) 410 | # # conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(pool3) 411 | # conv4 = Conv2d(conv4, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_2') 412 | # # print(conv4.outputs) # (10, 30, 30, 256) != (10, 30, 30, 512) 413 | # # conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv4) 414 | # pool4 = MaxPool2d(conv4, (2, 2), padding='SAME', name='pool4') 415 | # # pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 416 | # 417 | # conv5 = Conv2d(pool4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_1') 418 | # # conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(pool4) 419 | # conv5 = Conv2d(conv5, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_2') 420 | # # conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(conv5) 421 | # # print(conv5.outputs) # (10, 15, 15, 512) 422 | # print(" * After conv: %s" % conv5.outputs) 423 | # # print(nx/8,ny/8) # 30 30 424 | # up6 = UpSampling2dLayer(conv5, (2, 2), name='up6') 425 | # # print(up6.outputs) # (10, 30, 30, 512) == (10, 30, 30, 512) 426 | # up6 = ConcatLayer([up6, conv4], concat_dim=3, name='concat6') 427 | # # print(up6.outputs) # (10, 30, 30, 768) 428 | # # up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1) 429 | # conv6 = Conv2d(up6, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv6_1') 430 | # # conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(up6) 431 | # conv6 = Conv2d(conv6, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv6_2') 432 | # # conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv6) 433 | # 434 | # up7 = UpSampling2dLayer(conv6, (2, 2), name='up7') 435 | # up7 = ConcatLayer([up7, conv3] ,concat_dim=3, name='concat7') 436 | # # up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1) 437 | # conv7 = Conv2d(up7, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv7_1') 438 | # # conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(up7) 439 | # conv7 = Conv2d(conv7, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv7_2') 440 | # # conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv7) 441 | # 442 | # up8 = UpSampling2dLayer(conv7, (2, 2), name='up8') 443 | # up8 = ConcatLayer([up8, conv2] ,concat_dim=3, name='concat8') 444 | # # up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1) 445 | # conv8 = Conv2d(up8, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv8_1') 446 | # # conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up8) 447 | # conv8 = Conv2d(conv8, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv8_2') 448 | # # conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv8) 449 | # 450 | # up9 = UpSampling2dLayer(conv8, (2, 2), name='up9') 451 | # up9 = ConcatLayer([up9, conv1] ,concat_dim=3, name='concat9') 452 | # # up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1) 453 | # conv9 = Conv2d(up9, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv9_1') 454 | # # conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up9) 455 | # conv9 = Conv2d(conv9, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv9_2') 456 | # # conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9) 457 | # 458 | # conv10 = Conv2d(conv9, n_out, (1, 1), act=None, name='conv9') 459 | # # conv10 = Convolution2D(1, 1, 1, activation='sigmoid')(conv9) 460 | # print(" * Output: %s" % conv10.outputs) 461 | # outputs = tl.act.pixel_wise_softmax(conv10.outputs) 462 | # return conv10, outputs 463 | 464 | 465 | if __name__ == "__main__": 466 | pass 467 | # main() 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | # 488 | --------------------------------------------------------------------------------