├── LICENSE ├── README.md ├── callbacks ├── __init__.py ├── dafnet_image_callback.py ├── image_callback.py ├── loss_callback.py └── swa.py ├── configuration ├── __init__.py ├── dafnet_config_chaos.py ├── dafnet_spade_config_chaos.py └── mmsdnet_config_chaos.py ├── costs.py ├── experiment.py ├── layers ├── __init__.py ├── film.py ├── interpolate_spline.py ├── rounding.py ├── spade.py ├── spectralnorm.py └── stn_spline.py ├── loaders ├── MultimodalPairedData.py ├── __init__.py ├── base_loader.py ├── chaos.py ├── data.py ├── dcm_contour_utils.py └── loader_factory.py ├── model_components ├── __init__.py ├── anatomy_encoder.py ├── anatomy_fuser.py ├── balancer.py ├── decoder.py ├── modality_encoder.py └── segmentor.py ├── model_executors ├── __init__.py ├── base_executor.py ├── dafnet_executor.py └── mmsdnet_executor.py ├── model_tester.py ├── models ├── __init__.py ├── basenet.py ├── dafnet.py ├── discriminator.py ├── mmsdnet.py └── unet.py ├── pseudocode.pdf ├── requirements.txt └── utils ├── __init__.py ├── data_utils.py ├── distributions.py ├── image_utils.py ├── model_utils.py └── sdnet_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Agis Chartsias 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal segmentation with disentangled representations 2 | 3 | 4 | Implementations of the **MMSDNet** and **DAFNet** models that perform multimodal image segmentation using a disentangled representation of anatomy and modality factors. For further details please see our paper [Multimodal Cardiac Segmentation Using Disentangled Representation Learning] presented in STACOM 2019, and the pre-print [Disentangle align and fuse for multimodal and semi-supervised image segmentation]. 5 | 6 | Pseudocode of the training process is uploaded on [pseudocode.pdf](pseudocode.pdf) 7 | 8 | 9 | The code is written in Python 3.6 with [Keras] 2.1.6 and [tensorflow] 1.4.0 and 10 | experiments were run in Titan-X and Titan-V GPUs. The `requirements.txt` file contains all Python library versions. 11 | 12 | This project is structured in different packages as follows: 13 | 14 | * **callbacks**: contains Keras callbacks for printing images and losses during training 15 | * **configuration**: contains configuration files for running an experiment 16 | * **layers**: contains custom Keras layers, e.g. the STN layer 17 | * **loaders**: contains definitions of data loaders 18 | * **model_components**: contains implementations of individual components, e.g. the encoders and decoders 19 | * **model_executors**: contains code for loading data and training models 20 | * **models**: contains Keras implementations of MMSDNet, DAFNet 21 | * **utils**: package with utility functions used in the project 22 | 23 | To define a new data loader, extend class `base_loader.Loader`, and register the loader in `loader_factory.py`. The datapaths are specified in `base_loader.py`. 24 | 25 | To run an experiment, execute `experiment.py`, passing the configuration filename and the split number as runtime parameters: 26 | ``` 27 | python experiment.py --config myconfiguration --split 0 28 | ``` 29 | Optional parameters include the proportion of labels in modality 2 data for unsupervised learning, e.g. 30 | ``` 31 | python experiment.py --config myconfiguration --split 0 --l_mix 0.5 32 | ``` 33 | 34 | Sample config files for MMSDNet and DAFNet are placed in the **configuration** package for [CHAOS] data. 35 | 36 | ## Citation 37 | 38 | If you use this code for your research, please cite our papers: 39 | 40 | ``` 41 | @InProceedings{chartsias2020multimodal, 42 | author="Chartsias, Agisilaos and Papanastasiou, Giorgos and Wang, Chengjia and Stirrat, Colin and Semple, Scott and Newby, David and Dharmakumar, Rohan and Tsaftaris, Sotirios A.", 43 | title="Multimodal Cardiac Segmentation Using Disentangled Representation Learning", 44 | booktitle="Statistical Atlases and Computational Models of the Heart. Multi-Sequence CMR Segmentation, CRT-EPiggy and LV Full Quantification Challenges", 45 | year="2020", 46 | publisher="Springer International Publishing", 47 | address="Cham", 48 | pages="128--137", 49 | isbn="978-3-030-39074-7" 50 | } 51 | ``` 52 | 53 | ``` 54 | @article{chartsias2020disentangle, 55 | title={Disentangle, align and fuse for multimodal and semi-supervised image segmentation}, 56 | author={Chartsias, Agisilaos and Papanastasiou, Giorgos and Wang, Chengjia and Semple, Scott and Newby, David and Dharmakumar, Rohan and Tsaftaris, Sotirios A}, 57 | journal={arXiv preprint arXiv:1911.04417}, 58 | year={2019} 59 | } 60 | 61 | ``` 62 | 63 | [Multimodal Cardiac Segmentation Using Disentangled Representation Learning]: https://link.springer.com/chapter/10.1007/978-3-030-39074-7_14 64 | [Disentangle Align and Fuse for Multimodal and Zero-shot Image Segmentation]: https://arxiv.org/abs/1911.04417 65 | [Keras]: https://keras.io/ 66 | [tensorflow]: https://www.tensorflow.org/ 67 | [CHAOS]: http://doi.org/10.5281/zenodo.3362844 68 | -------------------------------------------------------------------------------- /callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/multimodal_segmentation/a4fa1b39830f6c1bc320ff5b5e3fda82b8382e18/callbacks/__init__.py -------------------------------------------------------------------------------- /callbacks/dafnet_image_callback.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import scipy.misc 7 | from keras import Model 8 | 9 | import utils 10 | import utils.image_utils 11 | import utils.data_utils 12 | from callbacks.image_callback import BaseSaveImage, get_s0chn 13 | from utils.distributions import NormalDistribution 14 | from utils.sdnet_utils import get_net 15 | 16 | log = logging.getLogger('callback') 17 | 18 | 19 | class DAFNetImageCallback(BaseSaveImage): 20 | """ 21 | Image callback for saving images during DAFNet training. 22 | Images are saved in a subfolder with name training_images, created inside the experiment folder. 23 | """ 24 | def __init__(self, conf, model, data_gen_lb): 25 | """ 26 | :param conf: configuration object 27 | :param model: a DAFNet model 28 | :param data_gen_lb: a python iterator of images+masks 29 | """ 30 | self.conf = conf 31 | super(DAFNetImageCallback, self).__init__(conf.folder, model) 32 | self._make_dirs(self.folder) 33 | self.data_gen_lb = data_gen_lb 34 | self.init_models() 35 | 36 | def _make_dirs(self, folder): 37 | self.lr_folder = folder + '/images_lr' 38 | if not os.path.exists(self.lr_folder): 39 | os.makedirs(self.lr_folder) 40 | 41 | self.segm_folder = folder + '/images_segm' 42 | if not os.path.exists(self.segm_folder): 43 | os.makedirs(self.segm_folder) 44 | 45 | self.rec_folder = folder + '/images_rec' 46 | if not os.path.exists(self.rec_folder): 47 | os.makedirs(self.rec_folder) 48 | 49 | self.discr_folder = folder + '/images_discr' 50 | if not os.path.exists(self.discr_folder): 51 | os.makedirs(self.discr_folder) 52 | 53 | def init_models(self): 54 | self.encoders_anatomy = self.model.Encoders_Anatomy 55 | self.reconstructor = self.model.Decoder 56 | self.segmentor = self.model.Segmentor 57 | self.discr_mask = self.model.D_Mask 58 | self.enc_modality = self.model.Enc_Modality 59 | self.fuser = self.model.Anatomy_Fuser 60 | self.discr_mask = self.model.D_Mask 61 | 62 | mean = get_net(self.enc_modality, 'z_mean') 63 | var = get_net(self.enc_modality, 'z_log_var') 64 | self.z_mean = Model(self.enc_modality.inputs, mean.output) 65 | self.z_var = Model(self.enc_modality.inputs, var.output) 66 | 67 | def on_epoch_end(self, epoch=None, logs=None): 68 | ''' 69 | Plot training images from the real_pool. For SDNet the real_pools will contain images paired with masks, 70 | and also unlabelled images. 71 | :param epoch: current training epoch 72 | :param logs: 73 | ''' 74 | x_mod1, x_mod2, m_mod1, m_mod2 = next(self.data_gen_lb) 75 | image_list = [x_mod1[..., 0:1], x_mod2[..., 0:1]] 76 | masks_list = [m_mod1[..., 0:self.conf.num_masks], m_mod2[..., 0:self.conf.num_masks]] 77 | 78 | # we usually plot 4 image-rows. If we have less, it means we've reached the end of the data, so iterate from 79 | # the beginning 80 | while len(image_list[0]) < 4: 81 | x_mod1, x_mod2 = image_list 82 | m_mod1, m_mod2 = masks_list 83 | 84 | x_mod1_2, x_mod2_2, m_mod1_2, m_mod2_2 = next(self.data_gen_lb) 85 | image_list = [np.concatenate([x_mod1[..., 0:1], x_mod1_2[..., 0:1]], axis=0), 86 | np.concatenate([x_mod2[..., 0:1], x_mod2_2[..., 0:1]], axis=0)] 87 | masks_list = [np.concatenate([m_mod1[..., 0:self.conf.num_masks], m_mod1_2[..., 0:self.conf.num_masks]], axis=0), 88 | np.concatenate([m_mod2[..., 0:self.conf.num_masks], m_mod2_2[..., 0:self.conf.num_masks]], axis=0)] 89 | 90 | self.plot_latent_representation(image_list, epoch) 91 | self.plot_segmentations(image_list, masks_list, epoch) 92 | self.plot_reconstructions(image_list, epoch) 93 | self.plot_discriminator_outputs(image_list, masks_list, epoch) 94 | 95 | def plot_latent_representation(self, image_list, epoch): 96 | """ 97 | Plot a 4-row image, where the first column shows the input image and the following columns 98 | each of the 8 channels of the spatial latent representation. 99 | :param image_list: a list of 4-dim arrays of images, one for each modality 100 | :param epoch : the epoch number 101 | """ 102 | 103 | x_list, s_list = [], [] 104 | for mod_i in range(len(image_list)): 105 | images = image_list[mod_i] 106 | 107 | x = utils.data_utils.sample(images, nb_samples=4, seed=self.conf.seed) 108 | x_list.append(x) 109 | 110 | # plot S 111 | s = self.encoders_anatomy[mod_i].predict(x) 112 | s_list.append(s) 113 | 114 | rows = [np.concatenate([x[i, :, :, 0]] + [s[i, :, :, s_chn] for s_chn in range(s.shape[-1])], axis=1) 115 | for i in range(x.shape[0])] 116 | im_plot = np.concatenate(rows, axis=0) 117 | scipy.misc.imsave(self.lr_folder + '/mod_%d_s_lr_epoch_%d.png' % (mod_i, epoch), im_plot) 118 | 119 | # plot Z 120 | enc_modality_inputs = [self.encoders_anatomy[mod_i].predict(images), images] 121 | z, _ = self.enc_modality.predict(enc_modality_inputs) 122 | 123 | means = self.z_mean.predict(enc_modality_inputs) 124 | variances = self.z_var.predict(enc_modality_inputs) 125 | means = np.var(means, axis=0) 126 | variances = np.mean(np.exp(variances), axis=0) 127 | with open(self.lr_folder + '/z_means.csv', 'a+') as f: 128 | f.writelines(', '.join([str(means[i]) for i in range(means.shape[0])]) + '\n') 129 | with open(self.lr_folder + '/z_vars.csv', 'a+') as f: 130 | f.writelines(', '.join([str(variances[i]) for i in range(variances.shape[0])]) + '\n') 131 | 132 | # plot deformed anatomies 133 | new_anatomies = self.fuser.predict(s_list) 134 | 135 | s1_def = new_anatomies[0] 136 | rows = [np.concatenate([x_list[0][i, :, :, 0], x_list[1][i, :, :, 0]] + 137 | [s1_def[i, :, :, s_chn] for s_chn in range(s1_def.shape[-1])], axis=1) 138 | for i in range(x_list[0].shape[0])] 139 | im_plot = np.concatenate(rows, axis=0) 140 | scipy.misc.imsave(self.lr_folder + '/s1def_lr_epoch_%d.png' % (epoch), im_plot) 141 | 142 | def plot_segmentations(self, image_list, mask_list, epoch): 143 | ''' 144 | Plot an image for every sample, where every row contains a channel of the spatial LR and a channel of the 145 | predicted mask. 146 | :param image_list: a list of 4-dim arrays of images, one for each modality 147 | :param masks_list: a list of 4-dim arrays of masks, one for each modality 148 | :param epoch: the epoch number 149 | ''' 150 | 151 | x_list, s_list, m_list2 = [], [], [] 152 | for mod_i in range(len(image_list)): 153 | images = image_list[mod_i] 154 | masks = mask_list[mod_i] 155 | 156 | x = utils.data_utils.sample(images, 4, seed=self.conf.seed) 157 | m = utils.data_utils.sample(masks, 4, seed=self.conf.seed) 158 | 159 | x_list.append(x) 160 | m_list2.append(m) 161 | 162 | assert x.shape[:-1] == m.shape[:-1], 'Incompatible shapes: %s vs %s' % (str(x.shape), str(m.shape)) 163 | 164 | s = self.encoders_anatomy[mod_i].predict(x) 165 | y = self.segmentor.predict(s) 166 | 167 | s_list.append(s) 168 | 169 | rows = [] 170 | for i in range(x.shape[0]): 171 | y_list = [y[i, :, :, chn] for chn in range(y.shape[-1])] 172 | m_list = [m[i, :, :, chn] for chn in range(m.shape[-1])] 173 | if m.shape[-1] < y.shape[-1]: 174 | m_list += [np.zeros(shape=(m.shape[1], m.shape[2]))] * (y.shape[-1] - m.shape[-1]) 175 | assert len(y_list) == len(m_list), 'Incompatible sizes: %d vs %d' % (len(y_list), len(m_list)) 176 | rows += [np.concatenate([x[i, :, :, 0]] + y_list + m_list, axis=1)] 177 | im_plot = np.concatenate(rows, axis=0) 178 | scipy.misc.imsave(self.segm_folder + '/mod_%d_segmentations_epoch_%d.png' % (mod_i, epoch), im_plot) 179 | 180 | new_anatomies = self.fuser.predict(s_list) 181 | pred_masks = [self.segmentor.predict(s) for s in new_anatomies] 182 | 183 | rows = [] 184 | for i in range(x_list[0].shape[0]): 185 | for y in pred_masks: 186 | y_list = [y[i, :, :, chn] for chn in range(self.conf.num_masks)] 187 | m_list = [m_list2[1][i, :, :, chn] for chn in range(self.conf.num_masks)] 188 | assert len(y_list) == len(m_list), 'Incompatible sizes: %d vs %d' % (len(y_list), len(m_list)) 189 | rows += [np.concatenate([x_list[0][i, :, :, 0], x_list[1][i, :, :, 0]] + y_list + m_list, axis=1)] 190 | im_plot = np.concatenate(rows, axis=0) 191 | scipy.misc.imsave(self.segm_folder + '/fused_segmentations_epoch_%d.png' % (epoch), im_plot) 192 | 193 | def plot_discriminator_outputs(self, image_list, mask_list, epoch): 194 | ''' 195 | Plot a histogram of predicted values by the discriminator 196 | :param image_list: a list of 4-dim arrays of images, one for each modality 197 | :param masks_list: a list of 4-dim arrays of masks, one for each modality 198 | :param epoch: the epoch number 199 | ''' 200 | 201 | s_list = [enc.predict(x) for enc, x in zip(self.encoders_anatomy, image_list)] 202 | 203 | 204 | s_list += self.fuser.predict(s_list) 205 | # s2_def, fused_s2 = self.fuser.predict(reversed(s_list)) 206 | # s_list += [s1_def, fused_s1] 207 | 208 | s = np.concatenate(s_list, axis=0) 209 | m = np.concatenate(mask_list, axis=0) 210 | pred_m = self.segmentor.predict(s) 211 | 212 | m = m[..., 0:self.discr_mask.input_shape[-1]] 213 | pred_m = pred_m[..., 0:self.discr_mask.input_shape[-1]] 214 | 215 | m = utils.data_utils.sample(m, nb_samples=4) 216 | pred_m = utils.data_utils.sample(pred_m, nb_samples=4) 217 | 218 | plt.figure() 219 | for i in range(4): 220 | plt.subplot(4, 2, 2 * i + 1) 221 | m_allchn = np.concatenate([m[i, :, :, chn] for chn in range(m.shape[-1])], axis=1) 222 | plt.imshow(m_allchn, cmap='gray') 223 | plt.xticks([]) 224 | plt.yticks([]) 225 | plt.title('Pred: %.3f' % self.discr_mask.predict(m[i:i + 1]).reshape(1, -1).mean(axis=1)) 226 | 227 | plt.subplot(4, 2, 2 * i + 2) 228 | pred_m_allchn_img = np.concatenate([pred_m[i, :, :, chn] for chn in range(pred_m.shape[-1])], axis=1) 229 | plt.imshow(pred_m_allchn_img, cmap='gray') 230 | plt.xticks([]) 231 | plt.yticks([]) 232 | plt.title('Pred: %.3f' % self.discr_mask.predict(pred_m).reshape(1, -1).mean(axis=1)) 233 | plt.tight_layout() 234 | plt.savefig(self.discr_folder + '/discriminator_mask_epoch_%d.png' % epoch) 235 | plt.close() 236 | 237 | def plot_reconstructions(self, image_list, epoch): 238 | """ 239 | Plot two images showing the combination of the spatial and modality LR to generate an image. The first 240 | image uses the predicted S and Z and the second samples Z from a Gaussian. 241 | :param image_list: a list of 2 4-dim arrays of images 242 | :param epoch: the epoch number 243 | """ 244 | x_list, s_list = [], [] 245 | for mod_i in range(len(image_list)): 246 | images = image_list[mod_i] 247 | x = utils.data_utils.sample(images, nb_samples=4, seed=self.conf.seed) 248 | x_list.append(x) 249 | 250 | # S + Z -> Image 251 | s = self.encoders_anatomy[mod_i].predict(x) 252 | s_list.append(s) 253 | 254 | im_plot = self.get_rec_image(x, s) 255 | scipy.misc.imsave(self.rec_folder + '/mod_%d_rec_epoch_%d.png' % (mod_i, epoch), im_plot) 256 | 257 | new_anatomies = self.fuser.predict(s_list) 258 | s1_def = new_anatomies[0] 259 | 260 | im_plot = self.get_rec_image(x_list[1], s1_def) 261 | scipy.misc.imsave(self.rec_folder + '/s1def_rec_epoch_%d.png' % (epoch), im_plot) 262 | 263 | def get_rec_image(self, x, s): 264 | z, _ = self.enc_modality.predict([s, x]) 265 | gaussian = NormalDistribution() 266 | 267 | y = self.reconstructor.predict([s, z]) 268 | y_s0 = self.reconstructor.predict([s, np.zeros(z.shape)]) 269 | all_bkg = np.concatenate([np.zeros(s.shape[:-1] + (s.shape[-1] - 1,)), np.ones(s.shape[:-1] + (1,))], axis=-1) 270 | y_0z = self.reconstructor.predict([all_bkg, z]) 271 | y_00 = self.reconstructor.predict([all_bkg, np.zeros(z.shape)]) 272 | z_random = gaussian.sample(z.shape) 273 | y_random = self.reconstructor.predict([s, z_random]) 274 | rows = [np.concatenate([x[i, :, :, 0], y[i, :, :, 0], y_random[i, :, :, 0], y_s0[i, :, :, 0]] + 275 | [self.reconstructor.predict([get_s0chn(k, s), z])[i, :, :, 0] for k in 276 | range(s.shape[-1] - 1)] + 277 | [y_0z[i, :, :, 0], y_00[i, :, :, 0]], axis=1) for i in range(x.shape[0])] 278 | header = utils.image_utils.makeTextHeaderImage(x.shape[2], ['X', 'rec(s,z)', 'rec(s,~z)', 'rec(s,0)'] + 279 | ['rec(s0_%d, z)' % k for k in range(s.shape[-1] - 1)] + [ 280 | 'rec(0, z)', 'rec(0,0)']) 281 | im_plot = np.concatenate([header] + rows, axis=0) 282 | im_plot = np.clip(im_plot, -1, 1) 283 | return im_plot -------------------------------------------------------------------------------- /callbacks/image_callback.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from abc import abstractmethod 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from keras.callbacks import Callback 8 | 9 | from costs import dice 10 | from utils.image_utils import save_segmentation, intensity_augmentation 11 | 12 | log = logging.getLogger('BaseSaveImage') 13 | 14 | 15 | class BaseSaveImage(Callback): 16 | """ 17 | Abstract base class for saving training images 18 | """ 19 | 20 | def __init__(self, folder, model): 21 | super(BaseSaveImage, self).__init__() 22 | self.folder = os.path.join(folder, 'training_images') 23 | if not os.path.exists(self.folder): 24 | os.makedirs(self.folder) 25 | self.model = model 26 | 27 | @abstractmethod 28 | def on_epoch_end(self, epoch=None, logs=None): 29 | pass 30 | 31 | 32 | class SaveImage(Callback): 33 | """ 34 | Simple callback that saves segmentation masks and dice error. 35 | """ 36 | def __init__(self, folder, test_data, test_masks=None, input_len=None, comet_experiment=None): 37 | super(SaveImage, self).__init__() 38 | self.folder = folder 39 | self.test_data = test_data # this can be a list of images of different spatial dimensions 40 | self.test_masks = test_masks 41 | self.input_len = input_len 42 | self.comet_experiment = comet_experiment 43 | 44 | def on_epoch_end(self, epoch, logs=None): 45 | if not os.path.exists(self.folder): 46 | os.makedirs(self.folder) 47 | 48 | all_dice = [] 49 | for i in range(len(self.test_data)): 50 | d, m = self.test_data[i], self.test_masks[i] 51 | s, im = save_segmentation(self.folder, self.model, d, m, 'slc_%d' % i) 52 | all_dice.append(-dice(self.test_masks[i:i+1], s)) 53 | 54 | if self.comet_experiment is not None: 55 | plt.figure() 56 | plt.plot(0, 0) # fake a line plot to upload to comet 57 | plt.imshow(im, cmap='gray') 58 | plt.xticks([]) 59 | plt.yticks([]) 60 | plt.tight_layout() 61 | self.comet_experiment.log_figure(figure_name='segmentation', figure=plt) 62 | plt.close() 63 | 64 | f = open(os.path.join(self.folder, 'test_error.txt'), 'a+') 65 | f.writelines("%d, %.3f\n" % (epoch, np.mean(all_dice))) 66 | f.close() 67 | 68 | 69 | class SaveEpochImages(Callback): 70 | def __init__(self, conf, model, img_gen, comet_experiment=None): 71 | super(SaveEpochImages, self).__init__() 72 | self.folder = conf.folder + '/training' 73 | self.conf = conf 74 | self.model = model 75 | self.gen = img_gen 76 | self.comet_experiment = comet_experiment 77 | if not os.path.exists(self.folder): 78 | os.makedirs(self.folder) 79 | 80 | def on_epoch_end(self, epoch, logs=None): 81 | x, m = next(self.gen) 82 | x = intensity_augmentation(x) 83 | 84 | y = self.model.predict(x) 85 | im1, im2 = save_multiimage_segmentation(x, m, y, self.folder, epoch) 86 | if self.comet_experiment is not None: 87 | plt.figure() 88 | plt.plot(0, 0) # fake a line plot to upload to comet 89 | plt.imshow(im1, cmap='gray') 90 | plt.imshow(im2, cmap='gray', alpha=0.5) 91 | plt.xticks([]) 92 | plt.yticks([]) 93 | plt.tight_layout() 94 | self.comet_experiment.log_figure(figure_name='segmentation', figure=plt) 95 | plt.close() 96 | 97 | 98 | def save_multiimage_segmentation(x, m, y, folder, epoch): 99 | rows_img, rows_msk = [], [] 100 | for i in range(x.shape[0]): 101 | if i == 4: 102 | break 103 | # y_list = [y[i, :, :, chn] for chn in range(y.shape[-1])] 104 | # m_list = [m[i, :, :, chn] for chn in range(m.shape[-1])] 105 | # if m.shape[-1] < y.shape[-1]: 106 | # m_list += [np.zeros(shape=(m.shape[1], m.shape[2]))] * (y.shape[-1] - m.shape[-1]) 107 | # assert len(y_list) == len(m_list), 'Incompatible sizes: %d vs %d' % (len(y_list), len(m_list)) 108 | 109 | rows_img += [np.concatenate([x[i, :, :, 0], x[i, :, :, 0], x[i, :, :, 0]], axis=1)] 110 | rows_msk += [np.concatenate([np.zeros(x[i, :, :, 0].shape)] + 111 | [sum([m[i, :, :, j] * (j + 1) * (1.0 / m.shape[-1]) for j in range(m.shape[-1])])] + 112 | [sum([y[i, :, :, j] * (j + 1) * (1.0 / m.shape[-1]) for j in range(m.shape[-1])])], axis=1)] 113 | 114 | rows_img = np.concatenate(rows_img, axis=0) 115 | rows_msk = np.concatenate(rows_msk, axis=0) 116 | 117 | plt.figure() 118 | plt.imshow(rows_img, cmap='gray') 119 | plt.imshow(rows_msk, alpha=0.5) 120 | plt.savefig(folder + '/segmentations_epoch_%d.png' % (epoch)) 121 | # scipy.misc.imsave(folder + '/segmentations_epoch_%d.png' % (epoch), im_plot) 122 | # return im_plot 123 | return rows_img, rows_msk 124 | 125 | 126 | def get_s0chn(k, s): 127 | s_res = s.copy() 128 | chnk = s_res[..., k] 129 | # move channel k 1s to the background 130 | s_res[..., -1][chnk == 1] = 1 131 | s_res[..., k] = 0 132 | return s_res -------------------------------------------------------------------------------- /callbacks/loss_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | from keras.callbacks import Callback 5 | 6 | 7 | class SaveLoss(Callback): 8 | """ 9 | Save the training loss in a csv file and plot a figure. 10 | """ 11 | def __init__(self, folder, scale='linear'): 12 | super(SaveLoss, self).__init__() 13 | self.folder = folder 14 | self.values = dict() 15 | self.scale = scale 16 | 17 | def on_epoch_end(self, epoch, logs=None): 18 | if logs is None: return 19 | 20 | if len(self.values) == 0: 21 | for k in logs: 22 | self.values[k] = [] 23 | 24 | for k in logs: 25 | self.values[k].append(logs[k]) 26 | 27 | plt.figure() 28 | plt.suptitle('Training loss', fontsize=16) 29 | for k in self.values: 30 | if 'dis' in k or 'adv' in k: 31 | continue 32 | 33 | epochs = range(len(self.values[k])) 34 | if self.scale == 'linear': 35 | plt.plot(epochs, self.values[k], label=k) 36 | elif self.scale == 'log': 37 | plt.semilogy(epochs, self.values[k], label=k) 38 | plt.xlabel('Epochs') 39 | plt.ylabel('Loss') 40 | plt.legend(loc='best') 41 | plt.savefig(os.path.join(self.folder, 'training_loss.png')) 42 | 43 | plt.figure() 44 | plt.suptitle('Training loss', fontsize=16) 45 | for k in self.values: 46 | if not ('dis' in k or 'adv' in k): 47 | continue 48 | 49 | epochs = range(len(self.values[k])) 50 | plt.plot(epochs, self.values[k], label=k) 51 | plt.xlabel('Epochs') 52 | plt.ylabel('Loss') 53 | plt.legend(loc='best') 54 | plt.savefig(os.path.join(self.folder, 'training_discr_loss.png')) 55 | 56 | plt.close() -------------------------------------------------------------------------------- /callbacks/swa.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Stochastic Weight Averaging: https://arxiv.org/abs/1803.05407 5 | Keras implementation adapted from https://github.com/kristpapadopoulos/keras-stochastic-weight-averaging. 6 | """ 7 | 8 | import logging 9 | import keras 10 | 11 | log = logging.getLogger('swa') 12 | 13 | 14 | class SWA(keras.callbacks.Callback): 15 | def __init__(self, swa_epoch, model_build_fnc, build_params): 16 | super(SWA, self).__init__() 17 | self.swa_epoch = swa_epoch 18 | self.model_build_fnc = model_build_fnc 19 | self.build_params = build_params 20 | self.clone = None 21 | 22 | def on_train_begin(self, logs=None): 23 | self.nb_epoch = self.params['epochs'] 24 | print('Stochastic weight averaging selected for last {} epochs.' 25 | .format(self.nb_epoch - self.swa_epoch)) 26 | 27 | def on_epoch_end(self, epoch, logs=None): 28 | if epoch <= self.swa_epoch: 29 | self.swa_weights = self.model.get_weights() 30 | 31 | elif epoch > self.swa_epoch: 32 | for i, layer in enumerate(self.swa_weights): 33 | self.swa_weights[i] = (self.swa_weights[i] * (epoch - self.swa_epoch) + self.model.get_weights()[i])\ 34 | / ((epoch - self.swa_epoch) + 1) 35 | 36 | def on_train_end(self, logs=None): 37 | self.model.set_weights(self.swa_weights) 38 | log.debug('Final model parameters set to stochastic weight average.') 39 | 40 | def get_clone_model(self): 41 | if self.clone is None: 42 | if self.build_params is not None: 43 | self.clone = self.model_build_fnc(self.build_params) 44 | else: 45 | self.clone = self.model_build_fnc() 46 | self.clone.set_weights(self.swa_weights) 47 | return self.clone 48 | -------------------------------------------------------------------------------- /configuration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/multimodal_segmentation/a4fa1b39830f6c1bc320ff5b5e3fda82b8382e18/configuration/__init__.py -------------------------------------------------------------------------------- /configuration/dafnet_config_chaos.py: -------------------------------------------------------------------------------- 1 | from loaders import chaos 2 | 3 | params = { 4 | 'seed': 10, 5 | 'folder': 'dafnet_chaos', 6 | 'epochs': 500, 7 | 'batch_size': 6, 8 | 'split': 0, 9 | 'dataset_name': 'chaos', 10 | 'test_dataset': 'chaos', 11 | 'input_shape': chaos.ChaosLoader().input_shape, 12 | 'image_downsample': 1, # downsample image size: used for testing 13 | 'modality': ['t1', 't2'], # list of [source, target] modalities 14 | 'model': 'dafnet.DAFNet', # model to load 15 | 'executor': 'dafnet_executor.DAFNetExecutor', # model trainer 16 | 'l_mix': 1, # amount of supervision for target modality 17 | 'decoder_type': 'film', # decoder type - can be film or spade 18 | 'num_z': 8, # dimensions of the modality factor 19 | 'w_sup_M': 10, 20 | 'w_adv_M': 1, 21 | 'w_rec_X': 1, 22 | 'w_adv_X': 1, 23 | 'w_rec_Z': 1, 24 | 'w_kl': 0.1, 25 | 'lr': 0.0001, 26 | 'randomise': False, 27 | 'automatedpairing': False, 28 | } 29 | 30 | # discriminator configs 31 | d_mask_params = {'filters': 64, 'lr': 0.0001, 'name': 'D_Mask'} 32 | d_image_params = {'filters': 64, 'lr': 0.0001, 'name': 'D_Image'} 33 | 34 | anatomy_encoder_params = { 35 | 'normalise' : 'batch', # normalisation layer - can be batch or instance 36 | 'downsample' : 4, # number of downsample layers of UNet encoder 37 | 'filters' : 64, # number of filters in the first convolutional layer 38 | 'out_channels': 8, # number of output channels - dimensions of the anatomy factor 39 | 'rounding' : True 40 | } 41 | 42 | 43 | def get(): 44 | shp = params['input_shape'] 45 | ratio = params['image_downsample'] 46 | shp = (int(shp[0] / ratio), int(shp[1] / ratio), shp[2]) 47 | 48 | params['input_shape'] = shp 49 | params['num_masks'] = chaos.ChaosLoader().num_masks 50 | 51 | d_mask_params['input_shape'] = (shp[:-1]) + (chaos.ChaosLoader().num_masks,) 52 | d_image_params['input_shape'] = shp 53 | 54 | anatomy_encoder_params['input_shape'] = shp 55 | anatomy_encoder_params['output_shape'] = (shp[:-1]) + (anatomy_encoder_params['out_channels'],) 56 | 57 | params.update({'anatomy_encoder': anatomy_encoder_params, 58 | 'd_mask_params': d_mask_params, 'd_image_params': d_image_params}) 59 | return params 60 | 61 | 62 | -------------------------------------------------------------------------------- /configuration/dafnet_spade_config_chaos.py: -------------------------------------------------------------------------------- 1 | from loaders import chaos 2 | 3 | params = { 4 | 'seed': 10, 5 | 'folder': 'dafnet_spade_chaos', 6 | 'epochs': 500, 7 | 'batch_size': 6, 8 | 'split': 0, 9 | 'dataset_name': 'chaos', 10 | 'test_dataset': 'chaos', 11 | 'input_shape': chaos.ChaosLoader().input_shape, 12 | 'image_downsample': 1, # downsample image size: used for testing 13 | 'modality': ['t1', 't2'], # list of [source, target] modalities 14 | 'model': 'dafnet.DAFNet', # model to load 15 | 'executor': 'dafnet_executor.DAFNetExecutor', # model trainer 16 | 'l_mix': 1, # amount of supervision for target modality 17 | 'decoder_type': 'spade', # decoder type - can be film or spade 18 | 'num_z': 8, 19 | 'w_sup_M': 10, 20 | 'w_adv_M': 1, 21 | 'w_rec_X': 1, 22 | 'w_adv_X': 1, 23 | 'w_rec_Z': 1, 24 | 'w_kl': 0.1, 25 | 'lr': 0.0001, 26 | 'randomise': False, 27 | 'automatedpairing': False, 28 | } 29 | 30 | # discriminator configs 31 | d_mask_params = {'filters': 64, 'lr': 0.0001, 'name': 'D_Mask'} 32 | d_image_params = {'filters': 64, 'lr': 0.0001, 'name': 'D_Image'} 33 | 34 | anatomy_encoder_params = { 35 | 'normalise' : 'batch', # normalisation layer - can be batch or instance 36 | 'downsample' : 4, # number of downsample layers of UNet encoder 37 | 'filters' : 64, # number of filters in the first convolutional layer 38 | 'out_channels': 8, # number of output channels - dimensions of the anatomy factor 39 | 'rounding' : True 40 | } 41 | 42 | 43 | def get(): 44 | shp = params['input_shape'] 45 | ratio = params['image_downsample'] 46 | shp = (int(shp[0] / ratio), int(shp[1] / ratio), shp[2]) 47 | 48 | params['input_shape'] = shp 49 | params['num_masks'] = chaos.ChaosLoader().num_masks 50 | 51 | d_mask_params['input_shape'] = (shp[:-1]) + (chaos.ChaosLoader().num_masks,) 52 | d_image_params['input_shape'] = shp 53 | 54 | anatomy_encoder_params['input_shape'] = shp 55 | anatomy_encoder_params['output_shape'] = (shp[:-1]) + (anatomy_encoder_params['out_channels'],) 56 | 57 | params.update({'anatomy_encoder': anatomy_encoder_params, 'd_mask_params': d_mask_params, 58 | 'd_image_params': d_image_params}) 59 | return params 60 | 61 | 62 | -------------------------------------------------------------------------------- /configuration/mmsdnet_config_chaos.py: -------------------------------------------------------------------------------- 1 | from loaders import chaos 2 | 3 | params = { 4 | 'seed': 10, 5 | 'folder': 'mmsdnet_chaos', 6 | 'epochs': 500, 7 | 'batch_size': 6, 8 | 'split': 0, 9 | 'dataset_name': 'chaos', 10 | 'test_dataset': 'chaos', 11 | 'input_shape': chaos.ChaosLoader().input_shape, 12 | 'image_downsample': 1, # downsample image size: used for testing 13 | 'modality': ['t1', 't2'], # list of [source, target] modalities 14 | 'model': 'mmsdnet.MMSDNet', # model to load 15 | 'executor': 'mmsdnet_executor.MMSDNetExecutor', # model trainer 16 | 'l_mix': 1, # amount of supervision for target modality 17 | 'decoder_type': 'film', # decoder type - can be film or spade 18 | 'num_z': 8, # dimensions of the modality factor 19 | 'w_sup_M': 10, 20 | 'w_adv_M': 1, 21 | 'w_rec_X': 10, 22 | 'w_adv_X': 1, 23 | 'w_rec_Z': 1, 24 | 'w_kl': 0.1, 25 | 'lr': 0.0001, 26 | } 27 | 28 | # discriminator configs 29 | d_mask_params = {'filters': 4, 'lr': 0.0001, 'name': 'D_Mask'} 30 | 31 | anatomy_encoder_params = { 32 | 'normalise' : 'batch', # normalisation layer - can be batch or instance 33 | 'downsample' : 4, # number of downsample layers of UNet encoder 34 | 'filters' : 64, # number of filters in the first convolutional layer 35 | 'out_channels': 8, # number of output channels - dimensions of the anatomy factor 36 | 'rounding' : True 37 | } 38 | 39 | 40 | def get(): 41 | shp = params['input_shape'] 42 | ratio = params['image_downsample'] 43 | shp = (int(shp[0] / ratio), int(shp[1] / ratio), shp[2]) 44 | 45 | params['input_shape'] = shp 46 | params['num_masks'] = chaos.ChaosLoader().num_masks 47 | 48 | d_mask_params['input_shape'] = (shp[:-1]) + (chaos.ChaosLoader().num_masks,) 49 | 50 | anatomy_encoder_params['input_shape'] = shp 51 | anatomy_encoder_params['output_shape'] = (shp[:-1]) + (anatomy_encoder_params['out_channels'],) 52 | params.update({'anatomy_encoder': anatomy_encoder_params, 'd_mask_params': d_mask_params}) 53 | return params 54 | -------------------------------------------------------------------------------- /costs.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from keras import backend as K 5 | import tensorflow as tf 6 | from scipy.spatial.distance import pdist, squareform 7 | 8 | log = logging.getLogger() 9 | 10 | lambda_bce = 0.01 11 | 12 | ########## RECONSTRUCTION LOSSES ########## 13 | 14 | def make_similarity_weighted_mae(weights): 15 | def similarity_weighted_mae(y_true, y_pred): 16 | shape = K.int_shape(y_pred) 17 | w_reshaped = tf.expand_dims(tf.expand_dims(weights, axis=1), axis=1) 18 | mae = tf.multiply(K.abs(y_true - y_pred), tf.tile(w_reshaped, (1, shape[1], shape[2], 1))) 19 | return K.mean(mae) 20 | 21 | return similarity_weighted_mae 22 | 23 | 24 | def mae_single_input(y): 25 | y1, y2 = y 26 | return K.mean(K.abs(y1-y2), axis=(1, 2)) 27 | 28 | 29 | ########## SEGMENTATION LOSSES ########## 30 | 31 | def dice(y_true, y_pred, binarise=False, smooth=1e-12): 32 | y_pred = y_pred[..., 0:y_true.shape[-1]] 33 | 34 | # Cast the prediction to binary 0 or 1 35 | if binarise: 36 | y_pred = np.round(y_pred) 37 | 38 | # Symbolically compute the intersection 39 | y_int = y_true * y_pred 40 | return np.mean((2 * np.sum(y_int, axis=(1, 2, 3)) + smooth) 41 | / (np.sum(y_true, axis=(1, 2, 3)) + np.sum(y_pred, axis=(1, 2, 3)) + smooth)) 42 | 43 | def dice_coef_perbatch(y_true, y_pred): 44 | # Symbolically compute the intersection 45 | intersection = K.sum(y_true * y_pred, axis=(1, 2, 3)) 46 | union = K.sum(y_true, axis=(1, 2, 3)) + K.sum(y_pred, axis=(1, 2, 3)) 47 | dice = (2 * intersection + 1e-12) / (union + 1e-12) 48 | return 1 - dice 49 | 50 | def dice_coef_loss(y_true, y_pred): 51 | ''' 52 | DICE Loss. 53 | :param y_true: a tensor of ground truth data 54 | :param y_pred: a tensor of predicted data 55 | ''' 56 | return K.mean(dice_coef_perbatch(y_true, y_pred), axis=0) 57 | 58 | 59 | def make_dice_loss_fnc(restrict_chn=1): 60 | log.debug('Making DICE loss function for the first %d channels' % restrict_chn) 61 | 62 | def dice_fnc(y_true, y_pred): 63 | y_pred_new = y_pred[..., 0:restrict_chn] + 0. 64 | y_true_new = y_true[..., 0:restrict_chn] + 0. 65 | return dice_coef_loss(y_true_new, y_pred_new) 66 | 67 | return dice_fnc 68 | 69 | 70 | def weighted_cross_entropy_loss(y_pred, y_true): 71 | """ 72 | Define weighted cross - entropy function for classification tasks. 73 | :param y_pred: tensor[None, width, height, n_classes] 74 | :param y_true: tensor[None, width, height, n_classes] 75 | """ 76 | num_classes = K.int_shape(y_true)[-1] 77 | n = [tf.reduce_sum(tf.cast(y_true[..., c], tf.float32)) for c in range(num_classes)] 78 | n_tot = tf.reduce_sum(n) 79 | weights = [n_tot / (n[c] + 1e-12) for c in range(num_classes)] 80 | y_pred = tf.reshape(y_pred, (-1, num_classes)) 81 | y_true = tf.to_float(tf.reshape(y_true, (-1, num_classes))) 82 | w_cross_entropy = tf.multiply(y_true * tf.log(y_pred + 1e-12), weights) 83 | w_cross_entropy = -tf.reduce_sum(w_cross_entropy, reduction_indices=[1]) 84 | loss = tf.reduce_mean(w_cross_entropy, name='weighted_cross_entropy') 85 | return loss 86 | 87 | 88 | def weighted_cross_entropy_perbatch(y_pred, y_true): 89 | """ 90 | Define weighted cross - entropy function for classification tasks. 91 | :param y_pred: tensor[None, width, height, n_classes] 92 | :param y_true: tensor[None, width, height, n_classes] 93 | """ 94 | shape = K.int_shape(y_true) 95 | restrict_chn = shape[-1] 96 | 97 | n = tf.reduce_sum(y_true, axis=[0, 1, 2]) 98 | n_tot = tf.reduce_sum(n, axis=0) 99 | weights = n_tot / (n + 1e-12) 100 | 101 | y_pred = tf.reshape(y_pred, (-1, shape[1] * shape[2], restrict_chn)) 102 | y_true2 = tf.to_float(tf.reshape(y_true, (-1, shape[1] * shape[2], restrict_chn))) 103 | softmax = tf.nn.softmax(y_pred) 104 | 105 | w_cross_entropy = -tf.reduce_sum(y_true2 * tf.log(softmax + 1e-12) * weights, reduction_indices=[2]) 106 | # w_cross_entropy = tf.multiply(w_cross_entropy, tf.tile(tf.expand_dims(contributions, axis=-1), (1, shape[1] * shape[2]))) 107 | loss = tf.reduce_mean(w_cross_entropy, axis=1, name='softmax_weighted_cross_entropy') 108 | return loss 109 | 110 | 111 | def similarity_weighted_dice(weights, restrict_chn): 112 | log.debug('Making similarity weighted DICE loss function for the first %d channels' % restrict_chn) 113 | 114 | def weighted_dice_fnc(y_true): 115 | y_pred_new, y_true_new = y_true 116 | # assert K.int_shape(y_pred)[-1] == K.int_shape(y_true)[-1] + 1, 'y_pred does not contain similarity weights' 117 | 118 | y_pred_new = y_pred_new[..., 0:restrict_chn] + 0. 119 | y_true_new = y_true_new[..., 0:restrict_chn] + 0. 120 | 121 | intersection = K.sum(y_true_new * y_pred_new, axis=(1, 2, 3)) 122 | union = K.sum(y_true_new, axis=(1, 2, 3)) + K.sum(y_pred_new, axis=(1, 2, 3)) 123 | dice = (2 * intersection + 1e-5) / (union + 1e-5) 124 | return K.mean(weights * (1 - dice)) 125 | 126 | return weighted_dice_fnc 127 | 128 | 129 | def make_combined_dice_bce(num_classes): 130 | dice = make_dice_loss_fnc(num_classes) 131 | bce = weighted_cross_entropy_loss 132 | 133 | def combined_dice_bce(y_true, y_pred): 134 | return dice(y_true, y_pred) + lambda_bce * bce(y_true, y_pred) 135 | 136 | return combined_dice_bce 137 | 138 | def make_combined_dice_bce_perbatch(num_classes): 139 | def fnc(y_true, y_pred): 140 | y_pred_new = y_pred[..., 0:num_classes] + 0. 141 | y_true_new = y_true[..., 0:num_classes] + 0. 142 | return dice_coef_perbatch(y_true_new, y_pred_new) + lambda_bce * weighted_cross_entropy_perbatch(y_true, y_pred) 143 | return fnc 144 | 145 | def similarity_weighted_dice_bce(contributions, restrict_chn, eps=1e-5): 146 | log.debug('Making similarity weighted DICE loss function for the first %d channels' % restrict_chn) 147 | 148 | def weighted_dice_fnc(y_true, y_pred): 149 | y_pred_new = y_pred[..., 0:restrict_chn] + 0. 150 | y_true_new = y_true[..., 0:restrict_chn] + 0. 151 | 152 | intersection = K.sum(y_true_new * y_pred_new, axis=(1, 2, 3)) 153 | union = K.sum(y_true_new, axis=(1, 2, 3)) + K.sum(y_pred_new, axis=(1, 2, 3)) 154 | dice = (2 * intersection + eps) / (union + eps) 155 | return K.mean(contributions * (1 - dice)) 156 | 157 | def weighted_cross_entropy(y_pred, y_true): 158 | """ 159 | Define weighted cross - entropy function for classification tasks. 160 | :param y_pred: tensor[None, width, height, n_classes] 161 | :param y_true: tensor[None, width, height, n_classes] 162 | """ 163 | shape = K.int_shape(y_true) 164 | num_chn = shape[-1] 165 | 166 | n = tf.reduce_sum(y_true, axis=[0, 1, 2]) 167 | n_tot = tf.reduce_sum(n, axis=0) 168 | weights = n_tot / (n + eps) 169 | 170 | y_pred = tf.reshape(y_pred, (-1, shape[1] * shape[2], num_chn)) 171 | y_true2 = tf.to_float(tf.reshape(y_true, (-1, shape[1] * shape[2], num_chn))) 172 | 173 | w_cross_entropy = -tf.reduce_sum(y_true2 * tf.log(y_pred + eps) * weights, reduction_indices=[2]) 174 | w_cross_entropy = tf.multiply(w_cross_entropy, tf.tile(contributions, (1, shape[1] * shape[2]))) 175 | loss = tf.reduce_mean(w_cross_entropy, name='weighted_cross_entropy') 176 | 177 | return loss 178 | 179 | def combined_fnc(y_true, y_pred): 180 | return weighted_dice_fnc(y_true, y_pred) + lambda_bce * weighted_cross_entropy(y_true, y_pred) 181 | 182 | return combined_fnc 183 | 184 | ########## VAE LOSSES ########## 185 | 186 | def kl(args): 187 | mean, log_var = args 188 | kl_loss = -0.5 * K.sum(1 + log_var - K.square(mean) - K.exp(log_var), axis=-1) 189 | return K.reshape(kl_loss, (-1, 1)) 190 | 191 | 192 | ########## OTHER LOSSES ########## 193 | 194 | def ypred(y_true, y_pred): 195 | return y_pred 196 | 197 | 198 | def distance_correlation(A, B): 199 | ''' 200 | Calculate the Distance Correlation between the two vectors. https://en.wikipedia.org/wiki/Distance_correlation 201 | Value of 0 implies independence. A and B can be vectors of different length. 202 | :param A: vector A of shape (num_samples, sizeA) 203 | :param B: vector B of shape (num_samples, sizeB) 204 | :return: the distance correlation between A and B 205 | ''' 206 | n = A.shape[0] 207 | if B.shape[0] != A.shape[0]: 208 | raise ValueError('Number of samples must match') 209 | a = squareform(pdist(A)) 210 | b = squareform(pdist(B)) 211 | A = a - a.mean(axis=0)[None, :] - a.mean(axis=1)[:, None] + a.mean() 212 | B = b - b.mean(axis=0)[None, :] - b.mean(axis=1)[:, None] + b.mean() 213 | 214 | dcov2_xy = (A * B).sum() / float(n * n) 215 | dcov2_xx = (A * A).sum() / float(n * n) 216 | dcov2_yy = (B * B).sum() / float(n * n) 217 | dcor = np.sqrt(dcov2_xy) / np.sqrt(np.sqrt(dcov2_xx) * np.sqrt(dcov2_yy)) 218 | return dcor 219 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import importlib 4 | import json 5 | import logging 6 | import os 7 | 8 | import git 9 | import matplotlib 10 | import numpy 11 | import comet_ml 12 | matplotlib.use('Agg') # environment for non-interactive environments 13 | 14 | from easydict import EasyDict 15 | 16 | 17 | class Experiment(object): 18 | def __init__(self): 19 | self.log = None 20 | 21 | def init_logging(self, config): 22 | if not os.path.exists(config.folder): 23 | os.makedirs(config.folder) 24 | logging.basicConfig(filename=config.folder + '/logfile.log', level=logging.DEBUG, format='%(asctime)s %(message)s') 25 | logging.getLogger().addHandler(logging.StreamHandler()) 26 | 27 | self.log = logging.getLogger() 28 | self.log.debug(config.items()) 29 | self.log.info('---- Setting up experiment at ' + config.folder + '----') 30 | 31 | def get_config(self, split, args): 32 | """ 33 | Read a config file and convert it into an object for easy processing. 34 | :param split: the cross-validation split id 35 | :param args: the command arguments 36 | :return: config object(namespace) 37 | """ 38 | config_script = args.config 39 | 40 | config_dict = importlib.import_module('configuration.' + config_script).get() 41 | config = EasyDict(config_dict) 42 | config.split = split 43 | 44 | if (hasattr(config, 'randomise') and config.randomise) or (hasattr(args, 'randomise') and args.randomise): 45 | config.randomise = True 46 | config.folder += '_randomise' 47 | 48 | config.n_pairs = 1 49 | if (hasattr(config, 'automatedpairing') and config.automatedpairing) or \ 50 | (hasattr(args, 'automatedpairing') and args.automatedpairing): 51 | config.automatedpairing = True 52 | config.folder += '_automatedpairing' 53 | config.n_pairs = 3 54 | 55 | l_mix = config.l_mix 56 | if hasattr(args, 'l_mix'): 57 | config.l_mix = float(args.l_mix) 58 | l_mix = args.l_mix 59 | config.folder += '_l%s' % l_mix 60 | 61 | config.folder += '_' + str(config.modality) 62 | config.folder += '_split%s' % split 63 | config.folder = config.folder.replace('.', '') 64 | 65 | if args.test_dataset: 66 | print('Overriding default test dataset') 67 | config.test_dataset = args.test_dataset 68 | 69 | config.githash = git.Repo(search_parent_directories=True).head.object.hexsha 70 | 71 | self.save_config(config) 72 | return config 73 | 74 | def save_config(self, config): 75 | if not os.path.exists(config.folder): 76 | os.makedirs(config.folder) 77 | with open(config.folder + '/experiment_configuration.json', 'w') as outfile: 78 | json.dump(dict(config.items()), outfile) 79 | 80 | def run(self): 81 | args = Experiment.read_console_parameters() 82 | configuration = self.get_config(int(args.split), args) 83 | self.init_logging(configuration) 84 | self.run_experiment(configuration, args.test) 85 | 86 | def run_experiment(self, configuration, test): 87 | executor = self.get_executor(configuration, test) 88 | 89 | if test: 90 | executor.test() 91 | else: 92 | executor.train() 93 | def default(o): 94 | if isinstance(o, numpy.int64): return int(o) 95 | raise TypeError 96 | with open(configuration.folder + '/experiment_configuration.json', 'w') as outfile: 97 | json.dump(vars(configuration), outfile, default=default) 98 | executor.test() 99 | 100 | @staticmethod 101 | def read_console_parameters(): 102 | parser = argparse.ArgumentParser(description='') 103 | parser.add_argument('--config', default='', help='The experiment configuration file', required=True) 104 | parser.add_argument('--test', help='Evaluate the model on test data', type=bool) 105 | parser.add_argument('--test_dataset', help='Override default test dataset', choices=['chaos']) 106 | parser.add_argument('--split', help='Data split to run.', required=True) 107 | parser.add_argument('--l_mix', help='Percentage of labelled data') 108 | parser.add_argument('--automatedpairing', help='Use weighted cost for training', type=bool) 109 | parser.add_argument('--randomise', help='Randomise multimodal pairs', type=bool) 110 | 111 | return parser.parse_args() 112 | 113 | def get_executor(self, config, test): 114 | # Initialise model 115 | module_name = config.model.split('.')[0] 116 | model_name = config.model.split('.')[1] 117 | model = getattr(importlib.import_module('models.' + module_name), model_name)(config) 118 | model.build() 119 | 120 | # Initialise executor 121 | module_name = config.executor.split('.')[0] 122 | model_name = config.executor.split('.')[1] 123 | executor = getattr(importlib.import_module('model_executors.' + module_name), model_name)(config, model) 124 | return executor 125 | 126 | 127 | if __name__ == '__main__': 128 | exp = Experiment() 129 | exp.run() 130 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/multimodal_segmentation/a4fa1b39830f6c1bc320ff5b5e3fda82b8382e18/layers/__init__.py -------------------------------------------------------------------------------- /layers/film.py: -------------------------------------------------------------------------------- 1 | 2 | import keras.backend as K 3 | from keras.engine import Layer 4 | 5 | 6 | class FiLM(Layer): 7 | ''' 8 | The FiLM Normalization defined in ? 9 | 10 | used like this: 11 | 12 | h = FiLM()([h, gamma, beta]) 13 | 14 | where: 15 | h is the multi channel image with shape (?, H, W, C) 16 | gamma has shape (?, C), and scales the channels (values in range -inf,inf) 17 | beta has shape (?, C), and offsets the channels (values in range -inf,inf) 18 | ''' 19 | 20 | def __init__(self, **kwargs): 21 | super(FiLM, self).__init__(**kwargs) 22 | 23 | def build(self, input_shape): 24 | super(FiLM, self).build(input_shape) 25 | 26 | def call(self, x, **kwargs): 27 | x, gamma, beta = x 28 | 29 | # print('FILM: ', K.int_shape(x), K.int_shape(gamma)) 30 | 31 | gamma = K.tile(K.reshape(gamma, (K.shape(gamma)[0], 1, 1, K.shape(gamma)[-1])), 32 | (1, K.shape(x)[1], K.shape(x)[2], 1)) 33 | beta = K.tile(K.reshape(beta, (K.shape(beta)[0], 1, 1, K.shape(beta)[-1])), 34 | (1, K.shape(x)[1], K.shape(x)[2], 1)) 35 | 36 | return x * gamma + beta 37 | 38 | def compute_output_shape(self, input_shape): 39 | return input_shape -------------------------------------------------------------------------------- /layers/interpolate_spline.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Polyharmonic spline interpolation.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensorflow.python.framework import ops 22 | from tensorflow.python.framework import tensor_shape 23 | from tensorflow.python.ops import array_ops 24 | from tensorflow.python.ops import linalg_ops 25 | from tensorflow.python.ops import math_ops 26 | 27 | EPSILON = 0.0000000001 28 | 29 | 30 | def _cross_squared_distance_matrix(x, y): 31 | """Pairwise squared distance between two (batch) matrices' rows (2nd dim). 32 | Computes the pairwise distances between rows of x and rows of y 33 | Args: 34 | x: [batch_size, n, d] float `Tensor` 35 | y: [batch_size, m, d] float `Tensor` 36 | Returns: 37 | squared_dists: [batch_size, n, m] float `Tensor`, where 38 | squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 39 | """ 40 | x_norm_squared = math_ops.reduce_sum(math_ops.square(x), 2) 41 | y_norm_squared = math_ops.reduce_sum(math_ops.square(y), 2) 42 | 43 | # Expand so that we can broadcast. 44 | x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2) 45 | y_norm_squared_tile = array_ops.expand_dims(y_norm_squared, 1) 46 | 47 | x_y_transpose = math_ops.matmul(x, y, adjoint_b=True) 48 | 49 | # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj 50 | squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile 51 | 52 | return squared_dists 53 | 54 | 55 | def _pairwise_squared_distance_matrix(x): 56 | """Pairwise squared distance among a (batch) matrix's rows (2nd dim). 57 | This saves a bit of computation vs. using _cross_squared_distance_matrix(x,x) 58 | Args: 59 | x: `[batch_size, n, d]` float `Tensor` 60 | Returns: 61 | squared_dists: `[batch_size, n, n]` float `Tensor`, where 62 | squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2 63 | """ 64 | 65 | x_x_transpose = math_ops.matmul(x, x, adjoint_b=True) 66 | x_norm_squared = array_ops.matrix_diag_part(x_x_transpose) 67 | x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2) 68 | 69 | # squared_dists[b,i,j] = ||x_bi - x_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj 70 | squared_dists = x_norm_squared_tile - 2 * x_x_transpose + array_ops.transpose( 71 | x_norm_squared_tile, [0, 2, 1]) 72 | 73 | return squared_dists 74 | 75 | 76 | def _solve_interpolation(train_points, train_values, order, 77 | regularization_weight): 78 | """Solve for interpolation coefficients. 79 | Computes the coefficients of the polyharmonic interpolant for the 'training' 80 | data defined by (train_points, train_values) using the kernel phi. 81 | Args: 82 | train_points: `[b, n, d]` interpolation centers 83 | train_values: `[b, n, k]` function values 84 | order: order of the interpolation 85 | regularization_weight: weight to place on smoothness regularization term 86 | Returns: 87 | w: `[b, n, k]` weights on each interpolation center 88 | v: `[b, d, k]` weights on each input dimension 89 | Raises: 90 | ValueError: if d or k is not fully specified. 91 | """ 92 | 93 | # These dimensions are set dynamically at runtime. 94 | b, n, _ = array_ops.unstack(array_ops.shape(train_points), num=3) 95 | 96 | d = train_points.shape[-1] 97 | # if tensor_shape.dimension_value(d) is None: 98 | # raise ValueError('The dimensionality of the input points (d) must be ' 99 | # 'statically-inferrable.') 100 | 101 | k = train_values.shape[-1] 102 | # if tensor_shape.dimension_value(k) is None: 103 | # raise ValueError('The dimensionality of the output values (k) must be ' 104 | # 'statically-inferrable.') 105 | 106 | # First, rename variables so that the notation (c, f, w, v, A, B, etc.) 107 | # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. 108 | # To account for python style guidelines we use 109 | # matrix_a for A and matrix_b for B. 110 | 111 | c = train_points 112 | f = train_values 113 | 114 | # Next, construct the linear system. 115 | with ops.name_scope('construct_linear_system'): 116 | 117 | matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n] 118 | if regularization_weight > 0: 119 | batch_identity_matrix = array_ops.expand_dims( 120 | linalg_ops.eye(n, dtype=c.dtype), 0) 121 | matrix_a += regularization_weight * batch_identity_matrix 122 | 123 | # Append ones to the feature values for the bias term in the linear model. 124 | ones = array_ops.ones_like(c[..., :1], dtype=c.dtype) 125 | matrix_b = array_ops.concat([c, ones], 2) # [b, n, d + 1] 126 | 127 | # [b, n + d + 1, n] 128 | left_block = array_ops.concat( 129 | [matrix_a, array_ops.transpose(matrix_b, [0, 2, 1])], 1) 130 | 131 | num_b_cols = matrix_b.get_shape()[2] # d + 1 132 | lhs_zeros = array_ops.zeros([b, num_b_cols, num_b_cols], train_points.dtype) 133 | right_block = array_ops.concat([matrix_b, lhs_zeros], 134 | 1) # [b, n + d + 1, d + 1] 135 | lhs = array_ops.concat([left_block, right_block], 136 | 2) # [b, n + d + 1, n + d + 1] 137 | 138 | rhs_zeros = array_ops.zeros([b, d + 1, k], train_points.dtype) 139 | rhs = array_ops.concat([f, rhs_zeros], 1) # [b, n + d + 1, k] 140 | 141 | # Then, solve the linear system and unpack the results. 142 | with ops.name_scope('solve_linear_system'): 143 | w_v = linalg_ops.matrix_solve(lhs, rhs) 144 | w = w_v[:, :n, :] 145 | v = w_v[:, n:, :] 146 | 147 | return w, v 148 | 149 | 150 | def _apply_interpolation(query_points, train_points, w, v, order): 151 | """Apply polyharmonic interpolation model to data. 152 | Given coefficients w and v for the interpolation model, we evaluate 153 | interpolated function values at query_points. 154 | Args: 155 | query_points: `[b, m, d]` x values to evaluate the interpolation at 156 | train_points: `[b, n, d]` x values that act as the interpolation centers 157 | ( the c variables in the wikipedia article) 158 | w: `[b, n, k]` weights on each interpolation center 159 | v: `[b, d, k]` weights on each input dimension 160 | order: order of the interpolation 161 | Returns: 162 | Polyharmonic interpolation evaluated at points defined in query_points. 163 | """ 164 | 165 | # First, compute the contribution from the rbf term. 166 | pairwise_dists = _cross_squared_distance_matrix(query_points, train_points) 167 | phi_pairwise_dists = _phi(pairwise_dists, order) 168 | 169 | rbf_term = math_ops.matmul(phi_pairwise_dists, w) 170 | 171 | # Then, compute the contribution from the linear term. 172 | # Pad query_points with ones, for the bias term in the linear model. 173 | query_points_pad = array_ops.concat([ 174 | query_points, 175 | array_ops.ones_like(query_points[..., :1], train_points.dtype) 176 | ], 2) 177 | linear_term = math_ops.matmul(query_points_pad, v) 178 | 179 | return rbf_term + linear_term 180 | 181 | 182 | def _phi(r, order): 183 | """Coordinate-wise nonlinearity used to define the order of the interpolation. 184 | See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. 185 | Args: 186 | r: input op 187 | order: interpolation order 188 | Returns: 189 | phi_k evaluated coordinate-wise on r, for k = r 190 | """ 191 | 192 | # using EPSILON prevents log(0), sqrt0), etc. 193 | # sqrt(0) is well-defined, but its gradient is not 194 | with ops.name_scope('phi'): 195 | if order == 1: 196 | r = math_ops.maximum(r, EPSILON) 197 | r = math_ops.sqrt(r) 198 | return r 199 | elif order == 2: 200 | return 0.5 * r * math_ops.log(math_ops.maximum(r, EPSILON)) 201 | elif order == 4: 202 | return 0.5 * math_ops.square(r) * math_ops.log( 203 | math_ops.maximum(r, EPSILON)) 204 | elif order % 2 == 0: 205 | r = math_ops.maximum(r, EPSILON) 206 | return 0.5 * math_ops.pow(r, 0.5 * order) * math_ops.log(r) 207 | else: 208 | r = math_ops.maximum(r, EPSILON) 209 | return math_ops.pow(r, 0.5 * order) 210 | 211 | 212 | def interpolate_spline(train_points, 213 | train_values, 214 | query_points, 215 | order, 216 | regularization_weight=0.0, 217 | name='interpolate_spline'): 218 | r"""Interpolate signal using polyharmonic interpolation. 219 | The interpolant has the form 220 | $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$ 221 | This is a sum of two terms: (1) a weighted sum of radial basis function (RBF) 222 | terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term with a bias. 223 | The \\(c_i\\) vectors are 'training' points. In the code, b is absorbed into v 224 | by appending 1 as a final dimension to x. The coefficients w and v are 225 | estimated such that the interpolant exactly fits the value of the function at 226 | the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), and the 227 | vector w sums to 0. With these constraints, the coefficients can be obtained 228 | by solving a linear system. 229 | \\(\phi\\) is an RBF, parametrized by an interpolation 230 | order. Using order=2 produces the well-known thin-plate spline. 231 | We also provide the option to perform regularized interpolation. Here, the 232 | interpolant is selected to trade off between the squared loss on the training 233 | data and a certain measure of its curvature 234 | ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)). 235 | Using a regularization weight greater than zero has the effect that the 236 | interpolant will no longer exactly fit the training data. However, it may be 237 | less vulnerable to overfitting, particularly for high-order interpolation. 238 | Note the interpolation procedure is differentiable with respect to all inputs 239 | besides the order parameter. 240 | We support dynamically-shaped inputs, where batch_size, n, and m are None 241 | at graph construction time. However, d and k must be known. 242 | Args: 243 | train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional 244 | locations. These do not need to be regularly-spaced. 245 | train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional values 246 | evaluated at train_points. 247 | query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations 248 | where we will output the interpolant's values. 249 | order: order of the interpolation. Common values are 1 for 250 | \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) (thin-plate spline), 251 | or 3 for \\(\phi(r) = r^3\\). 252 | regularization_weight: weight placed on the regularization term. 253 | This will depend substantially on the problem, and it should always be 254 | tuned. For many problems, it is reasonable to use no regularization. 255 | If using a non-zero value, we recommend a small value like 0.001. 256 | name: name prefix for ops created by this function 257 | Returns: 258 | `[b, m, k]` float `Tensor` of query values. We use train_points and 259 | train_values to perform polyharmonic interpolation. The query values are 260 | the values of the interpolant evaluated at the locations specified in 261 | query_points. 262 | """ 263 | with ops.name_scope(name): 264 | train_points = ops.convert_to_tensor(train_points) 265 | train_values = ops.convert_to_tensor(train_values) 266 | query_points = ops.convert_to_tensor(query_points) 267 | 268 | # First, fit the spline to the observed data. 269 | with ops.name_scope('solve'): 270 | w, v = _solve_interpolation(train_points, train_values, order, 271 | regularization_weight) 272 | 273 | # Then, evaluate the spline at the query locations. 274 | with ops.name_scope('predict'): 275 | query_values = _apply_interpolation(query_points, train_points, w, v, 276 | order) 277 | 278 | return query_values 279 | -------------------------------------------------------------------------------- /layers/rounding.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import tensorflow as tf 4 | from keras.engine.topology import Layer 5 | from tensorflow.python.framework import ops 6 | 7 | 8 | class Rounding(Layer): 9 | def __init__(self, **kwargs): 10 | super(Rounding, self).__init__(**kwargs) 11 | 12 | def build(self, input_shape): 13 | super(Rounding, self).build(input_shape) 14 | 15 | def call(self, x, **kwargs): 16 | return roundWithGrad(x) 17 | 18 | def compute_output_shape(self, input_shape): 19 | return input_shape 20 | 21 | 22 | # Define custom py_func which takes also a grad op as argument: 23 | def py_func(func, inp, Tout, stateful=True, name=None, grad=None): 24 | rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8)) # generate a unique name to avoid duplicates 25 | tf.RegisterGradient(rnd_name)(grad) 26 | g = tf.get_default_graph() 27 | with g.gradient_override_map({"PyFunc": rnd_name}): 28 | res = tf.py_func(func, inp, Tout, stateful=stateful, name=name) 29 | res[0].set_shape(inp[0].get_shape()) 30 | return res 31 | 32 | 33 | def roundWithGrad(x, name=None): 34 | with ops.name_scope(name, "roundWithGrad", [x]) as name: 35 | round_x = py_func(lambda x: np.round(x).astype('float32'), [x], [tf.float32], name=name, 36 | grad=_roundWithGrad_grad) # <-- here's the call to the gradient 37 | return round_x[0] 38 | 39 | 40 | def _roundWithGrad_grad(op, grad): 41 | x = op.inputs[0] 42 | return grad * 1 # do whatever with gradient here (e.g. could return grad * 2 * x if op was f(x)=x**2) 43 | -------------------------------------------------------------------------------- /layers/spade.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras.engine import Layer 3 | from keras.layers import LeakyReLU, Conv2D, Add, Lambda 4 | from keras_contrib.layers import InstanceNormalization 5 | 6 | 7 | def spade_block(conf, anatomy_input, layer, fin, fout): 8 | learn_shortcut = (fin != fout) 9 | fmiddle = min(fin, fout) 10 | 11 | l1 = _spade(conf, anatomy_input, layer, fin) 12 | l2 = LeakyReLU(0.2)(l1) 13 | l3 = Conv2D(fmiddle, 3, padding='same')(l2) 14 | 15 | l4 = _spade(conf, anatomy_input, l3, fmiddle) 16 | l5 = LeakyReLU(0.2)(l4) 17 | l6 = Conv2D(fout, 3, padding='same')(l5) 18 | 19 | if learn_shortcut: 20 | layer = _spade(conf, anatomy_input, layer, fin) 21 | layer = Conv2D(fout, 1, padding='same', use_bias=False)(layer) 22 | 23 | return Add()([layer, l6]) 24 | 25 | 26 | def _spade(conf, anatomy_input, layer, f): 27 | layer = InstanceNormalization(scale=False, center=False)(layer) 28 | anatomy = Lambda(resize_like, arguments={'ref_tensor': layer})(anatomy_input) 29 | anatomy = Conv2D(128, 3, padding='same', activation='relu')(anatomy) 30 | gamma = Conv2D(f, 3, padding='same')(anatomy) 31 | beta = Conv2D(f, 3, padding='same')(anatomy) 32 | return SPADE_COND()([layer, gamma, beta]) 33 | # return Add()([Multiply()([layer, gamma]), beta]) 34 | 35 | 36 | def resize_like(input_tensor, ref_tensor): # resizes input tensor wrt. ref_tensor 37 | H, W = ref_tensor.get_shape()[1], ref_tensor.get_shape()[2] 38 | return tf.image.resize_nearest_neighbor(input_tensor, [H.value, W.value]) 39 | 40 | 41 | class SPADE_COND(Layer): 42 | ''' 43 | The SPADE conditioning 44 | ''' 45 | 46 | def __init__(self, **kwargs): 47 | super(SPADE_COND, self).__init__(**kwargs) 48 | 49 | def build(self, input_shape): 50 | super(SPADE_COND, self).build(input_shape) 51 | 52 | def call(self, x, **kwargs): 53 | x, gamma, beta = x 54 | 55 | return x * (1 + gamma) + beta 56 | 57 | def compute_output_shape(self, input_shape): 58 | return input_shape -------------------------------------------------------------------------------- /layers/spectralnorm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file contains a keras Regularizer that can be used to encourage the largest sigular 3 | value (lsv) of either a Dense or a Conv2D layer to be <= 1. 4 | 5 | For Dense layers the linear transformation they perform is a multiplication by the 6 | weight matrix W (and then the addition of a bias term, but this does not affect the gradient 7 | so we can ignore it). So, for a Dense layer the lsv is just: lsv(W) 8 | 9 | For Conv2D layers things are more complicated. The weight matrix W does define the 10 | transformation, but it isn't defined simply as multiplication of the input by W, rather, 11 | the input is multiplied by a matrix that is some function of W (and the size of the 12 | input image, which we will call im_sz). Lets call the suitable function frmr, then for 13 | a Conv2D layer the lsv is: lsv(frmr(W,im_sz)). 14 | 15 | (As an asside, in the spectral normalization paper they seem to directly use lsv(W) even 16 | in the Conv2D case. I think this is technically wrong, but has a sort of similar result, 17 | in that it does provide some pressure for lsv(frmr(W,im_sz)) to be low-ish, but making 18 | lsv(W) <= 1 doesn't neccessarily imply lsv(frmr(W,im_sz)) <= 1) 19 | 20 | We need to define both lsv and frmr in a differentiable way so they can be used to 21 | train a network with backprop. The spectral normalization paper proposes a method to 22 | approximate lsv in a differentiable way (which we will come back to later), so we just 23 | need to work out a method for frmr. 24 | 25 | Lets assume that the Conv2D layer has padding and a stride of 1, and also assume the 26 | input is a single channel image, and we only have 1 filter. This means that the layer 27 | maps (im_sz, im_sz, 1) --> (im_sz, im_sz, 1). Lets define n = im_sz * imsz, then 28 | the layer defines a linear map from n-dimensional space to n-dimensional space. So, we 29 | can flatten the image to an n-dimensional vector, then multiply it by some n by n 30 | matrix, M, to get the n-dimensional vector output, which can be then reshaped into the 31 | output image. Thus M = frmr(W,im_sz) (and so lsv(M) = lsv(frmr(W,im_sz)) ). 32 | 33 | Specifically, M is made by arranging (and duplicating) the values of W into an n by n 34 | matrix. we do this in the function make_M(). 35 | 36 | frmr(W,im_sz) = f(W)*g(im_sz) 37 | 38 | a,b,c,d 39 | e,f,g,h 40 | i,j,k,l 41 | m,n,o,p 42 | 43 | ''' 44 | 45 | from keras import backend as K 46 | from keras.regularizers import Regularizer 47 | import numpy as np 48 | 49 | 50 | def my_im2col(img, W, pad=True, stride=1): 51 | # matrix will map from w*h*num_channels to (w/stride)*(h/stride)*num_filters 52 | 53 | w, h, num_channels = img.shape 54 | filter_width, filter_height, _, num_filters = W.shape 55 | 56 | if pad: 57 | w_pad, h_pad = 0, 0 58 | else: 59 | w_pad, h_pad = filter_width / 2, filter_height / 2 60 | 61 | M = np.zeros((((w - w_pad * 2) / stride) * ((h - h_pad * 2) / stride) * num_filters, w * h * num_channels)) 62 | 63 | # print '--' 64 | # print M.shape 65 | # print W.shape 66 | 67 | row = 0 68 | for filter in range(num_filters): 69 | for y_pos in range(h_pad, h - h_pad, stride): 70 | for x_pos in range(w_pad, w - w_pad, stride): 71 | for channel in range(num_channels): 72 | ind = x_pos + w * y_pos + w * h * channel 73 | for fx in range(-(filter_width / 2), (filter_width + 1) / 2): 74 | for fy in range(-(filter_height / 2), (filter_height + 1) / 2): 75 | if (0 <= x_pos + fx < w) and (0 <= y_pos + fy < h): 76 | # print row, ind+fx+fy*w 77 | # print (filter_width/2)+fx,(filter_height/2)+fy,channel,filter 78 | M[row, ind + fx + fy * w] = W[ 79 | (filter_width / 2) + fx, (filter_height / 2) + fy, channel, filter] 80 | row += 1 81 | 82 | return M 83 | 84 | 85 | def conv2d(a, f): 86 | s = f.shape + tuple(np.subtract(a.shape, f.shape) + 1) 87 | strd = numpy.lib.stride_tricks.as_strided 88 | subM = strd(a, shape=s, strides=a.strides * 2) 89 | return np.einsum('ij,ijkl->kl', f, subM) 90 | 91 | 92 | def my_im2col_np(img, W, pad=True, stride=1): 93 | w, h, num_channels = img.shape 94 | filter_width, filter_height, _, num_filters = W.shape 95 | 96 | if pad: 97 | w_pad, h_pad = 0, 0 98 | else: 99 | w_pad, h_pad = filter_width / 2, filter_height / 2 100 | 101 | M = np.zeros((((w - w_pad * 2) / stride) * ((h - h_pad * 2) / stride) * num_filters, w * h * num_channels)) 102 | 103 | indexes = [] 104 | values = [] 105 | for channel in range(num_channels): 106 | ind = x_pos + w * y_pos + w * h * channel 107 | for fx in range(-(filter_width / 2), (filter_width + 1) / 2): 108 | for fy in range(-(filter_height / 2), (filter_height + 1) / 2): 109 | if (0 <= x_pos + fx < w) and (0 <= y_pos + fy < h): 110 | M[row, ind + fx + fy * w] = W[(filter_width / 2) + fx, (filter_height / 2) + fy, channel, filter] 111 | 112 | return M 113 | 114 | 115 | def largestSingularValues(l, imsz=3): 116 | ''' 117 | This function takes a keras model as an input and returns a list of the 118 | largest singular values of each of its weights matricies. 119 | 120 | Useful for sanity checking. 121 | ''' 122 | 123 | layer_type = str(type(l)).split('.')[-1][:-2] 124 | 125 | if layer_type == 'Model': 126 | SVs = [] 127 | for sub_l in l.layers: 128 | if len(sub_l.get_weights()): 129 | SVs = SVs + largestSingularValues(sub_l, imsz) 130 | return SVs 131 | 132 | elif layer_type == 'Dense': 133 | W = l.get_weights()[0] 134 | # W = np.reshape(W, (-1,W.shape[-1])) 135 | _, s, _ = np.linalg.svd(W) 136 | return [s[0]] 137 | 138 | elif layer_type == 'Conv2D': 139 | # note, this is only an approximation. I think it's a lower bound? 140 | W = l.get_weights()[0] 141 | img = np.zeros((imsz, imsz, W.shape[2])) 142 | M = my_im2col(img, W) 143 | _, s, _ = np.linalg.svd(M) 144 | return [s[0]] 145 | 146 | else: 147 | return [] 148 | 149 | 150 | def largestSingularValues_old(l): 151 | ''' 152 | This function takes a keras model as an input and returns a list of the 153 | largest singular values of each of its weights matricies. 154 | 155 | Useful for sanity checking. 156 | ''' 157 | 158 | layer_type = str(type(l)).split('.')[-1][:-2] 159 | 160 | if layer_type == 'Model': 161 | SVs = [] 162 | for sub_l in l.layers: 163 | if len(sub_l.get_weights()): 164 | SVs = SVs + largestSingularValues_old(sub_l) 165 | return SVs 166 | 167 | elif layer_type == 'Dense' or layer_type == 'Conv2D': 168 | W = l.get_weights()[0] 169 | W = np.reshape(W, (-1, W.shape[-1])) 170 | _, s, _ = np.linalg.svd(W) 171 | return [s[0]] 172 | 173 | 174 | # def largestSingularValues(model): 175 | # ''' 176 | # This function takes a keras model as an input and returns a list of the 177 | # largest singular values of each of its weights matricies. 178 | 179 | # Useful for sanity checking. 180 | # ''' 181 | 182 | # SVs = [] 183 | # for l in model.layers: 184 | # if len(l.get_weights()): 185 | 186 | # layer_type = str(type(l)).split('.')[-1][:-2] 187 | 188 | # if layer_type == 'Dense': 189 | # W = l.get_weights()[0] 190 | # W = np.reshape(W, (-1,W.shape[-1])) 191 | # _, s, _ = np.linalg.svd(W) 192 | # SVs.append(s[0]) 193 | 194 | # if layer_type == 'Conv2D': 195 | # pass 196 | 197 | # return SVs 198 | 199 | class Spectral(Regularizer): 200 | ''' Spectral normalization regularizer 201 | # Arguments 202 | alpha = weight for regularization penalty 203 | ''' 204 | 205 | def __init__(self, dim, alpha=K.variable(10.)): 206 | ''' 207 | in a Conv2D layer dim needs to be num_channels in the previous layer times the filter_size^2 208 | in a Dense layer dim needs to be num_channels in the previous layer 209 | ''' 210 | 211 | self.dim = dim 212 | self.alpha = alpha # K.cast_to_floatx(alpha) 213 | self.u = K.variable(np.random.random((dim, 1)) * 2 - 1.) 214 | 215 | def __call__(self, x): 216 | # return K.mean(K.abs(x)) 217 | 218 | # print K.int_shape(x) 219 | 220 | x_shape = K.shape(x) 221 | x = K.reshape(x, (-1, x_shape[-1])) # this deals with convolutions, fingers crossed! 222 | # x = K.transpose(K.reshape(x, (-1, x_shape[-1]))) #this deals with convolutions, fingers crossed! 223 | 224 | # print K.int_shape(x) 225 | # print K.shape(self.u) 226 | 227 | # self.u = K.variable(np.random.random((self.dim,1))*2-1.) 228 | 229 | for itters in range(3): 230 | WTu = K.dot(K.transpose(x), self.u) 231 | v = WTu / K.sqrt(K.sum(K.square(WTu))) 232 | 233 | Wv = K.dot(x, v) 234 | self.u = Wv / K.sqrt(K.sum(K.square(Wv))) 235 | 236 | spectral_norm = K.dot(K.dot(K.transpose(self.u), x), v) 237 | 238 | target_x = K.stop_gradient(x / spectral_norm) 239 | return self.alpha * K.mean(K.abs(target_x - x)) 240 | 241 | # return self.alpha * K.switch(K.greater(spectral_norm, 1), spectral_norm, 0*spectral_norm) 242 | 243 | return self.alpha * K.abs(1 - spectral_norm) # + 0.3 * K.sum(K.abs(x)) 244 | 245 | def get_config(self): 246 | return {'alpha': float(self.alpha)} 247 | 248 | 249 | if __name__ == '__main__': 250 | 251 | from matplotlib import pyplot as plt 252 | 253 | # M[row, ind+fx+fy*w] = W[(filter_width/2)+fx,(filter_height/2)+fy,channel,filter] 254 | 255 | # row = 0 256 | # for filter in range(num_filters): 257 | # for y_pos in range(h_pad,h-h_pad,stride): 258 | # for x_pos in range(w_pad,w-w_pad,stride): 259 | # for channel in range(num_channels): 260 | # ind = x_pos + w*y_pos + w*h*channel 261 | # for fx in range(-(filter_width/2), (filter_width+1)/2): 262 | # for fy in range(-(filter_height/2), (filter_height+1)/2): 263 | # if (0 <= x_pos + fx < w) and (0 <= y_pos + fy < h): 264 | # # print row, ind+fx+fy*w 265 | # # print (filter_width/2)+fx,(filter_height/2)+fy,channel,filter 266 | # M[row, ind+fx+fy*w] = W[(filter_width/2)+fx,(filter_height/2)+fy,channel,filter] 267 | # row += 1 268 | 269 | 270 | # stride = 1 271 | # im_w, im_h = 5, 5 272 | # filter_width, filter_height, channels, filters = 3, 3, 2, 1 273 | # w_pad, h_pad = 0, 0 274 | 275 | # W = np.ones((filter_width, filter_height, channels, filters)) 276 | # M = np.zeros((((im_w-w_pad*2)/stride)*((im_h-h_pad*2)/stride)*filters, im_w*im_h*channels)) 277 | 278 | # i1 = np.array(range(filter_width)) 279 | # i2 = np.array(range(filter_height)) 280 | # i3 = np.array(range(channels)) 281 | # i4 = np.array(range(filters)) 282 | 283 | # M[i1, i2*3, i3*im_w*im_h] = W[i1, i2, i3] 284 | 285 | # print M 286 | # sys.exit() 287 | 288 | im_sz = 3 289 | channels = 1 290 | filter_sz = 3 291 | 292 | W = np.random.randn(filter_sz, filter_sz, channels, 1) 293 | 294 | img = np.random.randn(1, im_sz, im_sz, channels) 295 | M = my_im2col(img[0], W) 296 | base = np.linalg.svd(M)[1][0] 297 | 298 | previous = base 299 | 300 | img3 = np.random.randn(1, 3, 3, channels) 301 | img7 = np.random.randn(1, 7, 7, channels) 302 | 303 | X, Y = [], [] 304 | for i in range(100): 305 | W = np.random.randn(filter_sz, filter_sz, channels, 1) 306 | 307 | M = my_im2col(img3[0], W) 308 | lsv_3 = np.linalg.svd(M)[1][0] 309 | 310 | M = my_im2col(img7[0], W) 311 | lsv_7 = np.linalg.svd(M)[1][0] 312 | 313 | X.append(lsv_3) 314 | Y.append(lsv_7 / lsv_3) 315 | 316 | plt.scatter(X, Y) 317 | plt.show() 318 | -------------------------------------------------------------------------------- /layers/stn_spline.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras import Input, Model 3 | from keras.engine import Layer 4 | from keras.layers import Concatenate, MaxPooling2D, Conv2D, Flatten, Dense, Reshape, LeakyReLU 5 | 6 | from layers.interpolate_spline import interpolate_spline 7 | 8 | bilinear_interpolation = tf.contrib.resampler.resampler 9 | import numpy as np 10 | import logging 11 | log = logging.getLogger('stn_spline') 12 | 13 | 14 | class ThinPlateSpline2D(Layer): 15 | """ 16 | Keras layer for Thin Plate Spline interpolation. 17 | """ 18 | def __init__(self, input_volume_shape, cp_dims, num_channels, inverse=False, order=2, **kwargs): 19 | 20 | self.vol_shape = input_volume_shape 21 | self.data_dimensionality = len(input_volume_shape) 22 | self.cp_dims = cp_dims 23 | self.num_channels = num_channels 24 | self.initial_cp_grid = None 25 | self.flt_grid = None 26 | self.inverse = inverse 27 | self.order = order 28 | super(ThinPlateSpline2D, self).__init__(**kwargs) 29 | 30 | def build(self, input_shape): 31 | 32 | self.initial_cp_grid = nDgrid(self.cp_dims) 33 | self.flt_grid = nDgrid(self.vol_shape) 34 | 35 | super(ThinPlateSpline2D, self).build(input_shape) 36 | 37 | # @tf.contrib.eager.defun 38 | def interpolate_spline_batch(self, cp_offsets_single_batch): 39 | 40 | warped_cp_grid = self.initial_cp_grid + cp_offsets_single_batch 41 | 42 | if self.inverse: 43 | interpolated_sample_locations = interpolate_spline(train_points=warped_cp_grid, 44 | train_values=self.initial_cp_grid, 45 | query_points=self.flt_grid, 46 | order=self.order) 47 | else: 48 | interpolated_sample_locations = interpolate_spline(train_points=self.initial_cp_grid, 49 | train_values=warped_cp_grid, 50 | query_points=self.flt_grid, 51 | order=self.order) 52 | 53 | return interpolated_sample_locations 54 | 55 | def call(self, args): 56 | 57 | vol, cp_offsets = args 58 | 59 | interpolated_sample_locations = tf.map_fn(self.interpolate_spline_batch, cp_offsets)[:, 0] 60 | 61 | interpolated_sample_locations = tf.reverse(interpolated_sample_locations, axis=[-1]) 62 | 63 | interpolated_sample_locations = tf.multiply(interpolated_sample_locations, 64 | [self.vol_shape[1] - 1, self.vol_shape[0] - 1]) 65 | warped_volume = bilinear_interpolation(vol, interpolated_sample_locations) 66 | warped_volume = tf.reshape(warped_volume, (-1,) + tuple(self.vol_shape) + (self.num_channels,)) 67 | return warped_volume 68 | 69 | 70 | def nDgrid(dims, normalise=True, center=False, dtype='float32'): 71 | ''' 72 | returns the co-ordinates for an n-dimentional grid as a (num-points, n) shaped array 73 | e.g. dims=[3,3] would return: 74 | [[0,0],[0,1],[0,2],[1,0],[1,1],[1,2],[2,0],[2,1],[2,2]] 75 | if not normalized == False, or: 76 | [[0,0],[0,0.5],[0,1.],[0.5,0],[0.5,0.5],[0.5,1.],[1.,0],[1.,0.5],[1.,1.]] 77 | if normalized == True. 78 | ''' 79 | if len(dims) == 2: 80 | grid = np.expand_dims(np.mgrid[:dims[0], :dims[1]].reshape((2, -1)).T, 0) 81 | 82 | if len(dims) == 3: 83 | grid = np.expand_dims(np.mgrid[:dims[0], :dims[1], :dims[2]].reshape((3, -1)).T, 0) 84 | 85 | if normalise == True: 86 | grid = grid / (1. * (np.array([[dims]]) - 1)) 87 | 88 | if center == True: 89 | grid = (grid - 1) * 2 90 | 91 | return tf.cast(grid, dtype=dtype) 92 | 93 | 94 | def build_locnet(input_shape1, input_shape2, output_shape): 95 | """ 96 | Build STN for calculating the parameters of STN. 97 | :param input_shape1: shape of input tensor 1 98 | :param input_shape2: shape of input tensor 2 99 | :param output_shape: number of control points to predict 100 | :return: a Keras model 101 | """ 102 | input1 = Input(shape=input_shape1) 103 | input2 = Input(shape=input_shape2) 104 | stacked = Concatenate()([input1, input2]) 105 | 106 | l = Conv2D(20, 5)(stacked) 107 | l = LeakyReLU()(l) 108 | l = MaxPooling2D(pool_size=(2, 2))(l) 109 | l = Conv2D(20, 5)(l) 110 | l = LeakyReLU()(l) 111 | l = MaxPooling2D(pool_size=(2, 2))(l) 112 | l = Conv2D(20, 5)(l) 113 | l = LeakyReLU()(l) 114 | l = Flatten()(l) 115 | l = Dense(100, activation='tanh')(l) 116 | theta = Dense(output_shape, kernel_initializer='zeros', bias_initializer='zeros')(l) 117 | theta = Reshape((int(output_shape / 2), 2))(theta) 118 | m = Model(input=[input1, input2], output=theta, name='stn_locnet') 119 | m.summary(print_fn=log.info) 120 | return m 121 | -------------------------------------------------------------------------------- /loaders/MultimodalPairedData.py: -------------------------------------------------------------------------------- 1 | from loaders.data import Data 2 | import utils.data_utils 3 | import numpy as np 4 | import logging 5 | log = logging.getLogger('MultimodalPairedData') 6 | 7 | 8 | class MultimodalPairedData(Data): 9 | """ 10 | Container for multimodal data of image pairs. These are concatenated at the channel dimension 11 | """ 12 | def __init__(self, images, masks, index, downsample=1): 13 | super(MultimodalPairedData, self).__init__(images, masks, index, downsample) 14 | self.num_modalities = self.images.shape[-1] 15 | self.masks_per_mod = self.masks.shape[-1] // 2 16 | 17 | images_mod1 = self.images[..., 0:1] 18 | images_mod2 = self.images[..., 1:2] 19 | masks_mod1 = self.masks[..., 0:self.masks_per_mod] 20 | masks_mod2 = self.masks[..., self.masks_per_mod:] 21 | 22 | self.image_dict = {0: images_mod1, 1: images_mod2} 23 | self.masks_dict = {0: masks_mod1, 1: masks_mod2} 24 | 25 | del self.images 26 | del self.masks 27 | 28 | def get_images_modi(self, mod_i): 29 | return self.image_dict[mod_i] 30 | 31 | def get_masks_modi(self, mod_i): 32 | return self.masks_dict[mod_i] 33 | 34 | def set_images_modi(self, mod_i, images): 35 | self.image_dict[mod_i] = images 36 | 37 | def set_masks_modi(self, mod_i, masks): 38 | self.masks_dict[mod_i] = masks 39 | 40 | def get_volume_images_modi(self, mod_i, vol): 41 | return self.get_images_modi(mod_i)[self.index == vol] 42 | 43 | def get_volume_masks_modi(self, mod_i, vol): 44 | return self.get_masks_modi(mod_i)[self.index == vol] 45 | 46 | def filter_volumes(self, volumes): 47 | if len(volumes) == 0: 48 | for modi in range(self.num_modalities): 49 | self.set_images_modi(modi, np.array((0,) + self.image_shape)) 50 | self.set_masks_modi(modi, np.array((0,) + self.mask_shape)) 51 | self.index = np.array((0,) + self.index.shape[1:]) 52 | self.num_volumes = 0 53 | return 54 | 55 | for modi in range(self.num_modalities): 56 | self.set_images_modi(modi, np.concatenate([self.get_volume_images_modi(modi, v) for v in volumes], axis=0)) 57 | self.set_masks_modi(modi, np.concatenate([self.get_volume_masks_modi(modi, v) for v in volumes], axis=0)) 58 | 59 | self.index = np.concatenate([self.index.copy()[self.index == v] for v in volumes], axis=0) 60 | self.num_volumes = len(volumes) 61 | 62 | log.info('Filtered volumes: %s of total %d images' % (str(volumes), self.size())) 63 | 64 | def crop(self, shape): 65 | log.debug('Cropping images and masks to shape ' + str(shape)) 66 | for modi in range(self.num_modalities): 67 | [images], [masks] = utils.data_utils.crop_same([self.get_images_modi(modi)], [self.get_masks_modi(modi)], 68 | size=shape, pad_mode='constant') 69 | self.set_images_modi(modi, images) 70 | self.set_masks_modi(modi, masks) 71 | assert images.shape[1:-1] == masks.shape[1:-1] == tuple(shape), \ 72 | 'Invalid shapes: ' + str(images.shape[1:-1]) + ' ' + str(masks.shape[1:-1]) + ' ' + str(shape) 73 | 74 | def size(self): 75 | return np.max([self.get_images_modi(modi).shape[0] for modi in range(self.num_modalities)]) 76 | 77 | def sample_images(self, num, seed=-1): 78 | log.info('Sampling %d images out of total %d' % (num, self.size())) 79 | if seed > -1: 80 | np.random.seed(seed) 81 | 82 | idx = np.random.choice(self.size(), size=num, replace=False) 83 | for modi in range(self.num_modalities): 84 | images = self.get_images_modi(modi) 85 | masks = self.get_masks_modi(modi) 86 | 87 | self.set_images_modi(np.array([images[i] for i in idx])) 88 | self.set_masks_modi(np.array([masks[i] for i in idx])) 89 | self.index = np.array([self.index[i] for i in idx]) 90 | 91 | def expand_pairs(self, offsets, mod_i, neighborhood=2): 92 | """ 93 | Create more pairs by considering neighbour images. Change the object in-place 94 | :param offsets: number of neighbouring slices to pair with 95 | :param mod_i: which modality is enlarged 96 | :param neighborhood: the number of candidates. 97 | """ 98 | assert mod_i in [0, 1], 'mod_i can be in [0, 1]. It defines the neighborhood of which modality to enlarge' 99 | log.debug('Enlarge neighborhood with %d pairs' % offsets) 100 | 101 | all_images, all_labels, all_index = [], [], [] 102 | for vol in self.volumes(): 103 | img_mod1 = self.get_volume_images_modi(mod_i, vol) 104 | img_mod2 = self.get_volume_images_modi(1 - mod_i, vol) 105 | 106 | num_images = img_mod2.shape[0] 107 | vol_img_mod1 = [] 108 | for i in range(num_images): 109 | if img_mod1.shape[0] < 2 * offsets + 1: 110 | value_range = list(range(0, img_mod1.shape[0])) + [0] * (2 * offsets + 1 - img_mod1.shape[0]) 111 | elif i < offsets: 112 | value_range = list(range(0, 2 * offsets + 1)) 113 | elif i + offsets >= num_images: 114 | value_range = list(range(num_images - (2 * offsets + 1), num_images)) 115 | else: 116 | value_range = list(range(i - offsets, i + offsets + 1)) 117 | 118 | # rearrange values, such that the first value is the expertly paired one. 119 | value_range.insert(0, value_range.pop(value_range.index(i))) 120 | assert len(list(value_range)) == 2 * offsets + 1, \ 121 | 'Invalid length: %d vs %d' % (2 * offsets + 1, len(list(value_range))) 122 | 123 | if len(value_range) > neighborhood: 124 | new_value_range = [value_range[0]] 125 | new_value_range += list(np.random.choice(value_range[1:], size=neighborhood - 1, replace=False)) 126 | value_range = new_value_range 127 | assert len(value_range) <= neighborhood, "Exceeded maximum neighborhood size" 128 | 129 | neighbour_imgs = np.concatenate([img_mod1[index:index+1] for index in value_range], axis=-1) 130 | vol_img_mod1.append(neighbour_imgs) 131 | 132 | all_images.append(np.concatenate(vol_img_mod1, axis=0)) 133 | 134 | all_images = np.concatenate(all_images, axis=0) 135 | 136 | assert all_images.shape[-1] == neighborhood, '%s vs %s' % (all_images.shape[-1], neighborhood) 137 | 138 | if mod_i == 0: 139 | self.set_images_modi(0, all_images) 140 | elif mod_i == 1: 141 | self.set_images_modi(1, all_images) 142 | 143 | def randomise_pairs(self, length=3, seed=None): 144 | if seed is not None: 145 | np.random.seed(seed) 146 | log.debug('Randomising pairs within a volume') 147 | 148 | new_images, new_masks = [], [] 149 | for vol in self.volumes(): 150 | images = self.get_volume_images_modi(0, vol) 151 | masks = self.get_volume_masks_modi(0, vol) 152 | 153 | offsets = np.random.randint(-length, length, size=images.shape[0]) 154 | for off in range(length): 155 | if offsets[off] + off < 0: 156 | offsets[off] = np.random.randint(-off, length, size=1) 157 | 158 | for i in range(1, length): 159 | if offsets[-i] + range(images.shape[0])[-i] >= images.shape[0]: 160 | offsets[-i] = np.random.randint(-length, i, size=1) 161 | new_pair_index = np.array(range(images.shape[0])) + offsets 162 | 163 | new_images.append(images[new_pair_index]) 164 | new_masks.append(masks[new_pair_index]) 165 | 166 | self.set_images_modi(0, np.concatenate(new_images, axis=0)) 167 | self.set_masks_modi(0, np.concatenate(new_masks, axis=0)) 168 | 169 | def merge(self, other): 170 | log.info('Merging Data object of %d to this Data object of size %d' % (other.size(), self.size())) 171 | 172 | for mod in range(self.num_modalities): 173 | cur_img_mod = self.get_images_modi(mod) 174 | oth_img_mod = other.get_images_modi(mod) 175 | 176 | cur_msk_mod = self.get_masks_modi(mod) 177 | oth_msk_mod = other.get_masks_modi(mod) 178 | 179 | img_mod = np.concatenate([cur_img_mod, oth_img_mod], axis=0) 180 | msk_mod = np.concatenate([cur_msk_mod, oth_msk_mod], axis=0) 181 | 182 | self.set_images_modi(mod, img_mod) 183 | self.set_masks_modi(mod, msk_mod) 184 | 185 | self.index = np.concatenate([self.index, other.index], axis=0) 186 | assert self.get_images_modi(0).shape[0] == self.index.shape[0] 187 | 188 | self.num_volumes = len(self.volumes()) -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/multimodal_segmentation/a4fa1b39830f6c1bc320ff5b5e3fda82b8382e18/loaders/__init__.py -------------------------------------------------------------------------------- /loaders/base_loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import abstractmethod 3 | 4 | 5 | data_conf = { 6 | 'chaos': '../../data/Chaos/MR', 7 | } 8 | 9 | 10 | class Loader(object): 11 | """ 12 | Abstract class defining the behaviour of loaders for different datasets. 13 | """ 14 | def __init__(self, volumes=None): 15 | self.num_masks = 0 16 | self.num_volumes = 0 17 | self.input_shape = (None, None, 1) 18 | self.processed_folder = None 19 | if volumes is not None: 20 | self.volumes = volumes 21 | else: 22 | all_volumes = self.splits()[0]['training'] + self.splits()[0]['validation'] + self.splits()[0]['test'] 23 | self.volumes = sorted(all_volumes) 24 | self.log = logging.getLogger('loader') 25 | 26 | @abstractmethod 27 | def splits(self): 28 | """ 29 | :return: an array of splits into validation, test and train indices 30 | """ 31 | pass 32 | 33 | @abstractmethod 34 | def load_all_modalities_concatenated(self, split, split_type, downsample): 35 | """ 36 | Load multimodal data, and concatenate the images of the same volume/slice 37 | :param split: the split number 38 | :param split_type: training/validation/test 39 | :return: a Data object of multimodal images 40 | """ 41 | pass 42 | 43 | @abstractmethod 44 | def load_labelled_data(self, split, split_type, modality, normalise=True, downsample=1, root_folder=None): 45 | """ 46 | Load labelled data. 47 | :param split: the split number, e.g. 0, 1 48 | :param split_type: the split type, e.g. training, validation, test, all (for all data) 49 | :param modality: modality to load if the dataset has multimodal data 50 | :param normalise: True/False: normalise images to [-1, 1] 51 | :param downsample: downsample image ratio - used for for testing 52 | :param root_folder: root data folder 53 | :return: a Data object containing the loaded data 54 | """ 55 | pass 56 | 57 | @abstractmethod 58 | def load_unlabelled_data(self, split, split_type, modality, normalise=True, downsample=1): 59 | """ 60 | Load unlabelled data. 61 | :param split: the split number, e.g. 0, 1 62 | :param split_type: the split type, e.g. training, validation, test, all (for all data) 63 | :param modality: modality to load if the dataset has multimodal data 64 | :param normalise: True/False: normalise images to [-1, 1] 65 | :return: a Data object containing the loaded data 66 | """ 67 | pass 68 | 69 | @abstractmethod 70 | def load_all_data(self, split, split_type, modality, normalise=True, downsample=1): 71 | """ 72 | Load all images (labelled and unlabelled). 73 | :param split: the split number, e.g. 0, 1 74 | :param split_type: the split type, e.g. training, validation, test, all (for all data) 75 | :param modality: modality to load if the dataset has multimodal data 76 | :param normalise: True/False: normalise images to [-1, 1] 77 | :return: a Data object containing the loaded data 78 | """ 79 | pass 80 | 81 | def get_volumes_for_split(self, split, split_type): 82 | assert split_type in ['training', 'validation', 'test', 'all'], 'Unknown split_type: ' + split_type 83 | 84 | if split_type == 'all': 85 | volumes = sorted(self.splits()[split]['training'] + self.splits()[split]['validation'] + 86 | self.splits()[split]['test']) 87 | else: 88 | volumes = self.splits()[split][split_type] 89 | return volumes 90 | -------------------------------------------------------------------------------- /loaders/chaos.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | from scipy.ndimage import imread 7 | from skimage import transform 8 | 9 | sys.path.append('loaders') 10 | sys.path.append('.') 11 | from loaders.MultimodalPairedData import MultimodalPairedData 12 | from loaders.base_loader import Loader, data_conf 13 | from loaders.data import Data 14 | from loaders.dcm_contour_utils import DicomImage 15 | from utils import data_utils 16 | import nibabel as nib 17 | log = logging.getLogger('chaos') 18 | 19 | 20 | class ChaosLoader(Loader): 21 | # average resolution 1.61 22 | def __init__(self): 23 | self.volumes = [1,2,3,5,8,10,13,15,19,20,21,22,31,32,33,34,36,37,38,39] 24 | super(ChaosLoader, self).__init__(self.volumes) 25 | self.num_masks = 4 # liver, right kidney, left kidney, spleen 26 | self.input_shape = (192, 192, 1) 27 | self.data_folder = data_conf['chaos'] 28 | self.num_volumes = len(self.volumes) 29 | self.log = logging.getLogger('chaos') 30 | self.modalities = ['t1', 't2'] 31 | 32 | def splits(self): 33 | return [ 34 | {'validation': [31, 36, 13], 35 | 'test': [10, 22, 34], 36 | 'training': [5, 3, 1, 15, 19, 2, 20, 37, 32, 38, 8, 39, 21, 33] 37 | }, 38 | { 39 | 'validation': [13, 3, 20], 40 | 'test': [5, 15, 39], 41 | 'training': [33, 8, 38, 34, 36, 31, 32, 37, 22, 2, 1, 10, 19, 21] 42 | }, 43 | { 44 | 'validation': [37, 13, 33], 45 | 'test': [1, 19, 32], 46 | 'training': [5, 20, 31, 2, 38, 3, 8, 15, 22, 10, 34, 39, 36, 21] 47 | } 48 | ] 49 | 50 | def load_all_data(self, split, split_type, modality, normalise=True, downsample=1): 51 | return self.load_labelled_data(split, split_type, modality, normalise, downsample) 52 | 53 | def load_unlabelled_data(self, split, split_type, modality, normalise=True, downsample=1): 54 | return self.load_labelled_data(split, split_type, modality, normalise, downsample) 55 | 56 | def load_labelled_data(self, split, split_type, modality, normalise=True, downsample=1, root_folder=None): 57 | data = self.load_all_modalities_concatenated(split, split_type, downsample) 58 | images_t1 = data.get_images_modi(0) 59 | images_t2 = data.get_images_modi(1) 60 | labels_t1 = data.get_masks_modi(0) 61 | labels_t2 = data.get_masks_modi(1) 62 | 63 | if modality == 'all': 64 | all_images = np.concatenate([images_t1, images_t2], axis=0) 65 | all_labels = np.concatenate([labels_t1, labels_t2], axis=0) 66 | all_index = np.concatenate([data.index, data.index.copy()], axis=0) 67 | elif modality == 't1': 68 | all_images = images_t1 69 | all_labels = labels_t1 70 | all_index = data.index 71 | elif modality == 't2': 72 | all_images = images_t2 73 | all_labels = labels_t2 74 | all_index = data.index 75 | else: 76 | raise Exception('Unknown modality: %s' % modality) 77 | 78 | assert split_type in ['training', 'validation', 'test', 'all'], split_type 79 | assert all_images.max() - 1 < 0.01 and all_images.min() + 1 < 0.01, \ 80 | 'max: %.3f, min: %.3f' % (all_images.max(), all_images.min()) 81 | 82 | self.log.debug('Loaded compressed data of shape: ' + str(all_images.shape) + ' ' + str(all_index.shape)) 83 | 84 | if split_type == 'all': 85 | return Data(all_images, all_labels, all_index, 1) 86 | 87 | volumes = self.splits()[split][split_type] 88 | all_images = np.concatenate([all_images[all_index == v] for v in volumes]) 89 | 90 | assert all_labels.max() == 1 and all_labels.min() == 0, \ 91 | 'max: %d - min: %d' % (all_labels.max(), all_labels.min()) 92 | 93 | all_masks = np.concatenate([all_labels[all_index == v] for v in volumes]) 94 | assert all_images.shape[0] == all_masks.shape[0] 95 | all_index = np.concatenate([all_index[all_index == v] for v in volumes]) 96 | assert all_images.shape[0] == all_index.shape[0] 97 | 98 | self.log.debug(split_type + ' set: ' + str(all_images.shape)) 99 | return Data(all_images, all_masks, all_index, 1) 100 | 101 | def load_all_modalities_concatenated(self, split, split_type, downsample=1): 102 | all_images_t1, all_labels_t1, all_images_t2, all_labels_t2, all_index = [], [], [], [], [] 103 | volumes = self.get_volumes_for_split(split, split_type) 104 | for v in volumes: 105 | images_t1, labels_t1 = self._load_volume(v, 't1') 106 | images_t2, labels_t2 = self._load_volume(v, 't2') 107 | 108 | # for each CHAOS subject, create pairs of T1 and T2 slices that approximately correspond to the same 109 | # position in the 3D volume, i.e. contain the same anatomical parts 110 | if v == 1: 111 | images_t2 = images_t2[1:] 112 | labels_t2 = labels_t2[1:] 113 | 114 | images_t1 = images_t1[0:26] 115 | labels_t1 = labels_t1[0:26] 116 | images_t2 = images_t2[4:24] 117 | labels_t2 = labels_t2[4:24] 118 | 119 | images_t1 = np.concatenate([images_t1[0:5], images_t1[7:10], images_t1[13:17], images_t1[18:]], axis=0) 120 | labels_t1 = np.concatenate([labels_t1[0:5], labels_t1[7:10], labels_t1[13:17], labels_t1[18:]], axis=0) 121 | if v == 2: 122 | images_t1 = np.concatenate([images_t1[4:7], images_t1[8:23]], axis=0) 123 | labels_t1 = np.concatenate([labels_t1[4:7], labels_t1[8:23]], axis=0) 124 | images_t2 = images_t2[3:22] 125 | labels_t2 = labels_t2[3:22] 126 | 127 | images_t1 = np.concatenate([images_t1[0:11], images_t1[12:18]], axis=0) 128 | labels_t1 = np.concatenate([labels_t1[0:11], labels_t1[12:18]], axis=0) 129 | images_t2 = np.concatenate([images_t2[0:11], images_t2[12:18]], axis=0) 130 | labels_t2 = np.concatenate([labels_t2[0:11], labels_t2[12:18]], axis=0) 131 | if v == 3: 132 | images_t1 = np.concatenate([images_t1[11:14], images_t1[15:26]], axis=0) 133 | labels_t1 = np.concatenate([labels_t1[11:14], labels_t1[15:26]], axis=0) 134 | images_t2 = images_t2[9:23] 135 | labels_t2 = labels_t2[9:23] 136 | if v == 5: 137 | images_t1 = np.concatenate([images_t1[4:5], images_t1[8:24]], axis=0) 138 | labels_t1 = np.concatenate([labels_t1[4:5], labels_t1[8:24]], axis=0) 139 | images_t2 = images_t2[2:22] 140 | labels_t2 = labels_t2[2:22] 141 | 142 | images_t2 = np.concatenate([images_t2[0:6], images_t2[9:]], axis=0) 143 | labels_t2 = np.concatenate([labels_t2[0:6], labels_t2[9:]], axis=0) 144 | 145 | images_t1 = np.concatenate([images_t1[0:8], images_t1[9:]], axis=0) 146 | labels_t1 = np.concatenate([labels_t1[0:8], labels_t1[9:]], axis=0) 147 | images_t2 = np.concatenate([images_t2[0:8], images_t2[9:]], axis=0) 148 | labels_t2 = np.concatenate([labels_t2[0:8], labels_t2[9:]], axis=0) 149 | if v == 8: 150 | images_t1 = images_t1[2:-2] 151 | labels_t1 = labels_t1[2:-2] 152 | 153 | images_t1 = np.concatenate([images_t1[5:11], images_t1[12:27]], axis=0) 154 | labels_t1 = np.concatenate([labels_t1[5:11], labels_t1[12:27]], axis=0) 155 | images_t2 = images_t2[6:27] 156 | labels_t2 = labels_t2[6:27] 157 | if v == 10: 158 | images_t1 = images_t1[14:38] 159 | labels_t1 = labels_t1[14:38] 160 | images_t2 = images_t2[5:24] 161 | labels_t2 = labels_t2[5:24] 162 | 163 | images_t1 = np.concatenate([images_t1[0:8], images_t1[12:18], images_t1[19:]], axis=0) 164 | labels_t1 = np.concatenate([labels_t1[0:8], labels_t1[12:18], labels_t1[19:]], axis=0) 165 | if v == 13: 166 | images_t1 = images_t1[4:29] 167 | labels_t1 = labels_t1[4:29] 168 | images_t2 = images_t2[3:28] 169 | labels_t2 = labels_t2[3:28] 170 | if v == 15: 171 | images_t1 = images_t1[:22] 172 | labels_t1 = labels_t1[:22] 173 | images_t2 = images_t2[:22] 174 | labels_t2 = labels_t2[:22] 175 | if v == 19: 176 | images_t1 = images_t1[8:27] 177 | labels_t1 = labels_t1[8:27] 178 | images_t2 = images_t2[5:24] 179 | labels_t2 = labels_t2[5:24] 180 | if v == 20: 181 | images_t1 = images_t1[2:21] 182 | labels_t1 = labels_t1[2:21] 183 | images_t2 = images_t2[2:21] 184 | labels_t2 = labels_t2[2:21] 185 | if v == 21: 186 | images_t1 = images_t1[3:19] 187 | labels_t1 = labels_t1[3:19] 188 | images_t2 = images_t2[5:21] 189 | labels_t2 = labels_t2[5:21] 190 | if v == 22: 191 | images_t1 = images_t1[:-2] 192 | labels_t1 = labels_t1[:-2] 193 | 194 | images_t1 = np.concatenate([images_t1[8:17], images_t1[18:26]], axis=0) 195 | labels_t1 = np.concatenate([labels_t1[8:17], labels_t1[18:26]], axis=0) 196 | images_t2 = np.concatenate([images_t2[3:12], images_t2[15:23]], axis=0) 197 | labels_t2 = np.concatenate([labels_t2[3:12], labels_t2[15:23]], axis=0) 198 | if v == 31: 199 | images_t1 = images_t1[7:23] 200 | labels_t1 = labels_t1[7:23] 201 | images_t2 = np.concatenate([images_t2[5:12], images_t2[13:22]], axis=0) 202 | labels_t2 = np.concatenate([labels_t2[5:12], labels_t2[13:22]], axis=0) 203 | if v == 32: 204 | images_t1 = images_t1[5:32] 205 | labels_t1 = labels_t1[5:32] 206 | 207 | images_t2 = images_t2[3:30] 208 | labels_t2 = labels_t2[3:30] 209 | if v == 33: 210 | images_t1 = images_t1[7:-5] 211 | labels_t1 = labels_t1[7:-5] 212 | images_t2 = np.concatenate([images_t2[3:12], images_t2[15:-2]], axis=0) 213 | labels_t2 = np.concatenate([labels_t2[3:12], labels_t2[15:-2]], axis=0) 214 | if v == 34: 215 | images_t1 = np.concatenate([images_t1[1:2], images_t1[3:4], images_t1[5:6], images_t1[7:27]], axis=0) 216 | labels_t1 = np.concatenate([labels_t1[1:2], labels_t1[3:4], labels_t1[5:6], labels_t1[7:27]], axis=0) 217 | images_t1 = np.concatenate([images_t1[0:14], images_t1[15:16], images_t1[17:18], images_t1[19:22], images_t1[23:24]], axis=0) 218 | labels_t1 = np.concatenate([labels_t1[0:14], labels_t1[15:16], labels_t1[17:18], labels_t1[19:22], labels_t1[23:24]], axis=0) 219 | images_t2 = images_t2[2:21] 220 | labels_t2 = labels_t2[2:21] 221 | if v == 36: 222 | images_t1 = images_t1[8:25] 223 | labels_t1 = labels_t1[8:25] 224 | images_t2 = np.concatenate([images_t2[4:6], images_t2[7:22]], axis=0) 225 | labels_t2 = np.concatenate([labels_t2[4:6], labels_t2[7:22]], axis=0) 226 | if v == 37: 227 | images_t1 = np.concatenate([images_t1[9:23], images_t1[24:-1]], axis=0) 228 | labels_t1 = np.concatenate([labels_t1[9:23], labels_t1[24:-1]], axis=0) 229 | images_t2 = np.concatenate([images_t2[4:6], images_t2[7:21], images_t2[22:-7]], axis=0) 230 | labels_t2 = np.concatenate([labels_t2[4:6], labels_t2[7:21], labels_t2[22:-7]], axis=0) 231 | if v == 38: 232 | images_t1 = images_t1[9:24] 233 | labels_t1 = labels_t1[9:24] 234 | images_t2 = images_t2[9:24] 235 | labels_t2 = labels_t2[9:24] 236 | if v == 39: 237 | images_t1 = images_t1[3:22] 238 | labels_t1 = labels_t1[3:22] 239 | images_t2 = images_t2[3:22] 240 | labels_t2 = labels_t2[3:22] 241 | 242 | images_t1 = np.concatenate([data_utils.rescale(images_t1[i:i + 1], -1, 1) for i in range(images_t1.shape[0])]) 243 | images_t2 = np.concatenate([data_utils.rescale(images_t2[i:i + 1], -1, 1) for i in range(images_t2.shape[0])]) 244 | 245 | assert images_t1.max() == 1 and images_t1.min() == -1, '%.3f to %.3f' % (images_t1.max(), images_t1.min()) 246 | assert images_t2.max() == 1 and images_t2.min() == -1, '%.3f to %.3f' % (images_t2.max(), images_t2.min()) 247 | 248 | all_images_t1.append(images_t1) 249 | all_labels_t1.append(labels_t1) 250 | all_images_t2.append(images_t2) 251 | all_labels_t2.append(labels_t2) 252 | 253 | all_index.append(np.array([v] * images_t1.shape[0])) 254 | 255 | all_images_t1, all_labels_t1 = data_utils.crop_same(all_images_t1, all_labels_t1, self.input_shape[:-1]) 256 | all_images_t2, all_labels_t2 = data_utils.crop_same(all_images_t2, all_labels_t2, self.input_shape[:-1]) 257 | 258 | all_images_t1 = np.concatenate(all_images_t1, axis=0) 259 | all_labels_t1 = np.concatenate(all_labels_t1, axis=0) 260 | all_images_t2 = np.concatenate(all_images_t2, axis=0) 261 | all_labels_t2 = np.concatenate(all_labels_t2, axis=0) 262 | 263 | if self.modalities == ['t1', 't2']: 264 | all_images = np.concatenate([all_images_t1, all_images_t2], axis=-1) 265 | all_labels = np.concatenate([all_labels_t1, all_labels_t2], axis=-1) 266 | elif self.modalities == ['t2', 't1']: 267 | all_images = np.concatenate([all_images_t2, all_images_t1], axis=-1) 268 | all_labels = np.concatenate([all_labels_t2, all_labels_t1], axis=-1) 269 | else: 270 | raise ValueError('invalid self.modalities', self.modalities) 271 | all_index = np.concatenate(all_index, axis=0) 272 | 273 | assert all_labels.max() == 1 and all_labels.min() == 0, '%.3f to %.3f' % (all_labels.max(), all_labels.min()) 274 | return MultimodalPairedData(all_images, all_labels, all_index, downsample=downsample) 275 | 276 | def _load_volume(self, volume, modality): 277 | if modality == 't1': 278 | folder = self.data_folder + '/%d/T1DUAL' % volume 279 | image_folder = folder + '/DICOM_anon/OutPhase' 280 | elif modality == 't2': 281 | folder = self.data_folder + '/%d/T2SPIR' % volume 282 | image_folder = folder + '/DICOM_anon' 283 | else: 284 | raise Exception('Unknown modality') 285 | labels_folder = folder + '/Ground' 286 | 287 | image_files = list(os.listdir(image_folder)) 288 | image_files.sort(key=lambda x: x.split('-')[-1], reverse=True) 289 | images_dcm = [DicomImage(image_folder + '/' + f) for f in image_files] 290 | images = np.concatenate([np.expand_dims(np.expand_dims(dcm.image, 0), -1) for dcm in images_dcm], axis=0) 291 | 292 | label_files = list(os.listdir(labels_folder)) 293 | label_files.sort(key=lambda x: x.split('-')[-1], reverse=True) 294 | labels = [imread(labels_folder + '/' + f) for f in label_files] 295 | labels = np.concatenate([np.expand_dims(np.expand_dims(l, 0), -1) for l in labels], axis=0) 296 | 297 | res = images_dcm[0].resolution[0:2] 298 | images = np.concatenate([np.expand_dims(resample(images[i], res), axis=0) for i in range(images.shape[0])], 299 | axis=0) 300 | labels = np.concatenate([np.expand_dims(resample(labels[i], res, binary=True), axis=0) 301 | for i in range(labels.shape[0])], axis=0) 302 | 303 | labels_l1 = labels.copy() 304 | labels_l1[labels != 63] = 0 305 | labels_l1[labels == 63] = 1 306 | 307 | labels_l2 = labels.copy() 308 | labels_l2[labels != 126] = 0 309 | labels_l2[labels == 126] = 1 310 | 311 | labels_l3 = labels.copy() 312 | labels_l3[labels != 189] = 0 313 | labels_l3[labels == 189] = 1 314 | 315 | labels_l4 = labels.copy() 316 | labels_l4[labels != 252] = 0 317 | labels_l4[labels == 252] = 1 318 | 319 | labels = np.concatenate([labels_l1, labels_l2, labels_l3, labels_l4], axis=-1) 320 | 321 | return images, labels 322 | 323 | 324 | def resample(image, old_res, binary=False): 325 | """ 326 | Resample all volumes to the same resolution 327 | :param image: an image slice 328 | :param old_res: the original image resolution 329 | :param binary: flag to denote a segmentation mask 330 | :return: a resampled image 331 | """ 332 | new_res = (1.89, 1.89) 333 | scale_vector = (old_res[0] / new_res[0], old_res[1] / new_res[1]) 334 | order = 0 if binary else 1 335 | 336 | assert len(image.shape) == 3 337 | 338 | result = [] 339 | for i in range(image.shape[-1]): 340 | im = image[..., i] 341 | rescaled = transform.rescale(im, scale_vector, order=order, preserve_range=True, mode='constant') 342 | result.append(np.expand_dims(rescaled, axis=-1)) 343 | return np.concatenate(result, axis=-1) 344 | -------------------------------------------------------------------------------- /loaders/data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | from skimage.measure import block_reduce 6 | 7 | import utils.data_utils 8 | import utils.image_utils 9 | 10 | log = logging.getLogger('data') 11 | 12 | 13 | class Data(object): 14 | def __init__(self, images, masks, index, downsample=1): 15 | """ 16 | Data constructor. 17 | :param images: a 4-D numpy array of images. Expected shape: (N, H, W, 1) 18 | :param masks: a 4-D numpy array of segmentation masks. Expected shape: (N, H, W, L) 19 | :param index: a 1-D numpy array indicating the volume each image/mask belongs to. Used for data selection. 20 | :param downsample: factor to downsample images. 21 | """ 22 | assert images.shape[:-1] == masks.shape[:-1] 23 | assert images.shape[0] == index.shape[0] 24 | 25 | self.image_shape = images.shape[1:] 26 | self.mask_shape = masks.shape[1:] 27 | 28 | self.images = images 29 | self.masks = masks 30 | self.index = index 31 | self.num_volumes = len(self.volumes()) 32 | 33 | self.downsample(downsample) 34 | 35 | log.info( 36 | 'Creating Data object with images of shape %s and %d volumes' % (str(self.images.shape), self.num_volumes)) 37 | log.info('Images value range [%.1f, %.1f]' % (images.min(), images.max())) 38 | log.info('Masks value range [%.1f, %.1f]' % (masks.min(), masks.max())) 39 | 40 | def copy(self): 41 | return Data(np.copy(self.images), np.copy(self.masks), np.copy(self.index)) 42 | 43 | def merge(self, other): 44 | assert self.images.shape[1:] == other.images.shape[1:], str(self.images.shape) + ' vs ' + str( 45 | other.images.shape) 46 | assert self.masks.shape[1:] == other.masks.shape[1:], str(self.masks.shape) + ' vs ' + str(other.masks.shape) 47 | 48 | log.info('Merging Data object of %d to this Data object of size %d' % (other.size(), self.size())) 49 | 50 | self.images = np.concatenate([self.images, other.images], axis=0) 51 | self.masks = np.concatenate([self.masks, other.masks], axis=0) 52 | self.index = np.concatenate([self.index, other.index], axis=0) 53 | self.num_volumes = len(self.volumes()) 54 | 55 | def shuffle(self): 56 | idx = np.array(range(self.images.shape[0])) 57 | np.random.shuffle(idx) 58 | self.images = self.images[idx] 59 | self.masks = self.masks[idx] 60 | self.index = self.index[idx] 61 | 62 | def crop(self, shape): 63 | log.debug('Cropping images and masks to shape ' + str(shape)) 64 | [images], [masks] = utils.data_utils.crop_same([self.images], [self.masks], size=shape, pad_mode='constant') 65 | self.images = images 66 | self.masks = masks 67 | assert self.images.shape[1:-1] == self.masks.shape[1:-1] == tuple(shape), \ 68 | 'Invalid shapes: ' + str(self.images.shape[1:-1]) + ' ' + str(self.masks.shape[1:-1]) + ' ' + str(shape) 69 | 70 | def volumes(self): 71 | return sorted(set(self.index)) 72 | 73 | def get_images(self, vol): 74 | return self.images[self.index == vol] 75 | 76 | def get_masks(self, vol): 77 | return self.masks[self.index == vol] 78 | 79 | def size(self): 80 | return len(self.images) 81 | 82 | def sample_per_volume(self, num, seed=-1): 83 | log.info('Sampling %d from each volume' % num) 84 | if seed > -1: 85 | np.random.seed(seed) 86 | 87 | new_images, new_masks, new_scanner, new_index = [], [], [], [] 88 | for vol in self.volumes(): 89 | images = self.get_images(vol) 90 | masks = self.get_masks(vol) 91 | 92 | if images.shape[0] < num: 93 | log.debug('Volume %s contains less images: %d < %d. Sampling %d images.' % 94 | (str(vol), images.shape[0], num, images.shape[0])) 95 | idx = range(len(images)) 96 | else: 97 | idx = np.random.choice(images.shape[0], size=num, replace=False) 98 | 99 | images = np.array([images[i] for i in idx]) 100 | masks = np.array([masks[i] for i in idx]) 101 | index = np.array([vol] * num) 102 | 103 | new_images.append(images) 104 | new_masks.append(masks) 105 | new_index.append(index) 106 | 107 | self.images = np.concatenate(new_images, axis=0) 108 | self.masks = np.concatenate(new_masks, axis=0) 109 | self.index = np.concatenate(new_index, axis=0) 110 | 111 | log.info('Sampled %d images.' % len(self.images)) 112 | 113 | def sample_images(self, num, seed=-1): 114 | log.info('Sampling %d images out of total %d' % (num, self.size())) 115 | if seed > -1: 116 | np.random.seed(seed) 117 | 118 | idx = np.random.choice(self.size(), size=num, replace=False) 119 | self.images = np.array([self.images[i] for i in idx]) 120 | self.masks = np.array([self.masks[i] for i in idx]) 121 | self.index = np.array([self.index[i] for i in idx]) 122 | 123 | def get_sample_volumes(self, num, seed=-1): 124 | log.info('Sampling %d volumes out of total %d' % (num, self.num_volumes)) 125 | if seed > -1: 126 | np.random.seed(seed) 127 | 128 | volumes = np.random.choice(self.volumes(), size=num, replace=False) 129 | return volumes 130 | 131 | def sample(self, num, seed=-1): 132 | if num == self.num_volumes: 133 | return 134 | 135 | volumes = self.get_sample_volumes(num, seed) 136 | self.filter_volumes(volumes) 137 | 138 | def filter_volumes(self, volumes): 139 | if len(volumes) == 0: 140 | self.images = np.array((0,) + self.images.shape[1:]) 141 | self.masks = np.array((0,) + self.masks.shape[1:]) 142 | self.index = np.array((0,) + self.index.shape[1:]) 143 | self.num_volumes = 0 144 | return 145 | 146 | self.images = np.concatenate([self.get_images(v) for v in volumes], axis=0) 147 | self.masks = np.concatenate([self.get_masks(v) for v in volumes], axis=0) 148 | self.index = np.concatenate([self.index.copy()[self.index == v] for v in volumes], axis=0) 149 | self.num_volumes = len(volumes) 150 | 151 | log.info('Filtered volumes: %s of total %d images' % (str(volumes), self.size())) 152 | 153 | def shape(self): 154 | return self.image_shape 155 | 156 | def downsample(self, ratio=2): 157 | if ratio == 1: return 158 | 159 | self.images = block_reduce(self.images, block_size=(1, ratio, ratio, 1), func=np.mean) 160 | if self.masks is not None: 161 | self.masks = block_reduce(self.masks, block_size=(1, ratio, ratio, 1), func=np.mean) 162 | 163 | log.info('Downsampled data by %d to shape %s' % (ratio, str(self.images.shape))) 164 | 165 | def save(self, folder): 166 | if not os.path.exists(folder): 167 | os.makedirs(folder) 168 | 169 | for i in range(self.images.shape[0]): 170 | np.savez_compressed(folder + '/images_%d' % i, self.images[i:i+1]) 171 | np.savez_compressed(folder + '/masks_%d' % i, self.masks[i:i + 1]) 172 | -------------------------------------------------------------------------------- /loaders/dcm_contour_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import dicom 6 | import scipy 7 | import numpy as np 8 | 9 | class DicomImage(object): 10 | """ 11 | Object for basic information of a DICOM image. 12 | """ 13 | def __init__(self, dcm_image_file): 14 | print('Reading ' + dcm_image_file) 15 | assert os.path.exists(dcm_image_file) 16 | dcm_image = dicom.read_file(dcm_image_file) 17 | 18 | # this is unique for all images of the same patient 19 | self.patient = str(dcm_image.PatientName) if hasattr(dcm_image, 'PatientName') else None 20 | 21 | # this is a series identifier 22 | self.series = int(dcm_image.SeriesNumber) if hasattr(dcm_image, 'SeriesNumber') else None 23 | 24 | # the instance within the series 25 | self.instance = int(dcm_image.InstanceNumber) if hasattr(dcm_image, 'InstanceNumber') else None 26 | self.id = None # a unique id of the image 27 | self.resolution = [float(i) for i in dcm_image.PixelSpacing] + [float(dcm_image.SpacingBetweenSlices)] 28 | self.image = dcm_image.pixel_array 29 | self.age = dcm_image.PatientAge 30 | 31 | def save(self, folder): 32 | scipy.misc.imsave(folder + '/image_%s.png' % str(self.id), self.image) 33 | np.savez_compressed(folder + '/image_%s' % str(self.id), self.image) 34 | 35 | 36 | class Coordinates(object): 37 | """ 38 | Class for storing contours for a cardiac phase 39 | """ 40 | def __init__(self): 41 | self.endo = None 42 | self.epi = None 43 | 44 | class Contour(object): 45 | """ 46 | Simple object for basic information of a contour 47 | """ 48 | def __init__(self, contour_file): 49 | self.contour_file = contour_file 50 | self.patient_name = None # patient name. Also the parent folder of the images 51 | self.series = None 52 | self.series_description = None # the common identifier of the folders that contain the images 53 | self.coordinates = defaultdict(lambda: defaultdict(lambda: Coordinates())) # a dictionary of Coordinates keyed by slice id and phase 54 | self.gender = None 55 | self.birth_date = None 56 | self.study_date = None 57 | self.weight = None 58 | self.height = None 59 | self.age = None 60 | self.es = None 61 | self.ed = None 62 | self.read_file() 63 | 64 | 65 | def read_file(self): 66 | with open(self.contour_file, 'r') as fd: 67 | last_slice = (-1, -1) # (slice, phase) 68 | while True: 69 | l = fd.readline() 70 | if l == '': break 71 | 72 | if 'Patient_name=' in l: 73 | self.patient_name = l.split('Patient_name=')[1].split('\n')[0] 74 | 75 | if 'Series=' in l: 76 | self.series = l.split('Series=')[1].split('\n')[0] 77 | 78 | if 'Series_description=' in l: 79 | self.series_description = l.split('Series_description=')[1].split('/')[0]\ 80 | .strip().replace(' ', '_').replace('.', '_') 81 | 82 | if 'Patient_gender' in l: 83 | self.gender = l.split('Patient_gender=')[1].split('\n')[0] 84 | 85 | if 'birth_date' in l: 86 | self.birth_date = l.split('Birth_date=')[1].split('\n')[0] 87 | 88 | if 'Study_date' in l: 89 | self.study_date = l.split('Study_date=')[1].split('\n')[0] 90 | 91 | if 'Patient_weight' in l: 92 | self.weight = l.split('Patient_weight=')[1].split('\n')[0] 93 | 94 | if 'Patient_height' in l: 95 | self.height = l.split('Patient_height=')[1].split('\n')[0] 96 | 97 | if 'manual_lv_es_phase' in l: 98 | self.es = int(l.split('manual_lv_es_phase=')[1].split('\n')[0]) + 1 # images are 1-indexed 99 | 100 | if 'manual_lv_ed_phase' in l: 101 | self.ed = int(l.split('manual_lv_ed_phase=')[1].split('\n')[0]) + 1 # images are 1-indexed 102 | 103 | # create a tuple of floats from a x-y coord pair. 104 | if '[XYCONTOUR]' in l: 105 | header = fd.readline().split(' ') # e.g. 1 0 0 1.0 106 | slice = int(header[0]) 107 | phase = int(header[1]) # usually 0 -> ED, 7 -> ES 108 | contour_type = int(header[2]) # 0 -> endo, 1 -> epi 109 | 110 | # Initialise ES and ED if not explicitly defined in the contour file 111 | if phase < 2 and self.ed is None: 112 | self.ed = phase 113 | if phase > 2 and self.es is None: 114 | self.es = phase 115 | 116 | num_coords = int(fd.readline()) 117 | 118 | parse_coord = lambda x: (float(x.split(' ')[0]), float(x.split(' ')[1])) 119 | coords = [parse_coord(fd.readline()) for i in range(num_coords)] # read coordinates 120 | 121 | cc = self.coordinates[slice][phase] 122 | if contour_type == 0: # Coordinate for endo 123 | cc.endo = coords 124 | elif contour_type == 1: # Coordinate for epi 125 | cc.epi = coords 126 | self.coordinates[slice][phase] = cc 127 | 128 | def save(self, folder): 129 | with open(folder + '/contour.json', 'w') as outfile: 130 | d = self.__dict__.copy() 131 | d['coordinates'] = None 132 | json.dump(d, outfile) 133 | -------------------------------------------------------------------------------- /loaders/loader_factory.py: -------------------------------------------------------------------------------- 1 | from loaders.chaos import ChaosLoader 2 | 3 | 4 | def init_loader(dataset): 5 | """ 6 | Factory method for initialising data loaders by name. 7 | """ 8 | if dataset == 'chaos': 9 | return ChaosLoader() 10 | return None 11 | -------------------------------------------------------------------------------- /model_components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/multimodal_segmentation/a4fa1b39830f6c1bc320ff5b5e3fda82b8382e18/model_components/__init__.py -------------------------------------------------------------------------------- /model_components/anatomy_encoder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from keras import Input, Model 4 | from keras.layers import Conv2D, Activation, UpSampling2D, Concatenate 5 | 6 | from layers.rounding import Rounding 7 | from models.unet import UNet 8 | from utils.model_utils import normalise 9 | 10 | log = logging.getLogger('anatomy_encoder') 11 | 12 | 13 | def build(conf, name='Enc_Anatomy'): 14 | """ 15 | Build a UNet based encoder to extract anatomical information from the image. 16 | """ 17 | spatial_encoder = UNet(conf) 18 | spatial_encoder.input = Input(shape=conf.input_shape) 19 | l1_down = spatial_encoder.unet_downsample(spatial_encoder.input, spatial_encoder.normalise) # downsample 20 | spatial_encoder.unet_bottleneck(l1_down, spatial_encoder.normalise) # bottleneck 21 | l2_up = spatial_encoder.unet_upsample(spatial_encoder.bottleneck, spatial_encoder.normalise) # upsample 22 | 23 | anatomy = Conv2D(conf.out_channels, 1, padding='same', activation='softmax', name='conv_anatomy')(l2_up) 24 | if conf.rounding: 25 | anatomy = Rounding()(anatomy) 26 | 27 | model = Model(inputs=spatial_encoder.input, outputs=anatomy, name=name) 28 | log.info('Enc_Anatomy') 29 | model.summary(print_fn=log.info) 30 | return model 31 | 32 | class AnatomyEncoders(object): 33 | 34 | def __init__(self, modalities): 35 | self.modalities = modalities 36 | 37 | def build(self, conf): 38 | # build encoder1 39 | encoder1 = UNet(conf) 40 | encoder1.input = Input(shape=conf.input_shape) 41 | l1 = encoder1.unet_downsample(encoder1.input, encoder1.normalise) 42 | 43 | # build encoder2 44 | encoder2 = UNet(conf) 45 | encoder2.input = Input(shape=conf.input_shape) 46 | l2 = encoder2.unet_downsample(encoder2.input, encoder2.normalise) 47 | 48 | self.build_decoder(conf) 49 | 50 | d1_l3 = encoder1.d_l3 if conf.downsample > 3 else None 51 | d2_l3 = encoder2.d_l3 if conf.downsample > 3 else None 52 | anatomy_output1 = self.evaluate_decoder(conf, 53 | l1, d1_l3, encoder1.d_l2, encoder1.d_l1, encoder1.d_l0) 54 | anatomy_output2 = self.evaluate_decoder(conf, 55 | l2, d2_l3, encoder2.d_l2, encoder2.d_l1, encoder2.d_l0) 56 | 57 | # build shared layer 58 | shr_lay4 = Conv2D(conf.out_channels, 1, padding='same', activation='softmax', name='conv_anatomy') 59 | 60 | # connect models 61 | encoder1_output = shr_lay4(anatomy_output1) 62 | encoder2_output = shr_lay4(anatomy_output2) 63 | 64 | if conf.rounding: 65 | encoder1_output = Rounding()(encoder1_output) 66 | encoder2_output = Rounding()(encoder2_output) 67 | 68 | encoder1 = Model(inputs=encoder1.input, outputs=encoder1_output, 69 | name='Enc_Anatomy_%s' % self.modalities[0]) 70 | encoder2 = Model(inputs=encoder2.input, outputs=encoder2_output, 71 | name='Enc_Anatomy_%s' % self.modalities[1]) 72 | 73 | return [encoder1, encoder2] 74 | 75 | def evaluate_decoder(self, conf, decoder_input, decoder_l3, decoder_l2, decoder_l1, decoder_l0): 76 | l0_out = self.l0_6(self.l0_5(self.l0_4(self.l0_3(self.l0_2(self.l0_1(decoder_input)))))) 77 | 78 | if conf.downsample > 3: 79 | l3_out = self.l3(self.l2(self.l1(l0_out))) 80 | l4_out = self.l4([l3_out, decoder_l3]) 81 | 82 | l10_out = self.l10(self.l9(self.l8(self.l7(self.l6(self.l5(l4_out)))))) 83 | else: 84 | l10_out = l0_out 85 | 86 | l13_out = self.l13(self.l12(self.l11(l10_out))) 87 | l14_out = self.l14([l13_out, decoder_l2]) 88 | 89 | l20_out = self.l20(self.l19(self.l18(self.l17(self.l16(self.l15(l14_out)))))) 90 | l23_out = self.l23(self.l22(self.l21(l20_out))) 91 | l24_out = self.l14([l23_out, decoder_l1]) 92 | 93 | l30_out = self.l30(self.l29(self.l28(self.l27(self.l26(self.l25(l24_out)))))) 94 | l33_out = self.l33(self.l32(self.l31(l30_out))) 95 | l34_out = self.l34([l33_out, decoder_l0]) 96 | 97 | l40_out = self.l40(self.l39(self.l38(self.l37(self.l36(self.l35(l34_out)))))) 98 | return l40_out 99 | 100 | def build_decoder(self, conf): 101 | f0 = conf.filters * 16 if conf.downsample > 3 else conf.filters * 8 102 | self.l0_1 = Conv2D(f0, 3, padding='same', kernel_initializer='he_normal') 103 | self.l0_2 = normalise(conf.normalise) 104 | self.l0_3 = Activation('relu') 105 | self.l0_4 = Conv2D(f0, 3, padding='same', kernel_initializer='he_normal') 106 | self.l0_5 = normalise(conf.normalise) 107 | self.l0_6 = Activation('relu') 108 | 109 | if conf.downsample > 3: 110 | self.l1 = UpSampling2D(size=2) 111 | self.l2 = Conv2D(conf.filters * 8, 3, padding='same', kernel_initializer='he_normal') 112 | self.l3 = normalise(conf.normalise) 113 | 114 | self.l4 = Concatenate() 115 | self.l5 = Conv2D(conf.filters * 8, 3, strides=1, padding='same', kernel_initializer='he_normal') 116 | self.l6 = normalise(conf.normalise) 117 | self.l7 = Activation('relu') 118 | self.l8 = Conv2D(conf.filters * 8, 3, strides=1, padding='same', kernel_initializer='he_normal') 119 | self.l9 = normalise(conf.normalise) 120 | self.l10 = Activation('relu') 121 | 122 | self.l11 = UpSampling2D(size=2) 123 | self.l12 = Conv2D(conf.filters * 4, 3, padding='same', kernel_initializer='he_normal') 124 | self.l13 = normalise(conf.normalise) 125 | 126 | self.l14 = Concatenate() 127 | self.l15 = Conv2D(conf.filters * 4, 3, strides=1, padding='same', kernel_initializer='he_normal') 128 | self.l16 = normalise(conf.normalise) 129 | self.l17 = Activation('relu') 130 | self.l18 = Conv2D(conf.filters * 4, 3, strides=1, padding='same', kernel_initializer='he_normal') 131 | self.l19 = normalise(conf.normalise) 132 | self.l20 = Activation('relu') 133 | 134 | self.l21 = UpSampling2D(size=2) 135 | self.l22 = Conv2D(conf.filters * 2, 3, padding='same', kernel_initializer='he_normal') 136 | self.l23 = normalise(conf.normalise) 137 | 138 | self.l24 = Concatenate() 139 | self.l25 = Conv2D(conf.filters * 2, 3, strides=1, padding='same', kernel_initializer='he_normal') 140 | self.l26 = normalise(conf.normalise) 141 | self.l27 = Activation('relu') 142 | self.l28 = Conv2D(conf.filters * 2, 3, strides=1, padding='same', kernel_initializer='he_normal') 143 | self.l29 = normalise(conf.normalise) 144 | self.l30 = Activation('relu') 145 | 146 | self.l31 = UpSampling2D(size=2) 147 | self.l32 = Conv2D(conf.filters, 3, padding='same', kernel_initializer='he_normal') 148 | self.l33 = normalise(conf.normalise) 149 | 150 | self.l34 = Concatenate() 151 | self.l35 = Conv2D(conf.filters, 3, strides=1, padding='same', kernel_initializer='he_normal') 152 | self.l36 = normalise(conf.normalise) 153 | self.l37 = Activation('relu') 154 | self.l38 = Conv2D(conf.filters, 3, strides=1, padding='same', kernel_initializer='he_normal') 155 | self.l39 = normalise(conf.normalise) 156 | self.l40 = Activation('relu') -------------------------------------------------------------------------------- /model_components/anatomy_fuser.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import keras.layers 4 | from keras import Input, Model 5 | 6 | from layers import stn_spline 7 | from layers.stn_spline import ThinPlateSpline2D 8 | 9 | log = logging.getLogger('anatomy_fuser') 10 | 11 | 12 | def build(conf): 13 | """ 14 | Build a model that deforms and fuses anatomies, used to combine multimodal information. 15 | Two anatomies are assumed: the first anatomy is deformed to match the second. 16 | Deformation model uses a STN. 17 | :param conf: a configuration object 18 | """ 19 | anatomy1 = Input(conf.anatomy_encoder.output_shape) # anatomy from modality1 20 | anatomy2 = Input(conf.anatomy_encoder.output_shape) # anatomy from modality2 21 | 22 | output_shape = conf.anatomy_encoder.output_shape 23 | 24 | dims = conf.anatomy_encoder.output_shape[:-1] 25 | cp = [5, 5] 26 | channels = conf.anatomy_encoder.out_channels 27 | 28 | locnet = stn_spline.build_locnet(output_shape, output_shape, cp[0] * cp[1] * 2) 29 | theta = locnet([anatomy1, anatomy2]) 30 | anatomy1_deformed = ThinPlateSpline2D(dims, cp, channels)([anatomy1, theta]) 31 | 32 | # Fusion step 33 | anatomy_fused = keras.layers.Maximum()([anatomy1_deformed, anatomy2]) 34 | 35 | model = Model(inputs=[anatomy1, anatomy2], outputs=[anatomy1_deformed, anatomy_fused], name='Anatomy_Fuser') 36 | log.info('Anatomy fuser') 37 | model.summary(print_fn=log.info) 38 | return model 39 | -------------------------------------------------------------------------------- /model_components/balancer.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | 4 | import keras.backend as K 5 | import tensorflow as tf 6 | from keras import Input, Model 7 | from keras.layers import Concatenate, Dense, Lambda 8 | 9 | log = logging.getLogger('pair_selector') 10 | 11 | def build(conf): 12 | """ 13 | Build a model that predicts similarity weights between anatomies. These are based on the overlap of the anatomies 14 | calculated with Dice 15 | :param conf: a configuration object 16 | """ 17 | x1 = Input(shape=conf.anatomy_encoder.output_shape) 18 | x2 = Input(shape=conf.anatomy_encoder.output_shape) 19 | x3 = Input(shape=conf.anatomy_encoder.output_shape) 20 | x4 = Input(shape=conf.anatomy_encoder.output_shape) 21 | 22 | overlap = [Lambda(dice)([x1, x]) for x in [x2, x3, x4]] 23 | x = Concatenate()(overlap) 24 | l = Dense(5, activation='relu')(x) 25 | w = Dense(conf.n_pairs, name='beta')(l) 26 | w = Lambda(lambda x: tf.nn.softmax(x, dim=-1))(w) 27 | 28 | m = Model(inputs=[x1, x2, x3, x4], outputs=w, name='Balancer') 29 | log.info('Balancer') 30 | m.summary(print_fn=log.info) 31 | return m 32 | 33 | def dice(y): 34 | y_true, y_pred = y 35 | intersection = K.sum(y_true * y_pred, axis=(1, 2, 3)) 36 | union = K.sum(y_true, axis=(1, 2, 3)) + K.sum(y_pred, axis=(1, 2, 3)) 37 | dice = (2 * intersection + 1e-12) / (union + 1e-12) 38 | return K.expand_dims(dice, axis=1) -------------------------------------------------------------------------------- /model_components/decoder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from keras import Input, Model 4 | from keras.layers import Dense, LeakyReLU, Reshape, Conv2D, Add, UpSampling2D 5 | 6 | from layers.film import FiLM 7 | from layers.spade import spade_block 8 | 9 | log = logging.getLogger('decoder') 10 | 11 | 12 | def build(conf): 13 | """ 14 | Build a decoder that generates an image by combining an anatomical and a modality 15 | representation. Two decoders are considered based on FiLM or SPADE conditioning. 16 | :param conf: a configuration object 17 | """ 18 | anatomy_input = Input(shape=conf.anatomy_encoder.output_shape) 19 | modality_input = Input((conf.num_z,)) 20 | 21 | if conf.decoder_type == 'film': 22 | l = _film_decoder(anatomy_input, modality_input) 23 | elif conf.decoder_type == 'spade': 24 | l = _spade_decoder(conf, anatomy_input, modality_input) 25 | else: 26 | raise ValueError('Unknown decoder_type value: ' + str(conf.decoder_type)) 27 | 28 | l = Conv2D(1, 1, activation='tanh', padding='same', kernel_initializer='glorot_normal')(l) 29 | log.info('Decoder') 30 | 31 | model = Model(inputs=[anatomy_input, modality_input], outputs=l, name='Decoder') 32 | model.summary(print_fn=log.info) 33 | return model 34 | 35 | 36 | def _gamma_beta_pred(inp, num_chn): 37 | gamma = Dense(num_chn)(inp) 38 | gamma = LeakyReLU()(gamma) 39 | beta = Dense(num_chn)(inp) 40 | beta = LeakyReLU()(beta) 41 | return gamma, beta 42 | 43 | 44 | def _film_layer(anatomy_input, modality_input): 45 | l1 = Conv2D(8, 3, padding='same')(anatomy_input) 46 | l1 = LeakyReLU()(l1) 47 | 48 | l2 = Conv2D(8, 3, strides=1, padding='same')(l1) 49 | gamma_l2, beta_l2 = _gamma_beta_pred(modality_input, 8) 50 | l2 = FiLM()([l2, gamma_l2, beta_l2]) 51 | l2 = LeakyReLU()(l2) 52 | 53 | l = Add()([l1, l2]) 54 | return l 55 | 56 | 57 | def _film_decoder(anatomy_input, modality_input): 58 | l = Conv2D(8, 3, padding='same')(anatomy_input) 59 | l = LeakyReLU()(l) 60 | l1 = _film_layer(l, modality_input) 61 | l2 = _film_layer(l1, modality_input) 62 | l3 = _film_layer(l2, modality_input) 63 | l4 = _film_layer(l3, modality_input) 64 | return l4 65 | 66 | 67 | def _spade_decoder(conf, anatomy_input, modality_input): 68 | modality = Dense(conf.input_shape[0] * conf.input_shape[1] * 128 // 1024)(modality_input) 69 | l1 = Reshape((conf.input_shape[0] // 32, conf.input_shape[1] // 32, 128))(modality) 70 | l1 = spade_block(conf, anatomy_input, l1, 128, 128) 71 | l1 = UpSampling2D(size=2)(l1) 72 | l1 = spade_block(conf, anatomy_input, l1, 128, 128) 73 | l1 = UpSampling2D(size=2)(l1) 74 | l2 = spade_block(conf, anatomy_input, l1, 128, 128) 75 | l3 = UpSampling2D(size=2)(l2) 76 | l4 = spade_block(conf, anatomy_input, l3, 128, 64) 77 | l5 = UpSampling2D(size=2)(l4) 78 | l6 = spade_block(conf, anatomy_input, l5, 64, 32) 79 | l7 = UpSampling2D(size=2)(l6) 80 | l8 = spade_block(conf, anatomy_input, l7, 32, 16) 81 | return l8 82 | -------------------------------------------------------------------------------- /model_components/modality_encoder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from keras import Input, Model 4 | from keras.layers import Concatenate, Conv2D, LeakyReLU, Flatten, Dense, \ 5 | Lambda 6 | 7 | import costs 8 | from utils.sdnet_utils import sampling 9 | 10 | log = logging.getLogger('modality_encoder') 11 | 12 | 13 | def build(conf): 14 | """ 15 | Build an encoder to extract intensity information from the image. 16 | :param conf: a configuration object 17 | """ 18 | anatomy = Input(conf.anatomy_encoder.output_shape) 19 | image = Input(conf.input_shape) 20 | 21 | z_mean, z_log_var = build_simple_encoder(conf, anatomy, image) 22 | 23 | # use reparameterization trick to push the sampling out as input 24 | # note that "output_shape" isn't necessary with the TensorFlow backend 25 | z = Lambda(sampling, name='z')([z_mean, z_log_var]) 26 | divergence = Lambda(costs.kl, name='divergence')([z_mean, z_log_var]) 27 | 28 | model = Model(inputs=[anatomy, image], outputs=[z, divergence], name='Enc_Modality') 29 | log.info('Enc_Modality') 30 | model.summary(print_fn=log.info) 31 | return model 32 | 33 | 34 | def build_simple_encoder(conf, anatomy, image): 35 | l = Concatenate(axis=-1)([anatomy, image]) 36 | l = Conv2D(16, 3, strides=2, kernel_initializer='he_normal')(l) 37 | l = LeakyReLU()(l) 38 | l = Conv2D(32, 3, strides=2, kernel_initializer='he_normal')(l) 39 | l = LeakyReLU()(l) 40 | l = Conv2D(64, 3, strides=2, kernel_initializer='he_normal')(l) 41 | l = LeakyReLU()(l) 42 | l = Conv2D(128, 3, strides=2, kernel_initializer='he_normal')(l) 43 | l = LeakyReLU()(l) 44 | 45 | l = Flatten()(l) 46 | l = Dense(32, kernel_initializer='he_normal')(l) 47 | l = LeakyReLU()(l) 48 | 49 | z_mean = Dense(conf.num_z, name='z_mean')(l) 50 | z_log_var = Dense(conf.num_z, name='z_log_var')(l) 51 | 52 | return z_mean, z_log_var 53 | -------------------------------------------------------------------------------- /model_components/segmentor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from keras import Input, Model 4 | from keras.layers import Conv2D, BatchNormalization, Activation 5 | 6 | log = logging.getLogger('segmentor') 7 | 8 | 9 | def build(conf): 10 | """ 11 | Build a segmentation network that converts anatomical maps to segmentation masks. 12 | :param conf: a configuration object 13 | """ 14 | inp = Input(conf.anatomy_encoder.output_shape) 15 | 16 | l = Conv2D(64, 3, padding='same', kernel_initializer='he_normal')(inp) 17 | l = BatchNormalization()(l) 18 | l = Activation('relu')(l) 19 | l = Conv2D(64, 3, padding='same', kernel_initializer='he_normal')(l) 20 | l = BatchNormalization()(l) 21 | l = Activation('relu')(l) 22 | 23 | # +1 output for background 24 | output = Conv2D(conf.num_masks + 1, 1, padding='same', activation='softmax')(l) 25 | 26 | model = Model(inputs=inp, outputs=output, name='Segmentor') 27 | log.info('Segmentor') 28 | model.summary(print_fn=log.info) 29 | return model 30 | -------------------------------------------------------------------------------- /model_executors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/multimodal_segmentation/a4fa1b39830f6c1bc320ff5b5e3fda82b8382e18/model_executors/__init__.py -------------------------------------------------------------------------------- /model_executors/base_executor.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging 3 | from abc import abstractmethod 4 | 5 | import numpy as np 6 | from keras.preprocessing.image import ImageDataGenerator 7 | 8 | from loaders import loader_factory 9 | from model_tester import ModelTester 10 | 11 | log = logging.getLogger('executor') 12 | 13 | 14 | class Executor(object): 15 | """ 16 | Base class for executor objects. 17 | """ 18 | def __init__(self, conf, model): 19 | self.conf = conf 20 | self.model = model 21 | self.loader = loader_factory.init_loader(self.conf.dataset_name) 22 | self.batch = 0 23 | self.epoch = 0 24 | 25 | @abstractmethod 26 | def init_train_data(self): 27 | pass 28 | 29 | @abstractmethod 30 | def get_loss_names(self): 31 | pass 32 | 33 | @abstractmethod 34 | def train(self): 35 | pass 36 | 37 | def get_data_generator(self, train_images=None, train_labels=None): 38 | """ 39 | Create a data generator that also augments the data. 40 | :param train_images: input data 41 | :param train_labels: target data 42 | :return an iterator that gives a tuple of (input, output) data. 43 | """ 44 | image_dict = self.get_datagen_params() 45 | mask_dict = self.get_datagen_params() 46 | 47 | img_gens = [] 48 | if train_images is not None: 49 | if type(train_images) != list: 50 | train_images = [train_images] 51 | 52 | for img_array in train_images: 53 | img_gens.append(ImageDataGenerator(**image_dict).flow(img_array, batch_size=self.conf.batch_size, 54 | seed=self.conf.seed)) 55 | 56 | msk_gens = [] 57 | if train_labels is not None: 58 | if type(train_labels) != list: 59 | train_labels = [train_labels] 60 | 61 | for msk_array in train_labels: 62 | msk_gens.append(ImageDataGenerator(**mask_dict).flow(msk_array, batch_size=self.conf.batch_size, 63 | seed=self.conf.seed)) 64 | 65 | if len(img_gens) > 0 and len(msk_gens) > 0: 66 | all_data = img_gens + msk_gens 67 | gen = itertools.zip_longest(*all_data) 68 | return gen 69 | elif len(img_gens) > 0 and len(msk_gens) == 0: 70 | if len(img_gens) == 1: 71 | return img_gens[0] 72 | return itertools.zip_longest(*img_gens) 73 | elif len(img_gens) == 0 and len(msk_gens) > 0: 74 | if len(msk_gens) == 1: 75 | return msk_gens[0] 76 | return itertools.zip_longest(*msk_gens) 77 | else: 78 | raise Exception("No data to iterate.") 79 | 80 | def validate(self, epoch_loss): 81 | pass 82 | 83 | def add_residual(self, data): 84 | residual = np.ones(data.shape[:-1] + (1,)) 85 | for i in range(data.shape[-1]): 86 | residual[data[..., i:i+1] == 1] = 0 87 | return np.concatenate([data, residual], axis=-1) 88 | 89 | @abstractmethod 90 | def test(self): 91 | """ 92 | Evaluate a model on the test data. 93 | """ 94 | log.info('Evaluating model on test data') 95 | tester = ModelTester(self.model, self.conf) 96 | tester.run() 97 | 98 | def stop_criterion(self, es, logs): 99 | es.on_epoch_end(self.epoch, logs) 100 | if es.stopped_epoch > 0: 101 | return True 102 | 103 | def get_datagen_params(self): 104 | """ 105 | Construct a dictionary of augmentations. 106 | :return: a dictionary of augmentation parameters to use with a keras image processor 107 | """ 108 | d = dict(horizontal_flip=False, vertical_flip=False, rotation_range=20., 109 | width_shift_range=0, height_shift_range=0, zoom_range=0) 110 | return d 111 | 112 | def align_batches(self, array_list): 113 | """ 114 | Align the arrays of the input list, based on batch size. 115 | :param array_list: list of 4-d arrays to align 116 | """ 117 | mn = np.min([x.shape[0] for x in array_list]) 118 | new_list = [x[0:mn] + 0. for x in array_list] 119 | return new_list 120 | -------------------------------------------------------------------------------- /model_executors/mmsdnet_executor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from keras.callbacks import CSVLogger, EarlyStopping 5 | from keras.utils import Progbar 6 | 7 | import costs 8 | import utils.data_utils 9 | from callbacks.dafnet_image_callback import DAFNetImageCallback 10 | from callbacks.loss_callback import SaveLoss 11 | from model_executors.base_executor import Executor 12 | from utils.distributions import NormalDistribution 13 | 14 | log = logging.getLogger('mmsdnet_executor') 15 | 16 | 17 | class MMSDNetExecutor(Executor): 18 | """ 19 | Train a DAFNet or MMSDNet model using parameters stored in the configuration. 20 | """ 21 | def __init__(self, conf, model): 22 | super(MMSDNetExecutor, self).__init__(conf, model) 23 | self.exp_clb = None 24 | self.loader.modalities = self.conf.modality 25 | 26 | self.gen_labelled = None # iterator for labelled data (supervised learning) 27 | self.gen_unlabelled = None # iterator for unlabelled data (unsupervised learning) 28 | self.discriminator_masks = None # iterator for real masks to train discriminators 29 | self.discriminator_image = None # iterator for images to train discriminators 30 | self.img_callback = None # callback to save images 31 | self.data = None # labelled data container of type MultimodalPairedData 32 | self.ul_data = None # unlabelled data container of type MultimodalPairedData 33 | 34 | self.gen_unlabelled_lge = None 35 | self.gen_unlabelled_cine = None 36 | 37 | self.img_callback = None 38 | self.conf.batches_lge = 0 39 | 40 | def init_train_data(self): 41 | self.gen_labelled = self._init_labelled_data_generator() 42 | self.gen_unlabelled = self._init_unlabelled_data_generator() 43 | self.discriminator_masks = self._init_disciminator_mask_generator() 44 | self.discriminator_image = [self._init_discriminator_image_generator(mod) 45 | for mod in self.model.modalities] 46 | 47 | self.batches = int(np.ceil(self.data_len / self.conf.batch_size)) 48 | 49 | def _init_labelled_data_generator(self): 50 | """ 51 | Initialise a data generator (image, mask) for labelled data 52 | """ 53 | if self.conf.l_mix == 0: 54 | return 55 | 56 | log.info('Initialising labelled datagen. Loading %s data' % self.conf.dataset_name) 57 | self.data = self.loader.load_all_modalities_concatenated(self.conf.split, 'training', self.conf.image_downsample) 58 | self.data.sample(int(np.round(self.conf.l_mix * self.data.num_volumes)), seed=self.conf.seed) 59 | log.info('labelled data size: ' + str(self.data.size())) 60 | self.data_len = self.data.size() 61 | 62 | return self.get_data_generator(train_images=[self.data.get_images_modi(i) for i in range(2)], 63 | train_labels=[self.data.get_masks_modi(i) for i in range(2)]) 64 | 65 | def _init_unlabelled_data_generator(self): 66 | """ 67 | Initialise a data generator (image) for unlabelled data 68 | """ 69 | if self.conf.l_mix == 1: 70 | return 71 | 72 | self.ul_data = self._load_unlabelled_data('training', 'ul', None) 73 | self.conf.unlabelled_image_num = self.ul_data.size() 74 | if self.data is None or self.ul_data.size() > self.data.size(): 75 | self.data_len = self.ul_data.size() 76 | 77 | return self.get_data_generator(train_images=[self.ul_data.get_images_modi(i) for i in range(2)], 78 | train_labels=[self.ul_data.get_masks_modi(0)]) 79 | 80 | def _load_unlabelled_data(self, split_type, data_type, modality): 81 | ''' 82 | Create a Data object with unlabelled data. This will be used to train the unlabelled path of the 83 | generators and produce fake masks for training the discriminator 84 | :param split_type: the split defining which volumes to load 85 | :param data_type: can be one ['ul', 'all']. The second includes images that have masks. 86 | :return: a data object 87 | ''' 88 | log.info('Initialising unlabelled datagen. Loading %s data of type %s' % (self.conf.dataset_name, data_type)) 89 | if data_type == 'ul': 90 | log.info('Estimating number of unlabelled images from %s data' % self.conf.dataset_name) 91 | ul_data = self.loader.load_all_modalities_concatenated(self.conf.split, split_type, 92 | self.conf.image_downsample) 93 | self.conf.num_ul_volumes = ul_data.num_volumes 94 | 95 | if self.conf.l_mix > 0: 96 | num_lb_vols = int(np.round(self.conf.l_mix * ul_data.num_volumes)) 97 | volumes = ul_data.get_sample_volumes(num_lb_vols, seed=self.conf.seed) 98 | ul_volumes = [v for v in ul_data.volumes() if v not in volumes] # ul volumes are the remaining from lbl 99 | ul_data.filter_volumes(ul_volumes) 100 | 101 | log.info('unlabelled data size: ' + str(ul_data.size())) 102 | elif data_type == 'all': 103 | ul_data = self.loader.load_all_data(self.conf.split, split_type, modality=modality, 104 | downsample=self.conf.image_downsample) 105 | else: 106 | raise Exception('Invalid data_type: %s' % str(data_type)) 107 | 108 | return ul_data 109 | 110 | def _init_disciminator_mask_generator(self): 111 | """ 112 | Init a generator for masks to use in the discriminator. 113 | """ 114 | log.info('Initialising discriminator maskgen.') 115 | masks = self._load_discriminator_masks() 116 | return self.get_data_generator(train_images=None, train_labels=[masks]) 117 | 118 | def _load_discriminator_masks(self): 119 | masks = [] 120 | if self.data is not None: 121 | masks.append(np.concatenate([self.data.get_masks_modi(0), self.data.get_masks_modi(1)], axis=0)) 122 | if self.ul_data is not None: 123 | masks.append(self.ul_data.get_masks_modi(0)) 124 | 125 | if len(masks) == 0: 126 | masks = np.empty(shape=([0] + self.conf.input_shape[:-1] + [self.loader.num_masks])) 127 | else: 128 | masks = np.concatenate(masks, axis=0) 129 | 130 | im_shape = self.conf.input_shape[:2] 131 | assert masks.shape[1] == im_shape[0] and masks.shape[2] == im_shape[1], masks.shape 132 | 133 | return masks 134 | 135 | def _init_discriminator_image_generator(self, modality): 136 | """ 137 | Init a generator for images to train a discriminator (for fake masks) 138 | """ 139 | log.info('Initialising discriminator imagegen.') 140 | data = self._load_unlabelled_data('training', 'all', modality) 141 | return self.get_data_generator(train_images=[data.images], train_labels=None) 142 | 143 | def init_image_callback(self): 144 | log.info('Initialising a data generator to use for printing.') 145 | 146 | if self.data is None: 147 | data = self.loader.load_all_modalities_concatenated(self.conf.split, 'training', self.conf.image_downsample) 148 | else: 149 | data = self.data 150 | 151 | gen = self.get_data_generator(train_images=[data.get_images_modi(i) for i in range(2)], 152 | train_labels=[data.get_masks_modi(i) for i in range(2)]) 153 | self.img_callback = DAFNetImageCallback(self.conf, self.model, gen) 154 | 155 | def get_loss_names(self): 156 | return ['adv_M', 'rec_X', 'dis_M', 'val_loss', 'val_loss_mod1', 'val_loss_mod2', 157 | 'val_loss_mod2_s1def', 'val_loss_mod2_fused', 'supervised_Mask', 'loss', 'KL', 'rec_Z'] 158 | 159 | def train(self): 160 | log.info('Training Model') 161 | 162 | self.init_train_data() 163 | 164 | self.init_image_callback() 165 | sl = SaveLoss(self.conf.folder) 166 | cl = CSVLogger(self.conf.folder + '/training.csv') 167 | cl.on_train_begin() 168 | 169 | es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60) 170 | es.model = self.model.Segmentor 171 | es.on_train_begin() 172 | 173 | loss_names = self.get_loss_names() 174 | total_loss = {n: [] for n in loss_names} 175 | 176 | progress_bar = Progbar(target=self.batches * self.conf.batch_size) 177 | for self.epoch in range(self.conf.epochs): 178 | log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs)) 179 | 180 | epoch_loss = {n: [] for n in loss_names} 181 | epoch_loss_list = [] 182 | 183 | for self.batch in range(self.batches): 184 | self.train_batch(epoch_loss) 185 | progress_bar.update((self.batch + 1) * self.conf.batch_size) 186 | 187 | self.validate(epoch_loss) 188 | 189 | for n in loss_names: 190 | epoch_loss_list.append((n, np.mean(epoch_loss[n]))) 191 | total_loss[n].append(np.mean(epoch_loss[n])) 192 | log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % 193 | ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) 194 | logs = {l: total_loss[l][-1] for l in loss_names} 195 | 196 | cl.model = self.model.D_Mask 197 | cl.model.stop_training = False 198 | cl.on_epoch_end(self.epoch, logs) 199 | sl.on_epoch_end(self.epoch, logs) 200 | 201 | # Plot some example images 202 | self.img_callback.on_epoch_end(self.epoch) 203 | 204 | self.model.save_models() 205 | 206 | if self.stop_criterion(es, logs): 207 | log.info('Finished training from early stopping criterion') 208 | break 209 | 210 | def validate(self, epoch_loss): 211 | # Report validation error 212 | valid_data = self.loader.load_all_modalities_concatenated(self.conf.split, 'validation', self.conf.image_downsample) 213 | valid_data.crop(self.conf.input_shape[:2]) 214 | 215 | images0 = valid_data.get_images_modi(0) 216 | images1 = valid_data.get_images_modi(1) 217 | real_mask0 = valid_data.get_masks_modi(0) 218 | real_mask1 = valid_data.get_masks_modi(1) 219 | 220 | s1 = self.model.Encoders_Anatomy[0].predict(images0) 221 | s2 = self.model.Encoders_Anatomy[1].predict(images1) 222 | s1_deformed, s_fused = self.model.Anatomy_Fuser.predict([s1, s2]) 223 | mask1 = self.model.Segmentor.predict(s1) 224 | mask2 = self.model.Segmentor.predict(s2) 225 | mask3 = self.model.Segmentor.predict(s1_deformed) 226 | mask4 = self.model.Segmentor.predict(s_fused) 227 | 228 | l_mod1 = (1 - costs.dice(real_mask0, mask1, binarise=True)) 229 | l_mod2 = (1 - costs.dice(real_mask1, mask2, binarise=True)) 230 | l_mod2_s1def = (1 - costs.dice(real_mask1, mask3, binarise=True)) 231 | l_mod2_fused = (1 - costs.dice(real_mask1, mask4, binarise=True)) 232 | epoch_loss['val_loss_mod2'].append(l_mod2) 233 | epoch_loss['val_loss_mod2_s1def'].append(l_mod2_s1def) 234 | epoch_loss['val_loss_mod2_fused'].append(l_mod2_fused) 235 | epoch_loss['val_loss_mod1'].append(l_mod1) 236 | epoch_loss['val_loss'].append(np.mean([l_mod1, l_mod2, l_mod2_s1def, l_mod2_fused])) 237 | 238 | def train_batch(self, epoch_loss): 239 | self.train_batch_generators(epoch_loss) 240 | self.train_batch_mask_discriminator(epoch_loss) 241 | 242 | def train_batch_generators(self, epoch_loss): 243 | """ 244 | Train generator/segmentation networks. 245 | :param epoch_loss: Dictionary of losses for the epoch 246 | """ 247 | num_mod = len(self.model.modalities) 248 | 249 | if self.conf.l_mix > 0: 250 | x1, x2, m1, m2 = next(self.gen_labelled) 251 | [x1, x2, m1, m2] = self.align_batches([x1, x2, m1, m2]) 252 | batch_size = x1.shape[0] # maybe this differs from conf.batch_size at the last batch. 253 | dm_shape = (batch_size,) + self.model.D_Mask.output_shape[1:] 254 | ones_m = np.ones(shape=dm_shape) 255 | 256 | # Train labelled path (supervised_model) 257 | all_outputs = [m1, m2, m2, m2, m1, m1] + \ 258 | [ones_m for _ in range(num_mod * 3)] + \ 259 | [x1, x2, x2, x2, x1, x1] + \ 260 | [np.zeros(batch_size) for _ in range(num_mod * 3)] 261 | h = self.model.supervised_trainer.fit([x1, x2], all_outputs, epochs=1, verbose=0) 262 | epoch_loss['supervised_Mask'].append(np.mean(h.history['Segmentor_loss'])) 263 | epoch_loss['adv_M'].append(np.mean(h.history['D_Mask_loss'])) 264 | epoch_loss['rec_X'].append(np.mean(h.history['Decoder_loss'])) 265 | epoch_loss['KL'].append(np.mean(h.history['Enc_Modality_loss'])) 266 | 267 | # Train Z Regressor 268 | norm = NormalDistribution() 269 | s_list = [self.model.Encoders_Anatomy[i].predict(x) for i, x in enumerate([x1, x2])] 270 | s1_def, s1_fused = self.model.Anatomy_Fuser.predict(s_list) 271 | s2_def, s2_fused = self.model.Anatomy_Fuser.predict(list(reversed(s_list))) 272 | s_list += [s1_def, s1_fused] 273 | s_list += [s2_def, s2_fused] 274 | z_list = [norm.sample((batch_size, self.conf.num_z)) for _ in range(num_mod * 3)] 275 | h = self.model.Z_Regressor.fit(s_list + z_list, z_list, epochs=1, verbose=0) 276 | epoch_loss['rec_Z'].append(np.mean(h.history['loss'])) 277 | 278 | # Train unlabelled path 279 | if self.conf.l_mix < 1: 280 | x1, x2, m1 = next(self.gen_unlabelled) 281 | [x1, x2, m1] = self.align_batches([x1, x2, m1]) 282 | batch_size = x1.shape[0] # maybe this differs from conf.batch_size at the last batch. 283 | dm_shape = (batch_size,) + self.model.D_Mask.output_shape[1:] 284 | ones_m = np.ones(shape=dm_shape) 285 | 286 | # Train unlabelled path (G_model) 287 | all_outputs = [m1, m1, m1] + \ 288 | [ones_m for _ in range(num_mod * 3)] + \ 289 | [x1, x2, x2, x2, x1, x1] + \ 290 | [np.zeros(batch_size) for _ in range(num_mod * 3)] 291 | h = self.model.unsupervised_trainer.fit([x1, x2], all_outputs, epochs=1, verbose=0) 292 | epoch_loss['supervised_Mask'].append(np.mean(h.history['Segmentor_loss'])) 293 | epoch_loss['adv_M'].append(np.mean(h.history['D_Mask_loss'])) 294 | epoch_loss['rec_X'].append(np.mean(h.history['Decoder_loss'])) 295 | epoch_loss['KL'].append(np.mean(h.history['Enc_Modality_loss'])) 296 | 297 | # Train Z Regressor 298 | norm = NormalDistribution() 299 | s_list = [self.model.Encoders_Anatomy[i].predict(x) for i, x in enumerate([x1, x2])] 300 | s1_def, s1_fused = self.model.Anatomy_Fuser.predict(s_list) 301 | s2_def, s2_fused = self.model.Anatomy_Fuser.predict(list(reversed(s_list))) 302 | s_list += [s1_def, s1_fused] 303 | s_list += [s2_def, s2_fused] 304 | z_list = [norm.sample((batch_size, self.conf.num_z)) for _ in range(num_mod * 3)] 305 | h = self.model.Z_Regressor.fit(s_list + z_list, z_list, epochs=1, verbose=0) 306 | epoch_loss['rec_Z'].append(np.mean(h.history['loss'])) 307 | 308 | def train_batch_mask_discriminator(self, epoch_loss): 309 | """ 310 | Jointly train a discriminator for images and masks. 311 | :param epoch_loss: Dictionary of losses for the epoch 312 | """ 313 | m = next(self.discriminator_masks) 314 | m = m[..., 0:self.conf.num_masks] 315 | x_list = self.align_batches([next(gen) for gen in self.discriminator_image]) 316 | x_0, m = self.align_batches([x_list[0], m]) 317 | x_list = self.align_batches([x_0] + x_list[1:]) 318 | batch_size = m.shape[0] # maybe this differs from conf.batch_size at the last batch. 319 | 320 | num_mod = len(self.model.modalities) 321 | fake_s_list = [self.model.Encoders_Anatomy[i].predict(x_list[i]) for i in range(num_mod)] 322 | fake_m_list = [self.model.Segmentor.predict(fake_s_list[i]) for i in range(num_mod)] 323 | s1_def, s1_fused = self.model.Anatomy_Fuser.predict(fake_s_list) 324 | fake_m_list += [self.model.Segmentor.predict(s) for s in [s1_def, s1_fused]] 325 | fake_m = np.concatenate(fake_m_list, axis=0)[..., 0:self.conf.num_masks] 326 | fake_m = utils.data_utils.sample(fake_m, batch_size) 327 | 328 | # Train Discriminator 329 | m_shape = (batch_size,) + self.model.D_Mask.get_output_shape_at(0)[1:] 330 | h = self.model.D_Mask_trainer.fit([m, fake_m], [np.ones(m_shape), np.zeros(m_shape)], epochs=1, verbose=0) 331 | epoch_loss['dis_M'].append(np.mean(h.history['D_Mask_loss'])) 332 | -------------------------------------------------------------------------------- /model_tester.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import os 4 | import numpy as np 5 | import scipy 6 | 7 | import costs 8 | from loaders import loader_factory 9 | 10 | log = logging.getLogger('model_tester') 11 | 12 | 13 | class ModelTester(object): 14 | 15 | def __init__(self, model, conf): 16 | self.model = model 17 | self.conf = conf 18 | 19 | def run(self): 20 | for modi, mod in enumerate(self.model.modalities): 21 | log.info('Evaluating model on test data for %s' % mod) 22 | self.test_modality(mod, modi) 23 | 24 | def make_test_folder(self, modality, suffix=''): 25 | folder = os.path.join(self.conf.folder, 'test_results_%s_%s_%s' % (self.conf.test_dataset, modality, suffix)) 26 | if not os.path.exists(folder): 27 | os.makedirs(folder) 28 | return folder 29 | 30 | def test_modality(self, modality, modality_index): 31 | """ 32 | Evaluate model on a given modality 33 | :param modality: the modality to load 34 | """ 35 | test_loader = loader_factory.init_loader(self.conf.test_dataset) 36 | test_loader.modalities = self.conf.modality 37 | test_data = test_loader.load_all_modalities_concatenated(self.conf.split, 'test', self.conf.image_downsample) 38 | test_data.crop(self.conf.input_shape[:2]) # crop data to input shape 39 | 40 | for type in ['simple', 'def', 'max']: 41 | folder = self.make_test_folder(modality, suffix=type) 42 | self.test_modality_type(folder, modality_index, type, test_loader, test_data) 43 | 44 | test_data.randomise_pairs(length=2, seed=self.conf.seed) 45 | for type in ['simple', 'def', 'max']: 46 | folder = self.make_test_folder(modality, suffix=type + '_rand') 47 | self.test_modality_type(folder, modality_index, type, test_loader, test_data) 48 | 49 | def test_modality_type(self, folder, modality_index, type, test_loader, test_data): 50 | assert type in ['simple', 'def', 'max', 'maxnostn'] 51 | 52 | samples = os.path.join(folder, 'samples') 53 | if not os.path.exists(samples): 54 | os.makedirs(samples) 55 | 56 | synth = [] 57 | im_dice = {} 58 | 59 | f = open(os.path.join(folder, 'results.csv'), 'w') 60 | f.writelines('Vol, Dice, ' + ', '.join(['Dice%d' % mi for mi in range(test_loader.num_masks)]) + '\n') 61 | for vol_i in test_data.volumes(): 62 | vol_folder = os.path.join(samples, 'vol_%s' % str(vol_i)) 63 | if not os.path.exists(vol_folder): 64 | os.makedirs(vol_folder) 65 | 66 | vol_image_mod1 = test_data.get_volume_images_modi(0, vol_i) 67 | vol_image_mod2 = test_data.get_volume_images_modi(1, vol_i) 68 | assert vol_image_mod1.shape[0] > 0 69 | 70 | vol_mask = test_data.get_volume_masks_modi(modality_index, vol_i) 71 | prd_mask = self.model.predict_mask(modality_index, type, [vol_image_mod1, vol_image_mod2]) 72 | 73 | synth.append(prd_mask) 74 | im_dice[vol_i] = costs.dice(vol_mask, prd_mask, binarise=True) 75 | sep_dice = [costs.dice(vol_mask[..., mi:mi + 1], prd_mask[..., mi:mi + 1], binarise=True) 76 | for mi in range(test_loader.num_masks)] 77 | 78 | s = '%s, %.3f, ' + ', '.join(['%.3f'] * test_loader.num_masks) + '\n' 79 | d = (str(vol_i), im_dice[vol_i]) + tuple(sep_dice) 80 | f.writelines(s % d) 81 | 82 | self.plot_images(samples, vol_i, modality_index, prd_mask, vol_mask, [vol_image_mod1, vol_image_mod2]) 83 | 84 | print('%s - Dice score: %.3f' % (type, np.mean(list(im_dice.values())))) 85 | f.close() 86 | 87 | def plot_images(self, samples, vol_i, modality_index, prd_mask, vol_mask, image_list): 88 | vol_image_mod2 = image_list[modality_index] 89 | 90 | for i in range(vol_image_mod2.shape[0]): 91 | vol_folder = os.path.join(samples, 'vol_%s' % str(vol_i)) 92 | if not os.path.exists(vol_folder): 93 | os.makedirs(vol_folder) 94 | 95 | row1 = [vol_image_mod2[i, :, :, 0]] + [prd_mask[i, :, :, j] for j in range(vol_mask.shape[-1])] 96 | row2 = [vol_image_mod2[i, :, :, 0]] + [vol_mask[i, :, :, j] for j in range(vol_mask.shape[-1])] 97 | 98 | row1 = np.concatenate(row1, axis=1) 99 | row2 = np.concatenate(row2, axis=1) 100 | im = np.concatenate([row1, row2], axis=0) 101 | 102 | scipy.misc.imsave(os.path.join(vol_folder, 'test_vol%s_im%d.png' % (str(vol_i), i)), im) 103 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/multimodal_segmentation/a4fa1b39830f6c1bc320ff5b5e3fda82b8382e18/models/__init__.py -------------------------------------------------------------------------------- /models/basenet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from abc import abstractmethod 4 | 5 | from keras import Model, Input 6 | from keras.callbacks import EarlyStopping, CSVLogger 7 | from keras.layers import Lambda 8 | from keras.optimizers import Adam 9 | from keras.preprocessing.image import ImageDataGenerator 10 | 11 | from callbacks.image_callback import SaveImage 12 | from costs import make_dice_loss_fnc, make_combined_dice_bce, weighted_cross_entropy_loss 13 | from loaders import loader_factory 14 | 15 | log = logging.getLogger('basenet') 16 | 17 | 18 | class BaseNet(object): 19 | """ 20 | Base model for segmentation neural networks 21 | """ 22 | def __init__(self, conf): 23 | self.model = None 24 | self.conf = conf 25 | self.loader = None 26 | if hasattr(self.conf, 'dataset_name') and len(self.conf.dataset_name) > 0: 27 | self.loader = loader_factory.init_loader(self.conf.dataset_name) 28 | 29 | @abstractmethod 30 | def build(self): 31 | pass 32 | 33 | def load_models(self): 34 | pass 35 | 36 | @abstractmethod 37 | def get_segmentor(self, modality): 38 | pass 39 | -------------------------------------------------------------------------------- /models/dafnet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import traceback 3 | 4 | from keras import Input, Model 5 | from keras.layers import Lambda, Multiply, Add 6 | from keras.optimizers import Adam 7 | 8 | import costs 9 | from model_components import anatomy_fuser, modality_encoder, segmentor, decoder, balancer 10 | from model_components.anatomy_encoder import AnatomyEncoders 11 | from models.discriminator import Discriminator 12 | from models.mmsdnet import MMSDNet 13 | from utils.sdnet_utils import make_trainable 14 | 15 | log = logging.getLogger('dafnet') 16 | 17 | 18 | class DAFNet(MMSDNet): 19 | def __init__(self, conf): 20 | super(DAFNet, self).__init__(conf) 21 | 22 | self.D_Mask = None # Mask Discriminator 23 | self.D_Image1 = None # Image Discriminator for modality 1 24 | self.D_Image2 = None # Image Discriminator for modality 2 25 | self.Encoders_Anatomy = None # list of anatomy encoders for every modality 26 | self.Enc_Modality = None # Modality Encoder 27 | self.Enc_Modality_mu = None # The mean value of the Modality Encoder prediction 28 | self.Anatomy_Fuser = None # Anatomy Fuser that deforms and fused anatomies 29 | self.Segmentor = None # Segmentation network 30 | self.Decoder = None # Decoder network 31 | self.Balancer = None # Model that calculates weighs similarity of anatomies 32 | 33 | # Trainers 34 | self.D_Mask_trainer = None # Trainer for mask discriminator 35 | self.D_Image1_trainer = None # Trainer for image modality1 discriminator 36 | self.D_Image2_trainer = None # Trainer for image modality2 discriminator 37 | self.unsupervised_trainer = None # Trainer when having unlabelled data 38 | self.supervised_trainer = None # Trainer when using data with labels. 39 | self.Z_Regressor = None # Trainer for reconstructing a sampled Z 40 | 41 | def build(self): 42 | self.build_mask_discriminator() 43 | self.build_image_discriminator1() 44 | self.build_image_discriminator2() 45 | 46 | self.build_generators() 47 | try: 48 | self.load_models() 49 | except: 50 | log.warning('No models found') 51 | traceback.print_exc() 52 | pass 53 | 54 | def load_models(self): 55 | log.info('Loading trained models from file') 56 | 57 | model_folder = self.conf.folder + '/models/' 58 | self.D_Mask.load_weights(model_folder + '/D_Mask') 59 | self.D_Image1.load_weights(model_folder + '/D_Image1') 60 | self.D_Image2.load_weights(model_folder + '/D_Image2') 61 | 62 | self.Encoders_Anatomy[0].load_weights(model_folder + 'Enc_Anatomy1') 63 | self.Encoders_Anatomy[1].load_weights(model_folder + 'Enc_Anatomy2') 64 | self.Enc_Modality.load_weights(model_folder + 'Enc_Modality') 65 | self.Anatomy_Fuser.load_weights(model_folder + 'Anatomy_Fuser') 66 | self.Segmentor.load_weights(model_folder + 'Segmentor') 67 | self.Decoder.load_weights(model_folder + 'Decoder') 68 | try: 69 | self.Balancer.load_weights(model_folder + 'Balancer') 70 | except: 71 | pass 72 | 73 | self.build_trainers() 74 | 75 | def build_image_discriminator1(self): 76 | """ 77 | Build a discriminator for images 78 | """ 79 | params1 = self.conf.d_image_params 80 | params1['name'] = 'D_Image1' 81 | D = Discriminator(params1) 82 | D.build() 83 | log.info('Image Discriminator D_I') 84 | D.model.summary(print_fn=log.info) 85 | self.D_Image1 = D.model 86 | 87 | real_x = Input(self.conf.d_image_params.input_shape) 88 | fake_x = Input(self.conf.d_image_params.input_shape) 89 | real = self.D_Image1(real_x) 90 | fake = self.D_Image1(fake_x) 91 | 92 | self.D_Image1_trainer = Model([real_x, fake_x], [real, fake], name='D_Image1_trainer') 93 | self.D_Image1_trainer.compile(Adam(lr=self.conf.d_image_params.lr), loss='mse') 94 | self.D_Image1_trainer.summary(print_fn=log.info) 95 | 96 | def build_image_discriminator2(self): 97 | """ 98 | Build a discriminator for images 99 | """ 100 | params2 = self.conf.d_image_params 101 | params2['name'] = 'D_Image2' 102 | D = Discriminator(params2) 103 | D.build() 104 | log.info('Image Discriminator D_I2') 105 | D.model.summary(print_fn=log.info) 106 | self.D_Image2 = D.model 107 | 108 | real_x = Input(self.conf.d_image_params.input_shape) 109 | fake_x = Input(self.conf.d_image_params.input_shape) 110 | real = self.D_Image2(real_x) 111 | fake = self.D_Image2(fake_x) 112 | 113 | self.D_Image2_trainer = Model([real_x, fake_x], [real, fake], name='D_Image2_trainer') 114 | self.D_Image2_trainer.compile(Adam(lr=self.conf.d_image_params.lr), loss='mse') 115 | self.D_Image2_trainer.summary(print_fn=log.info) 116 | 117 | def build_generators(self): 118 | assert self.D_Mask is not None, 'Discriminator has not been built yet' 119 | make_trainable(self.D_Mask, False) 120 | make_trainable(self.D_Image1, False) 121 | make_trainable(self.D_Image2, False) 122 | 123 | self.Encoders_Anatomy = AnatomyEncoders(self.modalities).build(self.conf.anatomy_encoder) 124 | self.Anatomy_Fuser = anatomy_fuser.build(self.conf) 125 | self.Enc_Modality = modality_encoder.build(self.conf) 126 | self.Enc_Modality_mu = Model(self.Enc_Modality.inputs, self.Enc_Modality.get_layer('z_mean').output) 127 | self.Segmentor = segmentor.build(self.conf) 128 | self.Decoder = decoder.build(self.conf) 129 | self.Balancer = balancer.build(self.conf) 130 | 131 | self.build_trainers() 132 | 133 | def build_trainers(self): 134 | self.build_z_regressor() 135 | if not self.conf.automatedpairing: 136 | self.build_trainers_expertpairs() 137 | else: 138 | self.build_trainers_automatedpairs() 139 | 140 | def build_trainers_expertpairs(self): 141 | """ 142 | Build trainer models for unsupervised and supervised learning, when multimodal data are expertly paired. 143 | This assumes two modalities 144 | """ 145 | losses = {'Segmentor': costs.make_combined_dice_bce(self.loader.num_masks), 'D_Mask': 'mse', 'Decoder': 'mae', 146 | 'D_Image1': 'mse', 'D_Image2': 'mse', 'Enc_Modality': costs.ypred, 'ZReconstruct': 'mae'} 147 | loss_weights = {'Segmentor': self.conf.w_sup_M, 'D_Mask': self.conf.w_adv_M, 'Decoder': self.conf.w_rec_X, 148 | 'D_Image1': self.conf.w_adv_X, 'D_Image2': self.conf.w_adv_X, 'Enc_Modality': self.conf.w_kl, 149 | 'ZReconstruct': self.conf.w_rec_Z} 150 | 151 | all_inputs, all_outputs = self.get_params_expert_pairing(supervised=False) 152 | self.unsupervised_trainer = Model(inputs=all_inputs, outputs=all_outputs) 153 | log.info('Unsupervised model trainer') 154 | self.unsupervised_trainer.summary(print_fn=log.info) 155 | self.unsupervised_trainer.compile(Adam(self.conf.lr), loss=losses, loss_weights=loss_weights) 156 | 157 | all_inputs, all_outputs = self.get_params_expert_pairing(supervised=True) 158 | self.supervised_trainer = Model(inputs=all_inputs, outputs=all_outputs) 159 | log.info('Supervised model trainer') 160 | self.supervised_trainer.summary(print_fn=log.info) 161 | self.supervised_trainer.compile(Adam(self.conf.lr), loss=losses, loss_weights=loss_weights) 162 | 163 | def get_params_expert_pairing(self, supervised): 164 | """ 165 | Connect the DAFNet components for supervised or unsupervised training 166 | :return: a list of inputs and outputs 167 | """ 168 | # inputs 169 | x1 = Input(shape=self.conf.input_shape) 170 | x2 = Input(shape=self.conf.input_shape) 171 | 172 | # encode 173 | s1 = self.Encoders_Anatomy[0](x1) 174 | s2 = self.Encoders_Anatomy[1](x2) 175 | z1, kl1 = self.Enc_Modality([s1, x1]) 176 | z2, kl2 = self.Enc_Modality([s2, x2]) 177 | 178 | # segment 179 | m1 = self.Segmentor(s1) 180 | m2 = self.Segmentor(s2) 181 | 182 | # decoder 183 | y1 = self.Decoder([s1, z1]) 184 | y2 = self.Decoder([s2, z2]) 185 | 186 | # GANs 187 | adv_m1 = self.D_Mask(Lambda(lambda x: x[..., 0:self.conf.num_masks])(m1)) 188 | adv_m2 = self.D_Mask(Lambda(lambda x: x[..., 0:self.conf.num_masks])(m2)) 189 | adv_y1 = self.D_Image1(y1) 190 | adv_y2 = self.D_Image2(y2) 191 | 192 | # deform and fuse 193 | s1_def, _ = self.Anatomy_Fuser([s1, s2]) 194 | s2_def, _ = self.Anatomy_Fuser([s2, s1]) 195 | 196 | # segment 197 | m2_s1_def = self.Segmentor(s1_def) 198 | m1_s2_def = self.Segmentor(s2_def) 199 | 200 | # decoder (cross-reconstruction) 201 | y2_s1_def = self.Decoder([s1_def, z2]) 202 | y1_s2_def = self.Decoder([s2_def, z1]) 203 | 204 | # GANs 205 | adv_m2_s1_def = self.D_Mask(Lambda(lambda x: x[..., 0:self.conf.num_masks])(m2_s1_def)) 206 | adv_m1_s2_def = self.D_Mask(Lambda(lambda x: x[..., 0:self.conf.num_masks])(m1_s2_def)) 207 | adv_y2_s1_def = self.D_Image2(y2_s1_def) 208 | adv_y1_s2_def = self.D_Image1(y1_s2_def) 209 | 210 | # Z-Regressor 211 | z1_input = Input(shape=(self.conf.num_z,)) 212 | z2_input = Input(shape=(self.conf.num_z,)) 213 | [z1_rec, z2_rec] = self.Z_Regressor([s1, s2, z1_input, z2_input]) 214 | 215 | # inputs / outputs 216 | all_inputs = [x1, x2, z1_input, z2_input] 217 | all_outputs = [m1, m2, m1_s2_def, m2_s1_def] if supervised else [m1, m1_s2_def] 218 | all_outputs += [adv_m1, adv_m2, adv_m1_s2_def, adv_m2_s1_def] + \ 219 | [y1, y2, y1_s2_def, y2_s1_def] + \ 220 | [adv_y1, adv_y2, adv_y1_s2_def, adv_y2_s1_def] + \ 221 | [kl1, kl2, z1_rec, z2_rec] 222 | return all_inputs, all_outputs 223 | 224 | def build_trainers_automatedpairs(self): 225 | """ 226 | Build trainer models for unsupervised and supervised learning, when multimodal data are automatically paired. 227 | This assumes two modalities 228 | """ 229 | losses = {'Segmentor': costs.make_combined_dice_bce(self.loader.num_masks), 'SegmentorDef': costs.ypred, 230 | 'D_Mask': 'mse', 'Decoder': 'mae', 'DecoderDef': costs.ypred, 231 | 'D_Image1': 'mse', 'D_Image2': 'mse', 'Enc_Modality': costs.ypred, 'ZReconstruct': 'mae'} 232 | loss_weights = {'Segmentor': self.conf.w_sup_M, 'SegmentorDef': self.conf.w_sup_M, 'D_Mask': self.conf.w_adv_M, 233 | 'Decoder': self.conf.w_rec_X, 'DecoderDef': self.conf.w_rec_X, 'D_Image1': self.conf.w_adv_X, 234 | 'D_Image2': self.conf.w_adv_X, 'Enc_Modality': self.conf.w_kl, 'ZReconstruct': self.conf.w_rec_Z} 235 | 236 | all_inputs, all_outputs = self.get_params_automated_pairing(supervised=False) 237 | self.unsupervised_trainer = Model(inputs=all_inputs, outputs=all_outputs) 238 | log.info('Unupervised model trainer') 239 | self.unsupervised_trainer.summary(print_fn=log.info) 240 | self.unsupervised_trainer.compile(Adam(self.conf.lr), loss=losses, loss_weights=loss_weights) 241 | 242 | all_inputs, all_outputs = self.get_params_automated_pairing(supervised=True) 243 | self.supervised_trainer = Model(inputs=all_inputs, outputs=all_outputs) 244 | log.info('Supervised model trainer') 245 | self.supervised_trainer.summary(print_fn=log.info) 246 | self.supervised_trainer.compile(Adam(self.conf.lr), loss=losses, loss_weights=loss_weights) 247 | 248 | def get_params_automated_pairing(self, supervised): 249 | """ 250 | Connect the DAFNet components for supervised or unsupervised training 251 | :return: a list of inputs and outputs 252 | """ 253 | # inputs 254 | x1_lst = [Input(shape=self.conf.input_shape) for _ in range(self.conf.n_pairs)] 255 | x2_lst = [Input(shape=self.conf.input_shape) for _ in range(self.conf.n_pairs)] 256 | m1_input = Input(shape=self.conf.input_shape[:-1] + [self.conf.num_masks + 1]) 257 | x1 = x1_lst[0] 258 | x2 = x2_lst[0] 259 | 260 | # encode 261 | s1_lst = [self.Encoders_Anatomy[0](x) for x in x1_lst] 262 | s2_lst = [self.Encoders_Anatomy[1](x) for x in x2_lst] 263 | s1 = s1_lst[0] 264 | s2 = s2_lst[0] 265 | z1, kl1 = self.Enc_Modality([s1, x1]) 266 | z2, kl2 = self.Enc_Modality([s2, x2]) 267 | 268 | # segment 269 | m1 = self.Segmentor(s1) 270 | m2 = self.Segmentor(s2) 271 | 272 | # decode 273 | y1 = self.Decoder([s1, z1]) 274 | y2 = self.Decoder([s2, z2]) 275 | 276 | # GANs 277 | adv_m1 = self.D_Mask(Lambda(lambda x: x[..., 0:self.conf.num_masks])(m1)) 278 | adv_m2 = self.D_Mask(Lambda(lambda x: x[..., 0:self.conf.num_masks])(m2)) 279 | adv_y1 = self.D_Image1(y1) 280 | adv_y2 = self.D_Image2(y2) 281 | 282 | # deform and fuse 283 | s1_def_lst = [self.Anatomy_Fuser([s1_i, s2])[0] for s1_i in s1_lst] 284 | w1_def_lst = self.calculate_weights([s2] + s1_def_lst) 285 | 286 | s2_def_lst = [self.Anatomy_Fuser([s2_i, s1])[0] for s2_i in s2_lst] 287 | w2_def_lst = self.calculate_weights([s1] + s2_def_lst) 288 | 289 | # decoder (cross-reconstruction) 290 | DecoderLoss = Lambda(lambda x: costs.mae_single_input(x)) 291 | DecoderDef = Add(name='DecoderDef') 292 | 293 | y2_s1_def_lst = [self.Decoder([s1_def, z2]) for s1_def in s1_def_lst] 294 | y1_s2_def_lst = [self.Decoder([s2_def, z1]) for s2_def in s2_def_lst] 295 | y2_s1_def = DecoderDef([Multiply()([w, DecoderLoss([x2, y2_s1_def])]) 296 | for w, y2_s1_def in zip(w1_def_lst, y2_s1_def_lst)]) 297 | y1_s2_def = DecoderDef([Multiply()([w, DecoderLoss([x1, y1_s2_def])]) 298 | for w, y1_s2_def in zip(w2_def_lst, y1_s2_def_lst)]) 299 | 300 | # segment 301 | SegmentorDef = Add(name='SegmentorDef') 302 | SegmentorLoss = Lambda(lambda x: costs.make_combined_dice_bce_perbatch(self.loader.num_masks)(x[0], x[1])) 303 | 304 | m1_s2_def_lst = [self.Segmentor(s2_def) for s2_def in s2_def_lst] 305 | m1_s2_def = SegmentorDef([Multiply()([w, SegmentorLoss([m1_input, m1_s2_def])]) 306 | for w, m1_s2_def in zip(w2_def_lst, m1_s2_def_lst)]) 307 | m2_s1_def_lst = [self.Segmentor(s1_def) for s1_def in s1_def_lst] 308 | 309 | if supervised: 310 | m2_input = Input(shape=self.conf.input_shape[:-1] + [self.conf.num_masks + 1]) 311 | m2_s1_def = SegmentorDef([Multiply()([w, SegmentorLoss([m2_input, m2_s1_def])]) 312 | for w, m2_s1_def in zip(w1_def_lst, m2_s1_def_lst)]) 313 | 314 | # GANs 315 | adv_m2_s1_def = self.D_Mask(Lambda(lambda x: x[..., 0:self.conf.num_masks])(m2_s1_def_lst[0])) 316 | adv_m1_s2_def = self.D_Mask(Lambda(lambda x: x[..., 0:self.conf.num_masks])(m1_s2_def_lst[0])) 317 | adv_y2_s1_def = self.D_Image2(y2_s1_def_lst[0]) 318 | adv_y1_s2_def = self.D_Image1(y1_s2_def_lst[0]) 319 | 320 | # Z-Regressor 321 | z1_input = Input(shape=(self.conf.num_z,)) 322 | z2_input = Input(shape=(self.conf.num_z,)) 323 | [z1_rec, z2_rec] = self.Z_Regressor([s1, s2, z1_input, z2_input]) 324 | 325 | # outputs 326 | all_inputs = x1_lst + x2_lst + [m1_input, m2_input, z1_input, z2_input] if supervised \ 327 | else x1_lst + x2_lst + [m1_input, z1_input, z2_input] 328 | all_outputs = [m1, m2, m1_s2_def, m2_s1_def] if supervised else [m1, m1_s2_def] 329 | all_outputs += [adv_m1, adv_m2, adv_m1_s2_def, adv_m2_s1_def] + \ 330 | [y1, y2, y1_s2_def, y2_s1_def] + \ 331 | [adv_y1, adv_y2, adv_y1_s2_def, adv_y2_s1_def] + \ 332 | [kl1, kl2, z1_rec, z2_rec] 333 | 334 | return all_inputs, all_outputs 335 | 336 | def build_z_regressor(self): 337 | """ 338 | Regress the modality factor. Assumes 4 inputs: 2 s-factors for the 2 modalities 339 | """ 340 | num_inputs = 2 341 | 342 | s_lst = [Input(self.conf.anatomy_encoder.output_shape) for _ in range(num_inputs)] 343 | z_lst = [Input((self.conf.num_z,)) for _ in range(num_inputs)] 344 | y_lst = [self.Decoder([s, z]) for s, z in zip(s_lst, z_lst)] 345 | 346 | z_rec_lst = [self.Enc_Modality_mu([s, y]) for s, y in zip(s_lst, y_lst)] 347 | 348 | self.Z_Regressor = Model(inputs=s_lst + z_lst, outputs=z_rec_lst, name='ZReconstruct') 349 | self.Z_Regressor.compile(Adam(self.conf.lr), loss=['mae', 'mae'], 350 | loss_weights=[self.conf.w_rec_Z, self.conf.w_rec_Z]) 351 | 352 | def calculate_weights(self, inputs): 353 | s_mod2 = inputs[0] 354 | s_list = inputs[1:] 355 | 356 | if len(s_list) == 1: 357 | return None 358 | 359 | weights = self.Balancer([s_mod2] + s_list) 360 | weights = [Lambda(lambda x: x[..., j:j+1])(weights) for j in range(self.conf.n_pairs)] 361 | return weights 362 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | from keras import Input, Model 2 | from keras.layers import Conv2D, LeakyReLU, Flatten, Dense 3 | from keras.optimizers import Adam 4 | 5 | from layers.spectralnorm import Spectral 6 | from models.basenet import BaseNet 7 | 8 | 9 | class Discriminator(BaseNet): 10 | ''' 11 | DCGAN Discriminator with Spectral Norm and LS-GAN loss 12 | ''' 13 | def __init__(self, conf): 14 | super(Discriminator, self).__init__(conf) 15 | 16 | def build(self): 17 | inp_shape = self.conf.input_shape 18 | name = self.conf.name 19 | f = self.conf.filters 20 | downsample_blocks = 3 if not hasattr(self.conf, 'downsample_blocks') else self.conf.downsample_blocks 21 | assert downsample_blocks > 1, downsample_blocks 22 | 23 | d_input = Input(inp_shape) 24 | l = Conv2D(f, 4, strides=2, kernel_initializer="he_normal")(d_input) 25 | l = LeakyReLU(0.2)(l) 26 | 27 | for i in range(downsample_blocks): 28 | s = 1 if i == downsample_blocks - 1 else 2 29 | spectral_params = f * (2 ** i) 30 | l = self._downsample_block(l, f * 2 * (2 ** i), s, spectral_params) 31 | 32 | l = Flatten()(l) 33 | l = Dense(1, activation="linear")(l) 34 | 35 | self.model = Model(d_input, l, name=name) 36 | return self.model 37 | 38 | def _downsample_block(self, l0, f, stride, spectral_params, name=''): 39 | l = Conv2D(f, 4, strides=stride, kernel_initializer="he_normal", 40 | kernel_regularizer=Spectral( spectral_params * 4 * 4, 10.), name=name)(l0) 41 | return LeakyReLU(0.2)(l) 42 | 43 | def compile(self): 44 | assert self.model is not None, 'Model has not been built' 45 | self.model.compile(optimizer=Adam(lr=self.conf.lr), loss='mse') 46 | -------------------------------------------------------------------------------- /models/mmsdnet.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import os 4 | import numpy as np 5 | from keras import Input, Model 6 | from keras.layers import Lambda 7 | from keras.optimizers import Adam 8 | 9 | import costs 10 | from model_components import anatomy_encoder, anatomy_fuser, modality_encoder, segmentor, decoder 11 | from models.basenet import BaseNet 12 | from models.discriminator import Discriminator 13 | from utils.sdnet_utils import make_trainable, get_net 14 | 15 | log = logging.getLogger('mmsdnet') 16 | 17 | 18 | class MMSDNet(BaseNet): 19 | def __init__(self, conf): 20 | super(MMSDNet, self).__init__(conf) 21 | 22 | self.modalities = conf.modality # list of input modalities 23 | 24 | self.D_Mask = None # Mask Discriminator 25 | self.Encoders_Anatomy = None # list of anatomy encoders for every modality 26 | self.Enc_Modality = None # Modality Encoder 27 | self.Enc_Modality_mu = None # The mean value of the Modality Encoder prediction 28 | self.Anatomy_Fuser = None # Anatomy Fuser that deforms and fused anatomies 29 | self.Segmentor = None # Segmentation network 30 | self.Decoder = None # Decoder network 31 | 32 | self.D_Mask_trainer = None # Trainer for mask discriminator 33 | self.unsupervised_trainer = None # Trainer when having unlabelled data 34 | self.supervised_trainer = None # Trainer when using data with labels. 35 | self.Z_Regressor = None # Trainer for reconstructing a sampled Z 36 | 37 | def build(self): 38 | self.build_mask_discriminator() 39 | self.build_generators() 40 | self.load_models() 41 | 42 | def load_models(self): 43 | if os.path.exists(self.conf.folder + '/supervised_trainer'): 44 | log.info('Loading trained models from file') 45 | 46 | self.supervised_trainer.load_weights(self.conf.folder + '/supervised_trainer') 47 | 48 | self.Encoders_Anatomy = [get_net(self.supervised_trainer, 'Enc_Anatomy_%s' % mod) for mod in self.modalities] 49 | 50 | self.Enc_Modality = get_net(self.supervised_trainer, 'Enc_Modality') 51 | self.Enc_Modality_mu = Model(self.Enc_Modality.inputs, self.Enc_Modality.get_layer('z_mean').output) 52 | self.Anatomy_Fuser= get_net(self.supervised_trainer, 'Anatomy_Fuser') 53 | self.Segmentor = get_net(self.supervised_trainer, 'Segmentor') 54 | self.Decoder = get_net(self.supervised_trainer, 'Decoder') 55 | self.D_Mask = get_net(self.supervised_trainer, 'D_Mask') 56 | self.build_z_regressor() 57 | 58 | def save_models(self): 59 | log.debug('Saving trained models') 60 | self.supervised_trainer.save_weights(self.conf.folder + '/supervised_trainer') 61 | 62 | def build_mask_discriminator(self): 63 | # Build a discriminator for masks. 64 | D = Discriminator(self.conf.d_mask_params) 65 | D.build() 66 | log.info('Mask Discriminator D_M') 67 | D.model.summary(print_fn=log.info) 68 | self.D_Mask = D.model 69 | 70 | real_M = Input(self.conf.d_mask_params.input_shape) 71 | fake_M = Input(self.conf.d_mask_params.input_shape) 72 | real = self.D_Mask(real_M) 73 | fake = self.D_Mask(fake_M) 74 | 75 | self.D_Mask_trainer = Model([real_M, fake_M], [real, fake], name='D_Mask_trainer') 76 | self.D_Mask_trainer.compile(Adam(lr=self.conf.d_mask_params.lr), loss='mse') 77 | self.D_Mask_trainer.summary(print_fn=log.info) 78 | 79 | def build_generators(self): 80 | assert self.D_Mask is not None, 'Discriminator has not been built yet' 81 | make_trainable(self.D_Mask_trainer, False) 82 | 83 | self.Encoders_Anatomy = [anatomy_encoder.build(self.conf.anatomy_encoder, 'Enc_Anatomy_%s' % mod) 84 | for mod in self.modalities] 85 | self.Anatomy_Fuser = anatomy_fuser.build(self.conf) 86 | self.Enc_Modality = modality_encoder.build(self.conf) 87 | self.Enc_Modality_mu = Model(self.Enc_Modality.inputs, self.Enc_Modality.get_layer('z_mean').output) 88 | self.Segmentor = segmentor.build(self.conf) 89 | self.Decoder = decoder.build(self.conf) 90 | 91 | self.build_unsupervised_trainer() # build standard gan for data with no labels 92 | self.build_supervised_trainer() 93 | self.build_z_regressor() 94 | 95 | def build_unsupervised_trainer(self): 96 | # Model for unsupervised training 97 | 98 | # inputs 99 | x_list = [Input(shape=self.conf.input_shape) for _ in self.modalities] 100 | num_mod = len(self.modalities) 101 | 102 | # x -> s, z -> m, y 103 | s_list = [self.Encoders_Anatomy[i](x_list[i]) for i in range(num_mod)] 104 | z_list = [self.Enc_Modality([s_list[i], x_list[i]]) for i in range(num_mod)] 105 | m1, m2 = self.Segmentor(s_list[0]), self.Segmentor(s_list[1]) 106 | 107 | m_list = [m1] 108 | adv_m_list = [self.D_Mask(Lambda(lambda x: x[..., 0:self.conf.num_masks])(m)) for m in [m1, m2]] 109 | rec_x_list = [self.Decoder([s_list[i], z_list[i][0]]) for i in range(num_mod)] 110 | 111 | # segment deformed and fused 112 | s1_def, s1_fused = self.Anatomy_Fuser(s_list) 113 | s2_def, s2_fused = self.Anatomy_Fuser(list(reversed(s_list))) 114 | 115 | fused_segmentations = [self.Segmentor(s) for s in [s1_def, s1_fused, s2_def, s2_fused]] 116 | m_list += fused_segmentations[2:] # there are masks only for modality1 117 | adv_m_list += [self.D_Mask(Lambda(lambda x: x[..., 0:self.conf.num_masks])(s)) for s in fused_segmentations] 118 | 119 | # reconstruct deformed and fused 120 | z_list_s1def = [self.Enc_Modality([s, x_list[1]]) for s in [s1_def, s1_fused]] 121 | rec_x_list += [self.Decoder([s, z_list_s1def[i][0]]) for i, s in enumerate([s1_def, s1_fused])] 122 | 123 | z_list_s2def = [self.Enc_Modality([s, x_list[0]]) for s in [s2_def, s2_fused]] 124 | rec_x_list += [self.Decoder([s, z_list_s2def[i][0]]) for i, s in enumerate([s2_def, s2_fused])] 125 | 126 | # list of KL divergences for every modality 127 | diverg_list = [z_list[i][1] for i in range(num_mod)] 128 | diverg_list += [z_list_s1def[i][1] for i in range(num_mod)] 129 | diverg_list += [z_list_s2def[i][1] for i in range(num_mod)] 130 | 131 | all_outputs = m_list + adv_m_list + rec_x_list + diverg_list 132 | self.unsupervised_trainer = Model(inputs=x_list, outputs=all_outputs) 133 | log.info('Unsupervised trainer') 134 | self.unsupervised_trainer.summary(print_fn=log.info) 135 | 136 | loss_list = [costs.make_dice_loss_fnc(self.loader.num_masks) for _ in range(3)] + \ 137 | ['mse'] * (num_mod * 3) + \ 138 | ['mae'] * (num_mod * 3) + \ 139 | [costs.ypred for _ in range(num_mod * 3)] 140 | weights_list = [self.conf.w_sup_M for _ in range(3)] + \ 141 | [self.conf.w_adv_M for _ in range(num_mod * 3)] + \ 142 | [self.conf.w_rec_X for _ in range(num_mod * 3)] + \ 143 | [self.conf.w_kl for _ in range(num_mod * 3)] 144 | self.unsupervised_trainer.compile(Adam(self.conf.lr), loss=loss_list, loss_weights=weights_list) 145 | 146 | def build_supervised_trainer(self): 147 | # Model for unsupervised training 148 | 149 | # inputs 150 | x_list = [Input(shape=self.conf.input_shape) for _ in self.modalities] 151 | num_mod = len(self.modalities) 152 | 153 | s_list = [self.Encoders_Anatomy[i](x_list[i]) for i in range(num_mod)] 154 | z_list = [self.Enc_Modality([s_list[i], x_list[i]]) for i in range(num_mod)] 155 | m_list = [self.Segmentor(s) for s in s_list] 156 | adv_m_list = [self.D_Mask(Lambda(lambda x: x[..., 0:self.conf.num_masks])(m)) for m in m_list] 157 | rec_x_list = [self.Decoder([s_list[i], z_list[i][0]]) for i in range(num_mod)] 158 | 159 | # segment deformed and fused 160 | s1_def, s1_fused = self.Anatomy_Fuser(s_list) 161 | s2_def, s2_fused = self.Anatomy_Fuser(list(reversed(s_list))) 162 | 163 | fused_segmentations = [self.Segmentor(s) for s in [s1_def, s1_fused, s2_def, s2_fused]] 164 | m_list += fused_segmentations 165 | adv_m_list += [self.D_Mask(Lambda(lambda x: x[..., 0:self.conf.num_masks])(s)) for s in fused_segmentations] 166 | 167 | # reconstruct deformed and fused 168 | z_list_s1def = [self.Enc_Modality([s, x_list[1]]) for s in [s1_def, s1_fused]] 169 | rec_x_list += [self.Decoder([s, z_list_s1def[i][0]]) for i, s in enumerate([s1_def, s1_fused])] 170 | 171 | z_list_s2def = [self.Enc_Modality([s, x_list[0]]) for s in [s2_def, s2_fused]] 172 | rec_x_list += [self.Decoder([s, z_list_s2def[i][0]]) for i, s in enumerate([s2_def, s2_fused])] 173 | 174 | # list of KL divergences for every modality 175 | diverg_list = [z_list[i][1] for i in range(num_mod)] 176 | diverg_list += [z_list_s1def[i][1] for i in range(num_mod)] 177 | diverg_list += [z_list_s2def[i][1] for i in range(num_mod)] 178 | 179 | all_outputs = m_list + adv_m_list + rec_x_list + diverg_list 180 | self.supervised_trainer = Model(inputs=x_list, outputs=all_outputs) 181 | log.info('Supervised trainer') 182 | self.supervised_trainer.summary(print_fn=log.info) 183 | 184 | loss_list = [costs.make_dice_loss_fnc(self.loader.num_masks) for _ in range(num_mod * 3)] + \ 185 | ['mse'] * (num_mod * 3) + \ 186 | ['mae'] * (num_mod * 3) + \ 187 | [costs.ypred for _ in range(num_mod * 3)] 188 | weights_list = [self.conf.w_sup_M for _ in range(num_mod * 3)] + \ 189 | [self.conf.w_adv_M for _ in range(num_mod * 3)] + \ 190 | [self.conf.w_rec_X for _ in range(num_mod * 3)] + \ 191 | [self.conf.w_kl for _ in range(num_mod * 3)] 192 | self.supervised_trainer.compile(Adam(self.conf.lr), loss=loss_list, loss_weights=weights_list) 193 | 194 | def build_z_regressor(self): 195 | num_inputs = len(self.modalities) + 4 196 | s_list = [Input(self.conf.anatomy_encoder.output_shape) for _ in range(num_inputs)] 197 | sample_z_list = [Input((self.conf.num_z,)) for _ in range(num_inputs)] 198 | sample_x_list = [self.Decoder([s_list[i], sample_z_list[i]]) for i in range(num_inputs)] 199 | 200 | rec_Z_list = [self.Enc_Modality_mu([s_list[i], sample_x_list[i]]) for i in range(num_inputs)] 201 | 202 | all_inputs = s_list + sample_z_list 203 | self.Z_Regressor = Model(inputs=all_inputs, outputs=rec_Z_list, name='ZReconstruct') 204 | log.info('Z Regressor') 205 | self.Z_Regressor.summary(print_fn=log.info) 206 | losses = ['mae'] * (num_inputs) 207 | weights = [self.conf.w_rec_Z for _ in range(num_inputs)] 208 | self.Z_Regressor.compile(Adam(self.conf.lr), loss=losses, loss_weights=weights) 209 | 210 | def predict_mask(self, modality_index, type, image_list): 211 | assert type in ['simple', 'def', 'max', 'maxnostn'] 212 | 213 | idx2 = modality_index 214 | idx1 = 1 - idx2 215 | 216 | images_mod1 = image_list[idx1] 217 | images_mod2 = image_list[idx2] 218 | 219 | s1 = self.Encoders_Anatomy[idx1].predict(images_mod1) 220 | s2 = self.Encoders_Anatomy[idx2].predict(images_mod2) 221 | 222 | if type == 'simple': 223 | return self.Segmentor.predict(s2) 224 | elif type == 'def': 225 | return self.Segmentor.predict(self.Anatomy_Fuser.predict([s1, s2])[0]) 226 | elif type == 'max': 227 | return self.Segmentor.predict(self.Anatomy_Fuser.predict([s1, s2])[1]) 228 | elif type == 'maxnostn': 229 | s_max_nostn = np.max([s1, s2], axis=0) 230 | return self.Segmentor.predict(s_max_nostn) 231 | 232 | raise ValueError(type) -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | # UNet Implementation of 4 downsampling and 4 upsampling blocks. 2 | # Each block has 2 convolutions, batch normalisation, relu and a residual connection. 3 | # The number of filters for the 1st layer is 64 and at every block, this is doubled. Each upsampling blocks halves the 4 | # number of filters. 5 | 6 | 7 | from keras import Input, Model 8 | from keras.layers import Concatenate, Conv2D, MaxPooling2D, Activation 9 | 10 | from models.basenet import BaseNet 11 | from utils.model_utils import upsample_block, normalise 12 | import logging 13 | log = logging.getLogger('unet') 14 | 15 | 16 | class UNet(BaseNet): 17 | def __init__(self, conf): 18 | super(UNet, self).__init__(conf) 19 | self.input_shape = conf.input_shape 20 | self.out_channels = conf.out_channels 21 | 22 | self.normalise = conf.normalise 23 | self.f = conf.filters 24 | self.downsample = conf.downsample 25 | assert self.downsample > 0, 'Unet downsample must be over 0.' 26 | 27 | def build(self): 28 | self.input = Input(shape=self.input_shape) 29 | l = self.unet_downsample(self.input, self.normalise) 30 | self.unet_bottleneck(l, self.normalise) 31 | l = self.unet_upsample(self.bottleneck, self.normalise) 32 | out = self.out(l) 33 | self.model = Model(inputs=self.input, outputs=out) 34 | self.model.summary(print_fn=log.info) 35 | self.load_models() 36 | 37 | def unet_downsample(self, inp, normalise): 38 | self.d_l0 = conv_block(inp, self.f, normalise) 39 | l = MaxPooling2D(pool_size=(2, 2))(self.d_l0) 40 | 41 | if self.downsample > 1: 42 | self.d_l1 = conv_block(l, self.f * 2, normalise) 43 | l = MaxPooling2D(pool_size=(2, 2))(self.d_l1) 44 | 45 | if self.downsample > 2: 46 | self.d_l2 = conv_block(l, self.f * 4, normalise) 47 | l = MaxPooling2D(pool_size=(2, 2))(self.d_l2) 48 | 49 | if self.downsample > 3: 50 | self.d_l3 = conv_block(l, self.f * 8, normalise) 51 | l = MaxPooling2D(pool_size=(2, 2))(self.d_l3) 52 | return l 53 | 54 | def unet_bottleneck(self, l, normalise, name=''): 55 | flt = self.f * 2 56 | if self.downsample > 1: 57 | flt *= 2 58 | if self.downsample > 2: 59 | flt *= 2 60 | if self.downsample > 3: 61 | flt *= 2 62 | self.bottleneck = conv_block(l, flt, normalise, name) 63 | return self.bottleneck 64 | 65 | def unet_upsample(self, l, normalise): 66 | if self.downsample > 3: 67 | l = upsample_block(l, self.f * 8, normalise, activation='linear') 68 | l = Concatenate()([l, self.d_l3]) 69 | l = conv_block(l, self.f * 8, normalise) 70 | 71 | if self.downsample > 2: 72 | l = upsample_block(l, self.f * 4, normalise, activation='linear') 73 | l = Concatenate()([l, self.d_l2]) 74 | l = conv_block(l, self.f * 4, normalise) 75 | 76 | if self.downsample > 1: 77 | l = upsample_block(l, self.f * 2, normalise, activation='linear') 78 | l = Concatenate()([l, self.d_l1]) 79 | l = conv_block(l, self.f * 2, normalise) 80 | 81 | if self.downsample > 0: 82 | l = upsample_block(l, self.f, normalise, activation='linear') 83 | l = Concatenate()([l, self.d_l0]) 84 | l = conv_block(l, self.f, normalise) 85 | 86 | return l 87 | 88 | def out(self, l, out_activ=None): 89 | if out_activ is None: 90 | out_activ = 'sigmoid' if self.out_channels == 1 else 'softmax' 91 | return Conv2D(self.out_channels, 1, padding='same', activation=out_activ)(l) 92 | 93 | 94 | def conv_block(l0, f, norm_name, name=''): 95 | l = Conv2D(f, 3, strides=1, padding='same', kernel_initializer='he_normal')(l0) 96 | l = normalise(norm_name)(l) 97 | l = Activation('relu')(l) 98 | 99 | l = Conv2D(f, 3, strides=1, padding='same', kernel_initializer='he_normal')(l) 100 | l = normalise(norm_name)(l) 101 | return Activation('relu', name=name)(l) 102 | -------------------------------------------------------------------------------- /pseudocode.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/multimodal_segmentation/a4fa1b39830f6c1bc320ff5b5e3fda82b8382e18/pseudocode.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.2.2 2 | albumentations==0.3.1 3 | astor==0.6.2 4 | attrs==19.1.0 5 | backcall==0.1.0 6 | backports.weakref==1.0rc1 7 | bleach==1.5.0 8 | certifi==2019.3.9 9 | chardet==3.0.4 10 | comet-git-pure==0.19.11 11 | configobj==5.0.6 12 | cycler==0.10.0 13 | decorator==4.1.2 14 | dicom==0.9.9.post1 15 | easydict==1.7 16 | enum34==1.1.6 17 | everett==1.0.2 18 | gast==0.2.0 19 | gitdb2==2.0.5 20 | GitPython==2.1.11 21 | grpcio==1.12.1 22 | h5py==2.8.0 23 | html5lib==0.9999999 24 | idna==2.8 25 | imageio==2.8.0 26 | imgaug==0.2.6 27 | ipython==6.4.0 28 | ipython-genutils==0.2.0 29 | jedi==0.12.0 30 | jsonschema==3.0.1 31 | Keras==2.1.6 32 | Keras-Applications==1.0.2 33 | keras-contrib==2.0.8 34 | Keras-Preprocessing==1.0.1 35 | Markdown==2.6.9 36 | matplotlib==2.0.2 37 | netifaces==0.10.9 38 | networkx==1.11 39 | nibabel==2.3.0 40 | numexpr==2.6.2 41 | numpy==1.14.4 42 | nvidia-ml-py3==7.352.0 43 | olefile==0.44 44 | opencv-python-headless==4.1.0.25 45 | parso==0.2.1 46 | pexpect==4.6.0 47 | pickleshare==0.7.4 48 | Pillow==4.2.1 49 | prompt-toolkit==1.0.15 50 | protobuf==3.4.0 51 | ptyprocess==0.5.2 52 | pydicom==1.2.2 53 | Pygments==2.2.0 54 | pyparsing==2.2.0 55 | pyrsistent==0.14.11 56 | python-dateutil==2.6.1 57 | pytz==2017.2 58 | PyWavelets==0.5.2 59 | PyYAML==3.12 60 | requests==2.21.0 61 | scikit-image==0.13.0 62 | scikit-learn==0.19.0 63 | scipy==0.19.1 64 | simplegeneric==0.8.1 65 | six==1.12.0 66 | smmap2==2.0.5 67 | tensorboard==1.7.0 68 | tensorflow==1.3.0 69 | tensorflow-gpu==1.4.0 70 | tensorflow-tensorboard==0.4.0 71 | termcolor==1.1.0 72 | traitlets==4.3.2 73 | urllib3==1.24.1 74 | wcwidth==0.1.7 75 | websocket-client==0.55.0 76 | Werkzeug==0.12.2 77 | wurlitzer==1.0.2 78 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/multimodal_segmentation/a4fa1b39830f6c1bc320ff5b5e3fda82b8382e18/utils/__init__.py -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import logging 4 | log = logging.getLogger('data_utils') 5 | 6 | 7 | def rescale(array, min_value=-1, max_value=1): 8 | """ 9 | Rescales the input image between the min and max value. 10 | :param array: a 4D array 11 | :param min_value: the minimum value 12 | :param max_value: the maximum value 13 | :return: the rescaled array 14 | """ 15 | if array.max() == array.min(): 16 | array = (array * 0) + min_value 17 | return array 18 | array = (max_value - min_value) * (array - float(array.min())) / (array.max() - array.min()) + min_value 19 | assert array.max() == max_value and array.min() == min_value, '%d, %d' % (array.max(), array.min()) 20 | return array 21 | 22 | def normalise(image): 23 | """ 24 | Normalise an image using the median and inter-quartile distance. 25 | :param image: a 4D array 26 | :return: the normalised image 27 | """ 28 | array = image.copy() 29 | m = np.percentile(array, 50) 30 | s = np.percentile(array, 75) - np.percentile(array, 25) 31 | array = np.divide((array - m), s + 1e-12) 32 | 33 | assert not np.any(np.isnan(array)), 'NaN values in normalised array' 34 | return array 35 | 36 | 37 | def crop_same(image_list, mask_list, size=(None, None), mode='equal', pad_mode='edge'): 38 | ''' 39 | Crop the data in the image and mask lists, so that they have the same size. 40 | :param image_list: a list of images. Each element should be 4-dimensional, (sl,h,w,chn) 41 | :param mask_list: a list of masks. Each element should be 4-dimensional, (sl,h,w,chn) 42 | :param size: dimensions to crop the images to. 43 | :param mode: can be one of [equal, left, right]. Denotes where to crop pixels from. Defaults to middle. 44 | :param pad_mode: can be one of ['edge', 'constant']. 'edge' pads using the values of the edge pixels, 45 | 'constant' pads with a constant value 46 | :return: the modified arrays 47 | ''' 48 | min_w = np.min([m.shape[1] for m in mask_list]) if size[0] is None else size[0] 49 | min_h = np.min([m.shape[2] for m in mask_list]) if size[1] is None else size[1] 50 | 51 | # log.debug('Resizing list1 of size %s to size %s' % (str(image_list[0].shape), str((min_w, min_h)))) 52 | # log.debug('Resizing list2 of size %s to size %s' % (str(mask_list[0].shape), str((min_w, min_h)))) 53 | 54 | img_result, msk_result = [], [] 55 | for i in range(len(mask_list)): 56 | im = image_list[i] 57 | m = mask_list[i] 58 | 59 | if m.shape[1] > min_w: 60 | m = _crop(m, 1, min_w, mode) 61 | if im.shape[1] > min_w: 62 | im = _crop(im, 1, min_w, mode) 63 | if m.shape[1] < min_w: 64 | m = _pad(m, 1, min_w, pad_mode) 65 | if im.shape[1] < min_w: 66 | im = _pad(im, 1, min_w, pad_mode) 67 | 68 | if m.shape[2] > min_h: 69 | m = _crop(m, 2, min_h, mode) 70 | if im.shape[2] > min_h: 71 | im = _crop(im, 2, min_h, mode) 72 | if m.shape[2] < min_h: 73 | m = _pad(m, 2, min_h, pad_mode) 74 | if im.shape[2] < min_h: 75 | im = _pad(im, 2, min_h, pad_mode) 76 | 77 | img_result.append(im) 78 | msk_result.append(m) 79 | return img_result, msk_result 80 | 81 | 82 | def _crop(image, dim, nb_pixels, mode): 83 | diff = image.shape[dim] - nb_pixels 84 | if mode == 'equal': 85 | l = int(np.ceil(diff / 2)) 86 | r = image.shape[dim] - l 87 | elif mode == 'right': 88 | l = 0 89 | r = nb_pixels 90 | elif mode == 'left': 91 | l = diff 92 | r = image.shape[dim] 93 | else: 94 | raise 'Unexpected mode: %s. Expected to be one of [equal, left, right].' % mode 95 | 96 | if dim == 1: 97 | return image[:, l:r, :, :] 98 | elif dim == 2: 99 | return image[:, :, l:r, :] 100 | else: 101 | return None 102 | 103 | 104 | def _pad(image, dim, nb_pixels, mode='edge'): 105 | diff = nb_pixels - image.shape[dim] 106 | l = int(diff / 2) 107 | r = int(diff - l) 108 | if dim == 1: 109 | pad_width = ((0, 0), (l, r), (0, 0), (0, 0)) 110 | elif dim == 2: 111 | pad_width = ((0, 0), (0, 0), (l, r), (0, 0)) 112 | else: 113 | return None 114 | 115 | if mode == 'edge': 116 | new_image = np.pad(image, pad_width, 'edge') 117 | elif mode == 'constant': 118 | new_image = np.pad(image, pad_width, 'constant', constant_values=np.min(image)) 119 | else: 120 | raise Exception('Invalid pad mode: ' + mode) 121 | 122 | return new_image 123 | 124 | 125 | def sample(data, nb_samples, seed=-1): 126 | if seed > -1: 127 | np.random.seed(seed) 128 | idx = np.random.choice(len(data), size=nb_samples, replace=False) 129 | return np.array([data[i] for i in idx]) 130 | -------------------------------------------------------------------------------- /utils/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class NormalDistribution(object): 5 | def __init__(self): 6 | self.mu = 0 7 | self.sigma = 1 8 | 9 | def sample(self, N): 10 | samples = np.random.normal(self.mu, self.sigma, N) 11 | return samples 12 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | import albumentations.augmentations.transforms 5 | import matplotlib.path as pth 6 | import numpy as np 7 | from PIL import Image, ImageDraw 8 | from scipy.misc import imsave 9 | from scipy.ndimage.morphology import binary_fill_holes, binary_dilation 10 | 11 | import utils.data_utils 12 | 13 | 14 | def save_segmentation(folder, model, images, masks, name_prefix): 15 | ''' 16 | :param folder: folder to save the image 17 | :param model : segmentation model 18 | :param images: an image of shape [H,W,chn] 19 | :param masks : a mask of shape [H,W,chn] 20 | :return : the predicted segmentation mask 21 | ''' 22 | images = np.expand_dims(images, axis=0) 23 | masks = np.expand_dims(masks, axis=0) 24 | s = model.predict(images) 25 | 26 | # In this case the segmentor is multi-output, with each output corresponding to a mask. 27 | if len(s[0].shape) == 4: 28 | s = np.concatenate(s, axis=-1) 29 | 30 | mask_list_pred = [s[:, :, :, j:j + 1] for j in range(s.shape[-1])] 31 | mask_list_real = [masks[:, :, :, j:j + 1] for j in range(masks.shape[-1])] 32 | if masks.shape[-1] < s.shape[-1]: 33 | mask_list_real += [np.zeros(shape=masks.shape[0:3] + (1,))] * (s.shape[-1] - masks.shape[-1]) 34 | 35 | # if we use rotations, the sizes might differ 36 | m1, m2 = utils.data_utils.crop_same(mask_list_real, mask_list_pred) 37 | images_cropped, _ = utils.data_utils.crop_same([images], [images.copy()], size=(m1[0].shape[1], m1[0].shape[2])) 38 | mask_list_real = [s[0, :, :, 0] for s in m1] 39 | mask_list_pred = [s[0, :, :, 0] for s in m2] 40 | images_cropped = [s[0, :, :, 0] for s in images_cropped] 41 | 42 | row1 = np.concatenate(images_cropped + mask_list_pred, axis=1) 43 | row2 = np.concatenate(images_cropped + mask_list_real, axis=1) 44 | im = np.concatenate([row1, row2], axis=0) 45 | imsave(os.path.join(folder, name_prefix + '.png'), im) 46 | return s, im 47 | 48 | 49 | def makeTextHeaderImage(col_widths, headings, padding=(5, 5)): 50 | im_width = len(headings) * col_widths 51 | im_height = padding[1] * 2 + 11 52 | 53 | img = Image.new('RGB', (im_width, im_height), (0, 0, 0)) 54 | d = ImageDraw.Draw(img) 55 | 56 | for i, txt in enumerate(headings): 57 | 58 | while d.textsize(txt)[0] > col_widths - padding[0]: 59 | txt = txt[:-1] 60 | d.text((col_widths * i + padding[0], + padding[1]), txt, fill=(1, 0, 0)) 61 | 62 | raw_img_data = np.asarray(img, dtype="int32") 63 | 64 | return raw_img_data[:, :, 0] 65 | 66 | 67 | def process_contour(segm_mask, endocardium, epicardium=None): 68 | ''' 69 | in each pixel we sample these 8 points: 70 | _________________ 71 | | * * | 72 | | * * | 73 | | | 74 | | | 75 | | | 76 | | * * | 77 | | * * | 78 | ------------------ 79 | we say a pixel is in the contour if half or more of these 8 points fall within the contour line 80 | ''' 81 | 82 | contour_endo = pth.Path(endocardium, closed=True) 83 | contour_epi = pth.Path(epicardium, closed=True) if epicardium is not None else None 84 | for x in range(segm_mask.shape[1]): 85 | for y in range(segm_mask.shape[0]): 86 | for (dx, dy) in [(-0.25, -0.375), (-0.375, -0.25), (-0.25, 0.375), (-0.375, 0.25), (0.25, 0.375), 87 | (0.375, 0.25), (0.25, -0.375), (0.375, -0.25)]: 88 | 89 | point = (x + dx, y + dy) 90 | if contour_epi is None and contour_endo.contains_point(point): 91 | segm_mask[y, x] += 1 92 | elif contour_epi is not None and \ 93 | contour_epi.contains_point(point) and not contour_endo.contains_point(point): 94 | segm_mask[y, x] += 1 95 | 96 | segm_mask = (segm_mask >= 4) * 1. 97 | return segm_mask 98 | 99 | 100 | def intensity_augmentation(batch): 101 | """ 102 | Perform a random intensity augmentation 103 | :param batch: an image batch (B,H,W,C) 104 | :return: the intensity augmented batch 105 | """ 106 | aug = albumentations.augmentations.transforms.RandomBrightnessContrast(brightness_limit=0.01, 107 | contrast_limit=(0.99, 1.01)) 108 | batch = utils.data_utils.rescale(batch, 0, 1) 109 | batch = aug(image = batch)['image'] 110 | return utils.data_utils.rescale(batch, -1, 1) 111 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from keras.layers import BatchNormalization, Lambda, Conv2D, LeakyReLU, UpSampling2D, Activation 2 | 3 | from keras_contrib.layers import InstanceNormalization 4 | 5 | 6 | def normalise(norm=None, **kwargs): 7 | if norm == 'instance': 8 | return InstanceNormalization(**kwargs) 9 | elif norm == 'batch': 10 | return BatchNormalization() 11 | else: 12 | return Lambda(lambda x : x) 13 | 14 | 15 | def upsample_block(l0, f, norm_name, activation='relu'): 16 | l = UpSampling2D(size=2)(l0) 17 | l = Conv2D(f, 3, padding='same', kernel_initializer='he_normal')(l) 18 | l = normalise(norm_name)(l) 19 | 20 | if activation == 'leakyrelu': 21 | return LeakyReLU()(l) 22 | return Activation(activation)(l) 23 | 24 | 25 | -------------------------------------------------------------------------------- /utils/sdnet_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import keras.backend as K 4 | 5 | from utils import image_utils 6 | from utils.distributions import NormalDistribution 7 | 8 | 9 | def sampling(args): 10 | """ 11 | Reparameterization trick by sampling from an isotropic unit Gaussian. 12 | Instead of sampling from Q(z|X), sample eps = N(0,I): z = z_mean + sqrt(var)*eps 13 | :param args: args (tensor): mean and log of variance of Q(z|X) 14 | :return: z (tensor): sampled latent vector 15 | """ 16 | z_mean, z_log_var = args 17 | batch = K.shape(z_mean)[0] 18 | dim = K.int_shape(z_mean)[1] 19 | # by default, random_normal has mean=0 and std=1.0 20 | epsilon = K.random_normal(shape=(batch, dim)) 21 | return z_mean + K.exp(0.5 * z_log_var) * epsilon 22 | 23 | 24 | def vae_sample(args): 25 | z_mean, z_log_var = args 26 | batch = z_mean.shape[0] 27 | dim = z_mean.shape[1] 28 | # by default, random_normal has mean=0 and std=1.0 29 | gaussian = NormalDistribution() 30 | epsilon = gaussian.sample((batch, dim)) 31 | return z_mean + np.exp(0.5 * z_log_var) * epsilon 32 | 33 | 34 | def get_net(trainer_model, name): 35 | layers = [l for l in trainer_model.layers if l.name == name] 36 | assert len(layers) == 1 37 | return layers[0] 38 | 39 | 40 | def make_trainable(model, val): 41 | model.trainable = val 42 | try: 43 | for l in model.layers: 44 | try: 45 | for k in l.layers: 46 | make_trainable(k, val) 47 | except: 48 | # Layer is not a model, so continue 49 | pass 50 | l.trainable = val 51 | except: 52 | # Layer is not a model, so continue 53 | pass 54 | --------------------------------------------------------------------------------