├── Figure └── merge_scheme.png ├── LICENSE ├── README.md ├── data_util ├── __init__.py ├── brain.py ├── data.py └── liver.py ├── datasets └── brain.json ├── evaluate ├── Dec09-1849-model-31200.txt └── Dec09-1849-model-31200.xls ├── network ├── __init__.py ├── base_networks.py ├── framework.py ├── losses.py ├── recursive_cascaded_networks.py ├── spatial_transformer.py ├── transform.py ├── trilinear_sampler.py └── utils.py ├── predict.py └── train.py /Figure/merge_scheme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JinxLv/reimplemention-of-Dual-PRNet/83f2e42ac6316fac49708a7bc42b303a8ff16af2/Figure/merge_scheme.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jinx 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reimplemention-of-Dual-PRNet 2 | 3 | This is a Tensorflow reimplemention of [Dual-Stream Pyramid Registration Network](https://arxiv.org/abs/1909.11966) 4 | 5 | ## Install 6 | The packages and their corresponding version we used in this repository are listed in below. 7 | 8 | - Tensorflow==1.15.4 9 | - Keras==2.3.1 10 | - tflearn==0.5.0 11 | 12 | ## Training 13 | After configuring the environment, please use this command to train the model. 14 | 15 | ```sh 16 | python train.py -g 0 --batch 1 -d datasets/brain.json -b DUAL -n 1 --round 10000 --epoch 10 17 | ``` 18 | 19 | ## Testing 20 | Use this command to obtain the testing results. 21 | ```sh 22 | python predict.py -g 0 --batch 1 -d datasets/brain.json -c weights/Dec09-1849 23 | ``` 24 | 25 | ## LPBA dataset 26 | We use the same training and testing data as [RCN](https://github.com/microsoft/Recursive-Cascaded-Networks), please refer to their repository to download the pre-processed data. 27 | 28 | ## Results 29 | 30 | Method |Dice | HD | ASSD |Jacobian Std. | Folding (%) | 31 | ---|:-:|:-:|:-:|:-:|:-:| 32 | Original [Dual-PRNet](https://arxiv.org/abs/1909.11966) | 0.778 | - | - | - | - | 33 | Re-implemented Dual-PRNet | 0.831±0.008 | 3.457±0.297 | 0.811±0.046 | 0.906±0.059 | 1.6e-1±2.4e-2| 34 | [VoxelMorph](https://arxiv.org/pdf/1809.05231.pdf) | 0.820±0.008 | 3.648±0.284 | 0.892±0.047 | 0.247±0.057 | 5.2e-3±6.8e-3 | 35 | [VTN](https://arxiv.org/pdf/1902.05020.pdf) | 0.825±0.008 | 3.584±0.265 | 0.925±0.047| 0.179±0.024 | 0.0±0.0| 36 | [2×10-cascade VTN](https://openaccess.thecvf.com/content_ICCV_2019/papers/Zhao_Recursive_Cascaded_Networks_for_Unsupervised_Medical_Image_Registration_ICCV_2019_paper.pdf) | 0.831±0.009 | 3.551±0.328 | 0.810±0.046| 0.355±0.068|1.2e-6±7.5e-6| 37 | 38 | We have tried to follow [Dual-PRNet](https://arxiv.org/abs/1909.11966) to merge 56 regions into 7, and the merged 7 regions and the corresponding [label IDs](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2757616/) of functional areas in each merged region are shown in figure below. 39 | 40 | ![merge](./Figure/merge_scheme.png) 41 | 42 | ## Acknowledgment 43 | 44 | Some codes are modified from [RCN](https://github.com/microsoft/Recursive-Cascaded-Networks) and [VoxelMorph](https://github.com/voxelmorph/voxelmorph). 45 | Thanks a lot for their great contribution. 46 | -------------------------------------------------------------------------------- /data_util/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Path: 4 | def __init__(self, path): 5 | self.path = path 6 | if not os.path.exists(path): 7 | os.mkdir(path) 8 | def __call__(self, *names): 9 | return os.path.join(*((self.path, ) + names)) -------------------------------------------------------------------------------- /data_util/brain.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import copy 4 | import numpy as np 5 | import collections 6 | 7 | from .liver import FileManager 8 | from .liver import Dataset as BaseDataset 9 | 10 | 11 | class Dataset(BaseDataset): 12 | def __init__(self, split_path, paired=False, task=None, batch_size=None): 13 | with open(split_path, 'r') as f: 14 | config = json.load(f) 15 | self.files = FileManager(config['files']) 16 | self.subset = {} 17 | 18 | for k, v in config['subsets'].items(): 19 | self.subset[k] = {} 20 | for entry in v: 21 | self.subset[k][entry] = self.files[entry] 22 | 23 | self.paired = paired 24 | 25 | def convert_int(key): 26 | try: 27 | return int(key) 28 | except ValueError as e: 29 | return key 30 | self.schemes = dict([(convert_int(k), v) 31 | for k, v in config['schemes'].items()]) 32 | 33 | for k, v in self.subset.items(): 34 | print('Number of data in {} is {}'.format(k, len(v))) 35 | 36 | self.task = task 37 | if self.task is None: 38 | self.task = config.get("task", "registration") 39 | if not isinstance(self.task, list): 40 | self.task = [self.task] 41 | 42 | self.image_size = config.get("image_size", [128, 128, 128]) 43 | self.segmentation_class_value = config.get( 44 | 'segmentation_class_value', None) 45 | 46 | if 'atlas' in config: 47 | self.atlas = self.files[config['atlas']] 48 | else: 49 | self.atlas = None 50 | 51 | self.batch_size = batch_size 52 | 53 | def center_crop(self, volume): 54 | slices = [slice((os - ts) // 2, (os - ts) // 2 + ts) if ts < os else slice(None, None) 55 | for ts, os in zip(self.image_size, volume.shape)] 56 | volume = volume[slices] 57 | 58 | ret = np.zeros(self.image_size, dtype=volume.dtype) 59 | slices = [slice((ts - os) // 2, (ts - os) // 2 + os) if ts > os else slice(None, None) 60 | for ts, os in zip(self.image_size, volume.shape)] 61 | ret[slices] = volume 62 | 63 | return ret 64 | 65 | @staticmethod 66 | def generate_atlas(atlas, sets, loop=False): 67 | sets = copy.copy(sets) 68 | while True: 69 | if loop: 70 | np.random.shuffle(sets) 71 | for d in sets: 72 | yield atlas, d 73 | if not loop: 74 | break 75 | 76 | def generator(self, subset, batch_size=None, loop=False): 77 | if batch_size is None: 78 | batch_size = self.batch_size 79 | scheme = self.schemes[subset] 80 | if 'registration' in self.task: 81 | if self.atlas is not None: 82 | generators, fractions = zip(*[(self.generate_atlas(self.atlas, list( 83 | self.subset[k].values()), loop), fraction) for k, fraction in scheme.items()]) 84 | else: 85 | generators, fractions = zip( 86 | *[(self.generate_pairs(list(self.subset[k].values()), loop), fraction) for k, fraction in scheme.items()]) 87 | 88 | while True: 89 | imgs = [batch_size] + self.image_size + [1] 90 | ret = dict() 91 | ret['voxel1'] = np.zeros(imgs, dtype=np.float32) 92 | ret['voxel2'] = np.zeros(imgs, dtype=np.float32) 93 | ret['seg1'] = np.zeros(imgs, dtype=np.float32) 94 | ret['seg2'] = np.zeros(imgs, dtype=np.float32) 95 | ret['point1'] = np.ones( 96 | (batch_size, 6, 3), dtype=np.float32) * (-1) 97 | ret['point2'] = np.ones( 98 | (batch_size, 6, 3), dtype=np.float32) * (-1) 99 | ret['id1'] = np.empty((batch_size), dtype='= 0.5, tf.float32)*v) 69 | return tf.add_n(warped_segs) 70 | return tf.cond(tflearn.get_training_mode(), lambda: aug(segmentation,flow),lambda: segmentation) 71 | 72 | def augmenetation_pts(incoming,flow): 73 | def aug(incoming,flow): 74 | aug_pt = tf.cast(transform.warp_points( 75 | flow, incoming), tf.float32) 76 | pt_mask = tf.cast(tf.reduce_all( 77 | incoming >= 0, axis=-1, keep_dims=True), tf.float32) 78 | return aug_pt * pt_mask - (1 - pt_mask) 79 | return tf.cond(tflearn.get_training_mode(), lambda: aug(incoming,flow), lambda: incoming) 80 | 81 | augImg2 = augmentation(preprocessedImg2,augFlow) 82 | augSeg2 = augmentation(seg2,augFlow) 83 | augPt2 = augmenetation_pts(point2,augFlow) 84 | elif aug == 'identity': 85 | augFlow = tf.zeros( 86 | tf.stack([tf.shape(img1)[0], 128, 128, 128, 3]), dtype=tf.float32) 87 | augImg2 = preprocessedImg2 88 | augSeg2 = seg2 89 | augPt2 = point2 90 | else: 91 | raise NotImplementedError('Augmentation {}'.format(aug)) 92 | 93 | learningRate = tf.placeholder(tf.float32, [], 'learningRate') 94 | if not validation: 95 | adamOptimizer = tf.train.AdamOptimizer(learningRate) 96 | 97 | self.segmentation_class_value = segmentation_class_value 98 | self.network = network_class( 99 | self.framework_name, framework=self, fast_reconstruction=fast_reconstruction, **self.net_args) 100 | net_pls = [augImg1, augImg2, seg1, augSeg2, point1, augPt2] 101 | if devices == 0: 102 | with tf.device("/cpu:0"): 103 | self.predictions = self.network(*net_pls) 104 | if not validation: 105 | self.adamOpt = adamOptimizer.minimize( 106 | self.predictions["loss"]) 107 | else: 108 | gpus = MultiGPUs(devices) 109 | if validation: 110 | self.predictions = gpus(self.network, net_pls) 111 | else: 112 | self.predictions, self.adamOpt = gpus( 113 | self.network, net_pls, opt=adamOptimizer) 114 | self.build_summary(self.predictions) 115 | 116 | @property 117 | def data_args(self): 118 | return self.network.data_args 119 | 120 | def build_summary(self, predictions): 121 | self.loss = tf.reduce_mean(predictions['loss']) 122 | for k in predictions: 123 | if k.find('loss') != -1: 124 | tf.summary.scalar(k, tf.reduce_mean(predictions[k])) 125 | self.summaryOp = tf.summary.merge_all() 126 | 127 | if self.summaryType == 'full': 128 | tf.summary.scalar('dice_score', tf.reduce_mean( 129 | self.predictions['dice_score'])) 130 | tf.summary.scalar('landmark_dist', masked_mean( 131 | self.predictions['landmark_dist'], self.predictions['pt_mask'])) 132 | preds = tf.reduce_sum( 133 | tf.cast(self.predictions['jacc_score'] > 0, tf.float32)) 134 | tf.summary.scalar('jacc_score', tf.reduce_sum( 135 | self.predictions['jacc_score']) / (preds + 1e-8)) 136 | self.summaryExtra = tf.summary.merge_all() 137 | else: 138 | self.summaryExtra = self.summaryOp 139 | self.summaryImages1 = tf.summary.image('fixed_img', tf.reshape(self.predictions['image_fixed'][0,:,64,:,0], (1,128,128,1))) 140 | self.summaryImages2 = tf.summary.image('warped_moving_img', tf.reshape(self.predictions['warped_moving'][0,:,64,:,0], (1,128,128,1))) 141 | self.summaryImages3 = tf.summary.image('image_float', tf.reshape(self.predictions['moving_img'][0,:,64,:,0], (1,128,128,1))) 142 | self.summaryImages = tf.summary.merge([self.summaryImages1,self.summaryImages2,self.summaryImages3]) 143 | 144 | def get_predictions(self, *keys): 145 | return dict([(k, self.predictions[k]) for k in keys]) 146 | 147 | def validate_clean(self, sess, generator, keys=None): 148 | for fd in generator: 149 | _ = fd.pop('id1') 150 | _ = fd.pop('id2') 151 | _ = sess.run(self.get_predictions(*keys), 152 | feed_dict=set_tf_keys(fd)) 153 | 154 | def validate(self, sess, generator, keys=None, summary=False, predict=False, show_tqdm=False): 155 | if keys is None: 156 | keys = ['dice_score', 'landmark_dist', 'pt_mask', 'jacc_score','total_ncc','total_mse']# 157 | 158 | full_results = dict([(k, list()) for k in keys]) 159 | if not summary: 160 | full_results['id1'] = [] 161 | full_results['id2'] = [] 162 | if predict: 163 | full_results['seg1'] = [] 164 | full_results['seg2'] = [] 165 | full_results['img1'] = [] 166 | full_results['img2'] = [] 167 | tflearn.is_training(False, sess) 168 | if show_tqdm: 169 | generator = tqdm(generator) 170 | for fd in generator: 171 | id1 = fd.pop('id1') 172 | id2 = fd.pop('id2') 173 | results = sess.run(self.get_predictions( 174 | *keys), feed_dict=set_tf_keys(fd)) 175 | if not summary: 176 | results['id1'] = id1 177 | results['id2'] = id2 178 | if predict: 179 | results['seg1'] = fd['seg1'] 180 | results['seg2'] = fd['seg2'] 181 | results['img1'] = fd['voxel1'] 182 | results['img2'] = fd['voxel2'] 183 | mask = np.where([i and j for i, j in zip(id1, id2)]) 184 | for k, v in results.items(): 185 | full_results[k].append(v[mask]) 186 | if 'landmark_dist' in full_results and 'pt_mask' in full_results: 187 | pt_mask = full_results.pop('pt_mask') 188 | full_results['landmark_dist'] = [arr * mask for arr, 189 | mask in zip(full_results['landmark_dist'], pt_mask)] 190 | for k in full_results: 191 | full_results[k] = np.concatenate(full_results[k], axis=0) 192 | if summary: 193 | full_results[k] = full_results[k].mean() 194 | 195 | return full_results 196 | -------------------------------------------------------------------------------- /network/losses.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow.keras.layers as KL 5 | import tensorflow.keras.backend as K 6 | 7 | 8 | class NCC: 9 | """ 10 | Local (over window) normalized cross correlation loss. 11 | """ 12 | 13 | def __init__(self, win=None, eps=1e-5): 14 | self.win = win 15 | self.eps = eps 16 | 17 | def ncc(self, I, J): 18 | # get dimension of volume 19 | # assumes I, J are sized [batch_size, *vol_shape, nb_feats] 20 | ndims = len(I.get_shape().as_list()) - 2 21 | assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims 22 | 23 | # set window size 24 | if self.win is None: 25 | self.win = [9] * ndims 26 | 27 | # get convolution function 28 | conv_fn = getattr(tf.nn, 'conv%dd' % ndims) 29 | 30 | # compute CC squares 31 | I2 = I * I 32 | J2 = J * J 33 | IJ = I * J 34 | 35 | # compute filters 36 | in_ch = J.get_shape().as_list()[-1] 37 | sum_filt = tf.ones([*self.win, in_ch, 1]) 38 | strides = 1 39 | if ndims > 1: 40 | strides = [1] * (ndims + 2) 41 | 42 | # compute local sums via convolution 43 | padding = 'SAME' 44 | I_sum = conv_fn(I, sum_filt, strides, padding) 45 | J_sum = conv_fn(J, sum_filt, strides, padding) 46 | I2_sum = conv_fn(I2, sum_filt, strides, padding) 47 | J2_sum = conv_fn(J2, sum_filt, strides, padding) 48 | IJ_sum = conv_fn(IJ, sum_filt, strides, padding) 49 | 50 | # compute cross correlation 51 | win_size = np.prod(self.win) * in_ch 52 | u_I = I_sum / win_size 53 | u_J = J_sum / win_size 54 | 55 | cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size # TODO: simplify this 56 | I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size 57 | J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size 58 | 59 | cc = cross * cross / (I_var * J_var + self.eps) 60 | 61 | # return mean cc for each entry in batch 62 | return tf.reduce_mean(K.batch_flatten(cc), axis=-1) 63 | 64 | def loss(self, y_true, y_pred): 65 | return - self.ncc(y_true, y_pred) 66 | 67 | 68 | class MSE: 69 | """ 70 | Sigma-weighted mean squared error for image reconstruction. 71 | """ 72 | 73 | def __init__(self, image_sigma=1.0): 74 | self.image_sigma = image_sigma 75 | 76 | def loss(self, y_true, y_pred): 77 | return 1.0 / (self.image_sigma**2) * K.mean(K.square(y_true - y_pred)) 78 | 79 | 80 | class TukeyBiweight: 81 | """ 82 | Tukey-Biweight loss. 83 | 84 | The single parameter c represents the threshold above which voxel 85 | differences are cropped and have no further effect (that is, they are 86 | treated as outliers and automatically discounted). 87 | 88 | See: DOI: 10.1016/j.neuroimage.2010.07.020 89 | Reuter, Rosas and Fischl, 2010. Highly accurate inverse consistent registration: 90 | a robust approach. NeuroImage, 53(4):1181-96. 91 | """ 92 | 93 | def __init__(self, c=0.5): 94 | self.csq = c * c # squared error threshold 95 | 96 | def loss(self, y_true, y_pred): 97 | error_sq = (y_true - y_pred) ** 2 98 | ind_below = tf.where(error_sq <= self.csq) 99 | rho_below = (self.csq / 2) * (1 - (1 - (tf.gather_nd(error_sq, ind_below)/self.csq)) ** 3) 100 | rho_above = self.csq / 2 101 | w_below = tf.cast(tf.shape(ind_below)[0], tf.float32) 102 | w_above = tf.cast(tf.reduce_prod(tf.shape(y_pred)), tf.float32) - w_below 103 | return (w_below * tf.reduce_mean(rho_below) + w_above * rho_above) / (w_below + w_above) 104 | 105 | 106 | class Dice: 107 | """ 108 | N-D dice for segmentation 109 | """ 110 | 111 | def loss(self, y_true, y_pred): 112 | ndims = len(y_pred.get_shape().as_list()) - 2 113 | vol_axes = list(range(1, ndims+1)) 114 | 115 | top = 2 * tf.reduce_sum(y_true * y_pred, vol_axes) 116 | bottom = tf.reduce_sum(y_true + y_pred, vol_axes) 117 | 118 | div_no_nan = tf.math.divide_no_nan if hasattr(tf.math, 'divide_no_nan') else tf.div_no_nan 119 | dice = tf.reduce_mean(div_no_nan(top, bottom)) 120 | return -dice 121 | 122 | 123 | class Grad: 124 | """ 125 | N-D gradient loss. 126 | loss_mult can be used to scale the loss value - this is recommended if 127 | the gradient is computed on a downsampled vector field (where loss_mult 128 | is equal to the downsample factor). 129 | """ 130 | 131 | def __init__(self, penalty='l1', loss_mult=None): 132 | self.penalty = penalty 133 | self.loss_mult = loss_mult 134 | 135 | def _diffs(self, y): 136 | vol_shape = y.get_shape().as_list()[1:-1] 137 | ndims = len(vol_shape) 138 | 139 | df = [None] * ndims 140 | for i in range(ndims): 141 | d = i + 1 142 | # permute dimensions to put the ith dimension first 143 | r = [d, *range(d), *range(d + 1, ndims + 2)] 144 | y = K.permute_dimensions(y, r) 145 | dfi = y[1:, ...] - y[:-1, ...] 146 | 147 | # permute back 148 | # note: this might not be necessary for this loss specifically, 149 | # since the results are just summed over anyway. 150 | r = [*range(1, d + 1), 0, *range(d + 1, ndims + 2)] 151 | r = [d, *range(1, d), 0, *range(d + 1, ndims + 2)] 152 | df[i] = K.permute_dimensions(dfi, r) 153 | 154 | return df 155 | 156 | def loss(self, _, y_pred): 157 | 158 | if self.penalty == 'l1': 159 | dif = [tf.abs(f) for f in self._diffs(y_pred)] 160 | else: 161 | assert self.penalty == 'l2', 'penalty can only be l1 or l2. Got: %s' % self.penalty 162 | dif = [f * f for f in self._diffs(y_pred)] 163 | 164 | df = [tf.reduce_mean(K.batch_flatten(f), axis=-1) for f in dif] 165 | grad = tf.add_n(df) / len(df) 166 | 167 | if self.loss_mult is not None: 168 | grad *= self.loss_mult 169 | 170 | return grad 171 | 172 | 173 | class KL: 174 | """ 175 | Kullback–Leibler divergence for probabilistic flows. 176 | """ 177 | 178 | def __init__(self, prior_lambda, flow_vol_shape): 179 | self.prior_lambda = prior_lambda 180 | self.flow_vol_shape = flow_vol_shape 181 | self.D = None 182 | 183 | def _adj_filt(self, ndims): 184 | """ 185 | compute an adjacency filter that, for each feature independently, 186 | has a '1' in the immediate neighbor, and 0 elsewhere. 187 | so for each filter, the filter has 2^ndims 1s. 188 | the filter is then setup such that feature i outputs only to feature i 189 | """ 190 | 191 | # inner filter, that is 3x3x... 192 | filt_inner = np.zeros([3] * ndims) 193 | for j in range(ndims): 194 | o = [[1]] * ndims 195 | o[j] = [0, 2] 196 | filt_inner[np.ix_(*o)] = 1 197 | 198 | # full filter, that makes sure the inner filter is applied 199 | # ith feature to ith feature 200 | filt = np.zeros([3] * ndims + [ndims, ndims]) 201 | for i in range(ndims): 202 | filt[..., i, i] = filt_inner 203 | 204 | return filt 205 | 206 | def _degree_matrix(self, vol_shape): 207 | # get shape stats 208 | ndims = len(vol_shape) 209 | sz = [*vol_shape, ndims] 210 | 211 | # prepare conv kernel 212 | conv_fn = getattr(tf.nn, 'conv%dd' % ndims) 213 | 214 | # prepare tf filter 215 | z = K.ones([1] + sz) 216 | filt_tf = tf.convert_to_tensor(self._adj_filt(ndims), dtype=tf.float32) 217 | strides = [1] * (ndims + 2) 218 | return conv_fn(z, filt_tf, strides, "SAME") 219 | 220 | def prec_loss(self, y_pred): 221 | """ 222 | a more manual implementation of the precision matrix term 223 | mu * P * mu where P = D - A 224 | where D is the degree matrix and A is the adjacency matrix 225 | mu * P * mu = 0.5 * sum_i mu_i sum_j (mu_i - mu_j) = 0.5 * sum_i,j (mu_i - mu_j) ^ 2 226 | where j are neighbors of i 227 | 228 | Note: could probably do with a difference filter, 229 | but the edges would be complicated unless tensorflow allowed for edge copying 230 | """ 231 | vol_shape = y_pred.get_shape().as_list()[1:-1] 232 | ndims = len(vol_shape) 233 | 234 | sm = 0 235 | for i in range(ndims): 236 | d = i + 1 237 | # permute dimensions to put the ith dimension first 238 | r = [d, *range(d), *range(d + 1, ndims + 2)] 239 | y = K.permute_dimensions(y_pred, r) 240 | df = y[1:, ...] - y[:-1, ...] 241 | sm += K.mean(df * df) 242 | 243 | return 0.5 * sm / ndims 244 | 245 | def loss(self, y_true, y_pred): 246 | """ 247 | KL loss 248 | y_pred is assumed to be D*2 channels: first D for mean, next D for logsigma 249 | D (number of dimensions) should be 1, 2 or 3 250 | 251 | y_true is only used to get the shape 252 | """ 253 | 254 | # prepare inputs 255 | ndims = len(y_pred.get_shape()) - 2 256 | mean = y_pred[..., 0:ndims] 257 | log_sigma = y_pred[..., ndims:] 258 | 259 | # compute the degree matrix (only needs to be done once) 260 | # we usually can't compute this until we know the ndims, 261 | # which is a function of the data 262 | if self.D is None: 263 | self.D = self._degree_matrix(self.flow_vol_shape) 264 | 265 | # sigma terms 266 | sigma_term = self.prior_lambda * self.D * tf.exp(log_sigma) - log_sigma 267 | sigma_term = K.mean(sigma_term) 268 | 269 | # precision terms 270 | # note needs 0.5 twice, one here (inside self.prec_loss), one below 271 | prec_term = self.prior_lambda * self.prec_loss(mean) 272 | 273 | # combine terms 274 | return 0.5 * ndims * (sigma_term + prec_term) # ndims because we averaged over dimensions as well 275 | 276 | 277 | class NMI: 278 | 279 | def __init__(self, bin_centers, vol_size, sigma_ratio=0.5, max_clip=1, local=False, crop_background=False, patch_size=1): 280 | """ 281 | Mutual information loss for image-image pairs. 282 | Author: Courtney Guo 283 | 284 | If you use this loss function, please cite the following: 285 | 286 | Guo, Courtney K. Multi-modal image registration with unsupervised deep learning. MEng. Thesis 287 | 288 | Unsupervised Learning of Probabilistic Diffeomorphic Registration for Images and Surfaces 289 | Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu 290 | MedIA: Medial Image Analysis. 2019. eprint arXiv:1903.03545 291 | """ 292 | print("vxm info: mutual information loss is experimental", file=sys.stderr) 293 | self.vol_size = vol_size 294 | self.max_clip = max_clip 295 | self.patch_size = patch_size 296 | self.crop_background = crop_background 297 | self.mi = self.local_mi if local else self.global_mi 298 | self.vol_bin_centers = K.variable(bin_centers) 299 | self.num_bins = len(bin_centers) 300 | self.sigma = np.mean(np.diff(bin_centers)) * sigma_ratio 301 | self.preterm = K.variable(1 / (2 * np.square(self.sigma))) 302 | 303 | def local_mi(self, y_true, y_pred): 304 | # reshape bin centers to be (1, 1, B) 305 | o = [1, 1, 1, 1, self.num_bins] 306 | vbc = K.reshape(self.vol_bin_centers, o) 307 | 308 | # compute padding sizes 309 | patch_size = self.patch_size 310 | x, y, z = self.vol_size 311 | x_r = -x % patch_size 312 | y_r = -y % patch_size 313 | z_r = -z % patch_size 314 | pad_dims = [[0,0]] 315 | pad_dims.append([x_r//2, x_r - x_r//2]) 316 | pad_dims.append([y_r//2, y_r - y_r//2]) 317 | pad_dims.append([z_r//2, z_r - z_r//2]) 318 | pad_dims.append([0,0]) 319 | padding = tf.constant(pad_dims) 320 | print(padding,'&'*30) 321 | # compute image terms 322 | # num channels of y_true and y_pred must be 1 323 | I_a = K.exp(- self.preterm * K.square(tf.pad(y_true, padding, 'CONSTANT') - vbc)) 324 | I_a /= K.sum(I_a, -1, keepdims=True) 325 | print(I_a.shape,'*'*30) 326 | I_b = K.exp(- self.preterm * K.square(tf.pad(y_pred, padding, 'CONSTANT') - vbc)) 327 | I_b /= K.sum(I_b, -1, keepdims=True) 328 | 329 | I_a_patch = tf.reshape(I_a, [(x+x_r)//patch_size, patch_size, (y+y_r)//patch_size, patch_size, (z+z_r)//patch_size, patch_size, self.num_bins]) 330 | I_a_patch = tf.transpose(I_a_patch, [0, 2, 4, 1, 3, 5, 6]) 331 | I_a_patch = tf.reshape(I_a_patch, [-1, patch_size**3, self.num_bins]) 332 | 333 | I_b_patch = tf.reshape(I_b, [(x+x_r)//patch_size, patch_size, (y+y_r)//patch_size, patch_size, (z+z_r)//patch_size, patch_size, self.num_bins]) 334 | I_b_patch = tf.transpose(I_b_patch, [0, 2, 4, 1, 3, 5, 6]) 335 | I_b_patch = tf.reshape(I_b_patch, [-1, patch_size**3, self.num_bins]) 336 | 337 | # compute probabilities 338 | I_a_permute = K.permute_dimensions(I_a_patch, (0,2,1)) 339 | pab = K.batch_dot(I_a_permute, I_b_patch) # should be the right size now, nb_labels x nb_bins 340 | pab /= patch_size**3 341 | pa = tf.reduce_mean(I_a_patch, 1, keepdims=True) 342 | pb = tf.reduce_mean(I_b_patch, 1, keepdims=True) 343 | 344 | papb = K.batch_dot(K.permute_dimensions(pa, (0,2,1)), pb) + K.epsilon() 345 | return K.mean(K.sum(K.sum(pab * K.log(pab/papb + K.epsilon()), 1), 1)) 346 | 347 | def global_mi(self, y_true, y_pred): 348 | if self.crop_background: 349 | # does not support variable batch size 350 | thresh = 0.0001 351 | padding_size = 20 352 | filt = tf.ones([padding_size, padding_size, padding_size, 1, 1]) 353 | 354 | smooth = tf.nn.conv3d(y_true, filt, [1, 1, 1, 1, 1], "SAME") 355 | mask = smooth > thresh 356 | # mask = K.any(K.stack([y_true > thresh, y_pred > thresh], axis=0), axis=0) 357 | y_pred = tf.boolean_mask(y_pred, mask) 358 | y_true = tf.boolean_mask(y_true, mask) 359 | y_pred = K.expand_dims(K.expand_dims(y_pred, 0), 2) 360 | y_true = K.expand_dims(K.expand_dims(y_true, 0), 2) 361 | 362 | else: 363 | # reshape: flatten images into shape (batch_size, heightxwidthxdepthxchan, 1) 364 | y_true = K.reshape(y_true, (-1, K.prod(K.shape(y_true)[1:]))) 365 | y_true = K.expand_dims(y_true, 2) 366 | y_pred = K.reshape(y_pred, (-1, K.prod(K.shape(y_pred)[1:]))) 367 | y_pred = K.expand_dims(y_pred, 2) 368 | 369 | nb_voxels = tf.cast(K.shape(y_pred)[1], tf.float32) 370 | 371 | # reshape bin centers to be (1, 1, B) 372 | o = [1, 1, np.prod(self.vol_bin_centers.get_shape().as_list())] 373 | vbc = K.reshape(self.vol_bin_centers, o) 374 | 375 | # compute image terms 376 | I_a = K.exp(- self.preterm * K.square(y_true - vbc)) 377 | I_a /= K.sum(I_a, -1, keepdims=True) 378 | 379 | I_b = K.exp(- self.preterm * K.square(y_pred - vbc)) 380 | I_b /= K.sum(I_b, -1, keepdims=True) 381 | 382 | # compute probabilities 383 | I_a_permute = K.permute_dimensions(I_a, (0,2,1)) 384 | pab = K.batch_dot(I_a_permute, I_b) # should be the right size now, nb_labels x nb_bins 385 | pab /= nb_voxels 386 | pa = tf.reduce_mean(I_a, 1, keepdims=True) 387 | pb = tf.reduce_mean(I_b, 1, keepdims=True) 388 | 389 | papb = K.batch_dot(K.permute_dimensions(pa, (0,2,1)), pb) + K.epsilon() 390 | return K.sum(K.sum(pab * K.log(pab/papb + K.epsilon()), 1), 1) 391 | 392 | def loss(self, y_true, y_pred): 393 | y_pred = K.clip(y_pred, 0, self.max_clip) 394 | y_true = K.clip(y_true, 0, self.max_clip) 395 | return 1-self.mi(y_true, y_pred) 396 | 397 | 398 | class LossTuner: 399 | """ 400 | Simple utility to apply a tuning weight to a loss tensor. 401 | """ 402 | 403 | def __init__(self, loss_func, weight_tensor): 404 | self.weight = weight_tensor 405 | self.loss_func = loss_func 406 | 407 | def loss(self, y_true, y_pred): 408 | return self.weight * self.loss_func(y_true, y_pred) 409 | 410 | -------------------------------------------------------------------------------- /network/recursive_cascaded_networks.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tflearn 3 | import numpy as np 4 | import tensorflow.keras.backend as K 5 | from .utils import Network 6 | from .base_networks import VTNAffineStem,DUAL 7 | from .spatial_transformer import Dense3DSpatialTransformer, Fast3DTransformer 8 | from .trilinear_sampler import TrilinearSampler 9 | from .losses import NCC 10 | from keras.layers.convolutional import AveragePooling3D 11 | from keras.layers import UpSampling3D 12 | 13 | #MutualInformation,Dice, 14 | def mask_metrics(seg1, seg2): 15 | ''' Given two segmentation seg1, seg2, 0 for background 255 for foreground. 16 | Calculate the Dice score 17 | $ 2 * | seg1 \cap seg2 | / (|seg1| + |seg2|) $ 18 | and the Jacc score 19 | $ | seg1 \cap seg2 | / (|seg1 \cup seg2|) $ 20 | ''' 21 | sizes = np.prod(seg1.shape.as_list()[1:]) 22 | seg1 = tf.reshape(seg1, [-1, sizes]) 23 | seg2 = tf.reshape(seg2, [-1, sizes]) 24 | seg1 = tf.cast(seg1 > 128, tf.float32) 25 | seg2 = tf.cast(seg2 > 128, tf.float32) 26 | dice_score = 2.0 * tf.reduce_sum(seg1 * seg2, axis=-1) / ( 27 | tf.reduce_sum(seg1, axis=-1) + tf.reduce_sum(seg2, axis=-1)) 28 | union = tf.reduce_sum(tf.maximum(seg1, seg2), axis=-1) 29 | return (dice_score, tf.reduce_sum(tf.minimum(seg1, seg2), axis=-1) / tf.maximum(0.01, union)) 30 | 31 | 32 | class RecursiveCascadedNetworks(Network): 33 | default_params = { 34 | 'weight': 1, 35 | 'raw_weight': 1, 36 | 'reg_weight': 1, 37 | } 38 | 39 | def __init__(self, name, framework, 40 | base_network, n_cascades, rep=1, 41 | det_factor=0.1, ortho_factor=0.1, reg_factor=1.0, 42 | extra_losses={}, warp_gradient=True, 43 | fast_reconstruction=False, warp_padding=False, 44 | **kwargs): 45 | super().__init__(name) 46 | self.det_factor = det_factor 47 | self.ortho_factor = ortho_factor 48 | self.reg_factor = reg_factor 49 | 50 | self.base_network = eval(base_network) 51 | self.stems = [(VTNAffineStem('affine_stem', trainable=True), {'raw_weight': 0, 'reg_weight': 0})] + sum([ 52 | [(self.base_network("deform_stem_" + str(i), 53 | flow_multiplier=1.0 / n_cascades), {'raw_weight': 0})] * rep 54 | for i in range(n_cascades)], []) 55 | self.stems[-1][1]['raw_weight'] = 1 56 | 57 | for _, param in self.stems: 58 | for k, v in self.default_params.items(): 59 | if k not in param: 60 | param[k] = v 61 | print(self.stems) 62 | 63 | self.framework = framework 64 | self.warp_gradient = warp_gradient 65 | self.fast_reconstruction = fast_reconstruction 66 | 67 | self.reconstruction = Fast3DTransformer( 68 | warp_padding) if fast_reconstruction else Dense3DSpatialTransformer(warp_padding) 69 | self.trilinear_sampler = TrilinearSampler() 70 | 71 | @property 72 | def trainable_variables(self): 73 | return list(set(sum([stem.trainable_variables for stem, params in self.stems], []))) 74 | 75 | @property 76 | def data_args(self): 77 | return dict() 78 | 79 | def build(self, img1, img2, seg1, seg2, pt1, pt2): 80 | stem_results = [] 81 | 82 | stem_result = self.stems[0][0](img1, img2) 83 | stem_result['warped'] = self.reconstruction( 84 | [img2, stem_result['flow']]) 85 | stem_result['agg_flow'] = stem_result['flow'] 86 | 87 | 88 | stem_results.append(stem_result) 89 | 90 | for stem, params in self.stems[1:]: 91 | if self.warp_gradient: 92 | stem_result = stem(img1, stem_results[-1]['warped']) 93 | else: 94 | stem_result = stem(img1, tf.stop_gradient( 95 | stem_results[-1]['warped'])) 96 | 97 | if len(stem_results) == 1 and 'W' in stem_results[-1]: 98 | I = tf.constant([1, 0, 0, 0, 1, 0, 0, 0, 1], 99 | tf.float32, [1, 3, 3]) 100 | stem_result['agg_flow'] = tf.einsum( 101 | 'bij,bxyzj->bxyzi', stem_results[-1]['W'] + I, stem_result['flow']) + stem_results[-1]['flow'] 102 | 103 | else: 104 | stem_result['agg_flow'] = self.reconstruction( 105 | [stem_results[-1]['agg_flow'], stem_result['flow']]) + stem_result['flow'] 106 | 107 | stem_result['warped'] = self.reconstruction( 108 | [img2, stem_result['agg_flow']]) 109 | 110 | 111 | stem_results.append(stem_result) 112 | 113 | 114 | def GetSimilarityLoss(img_fixed,img_moving): 115 | NCC_loss = NCC(win = [8,8,8]) 116 | ncc_loss = NCC_loss.loss(img_fixed,img_moving) 117 | return ncc_loss 118 | 119 | # unsupervised learning with simlarity loss and regularization loss 120 | for stem_result, (stem, params) in zip(stem_results, self.stems): 121 | if 'W' in stem_result: 122 | stem_result['loss'] = stem_result['det_loss'] * \ 123 | self.det_factor + \ 124 | stem_result['ortho_loss'] * self.ortho_factor 125 | if params['raw_weight'] > 0: 126 | stem_result['raw_loss'] = self.similarity_loss( 127 | img1, stem_result['warped']) 128 | stem_result['loss'] = stem_result['loss'] + \ 129 | stem_result['raw_loss'] * params['raw_weight'] 130 | else: 131 | if params['raw_weight'] > 0: 132 | stem_result['raw_loss'] = self.similarity_loss(img1,stem_result['warped']) 133 | 134 | if params['reg_weight'] > 0: 135 | stem_result['reg_loss'] = self.regularize_loss( 136 | stem_result['flow']) * self.reg_factor 137 | 138 | stem_result['loss'] = sum( 139 | [stem_result[k] * params[k.replace('loss', 'weight')] for k in stem_result if k.endswith('loss')]) 140 | 141 | ret = {} 142 | 143 | flow = stem_results[-1]['agg_flow'] 144 | warped = stem_results[-1]['warped'] 145 | mean_mse = tf.reduce_mean(tf.abs(warped*255-img1*255)) 146 | mean_ncc = self.cal_ncc(img1,warped) 147 | jacobian_det = self.jacobian_det(flow) 148 | loss = sum([r['loss'] * params['weight'] 149 | for r, (stem, params) in zip(stem_results, self.stems)]) 150 | 151 | pt_mask1 = tf.reduce_any(tf.reduce_any(pt1 >= 0, -1), -1) 152 | pt_mask2 = tf.reduce_any(tf.reduce_any(pt2 >= 0, -1), -1) 153 | pt1 = tf.maximum(pt1, 0.0) 154 | 155 | moving_pt1 = pt1 + self.trilinear_sampler([flow, pt1]) 156 | 157 | pt_mask = tf.cast(pt_mask1, tf.float32) * tf.cast(pt_mask2, tf.float32) 158 | landmark_dists = tf.sqrt(tf.reduce_sum( 159 | (moving_pt1 - pt2) ** 2, axis=-1)) * tf.expand_dims(pt_mask, axis=-1) 160 | landmark_dist = tf.reduce_mean(landmark_dists, axis=-1) 161 | 162 | if self.framework.segmentation_class_value is None: 163 | seg_fixed = seg1 164 | warped_seg_moving = self.reconstruction([seg2, flow]) 165 | dice_score, jacc_score = mask_metrics(seg_fixed, warped_seg_moving) 166 | jaccs = [jacc_score] 167 | dices = [dice_score] 168 | else: 169 | def mask_class(seg, value): 170 | return tf.cast(tf.abs(seg - value) < 0.5, tf.float32) * 255 171 | jaccs = [] 172 | dices = [] 173 | fixed_segs = [] 174 | warped_segs = [] 175 | for k, v in self.framework.segmentation_class_value.items(): 176 | 177 | fixed_seg_class = mask_class(seg1, v) 178 | warped_seg_class = self.reconstruction( 179 | [mask_class(seg2, v), flow]) 180 | class_dice, class_jacc = mask_metrics( 181 | fixed_seg_class, warped_seg_class) 182 | ret['jacc_{}'.format(k)] = class_jacc 183 | jaccs.append(class_jacc) 184 | dices.append(class_dice) 185 | fixed_segs.append(fixed_seg_class[...,0]) 186 | warped_segs.append(warped_seg_class[...,0]) 187 | seg_fixed = tf.stack(fixed_segs, axis=-1) 188 | warped_seg_moving = tf.stack(warped_segs, axis=-1) 189 | dice_score, jacc_score = tf.add_n( 190 | dices) / len(dices), tf.add_n(jaccs) / len(jaccs) 191 | 192 | ret.update({'loss': tf.reshape(loss, (1, )), 193 | 'dice_score': dice_score, 194 | 'jacc_score': jacc_score, 195 | 'dices': tf.stack(dices, axis=-1), 196 | 'jaccs': tf.stack(jaccs, axis=-1), 197 | 'landmark_dist': landmark_dist, 198 | 'landmark_dists': landmark_dists, 199 | 'total_ncc':mean_ncc, 200 | 'total_mse':mean_mse, 201 | 'real_flow': flow, 202 | 'pt_mask': pt_mask, 203 | 'moving_img':img2*255, 204 | 'reconstruction': warped * 255.0, 205 | 'image_reconstruct': warped, 206 | 'warped_moving': warped * 255.0, 207 | 'seg_fixed': seg_fixed, 208 | 'warped_seg_moving': warped_seg_moving, 209 | 'image_fixed': img1*255, 210 | 'moving_pt': moving_pt1, 211 | 'jacobian_det': jacobian_det}}) 212 | 213 | for i, r in enumerate(stem_results): 214 | for k in r: 215 | if k.endswith('loss'): 216 | ret['{}_{}'.format(i, k)] = r[k] 217 | ret['warped_seg_moving_%d' % 218 | i] = self.reconstruction([seg2, r['agg_flow']]) 219 | ret['warped_moving_%d' % i] = r['warped'] 220 | ret['flow_%d' % i] = r['flow'] 221 | ret['real_flow_%d' % i] = r['agg_flow'] 222 | 223 | return ret 224 | 225 | def similarity_loss(self, img1, warped_img2): 226 | sizes = np.prod(img1.shape.as_list()[1:]) 227 | flatten1 = tf.reshape(img1, [-1, sizes]) 228 | flatten2 = tf.reshape(warped_img2, [-1, sizes]) 229 | 230 | if self.fast_reconstruction: 231 | _, pearson_r, _ = tf.user_ops.linear_similarity(flatten1, flatten2) 232 | else: 233 | mean1 = tf.reshape(tf.reduce_mean(flatten1, axis=-1), [-1, 1]) 234 | mean2 = tf.reshape(tf.reduce_mean(flatten2, axis=-1), [-1, 1]) 235 | var1 = tf.reduce_mean(tf.square(flatten1 - mean1), axis=-1) 236 | var2 = tf.reduce_mean(tf.square(flatten2 - mean2), axis=-1) 237 | cov12 = tf.reduce_mean( 238 | (flatten1 - mean1) * (flatten2 - mean2), axis=-1) 239 | pearson_r = cov12 / tf.sqrt((var1 + 1e-6) * (var2 + 1e-6)) 240 | 241 | raw_loss = 1 - pearson_r 242 | raw_loss = tf.reduce_sum(raw_loss) 243 | return raw_loss 244 | 245 | def cal_ncc(self, img1, warped_img2): 246 | sizes = np.prod(img1.shape.as_list()[1:]) 247 | flatten1 = tf.reshape(img1, [-1, sizes]) 248 | flatten2 = tf.reshape(warped_img2, [-1, sizes]) 249 | 250 | if self.fast_reconstruction: 251 | _, pearson_r, _ = tf.user_ops.linear_similarity(flatten1, flatten2) 252 | else: 253 | mean1 = tf.reshape(tf.reduce_mean(flatten1, axis=-1), [-1, 1]) 254 | mean2 = tf.reshape(tf.reduce_mean(flatten2, axis=-1), [-1, 1]) 255 | var1 = tf.reduce_mean(tf.square(flatten1 - mean1), axis=-1) 256 | var2 = tf.reduce_mean(tf.square(flatten2 - mean2), axis=-1) 257 | cov12 = tf.reduce_mean( 258 | (flatten1 - mean1) * (flatten2 - mean2), axis=-1) 259 | pearson_r = cov12 / tf.sqrt((var1 + 1e-6) * (var2 + 1e-6)) 260 | ncc = tf.reduce_sum(pearson_r) 261 | return ncc 262 | 263 | def regularize_loss(self, flow): 264 | ret = ((tf.nn.l2_loss(flow[:, 1:, :, :] - flow[:, :-1, :, :]) + 265 | tf.nn.l2_loss(flow[:, :, 1:, :] - flow[:, :, :-1, :]) + 266 | tf.nn.l2_loss(flow[:, :, :, 1:] - flow[:, :, :, :-1])) / np.prod(flow.shape.as_list()[1:5])) 267 | return ret 268 | 269 | def jacobian_det(self, flow): 270 | _, var = tf.nn.moments(tf.linalg.det(tf.stack([ 271 | flow[:, 1:, :-1, :-1] - flow[:, :-1, :-1, :-1] + 272 | tf.constant([1, 0, 0], dtype=tf.float32), 273 | flow[:, :-1, 1:, :-1] - flow[:, :-1, :-1, :-1] + 274 | tf.constant([0, 1, 0], dtype=tf.float32), 275 | flow[:, :-1, :-1, 1:] - flow[:, :-1, :-1, :-1] + 276 | tf.constant([0, 0, 1], dtype=tf.float32) 277 | ], axis=-1)), axes=[1, 2, 3]) 278 | return tf.sqrt(var) #return the std of the jacb det 279 | 280 | -------------------------------------------------------------------------------- /network/spatial_transformer.py: -------------------------------------------------------------------------------- 1 | from keras.layers.core import Layer 2 | import tensorflow as tf 3 | 4 | 5 | class Fast3DTransformer(Layer): 6 | def __init__(self, padding = False, **kwargs): 7 | super().__init__(**kwargs) 8 | self.padding = padding 9 | 10 | def build(self, input_shape): 11 | if len(input_shape) > 3: 12 | raise Exception('Spatial Transformer must be called on a list of length 2 or 3. ' 13 | 'First argument is the image, second is the offset field.') 14 | 15 | if len(input_shape[1]) != 5 or input_shape[1][4] != 3: 16 | raise Exception('Offset field must be one 5D tensor with 3 channels. ' 17 | 'Got: ' + str(input_shape[1])) 18 | 19 | self.built = True 20 | 21 | def call(self, inputs): 22 | im, flow = inputs 23 | if self.padding: 24 | im = tf.pad(im, [[0, 0], [1, 1], [1, 1], [1, 1], [0, 0]], "CONSTANT") 25 | flow = tf.pad(flow, [[0, 0], [1, 1], [1, 1], [1, 1], [0, 0]], "CONSTANT") 26 | warped = tf.user_ops.reconstruction(im, flow) 27 | if self.padding: 28 | warped = warped[:, 1: -1, 1: -1, 1: -1, :] 29 | return warped 30 | 31 | def compute_output_shape(self, input_shape): 32 | return input_shape[0] 33 | 34 | 35 | # We acknowledge VoxelMorph for providing an alternative implementation of the 3D spatial trasnformer. 36 | # Modified from https://github.com/voxelmorph/voxelmorph 37 | 38 | class Dense3DSpatialTransformer(Layer): 39 | def __init__(self, padding = False, **kwargs): 40 | self.padding = padding 41 | super(Dense3DSpatialTransformer, self).__init__(**kwargs) 42 | 43 | def build(self, input_shape): 44 | if len(input_shape) > 3: 45 | raise Exception('Spatial Transformer must be called on a list of length 2 or 3. ' 46 | 'First argument is the image, second is the offset field.') 47 | 48 | if len(input_shape[1]) != 5 or input_shape[1][4] != 3: 49 | raise Exception('Offset field must be one 5D tensor with 3 channels. ' 50 | 'Got: ' + str(input_shape[1])) 51 | 52 | self.built = True 53 | 54 | def call(self, inputs): 55 | return self._transform(inputs[0], inputs[1][:, :, :, :, 1], 56 | inputs[1][:, :, :, :, 0], inputs[1][:, :, :, :, 2]) 57 | 58 | def compute_output_shape(self, input_shape): 59 | return input_shape[0] 60 | 61 | def _transform(self, I, dx, dy, dz): 62 | 63 | batch_size = tf.shape(dx)[0] 64 | height = tf.shape(dx)[1] 65 | width = tf.shape(dx)[2] 66 | depth = tf.shape(dx)[3] 67 | 68 | # Convert dx and dy to absolute locations 69 | x_mesh, y_mesh, z_mesh = self._meshgrid(height, width, depth) 70 | x_mesh = tf.expand_dims(x_mesh, 0) 71 | y_mesh = tf.expand_dims(y_mesh, 0) 72 | z_mesh = tf.expand_dims(z_mesh, 0) 73 | 74 | x_mesh = tf.tile(x_mesh, [batch_size, 1, 1, 1]) 75 | y_mesh = tf.tile(y_mesh, [batch_size, 1, 1, 1]) 76 | z_mesh = tf.tile(z_mesh, [batch_size, 1, 1, 1]) 77 | x_new = dx + x_mesh 78 | y_new = dy + y_mesh 79 | z_new = dz + z_mesh 80 | 81 | return self._interpolate(I, x_new, y_new, z_new) 82 | 83 | def _repeat(self, x, n_repeats): 84 | rep = tf.transpose( 85 | tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0]) 86 | rep = tf.cast(rep, dtype='int32') 87 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep) 88 | return tf.reshape(x, [-1]) 89 | 90 | def _meshgrid(self, height, width, depth): 91 | x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])), 92 | tf.transpose(tf.expand_dims(tf.linspace(0.0, 93 | tf.cast(width, tf.float32)-1.0, width), 1), [1, 0])) 94 | y_t = tf.matmul(tf.expand_dims(tf.linspace(0.0, 95 | tf.cast(height, tf.float32)-1.0, height), 1), 96 | tf.ones(shape=tf.stack([1, width]))) 97 | 98 | x_t = tf.tile(tf.expand_dims(x_t, 2), [1, 1, depth]) 99 | y_t = tf.tile(tf.expand_dims(y_t, 2), [1, 1, depth]) 100 | 101 | z_t = tf.linspace(0.0, tf.cast(depth, tf.float32)-1.0, depth) 102 | z_t = tf.expand_dims(tf.expand_dims(z_t, 0), 0) 103 | z_t = tf.tile(z_t, [height, width, 1]) 104 | 105 | return x_t, y_t, z_t 106 | 107 | def _interpolate(self, im, x, y, z): 108 | if self.padding: 109 | im = tf.pad(im, [[0, 0], [1, 1], [1, 1], [1, 1], [0, 0]], "CONSTANT") 110 | 111 | num_batch = tf.shape(im)[0] 112 | height = tf.shape(im)[1] 113 | width = tf.shape(im)[2] 114 | depth = tf.shape(im)[3] 115 | channels = im.get_shape().as_list()[4] 116 | 117 | out_height = tf.shape(x)[1] 118 | out_width = tf.shape(x)[2] 119 | out_depth = tf.shape(x)[3] 120 | 121 | x = tf.reshape(x, [-1]) 122 | y = tf.reshape(y, [-1]) 123 | z = tf.reshape(z, [-1]) 124 | 125 | padding_constant = 1 if self.padding else 0 126 | x = tf.cast(x, 'float32') + padding_constant 127 | y = tf.cast(y, 'float32') + padding_constant 128 | z = tf.cast(z, 'float32') + padding_constant 129 | 130 | max_x = tf.cast(width - 1, 'int32') 131 | max_y = tf.cast(height - 1, 'int32') 132 | max_z = tf.cast(depth - 1, 'int32') 133 | 134 | x0 = tf.cast(tf.floor(x), 'int32') 135 | x1 = x0 + 1 136 | y0 = tf.cast(tf.floor(y), 'int32') 137 | y1 = y0 + 1 138 | z0 = tf.cast(tf.floor(z), 'int32') 139 | z1 = z0 + 1 140 | 141 | x0 = tf.clip_by_value(x0, 0, max_x) 142 | x1 = tf.clip_by_value(x1, 0, max_x) 143 | y0 = tf.clip_by_value(y0, 0, max_y) 144 | y1 = tf.clip_by_value(y1, 0, max_y) 145 | z0 = tf.clip_by_value(z0, 0, max_z) 146 | z1 = tf.clip_by_value(z1, 0, max_z) 147 | 148 | dim3 = depth 149 | dim2 = depth*width 150 | dim1 = depth*width*height 151 | base = self._repeat(tf.range(num_batch)*dim1, 152 | out_height*out_width*out_depth) 153 | 154 | base_y0 = base + y0*dim2 155 | base_y1 = base + y1*dim2 156 | 157 | idx_a = base_y0 + x0*dim3 + z0 158 | idx_b = base_y1 + x0*dim3 + z0 159 | idx_c = base_y0 + x1*dim3 + z0 160 | idx_d = base_y1 + x1*dim3 + z0 161 | idx_e = base_y0 + x0*dim3 + z1 162 | idx_f = base_y1 + x0*dim3 + z1 163 | idx_g = base_y0 + x1*dim3 + z1 164 | idx_h = base_y1 + x1*dim3 + z1 165 | 166 | # use indices to lookup pixels in the flat image and restore 167 | # channels dim 168 | im_flat = tf.reshape(im, tf.stack([-1, channels])) 169 | im_flat = tf.cast(im_flat, 'float32') 170 | 171 | Ia = tf.gather(im_flat, idx_a) 172 | Ib = tf.gather(im_flat, idx_b) 173 | Ic = tf.gather(im_flat, idx_c) 174 | Id = tf.gather(im_flat, idx_d) 175 | Ie = tf.gather(im_flat, idx_e) 176 | If = tf.gather(im_flat, idx_f) 177 | Ig = tf.gather(im_flat, idx_g) 178 | Ih = tf.gather(im_flat, idx_h) 179 | 180 | # and finally calculate interpolated values 181 | x1_f = tf.cast(x1, 'float32') 182 | y1_f = tf.cast(y1, 'float32') 183 | z1_f = tf.cast(z1, 'float32') 184 | 185 | dx = x1_f - x 186 | dy = y1_f - y 187 | dz = z1_f - z 188 | 189 | wa = tf.expand_dims((dz * dx * dy), 1) 190 | wb = tf.expand_dims((dz * dx * (1-dy)), 1) 191 | wc = tf.expand_dims((dz * (1-dx) * dy), 1) 192 | wd = tf.expand_dims((dz * (1-dx) * (1-dy)), 1) 193 | we = tf.expand_dims(((1-dz) * dx * dy), 1) 194 | wf = tf.expand_dims(((1-dz) * dx * (1-dy)), 1) 195 | wg = tf.expand_dims(((1-dz) * (1-dx) * dy), 1) 196 | wh = tf.expand_dims(((1-dz) * (1-dx) * (1-dy)), 1) 197 | 198 | output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id, 199 | we*Ie, wf*If, wg*Ig, wh*Ih]) 200 | output = tf.reshape(output, tf.stack( 201 | [-1, out_height, out_width, out_depth, channels])) 202 | return output -------------------------------------------------------------------------------- /network/transform.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import tensorflow as tf 3 | import tflearn 4 | import numpy as np 5 | 6 | 7 | def get_coef(u): 8 | return tf.stack([((1 - u) ** 3) / 6, (3 * (u ** 3) - 6 * (u ** 2) + 4) / 6, 9 | (-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6, (u ** 3) / 6], axis=1) 10 | 11 | 12 | def sample_power(lo, hi, k, size=None): 13 | r = (hi - lo) / 2 14 | center = (hi + lo) / 2 15 | r = r ** (1 / k) 16 | points = (tf.random_uniform(size, dtype=tf.float32) - 0.5) * 2 * r 17 | points = (tf.abs(points) ** k) * tf.sign(points) 18 | return points + center 19 | 20 | 21 | def pad_3d(mat, pad): 22 | return tf.pad(mat, [[0, 0], [pad, pad], [pad, pad], [pad, pad], [0, 0]], "CONSTANT") 23 | 24 | 25 | def resize_linear(target_shape, control_fields): 26 | _, n, m, t, _ = control_fields.shape.as_list() 27 | assert n == m and m == t 28 | assert target_shape % n == 0 29 | factor = target_shape // n 30 | ret_n = target_shape 31 | 32 | def interpolate_axis(arr): 33 | ret = arr 34 | expand_shape = ret.shape.as_list() 35 | expand_shape[3] = (expand_shape[3] + 1) * factor 36 | expand_points_l = tf.reshape( 37 | tf.tile(tf.expand_dims( 38 | tf.concat([ret[:, :, :, 0:1, :], ret[:, :, :, :-1, :], ret[:, :, :, -2:-1, :]], axis=3), 4), 39 | [1, 1, 1, 1, factor, 1]), [-1] + list(expand_shape[1:])) 40 | expand_points_r = tf.reshape( 41 | tf.tile(tf.expand_dims( 42 | tf.concat([ret[:, :, :, 1:2, :], ret[:, :, :, 1:, :], ret[:, :, :, -1:, :]], axis=3), 4), 43 | [1, 1, 1, 1, factor, 1]), [-1] + list(expand_shape[1:])) 44 | d = factor // 2 45 | u = np.zeros(ret_n) 46 | for i in range(ret_n): 47 | p = min(max(0, (i - d) // factor), n - 2) 48 | u[i] = (p + 1.5) - (i + 0.5) / factor 49 | u = u.reshape((1, 1, 1, ret_n, 1)) 50 | ret = u * expand_points_l[:, :, :, d: d + ret_n] + \ 51 | (1 - u) * expand_points_r[:, :, :, d: d + ret_n] 52 | return ret 53 | 54 | ret = interpolate_axis(tf.transpose(control_fields, [0, 1, 3, 2, 4])) 55 | ret = interpolate_axis(tf.transpose(ret, [0, 3, 2, 1, 4])) 56 | ret = interpolate_axis(tf.transpose(ret, [0, 3, 1, 2, 4])) 57 | return ret 58 | 59 | 60 | def meshgrids(shape, flatten=True, name=None): 61 | with tf.name_scope(name, "meshgrid", [shape]): 62 | indices_x = tf.range(0, shape[1]) 63 | indices_y = tf.range(0, shape[2]) 64 | indices_z = tf.range(0, shape[3]) 65 | indices = tf.stack(tf.meshgrid(indices_x, indices_y, 66 | indices_z, indexing='ij'), axis=-1) 67 | indices = tf.tile(tf.expand_dims(indices, axis=0), 68 | tf.stack([shape[0], 1, 1, 1, 1])) 69 | indices = tf.cast(indices, tf.float32) 70 | if flatten: 71 | return tf.reshape(indices, tf.stack([shape[0], shape[1] * shape[2] * shape[3], 3])) 72 | else: 73 | return indices 74 | 75 | 76 | def meshgrids_like(tensor, flatten=True, name=None): 77 | return meshgrids(tf.shape(tensor), flatten, name) 78 | 79 | 80 | def warp_points(flow_fields, pts): 81 | ''' 82 | Arguments 83 | ---------------- 84 | flow_fields : [batch, X, Y, Z, 3] 85 | pts: [batch, 6, 3] 86 | ''' 87 | moving_pts = meshgrids_like(flow_fields, flatten=False) + flow_fields 88 | shape = tf.shape(flow_fields) 89 | moving_pts = tf.reshape(moving_pts, tf.stack( 90 | [shape[0], shape[1] * shape[2] * shape[3], 1, 3])) 91 | distance = tf.sqrt(tf.reduce_sum( 92 | (moving_pts - tf.expand_dims(pts, axis=1)) ** 2, axis=-1)) 93 | closest = tf.cast(tf.argmin(distance, axis=1), tf.int32) 94 | fixed_pts = tf.stack([tf.div(closest, shape[2] * shape[3]), tf.mod( 95 | tf.div(closest, shape[3]), shape[2]), tf.mod(closest, shape[3])], axis=2) 96 | return fixed_pts 97 | 98 | 99 | def free_form_fields(shape, control_fields, padding='same'): 100 | '''Calculate a flow fields based on 3-order B-Spline interpolation of control points. 101 | 102 | Arguments 103 | -------------- 104 | shape : list of 3 integers, flow field shape `(x, y, z)` 105 | control_fields : 5d tensor with 3 channels `(batch_size, n, m, t, 3)` 106 | 107 | Output 108 | -------------- 109 | 5d tensor with 3 channels `(batch_size, x, y, z, 3)` 110 | ''' 111 | interpolate_range = 4 112 | 113 | control_fields = tf.convert_to_tensor(control_fields, dtype=tf.float32) 114 | _, n, m, t, _ = control_fields.shape.as_list() 115 | if padding == 'same': 116 | control_fields = pad_3d(control_fields, 1) 117 | elif padding == 'valid': 118 | n -= 2 119 | m -= 2 120 | t -= 2 121 | control_fields = tf.reshape(tf.transpose( 122 | control_fields, (1, 2, 3, 0, 4)), [n + 2, m + 2, t + 2, -1]) 123 | 124 | assert shape[0] % (n - 1) == 0 125 | s_x = shape[0] // (n - 1) 126 | u_x = (tf.range(0, s_x, dtype=tf.float32) + 0.5) / s_x # s_x 127 | coef_x = get_coef(u_x) # (s_x, 4) 128 | 129 | shape_cf = control_fields.shape.as_list() 130 | flow = tf.concat([tf.matmul(coef_x, 131 | tf.reshape(control_fields[i: i + interpolate_range], [interpolate_range, -1])) 132 | for i in range(0, n - 1)], 133 | axis=0) 134 | 135 | assert shape[1] % (m - 1) == 0 136 | s_y = shape[1] // (m - 1) 137 | u_y = (tf.range(0, s_y, dtype=tf.float32) + 0.5) / s_y # s_y 138 | coef_y = get_coef(u_y) # (s_y, 4) 139 | 140 | flow = tf.reshape(tf.transpose(flow), [shape_cf[1], -1]) 141 | flow = tf.concat([tf.matmul(coef_y, 142 | tf.reshape(flow[i: i + interpolate_range], [interpolate_range, -1])) 143 | for i in range(0, m - 1)], 144 | axis=0) 145 | # print(flow.shape) 146 | assert shape[2] % (t - 1) == 0 147 | s_z = shape[2] // (t - 1) 148 | u_z = (tf.range(0, s_z, dtype=tf.float32) + 0.5) / s_z # s_y 149 | coef_z = get_coef(u_z) # (s_y, 4) 150 | 151 | flow = tf.reshape(tf.transpose(flow), [shape_cf[2], -1]) 152 | flow = tf.concat([tf.matmul(coef_z, 153 | tf.reshape(flow[i: i + interpolate_range], [interpolate_range, -1])) 154 | for i in range(0, t - 1)], 155 | axis=0) 156 | # print(flow.shape) 157 | flow = tf.reshape(flow, [shape[2], -1, 3, shape[1], shape[0]]) 158 | flow = tf.transpose(flow, [1, 4, 3, 0, 2]) 159 | return flow 160 | -------------------------------------------------------------------------------- /network/trilinear_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/voxelmorph/voxelmorph 2 | 3 | from keras.layers.core import Layer 4 | import tensorflow as tf 5 | 6 | 7 | class TrilinearSampler(Layer): 8 | def __init__(self, **kwargs): 9 | super(TrilinearSampler, self).__init__(**kwargs) 10 | 11 | def build(self, input_shape): 12 | if len(input_shape) > 3: 13 | raise Exception('Spatial Transformer must be called on a list of length 2 or 3. ' 14 | 'First argument is the image, second is the offset field.') 15 | 16 | if len(input_shape[1]) != 3 or input_shape[1][2] != 3: 17 | raise Exception('Offset field must be one 3D tensor with 3 channels. ' 18 | 'Got: ' + str(input_shape[1])) 19 | 20 | self.built = True 21 | 22 | def call(self, inputs): 23 | return self._interpolate(inputs[0], inputs[1][:, :, 1], inputs[1][:, :, 0], inputs[1][:, :, 2]) 24 | 25 | def compute_output_shape(self, input_shape): 26 | return (input_shape[0][0], input_shape[1][1], input_shape[0][4]) 27 | 28 | def _repeat(self, x, n_repeats): 29 | rep = tf.transpose( 30 | tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0]) 31 | rep = tf.cast(rep, dtype='int32') 32 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep) 33 | return tf.reshape(x, [-1]) 34 | 35 | def _interpolate(self, im, x, y, z): 36 | 37 | im = tf.pad(im, [[0, 0], [1, 1], [1, 1], [1, 1], [0, 0]], "CONSTANT") 38 | 39 | num_batch = tf.shape(im)[0] 40 | height = tf.shape(im)[1] 41 | width = tf.shape(im)[2] 42 | depth = tf.shape(im)[3] 43 | channels = tf.shape(im)[4] 44 | 45 | out_size = tf.shape(x)[1] 46 | 47 | x = tf.reshape(x, [-1]) 48 | y = tf.reshape(y, [-1]) 49 | z = tf.reshape(z, [-1]) 50 | 51 | x = tf.cast(x, 'float32')+1 52 | y = tf.cast(y, 'float32')+1 53 | z = tf.cast(z, 'float32')+1 54 | 55 | max_x = tf.cast(width - 1, 'int32') 56 | max_y = tf.cast(height - 1, 'int32') 57 | max_z = tf.cast(depth - 1, 'int32') 58 | 59 | x0 = tf.cast(tf.floor(x), 'int32') 60 | x1 = x0 + 1 61 | y0 = tf.cast(tf.floor(y), 'int32') 62 | y1 = y0 + 1 63 | z0 = tf.cast(tf.floor(z), 'int32') 64 | z1 = z0 + 1 65 | 66 | x0 = tf.clip_by_value(x0, 0, max_x) 67 | x1 = tf.clip_by_value(x1, 0, max_x) 68 | y0 = tf.clip_by_value(y0, 0, max_y) 69 | y1 = tf.clip_by_value(y1, 0, max_y) 70 | z0 = tf.clip_by_value(z0, 0, max_z) 71 | z1 = tf.clip_by_value(z1, 0, max_z) 72 | 73 | dim3 = depth 74 | dim2 = depth*width 75 | dim1 = depth*width*height 76 | base = self._repeat(tf.range(num_batch)*dim1, out_size) 77 | 78 | base_y0 = base + y0*dim2 79 | base_y1 = base + y1*dim2 80 | 81 | idx_a = base_y0 + x0*dim3 + z0 82 | idx_b = base_y1 + x0*dim3 + z0 83 | idx_c = base_y0 + x1*dim3 + z0 84 | idx_d = base_y1 + x1*dim3 + z0 85 | idx_e = base_y0 + x0*dim3 + z1 86 | idx_f = base_y1 + x0*dim3 + z1 87 | idx_g = base_y0 + x1*dim3 + z1 88 | idx_h = base_y1 + x1*dim3 + z1 89 | 90 | # use indices to lookup pixels in the flat image and restore 91 | # channels dim 92 | im_flat = tf.reshape(im, tf.stack([-1, channels])) 93 | im_flat = tf.cast(im_flat, 'float32') 94 | 95 | Ia = tf.gather(im_flat, idx_a) 96 | Ib = tf.gather(im_flat, idx_b) 97 | Ic = tf.gather(im_flat, idx_c) 98 | Id = tf.gather(im_flat, idx_d) 99 | Ie = tf.gather(im_flat, idx_e) 100 | If = tf.gather(im_flat, idx_f) 101 | Ig = tf.gather(im_flat, idx_g) 102 | Ih = tf.gather(im_flat, idx_h) 103 | 104 | # and finally calculate interpolated values 105 | x1_f = tf.cast(x1, 'float32') 106 | y1_f = tf.cast(y1, 'float32') 107 | z1_f = tf.cast(z1, 'float32') 108 | 109 | dx = x1_f - x 110 | dy = y1_f - y 111 | dz = z1_f - z 112 | 113 | wa = tf.expand_dims((dz * dx * dy), 1) 114 | wb = tf.expand_dims((dz * dx * (1-dy)), 1) 115 | wc = tf.expand_dims((dz * (1-dx) * dy), 1) 116 | wd = tf.expand_dims((dz * (1-dx) * (1-dy)), 1) 117 | we = tf.expand_dims(((1-dz) * dx * dy), 1) 118 | wf = tf.expand_dims(((1-dz) * dx * (1-dy)), 1) 119 | wg = tf.expand_dims(((1-dz) * (1-dx) * dy), 1) 120 | wh = tf.expand_dims(((1-dz) * (1-dx) * (1-dy)), 1) 121 | 122 | output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id, 123 | we*Ie, wf*If, wg*Ig, wh*Ih]) 124 | output = tf.reshape(output, tf.stack( 125 | [-1, out_size, channels])) 126 | return output 127 | -------------------------------------------------------------------------------- /network/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from tensorflow.python.platform import flags 5 | from tensorflow.python.platform import app 6 | from tensorflow.python import pywrap_tensorflow 7 | import numpy as np 8 | import tensorflow as tf 9 | import tflearn 10 | import re 11 | import sys 12 | from keras import backend as K 13 | from keras.layers import GlobalAveragePooling3D, GlobalMaxPooling3D, Reshape, Dense, Add, Activation 14 | 15 | 16 | def ReLU(target, name=None): 17 | return tflearn.activations.relu(target) 18 | 19 | 20 | def LeakyReLU(target, alpha=0.1, name=None): 21 | return tflearn.activations.leaky_relu(target, alpha=alpha, name=name) 22 | 23 | def Softmax(target, name=None): 24 | return tflearn.activations.softmax(target) 25 | 26 | def Sigmoid(target, name=None): 27 | return tflearn.activations.sigmoid(target) 28 | 29 | def convolve(opName, inputLayer, inputChannel, outputChannel, kernelSize, stride, stddev=1e-2): 30 | return tflearn.layers.conv_2d(inputLayer, outputChannel, kernelSize, strides=stride, 31 | padding='same', activation='linear', bias=True, scope=opName) 32 | 33 | 34 | def convolveReLU(opName, inputLayer, inputChannel, outputChannel, kernelSize, stride, stddev=1e-2): 35 | return ReLU(convolve(opName, inputLayer, 36 | inputChannel, outputChannel, 37 | kernelSize, stride, stddev), 38 | opName+'_rectified') 39 | 40 | 41 | def convolveLeakyReLU(opName, inputLayer, inputChannel, outputChannel, kernelSize, stride, alpha=0.1, stddev=1e-2): 42 | return LeakyReLU(convolve(opName, inputLayer, 43 | inputChannel, outputChannel, 44 | kernelSize, stride, stddev), 45 | alpha, opName+'_leakilyrectified') 46 | 47 | 48 | def upconvolve(opName, inputLayer, inputChannel, outputChannel, kernelSize, stride, targetH, targetW, stddev=1e-2): 49 | return tflearn.layers.conv.conv_2d_transpose(inputLayer, outputChannel, kernelSize, [targetH, targetW], strides=stride, 50 | padding='same', activation='linear', bias=False, scope=opName) 51 | 52 | 53 | def upconvolveReLU(opName, inputLayer, inputChannel, outputChannel, kernelSize, stride, targetH, targetW, stddev=1e-2): 54 | return ReLU(upconvolve(opName, inputLayer, 55 | inputChannel, outputChannel, 56 | kernelSize, stride, 57 | targetH, targetW, stddev), 58 | opName+'_rectified') 59 | 60 | 61 | def upconvolveLeakyReLU(opName, inputLayer, inputChannel, outputChannel, kernelSize, stride, targetH, targetW, alpha=0.1, stddev=1e-2): 62 | return LeakyReLU(upconvolve(opName, inputLayer, 63 | inputChannel, outputChannel, 64 | kernelSize, stride, 65 | targetH, targetW, stddev), 66 | alpha, opName+'_rectified') 67 | 68 | 69 | def set_tf_keys(feed_dict, **kwargs): 70 | ret = dict([(k + ':0', v) for k, v in feed_dict.items()]) 71 | ret.update([(k + ':0', v) for k, v in kwargs.items()]) 72 | return ret 73 | 74 | 75 | class Network: 76 | def __init__(self, name, trainable=True, reuse=None): 77 | self._built = reuse 78 | self._name = name 79 | self.trainable = trainable 80 | 81 | @property 82 | def name(self): 83 | return self._name 84 | 85 | def __call__(self, *args, **kwargs): 86 | with tf.variable_scope(self.name, reuse=self._built) as self.scope: 87 | self._built = True 88 | return self.build(*args, **kwargs) 89 | 90 | @property 91 | def trainable_variables(self): 92 | if isinstance(self.trainable, str): 93 | var_list = tf.get_collection( 94 | tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope.name) 95 | return [var for var in var_list if re.fullmatch(self.trainable, var.name)] 96 | elif self.trainable: 97 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope.name) 98 | else: 99 | return [] 100 | 101 | @property 102 | def data_args(self): 103 | return dict() 104 | 105 | 106 | class ParallelLayer: 107 | inputs = {} 108 | replicated_inputs = None 109 | 110 | 111 | class MultiGPUs: 112 | def __init__(self, num): 113 | self.num = num 114 | 115 | def __call__(self, net, args, opt=None): 116 | args = [self.reshape(arg) for arg in args] 117 | results = [] 118 | grads = [] 119 | self.current_device = None 120 | for i in range(self.num): 121 | def auto_gpu(opr): 122 | # if opr.name.find('stack') != -1: 123 | # print(opr) 124 | if opr.type.startswith('Gather') or opr.type in ('L2Loss', 'Pack', 'Gather', 'Tile', 'ReconstructionWrtImageGradient', 'Softmax', 'FloorMod', 'MatMul'): 125 | return '/cpu:0' 126 | else: 127 | return '/gpu:%d' % i 128 | with tf.device(auto_gpu): 129 | self.current_device = i 130 | net.controller = self 131 | result = net(*[arg[i] for arg in args]) 132 | results.append(result) 133 | if opt is not None: 134 | grads.append(opt.compute_gradients( 135 | result['loss'], var_list=net.trainable_variables))#seg_loss 136 | 137 | 138 | with tf.device('/gpu:0'): 139 | concat_result = {} 140 | for k in results[0]: 141 | if len(results[0][k].shape) == 0: 142 | concat_result[k] = tf.stack( 143 | [result[k] for result in results]) 144 | else: 145 | concat_result[k] = tf.concat( 146 | [result[k] for result in results], axis=0) 147 | 148 | if grads: 149 | op = opt.apply_gradients(self.average_gradients(grads)) 150 | return concat_result, op 151 | else: 152 | return concat_result 153 | 154 | def call(self, net, kwargs): 155 | if net.replicated_inputs is None: 156 | with tf.device('/gpu:0'): 157 | net.replicated_inputs = dict( 158 | [(k, self.reshape(v)) for k, v in net.inputs.items()]) 159 | for k, v in net.replicated_inputs.items(): 160 | kwargs[k] = v[self.current_device] 161 | return net(**kwargs) 162 | 163 | @staticmethod 164 | def average_gradients(grads): 165 | ret = [] 166 | for grad_list in zip(*grads): 167 | grad, var = grad_list[0] 168 | if grad is None: 169 | ret.append((None, var)) 170 | else: 171 | print(var, var.device) 172 | ret.append( 173 | (tf.add_n([grad for grad, _ in grad_list]) / len(grad_list), var)) 174 | return ret 175 | 176 | def reshape(self, tensor): 177 | return tf.reshape(tensor, tf.concat([tf.stack([self.num, -1]), tf.shape(tensor)[1:]], axis = 0)) 178 | 179 | 180 | class FileRestorer: 181 | def __init__(self, rules=[(r'(.*)', r'\1')]): 182 | self.rules = rules 183 | 184 | def get_targets(self, key): 185 | targets = [] 186 | for r in self.rules: 187 | if re.match(r[0], key): 188 | targets.append(re.sub(r[0], r[1], key)) 189 | return targets 190 | 191 | def restore(self, sess, file_name): 192 | try: 193 | reader = pywrap_tensorflow.NewCheckpointReader(file_name) 194 | var_to_shape_map = reader.get_variable_to_shape_map() 195 | assign_ops = [] 196 | g = sess.graph 197 | for key in sorted(var_to_shape_map): 198 | for target in self.get_targets(key): 199 | var = None 200 | try: 201 | var = g.get_tensor_by_name(target + ':0') 202 | print("restoring: {} ---> {}".format(key, target)) 203 | except KeyError as e: 204 | print("Ignoring: {} ---> {}".format(key, target)) 205 | if var is not None: 206 | assign_ops.append( 207 | tf.assign(var, reader.get_tensor(key))) 208 | sess.run(assign_ops) 209 | except Exception as e: # pylint: disable=broad-except 210 | raise(e) 211 | print(str(e)) 212 | if "corrupted compressed block contents" in str(e): 213 | print("It's likely that your checkpoint file has been compressed " 214 | "with SNAPPY.") 215 | if ("Data loss" in str(e) and 216 | (any([e in file_name for e in [".index", ".meta", ".data"]]))): 217 | proposed_file = ".".join(file_name.split(".")[0:-1]) 218 | v2_file_error_template = """ 219 | It's likely that this is a V2 checkpoint and you need to provide the filename 220 | *prefix*. Try removing the '.' and extension. Try: 221 | inspect checkpoint --file_name = {}""" 222 | print(v2_file_error_template.format(proposed_file)) 223 | 224 | 225 | def restore_exists(sess, file_name, show=False): 226 | """Prints tensors in a checkpoint file. 227 | If no `tensor_name` is provided, prints the tensor names and shapes 228 | in the checkpoint file. 229 | If `tensor_name` is provided, prints the content of the tensor. 230 | Args: 231 | file_name: Name of the checkpoint file. 232 | tensor_name: Name of the tensor in the checkpoint file to print. 233 | all_tensors: Boolean indicating whether to print all tensors. 234 | all_tensor_names: Boolean indicating whether to print all tensor names. 235 | """ 236 | try: 237 | reader = pywrap_tensorflow.NewCheckpointReader(file_name) 238 | var_to_shape_map = reader.get_variable_to_shape_map() 239 | assign_ops = [] 240 | if show: 241 | for key in sorted(var_to_shape_map): 242 | w = reader.get_tensor(key) 243 | print(key, w.dtype, w.shape) 244 | else: 245 | g = sess.graph 246 | for key in sorted(var_to_shape_map): 247 | try: 248 | var = g.get_tensor_by_name(key + ':0') 249 | print("restoring: ", key) 250 | except KeyError as e: 251 | print("Ignoring: " + key) 252 | if var is not None: 253 | assign_ops.append(tf.assign(var, reader.get_tensor(key))) 254 | sess.run(assign_ops) 255 | except Exception as e: # pylint: disable=broad-except 256 | print(str(e)) 257 | if "corrupted compressed block contents" in str(e): 258 | print("It's likely that your checkpoint file has been compressed " 259 | "with SNAPPY.") 260 | if ("Data loss" in str(e) and 261 | (any([e in file_name for e in [".index", ".meta", ".data"]]))): 262 | proposed_file = ".".join(file_name.split(".")[0:-1]) 263 | v2_file_error_template = """ 264 | It's likely that this is a V2 checkpoint and you need to provide the filename 265 | *prefix*. Try removing the '.' and extension. Try: 266 | inspect checkpoint --file_name = {}""" 267 | print(v2_file_error_template.format(proposed_file)) 268 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import re 5 | import numpy as np 6 | from tqdm import tqdm 7 | import SimpleITK as sitk 8 | from PIL import Image 9 | import math 10 | import scipy.misc 11 | import xlwt 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-c', '--checkpoint', type=str, default=None, 15 | help='Specifies a previous checkpoint to load') 16 | parser.add_argument('-r', '--rep', type=int, default=1, 17 | help='Number of times of shared-weight cascading') 18 | parser.add_argument('-g', '--gpu', type=str, default='0', 19 | help='Specifies gpu device(s)') 20 | parser.add_argument('-d', '--dataset', type=str, default=None, 21 | help='Specifies a data config') 22 | parser.add_argument('-v', '--val_subset', type=str, default=None) 23 | parser.add_argument('--batch', type=int, default=4, help='Size of minibatch') 24 | parser.add_argument('--fast_reconstruction', action='store_true') 25 | parser.add_argument('--paired', action='store_true') 26 | parser.add_argument('--data_args', type=str, default=None) 27 | parser.add_argument('--net_args', type=str, default=None) 28 | parser.add_argument('--name', type=str, default=None) 29 | args = parser.parse_args() 30 | 31 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 32 | 33 | import tensorflow as tf 34 | import tflearn 35 | 36 | import network 37 | import data_util.liver 38 | import data_util.brain 39 | 40 | 41 | def main(): 42 | if args.checkpoint is None: 43 | print('Checkpoint must be specified!') 44 | return 45 | if ':' in args.checkpoint: 46 | args.checkpoint, steps = args.checkpoint.split(':') 47 | steps = int(steps) 48 | print(steps) 49 | else: 50 | steps = None 51 | args.checkpoint = find_checkpoint_step(args.checkpoint, steps) 52 | print(args.checkpoint) 53 | model_dir = os.path.dirname(args.checkpoint) 54 | try: 55 | with open(os.path.join(model_dir, 'args.json'), 'r') as f: 56 | model_args = json.load(f) 57 | print(model_args) 58 | except Exception as e: 59 | print(e) 60 | model_args = {} 61 | 62 | if args.dataset is None: 63 | args.dataset = model_args['dataset'] 64 | if args.data_args is None: 65 | args.data_args = model_args['data_args'] 66 | 67 | Framework = network.FrameworkUnsupervised 68 | Framework.net_args['base_network'] = model_args['base_network'] 69 | Framework.net_args['n_cascades'] = model_args['n_cascades'] 70 | Framework.net_args['rep'] = args.rep 71 | Framework.net_args['augmentation'] = 'identity' 72 | Framework.net_args.update(eval('dict({})'.format(model_args['net_args']))) 73 | if args.net_args is not None: 74 | Framework.net_args.update(eval('dict({})'.format(args.net_args))) 75 | with open(os.path.join(args.dataset), 'r') as f: 76 | cfg = json.load(f) 77 | image_size = cfg.get('image_size', [160, 160, 160]) 78 | image_type = cfg.get('image_type') 79 | gpus = 0 if args.gpu == '-1' else len(args.gpu.split(',')) 80 | framework = Framework(devices=gpus, image_size=image_size, segmentation_class_value=cfg.get( 81 | 'segmentation_class_value', None), fast_reconstruction=args.fast_reconstruction, validation=True) 82 | print('Graph built') 83 | 84 | Dataset = eval('data_util.{}.Dataset'.format(image_type)) 85 | ds = Dataset(args.dataset, batch_size=args.batch, paired=args.paired, ** 86 | eval('dict({})'.format(args.data_args))) 87 | 88 | config = tf.ConfigProto() 89 | config.gpu_options.allow_growth = True 90 | sess = tf.Session(config = config) 91 | tf.global_variables_initializer().run(session=sess) 92 | 93 | checkpoint = args.checkpoint 94 | saver = tf.train.Saver(tf.get_collection( 95 | tf.GraphKeys.GLOBAL_VARIABLES)) 96 | saver.restore(sess, checkpoint) 97 | tflearn.is_training(False, session=sess) 98 | 99 | val_subsets = [data_util.liver.Split.VALID] 100 | if args.val_subset is not None: 101 | val_subsets = args.val_subset.split(',') 102 | 103 | tflearn.is_training(False, session=sess) 104 | writebook = xlwt.Workbook() 105 | testSheet1= writebook.add_sheet('dice') 106 | keys = ['image_fixed','moving_img','warped_moving','seg_fixed','warped_seg_moving','real_flow','total_mse','total_ncc','jaccs','landmark_dists','jacobian_det','dices']#'image_fixed','warped_moving','seg_fixed','warped_seg_moving','jaccs','landmark_dists','jacobian_det', 107 | if not os.path.exists('evaluate'): 108 | os.mkdir('evaluate') 109 | path_prefix = os.path.join('evaluate', short_name(checkpoint)) 110 | if args.rep > 1: 111 | path_prefix = path_prefix + '-rep' + str(args.rep) 112 | if args.name is not None: 113 | path_prefix = path_prefix + '-' + args.name 114 | for val_subset in val_subsets: 115 | if args.val_subset is not None: 116 | output_fname = path_prefix + '-' + str(val_subset) + '_atlas2.txt' 117 | else: 118 | output_fname = path_prefix + '_atlas2.txt' 119 | output_xls = path_prefix + '_atlas2.xls' 120 | with open(output_fname, 'w') as fo: 121 | print("Validation subset {}".format(val_subset)) 122 | gen = ds.generator(val_subset, loop=False) 123 | results = framework.validate(sess, gen, keys=keys, summary=False, predict=True, show_tqdm=True) 124 | ##################image save######################### 125 | image_save_path = os.path.join('./test_images', short_name(checkpoint)+'_atlas2')#+'_onebyone' 126 | if not os.path.isdir(image_save_path): 127 | os.makedirs(image_save_path) 128 | 129 | for i in range(len(results['image_fixed'])): 130 | print(results['id2'][i]) 131 | writer = sitk.ImageFileWriter() 132 | 133 | sitk.WriteImage(sitk.GetImageFromArray(np.squeeze(results['image_fixed'][i][:,:,:,0])), image_save_path+'/'+results['id2'][i]+'_fixed.mhd', True) 134 | sitk.WriteImage(sitk.GetImageFromArray(np.squeeze(results['warped_moving'][i][:,:,:,0])), image_save_path+'/'+results['id2'][i]+'_moving.mhd', True) 135 | sitk.WriteImage(sitk.GetImageFromArray(np.squeeze(results['moving_img'][i][:,:,:,0])), image_save_path+'/'+results['id2'][i]+'_moving_raw.mhd', True) 136 | sitk.WriteImage(sitk.GetImageFromArray(np.squeeze(results['real_flow'][i][:,:,:,:])), image_save_path+'/'+results['id2'][i]+'_flow.mhd', True) 137 | 138 | warped_seg = np.squeeze(np.zeros(results['warped_seg_moving'][i][:,:,:,0].shape)) 139 | seg_fixed = np.squeeze(np.zeros(results['warped_seg_moving'][i][:,:,:,0].shape)) 140 | 141 | for seg in range(results['warped_seg_moving'][i].shape[-1]): 142 | sub_warp = np.squeeze((results['warped_seg_moving'][i])[:,:,:,seg]) 143 | sub_warp = np.where(sub_warp>127.5,seg+1,0) 144 | sub_seg = np.squeeze((results['seg_fixed'][i])[:,:,:,seg]) 145 | sub_seg = np.where(sub_seg>127.5,seg+1,0) 146 | warped_seg += sub_warp 147 | seg_fixed += sub_seg 148 | sitk.WriteImage(sitk.GetImageFromArray(warped_seg), image_save_path+'/'+results['id2'][i]+'_warped_seg.mhd', True) 149 | sitk.WriteImage(sitk.GetImageFromArray(seg_fixed), image_save_path+'/'+results['id2'][i]+'_seg_fixed.mhd', True) 150 | 151 | for i in range(len(results['dices'])): 152 | print(results['id1'][i],results['dices'][i],results['id2'][i], np.mean(results['jaccs'][i]), np.mean(results['landmark_dists'][i]), results['jacobian_det'][i], file=fo)# 153 | for j in range(len(results['dices'][i])): 154 | testSheet1.write(i,j,float('%.4f'%results['dices'][i][j])) 155 | writebook.save(output_xls) 156 | print('Summary', file=fo) 157 | jaccs, dices, landmarks,ncc,mse = results['jaccs'], results['dices'], results['landmark_dists'],results['total_ncc'],results['total_mse'] 158 | jacobian_det = results['jacobian_det'] 159 | print("Dice score: {} ({})".format(np.mean(dices), np.std( 160 | np.mean(dices, axis=-1))), file=fo) 161 | print("Jacc score: {} ({})".format(np.mean(jaccs), np.std( 162 | np.mean(jaccs, axis=-1))), file=fo) 163 | print("ncc score: {} ({})".format(np.mean(ncc), np.std( 164 | np.mean(ncc, axis=-1))), file=fo) 165 | print("mse score: {} ({})".format(np.mean(mse), np.std( 166 | np.mean(mse, axis=-1))), file=fo) 167 | print("Landmark distance: {} ({})".format(np.mean(landmarks), np.std( 168 | np.mean(landmarks, axis=-1))), file=fo) 169 | print("Jacobian determinant: {} ({})".format(np.mean( 170 | jacobian_det), np.std(jacobian_det)), file=fo) 171 | for seg in range(results['dices'].shape[1]): 172 | print("dice score for seg {}: {}".format(seg, np.mean( 173 | results['dices'][:,seg])), file=fo) 174 | 175 | 176 | 177 | def short_name(checkpoint): 178 | cpath, steps = os.path.split(checkpoint) 179 | _, exp = os.path.split(cpath) 180 | return exp + '-' + steps 181 | 182 | def find_checkpoint_step(checkpoint_path, target_steps=None): 183 | pattern = re.compile(r'model-(\d+).index') 184 | checkpoints = [] 185 | for f in os.listdir(checkpoint_path): 186 | m = pattern.match(f) 187 | if m: 188 | steps = int(m.group(1)) 189 | checkpoints.append((-steps if target_steps is None else abs( 190 | target_steps - steps), os.path.join(checkpoint_path, f.replace('.index', '')))) 191 | return min(checkpoints, key=lambda x: x[0])[1] 192 | 193 | 194 | if __name__ == '__main__': 195 | main() 196 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import json 5 | import h5py 6 | import copy 7 | import collections 8 | import re 9 | import datetime 10 | import hashlib 11 | import time 12 | from timeit import default_timer 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-b', '--base_network', type=str, default='VTN', 16 | help='Specifies the base network (either VTN or VoxelMorph)') 17 | parser.add_argument('-n', '--n_cascades', type=int, default=1, 18 | help='Number of cascades') 19 | parser.add_argument('-r', '--rep', type=int, default=1, 20 | help='Number of times of shared-weight cascading') 21 | parser.add_argument('-g', '--gpu', type=str, default='0', 22 | help='Specifies gpu device(s)') 23 | parser.add_argument('-c', '--checkpoint', type=str, default=None, 24 | help='Specifies a previous checkpoint to start with') 25 | parser.add_argument('-d', '--dataset', type=str, default="datasets/liver.json", 26 | help='Specifies a data config') 27 | parser.add_argument('--batch', type=int, default=4, 28 | help='Number of image pairs per batch') 29 | parser.add_argument('--round', type=int, default=20000, 30 | help='Number of batches per epoch') 31 | parser.add_argument('--epochs', type=float, default=5, 32 | help='Number of epochs') 33 | parser.add_argument('--fast_reconstruction', action='store_true') 34 | parser.add_argument('--debug', action='store_true') 35 | parser.add_argument('--val_steps', type=int, default=100) 36 | parser.add_argument('--net_args', type=str, default='') 37 | parser.add_argument('--data_args', type=str, default='') 38 | parser.add_argument('--lr', type=float, default=1e-4) 39 | parser.add_argument('--clear_steps', action='store_true') 40 | parser.add_argument('--finetune', type=str, default=None) 41 | parser.add_argument('--name', type=str, default=None) 42 | parser.add_argument('--logs', type=str, default='') 43 | args = parser.parse_args() 44 | 45 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 46 | 47 | import tensorflow as tf 48 | import tflearn 49 | import keras 50 | 51 | import network 52 | import data_util.liver 53 | import data_util.brain 54 | from data_util.data import Split 55 | 56 | def main(): 57 | repoRoot = os.path.dirname(os.path.realpath(__file__)) 58 | 59 | if args.finetune is not None: 60 | args.clear_steps = True 61 | 62 | batchSize = args.batch 63 | iterationSize = args.round 64 | 65 | gpus = 0 if args.gpu == '-1' else len(args.gpu.split(',')) 66 | 67 | Framework = network.FrameworkUnsupervised 68 | Framework.net_args['base_network'] = args.base_network 69 | Framework.net_args['n_cascades'] = args.n_cascades 70 | Framework.net_args['rep'] = args.rep 71 | Framework.net_args.update(eval('dict({})'.format(args.net_args))) 72 | with open(os.path.join(args.dataset), 'r') as f: 73 | cfg = json.load(f) 74 | image_size = cfg.get('image_size', [128, 128, 128]) 75 | image_type = cfg.get('image_type') 76 | framework = Framework(devices=gpus, image_size=image_size, segmentation_class_value=cfg.get('segmentation_class_value', None), fast_reconstruction = args.fast_reconstruction) 77 | Dataset = eval('data_util.{}.Dataset'.format(image_type)) 78 | print('Graph built.') 79 | 80 | # load training set and validation set 81 | 82 | def set_tf_keys(feed_dict, **kwargs): 83 | ret = dict([(k + ':0', v) for k, v in feed_dict.items()]) 84 | ret.update([(k + ':0', v) for k, v in kwargs.items()]) 85 | return ret 86 | config = tf.ConfigProto(allow_soft_placement = True) 87 | config.gpu_options.allow_growth = True 88 | 89 | 90 | with tf.Session(config=config) as sess: 91 | saver = tf.train.Saver(tf.get_collection( 92 | tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=5, keep_checkpoint_every_n_hours=5) 93 | if args.checkpoint is None: 94 | steps = 0 95 | tf.global_variables_initializer().run() 96 | else: 97 | if '\\' not in args.checkpoint and '/' not in args.checkpoint: 98 | args.checkpoint = os.path.join( 99 | repoRoot, 'weights', args.checkpoint) 100 | if os.path.isdir(args.checkpoint): 101 | args.checkpoint = tf.train.latest_checkpoint(args.checkpoint) 102 | 103 | tf.global_variables_initializer().run() 104 | checkpoints = args.checkpoint.split(';') 105 | if args.clear_steps: 106 | steps = 0 107 | else: 108 | steps = int(re.search('model-(\d+)', checkpoints[0]).group(1)) 109 | def optimistic_restore(session, save_file): 110 | reader = tf.train.NewCheckpointReader(save_file) 111 | saved_shapes = reader.get_variable_to_shape_map()#get the saving model var 112 | var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables() 113 | if var.name.split(':')[0] in saved_shapes]) 114 | restore_vars = [] 115 | name2var = dict(zip(map(lambda x:x.name.split(':')[0], tf.global_variables()), tf.global_variables())) 116 | with tf.variable_scope('', reuse=True): 117 | for var_name, saved_var_name in var_names: 118 | curr_var = name2var[saved_var_name] 119 | var_shape = curr_var.get_shape().as_list() 120 | if var_shape == saved_shapes[saved_var_name]: 121 | restore_vars.append(curr_var) 122 | saver_list = tf.train.Saver(restore_vars) 123 | print(restore_vars) 124 | saver_list.restore(session, save_file) 125 | for cp in checkpoints: 126 | optimistic_restore(sess, cp) 127 | 128 | data_args = eval('dict({})'.format(args.data_args)) 129 | data_args.update(framework.data_args) 130 | print('data_args', data_args) 131 | dataset = Dataset(args.dataset, **data_args) 132 | if args.finetune is not None: 133 | if 'finetune-train-%s' % args.finetune in dataset.schemes: 134 | dataset.schemes[Split.TRAIN] = dataset.schemes['finetune-train-%s' % 135 | args.finetune] 136 | if 'finetune-val-%s' % args.finetune in dataset.schemes: 137 | dataset.schemes[Split.VALID] = dataset.schemes['finetune-val-%s' % 138 | args.finetune] 139 | print('train', dataset.schemes[Split.TRAIN]) 140 | print('val', dataset.schemes[Split.VALID]) 141 | generator = dataset.generator(Split.TRAIN, batch_size=batchSize, loop=True) 142 | 143 | if not args.debug: 144 | if args.finetune is not None: 145 | run_id = os.path.basename(os.path.dirname(args.checkpoint)) 146 | if not run_id.endswith('_ft' + args.finetune): 147 | run_id = run_id + '_ft' + args.finetune 148 | else: 149 | pad = '' 150 | retry = 1 151 | while True: 152 | dt = datetime.datetime.now( 153 | tz=datetime.timezone(datetime.timedelta(hours=8))) 154 | run_id = dt.strftime('%b%d-%H%M') + pad 155 | modelPrefix = os.path.join(repoRoot, 'weights', run_id) 156 | try: 157 | os.makedirs(modelPrefix) 158 | break 159 | except Exception as e: 160 | print('Conflict with {}! Retry...'.format(run_id)) 161 | pad = '_{}'.format(retry) 162 | retry += 1 163 | modelPrefix = os.path.join(repoRoot, 'weights', run_id) 164 | if not os.path.exists(modelPrefix): 165 | os.makedirs(modelPrefix) 166 | if args.name is not None: 167 | run_id += '_' + args.name 168 | if args.logs is None: 169 | log_dir = 'logs' 170 | else: 171 | log_dir = os.path.join('logs', args.logs) 172 | summary_path = os.path.join(repoRoot, log_dir, run_id) 173 | if not os.path.exists(summary_path): 174 | os.makedirs(summary_path) 175 | summaryWriter = tf.summary.FileWriter(summary_path, sess.graph) 176 | with open(os.path.join(modelPrefix, 'args.json'), 'w') as fo: 177 | json.dump(vars(args), fo) 178 | 179 | if args.finetune is not None: 180 | learningRates = [1e-5 / 2, 1e-5 / 2, 1e-5 / 2, 1e-5 / 4, 1e-5 / 8] 181 | #args.epochs = 1 182 | else: 183 | learningRates = [1e-4, 1e-4, 1e-4, 1e-4 , 1e-4 / 2,1e-4 / 2, 1e-4 / 2, 1e-4 / 4, 1e-4 / 4,1e-4 / 8] 184 | # Training 185 | 186 | def get_lr(steps): 187 | m = args.lr / learningRates[0] 188 | return m * learningRates[steps // iterationSize] 189 | 190 | last_save_stamp = time.time() 191 | while True: 192 | if hasattr(framework, 'get_lr'): 193 | lr = framework.get_lr(steps, batchSize) 194 | else: 195 | lr = get_lr(steps) 196 | t0 = default_timer() 197 | fd = next(generator) 198 | fd.pop('mask', []) 199 | fd.pop('id1', []) 200 | fd.pop('id2', []) 201 | t1 = default_timer() 202 | tflearn.is_training(True, session=sess) 203 | summ, _ = sess.run([framework.summaryExtra, framework.adamOpt], 204 | set_tf_keys(fd, learningRate=lr)) 205 | 206 | for v in tf.Summary().FromString(summ).value: 207 | if v.tag == 'loss': 208 | loss = v.simple_value 209 | 210 | steps += 1 211 | if args.debug or steps % 10 == 0: 212 | if steps >= args.epochs * iterationSize: 213 | break 214 | 215 | if not args.debug: 216 | summaryWriter.add_summary(summ, steps) 217 | 218 | if steps % 50 == 0: 219 | if hasattr(framework, 'summaryImages'): 220 | summ, = sess.run([framework.summaryImages], 221 | set_tf_keys(fd)) 222 | summaryWriter.add_summary(summ, steps) 223 | 224 | if steps % 50 == 0: 225 | print('*%s* ' % run_id, 226 | time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), 227 | 'Steps %d, Total time %.2f, data %.2f%%. Loss %.3e lr %.3e' % (steps, 228 | default_timer() - t0, 229 | (t1 - t0) / ( 230 | default_timer() - t0), 231 | loss, 232 | lr), 233 | end='\n') 234 | if args.debug or steps % args.val_steps == 0: 235 | try: 236 | val_gen = dataset.generator( 237 | Split.VALID, loop=False, batch_size=1) 238 | metrics = framework.validate( 239 | sess, val_gen, summary=True) 240 | val_summ = tf.Summary(value=[ 241 | tf.Summary.Value(tag='val_' + k, simple_value=v) for k, v in metrics.items() 242 | ]) 243 | print('dice:',metrics['dice_score']) 244 | print('ncc:',metrics['total_ncc']) 245 | print('mse:',metrics['total_mse']) 246 | except: 247 | if steps == args.val_steps: 248 | print('Step {}, validation failed!'.format(steps)) 249 | print('Finished.') 250 | 251 | 252 | if __name__ == '__main__': 253 | main() 254 | --------------------------------------------------------------------------------