├── figure
├── saml.png
└── protocol.png
├── README.md
├── layer.py
├── main.py
├── train.py
├── data_generator.py
├── utils.py
└── saml_func.py
/figure/saml.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuquande/SAML/HEAD/figure/saml.png
--------------------------------------------------------------------------------
/figure/protocol.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuquande/SAML/HEAD/figure/protocol.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SAML & A Multi-site Dataset for Prostate MRI Segmentation
2 | by [Quande Liu](https://github.com/liuquande), [Qi Dou](http://www.cse.cuhk.edu.hk/~qdou/), [Pheng-Ann Heng](http://www.cse.cuhk.edu.hk/~pheng/).
3 |
4 | ### Introduction
5 |
6 | * The Tensorflow implementation for our MICCAI 2020 paper '[Shape-aware Meta-learning for Generalizing Prostate MRI Segmentation to Unseen Domains](https://arxiv.org/pdf/2007.02035.pdf)'.
7 |
8 |
9 |
10 |
11 |
12 | * A well-organized multi-site dataset (from six data sources) for prostate MRI segmentation, that can support research in various problem settings with need of multi-site data, such as Domain Generalization, Multi-site Learning and Life-long Learning, etc. For more details and downloading link of the dataset, please [Find Here](https://liuquande.github.io/SAML/).
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | ### Setup & Usage for the Code
21 |
22 | 1. Check dependencies:
23 | ```shell
24 | python==2.7.17
25 | numpy==1.16.6
26 | scipy==1.2.1
27 | tensorflow-gpu==1.12.0
28 | tensorboard==1.12.2
29 | SimpleITK==1.2.0
30 | ```
31 | 2. To train the model, you need to specify the training configurations (can simply use the default setting) in main.py, then run:
32 | ```shell
33 | python main.py --phase=train
34 | ```
35 |
36 | 2. To evaluate the model, run:
37 | ```shell
38 | python main.py --phase=test --restore_model='/path/to/test_model.cpkt'
39 | ```
40 | You will see the output results in the folder `./output/`.
41 |
42 | ### Citation
43 | If this repository is useful for your research, please cite:
44 |
45 | ```
46 | @article{liu2020shape,
47 | title={Shape-aware Meta-learning for Generalizing Prostate MRI Segmentation to Unseen Domains},
48 | author={Liu, Quande and Dou, Qi and Heng, Pheng-Ann},
49 | journal={International Conference on Medical Image Computing and Computer Assisted Intervention},
50 | year={2020}
51 | }
52 | ```
53 |
54 | ### Questions
55 |
56 | For further question about the code or dataset, please contact 'qdliu@cse.cuhk.edu.hk'
57 |
--------------------------------------------------------------------------------
/layer.py:
--------------------------------------------------------------------------------
1 | from tensorflow.contrib.layers.python import layers as tf_layers
2 | from tensorflow.python.platform import flags
3 | import tensorflow as tf
4 | import tensorflow.contrib.slim as slim
5 |
6 | def concat2d(x1,x2):
7 | """ concatenation without offset check"""
8 | x1_shape = tf.shape(x1)
9 | x2_shape = tf.shape(x2)
10 | try:
11 | tf.equal(x1_shape[0:-2], x2_shape[0: -2])
12 | except:
13 | print("x1_shape: %s"%str(x1.get_shape().as_list()))
14 | print("x2_shape: %s"%str(x2.get_shape().as_list()))
15 | raise ValueError("Cannot concatenate tensors with different shape, igonoring feature map depth")
16 | return tf.concat([x1, x2], 3)
17 |
18 | def normalize(inp, activation, reuse=tf.AUTO_REUSE, scope='', form='batch_norm', is_training=True):
19 | if form == 'batch_norm':
20 | return tf_layers.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope, is_training=is_training)
21 | elif form == 'layer_norm':
22 | return tf_layers.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope)
23 | elif form == 'None':
24 | if activation is not None:
25 | return activation(inp)
26 | else:
27 | return inp
28 |
29 | def conv_block(inp, cweight, bweight, scope='', bn=True, is_training=True):
30 | """ Perform, conv, batch norm, nonlinearity, and max pool """
31 | conv = tf.nn.conv2d(inp, cweight, strides=[1, 1, 1, 1], padding='SAME')
32 | conv = tf.nn.bias_add(conv, bweight)
33 |
34 | if bn == True:
35 | normed = normalize(conv, tf.nn.relu, scope=scope, is_training=is_training)
36 | return normed
37 | else:
38 | return conv
39 | # relu = tf.nn.leaky_relu(normed)
40 | # normalize = batch_norm(relu, True)
41 |
42 |
43 | def deconv_block(inp, cweight, bweight, scope='', is_training=True):
44 | # x_shape = tf.shape(inp)
45 | x_shape = inp.get_shape().as_list()
46 | output_shape = tf.stack([x_shape[0], x_shape[1]*2, x_shape[2]*2, x_shape[3]//2])
47 | deconv = tf.nn.conv2d_transpose(inp, cweight, output_shape, strides=[1,2,2,1], padding='SAME')
48 | deconv = tf.nn.bias_add(deconv, bweight)
49 |
50 | normed = normalize(deconv, tf.nn.relu, scope=scope, is_training=is_training)
51 | return normed
52 |
53 |
54 | def max_pool(x, filter_height, filter_width, stride_y, stride_x, padding='SAME'):
55 | """Create a max pooling layer."""
56 | return tf.nn.max_pool(x, ksize=[1, filter_height, filter_width, 1], strides=[1, stride_y, stride_x, 1], padding=padding)
57 |
58 | def lrn(x, radius, alpha, beta, bias=1.0):
59 | """Create a local response normalization layer."""
60 | return tf.nn.local_response_normalization(x, depth_radius=radius, alpha=alpha, beta=beta, bias=bias)
61 |
62 | def dropout(x, keep_prob):
63 | """Create a dropout layer."""
64 | return tf.nn.dropout(x, keep_prob)
65 |
66 | def fc(x, wweight, bweight, activation=None):
67 | """Create a fully connected layer."""
68 |
69 | act = tf.nn.xw_plus_b(x, wweight, bweight)
70 |
71 | if activation is 'relu':
72 | return tf.nn.relu(act)
73 | elif activation is 'leaky_relu':
74 | return tf.nn.leaky_relu(act)
75 | elif activation is None:
76 | return act
77 | else:
78 | raise NotImplementedError
79 |
80 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import numpy as np
4 | import tensorflow as tf
5 | from tensorflow.python.platform import flags
6 | from data_generator import ImageDataGenerator
7 | from saml_func import SAML
8 | from train import train
9 | from train import test
10 | import datetime
11 | import argparse
12 | from utils import check_folder, show_all_variables
13 | import logging
14 |
15 | currtime = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
16 | tf.set_random_seed(2)
17 |
18 | def parse_args(train_date):
19 | desc = "Tensorflow implementation of DenseUNet for prostate segmentation"
20 | parser = argparse.ArgumentParser(description=desc)
21 | parser.add_argument('--gpu', type=str, default='0', help='train or test or guide')
22 | parser.add_argument('--phase', type=str, default='train', help='train or test or guide')
23 | parser.add_argument('--n_class', type=int, default=2, help='The size of class')
24 |
25 | ## Training operations
26 | parser.add_argument('--target_domain', type=str, default='ISBI', help='dataset_name')
27 | parser.add_argument('--volume_size', type=list, default=[384, 384, 3], help='The size of input data')
28 | parser.add_argument('--label_size', type=list, default=[384, 384, 1], help='The size of label')
29 | parser.add_argument('--epoch', type=int, default=1, help='The number of epochs to run')
30 | parser.add_argument('--train_iterations', type=int, default=10000, help='The number of training iterations')
31 | parser.add_argument('--meta_batch_size', type=int, default=5, help='number of images sampled per source domain')
32 | parser.add_argument('--test_batch_size', type=int, default=1, help='number of images sampled per source domain')
33 | parser.add_argument('--inner_lr', type=float, default=1e-4, help='The learning rate')
34 | parser.add_argument('--outer_lr', type=float, default=1e-3, help='The learning rate')
35 | parser.add_argument('--metric_lr', type=float, default=1e-3, help='The learning rate')
36 | parser.add_argument('--margin', type=float, default=10.0, help='The learning rate')
37 | parser.add_argument('--compactness_loss_weight', type=float, default=1.0, help='The learning rate')
38 | parser.add_argument('--smoothness_loss_weight', type=float, default=0.005, help='The learning rate')
39 | parser.add_argument('--clipNorm', type=int, default=True, help='number of images sampled per source domain')
40 | parser.add_argument('--gradients_clip_value', type=float, default=10.0, help='The learning rate')
41 |
42 | # Logging, saving, and testing options
43 | parser.add_argument('--resume', type=int, default=False, help='number of images sampled per source domain')
44 | parser.add_argument('--log', type=int, default=True, help='write tensorboard')
45 | parser.add_argument('--decay_step', type=float, default=500, help='The learning rate')
46 | parser.add_argument('--decay_rate', type=float, default=0.95, help='The learning rate')
47 | parser.add_argument('--test_freq', type=int, default=200, help='The number of ckpt_save_freq')
48 | parser.add_argument('--save_freq', type=int, default=200, help='The number of ckpt_save_freq')
49 | parser.add_argument('--print_interval', type=int, default=5, help='The frequency to write tensorboard')
50 | parser.add_argument('--summary_interval', type=int, default=20, help='The frequency to write tensorboard')
51 | parser.add_argument('--restored_model', type=str, default=None, help='Model to restore')
52 | parser.add_argument('--test_model', type=str, default=None, help='Model to restore')
53 | # parser.add_argument('--dropout', type=str, default=1, help='dropout rate')
54 | # parser.add_argument('--cost_kwargs', type=str, default=1, help='cost_kwargs')
55 | # parser.add_argument('--opt_kwargs', type=str, default=1, help='opt_kwargs')
56 |
57 | parser.add_argument('--checkpoint_dir', type=str, default='../output/' + train_date + '/checkpoints/' ,
58 | help='Directory name to save the checkpoints')
59 | parser.add_argument('--result_dir', type=str, default='../output/' + train_date + '/results/',
60 | help='Directory name to save the generated images')
61 | parser.add_argument('--log_dir', type=str, default='../output/' + train_date + '/logs/',
62 | help='Directory name to save training logs')
63 | parser.add_argument('--sample_dir', type=str, default='../output/' + train_date + '/samples/',
64 | help='Directory name to save the samples on training')
65 |
66 | return check_args(parser.parse_args())
67 |
68 | """checking arguments"""
69 | def check_args(args):
70 | # --checkpoint_dir
71 | check_folder(args.checkpoint_dir)
72 | # --result_dir
73 | check_folder(args.result_dir)
74 | # --result_dir
75 | check_folder(args.log_dir)
76 | # --sample_dir
77 | check_folder(args.sample_dir)
78 |
79 | return args
80 |
81 | def main():
82 | train_date = 'xxx'
83 | args = parse_args(train_date)
84 |
85 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
86 |
87 | # define logger
88 | logging.basicConfig(filename=args.log_dir+"/"+args.phase+'_log.txt', level=logging.DEBUG, format='%(asctime)s %(message)s')
89 | logging.getLogger().addHandler(logging.StreamHandler())
90 |
91 | # print all parameters
92 | logging.info("Usage:")
93 | logging.info(" {0}".format(" ".join([x for x in sys.argv])))
94 | logging.debug("All settings used:")
95 |
96 | os.system('cp main.py %s' % (args.log_dir)) # bkp of train procedure
97 | os.system('cp saml_func.py %s' % (args.log_dir)) # bkp of train procedure
98 | os.system('cp train.py %s' % (args.log_dir)) # bkp of train procedure
99 | os.system('cp utils.py %s' % (args.log_dir)) # bkp of train procedure
100 | os.system('cp data_generator.py %s' % (args.log_dir))
101 |
102 |
103 | filelist_root = '../dataset'
104 | source_list = ['HK', 'ISBI', 'ISBI_1.5', 'I2CVB','UCL', 'BIDMC']#'ISBI_1.5', 'I2CVB', 'UCL','BIDMC']#, 'I2CVB', 'ISBI_1.5', 'UCL', 'BIDMC']#'I2CVB', 'UCL', 'BIDMC', 'HK']
105 | source_list.remove(args.target_domain)
106 |
107 | # Constructing model
108 | model = SAML(args)
109 | model.construct_model_train()
110 | model.construct_model_test()
111 |
112 | model.summ_op = tf.summary.merge_all()
113 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
114 | sess = tf.InteractiveSession()
115 |
116 | tf.global_variables_initializer().run()
117 | show_all_variables()
118 |
119 | # restore model ----
120 | resume_itr = 0
121 | model_file = None
122 | if args.resume:
123 | model_file = tf.train.latest_checkpoint(args.checkpoint_dir)
124 | if model_file:
125 | ind1 = model_file.index('model')
126 | resume_itr = int(model_file[ind1+5:])
127 | print("Restoring model weights from " + model_file)
128 | saver.restore(sess, model_file)
129 |
130 | train_file_list = [os.path.join(filelist_root, source_domain+'_train_list') for source_domain in source_list]
131 | test_file_list = [os.path.join(filelist_root, args.target_domain+'_train_list')]
132 |
133 | # start training ----
134 | if args.phase == 'train':
135 | train(model, saver, sess, train_file_list, test_file_list[0], args, resume_itr)
136 | else:
137 | args.test_model = 'xxx'
138 | saver.restore(sess, args.test_model)
139 | logging.info("testing model restored %s" % args.test_model)
140 |
141 | test_dice, test_dice_arr, test_haus, test_haus_arr = test(sess, test_file_list[0], model, args)
142 | with open((os.path.join(args.log_dir,'test.txt')), 'a') as f:
143 | print >> f, 'testing model %s :' % (args.test_model)
144 | print >> f, ' Unseen domain testing results: Dice: %f' %(test_dice), test_dice_arr
145 | print >> f, ' Unseen domain testing results: Haus: %f' %(test_haus), test_haus_arr
146 |
147 | if __name__ == "__main__":
148 | main()
149 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from tensorflow.python.platform import flags
4 | from data_generator import ImageDataGenerator
5 | import logging
6 | from utils import _eval_dice, _connectivity_region_analysis, parse_fn, _crop_object_region, _get_coutour_sample, parse_fn_haus,_eval_haus
7 | import time
8 | import os
9 | import SimpleITK as sitk
10 |
11 | def train(model, saver, sess, train_file_list, test_file, args, resume_itr=0):
12 |
13 | if args.log:
14 | train_writer = tf.summary.FileWriter(args.log_dir + '/' + args.phase + '/', sess.graph)
15 |
16 | # Data loaders
17 | with tf.device('/cpu:0'):
18 | tr_data_list, train_iterator_list, train_next_list = [],[],[]
19 | for i in range(len(train_file_list)):
20 | tr_data = ImageDataGenerator(train_file_list[i], mode='training', \
21 | batch_size=args.meta_batch_size, num_classes=args.n_class, shuffle=True)
22 | tr_data_list.append(tr_data)
23 | train_iterator_list.append(tf.data.Iterator.from_structure(tr_data.data.output_types,tr_data.data.output_shapes))
24 | train_next_list.append(train_iterator_list[i].get_next())
25 |
26 | # Ops for initializing different iterators
27 | training_init_op = []
28 | train_batches_per_epoch = []
29 | for i in range(len(train_file_list)):
30 | training_init_op.append(train_iterator_list[i].make_initializer(tr_data_list[i].data))
31 | sess.run(training_init_op[i]) # initialize training sample generator at itr=0
32 |
33 | # Training begins
34 | best_test_dice = 0
35 | best_test_haus = 0
36 | for epoch in xrange(0, args.epoch):
37 | for itr in range(resume_itr, args.train_iterations):
38 | start = time.time()
39 | # Sampling training and test tasks
40 | num_training_tasks = len(train_file_list)
41 | num_meta_train = 2#num_training_tasks-1
42 | num_meta_test = 1#num_training_tasks-num_meta_train # as setting num_meta_test = 1
43 |
44 | # Randomly choosing meta train and meta test domains
45 | task_list = np.random.permutation(num_training_tasks)
46 | meta_train_index_list = task_list[:2]
47 | meta_test_index_list = task_list[-1:]
48 |
49 | # Sampling meta-train, meta-test data
50 | for i in range(num_meta_train):
51 | task_ind = meta_train_index_list[i]
52 | if i == 0:
53 | inputa, labela = sess.run(train_next_list[task_ind])
54 | elif i == 1:
55 | inputa1, labela1 = sess.run(train_next_list[task_ind])
56 | else:
57 | raise RuntimeError('check number of meta-train domains.')
58 |
59 | for i in range(num_meta_test):
60 | task_ind = meta_test_index_list[i]
61 | if i == 0:
62 | inputb, labelb = sess.run(train_next_list[task_ind])
63 | else:
64 | raise RuntimeError('check number of meta-test domains.')
65 |
66 | input_group = np.concatenate((inputa[:2],inputa1[:1],inputb[:2]), axis=0)
67 | label_group = np.concatenate((labela[:2],labela1[:1],labelb[:2]), axis=0)
68 |
69 | contour_group, metric_label_group = _get_coutour_sample(label_group)
70 |
71 | feed_dict = {model.inputa: inputa, model.labela: labela, \
72 | model.inputa1: inputa1, model.labela1: labela1, \
73 | model.inputb: inputb, model.labelb: labelb, \
74 | model.input_group:input_group, \
75 | model.label_group:label_group, \
76 | model.contour_group:contour_group, \
77 | model.metric_label_group:metric_label_group, \
78 | model.KEEP_PROB: 1.0}
79 |
80 | output_tensors = [model.task_train_op, model.meta_train_op, model.metric_train_op]
81 | output_tensors.extend([model.summ_op, model.seg_loss_b, model.compactness_loss_b, model.smoothness_loss_b, model.target_loss, model.source_loss])
82 | _, _, _, summ_writer, seg_loss_b, compactness_loss_b, smoothness_loss_b, target_loss, source_loss = sess.run(output_tensors, feed_dict)
83 | # output_tensors = [model.task_train_op]
84 | # output_tensors.extend([model.source_loss])
85 | # _, source_loss = sess.run(output_tensors, feed_dict)
86 |
87 | if itr % args.print_interval == 0:
88 | logging.info("Epoch: [%2d] [%6d/%6d] time: %4.4f inner lr:%.8f outer lr:%.8f" % (epoch, itr, args.train_iterations, (time.time()-start), model.inner_lr.eval(), model.outer_lr.eval()))
89 | logging.info('sou_loss: %.7f, tar_loss: %.7f, tar_seg_loss: %.7f, tar_compactness_loss: %.7f, tar_smoothness_loss: %.7f' % (source_loss, target_loss, seg_loss_b, compactness_loss_b, smoothness_loss_b))
90 |
91 | if itr % args.summary_interval == 0:
92 | train_writer.add_summary(summ_writer, itr)
93 | train_writer.flush()
94 |
95 | if (itr!=0) and itr % args.save_freq == 0:
96 | saver.save(sess, args.checkpoint_dir + '/epoch_' + str(epoch) + '_itr_'+str(itr) + ".model.cpkt")
97 |
98 | # Testing periodically
99 | if (itr!=0) and itr % args.test_freq == 0:
100 | test_dice, test_dice_arr, test_haus, test_haus_arr = test(sess, test_file, model, args)
101 |
102 | if test_dice > best_test_dice:
103 | best_test_dice = test_dice
104 |
105 | with open((os.path.join(args.log_dir,'eva.txt')), 'a') as f:
106 | print >> f, 'Iteration %d :' % (itr)
107 | print >> f, ' Unseen domain testing results: Dice: %f' %(test_dice), test_dice_arr
108 | print >> f, ' Current best accuracy %f' %(best_test_dice)
109 | print >> f, ' Unseen domain testing results: Haus: %f' %(test_haus), test_haus_arr
110 | print >> f, ' Current best accuracy %f' %(best_test_haus)
111 | # Save model
112 |
113 | def test(sess, test_list, model, args):
114 |
115 | dice = []
116 | haus = []
117 | start = time.time()
118 |
119 | with open(test_list, 'r') as fp:
120 | rows = fp.readlines()
121 | test_list = [row[:-1] if row[-1] == '\n' else row for row in rows]
122 |
123 | for fid, filename in enumerate(test_list):
124 | image, mask, spacing = parse_fn_haus(filename)
125 | pred_y = np.zeros(mask.shape)
126 |
127 | frame_list = [kk for kk in range(1, image.shape[2] - 1)]
128 |
129 | for ii in xrange(int(np.floor(image.shape[2] // model.test_batch_size))):
130 | vol = np.zeros([model.test_batch_size, model.volume_size[0], model.volume_size[1], model.volume_size[2]])
131 |
132 | for idx, jj in enumerate(frame_list[ii * model.test_batch_size: (ii + 1) * model.test_batch_size]):
133 | vol[idx, ...] = image[..., jj - 1: jj + 2].copy()
134 |
135 | pred_student = sess.run((model.outputs), feed_dict={model.test_input: vol, \
136 | model.KEEP_PROB: 1.0,\
137 | model.training_mode: True})
138 |
139 | for idx, jj in enumerate(frame_list[ii * model.test_batch_size: (ii + 1) * model.test_batch_size]):
140 | pred_y[..., jj] = pred_student[idx, ...].copy()
141 |
142 | processed_pred_y = _connectivity_region_analysis(pred_y)
143 |
144 | dice_subject = _eval_dice(mask, processed_pred_y)
145 |
146 | # print spacing
147 | dice.append(dice_subject)
148 | # haus.append(haus_subject)
149 | # _save_nii_prediction(mask, processed_pred_y, pred_y, args.result_dir, '_' + filename[-26:-20])
150 | dice_avg = np.mean(dice, axis=0).tolist()[0]
151 | # haus_avg = np.mean(haus, axis=0).tolist()[0]
152 |
153 | logging.info("dice_avg %.4f" % (dice_avg))
154 | # logging.info("haus_avg %.4f" % (haus_avg))
155 |
156 | return dice_avg, dice, 0, 0
157 | # return dice_avg, dice, haus_avg, haus
158 |
159 | def _save_nii_prediction(gth, comp_pred, pre_pred, out_folder, out_bname):
160 | sitk.WriteImage(sitk.GetImageFromArray(gth), out_folder + out_bname + 'gth.nii.gz')
161 | sitk.WriteImage(sitk.GetImageFromArray(pre_pred), out_folder + out_bname + 'premask.nii.gz')
162 | sitk.WriteImage(sitk.GetImageFromArray(comp_pred), out_folder + out_bname + 'mask.nii.gz')
163 |
--------------------------------------------------------------------------------
/data_generator.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import os
4 | # from matplotlib import pyplot as plt
5 | from tensorflow.python.framework import dtypes
6 | from tensorflow.python.framework.ops import convert_to_tensor
7 | import skimage as sk
8 | from skimage import transform
9 | import SimpleITK as sitk
10 |
11 | IMAGENET_MEAN = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32)
12 |
13 |
14 | class ImageDataGenerator(object):
15 |
16 | def __init__(self, txt_file, mode, batch_size, num_classes, shuffle=True, buffer_size=5):
17 |
18 | """Create a new ImageDataGenerator.
19 | Receives a path string to a text file, where each line has a path string to an image and
20 | separated by a space, then with an integer referring to the class number.
21 |
22 | Args:
23 | txt_file: path to the text file.
24 | mode: either 'training' or 'validation'. Depending on this value, different parsing functions will be used.
25 | batch_size: number of images per batch.
26 | num_classes: number of classes in the dataset.
27 | shuffle: wether or not to shuffle the data in the dataset and the initial file list.
28 | buffer_size: number of images used as buffer for TensorFlows shuffling of the dataset.
29 |
30 | Raises:
31 | ValueError: If an invalid mode is passed.
32 | """
33 |
34 | self.txt_file = txt_file
35 | self.num_classes = num_classes
36 |
37 | # retrieve the data from the text file
38 | self._read_txt_file()
39 |
40 | # number of samples in the dataset
41 | self.data_size = len(self.img_paths)
42 |
43 | # initial shuffling of the file and label lists together
44 | if shuffle:
45 | self._shuffle_lists()
46 |
47 | # convert lists to TF tensor
48 | self.img_paths = convert_to_tensor(self.img_paths, dtype=dtypes.string)
49 |
50 | # create dataset
51 | data = tf.data.Dataset.from_tensor_slices((self.img_paths))
52 |
53 | # repeat indefinitely (train.py will count the epochs)
54 | data = data.repeat()
55 |
56 | # distinguish between train/infer. when calling the parsing functions
57 | self.get_patches_fn = lambda filename: tf.py_func(self.extract_patch, [filename, [384,384,3], 2], [tf.float32, tf.float32])
58 |
59 | if mode == 'training':
60 | data = data.map(self.get_patches_fn, num_parallel_calls=8)
61 |
62 | elif mode == 'inference':
63 | data = data.map(self._parse_function_inference, num_parallel_calls=8)
64 |
65 | else:
66 | raise ValueError("Invalid mode '%s'." % (mode))
67 |
68 | # shuffle the first `buffer_size` elements of the dataset
69 | if shuffle:
70 | data = data.shuffle(buffer_size=buffer_size)
71 |
72 | # create a new dataset with batches of images
73 | data = data.batch(batch_size)
74 |
75 | self.data = data
76 |
77 | def _read_txt_file(self):
78 | """Read the content of the text file and store it into lists."""
79 | with open(self.txt_file, 'r') as f:
80 | rows = f.readlines()
81 | self.img_paths = [row[:-1] for row in rows]
82 |
83 | def _shuffle_lists(self):
84 | """Conjoined shuffling of the list of paths and labels."""
85 | path = self.img_paths
86 | permutation = np.random.permutation(self.data_size)
87 | self.img_paths = []
88 | for i in permutation:
89 | self.img_paths.append(path[i])
90 |
91 | def extract_patch(self, filename, patch_size, num_class, num_patches=1):
92 | """Input parser for samples of the training set."""
93 | # convert label number into one-hot-encoding
94 |
95 | image, mask = self.parse_fn(filename) # get the image and its mask
96 | image_patches = []
97 | mask_patches = []
98 | num_patches_now = 0
99 |
100 | while num_patches_now < num_patches:
101 | # z = np.random.randint(1, mask.shape[2]-1)
102 | z = self.random_patch_center_z(mask, patch_size=patch_size) # define the centre of current patch
103 | image_patch = image[:, :, z-1:z+2]
104 | mask_patch = mask[:, :, z]
105 |
106 | image_patches.append(image_patch)
107 | mask_patches.append(mask_patch)
108 | num_patches_now += 1
109 | image_patches = np.stack(image_patches) # make into 4D (batch_size, patch_size[0], patch_size[1], patch_size[2])
110 | mask_patches = np.stack(mask_patches) # make into 4D (batch_size, patch_size[0], patch_size[1], patch_size[2])
111 |
112 | mask_patches = self._label_decomp(mask_patches, num_cls=num_class) # make into 5D (batch_size, patch_size[0], patch_size[1], patch_size[2], num_classes)
113 | #print image_patches.shape
114 | return image_patches[0,...].astype(np.float32), mask_patches[0,...].astype(np.float32)
115 |
116 | def random_patch_center_z(self, mask, patch_size):
117 | # bounded within the brain mask region
118 | limX, limY, limZ = np.where(mask>0)
119 | if (np.min(limZ) + patch_size[2] // 2 + 1) < (np.max(limZ) - patch_size[2] // 2):
120 | z = np.random.randint(low = np.min(limZ) + patch_size[2] // 2 + 1, high = np.max(limZ) - patch_size[2] // 2)
121 | else:
122 | z = np.random.randint(low = patchsize[2]//2, high = mask.shape[2] - patchsize[2]//2)
123 |
124 | limX, limY, limZ = np.where(mask>0)
125 |
126 | z = np.random.randint(low = max(1, np.min(limZ)), high = min(np.max(limZ), mask.shape[2] - 2))
127 | # z = np.random.randint(low = max(1, np.min(limZ)), high = min(np.max(limZ), mask.shape[2] - 2))
128 |
129 | return z
130 |
131 | def parse_fn(self, data_path):
132 | '''
133 | :param image_path: path to a folder of a patient
134 | :return: normalized entire image with its corresponding label
135 | In an image, the air region is 0, so we only calculate the mean and std within the brain area
136 | For any image-level normalization, do it here
137 | '''
138 | path = data_path.split(",")
139 | image_path = path[0]
140 | label_path = path[1]
141 | #itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
142 | #itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
143 | itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
144 | itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
145 | # itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz'))
146 |
147 | image = sitk.GetArrayFromImage(itk_image)
148 | mask = sitk.GetArrayFromImage(itk_mask)
149 | #image[image >= 1000] = 1000
150 | binary_mask = np.ones(mask.shape)
151 | mean = np.sum(image * binary_mask) / np.sum(binary_mask)
152 | std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask))
153 | image = (image - mean) / std # normalize per image, using statistics within the brain, but apply to whole image
154 |
155 | mask[mask==2] = 1
156 |
157 | return image.transpose([1,2,0]), mask.transpose([1,2,0]) # transpose the orientation of the
158 |
159 |
160 | def _label_decomp(self, label_vol, num_cls):
161 | """
162 | decompose label for softmax classifier
163 | original labels are batchsize * W * H * 1, with label values 0,1,2,3...
164 | this function decompse it to one hot, e.g.: 0,0,0,1,0,0 in channel dimension
165 | numpy version of tf.one_hot
166 | """
167 | one_hot = []
168 | for i in xrange(num_cls):
169 | _vol = np.zeros(label_vol.shape)
170 | _vol[label_vol == i] = 1
171 | one_hot.append(_vol)
172 |
173 | return np.stack(one_hot, axis=-1)
174 | # def augment(self, x):
175 | # # add more types of augmentations here
176 | # augmentations = [self.flip]
177 | # for f in augmentations:
178 | # x = tf.cond(tf.random_uniform([], 0, 1) < 0.25, lambda: f(x), lambda: x)
179 |
180 | # return x
181 |
182 | # def flip(self, x):
183 | # """Flip augmentation
184 | # Args:
185 | # x: Image to flip
186 | # Returns:
187 | # Augmented image
188 | # """
189 | # x = tf.image.random_flip_left_right(x)
190 | # # x = tf.image.random_flip_up_down(x)
191 |
192 | # return x
193 |
194 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """ Utility functions. """
2 | import numpy as np
3 | import os
4 | import random
5 | import tensorflow as tf
6 |
7 | from tensorflow.contrib.layers.python import layers as tf_layers
8 | from tensorflow.python.platform import flags
9 | import SimpleITK as sitk
10 | from scipy import ndimage
11 | import itertools
12 | from tensorflow.contrib import slim
13 | from scipy.ndimage import _ni_support
14 | from scipy.ndimage.morphology import distance_transform_edt, binary_erosion,\
15 | generate_binary_structure
16 | FLAGS = flags.FLAGS
17 |
18 | ## Image reader
19 | def get_images(paths, labels, nb_samples=None, shuffle=True):
20 | if nb_samples is not None:
21 | sampler = lambda x: random.sample(x, nb_samples)
22 | else:
23 | sampler = lambda x: x
24 | images = [(i, os.path.join(path, image)) \
25 | for i, path in zip(labels, paths) \
26 | for image in sampler(os.listdir(path))]
27 | if shuffle:
28 | random.shuffle(images)
29 | return images
30 |
31 | ## Loss functions
32 | def mse(pred, label):
33 | pred = tf.reshape(pred, [-1])
34 | label = tf.reshape(label, [-1])
35 | return tf.reduce_mean(tf.square(pred-label))
36 |
37 | def xent(pred, label):
38 | return tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=label)
39 |
40 | def kd(data1, label1, data2, label2, bool_indicator, n_class=7, temperature=2.0):
41 |
42 | kd_loss = 0.0
43 | eps = 1e-16
44 |
45 | prob1s = []
46 | prob2s = []
47 |
48 | for cls in range(n_class):
49 | mask1 = tf.tile(tf.expand_dims(label1[:, cls], -1), [1, n_class])
50 | logits_sum1 = tf.reduce_sum(tf.multiply(data1, mask1), axis=0)
51 | num1 = tf.reduce_sum(label1[:, cls])
52 | activations1 = logits_sum1 * 1.0 / (num1 + eps) # add eps for prevent un-sampled class resulting in NAN
53 | prob1 = tf.nn.softmax(activations1 / temperature)
54 | prob1 = tf.clip_by_value(prob1, clip_value_min=1e-8, clip_value_max=1.0) # for preventing prob=0 resulting in NAN
55 |
56 | mask2 = tf.tile(tf.expand_dims(label2[:, cls], -1), [1, n_class])
57 | logits_sum2 = tf.reduce_sum(tf.multiply(data2, mask2), axis=0)
58 | num2 = tf.reduce_sum(label2[:, cls])
59 | activations2 = logits_sum2 * 1.0 / (num2 + eps)
60 | prob2 = tf.nn.softmax(activations2 / temperature)
61 | prob2 = tf.clip_by_value(prob2, clip_value_min=1e-8, clip_value_max=1.0)
62 |
63 | KL_div = (tf.reduce_sum(prob1 * tf.log(prob1 / prob2)) + tf.reduce_sum(prob2 * tf.log(prob2 / prob1))) / 2.0
64 | kd_loss += KL_div * bool_indicator[cls]
65 |
66 | prob1s.append(prob1)
67 | prob2s.append(prob2)
68 |
69 | kd_loss = kd_loss / n_class
70 |
71 | return kd_loss, prob1s, prob2s
72 |
73 | def JS(data1, label1, data2, label2, bool_indicator, n_class=7, temperature=2.0):
74 |
75 | kd_loss = 0.0
76 | eps = 1e-16
77 |
78 | prob1s = []
79 | prob2s = []
80 |
81 | for cls in range(n_class):
82 | mask1 = tf.tile(tf.expand_dims(label1[:, cls], -1), [1, n_class])
83 | logits_sum1 = tf.reduce_sum(tf.multiply(data1, mask1), axis=0)
84 | num1 = tf.reduce_sum(label1[:, cls])
85 | activations1 = logits_sum1 * 1.0 / (num1 + eps) # add eps for prevent un-sampled class resulting in NAN
86 | prob1 = tf.nn.softmax(activations1 / temperature)
87 | prob1 = tf.clip_by_value(prob1, clip_value_min=1e-8, clip_value_max=1.0) # for preventing prob=0 resulting in NAN
88 |
89 | mask2 = tf.tile(tf.expand_dims(label2[:, cls], -1), [1, n_class])
90 | logits_sum2 = tf.reduce_sum(tf.multiply(data2, mask2), axis=0)
91 | num2 = tf.reduce_sum(label2[:, cls])
92 | activations2 = logits_sum2 * 1.0 / (num2 + eps)
93 | prob2 = tf.nn.softmax(activations2 / temperature)
94 | prob2 = tf.clip_by_value(prob2, clip_value_min=1e-8, clip_value_max=1.0)
95 |
96 | mean_prob = (prob1 + prob2) / 2
97 |
98 | JS_div = (tf.reduce_sum(prob1 * tf.log(prob1 / mean_prob)) + tf.reduce_sum(prob2 * tf.log(prob2 / mean_prob))) / 2.0
99 | kd_loss += JS_div * bool_indicator[cls]
100 |
101 | prob1s.append(prob1)
102 | prob2s.append(prob2)
103 |
104 | kd_loss = kd_loss / n_class
105 |
106 | return kd_loss, prob1s, prob2s
107 |
108 | def contrastive(feature1, label1, feature2, label2, bool_indicator=None, margin=50):
109 |
110 | l1 = tf.argmax(label1, axis=1)
111 | l2 = tf.argmax(label2, axis=1)
112 | pair = tf.to_float(tf.equal(l1,l2))
113 |
114 | delta = tf.reduce_sum(tf.square(feature1-feature2), 1) + 1e-10
115 | match_loss = delta
116 |
117 | delta_sqrt = tf.sqrt(delta + 1e-10)
118 | mismatch_loss = tf.square(tf.nn.relu(margin - delta_sqrt))
119 |
120 | if bool_indicator is None:
121 | loss = tf.reduce_mean(0.5 * (pair * match_loss + (1-pair) * mismatch_loss))
122 | else:
123 | loss = 0.5 * tf.reduce_sum(match_loss*pair)/tf.reduce_sum(pair)
124 |
125 | debug_dist_positive = tf.reduce_sum(delta_sqrt * pair)/tf.reduce_sum(pair)
126 | debug_dist_negative = tf.reduce_sum(delta_sqrt * (1-pair))/tf.reduce_sum(1-pair)
127 |
128 | return loss, pair, delta, debug_dist_positive, debug_dist_negative
129 |
130 | def compute_distance(feature1, label1, feature2, label2):
131 | l1 = tf.argmax(label1, axis=1)
132 | l2 = tf.argmax(label2, axis=1)
133 | pair = tf.to_float(tf.equal(l1,l2))
134 |
135 | delta = tf.reduce_sum(tf.square(feature1-feature2), 1)
136 | delta_sqrt = tf.sqrt(delta + 1e-16)
137 |
138 | dist_positive_pair = tf.reduce_sum(delta_sqrt * pair)/tf.reduce_sum(pair)
139 | dist_negative_pair = tf.reduce_sum(delta_sqrt * (1-pair))/tf.reduce_sum(1-pair)
140 |
141 | return dist_positive_pair, dist_negative_pair
142 |
143 | def _get_segmentation_cost(softmaxpred, seg_gt, n_class=2):
144 | """
145 | calculate the loss for segmentation prediction
146 | :param seg_logits: probability segmentation from the segmentation network
147 | :param seg_gt: ground truth segmentaiton mask
148 | :return: segmentation loss, according to the cost_kwards setting, cross-entropy weighted loss and dice loss
149 | """
150 | dice = 0
151 |
152 | for i in xrange(n_class):
153 | #inse = tf.reduce_sum(softmaxpred[:, :, :, i]*seg_gt[:, :, :, i])
154 | inse = tf.reduce_sum(softmaxpred[:, :, :, i]*seg_gt[:, :, :, i])
155 | l = tf.reduce_sum(softmaxpred[:, :, :, i])
156 | r = tf.reduce_sum(seg_gt[:, :, :, i])
157 | dice += 2.0 * inse/(l+r+1e-7) # here 1e-7 is relaxation eps
158 | dice_loss = 1 - 1.0 * dice / n_class
159 |
160 | # ce_weighted = 0
161 | # for i in xrange(n_class):
162 | # gti = seg_gt[:,:,:,i]
163 | # predi = softmaxpred[:,:,:,i]
164 | # ce_weighted += -1.0 * gti * tf.log(tf.clip_by_value(predi, 0.005, 1))
165 | # ce_weighted_loss = tf.reduce_mean(ce_weighted)
166 |
167 | # total_loss = dice_loss
168 |
169 |
170 | return dice_loss#, dice_loss, ce_weighted_loss
171 |
172 | def _get_compactness_cost(y_pred, y_true):
173 |
174 | """
175 | y_pred: BxHxWxC
176 | """
177 | """
178 | lenth term
179 | """
180 |
181 | # y_pred = tf.one_hot(y_pred, depth=2)
182 | # print (y_true.shape)
183 | # print (y_pred.shape)
184 | y_pred = y_pred[..., 1]
185 | y_true = y_pred[..., 1]
186 |
187 | x = y_pred[:,1:,:] - y_pred[:,:-1,:] # horizontal and vertical directions
188 | y = y_pred[:,:,1:] - y_pred[:,:,:-1]
189 |
190 | delta_x = x[:,:,1:]**2
191 | delta_y = y[:,1:,:]**2
192 |
193 | delta_u = tf.abs(delta_x + delta_y)
194 |
195 | epsilon = 0.00000001 # where is a parameter to avoid square root is zero in practice.
196 | w = 0.01
197 | length = w * tf.reduce_sum(tf.sqrt(delta_u + epsilon), [1, 2])
198 |
199 | area = tf.reduce_sum(y_pred, [1,2])
200 |
201 | compactness_loss = tf.reduce_sum(length ** 2 / (area * 4 * 3.1415926))
202 |
203 | return compactness_loss, tf.reduce_sum(length), tf.reduce_sum(area), delta_u
204 |
205 | # def _get_sample_masf(y_true):
206 | # """
207 | # y_pred: BxHxWx2
208 | # """
209 | # positive_mask = np.expand_dims(y_true[..., 1], axis=3)
210 | # metrix_label_group = np.expand_dims(np.array([1, 0, 1, 1, 0]), axis = 1)
211 | # # print (positive_mask.shape)
212 | # coutour_group = np.zeros(positive_mask.shape)
213 |
214 | # for i in range(positive_mask.shape[0]):
215 | # slice_i = positive_mask[i]
216 |
217 | # if metrix_label_group[i] == 1:
218 | # sample = (slice_i == 1)
219 | # elif metrix_label_group[i] == 0:
220 | # sample = (slice_i == 0)
221 |
222 | # coutour_group[i] = sample
223 |
224 | # return coutour_group, metrix_label_group
225 |
226 | def _get_coutour_sample(y_true):
227 | """
228 | y_true: BxHxWx2
229 | """
230 | positive_mask = np.expand_dims(y_true[..., 1], axis=3)
231 | metrix_label_group = np.expand_dims(np.array([1, 0, 1, 1, 0]), axis = 1)
232 | coutour_group = np.zeros(positive_mask.shape)
233 |
234 | for i in range(positive_mask.shape[0]):
235 | slice_i = positive_mask[i]
236 |
237 | if metrix_label_group[i] == 1:
238 | # generate coutour mask
239 | erosion = ndimage.binary_erosion(slice_i[..., 0], iterations=1).astype(slice_i.dtype)
240 | sample = np.expand_dims(slice_i[..., 0] - erosion, axis = 2)
241 |
242 | elif metrix_label_group[i] == 0:
243 | # generate background mask
244 | dilation = ndimage.binary_dilation(slice_i, iterations=5).astype(slice_i.dtype)
245 | sample = dilation - slice_i
246 |
247 | coutour_group[i] = sample
248 | return coutour_group, metrix_label_group
249 |
250 | # def _get_negative(y_true):
251 | def _get_boundary_cost(y_pred, y_true):
252 |
253 | """
254 | y_pred: BxHxWxC
255 | """
256 | """
257 | lenth term
258 | """
259 |
260 | # y_pred = tf.one_hot(y_pred, depth=2)
261 | # print (y_true.shape)
262 | # print (y_pred.shape)
263 | y_pred = y_pred[..., 1]
264 | y_true = y_pred[..., 1]
265 |
266 | x = y_pred[:,1:,:] - y_pred[:,:-1,:] # horizontal and vertical directions
267 | y = y_pred[:,:,1:] - y_pred[:,:,:-1]
268 |
269 | delta_x = x[:,:,1:]**2
270 | delta_y = y[:,1:,:]**2
271 |
272 | delta_u = tf.abs(delta_x + delta_y)
273 |
274 | epsilon = 0.00000001 # where is a parameter to avoid square root is zero in practice.
275 | w = 0.01
276 | length = w * tf.reduce_sum(tf.sqrt(delta_u + epsilon), [1, 2]) # equ.(11) in the paper
277 |
278 | area = tf.reduce_sum(y_pred, [1,2])
279 |
280 | compactness_loss = tf.reduce_sum(length ** 2 / (area * 4 * 3.1415926))
281 |
282 | return compactness_loss, tf.reduce_sum(length), tf.reduce_sum(area)
283 |
284 | def check_folder(log_dir):
285 | if not os.path.exists(log_dir):
286 | print ("Allocating '{:}'".format(log_dir))
287 | os.makedirs(log_dir)
288 | return log_dir
289 |
290 | def _eval_dice(gt_y, pred_y, detail=False):
291 |
292 | class_map = { # a map used for mapping label value to its name, used for output
293 | "0": "bg",
294 | "1": "CZ",
295 | "2": "prostate"
296 | }
297 |
298 | dice = []
299 |
300 | for cls in xrange(1,2):
301 |
302 | gt = np.zeros(gt_y.shape)
303 | pred = np.zeros(pred_y.shape)
304 |
305 | gt[gt_y == cls] = 1
306 | pred[pred_y == cls] = 1
307 |
308 | dice_this = 2*np.sum(gt*pred)/(np.sum(gt)+np.sum(pred))
309 | dice.append(dice_this)
310 |
311 | if detail is True:
312 | #print ("class {}, dice is {:2f}".format(class_map[str(cls)], dice_this))
313 | logging.info("class {}, dice is {:2f}".format(class_map[str(cls)], dice_this))
314 | return dice
315 |
316 | def __surface_distances(result, reference, voxelspacing=None, connectivity=1):
317 | """
318 | The distances between the surface voxel of binary objects in result and their
319 | nearest partner surface voxel of a binary object in reference.
320 | """
321 | result = np.atleast_1d(result.astype(np.bool))
322 | reference = np.atleast_1d(reference.astype(np.bool))
323 | if voxelspacing is not None:
324 | voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim)
325 | voxelspacing = np.asarray(voxelspacing, dtype=np.float64)
326 | if not voxelspacing.flags.contiguous:
327 | voxelspacing = voxelspacing.copy()
328 |
329 | # binary structure
330 | footprint = generate_binary_structure(result.ndim, connectivity)
331 |
332 | # test for emptiness
333 | if 0 == np.count_nonzero(result):
334 | raise RuntimeError('The first supplied array does not contain any binary object.')
335 | if 0 == np.count_nonzero(reference):
336 | raise RuntimeError('The second supplied array does not contain any binary object.')
337 |
338 | # extract only 1-pixel border line of objects
339 | result_border = result ^ binary_erosion(result, structure=footprint, iterations=1)
340 | reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1)
341 |
342 | # compute average surface distance
343 | # Note: scipys distance transform is calculated only inside the borders of the
344 | # foreground objects, therefore the input has to be reversed
345 | dt = distance_transform_edt(~reference_border, sampling=voxelspacing)
346 | sds = dt[result_border]
347 |
348 | return sds
349 |
350 | def asd(result, reference, voxelspacing=None, connectivity=1):
351 |
352 | sds = __surface_distances(result, reference, voxelspacing, connectivity)
353 | asd = sds.mean()
354 | return asd
355 |
356 | def calculate_hausdorff(lP,lT,spacing):
357 |
358 | return asd(lP, lT, spacing)
359 |
360 | def _eval_haus(pred, gt, spacing, detail=False):
361 | '''
362 | :param pred: whole brain prediction
363 | :param gt: whole
364 | :param detail:
365 | :return: a list, indicating Dice of each class for one case
366 | '''
367 | haus = []
368 |
369 | for cls in range(1,2):
370 | pred_i = np.zeros(pred.shape)
371 | pred_i[pred == cls] = 1
372 | gt_i = np.zeros(gt.shape)
373 | gt_i[gt == cls] = 1
374 |
375 | # hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
376 | # hausdorff_distance_filter.Execute(gt_i, pred_i)
377 |
378 | haus_cls = calculate_hausdorff(gt_i, (pred_i), spacing)
379 |
380 | haus.append(haus_cls)
381 |
382 | if detail is True:
383 | logging.info("class {}, haus is {:4f}".format(class_map[str(cls)], haus_cls))
384 | # logging.info("4 class average haus is {:4f}".format(np.mean(haus)))
385 |
386 | return haus
387 |
388 | def _connectivity_region_analysis(mask):
389 | s = [[0,1,0],
390 | [1,1,1],
391 | [0,1,0]]
392 | label_im, nb_labels = ndimage.label(mask)#, structure=s)
393 |
394 | sizes = ndimage.sum(mask, label_im, range(nb_labels + 1))
395 |
396 | # plt.imshow(label_im)
397 | label_im[label_im != np.argmax(sizes)] = 0
398 | label_im[label_im == np.argmax(sizes)] = 1
399 |
400 | return label_im
401 |
402 | def _crop_object_region(mask, prediction):
403 |
404 | limX, limY, limZ = np.where(mask>0)
405 | min_z = np.min(limZ)
406 | max_z = np.max(limZ)
407 |
408 | prediction[..., :np.min(limZ)] = 0
409 | prediction[..., np.max(limZ)+1:] = 0
410 |
411 | return prediction
412 |
413 | def parse_fn(data_path):
414 | '''
415 | :param image_path: path to a folder of a patient
416 | :return: normalized entire image with its corresponding label
417 | In an image, the air region is 0, so we only calculate the mean and std within the brain area
418 | For any image-level normalization, do it here
419 | '''
420 | path = data_path.split(",")
421 | image_path = path[0]
422 | label_path = path[1]
423 | #itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
424 | #itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
425 | itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
426 | itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
427 | # itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz'))
428 |
429 | image = sitk.GetArrayFromImage(itk_image)
430 | mask = sitk.GetArrayFromImage(itk_mask)
431 | #image[image >= 1000] = 1000
432 | binary_mask = np.ones(mask.shape)
433 | mean = np.sum(image * binary_mask) / np.sum(binary_mask)
434 | std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask))
435 | image = (image - mean) / std # normalize per image, using statistics within the brain, but apply to whole image
436 |
437 | mask[mask==2] = 1
438 |
439 | return image.transpose([1,2,0]), mask.transpose([1,2,0]) # transpose the orientation of the
440 |
441 |
442 | def parse_fn_haus(data_path):
443 | '''
444 | :param image_path: path to a folder of a patient
445 | :return: normalized entire image with its corresponding label
446 | In an image, the air region is 0, so we only calculate the mean and std within the brain area
447 | For any image-level normalization, do it here
448 | '''
449 | path = data_path.split(",")
450 | image_path = path[0]
451 | label_path = path[1]
452 | #itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
453 | #itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
454 | itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
455 | itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
456 | # itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz'))
457 | spacing = itk_mask.GetSpacing()
458 |
459 | image = sitk.GetArrayFromImage(itk_image)
460 | mask = sitk.GetArrayFromImage(itk_mask)
461 | #image[image >= 1000] = 1000
462 | binary_mask = np.ones(mask.shape)
463 | mean = np.sum(image * binary_mask) / np.sum(binary_mask)
464 | std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask))
465 | image = (image - mean) / std # normalize per image, using statistics within the brain, but apply to whole image
466 |
467 | mask[mask==2] = 1
468 |
469 | return image.transpose([1,2,0]), mask.transpose([1,2,0]), spacing
470 |
471 | def show_all_variables():
472 | model_vars = tf.trainable_variables()
473 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
474 |
475 |
--------------------------------------------------------------------------------
/saml_func.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import sys
4 | import tensorflow as tf
5 | from tensorflow.image import resize_images
6 | # try:
7 | # import special_grads
8 | # except KeyError as e:
9 | # print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e, file=sys.stderr)
10 |
11 | from tensorflow.python.platform import flags
12 | from layer import conv_block, deconv_block, fc, max_pool, concat2d
13 | from utils import xent, kd, _get_segmentation_cost, _get_compactness_cost
14 |
15 | class SAML:
16 | def __init__(self, args):
17 | """ Call construct_model_*() after initializing MASF"""
18 | self.args = args
19 |
20 | self.batch_size = args.meta_batch_size
21 | self.test_batch_size = args.test_batch_size
22 | self.volume_size = args.volume_size
23 | self.n_class = args.n_class
24 | self.compactness_loss_weight = args.compactness_loss_weight
25 | self.smoothness_loss_weight = args.smoothness_loss_weight
26 | self.margin = args.margin
27 |
28 | self.forward = self.forward_unet
29 | self.construct_weights = self.construct_unet_weights
30 | self.seg_loss = _get_segmentation_cost
31 | self.get_compactness_cost = _get_compactness_cost
32 |
33 | def construct_model_train(self, prefix='metatrain_'):
34 | # a: meta-train for inner update, b: meta-test for meta loss
35 | self.inputa = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
36 | self.labela = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
37 | self.inputa1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
38 | self.labela1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
39 | self.inputb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
40 | self.labelb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
41 | self.input_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
42 | self.label_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
43 | self.contour_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], 1])
44 | self.metric_label_group = tf.placeholder(tf.int32, shape=[self.batch_size, 1])
45 | self.training_mode = tf.placeholder_with_default(True, shape = None, name = "training_mode_for_bn_moving")
46 |
47 |
48 | self.clip_value = self.args.gradients_clip_value
49 | self.KEEP_PROB = tf.placeholder(tf.float32)
50 |
51 | with tf.variable_scope('model', reuse=None) as training_scope:
52 | if 'weights' in dir(self):
53 | print('weights already defined')
54 | training_scope.reuse_variables()
55 | weights = self.weights
56 | else:
57 | # Define the weights
58 | self.weights = weights = self.construct_weights()
59 |
60 | def task_metalearn(inp, reuse=True):
61 | # Function to perform meta learning update """
62 | inputa, inputa1, inputb, labela, labela1, labelb, input_group, contour_group, metric_label_group = inp
63 |
64 | # Obtaining the conventional task loss on meta-train
65 | task_outputa, _, _ = self.forward(inputa, weights, is_training=self.training_mode)
66 | task_lossa = self.seg_loss(task_outputa, labela)
67 | task_outputa1, _, _ = self.forward(inputa1, weights, is_training=self.training_mode)
68 | task_lossa1 = self.seg_loss(task_outputa1, labela1)
69 |
70 | ## perform inner update with plain gradient descent on meta-train
71 | grads = tf.gradients((task_lossa + task_lossa1)/2.0, list(weights.values()))
72 | grads = [tf.stop_gradient(grad) for grad in grads] # first-order gradients approximation
73 | gradients = dict(zip(weights.keys(), grads))
74 | # fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * gradients[key] for key in weights.keys()]))
75 | fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * tf.clip_by_norm(gradients[key], clip_norm=self.clip_value) for key in weights.keys()]))
76 |
77 | ## compute compactness loss
78 | task_outputb, task_predmaskb, _ = self.forward(inputb, fast_weights, is_training=self.training_mode)
79 | task_lossb = self.seg_loss(task_outputb, labelb)
80 | compactness_loss_b, length, area, boundary_b = self.get_compactness_cost(task_outputb, labelb)
81 | compactness_loss_b = self.compactness_loss_weight * compactness_loss_b
82 |
83 | # compute smoothness loss
84 | _, _, embeddings = self.forward(input_group, fast_weights, is_training=self.training_mode)
85 | coutour_embeddings = self.extract_coutour_embedding(contour_group, embeddings)
86 | metric_embeddings = self.forward_metric_net(coutour_embeddings)
87 |
88 | print (metric_label_group.shape)
89 | print (metric_embeddings.shape)
90 | smoothness_loss_b = tf.contrib.losses.metric_learning.triplet_semihard_loss(labels=metric_label_group[..., 0], embeddings=metric_embeddings, margin=self.margin)
91 | smoothness_loss_b = self.smoothness_loss_weight * smoothness_loss_b
92 | task_output = [task_lossb, compactness_loss_b, smoothness_loss_b, task_predmaskb, boundary_b, length, area, task_lossa, task_lossa1]
93 |
94 | return task_output
95 |
96 | self.global_step = tf.Variable(0, trainable=False)
97 | # self.inner_lr = tf.train.exponential_decay(learning_rate=self.args.inner_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate)
98 | # self.outer_lr = tf.train.exponential_decay(learning_rate=self.args.outer_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate)
99 | self.inner_lr = tf.Variable(self.args.inner_lr, trainable=False)
100 | self.outer_lr = tf.Variable(self.args.outer_lr, trainable=False)
101 | self.metric_lr = tf.Variable(self.args.metric_lr, trainable=False)
102 |
103 | input_tensors = (self.inputa, self.inputa1, self.inputb, self.labela, self.labela1, self.labelb, self.input_group, self.contour_group, self.metric_label_group)
104 | result = task_metalearn(inp=input_tensors)
105 | self.seg_loss_b, self.compactness_loss_b, self.smoothness_loss_b, self.task_predmaskb, self.boundary_b, self.length, self.area, self.seg_loss_a, self.seg_loss_a1= result
106 |
107 | ## Performance & Optimization
108 | if 'train' in prefix:
109 | self.source_loss = (self.seg_loss_a + self.seg_loss_a1) / 2.0
110 | self.target_loss = self.seg_loss_b + self.compactness_loss_b + self.smoothness_loss_b
111 |
112 | var_list_segmentor = [v for v in tf.trainable_variables() if 'metric' not in v.name.split('/')]
113 | var_list_metric = [v for v in tf.trainable_variables() if 'metric' in v.name.split('/')]
114 |
115 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
116 | with tf.control_dependencies(update_ops):
117 | self.task_train_op = tf.train.AdamOptimizer(learning_rate=self.inner_lr).minimize(self.source_loss, global_step=self.global_step)
118 |
119 | optimizer = tf.train.AdamOptimizer(self.outer_lr)
120 | gvs = optimizer.compute_gradients(self.target_loss, var_list=var_list_segmentor)
121 |
122 | # observe stability of gradients for meta loss
123 | # l2_norm = lambda t: tf.sqrt(tf.reduce_sum(tf.pow(t, 2)))
124 | # for grad, var in gvs:
125 | # tf.summary.histogram("gradients_norm/" + var.name, l2_norm(grad))
126 | # tf.summary.histogram("feature_extractor_var_norm/" + var.name, l2_norm(var))
127 | # tf.summary.histogram('gradients/' + var.name, var)
128 | # tf.summary.histogram("feature_extractor_var/" + var.name, var)
129 |
130 | # gvs = [(grad, var) for grad, var in gvs]
131 | gvs = [(tf.clip_by_norm(grad, clip_norm=self.clip_value), var) for grad, var in gvs]
132 | self.meta_train_op = optimizer.apply_gradients(gvs)
133 |
134 | # for grad, var in gvs:
135 | # tf.summary.histogram("gradients_norm_clipped/" + var.name, l2_norm(grad))
136 | # tf.summary.histogram('gradients_clipped/' + var.name, var)
137 |
138 | self.metric_train_op = tf.train.AdamOptimizer(self.metric_lr).minimize(self.smoothness_loss_b, var_list=var_list_metric)
139 |
140 | ## Summaries
141 | # scalar_summaries = []
142 | # train_images = []
143 | # val_images = []
144 |
145 | tf.summary.scalar(prefix+'source_1 loss', self.seg_loss_a)
146 | tf.summary.scalar(prefix+'source_2 loss', self.seg_loss_a1)
147 | tf.summary.scalar(prefix+'target_loss', self.seg_loss_b)
148 | tf.summary.scalar(prefix+'target_coutour_loss', self.compactness_loss_b)
149 | tf.summary.scalar(prefix+'target_length', self.length)
150 | tf.summary.scalar(prefix+'target_area', self.area)
151 | tf.summary.image("meta_test_mask", tf.expand_dims(tf.cast(self.task_predmaskb, tf.float32), 3))
152 | tf.summary.image("meta_test_gth", tf.expand_dims(tf.cast(self.labelb[:,:,:,1], tf.float32), 3))
153 | tf.summary.image("meta_test_image", tf.expand_dims(tf.cast(self.inputb[:,:,:,1], tf.float32), 3))
154 | tf.summary.image("meta_test_boundary", tf.expand_dims(tf.cast(self.boundary_b[:,:,:], tf.float32), 3))
155 | tf.summary.image("meta_test_ct_bg_sample", tf.expand_dims(tf.cast(self.contour_group[:,:,:, 0], tf.float32), 3))
156 | tf.summary.image("meta_input_group", tf.expand_dims(tf.cast(self.input_group[:,:,:, 1], tf.float32), 3))
157 | tf.summary.image("label_group", tf.expand_dims(tf.cast(self.label_group[:,:,:, 1], tf.float32), 3))
158 |
159 | def extract_coutour_embedding(self, coutour, embeddings):
160 |
161 | coutour_embeddings = coutour * embeddings
162 | average_embeddings = tf.reduce_sum(coutour_embeddings, [1,2])/tf.reduce_sum(coutour, [1,2])
163 | # print (coutour.shape)
164 | # print (embeddings.shape)
165 | # print (coutour_embeddings.shape)
166 | # print (average_embeddings.shape)
167 | return average_embeddings
168 |
169 | def construct_model_test(self, prefix='test'):
170 | self.test_input = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]])
171 | self.test_label = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.n_class])
172 |
173 | with tf.variable_scope('model', reuse=None) as testing_scope:
174 | if 'weights' in dir(self):
175 | testing_scope.reuse_variables()
176 | weights = self.weights
177 | else:
178 | raise ValueError('Weights not initilized. Create training model before testing model')
179 |
180 | outputs, mask, _ = self.forward(self.test_input, weights)
181 | losses = self.seg_loss(outputs, self.test_label)
182 | # self.pred_prob = tf.nn.softmax(outputs)
183 | self.outputs = mask
184 |
185 | self.test_loss = losses
186 | # self.test_acc = accuracies
187 |
188 | def forward_metric_net(self, x):
189 |
190 | with tf.variable_scope('metric', reuse=tf.AUTO_REUSE) as scope:
191 |
192 | w1 = tf.get_variable('w1', shape=[48,24])
193 | b1 = tf.get_variable('b1', shape=[24])
194 | out = fc(x, w1, b1, activation='leaky_relu')
195 | w2 = tf.get_variable('w2', shape=[24,16])
196 | b2 = tf.get_variable('b2', shape=[16])
197 | out = fc(out, w2, b2, activation='leaky_relu')
198 |
199 | return out
200 |
201 | def construct_unet_weights(self):
202 |
203 | weights = {}
204 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=tf.float32)
205 |
206 | with tf.variable_scope('conv1') as scope:
207 | weights['conv11_weights'] = tf.get_variable('weights', shape=[5, 5, 3, 16], initializer=conv_initializer)
208 | weights['conv11_biases'] = tf.get_variable('biases', [16])
209 | weights['conv12_weights'] = tf.get_variable('weights2', shape=[5, 5, 16, 16], initializer=conv_initializer)
210 | weights['conv12_biases'] = tf.get_variable('biases2', [16])
211 |
212 | with tf.variable_scope('conv2') as scope:
213 | weights['conv21_weights'] = tf.get_variable('weights', shape=[5, 5, 16, 32], initializer=conv_initializer)
214 | weights['conv21_biases'] = tf.get_variable('biases', [32])
215 | weights['conv22_weights'] = tf.get_variable('weights2', shape=[5, 5, 32, 32], initializer=conv_initializer)
216 | weights['conv22_biases'] = tf.get_variable('biases2', [32])
217 | ## Network has downsample here
218 |
219 | with tf.variable_scope('conv3') as scope:
220 | weights['conv31_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 64], initializer=conv_initializer)
221 | weights['conv31_biases'] = tf.get_variable('biases', [64])
222 | weights['conv32_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer)
223 | weights['conv32_biases'] = tf.get_variable('biases2', [64])
224 |
225 | with tf.variable_scope('conv4') as scope:
226 | weights['conv41_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 128], initializer=conv_initializer)
227 | weights['conv41_biases'] = tf.get_variable('biases', [128])
228 | weights['conv42_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer)
229 | weights['conv42_biases'] = tf.get_variable('biases2', [128])
230 | ## Network has downsample here
231 |
232 | with tf.variable_scope('conv5') as scope:
233 | weights['conv51_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 256], initializer=conv_initializer)
234 | weights['conv51_biases'] = tf.get_variable('biases', [256])
235 | weights['conv52_weights'] = tf.get_variable('weights2', shape=[3, 3, 256, 256], initializer=conv_initializer)
236 | weights['conv52_biases'] = tf.get_variable('biases2', [256])
237 |
238 | with tf.variable_scope('deconv6') as scope:
239 | weights['deconv6_weights'] = tf.get_variable('weights0', shape=[3, 3, 128, 256], initializer=conv_initializer)
240 | weights['deconv6_biases'] = tf.get_variable('biases0', shape=[128], initializer=conv_initializer)
241 | weights['conv61_weights'] = tf.get_variable('weights', shape=[3, 3, 256, 128], initializer=conv_initializer)
242 | weights['conv61_biases'] = tf.get_variable('biases', [128])
243 | weights['conv62_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer)
244 | weights['conv62_biases'] = tf.get_variable('biases2', [128])
245 |
246 | with tf.variable_scope('deconv7') as scope:
247 | weights['deconv7_weights'] = tf.get_variable('weights0', shape=[3, 3, 64, 128], initializer=conv_initializer)
248 | weights['deconv7_biases'] = tf.get_variable('biases0', shape=[64], initializer=conv_initializer)
249 | weights['conv71_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 64], initializer=conv_initializer)
250 | weights['conv71_biases'] = tf.get_variable('biases', [64])
251 | weights['conv72_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer)
252 | weights['conv72_biases'] = tf.get_variable('biases2', [64])
253 |
254 | with tf.variable_scope('deconv8') as scope:
255 | weights['deconv8_weights'] = tf.get_variable('weights0', shape=[3, 3, 32, 64], initializer=conv_initializer)
256 | weights['deconv8_biases'] = tf.get_variable('biases0', shape=[32], initializer=conv_initializer)
257 | weights['conv81_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 32], initializer=conv_initializer)
258 | weights['conv81_biases'] = tf.get_variable('biases', [32])
259 | weights['conv82_weights'] = tf.get_variable('weights2', shape=[3, 3, 32, 32], initializer=conv_initializer)
260 | weights['conv82_biases'] = tf.get_variable('biases2', [32])
261 |
262 | with tf.variable_scope('deconv9') as scope:
263 | weights['deconv9_weights'] = tf.get_variable('weights0', shape=[3, 3, 16, 32], initializer=conv_initializer)
264 | weights['deconv9_biases'] = tf.get_variable('biases0', shape=[16], initializer=conv_initializer)
265 | weights['conv91_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 16], initializer=conv_initializer)
266 | weights['conv91_biases'] = tf.get_variable('biases', [16])
267 | weights['conv92_weights'] = tf.get_variable('weights2', shape=[3, 3, 16, 16], initializer=conv_initializer)
268 | weights['conv92_biases'] = tf.get_variable('biases2', [16])
269 |
270 | with tf.variable_scope('output') as scope:
271 | weights['output_weights'] = tf.get_variable('weights', shape=[3, 3, 16, 2], initializer=conv_initializer)
272 | weights['output_biases'] = tf.get_variable('biases', [2])
273 |
274 | return weights
275 |
276 | def forward_unet(self, inp, weights, is_training=True):
277 |
278 | self.conv11 = conv_block(inp, weights['conv11_weights'], weights['conv11_biases'], scope='conv1/bn1', bn=False, is_training=is_training)
279 | self.conv12 = conv_block(self.conv11, weights['conv12_weights'], weights['conv12_biases'], scope='conv1/bn2', is_training=is_training)
280 | self.pool11 = max_pool(self.conv12, 2, 2, 2, 2, padding='VALID')
281 | # 192x192x16
282 | self.conv21 = conv_block(self.pool11, weights['conv21_weights'], weights['conv21_biases'], scope='conv2/bn1', is_training=is_training)
283 | self.conv22 = conv_block(self.conv21, weights['conv22_weights'], weights['conv22_biases'], scope='conv2/bn2', is_training=is_training)
284 | self.pool21 = max_pool(self.conv22, 2, 2, 2, 2, padding='VALID')
285 | # 96x96x32
286 | self.conv31 = conv_block(self.pool21, weights['conv31_weights'], weights['conv31_biases'], scope='conv3/bn1', is_training=is_training)
287 | self.conv32 = conv_block(self.conv31, weights['conv32_weights'], weights['conv32_biases'], scope='conv3/bn2', is_training=is_training)
288 | self.pool31 = max_pool(self.conv32, 2, 2, 2, 2, padding='VALID')
289 | # 48x48x64
290 | self.conv41 = conv_block(self.pool31, weights['conv41_weights'], weights['conv41_biases'], scope='conv4/bn1', is_training=is_training)
291 | self.conv42 = conv_block(self.conv41, weights['conv42_weights'], weights['conv42_biases'], scope='conv4/bn2', is_training=is_training)
292 | self.pool41 = max_pool(self.conv42, 2, 2, 2, 2, padding='VALID')
293 | # 24x24x128
294 | self.conv51 = conv_block(self.pool41, weights['conv51_weights'], weights['conv51_biases'], scope='conv5/bn1', is_training=is_training)
295 | self.conv52 = conv_block(self.conv51, weights['conv52_weights'], weights['conv52_biases'], scope='conv5/bn2', is_training=is_training)
296 | # 24x24x256
297 |
298 | ## add upsampling, meanwhile, channel number is reduced to half
299 | self.deconv6 = deconv_block(self.conv52, weights['deconv6_weights'], weights['deconv6_biases'], scope='deconv/bn6', is_training=is_training)
300 | # 48x48x128
301 | self.sum6 = concat2d(self.deconv6, self.deconv6)
302 | self.conv61 = conv_block(self.sum6, weights['conv61_weights'], weights['conv61_biases'], scope='conv6/bn1', is_training=is_training)
303 | self.conv62 = conv_block(self.conv61, weights['conv62_weights'], weights['conv62_biases'], scope='conv6/bn2', is_training=is_training)
304 | # 48x48x128
305 |
306 | self.deconv7 = deconv_block(self.conv62, weights['deconv7_weights'], weights['deconv7_biases'], scope='deconv/bn7', is_training=is_training)
307 | # 96x96x64
308 | self.sum7 = concat2d(self.deconv7, self.deconv7)
309 | self.conv71 = conv_block(self.sum7, weights['conv71_weights'], weights['conv71_biases'], scope='conv7/bn1', is_training=is_training)
310 | self.conv72 = conv_block(self.conv71, weights['conv72_weights'], weights['conv72_biases'], scope='conv7/bn2', is_training=is_training)
311 | # 96x96x64
312 |
313 | self.deconv8 = deconv_block(self.conv72, weights['deconv8_weights'], weights['deconv8_biases'], scope='deconv/bn8', is_training=is_training)
314 | # 192x192x32
315 | self.sum8 = concat2d(self.deconv8, self.deconv8)
316 | self.conv81 = conv_block(self.sum8, weights['conv81_weights'], weights['conv81_biases'], scope='conv8/bn1', is_training=is_training)
317 | self.conv82 = conv_block(self.conv81, weights['conv82_weights'], weights['conv82_biases'], scope='conv8/bn2', is_training=is_training)
318 | self.conv82_resize = tf.image.resize_images(self.conv82, [384, 384], method=tf.image.ResizeMethod.BILINEAR, align_corners=False)
319 | # 192x192x32
320 |
321 | self.deconv9 = deconv_block(self.conv82, weights['deconv9_weights'], weights['deconv9_biases'], scope='deconv/bn9', is_training=is_training)
322 | # 384x384x16
323 | self.sum9 = concat2d(self.deconv9, self.deconv9)
324 | self.conv91 = conv_block(self.sum9, weights['conv91_weights'], weights['conv91_biases'], scope='conv9/bn1', is_training=is_training)
325 | self.conv92 = conv_block(self.conv91, weights['conv92_weights'], weights['conv92_biases'], scope='conv9/bn2', is_training=is_training)
326 | # 384x384x16
327 |
328 | self.logits = conv_block(self.conv92, weights['output_weights'], weights['output_biases'], scope='outpu/bn', bn=False, is_training=is_training)
329 | #384x384x2
330 |
331 | self.pred_prob = tf.nn.softmax(self.logits) # shape [batch, w, h, num_classes]
332 | self.pred_compact = tf.argmax(self.pred_prob, axis=-1) # shape [batch, w, h]
333 |
334 | self.embeddings = concat2d(self.conv82_resize, self.conv92)
335 |
336 | return self.pred_prob, self.pred_compact, self.embeddings
337 |
--------------------------------------------------------------------------------