├── figure ├── saml.png └── protocol.png ├── README.md ├── layer.py ├── main.py ├── train.py ├── data_generator.py ├── utils.py └── saml_func.py /figure/saml.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuquande/SAML/HEAD/figure/saml.png -------------------------------------------------------------------------------- /figure/protocol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuquande/SAML/HEAD/figure/protocol.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAML & A Multi-site Dataset for Prostate MRI Segmentation 2 | by [Quande Liu](https://github.com/liuquande), [Qi Dou](http://www.cse.cuhk.edu.hk/~qdou/), [Pheng-Ann Heng](http://www.cse.cuhk.edu.hk/~pheng/). 3 | 4 | ### Introduction 5 | 6 | * The Tensorflow implementation for our MICCAI 2020 paper '[Shape-aware Meta-learning for Generalizing Prostate MRI Segmentation to Unseen Domains](https://arxiv.org/pdf/2007.02035.pdf)'. 7 | 8 |

9 | 10 |

11 | 12 | * A well-organized multi-site dataset (from six data sources) for prostate MRI segmentation, that can support research in various problem settings with need of multi-site data, such as Domain Generalization, Multi-site Learning and Life-long Learning, etc. For more details and downloading link of the dataset, please [Find Here](https://liuquande.github.io/SAML/). 13 | 14 | 15 |

16 | 17 |

18 | 19 | 20 | ### Setup & Usage for the Code 21 | 22 | 1. Check dependencies: 23 | ```shell 24 | python==2.7.17 25 | numpy==1.16.6 26 | scipy==1.2.1 27 | tensorflow-gpu==1.12.0 28 | tensorboard==1.12.2 29 | SimpleITK==1.2.0 30 | ``` 31 | 2. To train the model, you need to specify the training configurations (can simply use the default setting) in main.py, then run: 32 | ```shell 33 | python main.py --phase=train 34 | ``` 35 | 36 | 2. To evaluate the model, run: 37 | ```shell 38 | python main.py --phase=test --restore_model='/path/to/test_model.cpkt' 39 | ``` 40 | You will see the output results in the folder `./output/`. 41 | 42 | ### Citation 43 | If this repository is useful for your research, please cite: 44 | 45 | ``` 46 | @article{liu2020shape, 47 | title={Shape-aware Meta-learning for Generalizing Prostate MRI Segmentation to Unseen Domains}, 48 | author={Liu, Quande and Dou, Qi and Heng, Pheng-Ann}, 49 | journal={International Conference on Medical Image Computing and Computer Assisted Intervention}, 50 | year={2020} 51 | } 52 | ``` 53 | 54 | ### Questions 55 | 56 | For further question about the code or dataset, please contact 'qdliu@cse.cuhk.edu.hk' 57 | -------------------------------------------------------------------------------- /layer.py: -------------------------------------------------------------------------------- 1 | from tensorflow.contrib.layers.python import layers as tf_layers 2 | from tensorflow.python.platform import flags 3 | import tensorflow as tf 4 | import tensorflow.contrib.slim as slim 5 | 6 | def concat2d(x1,x2): 7 | """ concatenation without offset check""" 8 | x1_shape = tf.shape(x1) 9 | x2_shape = tf.shape(x2) 10 | try: 11 | tf.equal(x1_shape[0:-2], x2_shape[0: -2]) 12 | except: 13 | print("x1_shape: %s"%str(x1.get_shape().as_list())) 14 | print("x2_shape: %s"%str(x2.get_shape().as_list())) 15 | raise ValueError("Cannot concatenate tensors with different shape, igonoring feature map depth") 16 | return tf.concat([x1, x2], 3) 17 | 18 | def normalize(inp, activation, reuse=tf.AUTO_REUSE, scope='', form='batch_norm', is_training=True): 19 | if form == 'batch_norm': 20 | return tf_layers.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope, is_training=is_training) 21 | elif form == 'layer_norm': 22 | return tf_layers.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 23 | elif form == 'None': 24 | if activation is not None: 25 | return activation(inp) 26 | else: 27 | return inp 28 | 29 | def conv_block(inp, cweight, bweight, scope='', bn=True, is_training=True): 30 | """ Perform, conv, batch norm, nonlinearity, and max pool """ 31 | conv = tf.nn.conv2d(inp, cweight, strides=[1, 1, 1, 1], padding='SAME') 32 | conv = tf.nn.bias_add(conv, bweight) 33 | 34 | if bn == True: 35 | normed = normalize(conv, tf.nn.relu, scope=scope, is_training=is_training) 36 | return normed 37 | else: 38 | return conv 39 | # relu = tf.nn.leaky_relu(normed) 40 | # normalize = batch_norm(relu, True) 41 | 42 | 43 | def deconv_block(inp, cweight, bweight, scope='', is_training=True): 44 | # x_shape = tf.shape(inp) 45 | x_shape = inp.get_shape().as_list() 46 | output_shape = tf.stack([x_shape[0], x_shape[1]*2, x_shape[2]*2, x_shape[3]//2]) 47 | deconv = tf.nn.conv2d_transpose(inp, cweight, output_shape, strides=[1,2,2,1], padding='SAME') 48 | deconv = tf.nn.bias_add(deconv, bweight) 49 | 50 | normed = normalize(deconv, tf.nn.relu, scope=scope, is_training=is_training) 51 | return normed 52 | 53 | 54 | def max_pool(x, filter_height, filter_width, stride_y, stride_x, padding='SAME'): 55 | """Create a max pooling layer.""" 56 | return tf.nn.max_pool(x, ksize=[1, filter_height, filter_width, 1], strides=[1, stride_y, stride_x, 1], padding=padding) 57 | 58 | def lrn(x, radius, alpha, beta, bias=1.0): 59 | """Create a local response normalization layer.""" 60 | return tf.nn.local_response_normalization(x, depth_radius=radius, alpha=alpha, beta=beta, bias=bias) 61 | 62 | def dropout(x, keep_prob): 63 | """Create a dropout layer.""" 64 | return tf.nn.dropout(x, keep_prob) 65 | 66 | def fc(x, wweight, bweight, activation=None): 67 | """Create a fully connected layer.""" 68 | 69 | act = tf.nn.xw_plus_b(x, wweight, bweight) 70 | 71 | if activation is 'relu': 72 | return tf.nn.relu(act) 73 | elif activation is 'leaky_relu': 74 | return tf.nn.leaky_relu(act) 75 | elif activation is None: 76 | return act 77 | else: 78 | raise NotImplementedError 79 | 80 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.python.platform import flags 6 | from data_generator import ImageDataGenerator 7 | from saml_func import SAML 8 | from train import train 9 | from train import test 10 | import datetime 11 | import argparse 12 | from utils import check_folder, show_all_variables 13 | import logging 14 | 15 | currtime = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 16 | tf.set_random_seed(2) 17 | 18 | def parse_args(train_date): 19 | desc = "Tensorflow implementation of DenseUNet for prostate segmentation" 20 | parser = argparse.ArgumentParser(description=desc) 21 | parser.add_argument('--gpu', type=str, default='0', help='train or test or guide') 22 | parser.add_argument('--phase', type=str, default='train', help='train or test or guide') 23 | parser.add_argument('--n_class', type=int, default=2, help='The size of class') 24 | 25 | ## Training operations 26 | parser.add_argument('--target_domain', type=str, default='ISBI', help='dataset_name') 27 | parser.add_argument('--volume_size', type=list, default=[384, 384, 3], help='The size of input data') 28 | parser.add_argument('--label_size', type=list, default=[384, 384, 1], help='The size of label') 29 | parser.add_argument('--epoch', type=int, default=1, help='The number of epochs to run') 30 | parser.add_argument('--train_iterations', type=int, default=10000, help='The number of training iterations') 31 | parser.add_argument('--meta_batch_size', type=int, default=5, help='number of images sampled per source domain') 32 | parser.add_argument('--test_batch_size', type=int, default=1, help='number of images sampled per source domain') 33 | parser.add_argument('--inner_lr', type=float, default=1e-4, help='The learning rate') 34 | parser.add_argument('--outer_lr', type=float, default=1e-3, help='The learning rate') 35 | parser.add_argument('--metric_lr', type=float, default=1e-3, help='The learning rate') 36 | parser.add_argument('--margin', type=float, default=10.0, help='The learning rate') 37 | parser.add_argument('--compactness_loss_weight', type=float, default=1.0, help='The learning rate') 38 | parser.add_argument('--smoothness_loss_weight', type=float, default=0.005, help='The learning rate') 39 | parser.add_argument('--clipNorm', type=int, default=True, help='number of images sampled per source domain') 40 | parser.add_argument('--gradients_clip_value', type=float, default=10.0, help='The learning rate') 41 | 42 | # Logging, saving, and testing options 43 | parser.add_argument('--resume', type=int, default=False, help='number of images sampled per source domain') 44 | parser.add_argument('--log', type=int, default=True, help='write tensorboard') 45 | parser.add_argument('--decay_step', type=float, default=500, help='The learning rate') 46 | parser.add_argument('--decay_rate', type=float, default=0.95, help='The learning rate') 47 | parser.add_argument('--test_freq', type=int, default=200, help='The number of ckpt_save_freq') 48 | parser.add_argument('--save_freq', type=int, default=200, help='The number of ckpt_save_freq') 49 | parser.add_argument('--print_interval', type=int, default=5, help='The frequency to write tensorboard') 50 | parser.add_argument('--summary_interval', type=int, default=20, help='The frequency to write tensorboard') 51 | parser.add_argument('--restored_model', type=str, default=None, help='Model to restore') 52 | parser.add_argument('--test_model', type=str, default=None, help='Model to restore') 53 | # parser.add_argument('--dropout', type=str, default=1, help='dropout rate') 54 | # parser.add_argument('--cost_kwargs', type=str, default=1, help='cost_kwargs') 55 | # parser.add_argument('--opt_kwargs', type=str, default=1, help='opt_kwargs') 56 | 57 | parser.add_argument('--checkpoint_dir', type=str, default='../output/' + train_date + '/checkpoints/' , 58 | help='Directory name to save the checkpoints') 59 | parser.add_argument('--result_dir', type=str, default='../output/' + train_date + '/results/', 60 | help='Directory name to save the generated images') 61 | parser.add_argument('--log_dir', type=str, default='../output/' + train_date + '/logs/', 62 | help='Directory name to save training logs') 63 | parser.add_argument('--sample_dir', type=str, default='../output/' + train_date + '/samples/', 64 | help='Directory name to save the samples on training') 65 | 66 | return check_args(parser.parse_args()) 67 | 68 | """checking arguments""" 69 | def check_args(args): 70 | # --checkpoint_dir 71 | check_folder(args.checkpoint_dir) 72 | # --result_dir 73 | check_folder(args.result_dir) 74 | # --result_dir 75 | check_folder(args.log_dir) 76 | # --sample_dir 77 | check_folder(args.sample_dir) 78 | 79 | return args 80 | 81 | def main(): 82 | train_date = 'xxx' 83 | args = parse_args(train_date) 84 | 85 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 86 | 87 | # define logger 88 | logging.basicConfig(filename=args.log_dir+"/"+args.phase+'_log.txt', level=logging.DEBUG, format='%(asctime)s %(message)s') 89 | logging.getLogger().addHandler(logging.StreamHandler()) 90 | 91 | # print all parameters 92 | logging.info("Usage:") 93 | logging.info(" {0}".format(" ".join([x for x in sys.argv]))) 94 | logging.debug("All settings used:") 95 | 96 | os.system('cp main.py %s' % (args.log_dir)) # bkp of train procedure 97 | os.system('cp saml_func.py %s' % (args.log_dir)) # bkp of train procedure 98 | os.system('cp train.py %s' % (args.log_dir)) # bkp of train procedure 99 | os.system('cp utils.py %s' % (args.log_dir)) # bkp of train procedure 100 | os.system('cp data_generator.py %s' % (args.log_dir)) 101 | 102 | 103 | filelist_root = '../dataset' 104 | source_list = ['HK', 'ISBI', 'ISBI_1.5', 'I2CVB','UCL', 'BIDMC']#'ISBI_1.5', 'I2CVB', 'UCL','BIDMC']#, 'I2CVB', 'ISBI_1.5', 'UCL', 'BIDMC']#'I2CVB', 'UCL', 'BIDMC', 'HK'] 105 | source_list.remove(args.target_domain) 106 | 107 | # Constructing model 108 | model = SAML(args) 109 | model.construct_model_train() 110 | model.construct_model_test() 111 | 112 | model.summ_op = tf.summary.merge_all() 113 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)) 114 | sess = tf.InteractiveSession() 115 | 116 | tf.global_variables_initializer().run() 117 | show_all_variables() 118 | 119 | # restore model ---- 120 | resume_itr = 0 121 | model_file = None 122 | if args.resume: 123 | model_file = tf.train.latest_checkpoint(args.checkpoint_dir) 124 | if model_file: 125 | ind1 = model_file.index('model') 126 | resume_itr = int(model_file[ind1+5:]) 127 | print("Restoring model weights from " + model_file) 128 | saver.restore(sess, model_file) 129 | 130 | train_file_list = [os.path.join(filelist_root, source_domain+'_train_list') for source_domain in source_list] 131 | test_file_list = [os.path.join(filelist_root, args.target_domain+'_train_list')] 132 | 133 | # start training ---- 134 | if args.phase == 'train': 135 | train(model, saver, sess, train_file_list, test_file_list[0], args, resume_itr) 136 | else: 137 | args.test_model = 'xxx' 138 | saver.restore(sess, args.test_model) 139 | logging.info("testing model restored %s" % args.test_model) 140 | 141 | test_dice, test_dice_arr, test_haus, test_haus_arr = test(sess, test_file_list[0], model, args) 142 | with open((os.path.join(args.log_dir,'test.txt')), 'a') as f: 143 | print >> f, 'testing model %s :' % (args.test_model) 144 | print >> f, ' Unseen domain testing results: Dice: %f' %(test_dice), test_dice_arr 145 | print >> f, ' Unseen domain testing results: Haus: %f' %(test_haus), test_haus_arr 146 | 147 | if __name__ == "__main__": 148 | main() 149 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.platform import flags 4 | from data_generator import ImageDataGenerator 5 | import logging 6 | from utils import _eval_dice, _connectivity_region_analysis, parse_fn, _crop_object_region, _get_coutour_sample, parse_fn_haus,_eval_haus 7 | import time 8 | import os 9 | import SimpleITK as sitk 10 | 11 | def train(model, saver, sess, train_file_list, test_file, args, resume_itr=0): 12 | 13 | if args.log: 14 | train_writer = tf.summary.FileWriter(args.log_dir + '/' + args.phase + '/', sess.graph) 15 | 16 | # Data loaders 17 | with tf.device('/cpu:0'): 18 | tr_data_list, train_iterator_list, train_next_list = [],[],[] 19 | for i in range(len(train_file_list)): 20 | tr_data = ImageDataGenerator(train_file_list[i], mode='training', \ 21 | batch_size=args.meta_batch_size, num_classes=args.n_class, shuffle=True) 22 | tr_data_list.append(tr_data) 23 | train_iterator_list.append(tf.data.Iterator.from_structure(tr_data.data.output_types,tr_data.data.output_shapes)) 24 | train_next_list.append(train_iterator_list[i].get_next()) 25 | 26 | # Ops for initializing different iterators 27 | training_init_op = [] 28 | train_batches_per_epoch = [] 29 | for i in range(len(train_file_list)): 30 | training_init_op.append(train_iterator_list[i].make_initializer(tr_data_list[i].data)) 31 | sess.run(training_init_op[i]) # initialize training sample generator at itr=0 32 | 33 | # Training begins 34 | best_test_dice = 0 35 | best_test_haus = 0 36 | for epoch in xrange(0, args.epoch): 37 | for itr in range(resume_itr, args.train_iterations): 38 | start = time.time() 39 | # Sampling training and test tasks 40 | num_training_tasks = len(train_file_list) 41 | num_meta_train = 2#num_training_tasks-1 42 | num_meta_test = 1#num_training_tasks-num_meta_train # as setting num_meta_test = 1 43 | 44 | # Randomly choosing meta train and meta test domains 45 | task_list = np.random.permutation(num_training_tasks) 46 | meta_train_index_list = task_list[:2] 47 | meta_test_index_list = task_list[-1:] 48 | 49 | # Sampling meta-train, meta-test data 50 | for i in range(num_meta_train): 51 | task_ind = meta_train_index_list[i] 52 | if i == 0: 53 | inputa, labela = sess.run(train_next_list[task_ind]) 54 | elif i == 1: 55 | inputa1, labela1 = sess.run(train_next_list[task_ind]) 56 | else: 57 | raise RuntimeError('check number of meta-train domains.') 58 | 59 | for i in range(num_meta_test): 60 | task_ind = meta_test_index_list[i] 61 | if i == 0: 62 | inputb, labelb = sess.run(train_next_list[task_ind]) 63 | else: 64 | raise RuntimeError('check number of meta-test domains.') 65 | 66 | input_group = np.concatenate((inputa[:2],inputa1[:1],inputb[:2]), axis=0) 67 | label_group = np.concatenate((labela[:2],labela1[:1],labelb[:2]), axis=0) 68 | 69 | contour_group, metric_label_group = _get_coutour_sample(label_group) 70 | 71 | feed_dict = {model.inputa: inputa, model.labela: labela, \ 72 | model.inputa1: inputa1, model.labela1: labela1, \ 73 | model.inputb: inputb, model.labelb: labelb, \ 74 | model.input_group:input_group, \ 75 | model.label_group:label_group, \ 76 | model.contour_group:contour_group, \ 77 | model.metric_label_group:metric_label_group, \ 78 | model.KEEP_PROB: 1.0} 79 | 80 | output_tensors = [model.task_train_op, model.meta_train_op, model.metric_train_op] 81 | output_tensors.extend([model.summ_op, model.seg_loss_b, model.compactness_loss_b, model.smoothness_loss_b, model.target_loss, model.source_loss]) 82 | _, _, _, summ_writer, seg_loss_b, compactness_loss_b, smoothness_loss_b, target_loss, source_loss = sess.run(output_tensors, feed_dict) 83 | # output_tensors = [model.task_train_op] 84 | # output_tensors.extend([model.source_loss]) 85 | # _, source_loss = sess.run(output_tensors, feed_dict) 86 | 87 | if itr % args.print_interval == 0: 88 | logging.info("Epoch: [%2d] [%6d/%6d] time: %4.4f inner lr:%.8f outer lr:%.8f" % (epoch, itr, args.train_iterations, (time.time()-start), model.inner_lr.eval(), model.outer_lr.eval())) 89 | logging.info('sou_loss: %.7f, tar_loss: %.7f, tar_seg_loss: %.7f, tar_compactness_loss: %.7f, tar_smoothness_loss: %.7f' % (source_loss, target_loss, seg_loss_b, compactness_loss_b, smoothness_loss_b)) 90 | 91 | if itr % args.summary_interval == 0: 92 | train_writer.add_summary(summ_writer, itr) 93 | train_writer.flush() 94 | 95 | if (itr!=0) and itr % args.save_freq == 0: 96 | saver.save(sess, args.checkpoint_dir + '/epoch_' + str(epoch) + '_itr_'+str(itr) + ".model.cpkt") 97 | 98 | # Testing periodically 99 | if (itr!=0) and itr % args.test_freq == 0: 100 | test_dice, test_dice_arr, test_haus, test_haus_arr = test(sess, test_file, model, args) 101 | 102 | if test_dice > best_test_dice: 103 | best_test_dice = test_dice 104 | 105 | with open((os.path.join(args.log_dir,'eva.txt')), 'a') as f: 106 | print >> f, 'Iteration %d :' % (itr) 107 | print >> f, ' Unseen domain testing results: Dice: %f' %(test_dice), test_dice_arr 108 | print >> f, ' Current best accuracy %f' %(best_test_dice) 109 | print >> f, ' Unseen domain testing results: Haus: %f' %(test_haus), test_haus_arr 110 | print >> f, ' Current best accuracy %f' %(best_test_haus) 111 | # Save model 112 | 113 | def test(sess, test_list, model, args): 114 | 115 | dice = [] 116 | haus = [] 117 | start = time.time() 118 | 119 | with open(test_list, 'r') as fp: 120 | rows = fp.readlines() 121 | test_list = [row[:-1] if row[-1] == '\n' else row for row in rows] 122 | 123 | for fid, filename in enumerate(test_list): 124 | image, mask, spacing = parse_fn_haus(filename) 125 | pred_y = np.zeros(mask.shape) 126 | 127 | frame_list = [kk for kk in range(1, image.shape[2] - 1)] 128 | 129 | for ii in xrange(int(np.floor(image.shape[2] // model.test_batch_size))): 130 | vol = np.zeros([model.test_batch_size, model.volume_size[0], model.volume_size[1], model.volume_size[2]]) 131 | 132 | for idx, jj in enumerate(frame_list[ii * model.test_batch_size: (ii + 1) * model.test_batch_size]): 133 | vol[idx, ...] = image[..., jj - 1: jj + 2].copy() 134 | 135 | pred_student = sess.run((model.outputs), feed_dict={model.test_input: vol, \ 136 | model.KEEP_PROB: 1.0,\ 137 | model.training_mode: True}) 138 | 139 | for idx, jj in enumerate(frame_list[ii * model.test_batch_size: (ii + 1) * model.test_batch_size]): 140 | pred_y[..., jj] = pred_student[idx, ...].copy() 141 | 142 | processed_pred_y = _connectivity_region_analysis(pred_y) 143 | 144 | dice_subject = _eval_dice(mask, processed_pred_y) 145 | 146 | # print spacing 147 | dice.append(dice_subject) 148 | # haus.append(haus_subject) 149 | # _save_nii_prediction(mask, processed_pred_y, pred_y, args.result_dir, '_' + filename[-26:-20]) 150 | dice_avg = np.mean(dice, axis=0).tolist()[0] 151 | # haus_avg = np.mean(haus, axis=0).tolist()[0] 152 | 153 | logging.info("dice_avg %.4f" % (dice_avg)) 154 | # logging.info("haus_avg %.4f" % (haus_avg)) 155 | 156 | return dice_avg, dice, 0, 0 157 | # return dice_avg, dice, haus_avg, haus 158 | 159 | def _save_nii_prediction(gth, comp_pred, pre_pred, out_folder, out_bname): 160 | sitk.WriteImage(sitk.GetImageFromArray(gth), out_folder + out_bname + 'gth.nii.gz') 161 | sitk.WriteImage(sitk.GetImageFromArray(pre_pred), out_folder + out_bname + 'premask.nii.gz') 162 | sitk.WriteImage(sitk.GetImageFromArray(comp_pred), out_folder + out_bname + 'mask.nii.gz') 163 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | # from matplotlib import pyplot as plt 5 | from tensorflow.python.framework import dtypes 6 | from tensorflow.python.framework.ops import convert_to_tensor 7 | import skimage as sk 8 | from skimage import transform 9 | import SimpleITK as sitk 10 | 11 | IMAGENET_MEAN = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32) 12 | 13 | 14 | class ImageDataGenerator(object): 15 | 16 | def __init__(self, txt_file, mode, batch_size, num_classes, shuffle=True, buffer_size=5): 17 | 18 | """Create a new ImageDataGenerator. 19 | Receives a path string to a text file, where each line has a path string to an image and 20 | separated by a space, then with an integer referring to the class number. 21 | 22 | Args: 23 | txt_file: path to the text file. 24 | mode: either 'training' or 'validation'. Depending on this value, different parsing functions will be used. 25 | batch_size: number of images per batch. 26 | num_classes: number of classes in the dataset. 27 | shuffle: wether or not to shuffle the data in the dataset and the initial file list. 28 | buffer_size: number of images used as buffer for TensorFlows shuffling of the dataset. 29 | 30 | Raises: 31 | ValueError: If an invalid mode is passed. 32 | """ 33 | 34 | self.txt_file = txt_file 35 | self.num_classes = num_classes 36 | 37 | # retrieve the data from the text file 38 | self._read_txt_file() 39 | 40 | # number of samples in the dataset 41 | self.data_size = len(self.img_paths) 42 | 43 | # initial shuffling of the file and label lists together 44 | if shuffle: 45 | self._shuffle_lists() 46 | 47 | # convert lists to TF tensor 48 | self.img_paths = convert_to_tensor(self.img_paths, dtype=dtypes.string) 49 | 50 | # create dataset 51 | data = tf.data.Dataset.from_tensor_slices((self.img_paths)) 52 | 53 | # repeat indefinitely (train.py will count the epochs) 54 | data = data.repeat() 55 | 56 | # distinguish between train/infer. when calling the parsing functions 57 | self.get_patches_fn = lambda filename: tf.py_func(self.extract_patch, [filename, [384,384,3], 2], [tf.float32, tf.float32]) 58 | 59 | if mode == 'training': 60 | data = data.map(self.get_patches_fn, num_parallel_calls=8) 61 | 62 | elif mode == 'inference': 63 | data = data.map(self._parse_function_inference, num_parallel_calls=8) 64 | 65 | else: 66 | raise ValueError("Invalid mode '%s'." % (mode)) 67 | 68 | # shuffle the first `buffer_size` elements of the dataset 69 | if shuffle: 70 | data = data.shuffle(buffer_size=buffer_size) 71 | 72 | # create a new dataset with batches of images 73 | data = data.batch(batch_size) 74 | 75 | self.data = data 76 | 77 | def _read_txt_file(self): 78 | """Read the content of the text file and store it into lists.""" 79 | with open(self.txt_file, 'r') as f: 80 | rows = f.readlines() 81 | self.img_paths = [row[:-1] for row in rows] 82 | 83 | def _shuffle_lists(self): 84 | """Conjoined shuffling of the list of paths and labels.""" 85 | path = self.img_paths 86 | permutation = np.random.permutation(self.data_size) 87 | self.img_paths = [] 88 | for i in permutation: 89 | self.img_paths.append(path[i]) 90 | 91 | def extract_patch(self, filename, patch_size, num_class, num_patches=1): 92 | """Input parser for samples of the training set.""" 93 | # convert label number into one-hot-encoding 94 | 95 | image, mask = self.parse_fn(filename) # get the image and its mask 96 | image_patches = [] 97 | mask_patches = [] 98 | num_patches_now = 0 99 | 100 | while num_patches_now < num_patches: 101 | # z = np.random.randint(1, mask.shape[2]-1) 102 | z = self.random_patch_center_z(mask, patch_size=patch_size) # define the centre of current patch 103 | image_patch = image[:, :, z-1:z+2] 104 | mask_patch = mask[:, :, z] 105 | 106 | image_patches.append(image_patch) 107 | mask_patches.append(mask_patch) 108 | num_patches_now += 1 109 | image_patches = np.stack(image_patches) # make into 4D (batch_size, patch_size[0], patch_size[1], patch_size[2]) 110 | mask_patches = np.stack(mask_patches) # make into 4D (batch_size, patch_size[0], patch_size[1], patch_size[2]) 111 | 112 | mask_patches = self._label_decomp(mask_patches, num_cls=num_class) # make into 5D (batch_size, patch_size[0], patch_size[1], patch_size[2], num_classes) 113 | #print image_patches.shape 114 | return image_patches[0,...].astype(np.float32), mask_patches[0,...].astype(np.float32) 115 | 116 | def random_patch_center_z(self, mask, patch_size): 117 | # bounded within the brain mask region 118 | limX, limY, limZ = np.where(mask>0) 119 | if (np.min(limZ) + patch_size[2] // 2 + 1) < (np.max(limZ) - patch_size[2] // 2): 120 | z = np.random.randint(low = np.min(limZ) + patch_size[2] // 2 + 1, high = np.max(limZ) - patch_size[2] // 2) 121 | else: 122 | z = np.random.randint(low = patchsize[2]//2, high = mask.shape[2] - patchsize[2]//2) 123 | 124 | limX, limY, limZ = np.where(mask>0) 125 | 126 | z = np.random.randint(low = max(1, np.min(limZ)), high = min(np.max(limZ), mask.shape[2] - 2)) 127 | # z = np.random.randint(low = max(1, np.min(limZ)), high = min(np.max(limZ), mask.shape[2] - 2)) 128 | 129 | return z 130 | 131 | def parse_fn(self, data_path): 132 | ''' 133 | :param image_path: path to a folder of a patient 134 | :return: normalized entire image with its corresponding label 135 | In an image, the air region is 0, so we only calculate the mean and std within the brain area 136 | For any image-level normalization, do it here 137 | ''' 138 | path = data_path.split(",") 139 | image_path = path[0] 140 | label_path = path[1] 141 | #itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz')) 142 | #itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz')) 143 | itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz')) 144 | itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz')) 145 | # itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz')) 146 | 147 | image = sitk.GetArrayFromImage(itk_image) 148 | mask = sitk.GetArrayFromImage(itk_mask) 149 | #image[image >= 1000] = 1000 150 | binary_mask = np.ones(mask.shape) 151 | mean = np.sum(image * binary_mask) / np.sum(binary_mask) 152 | std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask)) 153 | image = (image - mean) / std # normalize per image, using statistics within the brain, but apply to whole image 154 | 155 | mask[mask==2] = 1 156 | 157 | return image.transpose([1,2,0]), mask.transpose([1,2,0]) # transpose the orientation of the 158 | 159 | 160 | def _label_decomp(self, label_vol, num_cls): 161 | """ 162 | decompose label for softmax classifier 163 | original labels are batchsize * W * H * 1, with label values 0,1,2,3... 164 | this function decompse it to one hot, e.g.: 0,0,0,1,0,0 in channel dimension 165 | numpy version of tf.one_hot 166 | """ 167 | one_hot = [] 168 | for i in xrange(num_cls): 169 | _vol = np.zeros(label_vol.shape) 170 | _vol[label_vol == i] = 1 171 | one_hot.append(_vol) 172 | 173 | return np.stack(one_hot, axis=-1) 174 | # def augment(self, x): 175 | # # add more types of augmentations here 176 | # augmentations = [self.flip] 177 | # for f in augmentations: 178 | # x = tf.cond(tf.random_uniform([], 0, 1) < 0.25, lambda: f(x), lambda: x) 179 | 180 | # return x 181 | 182 | # def flip(self, x): 183 | # """Flip augmentation 184 | # Args: 185 | # x: Image to flip 186 | # Returns: 187 | # Augmented image 188 | # """ 189 | # x = tf.image.random_flip_left_right(x) 190 | # # x = tf.image.random_flip_up_down(x) 191 | 192 | # return x 193 | 194 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ Utility functions. """ 2 | import numpy as np 3 | import os 4 | import random 5 | import tensorflow as tf 6 | 7 | from tensorflow.contrib.layers.python import layers as tf_layers 8 | from tensorflow.python.platform import flags 9 | import SimpleITK as sitk 10 | from scipy import ndimage 11 | import itertools 12 | from tensorflow.contrib import slim 13 | from scipy.ndimage import _ni_support 14 | from scipy.ndimage.morphology import distance_transform_edt, binary_erosion,\ 15 | generate_binary_structure 16 | FLAGS = flags.FLAGS 17 | 18 | ## Image reader 19 | def get_images(paths, labels, nb_samples=None, shuffle=True): 20 | if nb_samples is not None: 21 | sampler = lambda x: random.sample(x, nb_samples) 22 | else: 23 | sampler = lambda x: x 24 | images = [(i, os.path.join(path, image)) \ 25 | for i, path in zip(labels, paths) \ 26 | for image in sampler(os.listdir(path))] 27 | if shuffle: 28 | random.shuffle(images) 29 | return images 30 | 31 | ## Loss functions 32 | def mse(pred, label): 33 | pred = tf.reshape(pred, [-1]) 34 | label = tf.reshape(label, [-1]) 35 | return tf.reduce_mean(tf.square(pred-label)) 36 | 37 | def xent(pred, label): 38 | return tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=label) 39 | 40 | def kd(data1, label1, data2, label2, bool_indicator, n_class=7, temperature=2.0): 41 | 42 | kd_loss = 0.0 43 | eps = 1e-16 44 | 45 | prob1s = [] 46 | prob2s = [] 47 | 48 | for cls in range(n_class): 49 | mask1 = tf.tile(tf.expand_dims(label1[:, cls], -1), [1, n_class]) 50 | logits_sum1 = tf.reduce_sum(tf.multiply(data1, mask1), axis=0) 51 | num1 = tf.reduce_sum(label1[:, cls]) 52 | activations1 = logits_sum1 * 1.0 / (num1 + eps) # add eps for prevent un-sampled class resulting in NAN 53 | prob1 = tf.nn.softmax(activations1 / temperature) 54 | prob1 = tf.clip_by_value(prob1, clip_value_min=1e-8, clip_value_max=1.0) # for preventing prob=0 resulting in NAN 55 | 56 | mask2 = tf.tile(tf.expand_dims(label2[:, cls], -1), [1, n_class]) 57 | logits_sum2 = tf.reduce_sum(tf.multiply(data2, mask2), axis=0) 58 | num2 = tf.reduce_sum(label2[:, cls]) 59 | activations2 = logits_sum2 * 1.0 / (num2 + eps) 60 | prob2 = tf.nn.softmax(activations2 / temperature) 61 | prob2 = tf.clip_by_value(prob2, clip_value_min=1e-8, clip_value_max=1.0) 62 | 63 | KL_div = (tf.reduce_sum(prob1 * tf.log(prob1 / prob2)) + tf.reduce_sum(prob2 * tf.log(prob2 / prob1))) / 2.0 64 | kd_loss += KL_div * bool_indicator[cls] 65 | 66 | prob1s.append(prob1) 67 | prob2s.append(prob2) 68 | 69 | kd_loss = kd_loss / n_class 70 | 71 | return kd_loss, prob1s, prob2s 72 | 73 | def JS(data1, label1, data2, label2, bool_indicator, n_class=7, temperature=2.0): 74 | 75 | kd_loss = 0.0 76 | eps = 1e-16 77 | 78 | prob1s = [] 79 | prob2s = [] 80 | 81 | for cls in range(n_class): 82 | mask1 = tf.tile(tf.expand_dims(label1[:, cls], -1), [1, n_class]) 83 | logits_sum1 = tf.reduce_sum(tf.multiply(data1, mask1), axis=0) 84 | num1 = tf.reduce_sum(label1[:, cls]) 85 | activations1 = logits_sum1 * 1.0 / (num1 + eps) # add eps for prevent un-sampled class resulting in NAN 86 | prob1 = tf.nn.softmax(activations1 / temperature) 87 | prob1 = tf.clip_by_value(prob1, clip_value_min=1e-8, clip_value_max=1.0) # for preventing prob=0 resulting in NAN 88 | 89 | mask2 = tf.tile(tf.expand_dims(label2[:, cls], -1), [1, n_class]) 90 | logits_sum2 = tf.reduce_sum(tf.multiply(data2, mask2), axis=0) 91 | num2 = tf.reduce_sum(label2[:, cls]) 92 | activations2 = logits_sum2 * 1.0 / (num2 + eps) 93 | prob2 = tf.nn.softmax(activations2 / temperature) 94 | prob2 = tf.clip_by_value(prob2, clip_value_min=1e-8, clip_value_max=1.0) 95 | 96 | mean_prob = (prob1 + prob2) / 2 97 | 98 | JS_div = (tf.reduce_sum(prob1 * tf.log(prob1 / mean_prob)) + tf.reduce_sum(prob2 * tf.log(prob2 / mean_prob))) / 2.0 99 | kd_loss += JS_div * bool_indicator[cls] 100 | 101 | prob1s.append(prob1) 102 | prob2s.append(prob2) 103 | 104 | kd_loss = kd_loss / n_class 105 | 106 | return kd_loss, prob1s, prob2s 107 | 108 | def contrastive(feature1, label1, feature2, label2, bool_indicator=None, margin=50): 109 | 110 | l1 = tf.argmax(label1, axis=1) 111 | l2 = tf.argmax(label2, axis=1) 112 | pair = tf.to_float(tf.equal(l1,l2)) 113 | 114 | delta = tf.reduce_sum(tf.square(feature1-feature2), 1) + 1e-10 115 | match_loss = delta 116 | 117 | delta_sqrt = tf.sqrt(delta + 1e-10) 118 | mismatch_loss = tf.square(tf.nn.relu(margin - delta_sqrt)) 119 | 120 | if bool_indicator is None: 121 | loss = tf.reduce_mean(0.5 * (pair * match_loss + (1-pair) * mismatch_loss)) 122 | else: 123 | loss = 0.5 * tf.reduce_sum(match_loss*pair)/tf.reduce_sum(pair) 124 | 125 | debug_dist_positive = tf.reduce_sum(delta_sqrt * pair)/tf.reduce_sum(pair) 126 | debug_dist_negative = tf.reduce_sum(delta_sqrt * (1-pair))/tf.reduce_sum(1-pair) 127 | 128 | return loss, pair, delta, debug_dist_positive, debug_dist_negative 129 | 130 | def compute_distance(feature1, label1, feature2, label2): 131 | l1 = tf.argmax(label1, axis=1) 132 | l2 = tf.argmax(label2, axis=1) 133 | pair = tf.to_float(tf.equal(l1,l2)) 134 | 135 | delta = tf.reduce_sum(tf.square(feature1-feature2), 1) 136 | delta_sqrt = tf.sqrt(delta + 1e-16) 137 | 138 | dist_positive_pair = tf.reduce_sum(delta_sqrt * pair)/tf.reduce_sum(pair) 139 | dist_negative_pair = tf.reduce_sum(delta_sqrt * (1-pair))/tf.reduce_sum(1-pair) 140 | 141 | return dist_positive_pair, dist_negative_pair 142 | 143 | def _get_segmentation_cost(softmaxpred, seg_gt, n_class=2): 144 | """ 145 | calculate the loss for segmentation prediction 146 | :param seg_logits: probability segmentation from the segmentation network 147 | :param seg_gt: ground truth segmentaiton mask 148 | :return: segmentation loss, according to the cost_kwards setting, cross-entropy weighted loss and dice loss 149 | """ 150 | dice = 0 151 | 152 | for i in xrange(n_class): 153 | #inse = tf.reduce_sum(softmaxpred[:, :, :, i]*seg_gt[:, :, :, i]) 154 | inse = tf.reduce_sum(softmaxpred[:, :, :, i]*seg_gt[:, :, :, i]) 155 | l = tf.reduce_sum(softmaxpred[:, :, :, i]) 156 | r = tf.reduce_sum(seg_gt[:, :, :, i]) 157 | dice += 2.0 * inse/(l+r+1e-7) # here 1e-7 is relaxation eps 158 | dice_loss = 1 - 1.0 * dice / n_class 159 | 160 | # ce_weighted = 0 161 | # for i in xrange(n_class): 162 | # gti = seg_gt[:,:,:,i] 163 | # predi = softmaxpred[:,:,:,i] 164 | # ce_weighted += -1.0 * gti * tf.log(tf.clip_by_value(predi, 0.005, 1)) 165 | # ce_weighted_loss = tf.reduce_mean(ce_weighted) 166 | 167 | # total_loss = dice_loss 168 | 169 | 170 | return dice_loss#, dice_loss, ce_weighted_loss 171 | 172 | def _get_compactness_cost(y_pred, y_true): 173 | 174 | """ 175 | y_pred: BxHxWxC 176 | """ 177 | """ 178 | lenth term 179 | """ 180 | 181 | # y_pred = tf.one_hot(y_pred, depth=2) 182 | # print (y_true.shape) 183 | # print (y_pred.shape) 184 | y_pred = y_pred[..., 1] 185 | y_true = y_pred[..., 1] 186 | 187 | x = y_pred[:,1:,:] - y_pred[:,:-1,:] # horizontal and vertical directions 188 | y = y_pred[:,:,1:] - y_pred[:,:,:-1] 189 | 190 | delta_x = x[:,:,1:]**2 191 | delta_y = y[:,1:,:]**2 192 | 193 | delta_u = tf.abs(delta_x + delta_y) 194 | 195 | epsilon = 0.00000001 # where is a parameter to avoid square root is zero in practice. 196 | w = 0.01 197 | length = w * tf.reduce_sum(tf.sqrt(delta_u + epsilon), [1, 2]) 198 | 199 | area = tf.reduce_sum(y_pred, [1,2]) 200 | 201 | compactness_loss = tf.reduce_sum(length ** 2 / (area * 4 * 3.1415926)) 202 | 203 | return compactness_loss, tf.reduce_sum(length), tf.reduce_sum(area), delta_u 204 | 205 | # def _get_sample_masf(y_true): 206 | # """ 207 | # y_pred: BxHxWx2 208 | # """ 209 | # positive_mask = np.expand_dims(y_true[..., 1], axis=3) 210 | # metrix_label_group = np.expand_dims(np.array([1, 0, 1, 1, 0]), axis = 1) 211 | # # print (positive_mask.shape) 212 | # coutour_group = np.zeros(positive_mask.shape) 213 | 214 | # for i in range(positive_mask.shape[0]): 215 | # slice_i = positive_mask[i] 216 | 217 | # if metrix_label_group[i] == 1: 218 | # sample = (slice_i == 1) 219 | # elif metrix_label_group[i] == 0: 220 | # sample = (slice_i == 0) 221 | 222 | # coutour_group[i] = sample 223 | 224 | # return coutour_group, metrix_label_group 225 | 226 | def _get_coutour_sample(y_true): 227 | """ 228 | y_true: BxHxWx2 229 | """ 230 | positive_mask = np.expand_dims(y_true[..., 1], axis=3) 231 | metrix_label_group = np.expand_dims(np.array([1, 0, 1, 1, 0]), axis = 1) 232 | coutour_group = np.zeros(positive_mask.shape) 233 | 234 | for i in range(positive_mask.shape[0]): 235 | slice_i = positive_mask[i] 236 | 237 | if metrix_label_group[i] == 1: 238 | # generate coutour mask 239 | erosion = ndimage.binary_erosion(slice_i[..., 0], iterations=1).astype(slice_i.dtype) 240 | sample = np.expand_dims(slice_i[..., 0] - erosion, axis = 2) 241 | 242 | elif metrix_label_group[i] == 0: 243 | # generate background mask 244 | dilation = ndimage.binary_dilation(slice_i, iterations=5).astype(slice_i.dtype) 245 | sample = dilation - slice_i 246 | 247 | coutour_group[i] = sample 248 | return coutour_group, metrix_label_group 249 | 250 | # def _get_negative(y_true): 251 | def _get_boundary_cost(y_pred, y_true): 252 | 253 | """ 254 | y_pred: BxHxWxC 255 | """ 256 | """ 257 | lenth term 258 | """ 259 | 260 | # y_pred = tf.one_hot(y_pred, depth=2) 261 | # print (y_true.shape) 262 | # print (y_pred.shape) 263 | y_pred = y_pred[..., 1] 264 | y_true = y_pred[..., 1] 265 | 266 | x = y_pred[:,1:,:] - y_pred[:,:-1,:] # horizontal and vertical directions 267 | y = y_pred[:,:,1:] - y_pred[:,:,:-1] 268 | 269 | delta_x = x[:,:,1:]**2 270 | delta_y = y[:,1:,:]**2 271 | 272 | delta_u = tf.abs(delta_x + delta_y) 273 | 274 | epsilon = 0.00000001 # where is a parameter to avoid square root is zero in practice. 275 | w = 0.01 276 | length = w * tf.reduce_sum(tf.sqrt(delta_u + epsilon), [1, 2]) # equ.(11) in the paper 277 | 278 | area = tf.reduce_sum(y_pred, [1,2]) 279 | 280 | compactness_loss = tf.reduce_sum(length ** 2 / (area * 4 * 3.1415926)) 281 | 282 | return compactness_loss, tf.reduce_sum(length), tf.reduce_sum(area) 283 | 284 | def check_folder(log_dir): 285 | if not os.path.exists(log_dir): 286 | print ("Allocating '{:}'".format(log_dir)) 287 | os.makedirs(log_dir) 288 | return log_dir 289 | 290 | def _eval_dice(gt_y, pred_y, detail=False): 291 | 292 | class_map = { # a map used for mapping label value to its name, used for output 293 | "0": "bg", 294 | "1": "CZ", 295 | "2": "prostate" 296 | } 297 | 298 | dice = [] 299 | 300 | for cls in xrange(1,2): 301 | 302 | gt = np.zeros(gt_y.shape) 303 | pred = np.zeros(pred_y.shape) 304 | 305 | gt[gt_y == cls] = 1 306 | pred[pred_y == cls] = 1 307 | 308 | dice_this = 2*np.sum(gt*pred)/(np.sum(gt)+np.sum(pred)) 309 | dice.append(dice_this) 310 | 311 | if detail is True: 312 | #print ("class {}, dice is {:2f}".format(class_map[str(cls)], dice_this)) 313 | logging.info("class {}, dice is {:2f}".format(class_map[str(cls)], dice_this)) 314 | return dice 315 | 316 | def __surface_distances(result, reference, voxelspacing=None, connectivity=1): 317 | """ 318 | The distances between the surface voxel of binary objects in result and their 319 | nearest partner surface voxel of a binary object in reference. 320 | """ 321 | result = np.atleast_1d(result.astype(np.bool)) 322 | reference = np.atleast_1d(reference.astype(np.bool)) 323 | if voxelspacing is not None: 324 | voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim) 325 | voxelspacing = np.asarray(voxelspacing, dtype=np.float64) 326 | if not voxelspacing.flags.contiguous: 327 | voxelspacing = voxelspacing.copy() 328 | 329 | # binary structure 330 | footprint = generate_binary_structure(result.ndim, connectivity) 331 | 332 | # test for emptiness 333 | if 0 == np.count_nonzero(result): 334 | raise RuntimeError('The first supplied array does not contain any binary object.') 335 | if 0 == np.count_nonzero(reference): 336 | raise RuntimeError('The second supplied array does not contain any binary object.') 337 | 338 | # extract only 1-pixel border line of objects 339 | result_border = result ^ binary_erosion(result, structure=footprint, iterations=1) 340 | reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1) 341 | 342 | # compute average surface distance 343 | # Note: scipys distance transform is calculated only inside the borders of the 344 | # foreground objects, therefore the input has to be reversed 345 | dt = distance_transform_edt(~reference_border, sampling=voxelspacing) 346 | sds = dt[result_border] 347 | 348 | return sds 349 | 350 | def asd(result, reference, voxelspacing=None, connectivity=1): 351 | 352 | sds = __surface_distances(result, reference, voxelspacing, connectivity) 353 | asd = sds.mean() 354 | return asd 355 | 356 | def calculate_hausdorff(lP,lT,spacing): 357 | 358 | return asd(lP, lT, spacing) 359 | 360 | def _eval_haus(pred, gt, spacing, detail=False): 361 | ''' 362 | :param pred: whole brain prediction 363 | :param gt: whole 364 | :param detail: 365 | :return: a list, indicating Dice of each class for one case 366 | ''' 367 | haus = [] 368 | 369 | for cls in range(1,2): 370 | pred_i = np.zeros(pred.shape) 371 | pred_i[pred == cls] = 1 372 | gt_i = np.zeros(gt.shape) 373 | gt_i[gt == cls] = 1 374 | 375 | # hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter() 376 | # hausdorff_distance_filter.Execute(gt_i, pred_i) 377 | 378 | haus_cls = calculate_hausdorff(gt_i, (pred_i), spacing) 379 | 380 | haus.append(haus_cls) 381 | 382 | if detail is True: 383 | logging.info("class {}, haus is {:4f}".format(class_map[str(cls)], haus_cls)) 384 | # logging.info("4 class average haus is {:4f}".format(np.mean(haus))) 385 | 386 | return haus 387 | 388 | def _connectivity_region_analysis(mask): 389 | s = [[0,1,0], 390 | [1,1,1], 391 | [0,1,0]] 392 | label_im, nb_labels = ndimage.label(mask)#, structure=s) 393 | 394 | sizes = ndimage.sum(mask, label_im, range(nb_labels + 1)) 395 | 396 | # plt.imshow(label_im) 397 | label_im[label_im != np.argmax(sizes)] = 0 398 | label_im[label_im == np.argmax(sizes)] = 1 399 | 400 | return label_im 401 | 402 | def _crop_object_region(mask, prediction): 403 | 404 | limX, limY, limZ = np.where(mask>0) 405 | min_z = np.min(limZ) 406 | max_z = np.max(limZ) 407 | 408 | prediction[..., :np.min(limZ)] = 0 409 | prediction[..., np.max(limZ)+1:] = 0 410 | 411 | return prediction 412 | 413 | def parse_fn(data_path): 414 | ''' 415 | :param image_path: path to a folder of a patient 416 | :return: normalized entire image with its corresponding label 417 | In an image, the air region is 0, so we only calculate the mean and std within the brain area 418 | For any image-level normalization, do it here 419 | ''' 420 | path = data_path.split(",") 421 | image_path = path[0] 422 | label_path = path[1] 423 | #itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz')) 424 | #itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz')) 425 | itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz')) 426 | itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz')) 427 | # itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz')) 428 | 429 | image = sitk.GetArrayFromImage(itk_image) 430 | mask = sitk.GetArrayFromImage(itk_mask) 431 | #image[image >= 1000] = 1000 432 | binary_mask = np.ones(mask.shape) 433 | mean = np.sum(image * binary_mask) / np.sum(binary_mask) 434 | std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask)) 435 | image = (image - mean) / std # normalize per image, using statistics within the brain, but apply to whole image 436 | 437 | mask[mask==2] = 1 438 | 439 | return image.transpose([1,2,0]), mask.transpose([1,2,0]) # transpose the orientation of the 440 | 441 | 442 | def parse_fn_haus(data_path): 443 | ''' 444 | :param image_path: path to a folder of a patient 445 | :return: normalized entire image with its corresponding label 446 | In an image, the air region is 0, so we only calculate the mean and std within the brain area 447 | For any image-level normalization, do it here 448 | ''' 449 | path = data_path.split(",") 450 | image_path = path[0] 451 | label_path = path[1] 452 | #itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz')) 453 | #itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz')) 454 | itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz')) 455 | itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz')) 456 | # itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz')) 457 | spacing = itk_mask.GetSpacing() 458 | 459 | image = sitk.GetArrayFromImage(itk_image) 460 | mask = sitk.GetArrayFromImage(itk_mask) 461 | #image[image >= 1000] = 1000 462 | binary_mask = np.ones(mask.shape) 463 | mean = np.sum(image * binary_mask) / np.sum(binary_mask) 464 | std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask)) 465 | image = (image - mean) / std # normalize per image, using statistics within the brain, but apply to whole image 466 | 467 | mask[mask==2] = 1 468 | 469 | return image.transpose([1,2,0]), mask.transpose([1,2,0]), spacing 470 | 471 | def show_all_variables(): 472 | model_vars = tf.trainable_variables() 473 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 474 | 475 | -------------------------------------------------------------------------------- /saml_func.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import sys 4 | import tensorflow as tf 5 | from tensorflow.image import resize_images 6 | # try: 7 | # import special_grads 8 | # except KeyError as e: 9 | # print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e, file=sys.stderr) 10 | 11 | from tensorflow.python.platform import flags 12 | from layer import conv_block, deconv_block, fc, max_pool, concat2d 13 | from utils import xent, kd, _get_segmentation_cost, _get_compactness_cost 14 | 15 | class SAML: 16 | def __init__(self, args): 17 | """ Call construct_model_*() after initializing MASF""" 18 | self.args = args 19 | 20 | self.batch_size = args.meta_batch_size 21 | self.test_batch_size = args.test_batch_size 22 | self.volume_size = args.volume_size 23 | self.n_class = args.n_class 24 | self.compactness_loss_weight = args.compactness_loss_weight 25 | self.smoothness_loss_weight = args.smoothness_loss_weight 26 | self.margin = args.margin 27 | 28 | self.forward = self.forward_unet 29 | self.construct_weights = self.construct_unet_weights 30 | self.seg_loss = _get_segmentation_cost 31 | self.get_compactness_cost = _get_compactness_cost 32 | 33 | def construct_model_train(self, prefix='metatrain_'): 34 | # a: meta-train for inner update, b: meta-test for meta loss 35 | self.inputa = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) 36 | self.labela = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) 37 | self.inputa1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) 38 | self.labela1= tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) 39 | self.inputb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) 40 | self.labelb = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) 41 | self.input_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) 42 | self.label_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) 43 | self.contour_group = tf.placeholder("float", shape=[self.batch_size, self.volume_size[0], self.volume_size[1], 1]) 44 | self.metric_label_group = tf.placeholder(tf.int32, shape=[self.batch_size, 1]) 45 | self.training_mode = tf.placeholder_with_default(True, shape = None, name = "training_mode_for_bn_moving") 46 | 47 | 48 | self.clip_value = self.args.gradients_clip_value 49 | self.KEEP_PROB = tf.placeholder(tf.float32) 50 | 51 | with tf.variable_scope('model', reuse=None) as training_scope: 52 | if 'weights' in dir(self): 53 | print('weights already defined') 54 | training_scope.reuse_variables() 55 | weights = self.weights 56 | else: 57 | # Define the weights 58 | self.weights = weights = self.construct_weights() 59 | 60 | def task_metalearn(inp, reuse=True): 61 | # Function to perform meta learning update """ 62 | inputa, inputa1, inputb, labela, labela1, labelb, input_group, contour_group, metric_label_group = inp 63 | 64 | # Obtaining the conventional task loss on meta-train 65 | task_outputa, _, _ = self.forward(inputa, weights, is_training=self.training_mode) 66 | task_lossa = self.seg_loss(task_outputa, labela) 67 | task_outputa1, _, _ = self.forward(inputa1, weights, is_training=self.training_mode) 68 | task_lossa1 = self.seg_loss(task_outputa1, labela1) 69 | 70 | ## perform inner update with plain gradient descent on meta-train 71 | grads = tf.gradients((task_lossa + task_lossa1)/2.0, list(weights.values())) 72 | grads = [tf.stop_gradient(grad) for grad in grads] # first-order gradients approximation 73 | gradients = dict(zip(weights.keys(), grads)) 74 | # fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * gradients[key] for key in weights.keys()])) 75 | fast_weights = dict(zip(weights.keys(), [weights[key] - self.inner_lr * tf.clip_by_norm(gradients[key], clip_norm=self.clip_value) for key in weights.keys()])) 76 | 77 | ## compute compactness loss 78 | task_outputb, task_predmaskb, _ = self.forward(inputb, fast_weights, is_training=self.training_mode) 79 | task_lossb = self.seg_loss(task_outputb, labelb) 80 | compactness_loss_b, length, area, boundary_b = self.get_compactness_cost(task_outputb, labelb) 81 | compactness_loss_b = self.compactness_loss_weight * compactness_loss_b 82 | 83 | # compute smoothness loss 84 | _, _, embeddings = self.forward(input_group, fast_weights, is_training=self.training_mode) 85 | coutour_embeddings = self.extract_coutour_embedding(contour_group, embeddings) 86 | metric_embeddings = self.forward_metric_net(coutour_embeddings) 87 | 88 | print (metric_label_group.shape) 89 | print (metric_embeddings.shape) 90 | smoothness_loss_b = tf.contrib.losses.metric_learning.triplet_semihard_loss(labels=metric_label_group[..., 0], embeddings=metric_embeddings, margin=self.margin) 91 | smoothness_loss_b = self.smoothness_loss_weight * smoothness_loss_b 92 | task_output = [task_lossb, compactness_loss_b, smoothness_loss_b, task_predmaskb, boundary_b, length, area, task_lossa, task_lossa1] 93 | 94 | return task_output 95 | 96 | self.global_step = tf.Variable(0, trainable=False) 97 | # self.inner_lr = tf.train.exponential_decay(learning_rate=self.args.inner_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate) 98 | # self.outer_lr = tf.train.exponential_decay(learning_rate=self.args.outer_lr, global_step=self.global_step, decay_steps=self.args.decay_step, decay_rate=self.args.decay_rate) 99 | self.inner_lr = tf.Variable(self.args.inner_lr, trainable=False) 100 | self.outer_lr = tf.Variable(self.args.outer_lr, trainable=False) 101 | self.metric_lr = tf.Variable(self.args.metric_lr, trainable=False) 102 | 103 | input_tensors = (self.inputa, self.inputa1, self.inputb, self.labela, self.labela1, self.labelb, self.input_group, self.contour_group, self.metric_label_group) 104 | result = task_metalearn(inp=input_tensors) 105 | self.seg_loss_b, self.compactness_loss_b, self.smoothness_loss_b, self.task_predmaskb, self.boundary_b, self.length, self.area, self.seg_loss_a, self.seg_loss_a1= result 106 | 107 | ## Performance & Optimization 108 | if 'train' in prefix: 109 | self.source_loss = (self.seg_loss_a + self.seg_loss_a1) / 2.0 110 | self.target_loss = self.seg_loss_b + self.compactness_loss_b + self.smoothness_loss_b 111 | 112 | var_list_segmentor = [v for v in tf.trainable_variables() if 'metric' not in v.name.split('/')] 113 | var_list_metric = [v for v in tf.trainable_variables() if 'metric' in v.name.split('/')] 114 | 115 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 116 | with tf.control_dependencies(update_ops): 117 | self.task_train_op = tf.train.AdamOptimizer(learning_rate=self.inner_lr).minimize(self.source_loss, global_step=self.global_step) 118 | 119 | optimizer = tf.train.AdamOptimizer(self.outer_lr) 120 | gvs = optimizer.compute_gradients(self.target_loss, var_list=var_list_segmentor) 121 | 122 | # observe stability of gradients for meta loss 123 | # l2_norm = lambda t: tf.sqrt(tf.reduce_sum(tf.pow(t, 2))) 124 | # for grad, var in gvs: 125 | # tf.summary.histogram("gradients_norm/" + var.name, l2_norm(grad)) 126 | # tf.summary.histogram("feature_extractor_var_norm/" + var.name, l2_norm(var)) 127 | # tf.summary.histogram('gradients/' + var.name, var) 128 | # tf.summary.histogram("feature_extractor_var/" + var.name, var) 129 | 130 | # gvs = [(grad, var) for grad, var in gvs] 131 | gvs = [(tf.clip_by_norm(grad, clip_norm=self.clip_value), var) for grad, var in gvs] 132 | self.meta_train_op = optimizer.apply_gradients(gvs) 133 | 134 | # for grad, var in gvs: 135 | # tf.summary.histogram("gradients_norm_clipped/" + var.name, l2_norm(grad)) 136 | # tf.summary.histogram('gradients_clipped/' + var.name, var) 137 | 138 | self.metric_train_op = tf.train.AdamOptimizer(self.metric_lr).minimize(self.smoothness_loss_b, var_list=var_list_metric) 139 | 140 | ## Summaries 141 | # scalar_summaries = [] 142 | # train_images = [] 143 | # val_images = [] 144 | 145 | tf.summary.scalar(prefix+'source_1 loss', self.seg_loss_a) 146 | tf.summary.scalar(prefix+'source_2 loss', self.seg_loss_a1) 147 | tf.summary.scalar(prefix+'target_loss', self.seg_loss_b) 148 | tf.summary.scalar(prefix+'target_coutour_loss', self.compactness_loss_b) 149 | tf.summary.scalar(prefix+'target_length', self.length) 150 | tf.summary.scalar(prefix+'target_area', self.area) 151 | tf.summary.image("meta_test_mask", tf.expand_dims(tf.cast(self.task_predmaskb, tf.float32), 3)) 152 | tf.summary.image("meta_test_gth", tf.expand_dims(tf.cast(self.labelb[:,:,:,1], tf.float32), 3)) 153 | tf.summary.image("meta_test_image", tf.expand_dims(tf.cast(self.inputb[:,:,:,1], tf.float32), 3)) 154 | tf.summary.image("meta_test_boundary", tf.expand_dims(tf.cast(self.boundary_b[:,:,:], tf.float32), 3)) 155 | tf.summary.image("meta_test_ct_bg_sample", tf.expand_dims(tf.cast(self.contour_group[:,:,:, 0], tf.float32), 3)) 156 | tf.summary.image("meta_input_group", tf.expand_dims(tf.cast(self.input_group[:,:,:, 1], tf.float32), 3)) 157 | tf.summary.image("label_group", tf.expand_dims(tf.cast(self.label_group[:,:,:, 1], tf.float32), 3)) 158 | 159 | def extract_coutour_embedding(self, coutour, embeddings): 160 | 161 | coutour_embeddings = coutour * embeddings 162 | average_embeddings = tf.reduce_sum(coutour_embeddings, [1,2])/tf.reduce_sum(coutour, [1,2]) 163 | # print (coutour.shape) 164 | # print (embeddings.shape) 165 | # print (coutour_embeddings.shape) 166 | # print (average_embeddings.shape) 167 | return average_embeddings 168 | 169 | def construct_model_test(self, prefix='test'): 170 | self.test_input = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.volume_size[2]]) 171 | self.test_label = tf.placeholder("float", shape=[self.test_batch_size, self.volume_size[0], self.volume_size[1], self.n_class]) 172 | 173 | with tf.variable_scope('model', reuse=None) as testing_scope: 174 | if 'weights' in dir(self): 175 | testing_scope.reuse_variables() 176 | weights = self.weights 177 | else: 178 | raise ValueError('Weights not initilized. Create training model before testing model') 179 | 180 | outputs, mask, _ = self.forward(self.test_input, weights) 181 | losses = self.seg_loss(outputs, self.test_label) 182 | # self.pred_prob = tf.nn.softmax(outputs) 183 | self.outputs = mask 184 | 185 | self.test_loss = losses 186 | # self.test_acc = accuracies 187 | 188 | def forward_metric_net(self, x): 189 | 190 | with tf.variable_scope('metric', reuse=tf.AUTO_REUSE) as scope: 191 | 192 | w1 = tf.get_variable('w1', shape=[48,24]) 193 | b1 = tf.get_variable('b1', shape=[24]) 194 | out = fc(x, w1, b1, activation='leaky_relu') 195 | w2 = tf.get_variable('w2', shape=[24,16]) 196 | b2 = tf.get_variable('b2', shape=[16]) 197 | out = fc(out, w2, b2, activation='leaky_relu') 198 | 199 | return out 200 | 201 | def construct_unet_weights(self): 202 | 203 | weights = {} 204 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=tf.float32) 205 | 206 | with tf.variable_scope('conv1') as scope: 207 | weights['conv11_weights'] = tf.get_variable('weights', shape=[5, 5, 3, 16], initializer=conv_initializer) 208 | weights['conv11_biases'] = tf.get_variable('biases', [16]) 209 | weights['conv12_weights'] = tf.get_variable('weights2', shape=[5, 5, 16, 16], initializer=conv_initializer) 210 | weights['conv12_biases'] = tf.get_variable('biases2', [16]) 211 | 212 | with tf.variable_scope('conv2') as scope: 213 | weights['conv21_weights'] = tf.get_variable('weights', shape=[5, 5, 16, 32], initializer=conv_initializer) 214 | weights['conv21_biases'] = tf.get_variable('biases', [32]) 215 | weights['conv22_weights'] = tf.get_variable('weights2', shape=[5, 5, 32, 32], initializer=conv_initializer) 216 | weights['conv22_biases'] = tf.get_variable('biases2', [32]) 217 | ## Network has downsample here 218 | 219 | with tf.variable_scope('conv3') as scope: 220 | weights['conv31_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 64], initializer=conv_initializer) 221 | weights['conv31_biases'] = tf.get_variable('biases', [64]) 222 | weights['conv32_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer) 223 | weights['conv32_biases'] = tf.get_variable('biases2', [64]) 224 | 225 | with tf.variable_scope('conv4') as scope: 226 | weights['conv41_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 128], initializer=conv_initializer) 227 | weights['conv41_biases'] = tf.get_variable('biases', [128]) 228 | weights['conv42_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer) 229 | weights['conv42_biases'] = tf.get_variable('biases2', [128]) 230 | ## Network has downsample here 231 | 232 | with tf.variable_scope('conv5') as scope: 233 | weights['conv51_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 256], initializer=conv_initializer) 234 | weights['conv51_biases'] = tf.get_variable('biases', [256]) 235 | weights['conv52_weights'] = tf.get_variable('weights2', shape=[3, 3, 256, 256], initializer=conv_initializer) 236 | weights['conv52_biases'] = tf.get_variable('biases2', [256]) 237 | 238 | with tf.variable_scope('deconv6') as scope: 239 | weights['deconv6_weights'] = tf.get_variable('weights0', shape=[3, 3, 128, 256], initializer=conv_initializer) 240 | weights['deconv6_biases'] = tf.get_variable('biases0', shape=[128], initializer=conv_initializer) 241 | weights['conv61_weights'] = tf.get_variable('weights', shape=[3, 3, 256, 128], initializer=conv_initializer) 242 | weights['conv61_biases'] = tf.get_variable('biases', [128]) 243 | weights['conv62_weights'] = tf.get_variable('weights2', shape=[3, 3, 128, 128], initializer=conv_initializer) 244 | weights['conv62_biases'] = tf.get_variable('biases2', [128]) 245 | 246 | with tf.variable_scope('deconv7') as scope: 247 | weights['deconv7_weights'] = tf.get_variable('weights0', shape=[3, 3, 64, 128], initializer=conv_initializer) 248 | weights['deconv7_biases'] = tf.get_variable('biases0', shape=[64], initializer=conv_initializer) 249 | weights['conv71_weights'] = tf.get_variable('weights', shape=[3, 3, 128, 64], initializer=conv_initializer) 250 | weights['conv71_biases'] = tf.get_variable('biases', [64]) 251 | weights['conv72_weights'] = tf.get_variable('weights2', shape=[3, 3, 64, 64], initializer=conv_initializer) 252 | weights['conv72_biases'] = tf.get_variable('biases2', [64]) 253 | 254 | with tf.variable_scope('deconv8') as scope: 255 | weights['deconv8_weights'] = tf.get_variable('weights0', shape=[3, 3, 32, 64], initializer=conv_initializer) 256 | weights['deconv8_biases'] = tf.get_variable('biases0', shape=[32], initializer=conv_initializer) 257 | weights['conv81_weights'] = tf.get_variable('weights', shape=[3, 3, 64, 32], initializer=conv_initializer) 258 | weights['conv81_biases'] = tf.get_variable('biases', [32]) 259 | weights['conv82_weights'] = tf.get_variable('weights2', shape=[3, 3, 32, 32], initializer=conv_initializer) 260 | weights['conv82_biases'] = tf.get_variable('biases2', [32]) 261 | 262 | with tf.variable_scope('deconv9') as scope: 263 | weights['deconv9_weights'] = tf.get_variable('weights0', shape=[3, 3, 16, 32], initializer=conv_initializer) 264 | weights['deconv9_biases'] = tf.get_variable('biases0', shape=[16], initializer=conv_initializer) 265 | weights['conv91_weights'] = tf.get_variable('weights', shape=[3, 3, 32, 16], initializer=conv_initializer) 266 | weights['conv91_biases'] = tf.get_variable('biases', [16]) 267 | weights['conv92_weights'] = tf.get_variable('weights2', shape=[3, 3, 16, 16], initializer=conv_initializer) 268 | weights['conv92_biases'] = tf.get_variable('biases2', [16]) 269 | 270 | with tf.variable_scope('output') as scope: 271 | weights['output_weights'] = tf.get_variable('weights', shape=[3, 3, 16, 2], initializer=conv_initializer) 272 | weights['output_biases'] = tf.get_variable('biases', [2]) 273 | 274 | return weights 275 | 276 | def forward_unet(self, inp, weights, is_training=True): 277 | 278 | self.conv11 = conv_block(inp, weights['conv11_weights'], weights['conv11_biases'], scope='conv1/bn1', bn=False, is_training=is_training) 279 | self.conv12 = conv_block(self.conv11, weights['conv12_weights'], weights['conv12_biases'], scope='conv1/bn2', is_training=is_training) 280 | self.pool11 = max_pool(self.conv12, 2, 2, 2, 2, padding='VALID') 281 | # 192x192x16 282 | self.conv21 = conv_block(self.pool11, weights['conv21_weights'], weights['conv21_biases'], scope='conv2/bn1', is_training=is_training) 283 | self.conv22 = conv_block(self.conv21, weights['conv22_weights'], weights['conv22_biases'], scope='conv2/bn2', is_training=is_training) 284 | self.pool21 = max_pool(self.conv22, 2, 2, 2, 2, padding='VALID') 285 | # 96x96x32 286 | self.conv31 = conv_block(self.pool21, weights['conv31_weights'], weights['conv31_biases'], scope='conv3/bn1', is_training=is_training) 287 | self.conv32 = conv_block(self.conv31, weights['conv32_weights'], weights['conv32_biases'], scope='conv3/bn2', is_training=is_training) 288 | self.pool31 = max_pool(self.conv32, 2, 2, 2, 2, padding='VALID') 289 | # 48x48x64 290 | self.conv41 = conv_block(self.pool31, weights['conv41_weights'], weights['conv41_biases'], scope='conv4/bn1', is_training=is_training) 291 | self.conv42 = conv_block(self.conv41, weights['conv42_weights'], weights['conv42_biases'], scope='conv4/bn2', is_training=is_training) 292 | self.pool41 = max_pool(self.conv42, 2, 2, 2, 2, padding='VALID') 293 | # 24x24x128 294 | self.conv51 = conv_block(self.pool41, weights['conv51_weights'], weights['conv51_biases'], scope='conv5/bn1', is_training=is_training) 295 | self.conv52 = conv_block(self.conv51, weights['conv52_weights'], weights['conv52_biases'], scope='conv5/bn2', is_training=is_training) 296 | # 24x24x256 297 | 298 | ## add upsampling, meanwhile, channel number is reduced to half 299 | self.deconv6 = deconv_block(self.conv52, weights['deconv6_weights'], weights['deconv6_biases'], scope='deconv/bn6', is_training=is_training) 300 | # 48x48x128 301 | self.sum6 = concat2d(self.deconv6, self.deconv6) 302 | self.conv61 = conv_block(self.sum6, weights['conv61_weights'], weights['conv61_biases'], scope='conv6/bn1', is_training=is_training) 303 | self.conv62 = conv_block(self.conv61, weights['conv62_weights'], weights['conv62_biases'], scope='conv6/bn2', is_training=is_training) 304 | # 48x48x128 305 | 306 | self.deconv7 = deconv_block(self.conv62, weights['deconv7_weights'], weights['deconv7_biases'], scope='deconv/bn7', is_training=is_training) 307 | # 96x96x64 308 | self.sum7 = concat2d(self.deconv7, self.deconv7) 309 | self.conv71 = conv_block(self.sum7, weights['conv71_weights'], weights['conv71_biases'], scope='conv7/bn1', is_training=is_training) 310 | self.conv72 = conv_block(self.conv71, weights['conv72_weights'], weights['conv72_biases'], scope='conv7/bn2', is_training=is_training) 311 | # 96x96x64 312 | 313 | self.deconv8 = deconv_block(self.conv72, weights['deconv8_weights'], weights['deconv8_biases'], scope='deconv/bn8', is_training=is_training) 314 | # 192x192x32 315 | self.sum8 = concat2d(self.deconv8, self.deconv8) 316 | self.conv81 = conv_block(self.sum8, weights['conv81_weights'], weights['conv81_biases'], scope='conv8/bn1', is_training=is_training) 317 | self.conv82 = conv_block(self.conv81, weights['conv82_weights'], weights['conv82_biases'], scope='conv8/bn2', is_training=is_training) 318 | self.conv82_resize = tf.image.resize_images(self.conv82, [384, 384], method=tf.image.ResizeMethod.BILINEAR, align_corners=False) 319 | # 192x192x32 320 | 321 | self.deconv9 = deconv_block(self.conv82, weights['deconv9_weights'], weights['deconv9_biases'], scope='deconv/bn9', is_training=is_training) 322 | # 384x384x16 323 | self.sum9 = concat2d(self.deconv9, self.deconv9) 324 | self.conv91 = conv_block(self.sum9, weights['conv91_weights'], weights['conv91_biases'], scope='conv9/bn1', is_training=is_training) 325 | self.conv92 = conv_block(self.conv91, weights['conv92_weights'], weights['conv92_biases'], scope='conv9/bn2', is_training=is_training) 326 | # 384x384x16 327 | 328 | self.logits = conv_block(self.conv92, weights['output_weights'], weights['output_biases'], scope='outpu/bn', bn=False, is_training=is_training) 329 | #384x384x2 330 | 331 | self.pred_prob = tf.nn.softmax(self.logits) # shape [batch, w, h, num_classes] 332 | self.pred_compact = tf.argmax(self.pred_prob, axis=-1) # shape [batch, w, h] 333 | 334 | self.embeddings = concat2d(self.conv82_resize, self.conv92) 335 | 336 | return self.pred_prob, self.pred_compact, self.embeddings 337 | --------------------------------------------------------------------------------