├── example
├── brain_tumor_aug.png
├── brain_tumor_aug.pptx
├── brain_tumor_data.png
└── brain_tumor_data.pptx
├── .gitignore
├── README.md
├── train.py
├── prepare_data_with_valid.py
└── model.py
/example/brain_tumor_aug.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsdonghao/u-net-brain-tumor/HEAD/example/brain_tumor_aug.png
--------------------------------------------------------------------------------
/example/brain_tumor_aug.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsdonghao/u-net-brain-tumor/HEAD/example/brain_tumor_aug.pptx
--------------------------------------------------------------------------------
/example/brain_tumor_data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsdonghao/u-net-brain-tumor/HEAD/example/brain_tumor_data.png
--------------------------------------------------------------------------------
/example/brain_tumor_data.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsdonghao/u-net-brain-tumor/HEAD/example/brain_tumor_data.pptx
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | orlaye/__pacache__
2 | tensorlaye/.DS_Store
3 | .DS_Store
4 | dist
5 | build/
6 | tensorlayer.egg-info
7 | data/.DS_Store
8 | *.pyc
9 | *.gz
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # U-Net Brain Tumor Segmentation
2 |
3 | 🚀:Feb 2019 the data processing implementation in this repo is not the fastest way (code need update, contribution is welcome), you can use TensorFlow dataset API instead.
4 |
5 | This repo show you how to train a U-Net for brain tumor segmentation. By default, you need to download the training set of [BRATS 2017](http://braintumorsegmentation.org) dataset, which have 210 HGG and 75 LGG volumes, and put the data folder along with all scripts.
6 |
7 | ```bash
8 | data
9 | -- Brats17TrainingData
10 | -- train_dev_all
11 | model.py
12 | train.py
13 | ...
14 | ```
15 |
16 | ### About the data
17 | Note that according to the license, user have to apply the dataset from BRAST, please do **NOT** contact me for the dataset. Many thanks.
18 |
19 |
20 |

21 |
22 |
Fig 1: Brain Image
23 |
24 |
25 | * Each volume have 4 scanning images: FLAIR、T1、T1c and T2.
26 | * Each volume have 4 segmentation labels:
27 |
28 | ```
29 | Label 0: background
30 | Label 1: necrotic and non-enhancing tumor
31 | Label 2: edema
32 | Label 4: enhancing tumor
33 | ```
34 |
35 | The `prepare_data_with_valid.py` split the training set into 2 folds for training and validating. By default, it will use only half of the data for the sake of training speed, if you want to use all data, just change `DATA_SIZE = 'half'` to `all`.
36 |
37 | ### About the method
38 |
39 | - Network and Loss: In this experiment, as we use [dice loss](http://tensorlayer.readthedocs.io/en/latest/modules/cost.html#dice-coefficient) to train a network, one network only predict one labels (Label 1,2 or 4). We evaluate the performance using [hard dice](http://tensorlayer.readthedocs.io/en/latest/modules/cost.html#hard-dice-coefficient) and [IOU](http://tensorlayer.readthedocs.io/en/latest/modules/cost.html#iou-coefficient).
40 |
41 | - Data augmenation: Includes random left and right flip, rotation, shifting, shearing, zooming and the most important one -- [Elastic trasnformation](http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html#elastic-transform), see ["Automatic Brain Tumor Detection and Segmentation Using U-Net Based Fully Convolutional Networks"](https://arxiv.org/pdf/1705.03820.pdf) for details.
42 |
43 |
44 |

45 |
46 |
Fig 2: Data augmentation
47 |
48 |
49 | ### Start training
50 |
51 | We train HGG and LGG together, as one network only have one task, set the `task` to `all`, `necrotic`, `edema` or `enhance`, "all" means learn to segment all tumors.
52 |
53 | ```
54 | python train.py --task=all
55 | ```
56 |
57 | Note that, if the loss stick on 1 at the beginning, it means the network doesn't converge to near-perfect accuracy, please try restart it.
58 |
59 | ### Citation
60 | If you find this project useful, we would be grateful if you cite the TensorLayer paper:
61 |
62 | ```
63 | @article{tensorlayer2017,
64 | author = {Dong, Hao and Supratak, Akara and Mai, Luo and Liu, Fangde and Oehmichen, Axel and Yu, Simiao and Guo, Yike},
65 | journal = {ACM Multimedia},
66 | title = {{TensorLayer: A Versatile Library for Efficient Deep Learning Development}},
67 | url = {http://tensorlayer.org},
68 | year = {2017}
69 | }
70 | ```
71 |
72 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- coding: utf8 -*-
3 |
4 | import tensorflow as tf
5 | import tensorlayer as tl
6 | import numpy as np
7 | import os, time, model
8 |
9 | def distort_imgs(data):
10 | """ data augumentation """
11 | x1, x2, x3, x4, y = data
12 | # x1, x2, x3, x4, y = tl.prepro.flip_axis_multi([x1, x2, x3, x4, y], # previous without this, hard-dice=83.7
13 | # axis=0, is_random=True) # up down
14 | x1, x2, x3, x4, y = tl.prepro.flip_axis_multi([x1, x2, x3, x4, y],
15 | axis=1, is_random=True) # left right
16 | x1, x2, x3, x4, y = tl.prepro.elastic_transform_multi([x1, x2, x3, x4, y],
17 | alpha=720, sigma=24, is_random=True)
18 | x1, x2, x3, x4, y = tl.prepro.rotation_multi([x1, x2, x3, x4, y], rg=20,
19 | is_random=True, fill_mode='constant') # nearest, constant
20 | x1, x2, x3, x4, y = tl.prepro.shift_multi([x1, x2, x3, x4, y], wrg=0.10,
21 | hrg=0.10, is_random=True, fill_mode='constant')
22 | x1, x2, x3, x4, y = tl.prepro.shear_multi([x1, x2, x3, x4, y], 0.05,
23 | is_random=True, fill_mode='constant')
24 | x1, x2, x3, x4, y = tl.prepro.zoom_multi([x1, x2, x3, x4, y],
25 | zoom_range=[0.9, 1.1], is_random=True,
26 | fill_mode='constant')
27 | return x1, x2, x3, x4, y
28 |
29 | def vis_imgs(X, y, path):
30 | """ show one slice """
31 | if y.ndim == 2:
32 | y = y[:,:,np.newaxis]
33 | assert X.ndim == 3
34 | tl.vis.save_images(np.asarray([X[:,:,0,np.newaxis],
35 | X[:,:,1,np.newaxis], X[:,:,2,np.newaxis],
36 | X[:,:,3,np.newaxis], y]), size=(1, 5),
37 | image_path=path)
38 |
39 | def vis_imgs2(X, y_, y, path):
40 | """ show one slice with target """
41 | if y.ndim == 2:
42 | y = y[:,:,np.newaxis]
43 | if y_.ndim == 2:
44 | y_ = y_[:,:,np.newaxis]
45 | assert X.ndim == 3
46 | tl.vis.save_images(np.asarray([X[:,:,0,np.newaxis],
47 | X[:,:,1,np.newaxis], X[:,:,2,np.newaxis],
48 | X[:,:,3,np.newaxis], y_, y]), size=(1, 6),
49 | image_path=path)
50 |
51 | def main(task='all'):
52 | ## Create folder to save trained model and result images
53 | save_dir = "checkpoint"
54 | tl.files.exists_or_mkdir(save_dir)
55 | tl.files.exists_or_mkdir("samples/{}".format(task))
56 |
57 | ###======================== LOAD DATA ===================================###
58 | ## by importing this, you can load a training set and a validation set.
59 | # you will get X_train_input, X_train_target, X_dev_input and X_dev_target
60 | # there are 4 labels in targets:
61 | # Label 0: background
62 | # Label 1: necrotic and non-enhancing tumor
63 | # Label 2: edema
64 | # Label 4: enhancing tumor
65 | import prepare_data_with_valid as dataset
66 | X_train = dataset.X_train_input
67 | y_train = dataset.X_train_target[:,:,:,np.newaxis]
68 | X_test = dataset.X_dev_input
69 | y_test = dataset.X_dev_target[:,:,:,np.newaxis]
70 |
71 | if task == 'all':
72 | y_train = (y_train > 0).astype(int)
73 | y_test = (y_test > 0).astype(int)
74 | elif task == 'necrotic':
75 | y_train = (y_train == 1).astype(int)
76 | y_test = (y_test == 1).astype(int)
77 | elif task == 'edema':
78 | y_train = (y_train == 2).astype(int)
79 | y_test = (y_test == 2).astype(int)
80 | elif task == 'enhance':
81 | y_train = (y_train == 4).astype(int)
82 | y_test = (y_test == 4).astype(int)
83 | else:
84 | exit("Unknow task %s" % task)
85 |
86 | ###======================== HYPER-PARAMETERS ============================###
87 | batch_size = 10
88 | lr = 0.0001
89 | # lr_decay = 0.5
90 | # decay_every = 100
91 | beta1 = 0.9
92 | n_epoch = 100
93 | print_freq_step = 100
94 |
95 | ###======================== SHOW DATA ===================================###
96 | # show one slice
97 | X = np.asarray(X_train[80])
98 | y = np.asarray(y_train[80])
99 | # print(X.shape, X.min(), X.max()) # (240, 240, 4) -0.380588 2.62761
100 | # print(y.shape, y.min(), y.max()) # (240, 240, 1) 0 1
101 | nw, nh, nz = X.shape
102 | vis_imgs(X, y, 'samples/{}/_train_im.png'.format(task))
103 | # show data augumentation results
104 | for i in range(10):
105 | x_flair, x_t1, x_t1ce, x_t2, label = distort_imgs([X[:,:,0,np.newaxis], X[:,:,1,np.newaxis],
106 | X[:,:,2,np.newaxis], X[:,:,3,np.newaxis], y])#[:,:,np.newaxis]])
107 | # print(x_flair.shape, x_t1.shape, x_t1ce.shape, x_t2.shape, label.shape) # (240, 240, 1) (240, 240, 1) (240, 240, 1) (240, 240, 1) (240, 240, 1)
108 | X_dis = np.concatenate((x_flair, x_t1, x_t1ce, x_t2), axis=2)
109 | # print(X_dis.shape, X_dis.min(), X_dis.max()) # (240, 240, 4) -0.380588233471 2.62376139209
110 | vis_imgs(X_dis, label, 'samples/{}/_train_im_aug{}.png'.format(task, i))
111 |
112 | with tf.device('/cpu:0'):
113 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
114 | with tf.device('/gpu:0'): #<- remove it if you train on CPU or other GPU
115 | ###======================== DEFIINE MODEL =======================###
116 | ## nz is 4 as we input all Flair, T1, T1c and T2.
117 | t_image = tf.placeholder('float32', [batch_size, nw, nh, nz], name='input_image')
118 | ## labels are either 0 or 1
119 | t_seg = tf.placeholder('float32', [batch_size, nw, nh, 1], name='target_segment')
120 | ## train inference
121 | net = model.u_net(t_image, is_train=True, reuse=False, n_out=1)
122 | ## test inference
123 | net_test = model.u_net(t_image, is_train=False, reuse=True, n_out=1)
124 |
125 | ###======================== DEFINE LOSS =========================###
126 | ## train losses
127 | out_seg = net.outputs
128 | dice_loss = 1 - tl.cost.dice_coe(out_seg, t_seg, axis=[0,1,2,3])#, 'jaccard', epsilon=1e-5)
129 | iou_loss = tl.cost.iou_coe(out_seg, t_seg, axis=[0,1,2,3])
130 | dice_hard = tl.cost.dice_hard_coe(out_seg, t_seg, axis=[0,1,2,3])
131 | loss = dice_loss
132 |
133 | ## test losses
134 | test_out_seg = net_test.outputs
135 | test_dice_loss = 1 - tl.cost.dice_coe(test_out_seg, t_seg, axis=[0,1,2,3])#, 'jaccard', epsilon=1e-5)
136 | test_iou_loss = tl.cost.iou_coe(test_out_seg, t_seg, axis=[0,1,2,3])
137 | test_dice_hard = tl.cost.dice_hard_coe(test_out_seg, t_seg, axis=[0,1,2,3])
138 |
139 | ###======================== DEFINE TRAIN OPTS =======================###
140 | t_vars = tl.layers.get_variables_with_name('u_net', True, True)
141 | with tf.device('/gpu:0'):
142 | with tf.variable_scope('learning_rate'):
143 | lr_v = tf.Variable(lr, trainable=False)
144 | train_op = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(loss, var_list=t_vars)
145 |
146 | ###======================== LOAD MODEL ==============================###
147 | tl.layers.initialize_global_variables(sess)
148 | ## load existing model if possible
149 | tl.files.load_and_assign_npz(sess=sess, name=save_dir+'/u_net_{}.npz'.format(task), network=net)
150 |
151 | ###======================== TRAINING ================================###
152 | for epoch in range(0, n_epoch+1):
153 | epoch_time = time.time()
154 | ## update decay learning rate at the beginning of a epoch
155 | # if epoch !=0 and (epoch % decay_every == 0):
156 | # new_lr_decay = lr_decay ** (epoch // decay_every)
157 | # sess.run(tf.assign(lr_v, lr * new_lr_decay))
158 | # log = " ** new learning rate: %f" % (lr * new_lr_decay)
159 | # print(log)
160 | # elif epoch == 0:
161 | # sess.run(tf.assign(lr_v, lr))
162 | # log = " ** init lr: %f decay_every_epoch: %d, lr_decay: %f" % (lr, decay_every, lr_decay)
163 | # print(log)
164 |
165 | total_dice, total_iou, total_dice_hard, n_batch = 0, 0, 0, 0
166 | for batch in tl.iterate.minibatches(inputs=X_train, targets=y_train,
167 | batch_size=batch_size, shuffle=True):
168 | images, labels = batch
169 | step_time = time.time()
170 | ## data augumentation for a batch of Flair, T1, T1c, T2 images
171 | # and label maps synchronously.
172 | data = tl.prepro.threading_data([_ for _ in zip(images[:,:,:,0, np.newaxis],
173 | images[:,:,:,1, np.newaxis], images[:,:,:,2, np.newaxis],
174 | images[:,:,:,3, np.newaxis], labels)],
175 | fn=distort_imgs) # (10, 5, 240, 240, 1)
176 | b_images = data[:,0:4,:,:,:] # (10, 4, 240, 240, 1)
177 | b_labels = data[:,4,:,:,:]
178 | b_images = b_images.transpose((0,2,3,1,4))
179 | b_images.shape = (batch_size, nw, nh, nz)
180 |
181 | ## update network
182 | _, _dice, _iou, _diceh, out = sess.run([train_op,
183 | dice_loss, iou_loss, dice_hard, net.outputs],
184 | {t_image: b_images, t_seg: b_labels})
185 | total_dice += _dice; total_iou += _iou; total_dice_hard += _diceh
186 | n_batch += 1
187 |
188 | ## you can show the predition here:
189 | # vis_imgs2(b_images[0], b_labels[0], out[0], "samples/{}/_tmp.png".format(task))
190 | # exit()
191 |
192 | # if _dice == 1: # DEBUG
193 | # print("DEBUG")
194 | # vis_imgs2(b_images[0], b_labels[0], out[0], "samples/{}/_debug.png".format(task))
195 |
196 | if n_batch % print_freq_step == 0:
197 | print("Epoch %d step %d 1-dice: %f hard-dice: %f iou: %f took %fs (2d with distortion)"
198 | % (epoch, n_batch, _dice, _diceh, _iou, time.time()-step_time))
199 |
200 | ## check model fail
201 | if np.isnan(_dice):
202 | exit(" ** NaN loss found during training, stop training")
203 | if np.isnan(out).any():
204 | exit(" ** NaN found in output images during training, stop training")
205 |
206 | print(" ** Epoch [%d/%d] train 1-dice: %f hard-dice: %f iou: %f took %fs (2d with distortion)" %
207 | (epoch, n_epoch, total_dice/n_batch, total_dice_hard/n_batch, total_iou/n_batch, time.time()-epoch_time))
208 |
209 | ## save a predition of training set
210 | for i in range(batch_size):
211 | if np.max(b_images[i]) > 0:
212 | vis_imgs2(b_images[i], b_labels[i], out[i], "samples/{}/train_{}.png".format(task, epoch))
213 | break
214 | elif i == batch_size-1:
215 | vis_imgs2(b_images[i], b_labels[i], out[i], "samples/{}/train_{}.png".format(task, epoch))
216 |
217 | ###======================== EVALUATION ==========================###
218 | total_dice, total_iou, total_dice_hard, n_batch = 0, 0, 0, 0
219 | for batch in tl.iterate.minibatches(inputs=X_test, targets=y_test,
220 | batch_size=batch_size, shuffle=True):
221 | b_images, b_labels = batch
222 | _dice, _iou, _diceh, out = sess.run([test_dice_loss,
223 | test_iou_loss, test_dice_hard, net_test.outputs],
224 | {t_image: b_images, t_seg: b_labels})
225 | total_dice += _dice; total_iou += _iou; total_dice_hard += _diceh
226 | n_batch += 1
227 |
228 | print(" **"+" "*17+"test 1-dice: %f hard-dice: %f iou: %f (2d no distortion)" %
229 | (total_dice/n_batch, total_dice_hard/n_batch, total_iou/n_batch))
230 | print(" task: {}".format(task))
231 | ## save a predition of test set
232 | for i in range(batch_size):
233 | if np.max(b_images[i]) > 0:
234 | vis_imgs2(b_images[i], b_labels[i], out[i], "samples/{}/test_{}.png".format(task, epoch))
235 | break
236 | elif i == batch_size-1:
237 | vis_imgs2(b_images[i], b_labels[i], out[i], "samples/{}/test_{}.png".format(task, epoch))
238 |
239 | ###======================== SAVE MODEL ==========================###
240 | tl.files.save_npz(net.all_params, name=save_dir+'/u_net_{}.npz'.format(task), sess=sess)
241 |
242 | if __name__ == "__main__":
243 | import argparse
244 | parser = argparse.ArgumentParser()
245 |
246 | parser.add_argument('--task', type=str, default='all', help='all, necrotic, edema, enhance')
247 |
248 | args = parser.parse_args()
249 |
250 | main(args.task)
251 |
--------------------------------------------------------------------------------
/prepare_data_with_valid.py:
--------------------------------------------------------------------------------
1 | import tensorlayer as tl
2 | import numpy as np
3 | import os, csv, random, gc, pickle
4 | import nibabel as nib
5 |
6 |
7 | """
8 | In seg file
9 | --------------
10 | Label 1: necrotic and non-enhancing tumor
11 | Label 2: edema
12 | Label 4: enhancing tumor
13 | Label 0: background
14 |
15 | MRI
16 | -------
17 | whole/complete tumor: 1 2 4
18 | core: 1 4
19 | enhance: 4
20 | """
21 | ###============================= SETTINGS ===================================###
22 | DATA_SIZE = 'half' # (small, half or all)
23 |
24 | save_dir = "data/train_dev_all/"
25 | if not os.path.exists(save_dir):
26 | os.makedirs(save_dir)
27 |
28 | HGG_data_path = "data/Brats17TrainingData/HGG"
29 | LGG_data_path = "data/Brats17TrainingData/LGG"
30 | survival_csv_path = "data/Brats17TrainingData/survival_data.csv"
31 | ###==========================================================================###
32 |
33 | survival_id_list = []
34 | survival_age_list =[]
35 | survival_peroid_list = []
36 |
37 | with open(survival_csv_path, 'r') as f:
38 | reader = csv.reader(f)
39 | next(reader)
40 | for idx, content in enumerate(reader):
41 | survival_id_list.append(content[0])
42 | survival_age_list.append(float(content[1]))
43 | survival_peroid_list.append(float(content[2]))
44 |
45 | print(len(survival_id_list)) #163
46 |
47 | if DATA_SIZE == 'all':
48 | HGG_path_list = tl.files.load_folder_list(path=HGG_data_path)
49 | LGG_path_list = tl.files.load_folder_list(path=LGG_data_path)
50 | elif DATA_SIZE == 'half':
51 | HGG_path_list = tl.files.load_folder_list(path=HGG_data_path)[0:100]# DEBUG WITH SMALL DATA
52 | LGG_path_list = tl.files.load_folder_list(path=LGG_data_path)[0:30] # DEBUG WITH SMALL DATA
53 | elif DATA_SIZE == 'small':
54 | HGG_path_list = tl.files.load_folder_list(path=HGG_data_path)[0:50] # DEBUG WITH SMALL DATA
55 | LGG_path_list = tl.files.load_folder_list(path=LGG_data_path)[0:20] # DEBUG WITH SMALL DATA
56 | else:
57 | exit("Unknow DATA_SIZE")
58 | print(len(HGG_path_list), len(LGG_path_list)) #210 #75
59 |
60 | HGG_name_list = [os.path.basename(p) for p in HGG_path_list]
61 | LGG_name_list = [os.path.basename(p) for p in LGG_path_list]
62 |
63 | survival_id_from_HGG = []
64 | survival_id_from_LGG = []
65 | for i in survival_id_list:
66 | if i in HGG_name_list:
67 | survival_id_from_HGG.append(i)
68 | elif i in LGG_name_list:
69 | survival_id_from_LGG.append(i)
70 | else:
71 | print(i)
72 |
73 | print(len(survival_id_from_HGG), len(survival_id_from_LGG)) #163, 0
74 |
75 | # use 42 from 210 (in 163 subset) and 15 from 75 as 0.8/0.2 train/dev split
76 |
77 | # use 126/42/42 from 210 (in 163 subset) and 45/15/15 from 75 as 0.6/0.2/0.2 train/dev/test split
78 | index_HGG = list(range(0, len(survival_id_from_HGG)))
79 | index_LGG = list(range(0, len(LGG_name_list)))
80 | # random.shuffle(index_HGG)
81 | # random.shuffle(index_HGG)
82 |
83 | if DATA_SIZE == 'all':
84 | dev_index_HGG = index_HGG[-84:-42]
85 | test_index_HGG = index_HGG[-42:]
86 | tr_index_HGG = index_HGG[:-84]
87 | dev_index_LGG = index_LGG[-30:-15]
88 | test_index_LGG = index_LGG[-15:]
89 | tr_index_LGG = index_LGG[:-30]
90 | elif DATA_SIZE == 'half':
91 | dev_index_HGG = index_HGG[-30:] # DEBUG WITH SMALL DATA
92 | test_index_HGG = index_HGG[-5:]
93 | tr_index_HGG = index_HGG[:-30]
94 | dev_index_LGG = index_LGG[-10:] # DEBUG WITH SMALL DATA
95 | test_index_LGG = index_LGG[-5:]
96 | tr_index_LGG = index_LGG[:-10]
97 | elif DATA_SIZE == 'small':
98 | dev_index_HGG = index_HGG[35:42] # DEBUG WITH SMALL DATA
99 | # print(index_HGG, dev_index_HGG)
100 | # exit()
101 | test_index_HGG = index_HGG[41:42]
102 | tr_index_HGG = index_HGG[0:35]
103 | dev_index_LGG = index_LGG[7:10] # DEBUG WITH SMALL DATA
104 | test_index_LGG = index_LGG[9:10]
105 | tr_index_LGG = index_LGG[0:7]
106 |
107 | survival_id_dev_HGG = [survival_id_from_HGG[i] for i in dev_index_HGG]
108 | survival_id_test_HGG = [survival_id_from_HGG[i] for i in test_index_HGG]
109 | survival_id_tr_HGG = [survival_id_from_HGG[i] for i in tr_index_HGG]
110 |
111 | survival_id_dev_LGG = [LGG_name_list[i] for i in dev_index_LGG]
112 | survival_id_test_LGG = [LGG_name_list[i] for i in test_index_LGG]
113 | survival_id_tr_LGG = [LGG_name_list[i] for i in tr_index_LGG]
114 |
115 | survival_age_dev = [survival_age_list[survival_id_list.index(i)] for i in survival_id_dev_HGG]
116 | survival_age_test = [survival_age_list[survival_id_list.index(i)] for i in survival_id_test_HGG]
117 | survival_age_tr = [survival_age_list[survival_id_list.index(i)] for i in survival_id_tr_HGG]
118 |
119 | survival_period_dev = [survival_peroid_list[survival_id_list.index(i)] for i in survival_id_dev_HGG]
120 | survival_period_test = [survival_peroid_list[survival_id_list.index(i)] for i in survival_id_test_HGG]
121 | survival_period_tr = [survival_peroid_list[survival_id_list.index(i)] for i in survival_id_tr_HGG]
122 |
123 | data_types = ['flair', 't1', 't1ce', 't2']
124 | data_types_mean_std_dict = {i: {'mean': 0.0, 'std': 1.0} for i in data_types}
125 |
126 | # calculate mean and std for all data types
127 |
128 | # preserving_ratio = 0.0
129 | # preserving_ratio = 0.01 # 0.118 removed
130 | # preserving_ratio = 0.05 # 0.213 removed
131 | # preserving_ratio = 0.10 # 0.359 removed
132 |
133 | #==================== LOAD ALL IMAGES' PATH AND COMPUTE MEAN/ STD
134 | for i in data_types:
135 | data_temp_list = []
136 | for j in HGG_name_list:
137 | img_path = os.path.join(HGG_data_path, j, j + '_' + i + '.nii.gz')
138 | img = nib.load(img_path).get_data()
139 | data_temp_list.append(img)
140 |
141 | for j in LGG_name_list:
142 | img_path = os.path.join(LGG_data_path, j, j + '_' + i + '.nii.gz')
143 | img = nib.load(img_path).get_data()
144 | data_temp_list.append(img)
145 |
146 | data_temp_list = np.asarray(data_temp_list)
147 | m = np.mean(data_temp_list)
148 | s = np.std(data_temp_list)
149 | data_types_mean_std_dict[i]['mean'] = m
150 | data_types_mean_std_dict[i]['std'] = s
151 | del data_temp_list
152 | print(data_types_mean_std_dict)
153 |
154 | with open(save_dir + 'mean_std_dict.pickle', 'wb') as f:
155 | pickle.dump(data_types_mean_std_dict, f, protocol=4)
156 |
157 |
158 | ##==================== GET NORMALIZE IMAGES
159 | X_train_input = []
160 | X_train_target = []
161 | # X_train_target_whole = [] # 1 2 4
162 | # X_train_target_core = [] # 1 4
163 | # X_train_target_enhance = [] # 4
164 |
165 | X_dev_input = []
166 | X_dev_target = []
167 | # X_dev_target_whole = [] # 1 2 4
168 | # X_dev_target_core = [] # 1 4
169 | # X_dev_target_enhance = [] # 4
170 |
171 | print(" HGG Validation")
172 | for i in survival_id_dev_HGG:
173 | all_3d_data = []
174 | for j in data_types:
175 | img_path = os.path.join(HGG_data_path, i, i + '_' + j + '.nii.gz')
176 | img = nib.load(img_path).get_data()
177 | img = (img - data_types_mean_std_dict[j]['mean']) / data_types_mean_std_dict[j]['std']
178 | img = img.astype(np.float32)
179 | all_3d_data.append(img)
180 |
181 | seg_path = os.path.join(HGG_data_path, i, i + '_seg.nii.gz')
182 | seg_img = nib.load(seg_path).get_data()
183 | seg_img = np.transpose(seg_img, (1, 0, 2))
184 | for j in range(all_3d_data[0].shape[2]):
185 | combined_array = np.stack((all_3d_data[0][:, :, j], all_3d_data[1][:, :, j], all_3d_data[2][:, :, j], all_3d_data[3][:, :, j]), axis=2)
186 | combined_array = np.transpose(combined_array, (1, 0, 2))#.tolist()
187 | combined_array.astype(np.float32)
188 | X_dev_input.append(combined_array)
189 |
190 | seg_2d = seg_img[:, :, j]
191 | # whole = np.zeros_like(seg_2d)
192 | # core = np.zeros_like(seg_2d)
193 | # enhance = np.zeros_like(seg_2d)
194 | # for index, x in np.ndenumerate(seg_2d):
195 | # if x == 1:
196 | # whole[index] = 1
197 | # core[index] = 1
198 | # if x == 2:
199 | # whole[index] = 1
200 | # if x == 4:
201 | # whole[index] = 1
202 | # core[index] = 1
203 | # enhance[index] = 1
204 | # X_dev_target_whole.append(whole)
205 | # X_dev_target_core.append(core)
206 | # X_dev_target_enhance.append(enhance)
207 | seg_2d.astype(int)
208 | X_dev_target.append(seg_2d)
209 | del all_3d_data
210 | gc.collect()
211 | print("finished {}".format(i))
212 |
213 | print(" LGG Validation")
214 | for i in survival_id_dev_LGG:
215 | all_3d_data = []
216 | for j in data_types:
217 | img_path = os.path.join(LGG_data_path, i, i + '_' + j + '.nii.gz')
218 | img = nib.load(img_path).get_data()
219 | img = (img - data_types_mean_std_dict[j]['mean']) / data_types_mean_std_dict[j]['std']
220 | img = img.astype(np.float32)
221 | all_3d_data.append(img)
222 |
223 | seg_path = os.path.join(LGG_data_path, i, i + '_seg.nii.gz')
224 | seg_img = nib.load(seg_path).get_data()
225 | seg_img = np.transpose(seg_img, (1, 0, 2))
226 | for j in range(all_3d_data[0].shape[2]):
227 | combined_array = np.stack((all_3d_data[0][:, :, j], all_3d_data[1][:, :, j], all_3d_data[2][:, :, j], all_3d_data[3][:, :, j]), axis=2)
228 | combined_array = np.transpose(combined_array, (1, 0, 2))#.tolist()
229 | combined_array.astype(np.float32)
230 | X_dev_input.append(combined_array)
231 |
232 | seg_2d = seg_img[:, :, j]
233 | # whole = np.zeros_like(seg_2d)
234 | # core = np.zeros_like(seg_2d)
235 | # enhance = np.zeros_like(seg_2d)
236 | # for index, x in np.ndenumerate(seg_2d):
237 | # if x == 1:
238 | # whole[index] = 1
239 | # core[index] = 1
240 | # if x == 2:
241 | # whole[index] = 1
242 | # if x == 4:
243 | # whole[index] = 1
244 | # core[index] = 1
245 | # enhance[index] = 1
246 | # X_dev_target_whole.append(whole)
247 | # X_dev_target_core.append(core)
248 | # X_dev_target_enhance.append(enhance)
249 | seg_2d.astype(int)
250 | X_dev_target.append(seg_2d)
251 | del all_3d_data
252 | gc.collect()
253 | print("finished {}".format(i))
254 |
255 | X_dev_input = np.asarray(X_dev_input, dtype=np.float32)
256 | X_dev_target = np.asarray(X_dev_target)#, dtype=np.float32)
257 | # print(X_dev_input.shape)
258 | # print(X_dev_target.shape)
259 |
260 | # with open(save_dir + 'dev_input.pickle', 'wb') as f:
261 | # pickle.dump(X_dev_input, f, protocol=4)
262 | # with open(save_dir + 'dev_target.pickle', 'wb') as f:
263 | # pickle.dump(X_dev_target, f, protocol=4)
264 |
265 | # del X_dev_input, X_dev_target
266 |
267 | print(" HGG Train")
268 | for i in survival_id_tr_HGG:
269 | all_3d_data = []
270 | for j in data_types:
271 | img_path = os.path.join(HGG_data_path, i, i + '_' + j + '.nii.gz')
272 | img = nib.load(img_path).get_data()
273 | img = (img - data_types_mean_std_dict[j]['mean']) / data_types_mean_std_dict[j]['std']
274 | img = img.astype(np.float32)
275 | all_3d_data.append(img)
276 |
277 | seg_path = os.path.join(HGG_data_path, i, i + '_seg.nii.gz')
278 | seg_img = nib.load(seg_path).get_data()
279 | seg_img = np.transpose(seg_img, (1, 0, 2))
280 | for j in range(all_3d_data[0].shape[2]):
281 | combined_array = np.stack((all_3d_data[0][:, :, j], all_3d_data[1][:, :, j], all_3d_data[2][:, :, j], all_3d_data[3][:, :, j]), axis=2)
282 | combined_array = np.transpose(combined_array, (1, 0, 2))#.tolist()
283 | combined_array.astype(np.float32)
284 | X_train_input.append(combined_array)
285 |
286 | seg_2d = seg_img[:, :, j]
287 | # whole = np.zeros_like(seg_2d)
288 | # core = np.zeros_like(seg_2d)
289 | # enhance = np.zeros_like(seg_2d)
290 | # for index, x in np.ndenumerate(seg_2d):
291 | # if x == 1:
292 | # whole[index] = 1
293 | # core[index] = 1
294 | # if x == 2:
295 | # whole[index] = 1
296 | # if x == 4:
297 | # whole[index] = 1
298 | # core[index] = 1
299 | # enhance[index] = 1
300 | # X_train_target_whole.append(whole)
301 | # X_train_target_core.append(core)
302 | # X_train_target_enhance.append(enhance)
303 | seg_2d.astype(int)
304 | X_train_target.append(seg_2d)
305 | del all_3d_data
306 | print("finished {}".format(i))
307 | # print(len(X_train_target))
308 |
309 |
310 | print(" LGG Train")
311 | for i in survival_id_tr_LGG:
312 | all_3d_data = []
313 | for j in data_types:
314 | img_path = os.path.join(LGG_data_path, i, i + '_' + j + '.nii.gz')
315 | img = nib.load(img_path).get_data()
316 | img = (img - data_types_mean_std_dict[j]['mean']) / data_types_mean_std_dict[j]['std']
317 | img = img.astype(np.float32)
318 | all_3d_data.append(img)
319 |
320 | seg_path = os.path.join(LGG_data_path, i, i + '_seg.nii.gz')
321 | seg_img = nib.load(seg_path).get_data()
322 | seg_img = np.transpose(seg_img, (1, 0, 2))
323 | for j in range(all_3d_data[0].shape[2]):
324 | combined_array = np.stack((all_3d_data[0][:, :, j], all_3d_data[1][:, :, j], all_3d_data[2][:, :, j], all_3d_data[3][:, :, j]), axis=2)
325 | combined_array = np.transpose(combined_array, (1, 0, 2))#.tolist()
326 | combined_array.astype(np.float32)
327 | X_train_input.append(combined_array)
328 |
329 | seg_2d = seg_img[:, :, j]
330 | # whole = np.zeros_like(seg_2d)
331 | # core = np.zeros_like(seg_2d)
332 | # enhance = np.zeros_like(seg_2d)
333 | # for index, x in np.ndenumerate(seg_2d):
334 | # if x == 1:
335 | # whole[index] = 1
336 | # core[index] = 1
337 | # if x == 2:
338 | # whole[index] = 1
339 | # if x == 4:
340 | # whole[index] = 1
341 | # core[index] = 1
342 | # enhance[index] = 1
343 | # X_train_target_whole.append(whole)
344 | # X_train_target_core.append(core)
345 | # X_train_target_enhance.append(enhance)
346 | seg_2d.astype(int)
347 | X_train_target.append(seg_2d)
348 | del all_3d_data
349 | print("finished {}".format(i))
350 |
351 | X_train_input = np.asarray(X_train_input, dtype=np.float32)
352 | X_train_target = np.asarray(X_train_target)#, dtype=np.float32)
353 | # print(X_train_input.shape)
354 | # print(X_train_target.shape)
355 |
356 | # with open(save_dir + 'train_input.pickle', 'wb') as f:
357 | # pickle.dump(X_train_input, f, protocol=4)
358 | # with open(save_dir + 'train_target.pickle', 'wb') as f:
359 | # pickle.dump(X_train_target, f, protocol=4)
360 |
361 |
362 |
363 | # X_train_target_whole = np.asarray(X_train_target_whole)
364 | # X_train_target_core = np.asarray(X_train_target_core)
365 | # X_train_target_enhance = np.asarray(X_train_target_enhance)
366 |
367 |
368 | # X_dev_target_whole = np.asarray(X_dev_target_whole)
369 | # X_dev_target_core = np.asarray(X_dev_target_core)
370 | # X_dev_target_enhance = np.asarray(X_dev_target_enhance)
371 |
372 |
373 | # print(X_train_target_whole.shape)
374 | # print(X_train_target_core.shape)
375 | # print(X_train_target_enhance.shape)
376 |
377 | # print(X_dev_target_whole.shape)
378 | # print(X_dev_target_core.shape)
379 | # print(X_dev_target_enhance.shape)
380 |
381 |
382 |
383 | # with open(save_dir + 'train_target_whole.pickle', 'wb') as f:
384 | # pickle.dump(X_train_target_whole, f, protocol=4)
385 |
386 | # with open(save_dir + 'train_target_core.pickle', 'wb') as f:
387 | # pickle.dump(X_train_target_core, f, protocol=4)
388 |
389 | # with open(save_dir + 'train_target_enhance.pickle', 'wb') as f:
390 | # pickle.dump(X_train_target_enhance, f, protocol=4)
391 |
392 | # with open(save_dir + 'dev_target_whole.pickle', 'wb') as f:
393 | # pickle.dump(X_dev_target_whole, f, protocol=4)
394 |
395 | # with open(save_dir + 'dev_target_core.pickle', 'wb') as f:
396 | # pickle.dump(X_dev_target_core, f, protocol=4)
397 |
398 | # with open(save_dir + 'dev_target_enhance.pickle', 'wb') as f:
399 | # pickle.dump(X_dev_target_enhance, f, protocol=4)
400 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorlayer as tl
3 | from tensorlayer.layers import *
4 | import numpy as np
5 |
6 |
7 | from tensorlayer.layers import *
8 | def u_net(x, is_train=False, reuse=False, n_out=1):
9 | _, nx, ny, nz = x.get_shape().as_list()
10 | with tf.variable_scope("u_net", reuse=reuse):
11 | tl.layers.set_name_reuse(reuse)
12 | inputs = InputLayer(x, name='inputs')
13 | conv1 = Conv2d(inputs, 64, (3, 3), act=tf.nn.relu, name='conv1_1')
14 | conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, name='conv1_2')
15 | pool1 = MaxPool2d(conv1, (2, 2), name='pool1')
16 | conv2 = Conv2d(pool1, 128, (3, 3), act=tf.nn.relu, name='conv2_1')
17 | conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, name='conv2_2')
18 | pool2 = MaxPool2d(conv2, (2, 2), name='pool2')
19 | conv3 = Conv2d(pool2, 256, (3, 3), act=tf.nn.relu, name='conv3_1')
20 | conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, name='conv3_2')
21 | pool3 = MaxPool2d(conv3, (2, 2), name='pool3')
22 | conv4 = Conv2d(pool3, 512, (3, 3), act=tf.nn.relu, name='conv4_1')
23 | conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, name='conv4_2')
24 | pool4 = MaxPool2d(conv4, (2, 2), name='pool4')
25 | conv5 = Conv2d(pool4, 1024, (3, 3), act=tf.nn.relu, name='conv5_1')
26 | conv5 = Conv2d(conv5, 1024, (3, 3), act=tf.nn.relu, name='conv5_2')
27 |
28 | up4 = DeConv2d(conv5, 512, (3, 3), (nx/8, ny/8), (2, 2), name='deconv4')
29 | up4 = ConcatLayer([up4, conv4], 3, name='concat4')
30 | conv4 = Conv2d(up4, 512, (3, 3), act=tf.nn.relu, name='uconv4_1')
31 | conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, name='uconv4_2')
32 | up3 = DeConv2d(conv4, 256, (3, 3), (nx/4, ny/4), (2, 2), name='deconv3')
33 | up3 = ConcatLayer([up3, conv3], 3, name='concat3')
34 | conv3 = Conv2d(up3, 256, (3, 3), act=tf.nn.relu, name='uconv3_1')
35 | conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, name='uconv3_2')
36 | up2 = DeConv2d(conv3, 128, (3, 3), (nx/2, ny/2), (2, 2), name='deconv2')
37 | up2 = ConcatLayer([up2, conv2], 3, name='concat2')
38 | conv2 = Conv2d(up2, 128, (3, 3), act=tf.nn.relu, name='uconv2_1')
39 | conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, name='uconv2_2')
40 | up1 = DeConv2d(conv2, 64, (3, 3), (nx/1, ny/1), (2, 2), name='deconv1')
41 | up1 = ConcatLayer([up1, conv1] , 3, name='concat1')
42 | conv1 = Conv2d(up1, 64, (3, 3), act=tf.nn.relu, name='uconv1_1')
43 | conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, name='uconv1_2')
44 | conv1 = Conv2d(conv1, n_out, (1, 1), act=tf.nn.sigmoid, name='uconv1')
45 | return conv1
46 |
47 | # def u_net(x, is_train=False, reuse=False, pad='SAME', n_out=2):
48 | # """ Original U-Net for cell segmentataion
49 | # http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/
50 | # Original x is [batch_size, 572, 572, ?], pad is VALID
51 | # """
52 | # from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, DeConv2d, ConcatLayer
53 | # nx = int(x._shape[1])
54 | # ny = int(x._shape[2])
55 | # nz = int(x._shape[3])
56 | # print(" * Input: size of image: %d %d %d" % (nx, ny, nz))
57 | #
58 | # w_init = tf.truncated_normal_initializer(stddev=0.01)
59 | # b_init = tf.constant_initializer(value=0.0)
60 | # with tf.variable_scope("u_net", reuse=reuse):
61 | # tl.layers.set_name_reuse(reuse)
62 | # inputs = InputLayer(x, name='inputs')
63 | #
64 | # conv1 = Conv2d(inputs, 64, (3, 3), act=tf.nn.relu, padding=pad,
65 | # W_init=w_init, b_init=b_init, name='conv1_1')
66 | # conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding=pad,
67 | # W_init=w_init, b_init=b_init, name='conv1_2')
68 | # pool1 = MaxPool2d(conv1, (2, 2), padding=pad, name='pool1')
69 | #
70 | # conv2 = Conv2d(pool1, 128, (3, 3), act=tf.nn.relu, padding=pad,
71 | # W_init=w_init, b_init=b_init, name='conv2_1')
72 | # conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding=pad,
73 | # W_init=w_init, b_init=b_init, name='conv2_2')
74 | # pool2 = MaxPool2d(conv2, (2, 2), padding=pad, name='pool2')
75 | #
76 | # conv3 = Conv2d(pool2, 256, (3, 3), act=tf.nn.relu, padding=pad,
77 | # W_init=w_init, b_init=b_init, name='conv3_1')
78 | # conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding=pad,
79 | # W_init=w_init, b_init=b_init, name='conv3_2')
80 | # pool3 = MaxPool2d(conv3, (2, 2), padding=pad, name='pool3')
81 | #
82 | # conv4 = Conv2d(pool3, 512, (3, 3), act=tf.nn.relu, padding=pad,
83 | # W_init=w_init, b_init=b_init, name='conv4_1')
84 | # conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding=pad,
85 | # W_init=w_init, b_init=b_init, name='conv4_2')
86 | # pool4 = MaxPool2d(conv4, (2, 2), padding=pad, name='pool4')
87 | #
88 | # conv5 = Conv2d(pool4, 1024, (3, 3), act=tf.nn.relu, padding=pad,
89 | # W_init=w_init, b_init=b_init, name='conv5_1')
90 | # conv5 = Conv2d(conv5, 1024, (3, 3), act=tf.nn.relu, padding=pad,
91 | # W_init=w_init, b_init=b_init, name='conv5_2')
92 | #
93 | # print(" * After conv: %s" % conv5.outputs)
94 | #
95 | # up4 = DeConv2d(conv5, 512, (3, 3), out_size = (nx/8, ny/8),
96 | # strides=(2, 2), padding=pad, act=None,
97 | # W_init=w_init, b_init=b_init, name='deconv4')
98 | # up4 = ConcatLayer([up4, conv4], concat_dim=3, name='concat4')
99 | # conv4 = Conv2d(up4, 512, (3, 3), act=tf.nn.relu, padding=pad,
100 | # W_init=w_init, b_init=b_init, name='uconv4_1')
101 | # conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding=pad,
102 | # W_init=w_init, b_init=b_init, name='uconv4_2')
103 | #
104 | # up3 = DeConv2d(conv4, 256, (3, 3), out_size = (nx/4, ny/4),
105 | # strides=(2, 2), padding=pad, act=None,
106 | # W_init=w_init, b_init=b_init, name='deconv3')
107 | # up3 = ConcatLayer([up3, conv3], concat_dim=3, name='concat3')
108 | # conv3 = Conv2d(up3, 256, (3, 3), act=tf.nn.relu, padding=pad,
109 | # W_init=w_init, b_init=b_init, name='uconv3_1')
110 | # conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding=pad,
111 | # W_init=w_init, b_init=b_init, name='uconv3_2')
112 | #
113 | # up2 = DeConv2d(conv3, 128, (3, 3), out_size=(nx/2, ny/2),
114 | # strides=(2, 2), padding=pad, act=None,
115 | # W_init=w_init, b_init=b_init, name='deconv2')
116 | # up2 = ConcatLayer([up2, conv2] ,concat_dim=3, name='concat2')
117 | # conv2 = Conv2d(up2, 128, (3, 3), act=tf.nn.relu, padding=pad,
118 | # W_init=w_init, b_init=b_init, name='uconv2_1')
119 | # conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding=pad,
120 | # W_init=w_init, b_init=b_init, name='uconv2_2')
121 | #
122 | # up1 = DeConv2d(conv2, 64, (3, 3), out_size=(nx/1, ny/1),
123 | # strides=(2, 2), padding=pad, act=None,
124 | # W_init=w_init, b_init=b_init, name='deconv1')
125 | # up1 = ConcatLayer([up1, conv1] ,concat_dim=3, name='concat1')
126 | # conv1 = Conv2d(up1, 64, (3, 3), act=tf.nn.relu, padding=pad,
127 | # W_init=w_init, b_init=b_init, name='uconv1_1')
128 | # conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding=pad,
129 | # W_init=w_init, b_init=b_init, name='uconv1_2')
130 | #
131 | # conv1 = Conv2d(conv1, n_out, (1, 1), act=tf.nn.sigmoid, name='uconv1')
132 | # print(" * Output: %s" % conv1.outputs)
133 | #
134 | # # logits0 = conv1.outputs[:,:,:,0] # segmentataion
135 | # # logits1 = conv1.outputs[:,:,:,1] # edge
136 | # # logits0 = tf.expand_dims(logits0, axis=3)
137 | # # logits1 = tf.expand_dims(logits1, axis=3)
138 | # return conv1
139 |
140 |
141 | def u_net_bn(x, is_train=False, reuse=False, batch_size=None, pad='SAME', n_out=1):
142 | """image to image translation via conditional adversarial learning"""
143 | nx = int(x._shape[1])
144 | ny = int(x._shape[2])
145 | nz = int(x._shape[3])
146 | print(" * Input: size of image: %d %d %d" % (nx, ny, nz))
147 |
148 | w_init = tf.truncated_normal_initializer(stddev=0.01)
149 | b_init = tf.constant_initializer(value=0.0)
150 | gamma_init=tf.random_normal_initializer(1., 0.02)
151 | with tf.variable_scope("u_net", reuse=reuse):
152 | tl.layers.set_name_reuse(reuse)
153 | inputs = InputLayer(x, name='inputs')
154 |
155 | conv1 = Conv2d(inputs, 64, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv1')
156 | conv2 = Conv2d(conv1, 128, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv2')
157 | conv2 = BatchNormLayer(conv2, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn2')
158 |
159 | conv3 = Conv2d(conv2, 256, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv3')
160 | conv3 = BatchNormLayer(conv3, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn3')
161 |
162 | conv4 = Conv2d(conv3, 512, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv4')
163 | conv4 = BatchNormLayer(conv4, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn4')
164 |
165 | conv5 = Conv2d(conv4, 512, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv5')
166 | conv5 = BatchNormLayer(conv5, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn5')
167 |
168 | conv6 = Conv2d(conv5, 512, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv6')
169 | conv6 = BatchNormLayer(conv6, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn6')
170 |
171 | conv7 = Conv2d(conv6, 512, (4, 4), (2, 2), act=None, padding=pad, W_init=w_init, b_init=b_init, name='conv7')
172 | conv7 = BatchNormLayer(conv7, act=lambda x: tl.act.lrelu(x, 0.2), is_train=is_train, gamma_init=gamma_init, name='bn7')
173 |
174 | conv8 = Conv2d(conv7, 512, (4, 4), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2), padding=pad, W_init=w_init, b_init=b_init, name='conv8')
175 | print(" * After conv: %s" % conv8.outputs)
176 | # exit()
177 | # print(nx/8)
178 | up7 = DeConv2d(conv8, 512, (4, 4), out_size=(2, 2), strides=(2, 2),
179 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv7')
180 | up7 = BatchNormLayer(up7, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn7')
181 |
182 | # print(up6.outputs)
183 | up6 = ConcatLayer([up7, conv7], concat_dim=3, name='concat6')
184 | up6 = DeConv2d(up6, 1024, (4, 4), out_size=(4, 4), strides=(2, 2),
185 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv6')
186 | up6 = BatchNormLayer(up6, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn6')
187 | # print(up6.outputs)
188 | # exit()
189 |
190 | up5 = ConcatLayer([up6, conv6], concat_dim=3, name='concat5')
191 | up5 = DeConv2d(up5, 1024, (4, 4), out_size=(8, 8), strides=(2, 2),
192 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv5')
193 | up5 = BatchNormLayer(up5, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn5')
194 | # print(up5.outputs)
195 | # exit()
196 |
197 | up4 = ConcatLayer([up5, conv5] ,concat_dim=3, name='concat4')
198 | up4 = DeConv2d(up4, 1024, (4, 4), out_size=(15, 15), strides=(2, 2),
199 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv4')
200 | up4 = BatchNormLayer(up4, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn4')
201 |
202 | up3 = ConcatLayer([up4, conv4] ,concat_dim=3, name='concat3')
203 | up3 = DeConv2d(up3, 256, (4, 4), out_size=(30, 30), strides=(2, 2),
204 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv3')
205 | up3 = BatchNormLayer(up3, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn3')
206 |
207 | up2 = ConcatLayer([up3, conv3] ,concat_dim=3, name='concat2')
208 | up2 = DeConv2d(up2, 128, (4, 4), out_size=(60, 60), strides=(2, 2),
209 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv2')
210 | up2 = BatchNormLayer(up2, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn2')
211 |
212 | up1 = ConcatLayer([up2, conv2] ,concat_dim=3, name='concat1')
213 | up1 = DeConv2d(up1, 64, (4, 4), out_size=(120, 120), strides=(2, 2),
214 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv1')
215 | up1 = BatchNormLayer(up1, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn1')
216 |
217 | up0 = ConcatLayer([up1, conv1] ,concat_dim=3, name='concat0')
218 | up0 = DeConv2d(up0, 64, (4, 4), out_size=(240, 240), strides=(2, 2),
219 | padding=pad, act=None, batch_size=batch_size, W_init=w_init, b_init=b_init, name='deconv0')
220 | up0 = BatchNormLayer(up0, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn0')
221 | # print(up0.outputs)
222 | # exit()
223 |
224 | out = Conv2d(up0, n_out, (1, 1), act=tf.nn.sigmoid, name='out')
225 |
226 | print(" * Output: %s" % out.outputs)
227 | # exit()
228 |
229 | return out
230 |
231 | ## old implementation
232 | # def u_net_2d_64_1024_deconv(x, n_out=2):
233 | # from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, DeConv2d, ConcatLayer
234 | # nx = int(x._shape[1])
235 | # ny = int(x._shape[2])
236 | # nz = int(x._shape[3])
237 | # print(" * Input: size of image: %d %d %d" % (nx, ny, nz))
238 | #
239 | # w_init = tf.truncated_normal_initializer(stddev=0.01)
240 | # b_init = tf.constant_initializer(value=0.0)
241 | # inputs = InputLayer(x, name='inputs')
242 | #
243 | # conv1 = Conv2d(inputs, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_1')
244 | # conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_2')
245 | # pool1 = MaxPool2d(conv1, (2, 2), padding='SAME', name='pool1')
246 | #
247 | # conv2 = Conv2d(pool1, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_1')
248 | # conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_2')
249 | # pool2 = MaxPool2d(conv2, (2, 2), padding='SAME', name='pool2')
250 | #
251 | # conv3 = Conv2d(pool2, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_1')
252 | # conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_2')
253 | # pool3 = MaxPool2d(conv3, (2, 2), padding='SAME', name='pool3')
254 | #
255 | # conv4 = Conv2d(pool3, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_1')
256 | # conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_2')
257 | # pool4 = MaxPool2d(conv4, (2, 2), padding='SAME', name='pool4')
258 | #
259 | # conv5 = Conv2d(pool4, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_1')
260 | # conv5 = Conv2d(conv5, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_2')
261 | #
262 | # print(" * After conv: %s" % conv5.outputs)
263 | #
264 | # up4 = DeConv2d(conv5, 512, (3, 3), out_size = (nx/8, ny/8), strides = (2, 2),
265 | # padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv4')
266 | # up4 = ConcatLayer([up4, conv4], concat_dim=3, name='concat4')
267 | # conv4 = Conv2d(up4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv4_1')
268 | # conv4 = Conv2d(conv4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv4_2')
269 | #
270 | # up3 = DeConv2d(conv4, 256, (3, 3), out_size = (nx/4, ny/4), strides = (2, 2),
271 | # padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv3')
272 | # up3 = ConcatLayer([up3, conv3], concat_dim=3, name='concat3')
273 | # conv3 = Conv2d(up3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv3_1')
274 | # conv3 = Conv2d(conv3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv3_2')
275 | #
276 | # up2 = DeConv2d(conv3, 128, (3, 3), out_size = (nx/2, ny/2), strides = (2, 2),
277 | # padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv2')
278 | # up2 = ConcatLayer([up2, conv2] ,concat_dim=3, name='concat2')
279 | # conv2 = Conv2d(up2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv2_1')
280 | # conv2 = Conv2d(conv2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv2_2')
281 | #
282 | # up1 = DeConv2d(conv2, 64, (3, 3), out_size = (nx/1, ny/1), strides = (2, 2),
283 | # padding = 'SAME', act=None, W_init=w_init, b_init=b_init, name='deconv1')
284 | # up1 = ConcatLayer([up1, conv1] ,concat_dim=3, name='concat1')
285 | # conv1 = Conv2d(up1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv1_1')
286 | # conv1 = Conv2d(conv1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='uconv1_2')
287 | #
288 | # conv1 = Conv2d(conv1, n_out, (1, 1), act=None, name='uconv1')
289 | # print(" * Output: %s" % conv1.outputs)
290 | # outputs = tl.act.pixel_wise_softmax(conv1.outputs)
291 | # return conv1, outputs
292 | #
293 | #
294 | # def u_net_2d_32_1024_upsam(x, n_out=2):
295 | # """
296 | # https://github.com/jocicmarko/ultrasound-nerve-segmentation
297 | # """
298 | # from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, DeConv2d, ConcatLayer
299 | # batch_size = int(x._shape[0])
300 | # nx = int(x._shape[1])
301 | # ny = int(x._shape[2])
302 | # nz = int(x._shape[3])
303 | # print(" * Input: size of image: %d %d %d" % (nx, ny, nz))
304 | # ## define initializer
305 | # w_init = tf.truncated_normal_initializer(stddev=0.01)
306 | # b_init = tf.constant_initializer(value=0.0)
307 | # inputs = InputLayer(x, name='inputs')
308 | #
309 | # conv1 = Conv2d(inputs, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_1')
310 | # conv1 = Conv2d(conv1, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_2')
311 | # pool1 = MaxPool2d(conv1, (2, 2), padding='SAME', name='pool1')
312 | #
313 | # conv2 = Conv2d(pool1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_1')
314 | # conv2 = Conv2d(conv2, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_2')
315 | # pool2 = MaxPool2d(conv2, (2,2), padding='SAME', name='pool2')
316 | #
317 | # conv3 = Conv2d(pool2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_1')
318 | # conv3 = Conv2d(conv3, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_2')
319 | # pool3 = MaxPool2d(conv3, (2, 2), padding='SAME', name='pool3')
320 | #
321 | # conv4 = Conv2d(pool3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_1')
322 | # conv4 = Conv2d(conv4, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_2')
323 | # pool4 = MaxPool2d(conv4, (2, 2), padding='SAME', name='pool4')
324 | #
325 | # conv5 = Conv2d(pool4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_1')
326 | # conv5 = Conv2d(conv5, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_2')
327 | # pool5 = MaxPool2d(conv5, (2, 2), padding='SAME', name='pool6')
328 | #
329 | # # hao add
330 | # conv6 = Conv2d(pool5, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv6_1')
331 | # conv6 = Conv2d(conv6, 1024, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv6_2')
332 | #
333 | # print(" * After conv: %s" % conv6.outputs)
334 | #
335 | # # hao add
336 | # up7 = UpSampling2dLayer(conv6, (15, 15), is_scale=False, method=1, name='up7')
337 | # up7 = ConcatLayer([up7, conv5], concat_dim=3, name='concat7')
338 | # conv7 = Conv2d(up7, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv7_1')
339 | # conv7 = Conv2d(conv7, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv7_2')
340 | #
341 | # # print(nx/8,ny/8) # 30 30
342 | # up8 = UpSampling2dLayer(conv7, (2, 2), method=1, name='up8')
343 | # up8 = ConcatLayer([up8, conv4], concat_dim=3, name='concat8')
344 | # conv8 = Conv2d(up8, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv8_1')
345 | # conv8 = Conv2d(conv8, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv8_2')
346 | #
347 | # up9 = UpSampling2dLayer(conv8, (2, 2), method=1, name='up9')
348 | # up9 = ConcatLayer([up9, conv3] ,concat_dim=3, name='concat9')
349 | # conv9 = Conv2d(up9, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv9_1')
350 | # conv9 = Conv2d(conv9, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv9_2')
351 | #
352 | # up10 = UpSampling2dLayer(conv9, (2, 2), method=1, name='up10')
353 | # up10 = ConcatLayer([up10, conv2] ,concat_dim=3, name='concat10')
354 | # conv10 = Conv2d(up10, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv10_1')
355 | # conv10 = Conv2d(conv10, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv10_2')
356 | #
357 | # up11 = UpSampling2dLayer(conv10, (2, 2), method=1, name='up11')
358 | # up11 = ConcatLayer([up11, conv1] ,concat_dim=3, name='concat11')
359 | # conv11 = Conv2d(up11, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv11_1')
360 | # conv11 = Conv2d(conv11, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv11_2')
361 | #
362 | # conv12 = Conv2d(conv11, n_out, (1, 1), act=None, name='conv12')
363 | # print(" * Output: %s" % conv12.outputs)
364 | # outputs = tl.act.pixel_wise_softmax(conv12.outputs)
365 | # return conv10, outputs
366 | #
367 | #
368 | # def u_net_2d_32_512_upsam(x, n_out=2):
369 | # """
370 | # https://github.com/jocicmarko/ultrasound-nerve-segmentation
371 | # """
372 | # from tensorlayer.layers import InputLayer, Conv2d, MaxPool2d, DeConv2d, ConcatLayer
373 | # batch_size = int(x._shape[0])
374 | # nx = int(x._shape[1])
375 | # ny = int(x._shape[2])
376 | # nz = int(x._shape[3])
377 | # print(" * Input: size of image: %d %d %d" % (nx, ny, nz))
378 | # ## define initializer
379 | # w_init = tf.truncated_normal_initializer(stddev=0.01)
380 | # b_init = tf.constant_initializer(value=0.0)
381 | # inputs = InputLayer(x, name='inputs')
382 | # # inputs = Input((1, img_rows, img_cols))
383 | # conv1 = Conv2d(inputs, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_1')
384 | # # print(conv1.outputs) # (10, 240, 240, 32)
385 | # # conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs)
386 | # conv1 = Conv2d(conv1, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv1_2')
387 | # # print(conv1.outputs) # (10, 240, 240, 32)
388 | # # conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1)
389 | # pool1 = MaxPool2d(conv1, (2, 2), padding='SAME', name='pool1')
390 | # # pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
391 | # # print(pool1.outputs) # (10, 120, 120, 32)
392 | # # exit()
393 | # conv2 = Conv2d(pool1, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_1')
394 | # # conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1)
395 | # conv2 = Conv2d(conv2, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv2_2')
396 | # # conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2)
397 | # pool2 = MaxPool2d(conv2, (2,2), padding='SAME', name='pool2')
398 | # # pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
399 | #
400 | # conv3 = Conv2d(pool2, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_1')
401 | # # conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2)
402 | # conv3 = Conv2d(conv3, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv3_2')
403 | # # conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3)
404 | # pool3 = MaxPool2d(conv3, (2, 2), padding='SAME', name='pool3')
405 | # # pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
406 | # # print(pool3.outputs) # (10, 30, 30, 64)
407 | #
408 | # conv4 = Conv2d(pool3, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_1')
409 | # # print(conv4.outputs) # (10, 30, 30, 256)
410 | # # conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(pool3)
411 | # conv4 = Conv2d(conv4, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv4_2')
412 | # # print(conv4.outputs) # (10, 30, 30, 256) != (10, 30, 30, 512)
413 | # # conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv4)
414 | # pool4 = MaxPool2d(conv4, (2, 2), padding='SAME', name='pool4')
415 | # # pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
416 | #
417 | # conv5 = Conv2d(pool4, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_1')
418 | # # conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(pool4)
419 | # conv5 = Conv2d(conv5, 512, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv5_2')
420 | # # conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(conv5)
421 | # # print(conv5.outputs) # (10, 15, 15, 512)
422 | # print(" * After conv: %s" % conv5.outputs)
423 | # # print(nx/8,ny/8) # 30 30
424 | # up6 = UpSampling2dLayer(conv5, (2, 2), name='up6')
425 | # # print(up6.outputs) # (10, 30, 30, 512) == (10, 30, 30, 512)
426 | # up6 = ConcatLayer([up6, conv4], concat_dim=3, name='concat6')
427 | # # print(up6.outputs) # (10, 30, 30, 768)
428 | # # up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)
429 | # conv6 = Conv2d(up6, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv6_1')
430 | # # conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(up6)
431 | # conv6 = Conv2d(conv6, 256, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv6_2')
432 | # # conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv6)
433 | #
434 | # up7 = UpSampling2dLayer(conv6, (2, 2), name='up7')
435 | # up7 = ConcatLayer([up7, conv3] ,concat_dim=3, name='concat7')
436 | # # up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1)
437 | # conv7 = Conv2d(up7, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv7_1')
438 | # # conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(up7)
439 | # conv7 = Conv2d(conv7, 128, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv7_2')
440 | # # conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv7)
441 | #
442 | # up8 = UpSampling2dLayer(conv7, (2, 2), name='up8')
443 | # up8 = ConcatLayer([up8, conv2] ,concat_dim=3, name='concat8')
444 | # # up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)
445 | # conv8 = Conv2d(up8, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv8_1')
446 | # # conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up8)
447 | # conv8 = Conv2d(conv8, 64, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv8_2')
448 | # # conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv8)
449 | #
450 | # up9 = UpSampling2dLayer(conv8, (2, 2), name='up9')
451 | # up9 = ConcatLayer([up9, conv1] ,concat_dim=3, name='concat9')
452 | # # up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1)
453 | # conv9 = Conv2d(up9, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv9_1')
454 | # # conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up9)
455 | # conv9 = Conv2d(conv9, 32, (3, 3), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init, name='conv9_2')
456 | # # conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9)
457 | #
458 | # conv10 = Conv2d(conv9, n_out, (1, 1), act=None, name='conv9')
459 | # # conv10 = Convolution2D(1, 1, 1, activation='sigmoid')(conv9)
460 | # print(" * Output: %s" % conv10.outputs)
461 | # outputs = tl.act.pixel_wise_softmax(conv10.outputs)
462 | # return conv10, outputs
463 |
464 |
465 | if __name__ == "__main__":
466 | pass
467 | # main()
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 | #
488 |
--------------------------------------------------------------------------------