├── README.md ├── flowreg_a ├── README.md ├── data_generator.py ├── loss.py ├── model.py ├── neuron │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── dataproc.cpython-35.pyc │ │ ├── layers.cpython-35.pyc │ │ ├── layers.cpython-36.pyc │ │ ├── utils.cpython-35.pyc │ │ └── utils.cpython-36.pyc │ ├── callbacks.py │ ├── dataproc.py │ ├── generators.py │ ├── inits.py │ ├── layers.py │ ├── metrics.py │ ├── models.py │ ├── plot.py │ ├── regularizers.py │ ├── utils.py │ └── vae_tools.py ├── register.py ├── train.py └── utils.py └── flowreg_o ├── README.md ├── __pycache__ ├── data_generator.cpython-35.pyc ├── loss.cpython-35.pyc ├── model.cpython-35.pyc └── utils.cpython-35.pyc ├── data_generator.py ├── loss.py ├── model.py ├── register.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # FlowReg: Fast Deformable Unsupervised Medical Image Registration using Optical Flow 2 | 3 | FlowReg is a deep-learning based medical image registration framework. The framework is divided in 3D affine component (FlowReg-A) and a 2D fine-tuning network using optical flow (FlowReg-O). 4 | 5 | ## Usage 6 | Please visit [FlowReg-A](flowreg_a) and [FlowReg-O](flowreg_o) for the corresponding usage cases. 7 | 8 | ## Citation 9 | If you use any portion of our work, please cite our paper. 10 | ``` 11 | S. Mocanu, A. Moody, and A. Khademi, “FlowReg: Fast Deformable Unsupervised Medical Image Registration using Optical Flow,” 12 | Machine Learning for Biomedical Imaging, pp. 1–40, Sep. 2021. 13 | ``` 14 | Available at: https://www.melba-journal.org/article/27657-flowreg-fast-deformable-unsupervised-medical-image-registration-using-optical-flow -------------------------------------------------------------------------------- /flowreg_a/README.md: -------------------------------------------------------------------------------- 1 | # FlowReg-Affine (FlowReg-A) 2 | A fully unsupervised framework for affine registration using deep learning. 3 | 4 | ## Training 5 | To train using your own data, currently the framework supports `.mat` files, however, you may adapt to your specific needs. 6 | 7 | The application is command-line based and can be run in two easy steps by using your terminal of choice. 8 | 1. Navigate the `flowreg_a` directory. 9 | 2. Run `python train.py` with the appropriate arguments explained below. 10 | 11 | The command-line arguments are as follow. The text in *italics* is the expect data-type input: 12 | - `-t` or `--train` the directory of volumes used in training (*string*) 13 | - `-v` or `--validation` the directory of volumes used for validation (*string*) 14 | - `-f` or `--fixed` the fixed volume path (*string*) 15 | - `-b` or `--batch` the batch size used in training, default = 4 (*integer*) 16 | - `-e` or `--epochs` number of training epocs, default = 100 (*integer*) 17 | - `-c` or `--checkpoint` at which interval to save, default = 0 (*integer*) 18 | - `-l` or `--save_loss` save loss value to a csv during training, default = True (*boolean*) 19 | - `-m` or `--model_save` directory to save the final model (*string*) 20 | 21 | Note: the `checkpoint` and `save_loss` will be saved in appropriate folders within the `flowreg_a` folder. Otherwise, it can be easily modified in the `train.py` file. 22 | 23 | An example command could look something like: 24 | ``` 25 | python train.py \ 26 | --train "path/to/train/directory" \ 27 | --validation "path/to/validation/directory" \ 28 | --fixed "path/to/fixed/volume.mat" \ 29 | --batch 4 \ 30 | --checkpoint 1 \ 31 | --epochs 100 \ 32 | --save_loss True \ 33 | --model_save "path/to/model/save/directory" 34 | ``` 35 | 36 | ## Registration 37 | If you have a trained model, the script to register volumes can be found in `register.py`. 38 | 39 | Similar to training, registration is done via a command-line interface with the following arguments: 40 | - `-r` or `--register` directory of the volumes to be registered (*string*) 41 | - `-f` or `--fixed` directory of the fixed volume (*string*) 42 | - `-s` or `--save` directory where to save the registered volumes (*string*) 43 | - `-m` or `--model` directory of the model weights, a .h5 file (*string*) 44 | 45 | (OPTIONAL) Binary masks can be passed as additional arguments that will be warped with the calculated affine matrix. These masks do not have to be the 'brain', 'ventricles', or 'wml' (white matter lesions) masks as specified in the argument name. Any binary mask can be used as long as they correspond to the orientation and dimension of the moving volume. 46 | - `-b` or `--brain` brain masks directory (*string*) 47 | - `-v` or `--vent` ventricle masks directory (*string*) 48 | - `-w` or `--wml` WML masks directory (*string*) 49 | 50 | The output `.mat` file will be the registered volume and the corresponding flattened affine matrix. If masks are used, they will also be saved with `brainMask`, `ventMask`, or `wmlMask`. 51 | 52 | ## Citation 53 | If you use any portion of our work, please cite our paper. 54 | ``` 55 | S. Mocanu, A. Moody, and A. Khademi, “FlowReg: Fast Deformable Unsupervised Medical Image Registration using Optical Flow,” 56 | Machine Learning for Biomedical Imaging, pp. 1–40, Sep. 2021. 57 | ``` 58 | Available at: https://www.melba-journal.org/article/27657-flowreg-fast-deformable-unsupervised-medical-image-registration-using-optical-flow -------------------------------------------------------------------------------- /flowreg_a/data_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | import glob, os 4 | from tensorflow.python.keras.utils.data_utils import Sequence 5 | from skimage import transform 6 | 7 | from utils import normalize 8 | 9 | class DataGenerator(Sequence): 10 | 11 | def __init__(self, vols, mvol_dir, fvol_dir, batch_size=4, shuffle=True, dim=(256, 256, 55), n_channels=1): 12 | self.vols = vols 13 | self.mvol_dir = mvol_dir 14 | self.fvol_dir = fvol_dir 15 | self.batch_size = batch_size 16 | self.shuffle = shuffle 17 | self.dim = dim 18 | self.n_channels = n_channels 19 | self.on_epoch_end() 20 | 21 | def __len__(self): 22 | """ Denotes the number of batches per epoch""" 23 | return int(np.floor(len(self.vols) / self.batch_size)) 24 | 25 | def __getitem__(self, index): 26 | """ Generates one batch of data """ 27 | # Generates indexes of the batch 28 | indexes = self.indexes[index * self.batch_size:(index+1) * self.batch_size] 29 | # find list of IDs 30 | list_IDs_temp = [self.vols[k] for k in indexes] 31 | # generate data 32 | X, y = self.__data_generation(list_IDs_temp) 33 | return X, y 34 | 35 | def on_epoch_end(self): 36 | """Updates indexes after each epoch""" 37 | self.indexes = np.arange(len(self.vols)) 38 | if self.shuffle == True: 39 | np.random.shuffle(self.indexes) 40 | 41 | def __data_generation(self, list_IDs_temp): 42 | """ Generates data containing batch_size samples""" # X: (n_samples, *dim, n_channels) 43 | fixed = np.empty((self.batch_size, *self.dim, self.n_channels)) 44 | moving = np.empty((self.batch_size, *self.dim, self.n_channels)) 45 | 46 | fixed_vol = normalize(sio.loadmat(self.fvol_dir)['atlasFinal']) 47 | 48 | if os.path.isdir(self.mvol_dir): 49 | moving_vols = glob.glob(self.mvol_dir + '*.mat') 50 | elif os.path.isfile(self.mvol_dir): 51 | moving_vols = [line.rstrip('\n') for line in open(self.mvol_dir)] 52 | else: 53 | print("Invalid training data. Should be .txt file containing (training/validation) set location or directory of (training/validation) volumes") 54 | 55 | 56 | for i, ID, in enumerate(list_IDs_temp): 57 | moving_vol = normalize(transform.resize(sio.loadmat(moving_vols[ID])['im']['vol'][0][0], (256,256,55))) 58 | 59 | moving[i, :, :, :, 0] = moving_vol 60 | fixed[i, :, :, :, 0] = fixed_vol 61 | 62 | X = [fixed, moving] 63 | y = fixed 64 | return X, y 65 | -------------------------------------------------------------------------------- /flowreg_a/loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def correlation(true, pred): 4 | true = tf.cast(true, tf.float32) 5 | pred = tf.cast(pred, tf.float32) 6 | 7 | mux = tf.reduce_mean(true) 8 | muy = tf.reduce_mean(pred) 9 | n = tf.cast(tf.size(true), tf.float32) 10 | 11 | varx = tf.reduce_sum(tf.square(true - mux))/n 12 | vary = tf.reduce_sum(tf.square(pred - muy))/n 13 | 14 | corr = 1/n * tf.reduce_sum((true - mux) * (pred - muy)) / tf.math.sqrt(varx * vary) 15 | 16 | return 1-corr -------------------------------------------------------------------------------- /flowreg_a/model.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Input, Conv3D, Concatenate, Dense, Flatten 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.optimizers import Adam 4 | 5 | from loss import correlation 6 | 7 | # the Spatial Transformer layer used for the affine transform and resmapling was borrowed from 8 | # https://github.com/adalca/neurite orginally named neuron 9 | from neuron.layers import SpatialTransformer 10 | 11 | def affmodel(shape): 12 | fixedInput = Input(shape=shape, name='fixed') 13 | movingInput = Input(shape=shape, name='moving') 14 | 15 | inputs = Concatenate(axis=4, name='inputs')([fixedInput, movingInput]) 16 | 17 | conv1 = Conv3D(filters=16, kernel_size=7, strides=(2, 2, 1), padding='SAME', name='conv1', activation='relu')(inputs) 18 | 19 | conv2 = Conv3D(filters=32, kernel_size=5, strides=(2, 2, 1), padding='SAME', name='conv2', activation='relu')(conv1) 20 | 21 | conv3 = Conv3D(filters=64, kernel_size=3, strides=2, padding='SAME', name='conv3', activation='relu')(conv2) 22 | 23 | conv4 = Conv3D(filters=128, kernel_size=3, strides=2, padding='SAME', name='conv4', activation='relu')(conv3) 24 | 25 | conv5 = Conv3D(filters=256, kernel_size=3, strides=2, padding='SAME', name='conv5', activation='relu')(conv4) 26 | 27 | conv6 = Conv3D(filters=512, kernel_size=3, strides=2, padding='SAME', name='conv6', activation='relu')(conv5) 28 | 29 | flat = Flatten()(conv6) 30 | fc6 = Dense(12, activation='linear', name='Dense')(flat) 31 | 32 | out = SpatialTransformer(interp_method='linear', indexing='ij')([movingInput, fc6]) 33 | 34 | adam = Adam(lr=0.0001) 35 | model = Model(inputs=[fixedInput, movingInput], outputs=[out]) 36 | model.compile(optimizer=adam, loss=correlation) 37 | return model -------------------------------------------------------------------------------- /flowreg_a/neuron/__init__.py: -------------------------------------------------------------------------------- 1 | # import various 2 | # from . import dataproc 3 | # from . import generators 4 | # from . import callbacks 5 | # from . import plot 6 | # from . import metrics 7 | # from . import inits 8 | # from . import models 9 | # from . import utils 10 | # from . import layers 11 | # from . import vae_tools 12 | # from . import regularizers -------------------------------------------------------------------------------- /flowreg_a/neuron/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAMLAB-Ryerson/FlowReg/b3f613f8fc36175b0fd832041c75c7d37508a976/flowreg_a/neuron/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /flowreg_a/neuron/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAMLAB-Ryerson/FlowReg/b3f613f8fc36175b0fd832041c75c7d37508a976/flowreg_a/neuron/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /flowreg_a/neuron/__pycache__/dataproc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAMLAB-Ryerson/FlowReg/b3f613f8fc36175b0fd832041c75c7d37508a976/flowreg_a/neuron/__pycache__/dataproc.cpython-35.pyc -------------------------------------------------------------------------------- /flowreg_a/neuron/__pycache__/layers.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAMLAB-Ryerson/FlowReg/b3f613f8fc36175b0fd832041c75c7d37508a976/flowreg_a/neuron/__pycache__/layers.cpython-35.pyc -------------------------------------------------------------------------------- /flowreg_a/neuron/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAMLAB-Ryerson/FlowReg/b3f613f8fc36175b0fd832041c75c7d37508a976/flowreg_a/neuron/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /flowreg_a/neuron/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAMLAB-Ryerson/FlowReg/b3f613f8fc36175b0fd832041c75c7d37508a976/flowreg_a/neuron/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /flowreg_a/neuron/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAMLAB-Ryerson/FlowReg/b3f613f8fc36175b0fd832041c75c7d37508a976/flowreg_a/neuron/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /flowreg_a/neuron/callbacks.py: -------------------------------------------------------------------------------- 1 | ''' callbacks for the neuron project ''' 2 | 3 | ''' 4 | We'd like the following callback actions for neuron: 5 | 6 | - print metrics on the test and validation, especially surface-specific dice 7 | --- Perhaps doable with CSVLogger? 8 | - output graph up to current iteration for each metric 9 | --- Perhaps call CSVLogger or some metric computing callback? 10 | - save dice plots on validation 11 | --- again, expand CSVLogger or similar 12 | - save screenshots of a single test subject [Perhaps just do this as a separate callback?] 13 | --- new callback, PlotSlices 14 | 15 | ''' 16 | import sys 17 | import warnings 18 | from imp import reload 19 | 20 | import matplotlib.pyplot as plt 21 | import numpy as np 22 | from tensorflow import keras 23 | 24 | # the neuron folder should be on the path 25 | import neuron.utils as nrn_utils 26 | import pytools.timer as timer 27 | 28 | 29 | class ModelWeightCheck(keras.callbacks.Callback): 30 | """ 31 | check model weights for nan and infinite entries 32 | """ 33 | 34 | def __init__(self, 35 | weight_diff=False, 36 | at_batch_end=False, 37 | at_epoch_end=True): 38 | """ 39 | Params: 40 | at_batch_end: None or number indicate when to execute 41 | (i.e. at_batch_end = 10 means execute every 10 batches) 42 | at_epoch_end: logical, whether to execute at epoch end 43 | """ 44 | super(ModelWeightCheck, self).__init__() 45 | self.at_batch_end = at_batch_end 46 | self.at_epoch_end = at_epoch_end 47 | self.current_epoch = 0 48 | self.weight_diff = weight_diff 49 | self.wts = None 50 | 51 | def on_batch_end(self, batch, logs=None): 52 | if self.at_batch_end is not None and np.mod(batch + 1, self.at_batch_end) == 0: 53 | self.on_model_check(self.current_epoch, batch + 1, logs=logs) 54 | 55 | def on_epoch_end(self, epoch, logs=None): 56 | if self.at_epoch_end: 57 | self.on_model_check(epoch, 0, logs=logs) 58 | self.current_epoch = epoch 59 | 60 | def on_model_check(self, epoch, iter, logs=None): 61 | for layer in self.model.layers: 62 | for wt in layer.get_weights(): 63 | assert ~np.any(np.isnan(wt)), 'Found nan weights in model layer %s' % layer.name 64 | assert np.all(np.isfinite(wt)), 'Found infinite weights in model layer %s' % layer.name 65 | 66 | # compute max change 67 | if self.weight_diff: 68 | wts = self.model.get_weights() 69 | diff = -np.inf 70 | 71 | if self.wts is not None: 72 | for wi, w in enumerate(wts): 73 | if len(w) > 0: 74 | for si, sw in enumerate(w): 75 | diff = np.maximum(diff, np.max(np.abs(sw - self.wts[wi][si]))) 76 | 77 | self.wts = wts 78 | logs['max_diff'] = diff 79 | # print("max diff", diff) 80 | 81 | 82 | class CheckLossTrend(keras.callbacks.Callback): 83 | """ 84 | check model weights for nan and infinite entries 85 | """ 86 | 87 | def __init__(self, 88 | at_batch_end=True, 89 | at_epoch_end=False, 90 | nb_std_err=2, 91 | loss_window=10): 92 | """ 93 | Params: 94 | at_batch_end: None or number indicate when to execute 95 | (i.e. at_batch_end = 10 means execute every 10 batches) 96 | at_epoch_end: logical, whether to execute at epoch end 97 | """ 98 | super(CheckLossTrend, self).__init__() 99 | self.at_batch_end = at_batch_end 100 | self.at_epoch_end = at_epoch_end 101 | self.current_epoch = 0 102 | self.loss_window = loss_window 103 | self.nb_std_err = nb_std_err 104 | self.losses = [] 105 | 106 | def on_batch_end(self, batch, logs=None): 107 | if self.at_batch_end is not None and np.mod(batch + 1, self.at_batch_end) == 0: 108 | self.on_model_check(self.current_epoch, batch + 1, logs=logs) 109 | 110 | def on_epoch_end(self, epoch, logs=None): 111 | if self.at_epoch_end: 112 | self.on_model_check(epoch, 0, logs=logs) 113 | self.current_epoch = epoch 114 | 115 | def on_model_check(self, epoch, iter, logs=None): 116 | if len(self.losses) < self.loss_window: 117 | self.losses = [*self.losses, logs['loss']] 118 | else: 119 | losses_mean = np.mean(self.losses) 120 | losses_std = np.std(self.losses) 121 | this_loss = logs['loss'] 122 | 123 | if (this_loss) > (losses_mean + self.nb_std_err * losses_std): 124 | print(logs) 125 | err = "Found loss %f, which is much higher than %f + %f " % (this_loss, losses_mean, losses_std) 126 | # raise ValueError(err) 127 | print(err, file=sys.stderr) 128 | 129 | if (this_loss - losses_mean) > (losses_mean * 100): 130 | err = "Found loss %f, which is much higher than %f * 100 " % (this_loss, losses_mean) 131 | raise ValueError(err) 132 | 133 | # cut the first loss and stack athe latest loss. 134 | self.losses = [*self.losses[1:], logs['loss']] 135 | 136 | 137 | class PlotTestSlices(keras.callbacks.Callback): 138 | ''' 139 | plot slices of a test subject from several directions 140 | ''' 141 | 142 | def __init__(self, 143 | savefilepath, 144 | generator, 145 | vol_size, 146 | run, # object with fields: patch_size, patch_stride, grid_size 147 | data, # object with fields: 148 | at_batch_end=None, # None or number indicate when to execute (i.e. at_batch_end = 10 means execute every 10 batches) 149 | at_epoch_end=True, # logical, whether to execute at epoch end 150 | verbose=False, 151 | period=1, 152 | prior=None): 153 | """ 154 | Parameteres: 155 | savefilepath, 156 | generator, 157 | vol_size, 158 | run: object with fields: patch_size, patch_stride, grid_size 159 | data: object with fields: 160 | at_batch_end=None: None or number indicate when to execute (i.e. at_batch_end = 10 means execute every 10 batches) 161 | at_epoch_end=True: logical, whether to execute at epoch end 162 | verbose=False: 163 | period=1 164 | prior=None 165 | """ 166 | 167 | super().__init__() 168 | 169 | # save some parameters 170 | self.savefilepath = savefilepath 171 | self.generator = generator 172 | self.vol_size = vol_size 173 | 174 | self.run = run 175 | self.data = data 176 | 177 | self.at_batch_end = at_batch_end 178 | self.at_epoch_end = at_epoch_end 179 | self.current_epoch = 0 180 | self.period = period 181 | 182 | self.verbose = verbose 183 | 184 | # prepare prior 185 | self.prior = None 186 | if prior is not None: 187 | data = np.load(prior) 188 | loc_vol = data['prior'] 189 | self.prior = np.expand_dims(loc_vol, axis=0) # reshape for model 190 | 191 | def on_batch_end(self, batch, logs={}): 192 | if self.at_batch_end is not None and np.mod(batch + 1, self.at_batch_end) == 0: 193 | self.on_plot_save(self.current_epoch, batch + 1, logs=logs) 194 | 195 | def on_epoch_end(self, epoch, logs={}): 196 | if self.at_epoch_end and np.mod(epoch + 1, self.period) == 0: 197 | self.on_plot_save(epoch, 0, logs=logs) 198 | self.current_epoch = epoch 199 | 200 | def on_plot_save(self, epoch, iter, logs={}): 201 | # import neuron sandbox 202 | # has to be here, can't be at the top, due to cyclical imports (??) 203 | # TODO: should just pass the function to compute the figures given the model and generator 204 | import neuron.sandbox as nrn_sandbox 205 | reload(nrn_sandbox) 206 | 207 | with timer.Timer('plot callback', self.verbose): 208 | if len(self.run.grid_size) == 3: 209 | collapse_2d = [0, 1, 2] 210 | else: 211 | collapse_2d = [2] 212 | 213 | exampl = nrn_sandbox.show_example_prediction_result(self.model, 214 | self.generator, 215 | self.run, 216 | self.data, 217 | test_batch_size=1, 218 | test_model_names=None, 219 | test_grid_size=self.run.grid_size, 220 | ccmap=None, 221 | collapse_2d=collapse_2d, 222 | slice_nr=None, 223 | plt_width=17, 224 | verbose=self.verbose) 225 | 226 | # save, then close 227 | figs = exampl[1:] 228 | for idx, fig in enumerate(figs): 229 | dirn = "dirn_%d" % idx 230 | slice_nr = 0 231 | filename = self.savefilepath.format(epoch=epoch, iter=iter, axis=dirn, slice_nr=slice_nr) 232 | fig.savefig(filename) 233 | plt.close() 234 | 235 | 236 | class PredictMetrics(keras.callbacks.Callback): 237 | ''' 238 | Compute metrics, like Dice, and save to CSV/log 239 | 240 | ''' 241 | 242 | def __init__(self, 243 | filepath, 244 | metrics, 245 | data_generator, 246 | nb_samples, 247 | nb_labels, 248 | batch_size, 249 | label_ids=None, 250 | vol_params=None, 251 | at_batch_end=None, 252 | at_epoch_end=True, 253 | period=1, 254 | verbose=False): 255 | """ 256 | Parameters: 257 | filepath: filepath with epoch and metric 258 | metrics: list of metrics (functions) 259 | data_generator: validation generator 260 | nb_samples: number of validation samples - volumes or batches 261 | depending on whether vol_params is passed or not 262 | nb_labels: number of labels 263 | batch_size: 264 | label_ids=None: 265 | vol_params=None: 266 | at_batch_end=None: None or number indicate when to execute 267 | (i.e. at_batch_end = 10 means execute every 10 batches) 268 | at_epoch_end=True: logical, whether to execute at epoch end 269 | verbose=False 270 | """ 271 | 272 | # pass in the parameters to object variables 273 | self.metrics = metrics 274 | self.data_generator = data_generator 275 | self.nb_samples = nb_samples 276 | self.filepath = filepath 277 | self.nb_labels = nb_labels 278 | if label_ids is None: 279 | self.label_ids = list(range(nb_labels)) 280 | else: 281 | self.label_ids = label_ids 282 | self.vol_params = vol_params 283 | 284 | self.current_epoch = 1 285 | self.at_batch_end = at_batch_end 286 | self.at_epoch_end = at_epoch_end 287 | self.batch_size = batch_size 288 | self.period = period 289 | 290 | self.verbose = verbose 291 | 292 | def on_batch_end(self, batch, logs={}): 293 | if self.at_batch_end is not None and np.mod(batch + 1, self.at_batch_end) == 0: 294 | self.on_metric_call(self.current_epoch, batch + 1, logs=logs) 295 | 296 | def on_epoch_end(self, epoch, logs={}): 297 | if self.at_epoch_end and np.mod(epoch + 1, self.period) == 0: 298 | self.on_metric_call(epoch, 0, logs=logs) 299 | self.current_epoch = epoch 300 | 301 | def on_metric_call(self, epoch, iter, logs={}): 302 | """ compute metrics on several predictions """ 303 | with timer.Timer('predict metrics callback', self.verbose): 304 | 305 | # prepare metric 306 | met = np.zeros((self.nb_samples, self.nb_labels, len(self.metrics))) 307 | 308 | # generate predictions 309 | # the idea is to predict either a full volume or just a slice, 310 | # depending on what we need 311 | gen = _generate_predictions(self.model, 312 | self.data_generator, 313 | self.batch_size, 314 | self.nb_samples, 315 | self.vol_params) 316 | batch_idx = 0 317 | for (vol_true, vol_pred) in gen: 318 | for idx, metric in enumerate(self.metrics): 319 | met[batch_idx, :, idx] = metric(vol_true, vol_pred) 320 | batch_idx += 1 321 | 322 | # write metric to csv file 323 | if self.filepath is not None: 324 | for idx, metric in enumerate(self.metrics): 325 | filen = self.filepath.format(epoch=epoch, iter=iter, metric=metric.__name__) 326 | np.savetxt(filen, met[:, :, idx], fmt='%f', delimiter=',') 327 | else: 328 | meanmet = np.nanmean(met, axis=0) 329 | for midx, metric in enumerate(self.metrics): 330 | for idx in range(self.nb_labels): 331 | varname = '%s_label_%d' % (metric.__name__, self.label_ids[idx]) 332 | logs[varname] = meanmet[idx, midx] 333 | 334 | 335 | class ModelCheckpoint(keras.callbacks.Callback): 336 | """ 337 | A modification of keras' ModelCheckpoint, but allow for saving on_batch_end 338 | changes include: 339 | - optional at_batch_end, at_epoch_end arguments, 340 | - filename now must includes 'iter' 341 | 342 | Save the model after every epoch. 343 | `filepath` can contain named formatting options, 344 | which will be filled the value of `epoch` and 345 | keys in `logs` (passed in `on_epoch_end`). 346 | For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, 347 | then the model checkpoints will be saved with the epoch number and 348 | the validation loss in the filename. 349 | # Arguments 350 | filepath: string, path to save the model file. 351 | monitor: quantity to monitor. 352 | verbose: verbosity mode, 0 or 1. 353 | save_best_only: if `save_best_only=True`, 354 | the latest best model according to 355 | the quantity monitored will not be overwritten. 356 | mode: one of {auto, min, max}. 357 | If `save_best_only=True`, the decision 358 | to overwrite the current save file is made 359 | based on either the maximization or the 360 | minimization of the monitored quantity. For `val_acc`, 361 | this should be `max`, for `val_loss` this should 362 | be `min`, etc. In `auto` mode, the direction is 363 | automatically inferred from the name of the monitored quantity. 364 | save_weights_only: if True, then only the model's weights will be 365 | saved (`model.save_weights(filepath)`), else the full model 366 | is saved (`model.save(filepath)`). 367 | period: Interval (number of epochs) between checkpoints. 368 | """ 369 | 370 | def __init__(self, filepath, 371 | monitor='val_loss', 372 | save_best_only=False, 373 | save_weights_only=False, 374 | at_batch_end=None, 375 | at_epoch_end=True, 376 | mode='auto', period=1, 377 | verbose=False): 378 | """ 379 | Parameters: 380 | ... 381 | at_batch_end=None: None or number indicate when to execute 382 | (i.e. at_batch_end = 10 means execute every 10 batches) 383 | at_epoch_end=True: logical, whether to execute at epoch end 384 | """ 385 | super(ModelCheckpoint, self).__init__() 386 | self.monitor = monitor 387 | self.verbose = verbose 388 | self.filepath = filepath 389 | self.save_best_only = save_best_only 390 | self.save_weights_only = save_weights_only 391 | self.period = period 392 | self.steps_since_last_save = 0 393 | 394 | if mode not in ['auto', 'min', 'max']: 395 | warnings.warn('ModelCheckpoint mode %s is unknown, ' 396 | 'fallback to auto mode.' % (mode), 397 | RuntimeWarning) 398 | mode = 'auto' 399 | 400 | if mode == 'min': 401 | self.monitor_op = np.less 402 | self.best = np.Inf 403 | elif mode == 'max': 404 | self.monitor_op = np.greater 405 | self.best = -np.Inf 406 | else: 407 | if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): 408 | self.monitor_op = np.greater 409 | self.best = -np.Inf 410 | else: 411 | self.monitor_op = np.less 412 | self.best = np.Inf 413 | 414 | self.at_batch_end = at_batch_end 415 | self.at_epoch_end = at_epoch_end 416 | self.current_epoch = 0 417 | 418 | def on_epoch_begin(self, epoch, logs=None): 419 | self.current_epoch = epoch 420 | 421 | def on_batch_end(self, batch, logs=None): 422 | if self.at_batch_end is not None and np.mod(batch + 1, self.at_batch_end) == 0: 423 | print("Saving model at batch end!") 424 | self.on_model_save(self.current_epoch, batch + 1, logs=logs) 425 | 426 | def on_epoch_end(self, epoch, logs=None): 427 | if self.at_epoch_end: 428 | self.on_model_save(epoch, 0, logs=logs) 429 | self.current_epoch = epoch + 1 430 | 431 | def on_model_save(self, epoch, iter, logs=None): 432 | """ save the model to hdf5. Code mostly from keras core """ 433 | 434 | with timer.Timer('model save callback', self.verbose): 435 | logs = logs or {} 436 | self.steps_since_last_save += 1 437 | if self.steps_since_last_save >= self.period: 438 | self.steps_since_last_save = 0 439 | filepath = self.filepath.format(epoch=epoch, iter=iter, **logs) 440 | if self.save_best_only: 441 | current = logs.get(self.monitor) 442 | if current is None: 443 | warnings.warn('Can save best model only with %s available, ' 444 | 'skipping.' % (self.monitor), RuntimeWarning) 445 | else: 446 | if self.monitor_op(current, self.best): 447 | if self.verbose > 0: 448 | print('Epoch %05d Iter%05d: %s improved from %0.5f to %0.5f,' 449 | ' saving model to %s' 450 | % (epoch, iter, self.monitor, self.best, 451 | current, filepath)) 452 | self.best = current 453 | if self.save_weights_only: 454 | self.model.save_weights(filepath, overwrite=True) 455 | else: 456 | self.model.save(filepath, overwrite=True) 457 | else: 458 | if self.verbose > 0: 459 | print('Epoch %05d Iter%05d: %s did not improve' % 460 | (epoch, iter, self.monitor)) 461 | else: 462 | if self.verbose > 0: 463 | print('Epoch %05d: saving model to %s' % (epoch, filepath)) 464 | if self.save_weights_only: 465 | self.model.save_weights(filepath, overwrite=True) 466 | else: 467 | self.model.save(filepath, overwrite=True) 468 | 469 | 470 | class ModelCheckpointParallel(keras.callbacks.Callback): 471 | """ 472 | 473 | borrow from: https://github.com/rmkemker/main/blob/master/machine_learning/model_checkpoint_parallel.py 474 | 475 | Save the model after every epoch. 476 | `filepath` can contain named formatting options, 477 | which will be filled the value of `epoch` and 478 | keys in `logs` (passed in `on_epoch_end`). 479 | For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, 480 | then the model checkpoints will be saved with the epoch number and 481 | the validation loss in the filename. 482 | # Arguments 483 | filepath: string, path to save the model file. 484 | monitor: quantity to monitor. 485 | verbose: verbosity mode, 0 or 1. 486 | save_best_only: if `save_best_only=True`, 487 | the latest best model according to 488 | the quantity monitored will not be overwritten. 489 | mode: one of {auto, min, max}. 490 | If `save_best_only=True`, the decision 491 | to overwrite the current save file is made 492 | based on either the maximization or the 493 | minimization of the monitored quantity. For `val_acc`, 494 | this should be `max`, for `val_loss` this should 495 | be `min`, etc. In `auto` mode, the direction is 496 | automatically inferred from the name of the monitored quantity. 497 | save_weights_only: if True, then only the model's weights will be 498 | saved (`model.save_weights(filepath)`), else the full model 499 | is saved (`model.save(filepath)`). 500 | period: Interval (number of epochs) between checkpoints. 501 | """ 502 | 503 | def __init__(self, filepath, monitor='val_loss', verbose=0, 504 | save_best_only=False, save_weights_only=False, 505 | at_batch_end=None, 506 | at_epoch_end=True, 507 | mode='auto', period=1): 508 | super(ModelCheckpointParallel, self).__init__() 509 | self.monitor = monitor 510 | self.verbose = verbose 511 | self.filepath = filepath 512 | self.save_best_only = save_best_only 513 | self.save_weights_only = save_weights_only 514 | self.period = period 515 | self.epochs_since_last_save = 0 516 | 517 | if mode not in ['auto', 'min', 'max']: 518 | warnings.warn('ModelCheckpointParallel mode %s is unknown, ' 519 | 'fallback to auto mode.' % (mode), 520 | RuntimeWarning) 521 | mode = 'auto' 522 | 523 | if mode == 'min': 524 | self.monitor_op = np.less 525 | self.best = np.Inf 526 | elif mode == 'max': 527 | self.monitor_op = np.greater 528 | self.best = -np.Inf 529 | else: 530 | if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): 531 | self.monitor_op = np.greater 532 | self.best = -np.Inf 533 | else: 534 | self.monitor_op = np.less 535 | self.best = np.Inf 536 | 537 | self.at_batch_end = at_batch_end 538 | self.at_epoch_end = at_epoch_end 539 | self.current_epoch = 0 540 | 541 | def on_epoch_begin(self, epoch, logs=None): 542 | self.current_epoch = epoch 543 | 544 | def on_batch_end(self, batch, logs=None): 545 | if self.at_batch_end is not None and np.mod(batch + 1, self.at_batch_end) == 0: 546 | print("Saving model at batch end!") 547 | self.on_model_save(self.current_epoch, batch + 1, logs=logs) 548 | 549 | def on_epoch_end(self, epoch, logs=None): 550 | if self.at_epoch_end: 551 | self.on_model_save(epoch, 0, logs=logs) 552 | self.current_epoch = epoch + 1 553 | 554 | def on_model_save(self, epoch, iter, logs=None): 555 | """ save the model to hdf5. Code mostly from keras core """ 556 | 557 | with timer.Timer('model save callback', self.verbose): 558 | logs = logs or {} 559 | num_outputs = len(self.model.outputs) 560 | self.epochs_since_last_save += 1 561 | if self.epochs_since_last_save >= self.period: 562 | self.epochs_since_last_save = 0 563 | filepath = self.filepath.format(epoch=epoch, iter=iter, **logs) 564 | if self.save_best_only: 565 | current = logs.get(self.monitor) 566 | if current is None: 567 | warnings.warn('Can save best model only with %s available, ' 568 | 'skipping.' % (self.monitor), RuntimeWarning) 569 | else: 570 | if self.monitor_op(current, self.best): 571 | if self.verbose > 0: 572 | print('Epoch %05d: Iter%05d: %s improved from %0.5f to %0.5f,' 573 | ' saving model to %s' 574 | % (epoch, iter, self.monitor, self.best, 575 | current, filepath)) 576 | self.best = current 577 | if self.save_weights_only: 578 | self.model.layers[-(num_outputs+1)].save_weights(filepath, overwrite=True) 579 | else: 580 | self.model.layers[-(num_outputs+1)].save(filepath, overwrite=True) 581 | else: 582 | if self.verbose > 0: 583 | print('Epoch %05d Iter%05d: %s did not improve' % 584 | (epoch, iter, self.monitor)) 585 | else: 586 | if self.verbose > 0: 587 | print('Epoch %05d: saving model to %s' % (epoch, filepath)) 588 | if self.save_weights_only: 589 | self.model.layers[-(num_outputs+1)].save_weights(filepath, overwrite=True) 590 | else: 591 | self.model.layers[-(num_outputs+1)].save(filepath, overwrite=True) 592 | 593 | 594 | 595 | ################################################################################################## 596 | # helper functions 597 | ################################################################################################## 598 | 599 | def _generate_predictions(model, data_generator, batch_size, nb_samples, vol_params): 600 | # whole volumes 601 | if vol_params is not None: 602 | for _ in range(nb_samples): # assumes nr volume 603 | vols = nrn_utils.predict_volumes(model, 604 | data_generator, 605 | batch_size, 606 | vol_params["patch_size"], 607 | vol_params["patch_stride"], 608 | vol_params["grid_size"]) 609 | vol_true, vol_pred = vols[0], vols[1] 610 | yield (vol_true, vol_pred) 611 | 612 | # just one batch 613 | else: 614 | for _ in range(nb_samples): # assumes nr batches 615 | vol_pred, vol_true = nrn_utils.next_label(model, data_generator) 616 | yield (vol_true, vol_pred) 617 | 618 | import collections 619 | def _flatten(l): 620 | # https://stackoverflow.com/questions/2158395/flatten-an-irregular-list-of-lists 621 | for el in l: 622 | if isinstance(el, collections.Iterable) and not isinstance(el, (str, bytes)): 623 | yield from _flatten(el) 624 | else: 625 | yield el 626 | -------------------------------------------------------------------------------- /flowreg_a/neuron/dataproc.py: -------------------------------------------------------------------------------- 1 | ''' data processing for neuron project ''' 2 | 3 | import os 4 | import shutil 5 | # built-in 6 | import sys 7 | 8 | import matplotlib.pyplot as plt 9 | # third party 10 | import nibabel as nib 11 | import numpy as np 12 | import scipy.ndimage.interpolation 13 | import six 14 | from PIL import Image 15 | 16 | # note sure if tqdm_notebook reverts back to 17 | try: 18 | get_ipython 19 | from tqdm import tqdm_notebook as tqdm 20 | except: 21 | from tqdm import tqdm as tqdm 22 | 23 | # import local ndutils 24 | import pynd.ndutils as nd 25 | import re 26 | 27 | from imp import reload 28 | reload(nd) 29 | 30 | # from imp import reload # for re-loading modules, since some of the modules are still in development 31 | # reload(nd) 32 | 33 | 34 | def proc_mgh_vols(inpath, 35 | outpath, 36 | ext='.mgz', 37 | label_idx=None, 38 | **kwargs): 39 | ''' process mgh data from mgz format and save to numpy format 40 | 41 | 1. load file 42 | 2. normalize intensity 43 | 3. resize 44 | 4. save as python block 45 | 46 | TODO: check header info and such.? 47 | ''' 48 | 49 | # get files in input directory 50 | files = [f for f in os.listdir(inpath) if f.endswith(ext)] 51 | 52 | # go through each file 53 | list_skipped_files = () 54 | for fileidx in tqdm(range(len(files)), ncols=80): 55 | 56 | # load nifti volume 57 | volnii = nib.load(os.path.join(inpath, files[fileidx])) 58 | 59 | # get the data out 60 | vol_data = volnii.get_data().astype(float) 61 | 62 | if ('dim' in volnii.header) and volnii.header['dim'][4] > 1: 63 | vol_data = vol_data[:, :, :, -1] 64 | 65 | # process volume 66 | try: 67 | vol_data = vol_proc(vol_data, **kwargs) 68 | except Exception as e: 69 | list_skipped_files += (files[fileidx], ) 70 | print("Skipping %s\nError: %s" % (files[fileidx], str(e)), file=sys.stderr) 71 | continue 72 | 73 | if label_idx is not None: 74 | vol_data = (vol_data == label_idx).astype(int) 75 | 76 | # save numpy file 77 | outname = os.path.splitext(os.path.join(outpath, files[fileidx]))[0] + '.npz' 78 | np.savez_compressed(outname, vol_data=vol_data) 79 | 80 | for file in list_skipped_files: 81 | print("Skipped: %s" % file, file=sys.stderr) 82 | 83 | 84 | def scans_to_slices(inpath, outpath, slice_nrs, 85 | ext='.mgz', 86 | label_idx=None, 87 | dim_idx=2, 88 | out_ext='.png', 89 | slice_pad=0, 90 | vol_inner_pad_for_slice_nrs=0, 91 | **kwargs): # vol_proc args 92 | 93 | # get files in input directory 94 | files = [f for f in os.listdir(inpath) if f.endswith(ext)] 95 | 96 | # go through each file 97 | list_skipped_files = () 98 | for fileidx in tqdm(range(len(files)), ncols=80): 99 | 100 | # load nifti volume 101 | volnii = nib.load(os.path.join(inpath, files[fileidx])) 102 | 103 | # get the data out 104 | vol_data = volnii.get_data().astype(float) 105 | 106 | if ('dim' in volnii.header) and volnii.header['dim'][4] > 1: 107 | vol_data = vol_data[:, :, :, -1] 108 | 109 | if slice_pad > 0: 110 | assert (out_ext != '.png'), "slice pad can only be used with volumes" 111 | 112 | # process volume 113 | try: 114 | vol_data = vol_proc(vol_data, **kwargs) 115 | except Exception as e: 116 | list_skipped_files += (files[fileidx], ) 117 | print("Skipping %s\nError: %s" % (files[fileidx], str(e)), file=sys.stderr) 118 | continue 119 | 120 | mult_fact = 255 121 | if label_idx is not None: 122 | vol_data = (vol_data == label_idx).astype(int) 123 | mult_fact = 1 124 | 125 | # extract slice 126 | if slice_nrs is None: 127 | slice_nrs_sel = range(vol_inner_pad_for_slice_nrs+slice_pad, vol_data.shape[dim_idx]-slice_pad-vol_inner_pad_for_slice_nrs) 128 | else: 129 | slice_nrs_sel = slice_nrs 130 | 131 | for slice_nr in slice_nrs_sel: 132 | slice_nr_out = range(slice_nr - slice_pad, slice_nr + slice_pad + 1) 133 | if dim_idx == 2: # TODO: fix in one line 134 | vol_img = np.squeeze(vol_data[:, :, slice_nr_out]) 135 | elif dim_idx == 1: 136 | vol_img = np.squeeze(vol_data[:, slice_nr_out, :]) 137 | else: 138 | vol_img = np.squeeze(vol_data[slice_nr_out, :, :]) 139 | 140 | # save file 141 | if out_ext == '.png': 142 | # save png file 143 | img = (vol_img*mult_fact).astype('uint8') 144 | outname = os.path.splitext(os.path.join(outpath, files[fileidx]))[0] + '_slice%d.png' % slice_nr 145 | Image.fromarray(img).convert('RGB').save(outname) 146 | else: 147 | if slice_pad == 0: # dimenion has collapsed 148 | assert vol_img.ndim == 2 149 | vol_img = np.expand_dims(vol_img, dim_idx) 150 | # assuming nibabel saving image 151 | nii = nib.Nifti1Image(vol_img, np.diag([1,1,1,1])) 152 | outname = os.path.splitext(os.path.join(outpath, files[fileidx]))[0] + '_slice%d.nii.gz' % slice_nr 153 | nib.save(nii, outname) 154 | 155 | 156 | def vol_proc(vol_data, 157 | crop=None, 158 | resize_shape=None, # None (to not resize), or vector. If vector, third entry can be None 159 | interp_order=None, 160 | rescale=None, 161 | rescale_prctle=None, 162 | resize_slices=None, 163 | resize_slices_dim=None, 164 | offset=None, 165 | clip=None, 166 | extract_nd=None, # extracts a particular section 167 | force_binary=None, # forces anything > 0 to be 1 168 | permute=None): 169 | ''' process a volume with a series of intensity rescale, resize and crop rescale''' 170 | 171 | if offset is not None: 172 | vol_data = vol_data + offset 173 | 174 | # intensity normalize data .* rescale 175 | if rescale is not None: 176 | vol_data = np.multiply(vol_data, rescale) 177 | 178 | if rescale_prctle is not None: 179 | # print("max:", np.max(vol_data.flat)) 180 | # print("test") 181 | rescale = np.percentile(vol_data.flat, rescale_prctle) 182 | # print("rescaling by 1/%f" % (rescale)) 183 | vol_data = np.multiply(vol_data.astype(float), 1/rescale) 184 | 185 | if resize_slices is not None: 186 | resize_slices = [*resize_slices] 187 | assert resize_shape is None, "if resize_slices is given, resize_shape has to be None" 188 | resize_shape = resize_slices 189 | if resize_slices_dim is None: 190 | resize_slices_dim = np.where([f is None for f in resize_slices])[0] 191 | assert len(resize_slices_dim) == 1, "Could not find dimension or slice resize" 192 | resize_slices_dim = resize_slices_dim[0] 193 | resize_shape[resize_slices_dim] = vol_data.shape[resize_slices_dim] 194 | 195 | # resize (downsample) matrices 196 | if resize_shape is not None and resize_shape != vol_data.shape: 197 | resize_shape = [*resize_shape] 198 | # allow for the last entry to be None 199 | if resize_shape[-1] is None: 200 | resize_ratio = np.divide(resize_shape[0], vol_data.shape[0]) 201 | resize_shape[-1] = np.round(resize_ratio * vol_data.shape[-1]).astype('int') 202 | resize_ratio = np.divide(resize_shape, vol_data.shape) 203 | vol_data = scipy.ndimage.interpolation.zoom(vol_data, resize_ratio, order=interp_order) 204 | 205 | # crop data if necessary 206 | if crop is not None: 207 | vol_data = nd.volcrop(vol_data, crop=crop) 208 | 209 | # needs to be last to guarantee clip limits. 210 | # For e.g., resize might screw this up due to bicubic interpolation if it was done after. 211 | if clip is not None: 212 | vol_data = np.clip(vol_data, clip[0], clip[1]) 213 | 214 | if extract_nd is not None: 215 | vol_data = vol_data[np.ix_(*extract_nd)] 216 | 217 | if force_binary: 218 | vol_data = (vol_data > 0).astype(float) 219 | 220 | # return with checks. this check should be right at the end before rturn 221 | if clip is not None: 222 | assert np.max(vol_data) <= clip[1], "clip failed" 223 | assert np.min(vol_data) >= clip[0], "clip failed" 224 | return vol_data 225 | 226 | 227 | def prior_to_weights(prior_filename, nargout=1, min_freq=0, force_binary=False, verbose=False): 228 | 229 | ''' transform a 4D prior (3D + nb_labels) into a class weight vector ''' 230 | 231 | # load prior 232 | if isinstance(prior_filename, six.string_types): 233 | prior = np.load(prior_filename)['prior'] 234 | else: 235 | prior = prior_filename 236 | 237 | # assumes prior is 4D. 238 | assert np.ndim(prior) == 4 or np.ndim(prior) == 3, "prior is the wrong number of dimensions" 239 | prior_flat = np.reshape(prior, (np.prod(prior.shape[0:(np.ndim(prior)-1)]), prior.shape[-1])) 240 | 241 | if force_binary: 242 | nb_labels = prior_flat.shape[-1] 243 | prior_flat[:, 1] = np.sum(prior_flat[:, 1:nb_labels], 1) 244 | prior_flat = np.delete(prior_flat, range(2, nb_labels), 1) 245 | 246 | # sum total class votes 247 | class_count = np.sum(prior_flat, 0) 248 | class_prior = class_count / np.sum(class_count) 249 | 250 | # adding minimum frequency 251 | class_prior[class_prior < min_freq] = min_freq 252 | class_prior = class_prior / np.sum(class_prior) 253 | 254 | if np.any(class_prior == 0): 255 | print("Warning, found a label with 0 support. Setting its weight to 0!", file=sys.stderr) 256 | class_prior[class_prior == 0] = np.inf 257 | 258 | # compute weights from class frequencies 259 | weights = 1/class_prior 260 | weights = weights / np.sum(weights) 261 | # weights[0] = 0 # explicitly don't care about bg 262 | 263 | # a bit of verbosity 264 | if verbose: 265 | f, (ax1, ax2, ax3) = plt.subplots(1, 3) 266 | ax1.bar(range(prior.size), np.log(prior)) 267 | ax1.set_title('log class freq') 268 | ax2.bar(range(weights.size), weights) 269 | ax2.set_title('weights') 270 | ax3.bar(range(weights.size), np.log((weights))-np.min(np.log((weights)))) 271 | ax3.set_title('log(weights)-minlog') 272 | f.set_size_inches(12, 3) 273 | plt.show() 274 | np.set_printoptions(precision=3) 275 | 276 | # return 277 | if nargout == 1: 278 | return weights 279 | else: 280 | return (weights, prior) 281 | 282 | 283 | 284 | 285 | def filestruct_change(in_path, out_path, re_map, 286 | mode='subj_to_type', 287 | use_symlinks=False, name=""): 288 | """ 289 | change from independent subjects in a folder to breakdown structure 290 | 291 | example: filestruct_change('/../in_path', '/../out_path', {'asegs.nii.gz':'asegs', 'norm.nii.gz':'vols'}) 292 | 293 | 294 | input structure: 295 | /.../in_path/subj_1 --> with files that match regular repressions defined in re_map.keys() 296 | /.../in_path/subj_2 --> with files that match regular repressions defined in re_map.keys() 297 | ... 298 | output structure: 299 | /.../out_path/asegs/subj_1.nii.gz, subj_2.nii.gz 300 | /.../out_path/vols/subj_1.nii.gz, subj_2.nii.gz 301 | 302 | Parameters: 303 | in_path (string): input path 304 | out_path (string): output path 305 | re_map (dictionary): keys are reg-exs that match files in the input folders. 306 | values are the folders to put those files in the new structure. 307 | values can also be tuples, in which case values[0] is the dst folder, 308 | and values[1] is the extension of the output file 309 | mode (optional) 310 | use_symlinks (bool): whether to just use symlinks rather than copy files 311 | default:True 312 | """ 313 | 314 | 315 | if not os.path.isdir(out_path): 316 | os.mkdir(out_path) 317 | 318 | # go through folders 319 | for subj in tqdm(os.listdir(in_path), desc=name): 320 | 321 | # go through files in a folder 322 | files = os.listdir(os.path.join(in_path, subj)) 323 | for file in files: 324 | 325 | # see which key matches. Make sure only one does. 326 | matches = [re.match(k, file) for k in re_map.keys()] 327 | nb_matches = sum([f is not None for f in matches]) 328 | assert nb_matches == 1, "Found %d matches for file %s/%s" %(nb_matches, file, subj) 329 | 330 | # get the matches key 331 | match_idx = [i for i,f in enumerate(matches) if f is not None][0] 332 | matched_dst = re_map[list(re_map.keys())[match_idx]] 333 | _, ext = os.path.splitext(file) 334 | if isinstance(matched_dst, tuple): 335 | ext = matched_dst[1] 336 | matched_dst = matched_dst[0] 337 | 338 | # prepare source and destination file 339 | src_file = os.path.join(in_path, subj, file) 340 | dst_path = os.path.join(out_path, matched_dst) 341 | if not os.path.isdir(dst_path): 342 | os.mkdir(dst_path) 343 | dst_file = os.path.join(dst_path, subj + ext) 344 | 345 | if use_symlinks: 346 | # on windows there are permission problems. 347 | # Can try : call(['mklink', 'LINK', 'TARGET'], shell=True) 348 | # or note https://stackoverflow.com/questions/6260149/os-symlink-support-in-windows 349 | os.symlink(src_file, dst_file) 350 | 351 | else: 352 | shutil.copyfile(src_file, dst_file) 353 | 354 | 355 | def ml_split(in_path, out_path, 356 | cat_titles=['train', 'validate', 'test'], 357 | cat_prop=[0.5, 0.3, 0.2], 358 | use_symlinks=False, 359 | seed=None, 360 | tqdm=tqdm): 361 | """ 362 | split dataset 363 | """ 364 | 365 | if seed is not None: 366 | np.random.seed(seed) 367 | 368 | if not os.path.isdir(out_path): 369 | os.makedirs(out_path) 370 | 371 | # get subjects and randomize their order 372 | subjs = sorted(os.listdir(in_path)) 373 | nb_subj = len(subjs) 374 | subj_order = np.random.permutation(nb_subj) 375 | 376 | # prepare split 377 | cat_tot = np.cumsum(cat_prop) 378 | if not cat_tot[-1] == 1: 379 | print("split_prop sums to %f, re-normalizing" % cat_tot) 380 | cat_tot = np.array(cat_tot) / cat_tot[-1] 381 | nb_cat_subj = np.round(cat_tot * nb_subj).astype(int) 382 | cat_subj_start = [0, *nb_cat_subj[:-1]] 383 | 384 | # go through each category 385 | for cat_idx, cat in enumerate(cat_titles): 386 | if not os.path.isdir(os.path.join(out_path, cat)): 387 | os.mkdir(os.path.join(out_path, cat)) 388 | 389 | cat_subj_idx = subj_order[cat_subj_start[cat_idx]:nb_cat_subj[cat_idx]] 390 | for subj_idx in tqdm(cat_subj_idx, desc=cat): 391 | src_folder = os.path.join(in_path, subjs[subj_idx]) 392 | dst_folder = os.path.join(out_path, cat, subjs[subj_idx]) 393 | 394 | if use_symlinks: 395 | # on windows there are permission problems. 396 | # Can try : call(['mklink', 'LINK', 'TARGET'], shell=True) 397 | # or note https://stackoverflow.com/questions/6260149/os-symlink-support-in-windows 398 | os.symlink(src_folder, dst_folder) 399 | 400 | else: 401 | if os.path.isdir(src_folder): 402 | shutil.copytree(src_folder, dst_folder) 403 | else: 404 | shutil.copyfile(src_folder, dst_folder) 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | -------------------------------------------------------------------------------- /flowreg_a/neuron/inits.py: -------------------------------------------------------------------------------- 1 | ''' initializations for the neuron project ''' 2 | 3 | # general imports 4 | import numpy as np 5 | import tensorflow.keras.backend as K 6 | 7 | 8 | def output_init(shape, name=None, dim_ordering=None): 9 | ''' initialization for output weights''' 10 | size = (shape[0], shape[1], shape[2] - shape[3], shape[3]) 11 | 12 | # initialize output weights with random and identity 13 | rpart = np.random.random(size) 14 | # idpart_ = np.eye(size[3]) 15 | idpart_ = np.ones((size[3], size[3])) 16 | idpart = np.expand_dims(np.expand_dims(idpart_, 0), 0) 17 | value = np.concatenate((rpart, idpart), axis=2) 18 | return K.variable(value, name=name) 19 | -------------------------------------------------------------------------------- /flowreg_a/neuron/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | tensorflow/keras utilities for the neuron project 3 | 4 | If you use this code, please cite 5 | Dalca AV, Guttag J, Sabuncu MR 6 | Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, 7 | CVPR 2018 8 | 9 | Contact: adalca [at] csail [dot] mit [dot] edu 10 | License: GPLv3 11 | """ 12 | 13 | # third party 14 | import numpy as np 15 | import tensorflow as tf 16 | import tensorflow.keras.backend as K 17 | from tensorflow.keras import losses 18 | 19 | # local 20 | from . import utils 21 | 22 | 23 | class CategoricalCrossentropy(object): 24 | """ 25 | Categorical crossentropy with optional categorical weights and spatial prior 26 | 27 | Adapted from weighted categorical crossentropy via wassname: 28 | https://gist.github.com/wassname/ce364fddfc8a025bfab4348cf5de852d 29 | 30 | Variables: 31 | weights: numpy array of shape (C,) where C is the number of classes 32 | 33 | Usage: 34 | loss = CategoricalCrossentropy().loss # or 35 | loss = CategoricalCrossentropy(weights=weights).loss # or 36 | loss = CategoricalCrossentropy(..., prior=prior).loss 37 | model.compile(loss=loss, optimizer='adam') 38 | """ 39 | 40 | def __init__(self, weights=None, use_float16=False, vox_weights=None, crop_indices=None): 41 | """ 42 | Parameters: 43 | vox_weights is either a numpy array the same size as y_true, 44 | or a string: 'y_true' or 'expy_true' 45 | crop_indices: indices to crop each element of the batch 46 | if each element is N-D (so y_true is N+1 dimensional) 47 | then crop_indices is a Tensor of crop ranges (indices) 48 | of size <= N-D. If it's < N-D, then it acts as a slice 49 | for the last few dimensions. 50 | See Also: tf.gather_nd 51 | """ 52 | 53 | self.weights = weights if (weights is not None) else None 54 | self.use_float16 = use_float16 55 | self.vox_weights = vox_weights 56 | self.crop_indices = crop_indices 57 | 58 | if self.crop_indices is not None and vox_weights is not None: 59 | self.vox_weights = utils.batch_gather(self.vox_weights, self.crop_indices) 60 | 61 | def loss(self, y_true, y_pred): 62 | """ categorical crossentropy loss """ 63 | 64 | if self.crop_indices is not None: 65 | y_true = utils.batch_gather(y_true, self.crop_indices) 66 | y_pred = utils.batch_gather(y_pred, self.crop_indices) 67 | 68 | if self.use_float16: 69 | y_true = K.cast(y_true, 'float16') 70 | y_pred = K.cast(y_pred, 'float16') 71 | 72 | # scale and clip probabilities 73 | # this should not be necessary for softmax output. 74 | y_pred /= K.sum(y_pred, axis=-1, keepdims=True) 75 | y_pred = K.clip(y_pred, K.epsilon(), 1) 76 | 77 | # compute log probability 78 | log_post = K.log(y_pred) # likelihood 79 | 80 | # loss 81 | loss = - y_true * log_post 82 | 83 | # weighted loss 84 | if self.weights is not None: 85 | loss *= self.weights 86 | 87 | if self.vox_weights is not None: 88 | loss *= self.vox_weights 89 | 90 | # take the total loss 91 | # loss = K.batch_flatten(loss) 92 | mloss = K.mean(K.sum(K.cast(loss, 'float32'), -1)) 93 | tf.verify_tensor_all_finite(mloss, 'Loss not finite') 94 | return mloss 95 | 96 | 97 | class Dice(object): 98 | """ 99 | Dice of two Tensors. 100 | 101 | Tensors should either be: 102 | - probabilitic for each label 103 | i.e. [batch_size, *vol_size, nb_labels], where vol_size is the size of the volume (n-dims) 104 | e.g. for a 2D vol, y has 4 dimensions, where each entry is a prob for that voxel 105 | - max_label 106 | i.e. [batch_size, *vol_size], where vol_size is the size of the volume (n-dims). 107 | e.g. for a 2D vol, y has 3 dimensions, where each entry is the max label of that voxel 108 | 109 | Variables: 110 | nb_labels: optional numpy array of shape (L,) where L is the number of labels 111 | if not provided, all non-background (0) labels are computed and averaged 112 | weights: optional numpy array of shape (L,) giving relative weights of each label 113 | input_type is 'prob', or 'max_label' 114 | dice_type is hard or soft 115 | 116 | Usage: 117 | diceloss = metrics.dice(weights=[1, 2, 3]) 118 | model.compile(diceloss, ...) 119 | 120 | Test: 121 | import keras.utils as nd_utils 122 | reload(nrn_metrics) 123 | weights = [0.1, 0.2, 0.3, 0.4, 0.5] 124 | nb_labels = len(weights) 125 | vol_size = [10, 20] 126 | batch_size = 7 127 | 128 | dice_loss = metrics.Dice(nb_labels=nb_labels).loss 129 | dice = metrics.Dice(nb_labels=nb_labels).dice 130 | dice_wloss = metrics.Dice(nb_labels=nb_labels, weights=weights).loss 131 | 132 | # vectors 133 | lab_size = [batch_size, *vol_size] 134 | r = nd_utils.to_categorical(np.random.randint(0, nb_labels, lab_size), nb_labels) 135 | vec_1 = np.reshape(r, [*lab_size, nb_labels]) 136 | r = nd_utils.to_categorical(np.random.randint(0, nb_labels, lab_size), nb_labels) 137 | vec_2 = np.reshape(r, [*lab_size, nb_labels]) 138 | 139 | # get some standard vectors 140 | tf_vec_1 = tf.constant(vec_1, dtype=tf.float32) 141 | tf_vec_2 = tf.constant(vec_2, dtype=tf.float32) 142 | 143 | # compute some metrics 144 | res = [f(tf_vec_1, tf_vec_2) for f in [dice, dice_loss, dice_wloss]] 145 | res_same = [f(tf_vec_1, tf_vec_1) for f in [dice, dice_loss, dice_wloss]] 146 | 147 | # tf run 148 | init_op = tf.global_variables_initializer() 149 | with tf.Session() as sess: 150 | sess.run(init_op) 151 | sess.run(res) 152 | sess.run(res_same) 153 | print(res[2].eval()) 154 | print(res_same[2].eval()) 155 | """ 156 | 157 | def __init__(self, nb_labels, 158 | weights=None, 159 | input_type='prob', 160 | dice_type='soft', 161 | approx_hard_max=True, 162 | vox_weights=None, 163 | crop_indices=None, 164 | area_reg=0.1): # regularization for bottom of Dice coeff 165 | """ 166 | input_type is 'prob', or 'max_label' 167 | dice_type is hard or soft 168 | approx_hard_max - see note below 169 | 170 | Note: for hard dice, we grab the most likely label and then compute a 171 | one-hot encoding for each voxel with respect to possible labels. To grab the most 172 | likely labels, argmax() can be used, but only when Dice is used as a metric 173 | For a Dice *loss*, argmax is not differentiable, and so we can't use it 174 | Instead, we approximate the prob->one_hot translation when approx_hard_max is True. 175 | """ 176 | 177 | self.nb_labels = nb_labels 178 | self.weights = None if weights is None else K.variable(weights) 179 | self.vox_weights = None if vox_weights is None else K.variable(vox_weights) 180 | self.input_type = input_type 181 | self.dice_type = dice_type 182 | self.approx_hard_max = approx_hard_max 183 | self.area_reg = area_reg 184 | self.crop_indices = crop_indices 185 | 186 | if self.crop_indices is not None and vox_weights is not None: 187 | self.vox_weights = utils.batch_gather(self.vox_weights, self.crop_indices) 188 | 189 | def dice(self, y_true, y_pred): 190 | """ 191 | compute dice for given Tensors 192 | 193 | """ 194 | if self.crop_indices is not None: 195 | y_true = utils.batch_gather(y_true, self.crop_indices) 196 | y_pred = utils.batch_gather(y_pred, self.crop_indices) 197 | 198 | if self.input_type == 'prob': 199 | # We assume that y_true is probabilistic, but just in case: 200 | y_true /= K.sum(y_true, axis=-1, keepdims=True) 201 | y_true = K.clip(y_true, K.epsilon(), 1) 202 | 203 | # make sure pred is a probability 204 | y_pred /= K.sum(y_pred, axis=-1, keepdims=True) 205 | y_pred = K.clip(y_pred, K.epsilon(), 1) 206 | 207 | # Prepare the volumes to operate on 208 | # If we're doing 'hard' Dice, then we will prepare one-hot-based matrices of size 209 | # [batch_size, nb_voxels, nb_labels], where for each voxel in each batch entry, 210 | # the entries are either 0 or 1 211 | if self.dice_type == 'hard': 212 | 213 | # if given predicted probability, transform to "hard max"" 214 | if self.input_type == 'prob': 215 | if self.approx_hard_max: 216 | y_pred_op = _hard_max(y_pred, axis=-1) 217 | y_true_op = _hard_max(y_true, axis=-1) 218 | else: 219 | y_pred_op = _label_to_one_hot(K.argmax(y_pred, axis=-1), self.nb_labels) 220 | y_true_op = _label_to_one_hot(K.argmax(y_true, axis=-1), self.nb_labels) 221 | 222 | # if given predicted label, transform to one hot notation 223 | else: 224 | assert self.input_type == 'max_label' 225 | y_pred_op = _label_to_one_hot(y_pred, self.nb_labels) 226 | y_true_op = _label_to_one_hot(y_true, self.nb_labels) 227 | 228 | # If we're doing soft Dice, require prob output, and the data already is as we need it 229 | # [batch_size, nb_voxels, nb_labels] 230 | else: 231 | assert self.input_type == 'prob', "cannot do soft dice with max_label input" 232 | y_pred_op = y_pred 233 | y_true_op = y_true 234 | 235 | # reshape data to [batch_size, nb_voxels, nb_labels] 236 | flat_shape = tf.stack([-1, K.prod(K.shape(y_true_op)[1:-1]), K.shape(y_true_op)[-1]]) 237 | y_true_op = K.reshape(y_true_op, flat_shape) 238 | y_pred_op = K.reshape(y_pred_op, flat_shape) 239 | 240 | # compute dice for each entry in batch. 241 | # dice will now be [batch_size, nb_labels] 242 | top = 2 * K.sum(y_true_op * y_pred_op, 1) 243 | bottom = K.sum(K.square(y_true_op), 1) + K.sum(K.square(y_pred_op), 1) 244 | # make sure we have no 0s on the bottom. K.epsilon() 245 | bottom = K.maximum(bottom, self.area_reg) 246 | return top / bottom 247 | 248 | def mean_dice(self, y_true, y_pred): 249 | """ weighted mean dice across all patches and labels """ 250 | 251 | # compute dice, which will now be [batch_size, nb_labels] 252 | dice_metric = self.dice(y_true, y_pred) 253 | 254 | # weigh the entries in the dice matrix: 255 | if self.weights is not None: 256 | dice_metric *= self.weights 257 | if self.vox_weights is not None: 258 | dice_metric *= self.vox_weights 259 | 260 | # return one minus mean dice as loss 261 | mean_dice_metric = K.mean(dice_metric) 262 | tf.verify_tensor_all_finite(mean_dice_metric, 'metric not finite') 263 | return mean_dice_metric 264 | 265 | 266 | def loss(self, y_true, y_pred): 267 | """ the loss. Assumes y_pred is prob (in [0,1] and sum_row = 1) """ 268 | 269 | # compute dice, which will now be [batch_size, nb_labels] 270 | dice_metric = self.dice(y_true, y_pred) 271 | 272 | # loss 273 | dice_loss = 1 - dice_metric 274 | 275 | # weigh the entries in the dice matrix: 276 | if self.weights is not None: 277 | dice_loss *= self.weights 278 | 279 | # return one minus mean dice as loss 280 | mean_dice_loss = K.mean(dice_loss) 281 | tf.verify_tensor_all_finite(mean_dice_loss, 'Loss not finite') 282 | return mean_dice_loss 283 | 284 | 285 | class MeanSquaredError(): 286 | """ 287 | MSE with several weighting options 288 | """ 289 | 290 | 291 | def __init__(self, weights=None, vox_weights=None, crop_indices=None): 292 | """ 293 | Parameters: 294 | vox_weights is either a numpy array the same size as y_true, 295 | or a string: 'y_true' or 'expy_true' 296 | crop_indices: indices to crop each element of the batch 297 | if each element is N-D (so y_true is N+1 dimensional) 298 | then crop_indices is a Tensor of crop ranges (indices) 299 | of size <= N-D. If it's < N-D, then it acts as a slice 300 | for the last few dimensions. 301 | See Also: tf.gather_nd 302 | """ 303 | self.weights = weights 304 | self.vox_weights = vox_weights 305 | self.crop_indices = crop_indices 306 | 307 | if self.crop_indices is not None and vox_weights is not None: 308 | self.vox_weights = utils.batch_gather(self.vox_weights, self.crop_indices) 309 | 310 | def loss(self, y_true, y_pred): 311 | 312 | if self.crop_indices is not None: 313 | y_true = utils.batch_gather(y_true, self.crop_indices) 314 | y_pred = utils.batch_gather(y_pred, self.crop_indices) 315 | 316 | ksq = K.square(y_pred - y_true) 317 | 318 | if self.vox_weights is not None: 319 | if self.vox_weights == 'y_true': 320 | ksq *= y_true 321 | elif self.vox_weights == 'expy_true': 322 | ksq *= tf.exp(y_true) 323 | else: 324 | ksq *= self.vox_weights 325 | 326 | if self.weights is not None: 327 | ksq *= self.weights 328 | 329 | return K.mean(ksq) 330 | 331 | 332 | class Mix(): 333 | """ a mix of several losses """ 334 | 335 | def __init__(self, losses, loss_weights=None): 336 | self.losses = losses 337 | self.loss_wts = loss_wts 338 | if loss_wts is None: 339 | self.loss_wts = np.ones(len(loss_wts)) 340 | 341 | def loss(self, y_true, y_pred): 342 | total_loss = K.variable(0) 343 | for idx, loss in enumerate(self.losses): 344 | total_loss += self.loss_weights[idx] * loss(y_true, y_pred) 345 | return total_loss 346 | 347 | 348 | class WGAN_GP(object): 349 | """ 350 | based on https://github.com/rarilurelo/keras_improved_wgan/blob/master/wgan_gp.py 351 | """ 352 | 353 | def __init__(self, disc, batch_size=1, lambda_gp=10): 354 | self.disc = disc 355 | self.lambda_gp = lambda_gp 356 | self.batch_size = batch_size 357 | 358 | def loss(self, y_true, y_pred): 359 | 360 | # get the value for the true and fake images 361 | disc_true = self.disc(y_true) 362 | disc_pred = self.disc(y_pred) 363 | 364 | # sample a x_hat by sampling along the line between true and pred 365 | # z = tf.placeholder(tf.float32, shape=[None, 1]) 366 | # shp = y_true.get_shape()[0] 367 | # WARNING: SHOULD REALLY BE shape=[batch_size, 1] !!! 368 | # self.batch_size does not work, since it's not None!!! 369 | alpha = K.random_uniform(shape=[K.shape(y_pred)[0], 1, 1, 1]) 370 | diff = y_pred - y_true 371 | interp = y_true + alpha * diff 372 | 373 | # take gradient of D(x_hat) 374 | gradients = K.gradients(self.disc(interp), [interp])[0] 375 | grad_pen = K.mean(K.square(K.sqrt(K.sum(K.square(gradients), axis=1))-1)) 376 | 377 | # compute loss 378 | return (K.mean(disc_pred) - K.mean(disc_true)) + self.lambda_gp * grad_pen 379 | 380 | 381 | class Nonbg(object): 382 | """ UNTESTED 383 | class to modify output on operating only on the non-bg class 384 | 385 | All data is aggregated and the (passed) metric is called on flattened true and 386 | predicted outputs in all (true) non-bg regions 387 | 388 | Usage: 389 | loss = metrics.dice 390 | nonbgloss = nonbg(loss).loss 391 | """ 392 | 393 | def __init__(self, metric): 394 | self.metric = metric 395 | 396 | def loss(self, y_true, y_pred): 397 | """ prepare a loss of the given metric/loss operating on non-bg data """ 398 | yt = y_true #.eval() 399 | ytbg = np.where(yt == 0) 400 | y_true_fix = K.variable(yt.flat(ytbg)) 401 | y_pred_fix = K.variable(y_pred.flat(ytbg)) 402 | return self.metric(y_true_fix, y_pred_fix) 403 | 404 | 405 | def l1(y_true, y_pred): 406 | """ L1 metric (MAE) """ 407 | return losses.mean_absolute_error(y_true, y_pred) 408 | 409 | 410 | def l2(y_true, y_pred): 411 | """ L2 metric (MSE) """ 412 | return losses.mean_squared_error(y_true, y_pred) 413 | 414 | 415 | ############################################################################### 416 | # Helper Functions 417 | ############################################################################### 418 | 419 | def _label_to_one_hot(tens, nb_labels): 420 | """ 421 | Transform a label nD Tensor to a one-hot 3D Tensor. The input tensor is first 422 | batch-flattened, and then each batch and each voxel gets a one-hot representation 423 | """ 424 | y = K.batch_flatten(tens) 425 | return K.one_hot(y, nb_labels) 426 | 427 | 428 | def _hard_max(tens, axis): 429 | """ 430 | we can't use the argmax function in a loss, as it's not differentiable 431 | We can use it in a metric, but not in a loss function 432 | therefore, we replace the 'hard max' operation (i.e. argmax + onehot) 433 | with this approximation 434 | """ 435 | tensmax = K.max(tens, axis=axis, keepdims=True) 436 | eps_hot = K.maximum(tens - tensmax + K.epsilon(), 0) 437 | one_hot = eps_hot / K.epsilon() 438 | return one_hot 439 | -------------------------------------------------------------------------------- /flowreg_a/neuron/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | tensorflow/keras utilities for the neuron project 3 | 4 | If you use this code, please cite 5 | Dalca AV, Guttag J, Sabuncu MR 6 | Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, 7 | CVPR 2018 8 | 9 | Contact: adalca [at] csail [dot] mit [dot] edu 10 | License: GPLv3 11 | """ 12 | 13 | import sys 14 | 15 | # third party 16 | import numpy as np 17 | import tensorflow.keras.backend as K 18 | import tensorflow.keras.layers as KL 19 | from tensorflow.keras.models import Model 20 | from tensorflow.python.keras.constraints import maxnorm 21 | 22 | from . import layers 23 | 24 | 25 | ############################################################################### 26 | # Roughly volume preserving (e.g. high dim to high dim) models 27 | ############################################################################### 28 | 29 | 30 | def dilation_net(nb_features, 31 | input_shape, # input layer shape, vector of size ndims + 1(nb_channels) 32 | nb_levels, 33 | conv_size, 34 | nb_labels, 35 | name='dilation_net', 36 | prefix=None, 37 | feat_mult=1, 38 | pool_size=2, 39 | use_logp=True, 40 | padding='same', 41 | dilation_rate_mult=1, 42 | activation='elu', 43 | use_residuals=False, 44 | final_pred_activation='softmax', 45 | nb_conv_per_level=1, 46 | add_prior_layer=False, 47 | add_prior_layer_reg=0, 48 | layer_nb_feats=None, 49 | batch_norm=None): 50 | 51 | return unet(nb_features, 52 | input_shape, # input layer shape, vector of size ndims + 1(nb_channels) 53 | nb_levels, 54 | conv_size, 55 | nb_labels, 56 | name='unet', 57 | prefix=None, 58 | feat_mult=1, 59 | pool_size=2, 60 | use_logp=True, 61 | padding='same', 62 | activation='elu', 63 | use_residuals=False, 64 | dilation_rate_mult=dilation_rate_mult, 65 | final_pred_activation='softmax', 66 | nb_conv_per_level=1, 67 | add_prior_layer=False, 68 | add_prior_layer_reg=0, 69 | layer_nb_feats=None, 70 | batch_norm=None) 71 | 72 | 73 | def unet(nb_features, 74 | input_shape, 75 | nb_levels, 76 | conv_size, 77 | nb_labels, 78 | name='unet', 79 | prefix=None, 80 | feat_mult=1, 81 | pool_size=2, 82 | use_logp=True, 83 | padding='same', 84 | dilation_rate_mult=1, 85 | activation='elu', 86 | use_residuals=False, 87 | final_pred_activation='softmax', 88 | nb_conv_per_level=1, 89 | add_prior_layer=False, 90 | add_prior_layer_reg=0, 91 | layer_nb_feats=None, 92 | conv_dropout=0, 93 | batch_norm=None): 94 | """ 95 | unet-style keras model with an overdose of parametrization. 96 | 97 | downsampling: 98 | 99 | for U-net like architecture, we need to use Deconvolution3D. 100 | However, this is not yet available (maybe soon, it's on a dev branch in github I believe) 101 | Until then, we'll upsample and convolve. 102 | TODO: Need to check that UpSampling3D actually does NN-upsampling! 103 | 104 | Parameters: 105 | nb_features: the number of features at each convolutional level 106 | see below for `feat_mult` and `layer_nb_feats` for modifiers to this number 107 | input_shape: input layer shape, vector of size ndims + 1 (nb_channels) 108 | conv_size: the convolution kernel size 109 | nb_levels: the number of Unet levels (number of downsamples) in the "encoder" 110 | (e.g. 4 would give you 4 levels in encoder, 4 in decoder) 111 | nb_labels: number of output channels 112 | name (default: 'unet'): the name of the network 113 | prefix (default: `name` value): prefix to be added to layer names 114 | feat_mult (default: 1) multiple for `nb_features` as we go down the encoder levels. 115 | e.g. feat_mult of 2 and nb_features of 16 would yield 32 features in the 116 | second layer, 64 features in the third layer, etc 117 | pool_size (default: 2): max pooling size (integer or list if specifying per dimension) 118 | use_logp: 119 | padding: 120 | dilation_rate_mult: 121 | activation: 122 | use_residuals: 123 | final_pred_activation: 124 | nb_conv_per_level: 125 | add_prior_layer: 126 | add_prior_layer_reg: 127 | layer_nb_feats: 128 | conv_dropout: 129 | batch_norm: 130 | """ 131 | 132 | # naming 133 | model_name = name 134 | if prefix is None: 135 | prefix = model_name 136 | 137 | # volume size data 138 | ndims = len(input_shape) - 1 139 | if isinstance(pool_size, int): 140 | pool_size = (pool_size,) * ndims 141 | 142 | # get encoding model 143 | enc_model = conv_enc(nb_features, 144 | input_shape, 145 | nb_levels, 146 | conv_size, 147 | name=model_name, 148 | prefix=prefix, 149 | feat_mult=feat_mult, 150 | pool_size=pool_size, 151 | padding=padding, 152 | dilation_rate_mult=dilation_rate_mult, 153 | activation=activation, 154 | use_residuals=use_residuals, 155 | nb_conv_per_level=nb_conv_per_level, 156 | layer_nb_feats=layer_nb_feats, 157 | conv_dropout=conv_dropout, 158 | batch_norm=batch_norm) 159 | 160 | # get decoder 161 | # use_skip_connections=1 makes it a u-net 162 | lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None 163 | dec_model = conv_dec(nb_features, 164 | None, 165 | nb_levels, 166 | conv_size, 167 | nb_labels, 168 | name=model_name, 169 | prefix=prefix, 170 | feat_mult=feat_mult, 171 | pool_size=pool_size, 172 | use_skip_connections=1, 173 | padding=padding, 174 | dilation_rate_mult=dilation_rate_mult, 175 | activation=activation, 176 | use_residuals=use_residuals, 177 | final_pred_activation='linear' if add_prior_layer else final_pred_activation, 178 | nb_conv_per_level=nb_conv_per_level, 179 | batch_norm=batch_norm, 180 | layer_nb_feats=lnf, 181 | conv_dropout=conv_dropout, 182 | input_model=enc_model) 183 | 184 | final_model = dec_model 185 | if add_prior_layer: 186 | final_model = add_prior(dec_model, 187 | [*input_shape[:-1], nb_labels], 188 | name=model_name + '_prior', 189 | use_logp=use_logp, 190 | final_pred_activation=final_pred_activation, 191 | add_prior_layer_reg=add_prior_layer_reg) 192 | 193 | return final_model 194 | 195 | 196 | def ae(nb_features, 197 | input_shape, 198 | nb_levels, 199 | conv_size, 200 | nb_labels, 201 | enc_size, 202 | name='ae', 203 | prefix=None, 204 | feat_mult=1, 205 | pool_size=2, 206 | padding='same', 207 | activation='elu', 208 | use_residuals=False, 209 | nb_conv_per_level=1, 210 | batch_norm=None, 211 | enc_batch_norm=None, 212 | ae_type='conv', # 'dense', or 'conv' 213 | enc_lambda_layers=None, 214 | add_prior_layer=False, 215 | add_prior_layer_reg=0, 216 | use_logp=True, 217 | conv_dropout=0, 218 | include_mu_shift_layer=False, 219 | single_model=False, # whether to return a single model, or a tuple of models that can be stacked. 220 | final_pred_activation='softmax', 221 | do_vae=False): 222 | """ 223 | Convolutional Auto-Encoder. 224 | Optionally Variational. 225 | Optionally Dense middle layer 226 | 227 | "Mostly" in that the inner encoding can be (optionally) constructed via dense features. 228 | 229 | Parameters: 230 | do_vae (bool): whether to do a variational auto-encoder or not. 231 | 232 | enc_lambda_layers functions to try: 233 | K.softsign 234 | 235 | a = 1 236 | longtanh = lambda x: K.tanh(x) * K.log(2 + a * abs(x)) 237 | """ 238 | 239 | # naming 240 | model_name = name 241 | 242 | # volume size data 243 | ndims = len(input_shape) - 1 244 | if isinstance(pool_size, int): 245 | pool_size = (pool_size,) * ndims 246 | 247 | # get encoding model 248 | enc_model = conv_enc(nb_features, 249 | input_shape, 250 | nb_levels, 251 | conv_size, 252 | name=model_name, 253 | feat_mult=feat_mult, 254 | pool_size=pool_size, 255 | padding=padding, 256 | activation=activation, 257 | use_residuals=use_residuals, 258 | nb_conv_per_level=nb_conv_per_level, 259 | conv_dropout=conv_dropout, 260 | batch_norm=batch_norm) 261 | 262 | # middle AE structure 263 | if single_model: 264 | in_input_shape = None 265 | in_model = enc_model 266 | else: 267 | in_input_shape = enc_model.output.shape.as_list()[1:] 268 | in_model = None 269 | mid_ae_model = single_ae(enc_size, 270 | in_input_shape, 271 | conv_size=conv_size, 272 | name=model_name, 273 | ae_type=ae_type, 274 | input_model=in_model, 275 | batch_norm=enc_batch_norm, 276 | enc_lambda_layers=enc_lambda_layers, 277 | include_mu_shift_layer=include_mu_shift_layer, 278 | do_vae=do_vae) 279 | 280 | # decoder 281 | if single_model: 282 | in_input_shape = None 283 | in_model = mid_ae_model 284 | else: 285 | in_input_shape = mid_ae_model.output.shape.as_list()[1:] 286 | in_model = None 287 | dec_model = conv_dec(nb_features, 288 | in_input_shape, 289 | nb_levels, 290 | conv_size, 291 | nb_labels, 292 | name=model_name, 293 | feat_mult=feat_mult, 294 | pool_size=pool_size, 295 | use_skip_connections=False, 296 | padding=padding, 297 | activation=activation, 298 | use_residuals=use_residuals, 299 | final_pred_activation='linear', 300 | nb_conv_per_level=nb_conv_per_level, 301 | batch_norm=batch_norm, 302 | conv_dropout=conv_dropout, 303 | input_model=in_model) 304 | 305 | if add_prior_layer: 306 | dec_model = add_prior(dec_model, 307 | [*input_shape[:-1],nb_labels], 308 | name=model_name, 309 | prefix=model_name + '_prior', 310 | use_logp=use_logp, 311 | final_pred_activation=final_pred_activation, 312 | add_prior_layer_reg=add_prior_layer_reg) 313 | 314 | if single_model: 315 | return dec_model 316 | else: 317 | return (dec_model, mid_ae_model, enc_model) 318 | 319 | 320 | def add_prior(input_model, 321 | prior_shape, 322 | name='prior_model', 323 | prefix=None, 324 | use_logp=True, 325 | final_pred_activation='softmax', 326 | add_prior_layer_reg=0): 327 | """ 328 | Append post-prior layer to a given model 329 | """ 330 | 331 | # naming 332 | model_name = name 333 | if prefix is None: 334 | prefix = model_name 335 | 336 | # prior input layer 337 | prior_input_name = '%s-input' % prefix 338 | prior_tensor = KL.Input(shape=prior_shape, name=prior_input_name) 339 | prior_tensor_input = prior_tensor 340 | like_tensor = input_model.output 341 | 342 | # operation varies depending on whether we log() prior or not. 343 | if use_logp: 344 | # name = '%s-log' % prefix 345 | # prior_tensor = KL.Lambda(_log_layer_wrap(add_prior_layer_reg), name=name)(prior_tensor) 346 | print("Breaking change: use_logp option now requires log input!", file=sys.stderr) 347 | merge_op = KL.add 348 | 349 | else: 350 | # using sigmoid to get the likelihood values between 0 and 1 351 | # note: they won't add up to 1. 352 | name = '%s_likelihood_sigmoid' % prefix 353 | like_tensor = KL.Activation('sigmoid', name=name)(like_tensor) 354 | merge_op = KL.multiply 355 | 356 | # merge the likelihood and prior layers into posterior layer 357 | name = '%s_posterior' % prefix 358 | post_tensor = merge_op([prior_tensor, like_tensor], name=name) 359 | 360 | # output prediction layer 361 | # we use a softmax to compute P(L_x|I) where x is each location 362 | pred_name = '%s_prediction' % prefix 363 | if final_pred_activation == 'softmax': 364 | assert use_logp, 'cannot do softmax when adding prior via P()' 365 | print("using final_pred_activation %s for %s" % (final_pred_activation, model_name)) 366 | softmax_lambda_fcn = lambda x: tensorflow.keras.activations.softmax(x, axis=-1) 367 | pred_tensor = KL.Lambda(softmax_lambda_fcn, name=pred_name)(post_tensor) 368 | 369 | else: 370 | pred_tensor = KL.Activation('linear', name=pred_name)(post_tensor) 371 | 372 | # create the model 373 | model_inputs = [*input_model.inputs, prior_tensor_input] 374 | model = Model(inputs=model_inputs, outputs=[pred_tensor], name=model_name) 375 | 376 | # compile 377 | return model 378 | 379 | 380 | def single_ae(enc_size, 381 | input_shape, 382 | name='single_ae', 383 | prefix=None, 384 | ae_type='dense', # 'dense', or 'conv' 385 | conv_size=None, 386 | input_model=None, 387 | enc_lambda_layers=None, 388 | batch_norm=True, 389 | padding='same', 390 | activation=None, 391 | include_mu_shift_layer=False, 392 | do_vae=False): 393 | """ 394 | single-layer Autoencoder (i.e. input - encoding - output) 395 | """ 396 | 397 | # naming 398 | model_name = name 399 | if prefix is None: 400 | prefix = model_name 401 | 402 | if enc_lambda_layers is None: 403 | enc_lambda_layers = [] 404 | 405 | # prepare input 406 | input_name = '%s_input' % prefix 407 | if input_model is None: 408 | assert input_shape is not None, 'input_shape of input_model is necessary' 409 | input_tensor = KL.Input(shape=input_shape, name=input_name) 410 | last_tensor = input_tensor 411 | else: 412 | input_tensor = input_model.input 413 | last_tensor = input_model.output 414 | input_shape = last_tensor.shape.as_list()[1:] 415 | input_nb_feats = last_tensor.shape.as_list()[-1] 416 | 417 | # prepare conv type based on input 418 | if ae_type == 'conv': 419 | ndims = len(input_shape) - 1 420 | convL = getattr(KL, 'Conv%dD' % ndims) 421 | assert conv_size is not None, 'with conv ae, need conv_size' 422 | conv_kwargs = {'padding': padding, 'activation': activation} 423 | 424 | 425 | 426 | # if want to go through a dense layer in the middle of the U, need to: 427 | # - flatten last layer if not flat 428 | # - do dense encoding and decoding 429 | # - unflatten (rehsape spatially) at end 430 | if ae_type == 'dense' and len(input_shape) > 1: 431 | name = '%s_ae_%s_down_flat' % (prefix, ae_type) 432 | last_tensor = KL.Flatten(name=name)(last_tensor) 433 | 434 | # recall this layer 435 | pre_enc_layer = last_tensor 436 | 437 | # encoding layer 438 | if ae_type == 'dense': 439 | assert len(enc_size) == 1, "enc_size should be of length 1 for dense layer" 440 | 441 | enc_size_str = ''.join(['%d_' % d for d in enc_size])[:-1] 442 | name = '%s_ae_mu_enc_dense_%s' % (prefix, enc_size_str) 443 | last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) 444 | 445 | else: # convolution 446 | # convolve then resize. enc_size should be [nb_dim1, nb_dim2, ..., nb_feats] 447 | assert len(enc_size) == len(input_shape), \ 448 | "encoding size does not match input shape %d %d" % (len(enc_size), len(input_shape)) 449 | 450 | if list(enc_size)[:-1] != list(input_shape)[:-1] and \ 451 | all([f is not None for f in input_shape[:-1]]) and \ 452 | all([f is not None for f in enc_size[:-1]]): 453 | 454 | # assert len(enc_size) - 1 == 2, "Sorry, I have not yet implemented non-2D resizing -- need to check out interpn!" 455 | name = '%s_ae_mu_enc_conv' % (prefix) 456 | last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) 457 | 458 | name = '%s_ae_mu_enc' % (prefix) 459 | zf = [enc_size[:-1][f]/last_tensor.shape.as_list()[1:-1][f] for f in range(len(enc_size)-1)] 460 | last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) 461 | # resize_fn = lambda x: tf.image.resize_bilinear(x, enc_size[:-1]) 462 | # last_tensor = KL.Lambda(resize_fn, name=name)(last_tensor) 463 | 464 | elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck 465 | name = '%s_ae_mu_enc' % (prefix) 466 | last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) 467 | 468 | else: 469 | name = '%s_ae_mu_enc' % (prefix) 470 | last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) 471 | 472 | if include_mu_shift_layer: 473 | # shift 474 | name = '%s_ae_mu_shift' % (prefix) 475 | last_tensor = layers.LocalBias(name=name)(last_tensor) 476 | 477 | # encoding clean-up layers 478 | for layer_fcn in enc_lambda_layers: 479 | lambda_name = layer_fcn.__name__ 480 | name = '%s_ae_mu_%s' % (prefix, lambda_name) 481 | last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) 482 | 483 | if batch_norm is not None: 484 | name = '%s_ae_mu_bn' % (prefix) 485 | last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) 486 | 487 | # have a simple layer that does nothing to have a clear name before sampling 488 | name = '%s_ae_mu' % (prefix) 489 | last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) 490 | 491 | 492 | # if doing variational AE, will need the sigma layer as well. 493 | if do_vae: 494 | mu_tensor = last_tensor 495 | 496 | # encoding layer 497 | if ae_type == 'dense': 498 | name = '%s_ae_sigma_enc_dense_%s' % (prefix, enc_size_str) 499 | last_tensor = KL.Dense(enc_size[0], name=name, 500 | # kernel_initializer=tensorflow.keras.initializers.RandomNormal(mean=0.0, stddev=1e-5), 501 | # bias_initializer=tensorflow.keras.initializers.RandomNormal(mean=-5.0, stddev=1e-5) 502 | )(pre_enc_layer) 503 | 504 | else: 505 | if list(enc_size)[:-1] != list(input_shape)[:-1] and \ 506 | all([f is not None for f in input_shape[:-1]]) and \ 507 | all([f is not None for f in enc_size[:-1]]): 508 | 509 | # assert len(enc_size) - 1 == 2, "Sorry, I have not yet implemented non-2D resizing..." 510 | name = '%s_ae_sigma_enc_conv' % (prefix) 511 | last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) 512 | 513 | name = '%s_ae_sigma_enc' % (prefix) 514 | zf = [enc_size[:-1][f]/last_tensor.shape.as_list()[1:-1][f] for f in range(len(enc_size)-1)] 515 | last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) 516 | # resize_fn = lambda x: tf.image.resize_bilinear(x, enc_size[:-1]) 517 | # last_tensor = KL.Lambda(resize_fn, name=name)(last_tensor) 518 | 519 | elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck 520 | name = '%s_ae_sigma_enc' % (prefix) 521 | last_tensor = convL(pre_enc_layer.shape.as_list()[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) 522 | # cannot use lambda, then mu and sigma will be same layer. 523 | # last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) 524 | 525 | else: 526 | name = '%s_ae_sigma_enc' % (prefix) 527 | last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) 528 | 529 | # encoding clean-up layers 530 | for layer_fcn in enc_lambda_layers: 531 | lambda_name = layer_fcn.__name__ 532 | name = '%s_ae_sigma_%s' % (prefix, lambda_name) 533 | last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) 534 | 535 | if batch_norm is not None: 536 | name = '%s_ae_sigma_bn' % (prefix) 537 | last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) 538 | 539 | # have a simple layer that does nothing to have a clear name before sampling 540 | name = '%s_ae_sigma' % (prefix) 541 | last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) 542 | 543 | logvar_tensor = last_tensor 544 | 545 | # VAE sampling 546 | name = '%s_ae_sample' % (prefix) 547 | last_tensor = layers.SampleNormalLogVar(name=name)([mu_tensor, logvar_tensor]) 548 | 549 | if include_mu_shift_layer: 550 | # shift 551 | name = '%s_ae_sample_shift' % (prefix) 552 | last_tensor = layers.LocalBias(name=name)(last_tensor) 553 | 554 | # decoding layer 555 | if ae_type == 'dense': 556 | name = '%s_ae_%s_dec_flat_%s' % (prefix, ae_type, enc_size_str) 557 | last_tensor = KL.Dense(np.prod(input_shape), name=name)(last_tensor) 558 | 559 | # unflatten if dense method 560 | if len(input_shape) > 1: 561 | name = '%s_ae_%s_dec' % (prefix, ae_type) 562 | last_tensor = KL.Reshape(input_shape, name=name)(last_tensor) 563 | 564 | else: 565 | 566 | if list(enc_size)[:-1] != list(input_shape)[:-1] and \ 567 | all([f is not None for f in input_shape[:-1]]) and \ 568 | all([f is not None for f in enc_size[:-1]]): 569 | 570 | name = '%s_ae_mu_dec' % (prefix) 571 | zf = [input_shape[:-1][f]/enc_size[:-1][f] for f in range(len(enc_size)-1)] 572 | last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) 573 | # resize_fn = lambda x: tf.image.resize_bilinear(x, input_shape[:-1]) 574 | # last_tensor = KL.Lambda(resize_fn, name=name)(last_tensor) 575 | 576 | name = '%s_ae_%s_dec' % (prefix, ae_type) 577 | last_tensor = convL(input_nb_feats, conv_size, name=name, **conv_kwargs)(last_tensor) 578 | 579 | 580 | if batch_norm is not None: 581 | name = '%s_bn_ae_%s_dec' % (prefix, ae_type) 582 | last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) 583 | 584 | # create the model and retun 585 | model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) 586 | return model 587 | 588 | 589 | 590 | 591 | ############################################################################### 592 | # Encoders, decoders, etc. 593 | ############################################################################### 594 | 595 | 596 | def conv_enc(nb_features, 597 | input_shape, 598 | nb_levels, 599 | conv_size, 600 | name=None, 601 | prefix=None, 602 | feat_mult=1, 603 | pool_size=2, 604 | dilation_rate_mult=1, 605 | padding='same', 606 | activation='elu', 607 | layer_nb_feats=None, 608 | use_residuals=False, 609 | nb_conv_per_level=2, 610 | conv_dropout=0, 611 | batch_norm=None): 612 | """ 613 | Fully Convolutional Encoder 614 | """ 615 | 616 | # naming 617 | model_name = name 618 | if prefix is None: 619 | prefix = model_name 620 | 621 | # volume size data 622 | ndims = len(input_shape) - 1 623 | input_shape = tuple(input_shape) 624 | if isinstance(pool_size, int): 625 | pool_size = (pool_size,) * ndims 626 | 627 | # prepare layers 628 | convL = getattr(KL, 'Conv%dD' % ndims) 629 | conv_kwargs = {'padding': padding, 'activation': activation} 630 | maxpool = getattr(KL, 'MaxPooling%dD' % ndims) 631 | 632 | # first layer: input 633 | name = '%s_input' % prefix 634 | last_tensor = KL.Input(shape=input_shape, name=name) 635 | input_tensor = last_tensor 636 | 637 | # down arm: 638 | # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers 639 | lfidx = 0 640 | for level in range(nb_levels): 641 | lvl_first_tensor = last_tensor 642 | nb_lvl_feats = np.round(nb_features*feat_mult**level).astype(int) 643 | conv_kwargs['dilation_rate'] = dilation_rate_mult**level 644 | 645 | for conv in range(nb_conv_per_level): 646 | if layer_nb_feats is not None: 647 | nb_lvl_feats = layer_nb_feats[lfidx] 648 | lfidx += 1 649 | 650 | name = '%s_conv_downarm_%d_%d' % (prefix, level, conv) 651 | if conv < (nb_conv_per_level-1) or (not use_residuals): 652 | last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) 653 | else: # no activation 654 | last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) 655 | 656 | if conv_dropout > 0: 657 | # conv dropout along feature space only 658 | name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv) 659 | noise_shape = [None, *[1]*ndims, nb_lvl_feats] 660 | last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) 661 | 662 | if use_residuals: 663 | convarm_layer = last_tensor 664 | 665 | # the "add" layer is the original input 666 | # However, it may not have the right number of features to be added 667 | nb_feats_in = lvl_first_tensor.get_shape()[-1] 668 | nb_feats_out = convarm_layer.get_shape()[-1] 669 | add_layer = lvl_first_tensor 670 | if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): 671 | name = '%s_expand_down_merge_%d' % (prefix, level) 672 | last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor) 673 | add_layer = last_tensor 674 | 675 | if conv_dropout > 0: 676 | name = '%s_dropout_down_merge_%d_%d' % (prefix, level, conv) 677 | noise_shape = [None, *[1]*ndims, nb_lvl_feats] 678 | last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) 679 | 680 | name = '%s_res_down_merge_%d' % (prefix, level) 681 | last_tensor = KL.add([add_layer, convarm_layer], name=name) 682 | 683 | name = '%s_res_down_merge_act_%d' % (prefix, level) 684 | last_tensor = KL.Activation(activation, name=name)(last_tensor) 685 | 686 | if batch_norm is not None: 687 | name = '%s_bn_down_%d' % (prefix, level) 688 | last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) 689 | 690 | # max pool if we're not at the last level 691 | if level < (nb_levels - 1): 692 | name = '%s_maxpool_%d' % (prefix, level) 693 | last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor) 694 | 695 | # create the model and return 696 | model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) 697 | return model 698 | 699 | 700 | def conv_dec(nb_features, 701 | input_shape, 702 | nb_levels, 703 | conv_size, 704 | nb_labels, 705 | name=None, 706 | prefix=None, 707 | feat_mult=1, 708 | pool_size=2, 709 | use_skip_connections=False, 710 | padding='same', 711 | dilation_rate_mult=1, 712 | activation='elu', 713 | use_residuals=False, 714 | final_pred_activation='softmax', 715 | nb_conv_per_level=2, 716 | layer_nb_feats=None, 717 | batch_norm=None, 718 | conv_dropout=0, 719 | input_model=None): 720 | """ 721 | Fully Convolutional Decoder 722 | 723 | Parameters: 724 | ... 725 | use_skip_connections (bool): if true, turns an Enc-Dec to a U-Net. 726 | If true, input_tensor and tensors are required. 727 | It assumes a particular naming of layers. conv_enc... 728 | """ 729 | 730 | # naming 731 | model_name = name 732 | if prefix is None: 733 | prefix = model_name 734 | 735 | # if using skip connections, make sure need to use them. 736 | if use_skip_connections: 737 | assert input_model is not None, "is using skip connections, tensors dictionary is required" 738 | 739 | # first layer: input 740 | input_name = '%s_input' % prefix 741 | if input_model is None: 742 | input_tensor = KL.Input(shape=input_shape, name=input_name) 743 | last_tensor = input_tensor 744 | else: 745 | input_tensor = input_model.input 746 | last_tensor = input_model.output 747 | input_shape = last_tensor.shape.as_list()[1:] 748 | 749 | # vol size info 750 | ndims = len(input_shape) - 1 751 | input_shape = tuple(input_shape) 752 | if isinstance(pool_size, int): 753 | if ndims > 1: 754 | pool_size = (pool_size,) * ndims 755 | 756 | # prepare layers 757 | convL = getattr(KL, 'Conv%dD' % ndims) 758 | conv_kwargs = {'padding': padding, 'activation': activation} 759 | upsample = getattr(KL, 'UpSampling%dD' % ndims) 760 | 761 | # up arm: 762 | # nb_levels - 1 layers of Deconvolution3D 763 | # (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu 764 | lfidx = 0 765 | for level in range(nb_levels - 1): 766 | nb_lvl_feats = np.round(nb_features*feat_mult**(nb_levels-2-level)).astype(int) 767 | conv_kwargs['dilation_rate'] = dilation_rate_mult**(nb_levels-2-level) 768 | 769 | # upsample matching the max pooling layers size 770 | name = '%s_up_%d' % (prefix, nb_levels + level) 771 | last_tensor = upsample(size=pool_size, name=name)(last_tensor) 772 | up_tensor = last_tensor 773 | 774 | # merge layers combining previous layer 775 | # TODO: add Cropping3D or Cropping2D if 'valid' padding 776 | if use_skip_connections: 777 | conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1) 778 | cat_tensor = input_model.get_layer(conv_name).output 779 | name = '%s_merge_%d' % (prefix, nb_levels + level) 780 | last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims+1, name=name) 781 | 782 | # convolution layers 783 | for conv in range(nb_conv_per_level): 784 | if layer_nb_feats is not None: 785 | nb_lvl_feats = layer_nb_feats[lfidx] 786 | lfidx += 1 787 | 788 | name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv) 789 | if conv < (nb_conv_per_level-1) or (not use_residuals): 790 | last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) 791 | else: 792 | last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) 793 | 794 | if conv_dropout > 0: 795 | name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv) 796 | noise_shape = [None, *[1]*ndims, nb_lvl_feats] 797 | last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) 798 | 799 | # residual block 800 | if use_residuals: 801 | 802 | # the "add" layer is the original input 803 | # However, it may not have the right number of features to be added 804 | add_layer = up_tensor 805 | nb_feats_in = add_layer.get_shape()[-1] 806 | nb_feats_out = last_tensor.get_shape()[-1] 807 | if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): 808 | name = '%s_expand_up_merge_%d' % (prefix, level) 809 | add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer) 810 | 811 | if conv_dropout > 0: 812 | name = '%s_dropout_up_merge_%d_%d' % (prefix, level, conv) 813 | noise_shape = [None, *[1]*ndims, nb_lvl_feats] 814 | last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) 815 | 816 | name = '%s_res_up_merge_%d' % (prefix, level) 817 | last_tensor = KL.add([last_tensor, add_layer], name=name) 818 | 819 | name = '%s_res_up_merge_act_%d' % (prefix, level) 820 | last_tensor = KL.Activation(activation, name=name)(last_tensor) 821 | 822 | if batch_norm is not None: 823 | name = '%s_bn_up_%d' % (prefix, level) 824 | last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) 825 | 826 | # Compute likelyhood prediction (no activation yet) 827 | name = '%s_likelihood' % prefix 828 | last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor) 829 | like_tensor = last_tensor 830 | 831 | # output prediction layer 832 | # we use a softmax to compute P(L_x|I) where x is each location 833 | if final_pred_activation == 'softmax': 834 | print("using final_pred_activation %s for %s" % (final_pred_activation, model_name)) 835 | name = '%s_prediction' % prefix 836 | softmax_lambda_fcn = lambda x: tensorflow.keras.activations.softmax(x, axis=ndims + 1) 837 | pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor) 838 | 839 | # otherwise create a layer that does nothing. 840 | else: 841 | name = '%s_prediction' % prefix 842 | pred_tensor = KL.Activation('linear', name=name)(like_tensor) 843 | 844 | # create the model and retun 845 | model = Model(inputs=input_tensor, outputs=pred_tensor, name=model_name) 846 | return model 847 | 848 | 849 | def design_dnn(nb_features, input_shape, nb_levels, conv_size, nb_labels, 850 | feat_mult=1, 851 | pool_size=2, 852 | padding='same', 853 | activation='elu', 854 | final_layer='dense-sigmoid', 855 | conv_dropout=0, 856 | conv_maxnorm=0, 857 | nb_input_features=1, 858 | batch_norm=False, 859 | name=None, 860 | prefix=None, 861 | use_strided_convolution_maxpool=True, 862 | nb_conv_per_level=2): 863 | """ 864 | "deep" cnn with dense or global max pooling layer @ end... 865 | 866 | Could use sequential... 867 | """ 868 | 869 | 870 | def _global_max_nd(xtens): 871 | ytens = K.batch_flatten(xtens) 872 | return K.max(ytens, 1, keepdims=True) 873 | 874 | 875 | model_name = name 876 | if model_name is None: 877 | model_name = 'model_1' 878 | if prefix is None: 879 | prefix = model_name 880 | 881 | ndims = len(input_shape) 882 | input_shape = tuple(input_shape) 883 | 884 | convL = getattr(KL, 'Conv%dD' % ndims) 885 | maxpool = KL.MaxPooling3D if len(input_shape) == 3 else KL.MaxPooling2D 886 | if isinstance(pool_size, int): 887 | pool_size = (pool_size,) * ndims 888 | 889 | # kwargs for the convolution layer 890 | conv_kwargs = {'padding': padding, 'activation': activation} 891 | if conv_maxnorm > 0: 892 | conv_kwargs['kernel_constraint'] = maxnorm(conv_maxnorm) 893 | 894 | # initialize a dictionary 895 | enc_tensors = {} 896 | 897 | # first layer: input 898 | name = '%s_input' % prefix 899 | enc_tensors[name] = KL.Input(shape=input_shape + (nb_input_features,), name=name) 900 | last_tensor = enc_tensors[name] 901 | 902 | # down arm: 903 | # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers 904 | for level in range(nb_levels): 905 | for conv in range(nb_conv_per_level): 906 | if conv_dropout > 0: 907 | name = '%s_dropout_%d_%d' % (prefix, level, conv) 908 | enc_tensors[name] = KL.Dropout(conv_dropout)(last_tensor) 909 | last_tensor = enc_tensors[name] 910 | 911 | name = '%s_conv_%d_%d' % (prefix, level, conv) 912 | nb_lvl_feats = np.round(nb_features*feat_mult**level).astype(int) 913 | enc_tensors[name] = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) 914 | last_tensor = enc_tensors[name] 915 | 916 | # max pool 917 | if use_strided_convolution_maxpool: 918 | name = '%s_strided_conv_%d' % (prefix, level) 919 | enc_tensors[name] = convL(nb_lvl_feats, pool_size, **conv_kwargs, name=name)(last_tensor) 920 | last_tensor = enc_tensors[name] 921 | else: 922 | name = '%s_maxpool_%d' % (prefix, level) 923 | enc_tensors[name] = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor) 924 | last_tensor = enc_tensors[name] 925 | 926 | # dense layer 927 | if final_layer == 'dense-sigmoid': 928 | 929 | name = "%s_flatten" % prefix 930 | enc_tensors[name] = KL.Flatten(name=name)(last_tensor) 931 | last_tensor = enc_tensors[name] 932 | 933 | name = '%s_dense' % prefix 934 | enc_tensors[name] = KL.Dense(1, name=name, activation="sigmoid")(last_tensor) 935 | 936 | elif final_layer == 'dense-tanh': 937 | 938 | name = "%s_flatten" % prefix 939 | enc_tensors[name] = KL.Flatten(name=name)(last_tensor) 940 | last_tensor = enc_tensors[name] 941 | 942 | name = '%s_dense' % prefix 943 | enc_tensors[name] = KL.Dense(1, name=name)(last_tensor) 944 | last_tensor = enc_tensors[name] 945 | 946 | # Omittting BatchNorm for now, it seems to have a cpu vs gpu problem 947 | # https://github.com/tensorflow/tensorflow/pull/8906 948 | # https://github.com/fchollet/keras/issues/5802 949 | # name = '%s_%s_bn' % prefix 950 | # enc_tensors[name] = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) 951 | # last_tensor = enc_tensors[name] 952 | 953 | name = '%s_%s_tanh' % prefix 954 | enc_tensors[name] = KL.Activation(activation="tanh", name=name)(last_tensor) 955 | 956 | elif final_layer == 'dense-softmax': 957 | 958 | name = "%s_flatten" % prefix 959 | enc_tensors[name] = KL.Flatten(name=name)(last_tensor) 960 | last_tensor = enc_tensors[name] 961 | 962 | name = '%s_dense' % prefix 963 | enc_tensors[name] = KL.Dense(nb_labels, name=name, activation="softmax")(last_tensor) 964 | 965 | # global max pooling layer 966 | elif final_layer == 'myglobalmaxpooling': 967 | 968 | name = '%s_batch_norm' % prefix 969 | enc_tensors[name] = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) 970 | last_tensor = enc_tensors[name] 971 | 972 | name = '%s_global_max_pool' % prefix 973 | enc_tensors[name] = KL.Lambda(_global_max_nd, name=name)(last_tensor) 974 | last_tensor = enc_tensors[name] 975 | 976 | name = '%s_global_max_pool_reshape' % prefix 977 | enc_tensors[name] = KL.Reshape((1, 1), name=name)(last_tensor) 978 | last_tensor = enc_tensors[name] 979 | 980 | # cannot do activation in lambda layer. Could code inside, but will do extra lyaer 981 | name = '%s_global_max_pool_sigmoid' % prefix 982 | enc_tensors[name] = KL.Conv1D(1, 1, name=name, activation="sigmoid", use_bias=True)(last_tensor) 983 | 984 | elif final_layer == 'globalmaxpooling': 985 | 986 | name = '%s_conv_to_featmaps' % prefix 987 | enc_tensors[name] = KL.Conv3D(2, 1, name=name, activation="relu")(last_tensor) 988 | last_tensor = enc_tensors[name] 989 | 990 | name = '%s_global_max_pool' % prefix 991 | enc_tensors[name] = KL.GlobalMaxPooling3D(name=name)(last_tensor) 992 | last_tensor = enc_tensors[name] 993 | 994 | # cannot do activation in lambda layer. Could code inside, but will do extra lyaer 995 | name = '%s_global_max_pool_softmax' % prefix 996 | enc_tensors[name] = KL.Activation('softmax', name=name)(last_tensor) 997 | 998 | last_tensor = enc_tensors[name] 999 | 1000 | # create the model 1001 | model = Model(inputs=[enc_tensors['%s_input' % prefix]], outputs=[last_tensor], name=model_name) 1002 | return model 1003 | 1004 | 1005 | 1006 | ############################################################################### 1007 | # Helper function 1008 | ############################################################################### 1009 | -------------------------------------------------------------------------------- /flowreg_a/neuron/plot.py: -------------------------------------------------------------------------------- 1 | """ 2 | plot utilities for the neuron project 3 | 4 | If you use this code, please cite the first paper this was built for: 5 | Dalca AV, Guttag J, Sabuncu MR 6 | Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, 7 | CVPR 2018 8 | 9 | Contact: adalca [at] csail [dot] mit [dot] edu 10 | License: GPLv3 11 | """ 12 | 13 | import matplotlib.cm as cm 14 | import matplotlib.pyplot as plt 15 | # third party 16 | import numpy as np 17 | from matplotlib.colors import Normalize 18 | from mpl_toolkits.axes_grid1 import make_axes_locatable # plotting 19 | 20 | 21 | def slices(slices_in, # the 2D slices 22 | titles=None, # list of titles 23 | cmaps=None, # list of colormaps 24 | norms=None, # list of normalizations 25 | do_colorbars=False, # option to show colorbars on each slice 26 | grid=False, # option to plot the images in a grid or a single row 27 | width=15, # width in in 28 | show=True, # option to actually show the plot (plt.show()) 29 | axes_off=True, 30 | imshow_args=None): 31 | ''' 32 | plot a grid of slices (2d images) 33 | ''' 34 | 35 | # input processing 36 | if type(slices_in) == np.ndarray: 37 | slices_in = [slices_in] 38 | nb_plots = len(slices_in) 39 | for si, slice_in in enumerate(slices_in): 40 | if len(slice_in.shape) != 2: 41 | assert len(slice_in.shape) == 3 and slice_in.shape[-1] == 3, 'each slice has to be 2d or RGB (3 channels)' 42 | slices_in[si] = slice_in.astype('float') 43 | 44 | 45 | def input_check(inputs, nb_plots, name): 46 | ''' change input from None/single-link ''' 47 | assert (inputs is None) or (len(inputs) == nb_plots) or (len(inputs) == 1), \ 48 | 'number of %s is incorrect' % name 49 | if inputs is None: 50 | inputs = [None] 51 | if len(inputs) == 1: 52 | inputs = [inputs[0] for i in range(nb_plots)] 53 | return inputs 54 | 55 | titles = input_check(titles, nb_plots, 'titles') 56 | cmaps = input_check(cmaps, nb_plots, 'cmaps') 57 | norms = input_check(norms, nb_plots, 'norms') 58 | imshow_args = input_check(imshow_args, nb_plots, 'imshow_args') 59 | for idx, ia in enumerate(imshow_args): 60 | imshow_args[idx] = {} if ia is None else ia 61 | 62 | # figure out the number of rows and columns 63 | if grid: 64 | if isinstance(grid, bool): 65 | rows = np.floor(np.sqrt(nb_plots)).astype(int) 66 | cols = np.ceil(nb_plots/rows).astype(int) 67 | else: 68 | assert isinstance(grid, (list, tuple)), \ 69 | "grid should either be bool or [rows,cols]" 70 | rows, cols = grid 71 | else: 72 | rows = 1 73 | cols = nb_plots 74 | 75 | # prepare the subplot 76 | fig, axs = plt.subplots(rows, cols) 77 | if rows == 1 and cols == 1: 78 | axs = [axs] 79 | 80 | for i in range(nb_plots): 81 | col = np.remainder(i, cols) 82 | row = np.floor(i/cols).astype(int) 83 | 84 | # get row and column axes 85 | row_axs = axs if rows == 1 else axs[row] 86 | ax = row_axs[col] 87 | 88 | # turn off axis 89 | ax.axis('off') 90 | 91 | # add titles 92 | if titles is not None and titles[i] is not None: 93 | ax.title.set_text(titles[i]) 94 | 95 | # show figure 96 | im_ax = ax.imshow(slices_in[i], cmap=cmaps[i], interpolation="nearest", norm=norms[i], **imshow_args[i]) 97 | 98 | # colorbars 99 | # http://stackoverflow.com/questions/18195758/set-matplotlib-colorbar-size-to-match-graph 100 | if do_colorbars and cmaps[i] is not None: 101 | divider = make_axes_locatable(ax) 102 | cax = divider.append_axes("right", size="5%", pad=0.05) 103 | fig.colorbar(im_ax, cax=cax) 104 | 105 | # clear axes that are unnecessary 106 | for i in range(nb_plots, col*row): 107 | col = np.remainder(i, cols) 108 | row = np.floor(i/cols).astype(int) 109 | 110 | # get row and column axes 111 | row_axs = axs if rows == 1 else axs[row] 112 | ax = row_axs[col] 113 | 114 | if axes_off: 115 | ax.axis('off') 116 | 117 | # show the plots 118 | fig.set_size_inches(width, rows/cols*width) 119 | 120 | 121 | if show: 122 | plt.tight_layout() 123 | plt.show() 124 | 125 | return (fig, axs) 126 | 127 | 128 | def flow_legend(): 129 | """ 130 | show quiver plot to indicate how arrows are colored in the flow() method. 131 | https://stackoverflow.com/questions/40026718/different-colours-for-arrows-in-quiver-plot 132 | """ 133 | ph = np.linspace(0, 2*np.pi, 13) 134 | x = np.cos(ph) 135 | y = np.sin(ph) 136 | u = np.cos(ph) 137 | v = np.sin(ph) 138 | colors = np.arctan2(u, v) 139 | 140 | norm = Normalize() 141 | norm.autoscale(colors) 142 | # we need to normalize our colors array to match it colormap domain 143 | # which is [0, 1] 144 | 145 | colormap = cm.winter 146 | 147 | plt.figure(figsize=(6, 6)) 148 | plt.xlim(-2, 2) 149 | plt.ylim(-2, 2) 150 | plt.quiver(x, y, u, v, color=colormap(norm(colors)), angles='xy', scale_units='xy', scale=1) 151 | plt.show() 152 | 153 | 154 | def flow(slices_in, # the 2D slices 155 | titles=None, # list of titles 156 | cmaps=None, # list of colormaps 157 | width=15, # width in in 158 | img_indexing=True, # whether to match the image view, i.e. flip y axis 159 | grid=False, # option to plot the images in a grid or a single row 160 | show=True, # option to actually show the plot (plt.show()) 161 | quiver_width=None, 162 | scale=1): # note quiver essentially draws quiver length = 1/scale 163 | ''' 164 | plot a grid of flows (2d+2 images) 165 | ''' 166 | 167 | # input processing 168 | nb_plots = len(slices_in) 169 | for slice_in in slices_in: 170 | assert len(slice_in.shape) == 3, 'each slice has to be 3d: 2d+2 channels' 171 | assert slice_in.shape[-1] == 2, 'each slice has to be 3d: 2d+2 channels' 172 | 173 | def input_check(inputs, nb_plots, name): 174 | ''' change input from None/single-link ''' 175 | if not isinstance(inputs, (list, tuple)): 176 | inputs = [inputs] 177 | assert (inputs is None) or (len(inputs) == nb_plots) or (len(inputs) == 1), \ 178 | 'number of %s is incorrect' % name 179 | if inputs is None: 180 | inputs = [None] 181 | if len(inputs) == 1: 182 | inputs = [inputs[0] for i in range(nb_plots)] 183 | return inputs 184 | 185 | if img_indexing: 186 | for si, slc in enumerate(slices_in): 187 | slices_in[si] = np.flipud(slc) 188 | 189 | titles = input_check(titles, nb_plots, 'titles') 190 | cmaps = input_check(cmaps, nb_plots, 'cmaps') 191 | scale = input_check(scale, nb_plots, 'scale') 192 | 193 | # figure out the number of rows and columns 194 | if grid: 195 | if isinstance(grid, bool): 196 | rows = np.floor(np.sqrt(nb_plots)).astype(int) 197 | cols = np.ceil(nb_plots/rows).astype(int) 198 | else: 199 | assert isinstance(grid, (list, tuple)), \ 200 | "grid should either be bool or [rows,cols]" 201 | rows, cols = grid 202 | else: 203 | rows = 1 204 | cols = nb_plots 205 | 206 | # prepare the subplot 207 | fig, axs = plt.subplots(rows, cols) 208 | if rows == 1 and cols == 1: 209 | axs = [axs] 210 | 211 | for i in range(nb_plots): 212 | col = np.remainder(i, cols) 213 | row = np.floor(i/cols).astype(int) 214 | 215 | # get row and column axes 216 | row_axs = axs if rows == 1 else axs[row] 217 | ax = row_axs[col] 218 | 219 | # turn off axis 220 | ax.axis('off') 221 | 222 | # add titles 223 | if titles is not None and titles[i] is not None: 224 | ax.title.set_text(titles[i]) 225 | 226 | u, v = slices_in[i][...,0], slices_in[i][...,1] 227 | colors = np.arctan2(u, v) 228 | colors[np.isnan(colors)] = 0 229 | norm = Normalize() 230 | norm.autoscale(colors) 231 | if cmaps[i] is None: 232 | colormap = cm.winter 233 | else: 234 | raise Exception("custom cmaps not currently implemented for plt.flow()") 235 | 236 | # show figure 237 | ax.quiver(u, v, 238 | color=colormap(norm(colors).flatten()), 239 | angles='xy', 240 | units='xy', 241 | width=quiver_width, 242 | scale=scale[i]) 243 | ax.axis('equal') 244 | 245 | # clear axes that are unnecessary 246 | for i in range(nb_plots, col*row): 247 | col = np.remainder(i, cols) 248 | row = np.floor(i/cols).astype(int) 249 | 250 | # get row and column axes 251 | row_axs = axs if rows == 1 else axs[row] 252 | ax = row_axs[col] 253 | 254 | ax.axis('off') 255 | 256 | # show the plots 257 | fig.set_size_inches(width, rows/cols*width) 258 | plt.tight_layout() 259 | 260 | if show: 261 | plt.show() 262 | 263 | return (fig, axs) 264 | 265 | 266 | def pca(pca, x, y): 267 | x_mean = np.mean(x, 0) 268 | x_std = np.std(x, 0) 269 | 270 | W = pca.components_ 271 | x_mu = W @ pca.mean_ # pca.mean_ is y_mean 272 | y_hat = x @ W + pca.mean_ 273 | 274 | y_err = y_hat - y 275 | y_rel_err = y_err / np.maximum(0.5*(np.abs(y)+np.abs(y_hat)), np.finfo('float').eps) 276 | 277 | plt.figure(figsize=(15, 7)) 278 | plt.subplot(2, 3, 1) 279 | plt.plot(pca.explained_variance_ratio_) 280 | plt.title('var %% explained') 281 | plt.subplot(2, 3, 2) 282 | plt.plot(np.cumsum(pca.explained_variance_ratio_)) 283 | plt.ylim([0, 1.01]) 284 | plt.grid() 285 | plt.title('cumvar explained') 286 | plt.subplot(2, 3, 3) 287 | plt.plot(np.cumsum(pca.explained_variance_ratio_)) 288 | plt.ylim([0.8, 1.01]) 289 | plt.grid() 290 | plt.title('cumvar explained') 291 | 292 | plt.subplot(2, 3, 4) 293 | plt.plot(x_mean) 294 | plt.plot(x_mean + x_std, 'k') 295 | plt.plot(x_mean - x_std, 'k') 296 | plt.title('x mean across dims (sorted)') 297 | plt.subplot(2, 3, 5) 298 | plt.hist(y_rel_err.flat, 100) 299 | plt.title('y rel err histogram') 300 | plt.subplot(2, 3, 6) 301 | plt.imshow(W @ np.transpose(W), cmap=plt.get_cmap('gray')) 302 | plt.colorbar() 303 | plt.title('W * W\'') 304 | plt.show() 305 | -------------------------------------------------------------------------------- /flowreg_a/neuron/regularizers.py: -------------------------------------------------------------------------------- 1 | """ 2 | tensorflow/keras regularizers for the neuron project 3 | 4 | If you use this code, please cite 5 | Dalca AV, Guttag J, Sabuncu MR 6 | Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, 7 | CVPR 2018 8 | 9 | or for the transformation/interpolation related functions: 10 | 11 | Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration 12 | Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu 13 | MICCAI 2018. 14 | 15 | Contact: adalca [at] csail [dot] mit [dot] edu 16 | License: GPLv3 17 | """ 18 | 19 | import tensorflow as tf 20 | import tensorflow.keras.backend as K 21 | 22 | from .utils import soft_delta 23 | 24 | 25 | def soft_l0_wrap(wt = 1.): 26 | 27 | def soft_l0(x): 28 | """ 29 | maximize the number of 0 weights 30 | """ 31 | nb_weights = tf.cast(tf.size(x), tf.float32) 32 | nb_zero_wts = tf.reduce_sum(soft_delta(K.flatten(x))) 33 | return wt * (nb_weights - nb_zero_wts) / nb_weights 34 | 35 | return soft_l0 36 | -------------------------------------------------------------------------------- /flowreg_a/neuron/vae_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | tools for (v)ae processing, debugging, and exploration 3 | """ 4 | from tempfile import NamedTemporaryFile 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | # third party imports 9 | import tensorflow as tf 10 | from IPython.display import Image 11 | from sklearn.decomposition import PCA 12 | from tensorflow import keras 13 | from tensorflow.keras import backend as K 14 | from tensorflow.keras import layers as KL 15 | from tensorflow.keras.utils import plot_model 16 | from tqdm import tqdm as tqdm 17 | 18 | # project imports 19 | import neuron as ne 20 | 21 | 22 | def extract_z_dec(model, sample_layer_name, vis=False, wt_chk=False): 23 | """ 24 | extract the z_decoder [z = p(x)] and return it as a keras model 25 | 26 | Example Layer name: 27 | sample_layer_name = 'img-img-dense-vae_ae_dense_sample' 28 | """ 29 | 30 | # need to make new model to avoid mu, sigma outputs 31 | tmp_model = keras.models.Model(model.inputs, model.outputs[0]) 32 | 33 | # get new input 34 | sample_layer = model.get_layer(sample_layer_name) 35 | enc_size = sample_layer.get_output_at(0).get_shape().as_list()[1:] 36 | new_z_input = KL.Input(enc_size, name='z_input') 37 | 38 | # prepare outputs 39 | # assumes z was first input. 40 | new_inputs = [new_z_input, *model.inputs[1:]] 41 | input_layers = [sample_layer_name, *model.input_layers[1:]] 42 | z_dec_model_outs = ne.utils.mod_submodel(tmp_model, 43 | new_input_nodes=new_inputs, 44 | input_layers=input_layers) 45 | 46 | # get new model 47 | z_dec_model = keras.models.Model(new_inputs, z_dec_model_outs) 48 | 49 | if vis: 50 | outfile = NamedTemporaryFile().name + '.png' 51 | plot_model(z_dec_model, to_file=outfile, show_shapes=True) 52 | Image(outfile, width=100) 53 | 54 | # check model weights: 55 | if wt_chk: 56 | for layer in z_dec_model.layers: 57 | wts1 = layer.get_weights() 58 | if layer.name not in [l.name for l in model.layers]: 59 | continue 60 | wts2 = model.get_layer(layer.name).get_weights() 61 | if len(wts1) > 0: 62 | assert np.all([np.mean(wts1[i] - wts2[i]) < 1e-9 for i, 63 | _ in enumerate(wts1)]), "model copy failed" 64 | 65 | return z_dec_model 66 | 67 | 68 | def z_effect(model, gen, z_layer_name, nb_samples=100, do_plot=False, tqdm=tqdm): 69 | """ 70 | compute the effect of each z dimension on the final outcome via derivatives 71 | we attempt this by taking gradients as in 72 | https://stackoverflow.com/questions/39561560/getting-gradient-of-model-output-w-r-t-weights-using-keras 73 | 74 | e.g. layer name: 'img-img-dense-vae_ae_dense_sample' 75 | """ 76 | 77 | outputTensor = model.outputs[0] 78 | inner = model.get_layer(z_layer_name).get_output_at(1) 79 | 80 | # compute gradients 81 | gradients = K.gradients(outputTensor, inner) 82 | assert len(gradients) == 1, "wrong gradients" 83 | 84 | # would be nice to be able to do this with K.eval() as opposed to explicit tensorflow sessions. 85 | with tf.Session() as sess: 86 | sess.run(tf.initialize_all_variables()) 87 | 88 | evaluated_gradients = [None] * nb_samples 89 | for i in tqdm(range(nb_samples)): 90 | sample = next(gen) 91 | fdct = {model.get_input_at(0): sample[0]} 92 | evaluated_gradients[i] = sess.run(gradients, feed_dict=fdct)[0] 93 | 94 | all_gradients = np.mean(np.abs(np.vstack(evaluated_gradients)), 0) 95 | 96 | if do_plot: 97 | plt.figure() 98 | plt.plot(np.sort(all_gradients)) 99 | plt.xlabel('sorted z index') 100 | plt.ylabel('mean(|grad|)') 101 | plt.show() 102 | 103 | return all_gradients 104 | 105 | 106 | def sample_dec(z_dec_model, 107 | z_mu=None, 108 | z_logvar=None, 109 | nb_samples=5, 110 | tqdm=tqdm, 111 | z_id=None, 112 | do_sweep=False, 113 | nb_sweep_stds=3, 114 | extra_inputs=[], 115 | nargout=1): 116 | """ 117 | sample from the decoder (i.e. sample z, compute x_mu|z) 118 | 119 | use z_id if you want to vary only a specific z index 120 | 121 | use sweep parameters if you want to sweep around mu from one end to another. 122 | """ 123 | 124 | input_shape = z_dec_model.inputs[0].get_shape()[1:].as_list() 125 | if z_mu is None: 126 | z_mu = np.zeros([1, *input_shape]) 127 | else: 128 | z_mu = np.reshape(z_mu, [1, *input_shape]) 129 | 130 | if z_logvar is None: 131 | z_logvar = np.zeros([1, *input_shape]) 132 | else: 133 | z_logvar = np.reshape(z_logvar, [1, *input_shape]) 134 | 135 | # get standard deviation 136 | z_std = np.exp(z_logvar/2) 137 | 138 | # get samples 139 | if do_sweep: 140 | if z_id is not None: 141 | low = z_mu 142 | high = z_mu 143 | low[0, z_id] = z_mu[0, z_id] - nb_sweep_stds * z_std[0, z_id] 144 | high[0, z_id] = z_mu[0, z_id] - nb_sweep_stds * z_std[0, z_id] 145 | else: 146 | low = z_mu - nb_sweep_stds * z_std 147 | high = z_mu - nb_sweep_stds * z_std 148 | 149 | x_sweep = np.linspace(0, 1, nb_samples) 150 | z_samples = [x * high + (1-x) * low for x in x_sweep] 151 | 152 | else: 153 | std = np.copy(z_std) 154 | if z_id is not None: 155 | std = np.ones(len(z_std)) * np.finfo('float').eps 156 | std[0, z_id] = z_std[0, z_id] 157 | z_samples = [np.random.normal(loc=z_mu, scale=z_std) 158 | for _ in range(nb_samples)] 159 | 160 | # propagate 161 | outs = [None] * nb_samples 162 | for zi, z_sample in enumerate(tqdm(z_samples)): 163 | outs[zi] = z_dec_model.predict([z_sample, *extra_inputs]) 164 | 165 | if nargout == 1: 166 | return outs 167 | else: 168 | return (outs, z_samples) 169 | 170 | 171 | def sweep_dec_given_x(full_model, z_dec_model, sample1, sample2, sample_layer_name, 172 | sweep_z_samples=False, 173 | nb_samples=10, 174 | nargout=1, 175 | tqdm=tqdm): 176 | """ 177 | sweep the latent space given two samples in the original space 178 | specificaly, get z_mu = enc(x) for both samples, and sweep between those z_mus 179 | 180 | "sweep_z_samples" does a sweep between two samples, rather than between two z_mus. 181 | 182 | Example: 183 | sample_layer_name='img-img-dense-vae_ae_dense_sample' 184 | """ 185 | 186 | # get a model that also outputs the samples z 187 | full_output = [*full_model.outputs, 188 | full_model.get_layer(sample_layer_name).get_output_at(1)] 189 | full_model_plus = keras.models.Model(full_model.inputs, full_output) 190 | 191 | # get full predictions for these samples 192 | pred1 = full_model_plus.predict(sample1[0]) 193 | pred2 = full_model_plus.predict(sample2[0]) 194 | img1 = sample1[0] 195 | img2 = sample2[0] 196 | 197 | # sweep range 198 | x_range = np.linspace(0, 1, nb_samples) 199 | 200 | # prepare outputs 201 | outs = [None] * nb_samples 202 | for xi, x in enumerate(tqdm(x_range)): 203 | if sweep_z_samples: 204 | z = x * pred1[3] + (1-x) * pred2[3] 205 | else: 206 | z = x * pred1[1] + (1-x) * pred2[1] 207 | 208 | if isinstance(sample1[0], (list, tuple)): # assuming prior or something like that 209 | outs[xi] = z_dec_model.predict([z, *sample1[0][1:]]) 210 | else: 211 | outs[xi] = z_dec_model.predict(z) 212 | 213 | if nargout == 1: 214 | return outs 215 | else: 216 | return (outs, [pred1, pred2]) 217 | 218 | 219 | def pca_init_dense(model, mu_dense_layer_name, undense_layer_name, generator, 220 | input_len=None, 221 | do_vae=True, 222 | logvar_dense_layer_name=None, 223 | nb_samples=None, 224 | tqdm=tqdm, 225 | vis=False): 226 | """ 227 | initialize the (V)AE middle *dens*e layer with PCA 228 | Warning: this modifies the weights in your model! 229 | 230 | model should take input the same as the normal (V)AE, and output a flat layer before the mu dense layer 231 | if nb_samples is None, we will compute at least as many as there are initial dimension (Which might be a lot) 232 | 233 | assumes mu_dense_layer_name is of input size [None, pre_mu_len] and output size [None, enc_len] 234 | 235 | example 236 | mu_dense_layer_name = 'img-img-dense-ae_ae_mu_enc_1000' 237 | undense_layer_name = 'img-img-dense-ae_ae_dense_dec_flat_1000' 238 | """ 239 | 240 | # extract important layer 241 | mu_dense_layer = model.get_layer(mu_dense_layer_name) 242 | mu_undense_layer = model.get_layer(undense_layer_name) 243 | 244 | # prepare model that outputs the pre_mu flat 245 | nb_inbound_nodes = len(mu_dense_layer._inbound_nodes) 246 | for i in range(nb_inbound_nodes): 247 | try: 248 | out_tensor = mu_dense_layer.get_input_at(i) 249 | pre_mu_model = keras.models.Model(model.inputs, out_tensor) 250 | 251 | # save the node index 252 | node_idx = i 253 | break 254 | 255 | except: 256 | if i == nb_inbound_nodes - 1: 257 | raise Exception( 258 | 'Could not initialize pre_mu model. Something went wrong :(') 259 | 260 | # extract PCA sizes 261 | if input_len is None: 262 | input_len = mu_dense_layer.get_input_at( 263 | node_idx).get_shape().as_list()[1:] 264 | assert len(input_len) == 1, 'layer input size is not 0' 265 | input_len = input_len[0] 266 | if input_len is None: 267 | input_len = mu_dense_layer.get_weights()[0].shape[0] 268 | assert input_len is not None, "could not figure out input len" 269 | 270 | enc_size = mu_dense_layer.get_output_at(node_idx).get_shape().as_list()[1:] 271 | assert len(enc_size) == 1, 'encoding size is not 0' 272 | enc_len = enc_size[0] 273 | 274 | # number of samples 275 | if nb_samples is None: 276 | nb_samples = np.maximum(enc_len, input_len) 277 | 278 | # mu pca 279 | pca_mu, x, y = model_output_pca( 280 | pre_mu_model, generator, nb_samples, enc_len, vis=vis, tqdm=tqdm) 281 | W_mu = pca_mu.components_ # enc_size * input_len 282 | 283 | # fix pca 284 | # y = x @ W + y_mean = (x + x_mu) @ W 285 | # x = y @ np.transpose(W) - x_mu 286 | mu_dense_layer.set_weights([np.transpose(W_mu), - (W_mu @ pca_mu.mean_)]) 287 | mu_undense_layer.set_weights([W_mu, + pca_mu.mean_]) 288 | 289 | # set var components with mu pca as well. 290 | if do_vae: 291 | model.get_layer(logvar_dense_layer_name).set_weights( 292 | [np.transpose(W_mu), - x_mu]) 293 | 294 | # return pca data at least for debugging 295 | return (pca_mu, x, y) 296 | 297 | 298 | def model_output_pca(pre_mu_model, generator, nb_samples, nb_components, 299 | vis=False, 300 | tqdm=tqdm): 301 | """ 302 | compute PCA of model outputs 303 | """ 304 | 305 | # go through 306 | sample = next(generator) 307 | nb_batch_samples = _sample_batch_size(sample) 308 | if nb_batch_samples == 1: 309 | zs = [None] * nb_samples 310 | zs[0] = pre_mu_model.predict(sample[0]) 311 | for i in tqdm(range(1, nb_samples)): 312 | sample = next(generator) 313 | zs[i] = pre_mu_model.predict(sample[0]) 314 | y = np.vstack(zs) 315 | 316 | else: 317 | assert nb_batch_samples == nb_samples, \ 318 | "generator should either give us 1 sample or %d samples at once. got: %d" % (nb_samples, nb_batch_samples) 319 | y = pre_mu_model.predict(sample[0]) 320 | 321 | # pca 322 | pca = PCA(n_components=nb_components) 323 | x = pca.fit_transform(y) 324 | 325 | # make sure we can recover 326 | if vis: 327 | ne.plt.pca(pca, x, y) 328 | 329 | """ 330 | Test pca model assaignment: 331 | # make input, then dense, then dense, then output, and see if input is output for y samples. 332 | inp = KL.Input(pca.mean_.shape) 333 | den = KL.Dense(x_mu.shape[0]) 334 | den_o = den(inp) 335 | unden = KL.Dense(pca.mean_.shape[0]) 336 | unden_o = unden(den_o) 337 | test_ae = keras.models.Model(inp, [den_o, unden_o]) 338 | 339 | den.set_weights([np.transpose(W), - x_mu]) 340 | unden.set_weights([W, + pca.mean_]) 341 | 342 | x_pred, y_pred = test_ae.predict(y) 343 | x_pred - x 344 | y_pred - y 345 | """ 346 | 347 | return (pca, x, y) 348 | 349 | 350 | def latent_stats(model, gen, nb_reps=100, tqdm=tqdm): 351 | """ 352 | Gather several latent_space statistics (mu, var) 353 | 354 | Parameters: 355 | gen: generator (will call next() on this a few times) 356 | model: model (will predict from generator samples) 357 | """ 358 | 359 | mu_data = [None] * nb_reps 360 | logvar_data = [None] * nb_reps 361 | for i in tqdm(range(nb_reps)): 362 | sample = next(gen) 363 | p = model.predict(sample[0]) 364 | mu_data[i] = p[1] 365 | logvar_data[i] = p[2] 366 | 367 | mu_data = np.vstack(mu_data) 368 | mu_data = np.reshape(mu_data, (mu_data.shape[0], -1)) 369 | 370 | logvar_data = np.vstack(logvar_data) 371 | logvar_data = np.reshape(logvar_data, (logvar_data.shape[0], -1)) 372 | 373 | data = {'mu': mu_data, 'logvar': logvar_data} 374 | return data 375 | 376 | 377 | def latent_stats_plots(model, gen, nb_reps=100, dim_1=0, dim_2=1, figsize=(15, 7), tqdm=tqdm): 378 | """ 379 | Make some debug/info (mostly latent-stats-related) plots 380 | 381 | Parameters: 382 | gen: generator (will call next() on this a few times) 383 | model: model (will predict from generator samples) 384 | """ 385 | 386 | data = latent_stats(model, gen, nb_reps=nb_reps, tqdm=tqdm) 387 | mu_data = data['mu'] 388 | logvar_data = data['logvar'] 389 | 390 | z = mu_data.shape[0] 391 | colors = np.linspace(0, 1, z) 392 | x = np.arange(mu_data.shape[1]) 393 | print('VAE plots: colors represent sample index') 394 | 395 | 396 | print('Sample plots (colors represent sample index)') 397 | datapoints = np.zeros(data['mu'].shape) 398 | for di, mu in tqdm(enumerate(data['mu']), leave=False): 399 | logvar = data['logvar'][di,...] 400 | eps = np.random.normal(loc=0, scale=1, size=(data['mu'].shape[-1])) 401 | datapoints[di, ...] = mu + np.exp(logvar / 2) * eps 402 | plt.figure(figsize=figsize) 403 | plt.subplot(1, 2, 1) 404 | plt.scatter(datapoints[:, dim_1], datapoints[:, dim_2], c=np.linspace(0, 1, datapoints.shape[0])) 405 | plt.title('sample dist. nb_reps=%d. colors = sample idx.' % nb_reps) 406 | plt.xlabel('dim %d' % dim_1) 407 | plt.ylabel('dim %d' % dim_2) 408 | 409 | plt.subplot(1, 2, 2) 410 | d_mean = np.mean(datapoints, 0) 411 | d_idx = np.argsort(d_mean) 412 | d_mean_sort = d_mean[d_idx] 413 | d_std_sort = np.std(datapoints, 0)[d_idx] 414 | plt.scatter(x, d_mean_sort, c=colors[d_idx]) 415 | plt.plot(x, d_mean_sort + d_std_sort, 'k') 416 | plt.plot(x, d_mean_sort - d_std_sort, 'k') 417 | plt.title('mean sample z. nb_reps=%d. colors = sorted dim.' % nb_reps) 418 | plt.xlabel('sorted dims') 419 | plt.ylabel('mean sample z') 420 | 421 | 422 | 423 | 424 | 425 | # plot 426 | plt.figure(figsize=figsize) 427 | plt.subplot(1, 2, 1) 428 | plt.scatter(mu_data[:, dim_1], mu_data[:, dim_2], c=colors) 429 | plt.title('mu dist. nb_reps=%d. colors = sample idx.' % nb_reps) 430 | plt.xlabel('dim %d' % dim_1) 431 | plt.ylabel('dim %d' % dim_2) 432 | plt.subplot(1, 2, 2) 433 | plt.scatter(logvar_data[:, dim_1], logvar_data[:, dim_2], c=colors) 434 | plt.title('logvar_data dist. nb_reps=%d. colors = sample idx.' % nb_reps) 435 | plt.xlabel('dim %d' % dim_1) 436 | plt.ylabel('dim %d' % dim_2) 437 | plt.show() 438 | 439 | # plot means and variances 440 | z = mu_data.shape[1] 441 | colors = np.linspace(0, 1, z) 442 | 443 | plt.figure(figsize=figsize) 444 | plt.subplot(1, 2, 1) 445 | mu_mean = np.mean(mu_data, 0) 446 | mu_idx = np.argsort(mu_mean) 447 | mu_mean_sort = mu_mean[mu_idx] 448 | mu_std_sort = np.std(mu_data, 0)[mu_idx] 449 | plt.scatter(x, mu_mean_sort, c=colors[mu_idx]) 450 | plt.plot(x, mu_mean_sort + mu_std_sort, 'k') 451 | plt.plot(x, mu_mean_sort - mu_std_sort, 'k') 452 | plt.title('mean mu. nb_reps=%d. colors = sorted dim.' % nb_reps) 453 | plt.xlabel('sorted dims') 454 | plt.ylabel('mean mu') 455 | 456 | plt.subplot(1, 2, 2) 457 | logvar_mean = np.mean(logvar_data, 0) 458 | logvar_mean_sort = logvar_mean[mu_idx] 459 | logvar_std_sort = np.std(logvar_data, 0)[mu_idx] 460 | plt.scatter(x, logvar_mean_sort, c=colors[mu_idx]) 461 | plt.plot(x, logvar_mean_sort + logvar_std_sort, 'k') 462 | plt.plot(x, logvar_mean_sort - logvar_std_sort, 'k') 463 | plt.title('mean logvar. nb_reps=%d' % nb_reps) 464 | plt.xlabel('sorted dims (diff than mu)') 465 | plt.ylabel('mean std') 466 | plt.show() 467 | 468 | 469 | 470 | 471 | return data 472 | 473 | 474 | 475 | ############################################################################### 476 | # helper functions 477 | ############################################################################### 478 | 479 | def _sample_batch_size(sample): 480 | """ 481 | get the batch size of a sample, while not knowing how many lists are in the input object. 482 | """ 483 | if isinstance(sample[0], (list, tuple)): 484 | return _sample_batch_size(sample[0]) 485 | else: 486 | return sample[0].shape[0] 487 | -------------------------------------------------------------------------------- /flowreg_a/register.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import numpy as np 3 | import os 4 | import glob 5 | import argparse 6 | from tensorflow.keras import Model 7 | import tensorflow as tf 8 | from skimage import transform 9 | 10 | # the Spatial Transformer layer used for the affine transform and resmapling was borrowed from 11 | # https://github.com/adalca/neurite orginally named neuron 12 | from neuron.layers import SpatialTransformer 13 | 14 | from utils import normalize 15 | from model import affmodel 16 | 17 | config = tf.ConfigProto() 18 | config.gpu_options.allow_growth = True 19 | session = tf.Session(config=config) 20 | 21 | 22 | def warpMask(affmat, mask): 23 | affmat_tensor = tf.convert_to_tensor(affmat, dtype=tf.float32) 24 | mask_tensor = tf.convert_to_tensor(mask.reshape(1, 256, 256, 55, 1), dtype=tf.float32) 25 | regMask= SpatialTransformer(interp_method='linear', indexing='ij')([mask_tensor, affmat_tensor]) 26 | regMask = np.squeeze(regMask.eval(session=session)) 27 | regMask = np.where(regMask > 0.1, 1, 0) 28 | return regMask 29 | 30 | 31 | def register(fixedDir, movingDir, saveDir, brainDir, ventDir, wmlDir, modelDir): 32 | 33 | fixedVol = sio.loadmat(fixedDir)['atlasFinal'] 34 | fixedVol = np.reshape(normalize(fixedVol), [1, 256, 256, 55, 1]) 35 | 36 | if os.path.isfile(movingDir): 37 | movingVols = [line.rstrip('\n') for line in open(movingDir)] 38 | elif os.path.isdir(movingDir): 39 | movingVols = glob.glob(movingDir + '/*.mat') 40 | movingVols.sort() 41 | 42 | model = affmodel([256, 256, 55, 1]) 43 | modelh5path = modelDir 44 | model.load_weights(modelh5path) 45 | 46 | for i, movingVol in enumerate(movingVols): 47 | 48 | name = os.path.basename(movingVol) 49 | print('Registering volume ', i, name) 50 | 51 | movingVol = normalize(transform.resize(sio.loadmat(movingVol)['im']['vol'][0][0], (256, 256, 55))) 52 | movingVol = np.reshape(normalize(movingVol), [1, 256, 256, 55, 1]) 53 | 54 | layer_name = 'Dense6' 55 | intermediate_layer_model = Model(inputs=model.input, 56 | outputs=model.get_layer(layer_name).output) 57 | affinemat = intermediate_layer_model.predict([fixedVol, movingVol]) 58 | regvol = model.predict([fixedVol, movingVol]) 59 | regvol = np.squeeze(regvol) 60 | 61 | if brainDir and ventDir and wmlDir: 62 | brainVol = sio.loadmat(brainDir + '/' + name)['brainMask'].astype('float32') 63 | ventVol = sio.loadmat(ventDir + '/' + name)['ventMask'].astype('float32') 64 | wmlVol = sio.loadmat(wmlDir + '/' + name)['wmlMask'].astype('float32') 65 | 66 | regBrain = warpMask(affinemat, brainVol) 67 | regVent = warpMask(affinemat, ventVol) 68 | regWml = warpMask(affinemat, wmlVol) 69 | sio.savemat(os.path.join(saveDir + name), {'regvol': regvol, 70 | 'brainMask': regBrain, 'ventMask': regVent, 'wmlMask': regWml, 'affine': affinemat}) 71 | 72 | else: 73 | sio.savemat(os.path.join(saveDir + name), {'regvol': regvol, 'affine': affinemat}) 74 | 75 | 76 | if __name__ == "__main__": 77 | parser = argparse.ArgumentParser(description="""FlowReg-Affine (FlowReg-A) register""") 78 | 79 | parser.add_argument('-r', '--register', help=' register volumes directory', type=str, dest='moving') 80 | parser.add_argument('-f', '--fixed', help=' fixed volume directory', type=str, dest='fixed') 81 | parser.add_argument('-s', '--save', help=' results save directory', type=str, dest='save_dir') 82 | parser.add_argument('-b', '--brain', help=' brain masks directory', type=str, dest='brain_dir') 83 | parser.add_argument('-v', '--vent', help=' ventricle masks directory', type=str, dest='vent_dir') 84 | parser.add_argument('-w', '--wml', help=' wml masks directory', type=str, dest='wml_dir') 85 | parser.add_argument('-m', '--model', help=' trained model weights directory', dest='model') 86 | 87 | args = parser.parse_args() 88 | 89 | register(args.fixed, args.moving, args.save_dir, args.brain_dir, args.vent_dir, args.wml_dir, args.model) -------------------------------------------------------------------------------- /flowreg_a/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import tensorflow as tf 4 | from keras.callbacks import ModelCheckpoint, CSVLogger 5 | import os 6 | import datetime 7 | import argparse 8 | import glob 9 | 10 | from model import affmodel 11 | from data_generator import DataGenerator 12 | 13 | config = tf.ConfigProto() 14 | config.gpu_options.allow_growth = True 15 | session = tf.Session(config=config) 16 | 17 | def train(batch_size, train, validation, fixed, checkpoint, epochs, save_loss, model_save_dir, weights): 18 | 19 | # to check if the train volumes are a directory or a text file 20 | if os.path.isfile(train): 21 | train_vol_names = [line.rstrip('\n') for line in open(train)] 22 | elif os.path.isdir(train): 23 | train_vol_names = glob.glob(train + '/*.mat') 24 | train_vol_names.sort() 25 | 26 | num_vols = len(train_vol_names) 27 | idx = np.arange(num_vols) 28 | np.random.shuffle(idx) 29 | train_ids = idx 30 | 31 | if os.path.isfile(validation): 32 | validation_vol_names = [line.rstrip('\n') for line in open(validation)] 33 | elif os.path.isdir(validation): 34 | validation_vol_names = glob.glob(validation + '/*.mat') 35 | validation_vol_names.sort() 36 | v_num_vols = len(validation_vol_names) 37 | idx = np.arange(v_num_vols) 38 | np.random.shuffle(idx) 39 | validation_ids = idx 40 | 41 | params = {'batch_size': batch_size, 42 | 'dim': (256, 256, 55), 43 | 'shuffle': True, 44 | 'n_channels': 1 45 | } 46 | train_gen = DataGenerator(vols=train_ids, fvol_dir=fixed, mvol_dir=train, **params) 47 | valid_gen = DataGenerator(vols=validation_ids, fvol_dir=fixed, mvol_dir=validation, **params) 48 | 49 | timestr = time.strftime('%Y%m%d-%H%M%S') 50 | os.mkdir('../checkpoints/' + timestr) 51 | checkpoint = ModelCheckpoint(filepath='../checkpoints/' + timestr + '/weights-{epoch:02d}.h5', 52 | verbose=1, period=checkpoint, 53 | save_weights_only=True) 54 | 55 | if save_loss: 56 | csv_logger = CSVLogger('../losses/' + str(datetime.datetime.now().strftime('%Y-%m-%d')) + '.csv', separator=',') 57 | callbacks = [checkpoint, csv_logger] 58 | else: 59 | callbacks = [checkpoint] 60 | 61 | model = affmodel([256, 256, 55, 1]) 62 | if os.path.isfile(weights): 63 | print("----Loading checkpoint weights----") 64 | model.load_weights(weights) 65 | 66 | model.fit_generator(train_gen, steps_per_epoch=len(train_ids)//batch_size, 67 | validation_data=valid_gen, validation_steps=len(validation_ids)//batch_size, 68 | verbose=1, epochs=epochs, 69 | callbacks=callbacks) 70 | model.save(model_save_dir + timestr +'.h5') 71 | print("------------Model Saved---------------") 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser(description="""FlowReg-Affine (FlowReg-A) training""") 75 | 76 | parser.add_argument('-t', '--train', help=' training volumes directory', type=str, dest='train') 77 | parser.add_argument('-v', '--validation', help=' validation volumes directory', type=str, 78 | dest='validation') 79 | parser.add_argument('-f', '--fixed', help=' fixed volume directory', type=str, dest='fixed') 80 | parser.add_argument('-b', '--batch', help=' batch size, default=4', type=int, dest='batch', default=4) 81 | parser.add_argument('-c', '--checkpoint', help=' weights save checkpoint, default=00', type=int, dest='checkpoint', default=0) 82 | parser.add_argument('-e', '--epochs', help=' number of training epochs, default=100', type=int, dest='epochs', default=100) 83 | parser.add_argument('-l', '--save_loss', help=' save loss across all epochs, default=TRUE', type=bool, dest='save_loss', default=True) 84 | parser.add_argument('-m', '--model_save', help=' model save directory', type=str, dest='model_save') 85 | parser.add_argument('-w', '--load_weights', help=' load additional weights', type=str, dest='load_weights') 86 | 87 | args = parser.parse_args() 88 | 89 | train(args.batch, args.train, args.validation, args.fixed, args.checkpoint, args.epochs, args.save_loss, args.model_save, args.load_weights) -------------------------------------------------------------------------------- /flowreg_a/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | def normalize(input): 5 | input = np.float32(input) 6 | xmin = np.amin(input) 7 | xmax = np.amax(input) 8 | b = 1. # max value 9 | a = 0. # min value 10 | if (xmax - xmin) == 0: 11 | out = input 12 | else: 13 | out = a+(b-a)*(input-xmin)/(xmax-xmin) 14 | return out -------------------------------------------------------------------------------- /flowreg_o/README.md: -------------------------------------------------------------------------------- 1 | # FlowReg-Optical Flow (FlowReg-O) 2 | A fully unsupervised framework for deformable registration using deep learning. 3 | 4 | ## Training 5 | To train using your own data, currently the framework supports `.mat` files, however, you may adapt to your specific needs. 6 | 7 | The application is command-line based and can be run in two easy steps by using your terminal of choice. 8 | 1. Navigate the `flowreg_o` directory. 9 | 2. Run `python train.py` with the appropriate arguments explained below. 10 | 11 | The command-line arguments are as follow. The text in *italics* is the expect data-type input: 12 | - `-t` or `--train` the directory of volumes used in training (*string*) 13 | - `-v` or `--validation` the directory of volumes used for validation (*string*) 14 | - `-f` or `--fixed` the fixed volume path (*string*) 15 | - `-b` or `--batch` the batch size used in training, default = 64 (*integer*) 16 | - `-e` or `--epochs` number of training epocs, default = 100 (*integer*) 17 | - `-c` or `--checkpoint` at which interval to save, default = 0 (*integer*) 18 | - `-l` or `--save_loss` save loss value to a csv during training, default = True (*boolean*) 19 | - `-m` or `--model_save` directory to save the final model (*string*) 20 | - `-w` or `--load_weights` directory to load previous weights, useful if training crashes after a few epochs (*string*) 21 | - `-a` or `--alpha_val` alpha value used for the loss during training (please see article for choosing an optimal value) 22 | 23 | Note: the `checkpoint` and `save_loss` will be saved in appropriate folders within the `flowreg_o` folder. Otherwise, it can be easily modified in the `train.py` file. 24 | 25 | An example command could look something like: 26 | ``` 27 | python train.py \ 28 | --train "path/to/train/directory" \ 29 | --validation "path/to/validation/directory" \ 30 | --fixed "path/to/fixed/volume.mat" \ 31 | --batch 32 \ 32 | --checkpoint 1 \ 33 | --epochs 100 \ 34 | --save_loss True \ 35 | --model_save "path/to/model/save/directory" 36 | ``` 37 | 38 | ## Registration 39 | If you have a trained model, the script to register volumes can be found in `register.py`. 40 | 41 | Similar to training, registration is done via a command-line interface with the following arguments: 42 | - `-r` or `--register` directory of the volumes to be registered (*string*) 43 | - `-f` or `--fixed` directory of the fixed volume (*string*) 44 | - `-s` or `--save` directory where to save the registered volumes (*string*) 45 | - `-m` or `--model` directory of the model weights, a .h5 file (*string*) 46 | 47 | (OPTIONAL) Binary masks can be passed as additional arguments that will be warped with the calculated 2D deformation field. These masks do not have to be the 'brain', 'ventricles', or 'wml' (white matter lesions) masks as specified in the argument name. Any binary mask can be used as long as they correspond to the orientation and dimension of the moving volume. 48 | - `-b` or `--brain` brain masks directory (*string*) 49 | - `-v` or `--vent` ventricle masks directory (*string*) 50 | - `-w` or `--wml` WML masks directory (*string*) 51 | 52 | The output `.mat` file will be the registered volume and the corresponding 2D deformation field for each slice in the moving volume. If masks are used, they will also be saved with `brainMask`, `ventMask`, or `wmlMask`. 53 | 54 | ## Citation 55 | If you use any portion of our work, please cite our paper. 56 | ``` 57 | S. Mocanu, A. Moody, and A. Khademi, “FlowReg: Fast Deformable Unsupervised Medical Image Registration using Optical Flow,” 58 | Machine Learning for Biomedical Imaging, pp. 1–40, Sep. 2021. 59 | ``` 60 | Available at: https://www.melba-journal.org/article/27657-flowreg-fast-deformable-unsupervised-medical-image-registration-using-optical-flow -------------------------------------------------------------------------------- /flowreg_o/__pycache__/data_generator.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAMLAB-Ryerson/FlowReg/b3f613f8fc36175b0fd832041c75c7d37508a976/flowreg_o/__pycache__/data_generator.cpython-35.pyc -------------------------------------------------------------------------------- /flowreg_o/__pycache__/loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAMLAB-Ryerson/FlowReg/b3f613f8fc36175b0fd832041c75c7d37508a976/flowreg_o/__pycache__/loss.cpython-35.pyc -------------------------------------------------------------------------------- /flowreg_o/__pycache__/model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAMLAB-Ryerson/FlowReg/b3f613f8fc36175b0fd832041c75c7d37508a976/flowreg_o/__pycache__/model.cpython-35.pyc -------------------------------------------------------------------------------- /flowreg_o/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAMLAB-Ryerson/FlowReg/b3f613f8fc36175b0fd832041c75c7d37508a976/flowreg_o/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /flowreg_o/data_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import scipy.io as sio 5 | 6 | from utils import rescale_imgs, rescale_img 7 | 8 | def generatedata(moving_vols, fixed_vol, batch_size): 9 | 10 | batchcount = 0 11 | 12 | if os.path.isdir(moving_vols): 13 | vols = glob.glob(moving_vols + '/*.mat') 14 | elif os.path.isfile(moving_vols): 15 | vols = [line.rstrip('\n') for line in open(moving_vols)] 16 | else: 17 | print( 18 | "Invalid training data. Should be .txt file containing (training/validation) set location or directory of (training/validation) volumes") 19 | 20 | # loading and resizing the fixed volume, one slice at a time 21 | fixed = sio.loadmat(fixed_vol)['atlasFinal'] 22 | fixed_vol_shape = fixed.shape 23 | fixed_reshaped_slices = [] 24 | for i in range(fixed_vol_shape[2]): 25 | fixed_reshaped_slices.append(rescale_imgs(fixed[:, :, i], (256, 256))) 26 | 27 | f0 = np.zeros((batch_size+1, 256, 256, 1)) 28 | f1 = np.zeros((batch_size+1, 128, 128, 1)) 29 | f2 = np.zeros((batch_size+1, 64, 64, 1)) 30 | f3 = np.zeros((batch_size+1, 32, 32, 1)) 31 | f4 = np.zeros((batch_size+1, 16, 16, 1)) 32 | f5 = np.zeros((batch_size+1, 8, 8, 1)) 33 | f6 = np.zeros((batch_size+1, 4, 4, 1)) 34 | 35 | z0 = np.zeros((batch_size+1, 256, 256, 1)) 36 | z1 = np.zeros((batch_size+1, 128, 128, 1)) 37 | z2 = np.zeros((batch_size+1, 64, 64, 1)) 38 | z3 = np.zeros((batch_size+1, 32, 32, 1)) 39 | z4 = np.zeros((batch_size+1, 16, 16, 1)) 40 | z5 = np.zeros((batch_size+1, 8, 8, 1)) 41 | z6 = np.zeros((batch_size+1, 4, 4, 1)) 42 | 43 | m = np.zeros((batch_size+1, 256, 256, 1)) 44 | 45 | while True: 46 | for s in range(55): 47 | for vol in vols: 48 | moving = sio.loadmat(vol)['regvol'] 49 | moving_img = rescale_img(moving[:, :, s], (256, 256)) 50 | 51 | 52 | f0[batchcount, :, :, :] = fixed_reshaped_slices[s][0] 53 | f1[batchcount, :, :, :] = fixed_reshaped_slices[s][1] 54 | f2[batchcount, :, :, :] = fixed_reshaped_slices[s][2] 55 | f3[batchcount, :, :, :] = fixed_reshaped_slices[s][3] 56 | f4[batchcount, :, :, :] = fixed_reshaped_slices[s][4] 57 | f5[batchcount, :, :, :] = fixed_reshaped_slices[s][5] 58 | f6[batchcount, :, :, :] = fixed_reshaped_slices[s][6] 59 | 60 | z0[batchcount, :, :, :] = fixed_reshaped_slices[s][7] 61 | z1[batchcount, :, :, :] = fixed_reshaped_slices[s][8] 62 | z2[batchcount, :, :, :] = fixed_reshaped_slices[s][9] 63 | z3[batchcount, :, :, :] = fixed_reshaped_slices[s][10] 64 | z4[batchcount, :, :, :] = fixed_reshaped_slices[s][11] 65 | z5[batchcount, :, :, :] = fixed_reshaped_slices[s][12] 66 | z6[batchcount, :, :, :] = fixed_reshaped_slices[s][13] 67 | 68 | m[batchcount, :, :, :] = moving_img 69 | 70 | batchcount += 1 71 | if batchcount > batch_size: 72 | # print('f0 shape', f0.shape, 'batch_size', batch_size) 73 | f0 = f0[0:batch_size, :, :, :] 74 | f1 = f1[0:batch_size, :, :, :] 75 | f2 = f2[0:batch_size, :, :, :] 76 | f3 = f3[0:batch_size, :, :, :] 77 | f4 = f4[0:batch_size, :, :, :] 78 | f5 = f5[0:batch_size, :, :, :] 79 | f6 = f6[0:batch_size, :, :, :] 80 | 81 | z0 = z0[0:batch_size, :, :, :] 82 | z1 = z1[0:batch_size, :, :, :] 83 | z2 = z2[0:batch_size, :, :, :] 84 | z3 = z3[0:batch_size, :, :, :] 85 | z4 = z4[0:batch_size, :, :, :] 86 | z5 = z5[0:batch_size, :, :, :] 87 | z6 = z6[0:batch_size, :, :, :] 88 | 89 | m = m[0:batch_size, :, :, :] 90 | 91 | X = [f0, m] 92 | y = [f0, f1, f2, f3, f4, f5, f6, 93 | z0, z1, z2, z3, z4, z5, z6] 94 | # print('input shape', X[0].shape, X[1].shape, 'output shape', y[0].shape, y[1].shape) 95 | yield (X, y) 96 | batchcount = 0 97 | f0 = np.zeros((batch_size + 1, 256, 256, 1)) 98 | f1 = np.zeros((batch_size + 1, 128, 128, 1)) 99 | f2 = np.zeros((batch_size + 1, 64, 64, 1)) 100 | f3 = np.zeros((batch_size + 1, 32, 32, 1)) 101 | f4 = np.zeros((batch_size + 1, 16, 16, 1)) 102 | f5 = np.zeros((batch_size + 1, 8, 8, 1)) 103 | f6 = np.zeros((batch_size + 1, 4, 4, 1)) 104 | 105 | z0 = np.zeros((batch_size + 1, 256, 256, 1)) 106 | z1 = np.zeros((batch_size + 1, 128, 128, 1)) 107 | z2 = np.zeros((batch_size + 1, 64, 64, 1)) 108 | z3 = np.zeros((batch_size + 1, 32, 32, 1)) 109 | z4 = np.zeros((batch_size + 1, 16, 16, 1)) 110 | z5 = np.zeros((batch_size + 1, 8, 8, 1)) 111 | z6 = np.zeros((batch_size + 1, 4, 4, 1)) 112 | 113 | m = np.zeros((batch_size + 1, 256, 256, 1)) -------------------------------------------------------------------------------- /flowreg_o/loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | alphaval = 0.2 4 | 5 | def photometric_loss(true, pred): 6 | size = tf.cast(tf.size(pred), tf.float32) 7 | alpha = alphaval 8 | beta = 1.0 9 | diff = tf.abs(true-pred) 10 | dist = tf.reduce_sum(diff, axis=3, keep_dims=True) 11 | loss = charbonnier(dist, alpha, beta, 0.001) 12 | photo_loss = tf.reduce_sum(loss)/size 13 | 14 | corr_loss = correlation_loss(true, pred) 15 | lam_bda = 0.5 16 | return (lam_bda*photo_loss) + corr_loss 17 | # return lam_bda*photo_loss 18 | 19 | 20 | def smoothness_loss(zflow, flow): 21 | size = tf.cast(tf.size(flow), tf.float32) 22 | x, y = tf.unstack(flow, axis=3) 23 | x = tf.expand_dims(x, axis=3) 24 | y = tf.expand_dims(y, axis=3) 25 | 26 | u = tf.constant([[0., 0., 0.], 27 | [0., 1., -1.], 28 | [0., 0., 0.]]) 29 | u = tf.expand_dims(u, axis=2) 30 | u = tf.expand_dims(u, axis=3) 31 | v = tf.constant([[0., 0., 0.], 32 | [0., 1., 0.], 33 | [0., -1., 0.]]) 34 | v = tf.expand_dims(v, axis=2) 35 | v = tf.expand_dims(v, axis=3) 36 | 37 | u_diff = tf.nn.conv2d(x, u, strides=[1, 1, 1, 1], padding='SAME') 38 | v_diff = tf.nn.conv2d(y, v, strides=[1, 1, 1, 1], padding='SAME') 39 | all_diff = tf.concat([u_diff, v_diff], axis=3) 40 | dists = tf.reduce_sum(tf.abs(all_diff), axis=3, keep_dims=True) 41 | alpha = alphaval 42 | beta = 1.0 43 | lam_bda = 1. 44 | loss = charbonnier(dists, alpha, beta, 0.001) 45 | 46 | return lam_bda*(tf.reduce_sum(loss)/size) 47 | 48 | def charbonnier(x, alpha, beta, epsilon): 49 | x = x*beta 50 | out = tf.pow((tf.square(x)+tf.square(epsilon)), alpha) 51 | return out 52 | 53 | def correlation_loss(true, pred): 54 | true = tf.cast(true, tf.float32) 55 | pred = tf.cast(pred, tf.float32) 56 | 57 | mux = tf.reduce_mean(true) 58 | muy = tf.reduce_mean(pred) 59 | n = tf.cast(tf.size(true), tf.float32) 60 | 61 | varx = tf.reduce_sum(tf.square(true - mux))/n 62 | vary = tf.reduce_sum(tf.square(pred - muy))/n 63 | 64 | corr = 1/n * tf.reduce_sum((true - mux) * (pred - muy)) / tf.math.sqrt(varx * vary) 65 | 66 | return 1-corr 67 | -------------------------------------------------------------------------------- /flowreg_o/model.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input, Conv2D, Conv2DTranspose, Lambda, LeakyReLU, concatenate 2 | from keras import optimizers 3 | from keras.models import Model 4 | 5 | from loss import photometric_loss, smoothness_loss 6 | from utils import rescale_tensors, warp_tensors 7 | 8 | def flowmodelS(shape, batch_size): 9 | fixedinput = Input(shape=shape, batch_shape=(batch_size, 256, 256, 1), name='fixedinput') 10 | movinginput = Input(shape=shape, batch_shape=(batch_size, 256, 256, 1), name='movinginput') 11 | 12 | 13 | # encoder 14 | inputs = concatenate([fixedinput, movinginput], axis=3, name='inputs') 15 | conv1 = Conv2D(filters=64, kernel_size=7, strides=2, padding='same', name='conv1')(inputs) 16 | conv1 = LeakyReLU(alpha=0.1, name='LeakyReLu1')(conv1) 17 | conv2 = Conv2D(filters=128, kernel_size=5, strides=2, padding='same', name='conv2')(conv1) 18 | conv2 = LeakyReLU(alpha=0.1, name='LeakyReLu2')(conv2) 19 | conv3 = Conv2D(filters=256, kernel_size=5, strides=2, padding='same', name='conv3')(conv2) 20 | conv3 = LeakyReLU(alpha=0.1, name='LeakyReLu3')(conv3) 21 | conv3_1 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same', name='conv3_1')(conv3) 22 | conv3_1 = LeakyReLU(alpha=0.1, name='LeakyReLu4')(conv3_1) 23 | conv4 = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', name='conv4')(conv3_1) 24 | conv4 = LeakyReLU(alpha=0.1, name='LeakyReLu5')(conv4) 25 | conv4_1 = Conv2D(filters=512, kernel_size=3, strides=1, padding='same', name='conv4_1')(conv4) 26 | conv4_1 = LeakyReLU(alpha=0.1, name='LeakyReLu6')(conv4_1) 27 | conv5 = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', name='conv5')(conv4_1) 28 | conv5 = LeakyReLU(alpha=0.1, name='LeakyReLu7')(conv5) 29 | conv5_1 = Conv2D(filters=512, kernel_size=3, strides=1, padding='same', name='conv5_1')(conv5) 30 | conv5_1 = LeakyReLU(alpha=0.1, name='LeakyReLu8')(conv5_1) 31 | conv6 = Conv2D(filters=1024, kernel_size=3, strides=2, padding='same', name='conv6')(conv5_1) 32 | conv6 = LeakyReLU(alpha=0.1, name='LeakyReLu9')(conv6) 33 | conv6_1 = Conv2D(filters=1024, kernel_size=3, strides=1, padding='same', name='conv6_1')(conv6) 34 | conv6_1 = LeakyReLU(alpha=0.1, name='LeakyReLu10')(conv6_1) 35 | 36 | # decoder 37 | flow6 = Conv2D(filters=2, kernel_size=3, strides=1, padding='same', name='flow6')(conv6_1) 38 | flow6_up = Conv2DTranspose(filters=2, kernel_size=4, strides=2, padding='same', name='flow6_up')(flow6) 39 | 40 | upconv5 = Conv2DTranspose(filters=512, kernel_size=4, strides=2, padding='same', name='upconv5')(conv6_1) 41 | upconv5 = LeakyReLU(alpha=0.1, name='LeakyReLu11')(upconv5) 42 | concat5 = concatenate([upconv5, conv5_1, flow6_up], axis=3, name='concat5') 43 | flow5 = Conv2D(filters=2, kernel_size=3, strides=1, padding='same', name='flow5')(concat5) 44 | flow5_up = Conv2DTranspose(filters=2, kernel_size=4, strides=2, padding='same', name='flow5_up')(flow5) 45 | 46 | upconv4 = Conv2DTranspose(filters=256, kernel_size=4, strides=2, padding='same', name='upconv4')(concat5) 47 | upconv4 = LeakyReLU(alpha=0.1, name='LeakyReLu12')(upconv4) 48 | concat4 = concatenate([upconv4, conv4_1, flow5_up], axis=3, name='concat4') 49 | flow4 = Conv2D(filters=2, kernel_size=3, strides=1, padding='same', name='flow4')(concat4) 50 | flow4_up = Conv2DTranspose(filters=2, kernel_size=4, strides=2, padding='same', name='flow4_up')(flow4) 51 | 52 | upconv3 = Conv2DTranspose(filters=128, kernel_size=4, strides=2, padding='same', name='upconv3')(concat4) 53 | upconv3 = LeakyReLU(alpha=0.1, name='LeakyReLu13')(upconv3) 54 | concat3 = concatenate([upconv3, conv3_1, flow4_up], axis=3, name='concat3') 55 | flow3 = Conv2D(filters=2, kernel_size=3, strides=1, padding='same', name='flow3')(concat3) 56 | flow3_up = Conv2DTranspose(filters=2, kernel_size=4, strides=2, padding='same', name='flow3_up')(flow3) 57 | 58 | upconv2 = Conv2DTranspose(filters=64, kernel_size=4, strides=2, padding='same', name='upconv2')(concat3) 59 | upconv2 = LeakyReLU(alpha=0.1, name='LeakyReLu14')(upconv2) 60 | concat2 = concatenate([upconv2, conv2, flow3_up], axis=3, name='concat2') 61 | flow2 = Conv2D(filters=2, kernel_size=3, strides=1, padding='same', name='flow2')(concat2) 62 | flow2_up = Conv2DTranspose(filters=2, kernel_size=4, strides=2, padding='same', name='flow2_up')(flow2) 63 | 64 | upconv1 = Conv2DTranspose(filters=32, kernel_size=4, strides=2, padding='same', name='upconv1')(concat2) 65 | upconv1 = LeakyReLU(alpha=0.1, name='LeakyReLu15')(upconv1) 66 | concat1 = concatenate([upconv1, conv1, flow2_up], axis=3, name='concat1') 67 | flow1 = Conv2D(filters=2, kernel_size=3, strides=1, padding='same', name='flow1')(concat1) 68 | flow1_up = Conv2DTranspose(filters=2, kernel_size=4, strides=2, padding='same', name='flow1_up')(flow1) 69 | 70 | upconv0 = Conv2DTranspose(filters=16, kernel_size=4, strides=2, padding='same', name='upconv0')(concat1) 71 | upconv0 = LeakyReLU(alpha=0.1, name='LeakyReLu16')(upconv0) 72 | concat0 = concatenate([upconv0, inputs, flow1_up], axis=3, name='concat0') 73 | flow0 = Conv2D(filters=2, kernel_size=3, strides=1, padding='same', name='flow0')(concat0) 74 | 75 | 76 | rescaled_moving_input = Lambda(rescale_tensors, name='rescaling')(movinginput) 77 | out0 = Lambda(lambda x: warp_tensors(*x), name='out0')([rescaled_moving_input[0], flow0]) 78 | out1 = Lambda(lambda x: warp_tensors(*x), name='out1')([rescaled_moving_input[1], flow1]) 79 | out2 = Lambda(lambda x: warp_tensors(*x), name='out2')([rescaled_moving_input[2], flow2]) 80 | out3 = Lambda(lambda x: warp_tensors(*x), name='out3')([rescaled_moving_input[3], flow3]) 81 | out4 = Lambda(lambda x: warp_tensors(*x), name='out4')([rescaled_moving_input[4], flow4]) 82 | out5 = Lambda(lambda x: warp_tensors(*x), name='out5')([rescaled_moving_input[5], flow5]) 83 | out6 = Lambda(lambda x: warp_tensors(*x), name='out6')([rescaled_moving_input[6], flow6]) 84 | 85 | outputs = [out0, out1, out2, out3, out4, out5, out6, 86 | flow0, flow1, flow2, flow3, flow4, flow5, flow6] 87 | loss = { 88 | 'out0': photometric_loss, 89 | 'out1': photometric_loss, 90 | 'out2': photometric_loss, 91 | 'out3': photometric_loss, 92 | 'out4': photometric_loss, 93 | 'out5': photometric_loss, 94 | 'out6': photometric_loss, 95 | 'flow0': smoothness_loss, 96 | 'flow1': smoothness_loss, 97 | 'flow2': smoothness_loss, 98 | 'flow3': smoothness_loss, 99 | 'flow4': smoothness_loss, 100 | 'flow5': smoothness_loss, 101 | 'flow6': smoothness_loss, 102 | } 103 | 104 | adam = optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, decay=0.0005) 105 | model = Model(inputs=[fixedinput, movinginput], outputs=outputs) 106 | model.compile(optimizer=adam, loss=loss) 107 | return model -------------------------------------------------------------------------------- /flowreg_o/register.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import numpy as np 4 | import scipy.io as sio 5 | import tensorflow as tf 6 | from skimage import transform 7 | import argparse 8 | 9 | from utils import normalize, rescale_img, rescale_imgs 10 | from model import flowmodelS 11 | 12 | img_size = (256, 256) 13 | 14 | config = tf.ConfigProto() 15 | config.gpu_options.allow_growth = True 16 | session = tf.Session(config=config) 17 | 18 | def warpMask(flow, mask): 19 | mask_tensor = mask.transpose(2, 0, 1) 20 | mask_tensor = tf.convert_to_tensor(mask_tensor.reshape(55, 256, 256, 1)) 21 | regMask = tf.contrib.image.dense_image_warp(mask_tensor, flow, name='warpingmask') 22 | regMask = np.squeeze(regMask.eval(session=session)).transpose(1, 2, 0) 23 | regMask = np.where(regMask > 0.1, 1, 0) 24 | return regMask 25 | 26 | def register(modelweights, fixed_vol, moving, brain_dir, vent_dir, wml_dir, save_dir): 27 | fixed_vol = normalize(sio.loadmat(fixed_vol)['atlasFinal']) 28 | x, y, z = fixed_vol.shape 29 | 30 | vols = glob.glob(moving + "/*.mat") 31 | 32 | model = flowmodelS(shape=[256, 256, 1], batch_size=1) 33 | model.load_weights(modelweights) 34 | 35 | for j, vol in enumerate(vols): 36 | 37 | config = tf.ConfigProto() 38 | config.gpu_options.allow_growth = True 39 | session = tf.Session(config=config) 40 | 41 | name = os.path.splitext(os.path.basename(vol))[0] 42 | 43 | if os.path.isfile(save_dir + "/" + name): 44 | print('Skip-------') 45 | print(name) 46 | continue 47 | 48 | volume = sio.loadmat(vol) 49 | moving_vol = volume['resized'] 50 | moving_vol = normalize(transform.resize(moving_vol, (x, y, z))) 51 | 52 | regvol = np.empty((x, y, z)) 53 | flowvol = np.empty((256, 256, z, 2)) 54 | # flowvol128 = np.empty((128, 128, z, 2)) 55 | # flowvol64 = np.empty((64, 64, z, 2)) 56 | # flowvol32 = np.empty((32, 32, z, 2)) 57 | # flowvol16 = np.empty((16, 16, z, 2)) 58 | # flowvol8 = np.empty((8, 8, z, 2)) 59 | # flowvol4 = np.empty((4, 4, z, 2)) 60 | 61 | 62 | if 'brainMask' in volume and 'ventMask' in volume and 'wmlMask' in volume: 63 | brainMask = volume['brainMask'].astype(np.float32) 64 | ventMask = volume['ventMask'].astype(np.float32) 65 | wmlMask = volume['wmlMask'].astype(np.float32) 66 | regBrain = np.empty((x, y, z)) 67 | regVent = np.empty((x, y, z)) 68 | regWML = np.empty((x, y, z)) 69 | 70 | 71 | for i in range(z): 72 | print('Registering Volume', j, name, ' Slice: ', i) 73 | 74 | fixed_img = rescale_imgs(fixed_vol[:, :, i], img_size=img_size) 75 | moving_img = rescale_img(moving_vol[:, :, i], img_size=img_size).reshape(1, 256, 256, 1) 76 | 77 | out = model.predict(x=[fixed_img[0], moving_img]) 78 | reg_img = np.squeeze(out[0]) 79 | flow = np.squeeze(out[7]) 80 | # flow128 = np.squeeze(out[8]) 81 | # flow64 = np.squeeze(out[9]) 82 | # flow32 = np.squeeze(out[10]) 83 | # flow16 = np.squeeze(out[11]) 84 | # flow8 = np.squeeze(out[12]) 85 | # flow4 = np.squeeze(out[13]) 86 | 87 | regvol[:, :, i] = reg_img 88 | flowvol[:, :, i, :] = flow 89 | # flowvol256[:, :, i, :] = flow256 90 | # flowvol128[:, :, i, :] = flow128 91 | # flowvol64[:, :, i, :] = flow64 92 | # flowvol32[:, :, i, :] = flow32 93 | # flowvol16[:, :, i, :] = flow16 94 | # flowvol8[:, :, i, :] = flow8 95 | # flowvol4[:, :, i, :] = flow4 96 | 97 | 98 | if 'brainMask' in volume and 'ventMask' in volume and 'wmlMask' in volume: 99 | 100 | brainMask = volume['brainMask'].astype(np.float32) 101 | ventMask = volume['ventMask'].astype(np.float32) 102 | wmlMask = volume['wmlMask'].astype(np.float32) 103 | regBrain = np.empty((x, y, z)) 104 | regVent = np.empty((x, y, z)) 105 | regWML = np.empty((x, y, z)) 106 | 107 | ft = tf.convert_to_tensor(flowvol.transpose(2, 0, 1, 3)) 108 | 109 | regBrain = warpMask(ft, brainMask) 110 | regVent = warpMask(ft, ventMask) 111 | regWML = warpMask(ft, wmlMask) 112 | 113 | tf.keras.backend.clear_session() 114 | 115 | sio.savemat(os.path.join(save_dir, name), {'regvol': regvol, 'flow': flowvol, 116 | 'brainMask': regBrain, 'ventMask': regVent, 'wmlMask': regWML}) 117 | else: 118 | sio.savemat(os.path.join(save_dir, name), {'regvol': regvol, 'flow': flowvol}) 119 | 120 | print("Registered Volume Saved Successfully") 121 | 122 | 123 | if __name__ == "__main__": 124 | parser = argparse.ArgumentParser(description="""FlowReg-Affine (FlowReg-A) register""") 125 | 126 | parser.add_argument('-r', '--register', help=' training volumes directory', type=str, dest='moving') 127 | parser.add_argument('-f', '--fixed', help=' fixed volume directory', type=str, dest='fixed') 128 | parser.add_argument('-s', '--save', help=' results save directory', type=str, dest='save_dir') 129 | parser.add_argument('-b', '--brain', help=' brain masks directory', type=str, dest='brain_dir') 130 | parser.add_argument('-v', '--vent', help=' ventricle masks directory', type=str, dest='vent_dir') 131 | parser.add_argument('-w', '--wml', help=' wml masks directory', type=str, dest='wml_dir') 132 | parser.add_argument('-m', '--model', help=' trained model weights directory', dest='model') 133 | 134 | args = parser.parse_args() 135 | 136 | register(args.fixed, args.moving, args.save_dir, args.brain_dir, args.vent_dir, args.wml_dir, args.model) -------------------------------------------------------------------------------- /flowreg_o/train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import datetime 4 | import time 5 | import argparse 6 | 7 | from keras.callbacks import ModelCheckpoint, CSVLogger 8 | 9 | from data_generator import generatedata 10 | from model import flowmodelS 11 | 12 | def train(fixed, train, validation, batch_size, epochs, checkpoint, save_path, save_loss, alpha, weights): 13 | 14 | # get data from generator 15 | print('generating data') 16 | generate_train = generatedata(train, fixed, batch_size) 17 | generate_validation = generatedata(validation, fixed, batch_size) 18 | 19 | if os.path.isdir(train): 20 | train_vols = glob.glob(train + '/*.mat') 21 | elif os.path.isfile(train): 22 | train_vols = [line.rstrip('\n') for line in open(train)] 23 | else: 24 | print( 25 | "Invalid training data. Should be .txt file containing (training/validation) set location or directory of (training/validation) volumes") 26 | 27 | if os.path.isdir(train): 28 | validation_vols = glob.glob(validation + '/*.mat') 29 | elif os.path.isfile(train): 30 | validation_vols = [line.rstrip('\n') for line in open(validation)] 31 | else: 32 | print( 33 | "Invalid training data. Should be .txt file containing (training/validation) set location or directory of (training/validation) volumes") 34 | 35 | timestr = time.strftime('%Y%m%d-%H%M%S') 36 | checkpointdir = 'G:/My Drive/MASc/Code/python/flowReg/checkpoint/' + timestr + 'alpha' + alpha 37 | os.mkdir(checkpointdir) 38 | checkpoint = ModelCheckpoint(filepath=checkpointdir + '/weights-{epoch:02d}.h5', 39 | verbose=1, period=checkpoint, 40 | save_weights_only=True) 41 | datestr = str( 42 | datetime.datetime.now().strftime('%Y-%m-%d')) 43 | if save_loss: 44 | csv_logger = CSVLogger('G:/My Drive/MASc/Code/python/flowreg2d/losses/' + datestr + '.csv', separator=',') 45 | callbacks = [checkpoint, csv_logger] 46 | else: 47 | callbacks = [checkpoint] 48 | 49 | model = flowmodelS(shape=[256, 256, 1], batch_size=batch_size) 50 | model.summary() 51 | if weights: 52 | print("loading previously trained weights------------") 53 | model.load_weights(weights) 54 | model.fit_generator(generate_train, steps_per_epoch=len(train_vols) * 55 // batch_size, 55 | validation_data=generate_validation, validation_steps=len(validation_vols) * 55 // batch_size, 56 | verbose=1, epochs=epochs, 57 | callbacks=callbacks) 58 | 59 | model.save(save_path + timestr + '.h5') 60 | print("------------Model Saved---------------") 61 | 62 | return 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser(description="""FlowReg-OpticalFlow (FlowReg-O) training""") 67 | 68 | parser.add_argument('-t', '--train', help=' training volumes directory', type=str, dest='train') 69 | parser.add_argument('-v', '--validation', help=' validation volumes directory', type=str, 70 | dest='validation') 71 | parser.add_argument('-f', '--fixed', help=' fixed volume directory', type=str, dest='fixed') 72 | parser.add_argument('-b', '--batch', help=' batch size, default=4', type=int, dest='batch_size', default=64) 73 | parser.add_argument('-c', '--checkpoint', help=' weights save checkpoint, default=00', type=int, dest='checkpoint', default=0) 74 | parser.add_argument('-e', '--epochs', help=' number of training epochs, default=100', type=int, dest='epochs', default=100) 75 | parser.add_argument('-l', '--save_loss', help=' save loss across all epochs, default=TRUE', type=bool, dest='save_loss', default=True) 76 | parser.add_argument('-m', '--model_save', help=' model save directory', type=str, dest='model_save') 77 | parser.add_argument('-a', '--alpha', help=' alpha value for loss function during training, default = 0.20', type=str, dest='alpha', default='0.20') 78 | parser.add_argument('-w', '--load_weights', help=' location of weights to load', type=str, dest='load_weights') 79 | 80 | args = parser.parse_args() 81 | 82 | train(args.fixed, args.train, args.validation, args.batch_size, args.epochs, args.checkpoint, args.model_save, args.save_loss, args.alpha, args.load_weights) 83 | -------------------------------------------------------------------------------- /flowreg_o/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage import transform 3 | import tensorflow as tf 4 | 5 | def normalize(input): 6 | input = np.float32(input) 7 | xmin = np.amin(input) 8 | xmax = np.amax(input) 9 | b = 1. # max value (17375) 10 | a = 0. # min value (0) 11 | if (xmax - xmin) == 0: 12 | out = input 13 | else: 14 | out = a+(b-a)*(input-xmin)/(xmax-xmin) 15 | return out 16 | 17 | 18 | def rescale_img(img, img_size): 19 | contrast = np.random.uniform(low=0.7, high=1.3) 20 | brightness = np.random.normal(0, 0.1, 1) 21 | img = img*contrast + brightness 22 | r_img = transform.resize(img, img_size, anti_aliasing=True) 23 | return normalize(r_img).reshape(1, 256, 256, 1) 24 | 25 | 26 | def rescale_imgs(img, img_size): 27 | noise = np.random.normal(0, 0.1, (256, 256)) 28 | contrast = np.random.uniform(low=0.7, high=1.3) 29 | brightness = np.random.normal(0, 0.1, 1) 30 | img = img*contrast + brightness 31 | r_img0 = transform.resize(img, img_size, anti_aliasing=True) 32 | r_img0 = normalize(r_img0.reshape(1, r_img0.shape[0], r_img0.shape[1], 1)) 33 | r_img1 = transform.resize(img, (img_size[0]//2, img_size[1]//2), anti_aliasing=True) 34 | r_img1 = normalize(r_img1.reshape(1, r_img1.shape[0], r_img1.shape[1], 1)) 35 | r_img2 = transform.resize(img, (img_size[0]//4, img_size[1]//4), anti_aliasing=True) 36 | r_img2 = normalize(r_img2.reshape(1, r_img2.shape[0], r_img2.shape[1], 1)) 37 | r_img3 = transform.resize(img, (img_size[0]//8, img_size[1]//8), anti_aliasing=True) 38 | r_img3 = normalize(r_img3.reshape(1, r_img3.shape[0], r_img3.shape[1], 1)) 39 | r_img4 = transform.resize(img, (img_size[0]//16, img_size[1]//16), anti_aliasing=True) 40 | r_img4 = normalize(r_img4.reshape(1, r_img4.shape[0], r_img4.shape[1], 1)) 41 | r_img5 = transform.resize(img, (img_size[0]//32, img_size[1]//32), anti_aliasing=True) 42 | r_img5 = normalize(r_img5.reshape(1, r_img5.shape[0], r_img5.shape[1], 1)) 43 | r_img6 = transform.resize(img, (img_size[0]//64, img_size[1]//64), anti_aliasing=True) 44 | r_img6 = normalize(r_img6.reshape(1, r_img6.shape[0], r_img6.shape[1], 1)) 45 | 46 | zero_flow0 = np.float32(np.zeros(r_img0.shape)) 47 | zero_flow1 = np.float32(np.zeros(r_img1.shape)) 48 | zero_flow2 = np.float32(np.zeros(r_img2.shape)) 49 | zero_flow3 = np.float32(np.zeros(r_img3.shape)) 50 | zero_flow4 = np.float32(np.zeros(r_img4.shape)) 51 | zero_flow5 = np.float32(np.zeros(r_img5.shape)) 52 | zero_flow6 = np.float32(np.zeros(r_img6.shape)) 53 | return [r_img0, r_img1, r_img2, r_img3, r_img4, r_img5, r_img6, 54 | zero_flow0, zero_flow1, zero_flow2, zero_flow3, zero_flow4, zero_flow5, zero_flow6] 55 | 56 | 57 | def rescale_tensors(img): 58 | img0 = img 59 | img1 = tf.image.resize_bicubic(img, [128, 128]) 60 | img2 = tf.image.resize_bicubic(img, [64, 64]) 61 | img3 = tf.image.resize_bicubic(img, [32, 32]) 62 | img4 = tf.image.resize_bicubic(img, [16, 16]) 63 | img5 = tf.image.resize_bicubic(img, [8, 8]) 64 | img6 = tf.image.resize_bicubic(img, [4, 4]) 65 | return [img0, img1, img2, img3, img4, img5, img6] 66 | 67 | def warp_tensors(img, flow): 68 | warped = tf.contrib.image.dense_image_warp(img, flow, name='dense_image_warp') 69 | return warped --------------------------------------------------------------------------------