├── LICENSE.txt ├── README.md ├── callbacks ├── __init__.py ├── loss_callback.py └── sdnet_callback.py ├── config.py ├── costs.py ├── layers ├── __init__.py ├── instance_normalization.py └── rounding.py ├── loaders ├── __init__.py ├── acdc.py ├── base_loader.py ├── data.py └── loader_factory.py ├── main.py ├── models ├── __init__.py ├── discriminator.py ├── resnet.py └── unet.py ├── parameters.py ├── sdnet.py ├── sdnet_trainer.py └── utils ├── __init__.py └── data_utils.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Agis Chartsias 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatial Factorisation 2 | 3 | Code for the paper [Factorised spatial representation learning: application in semi-supervised myocardial segmentation]. 4 | 5 | The main files are: 6 | 7 | * sdnet.py: the model implementation 8 | * sdnet_trainer.py: code related to running an experiment 9 | 10 | Data loaders are stored in the _loaders_ package. To define a new loader, extend class `base_loader.Loader`, and add initialisation in loader_factory.py. The data folder location can be specified in parameters.py. 11 | 12 | The main method is in main.py and arguments can be passed at runtime. For example, an experiment can be run with: 13 | ``` 14 | python main.py --dataset acdc --split 0 --ul_mix 1 --l_mix 0.5 15 | ``` 16 | 17 | `--split` defines the cross validation data split, `--ul_mix` the percentage of unlabelled data, and `--l_mix` the percentage of labelled images. These proportions are calculated by comparing with the total number of labelled images in the dataset. 18 | 19 | ## Citation 20 | 21 | If you use this code for your research, please cite our paper: 22 | ``` 23 | @InProceedings{chartsias2018factorised, 24 | author="Chartsias, Agisilaos 25 | and Joyce, Thomas 26 | and Papanastasiou, Giorgos 27 | and Semple, Scott 28 | and Williams, Michelle 29 | and Newby, David 30 | and Dharmakumar, Rohan 31 | and Tsaftaris, Sotirios A.", 32 | editor="Frangi, Alejandro F. 33 | and Schnabel, Julia A. 34 | and Davatzikos, Christos 35 | and Alberola-L{\'o}pez, Carlos 36 | and Fichtinger, Gabor", 37 | title="Factorised Spatial Representation Learning: Application in Semi-supervised Myocardial Segmentation", 38 | booktitle="Medical Image Computing and Computer Assisted Intervention -- MICCAI 2018", 39 | year="2018", 40 | publisher="Springer International Publishing", 41 | address="Cham", 42 | pages="490--498", 43 | isbn="978-3-030-00934-2" 44 | } 45 | ``` 46 | 47 | [Factorised spatial representation learning: application in semi-supervised myocardial segmentation]: https://link.springer.com/chapter/10.1007/978-3-030-00934-2_55 48 | -------------------------------------------------------------------------------- /callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/spatial_factorisation/233d72511ffb52f52214a68f1c996555345991d0/callbacks/__init__.py -------------------------------------------------------------------------------- /callbacks/loss_callback.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import matplotlib.pyplot as plt 4 | from keras.callbacks import Callback 5 | 6 | 7 | class SaveLoss(Callback): 8 | """ 9 | Keras callback for saving a graph of the training losses over epochs as a .png image. 10 | It creates two plots, one for the adversarial losses (training_discr_loss.png) and one 11 | for the remaining losses (training_loss.png). 12 | """ 13 | def __init__(self, folder, scale='linear'): 14 | """ 15 | Callback constructor 16 | :param folder: folder to save the images 17 | :param scale: can be 'linear' or 'log', to plot values in y-axis in a linear or logarithmic scale. 18 | """ 19 | super(SaveLoss, self).__init__() 20 | self.folder = folder 21 | self.values = dict() 22 | 23 | if scale not in ['linear', 'log']: 24 | raise ValueError('Invalid value for scale. Allowed values: linear, log. Given value: %s' % str(scale)) 25 | self.scale = scale 26 | 27 | # Overwrite default on_epoch_end implementation 28 | def on_epoch_end(self, epoch, logs=None): 29 | if logs is None: 30 | raise ValueError('Parameter logs cannot be None.') 31 | 32 | # Initialise self.values dictionary the first epoch. 33 | if len(self.values) == 0: 34 | for k in logs: 35 | self.values[k] = [] 36 | 37 | # Update dictionary values. 38 | for k in logs: 39 | self.values[k].append(logs[k]) 40 | 41 | # Save a graph of the training loss values. 42 | plt.figure() 43 | plt.suptitle('Training loss', fontsize=16) 44 | for k in self.values: 45 | epochs = range(len(self.values[k])) 46 | if self.scale == 'linear': 47 | plt.plot(epochs, self.values[k], label=k) 48 | elif self.scale == 'log': 49 | plt.semilogy(epochs, self.values[k], label=k) 50 | plt.xlabel('Epochs') 51 | plt.ylabel('Loss') 52 | plt.legend(loc='best') 53 | plt.savefig(os.path.join(self.folder, 'training_loss.png')) 54 | 55 | # Save a graph of the loss values of adversarial training. 56 | # Convention: Adversarial loss names start with dis or adv 57 | plt.figure() 58 | plt.suptitle('Training loss', fontsize=16) 59 | for k in self.values: 60 | if not ('dis' in k or 'adv' in k): 61 | continue 62 | 63 | epochs = range(len(self.values[k])) 64 | plt.plot(epochs, self.values[k], label=k) 65 | plt.xlabel('Epochs') 66 | plt.ylabel('Loss') 67 | plt.legend(loc='best') 68 | plt.savefig(os.path.join(self.folder, 'training_discr_loss.png')) 69 | 70 | plt.close() 71 | -------------------------------------------------------------------------------- /callbacks/sdnet_callback.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import os 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from keras.callbacks import Callback 8 | from scipy.misc import imsave 9 | 10 | log = logging.getLogger('SDNetCallback') 11 | 12 | 13 | class SDNetCallback(Callback): 14 | """ 15 | Image callback for saving images during SDNet training. 16 | Images are saved in a subfolder with name training_images, created inside the experiment folder. 17 | """ 18 | def __init__(self, folder, batch_size, sdnet): 19 | """ 20 | :param folder: experiment folder, where all results are saved 21 | :param batch_size: batch size used for training 22 | """ 23 | super(SDNetCallback, self).__init__() 24 | 25 | # Create results folder 26 | self.folder = os.path.join(folder, 'training_images') 27 | if not os.path.exists(self.folder): 28 | os.makedirs(self.folder) 29 | 30 | self.sdnet = sdnet 31 | self.batch_size = batch_size 32 | 33 | def on_epoch_end(self, epoch, data_labelled=None, images_unlabelled=None, logs=None): 34 | """ 35 | Overwrite default on_epoch_end implementation. 36 | 37 | :param epoch: current training epoch 38 | :param data_labelled: a list of tuples (image, mask) 39 | :param data_unlabelled: a list of images with no corresponding masks 40 | :param logs: a dictionary of losses. Not used here 41 | """ 42 | images_labelled = np.concatenate([data_labelled[i][0] for i in range(len(data_labelled))], axis=0) 43 | masks = np.concatenate([data_labelled[i][1] for i in range(len(data_labelled))], axis=0) 44 | images_unlabelled = np.array(images_unlabelled) 45 | 46 | self.plot_images(epoch, images_labelled, masks, images_unlabelled) 47 | self.plot_discriminator_outputs(epoch, np.concatenate([images_labelled, images_unlabelled]), masks) 48 | 49 | def plot_images(self, epoch, images_labelled, masks, images_unlabelled): 50 | """ 51 | Save segmentation and reconstruction examples. 52 | :param epoch: current training epoch 53 | :param images_labelled: an array of labelled images 54 | :param masks: an array of corresponding masks (to the labelled images) 55 | :param images_unlabelled: an array of images with no masks 56 | """ 57 | rows = [] 58 | # plot 3 labelled examples 59 | for i in range(3): 60 | rows.append(self.get_image_row(images_labelled, masks)) 61 | # plot 3 unlabelled examples 62 | for i in range(3): 63 | rows.append(self.get_image_row(images_unlabelled, np.zeros(images_unlabelled.shape))) 64 | 65 | img = np.concatenate(rows, axis=0) 66 | imsave(self.folder + '/cardiacgan_epoch_%d.png' % epoch, img) 67 | 68 | def get_image_row(self, images, masks): 69 | """ 70 | Create an array of 8 images showing segmentations and reconstructions with different combinations of masks 71 | and residuals 72 | :param images: an array of images 73 | :param masks: an array of masks 74 | :return: a concatenated array of 8 subarrays to be used as one row of the final plotted image 75 | """ 76 | if len(images) == 0: 77 | return [] 78 | 79 | xi = np.random.randint(images.shape[0]) # draw random sample 80 | x = images[xi:xi + 1] 81 | pred_m, z = self.sdnet.Decomposer.predict(x) 82 | 83 | m = masks[xi:xi + 1] 84 | rec_predM_z = self.sdnet.Reconstructor.predict([pred_m, z]) 85 | rec_m_z = self.sdnet.Reconstructor.predict([m, z]) 86 | rec_m0_z = self.sdnet.Reconstructor.predict([np.zeros(m.shape), z]) 87 | rec_m_z0 = self.sdnet.Reconstructor.predict([m, np.zeros(z.shape)]) 88 | rec_m0_z0 = self.sdnet.Reconstructor.predict([np.zeros(m.shape), np.zeros(z.shape)]) 89 | return np.concatenate([np.squeeze(el) for el in 90 | [x, pred_m, rec_predM_z, m, rec_m_z, rec_m0_z, rec_m_z0, rec_m0_z0]], axis=1) 91 | 92 | def plot_discriminator_outputs(self, epoch, images, masks): 93 | if masks.shape[0] == 0: 94 | return 95 | 96 | # number of points used for the histogram 97 | sz = 40 if 40 < np.min([len(images), len(masks)]) else np.min([len(images), len(masks)]) 98 | idx_X = np.random.choice(len(images), size=sz, replace=False) 99 | idx_M = np.random.choice(len(masks), size=sz, replace=False) 100 | 101 | samples_X = np.concatenate([images[i:i + 1] for i in idx_X], axis=0) 102 | samples_M = np.concatenate([masks[i:i + 1] for i in idx_M], axis=0) 103 | samples_pred_M, samples_Z = self.sdnet.Decomposer.predict(samples_X) 104 | samples_pred_X = self.sdnet.Reconstructor.predict([samples_pred_M, samples_Z]) 105 | 106 | dx_true = np.array([np.mean(self.sdnet.ImageDiscriminator.predict(samples_X[i:i + 1])) for i in range(sz)]) 107 | dx_pred = np.array([np.mean(self.sdnet.ImageDiscriminator.predict(samples_pred_X[i:i + 1])) for i in range(sz)]) 108 | dm_true = np.array([np.mean(self.sdnet.MaskDiscriminator.predict(samples_M[i:i + 1])) for i in range(sz)]) 109 | dm_pred = np.array([np.mean(self.sdnet.MaskDiscriminator.predict(samples_pred_M[i:i + 1])) for i in range(sz)]) 110 | 111 | plt.figure() 112 | plt.subplot(1, 2, 1) 113 | plt.hist([dx_true, dx_pred], stacked=True, normed=True) 114 | plt.subplot(1, 2, 2) 115 | plt.hist([dm_true, dm_pred], stacked=True, normed=True) 116 | plt.savefig(self.folder + '/discriminator_hist_epoch_%d.png' % epoch) 117 | plt.close() 118 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | class Configuration(object): 6 | """ 7 | Configuration object with experiment parameters. 8 | """ 9 | def __init__(self, folder, data_len, input_shape): 10 | self.seed = 0 11 | self.folder = folder 12 | self.data_len = data_len 13 | self.epochs = 500 14 | self.batch_size = 4 15 | self.batches = data_len / self.batch_size 16 | self.pool_size = 50 17 | self.input_shape = input_shape 18 | self.split = None 19 | self.description = '' 20 | self.dataset_name = '' 21 | self.unlabelled_image_num = 0 22 | self.labelled_image_num = 0 23 | self.augment = False 24 | self.l_mix = None 25 | self.ul_mix = None 26 | 27 | self.w_uns_adv_M = 10 28 | self.w_uns_rec_X = 5 29 | self.w_uns_adv_X = 5 30 | self.w_fake_M = 10 31 | self.w_fake_X = 10 32 | self.w_rec_X = 1 33 | self.w_adv_M = 10 34 | self.w_adv_X_fromreal = 1 35 | 36 | def save(self): 37 | fname = os.path.join(self.folder, 'config.json') 38 | if not os.path.exists(self.folder): 39 | os.makedirs(self.folder) 40 | 41 | with open(fname, 'w') as outfile: 42 | json.dump(self.__dict__, outfile) 43 | -------------------------------------------------------------------------------- /costs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras import backend as K 3 | 4 | 5 | def dice(y_true, y_pred, smooth=0.1): 6 | y_pred = y_pred[..., 0:y_true.shape[-1]] 7 | 8 | # Symbolically compute the intersection 9 | y_int = y_true * y_pred 10 | return np.mean((2 * np.sum(y_int, axis=(1, 2, 3)) + smooth) 11 | / (np.sum(y_true, axis=(1, 2, 3)) + np.sum(y_pred, axis=(1, 2, 3)) + smooth)) 12 | 13 | 14 | def dice_coef(y_true, y_pred): 15 | ''' 16 | DICE coefficient. 17 | :param y_true: a tensor of ground truth data 18 | :param y_pred: a tensor of predicted data 19 | ''' 20 | # Symbolically compute the intersection 21 | intersection = K.sum(y_true * y_pred, axis=(1, 2, 3)) + 0.1 22 | union = K.sum(y_true, axis=(1, 2, 3)) + K.sum(y_pred, axis=(1, 2, 3)) + 0.1 23 | return K.mean(2 * intersection / union, axis=0) 24 | 25 | 26 | def dice_coef_loss(y_true, y_pred): 27 | return 1 - dice_coef(y_true, y_pred) 28 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/spatial_factorisation/233d72511ffb52f52214a68f1c996555345991d0/layers/__init__.py -------------------------------------------------------------------------------- /layers/instance_normalization.py: -------------------------------------------------------------------------------- 1 | 2 | from keras.engine.topology import Layer 3 | import keras.backend as K 4 | 5 | 6 | class InstanceNormalization(Layer): 7 | '''Instance Normalization adapted from https://github.com/PiscesDream/CycleGAN-keras''' 8 | 9 | def __init__(self, **kwargs): 10 | super(InstanceNormalization, self).__init__() 11 | self.inshape = kwargs['inshape'] if 'inshape' in kwargs else None 12 | 13 | def build(self, input_shape): 14 | self.scale = self.add_weight(name='scale', shape=(input_shape[get_channel_dim()],), initializer="one", trainable=True) 15 | self.shift = self.add_weight(name='shift', shape=(input_shape[get_channel_dim()],), initializer="zero", trainable=True) 16 | super(InstanceNormalization, self).build(input_shape) 17 | 18 | def call(self, x, mask=None): 19 | if get_channel_dim() == 1: 20 | h, w = 2, 3 21 | exp_dim = -1 22 | else: 23 | h, w = 1, 2 24 | exp_dim = 1 25 | 26 | x_shape = self.inshape if self.inshape else x.shape 27 | hw = K.cast(x_shape[h] * x_shape[w], K.floatx()) 28 | mu = K.sum(x, [h, w]) / hw 29 | mu_vec = K.expand_dims(K.expand_dims(mu, 1), 1) 30 | sig2 = K.sum(K.square(x - mu_vec), [h, w]) / hw 31 | sig2_vec = K.expand_dims(K.expand_dims(sig2, 1), 1) 32 | y = (x - mu_vec) / (K.sqrt(sig2_vec) + K.epsilon()) 33 | 34 | scale = K.expand_dims(K.expand_dims(K.expand_dims(self.scale, 0), exp_dim), exp_dim) 35 | shift = K.expand_dims(K.expand_dims(K.expand_dims(self.shift, 0), exp_dim), exp_dim) 36 | return scale * y + shift 37 | 38 | def compute_output_shape(self, input_shape): 39 | return input_shape 40 | 41 | def get_channel_dim(): 42 | if K.image_data_format() == 'channels_first': 43 | return 1 44 | return -1 -------------------------------------------------------------------------------- /layers/rounding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from keras.engine.topology import Layer 4 | from tensorflow.python.framework import ops 5 | 6 | 7 | class Rounding(Layer): 8 | """ 9 | Custom layer that rounds a tensor. 10 | """ 11 | def __init__(self, **kwargs): 12 | super(Rounding, self).__init__(**kwargs) 13 | 14 | def build(self, input_shape): 15 | super(Rounding, self).build(input_shape) 16 | 17 | def call(self, x, **kwargs): 18 | return roundWithGrad(x) 19 | 20 | def compute_output_shape(self, input_shape): 21 | return input_shape 22 | 23 | 24 | # Define custom py_func which takes also a grad op as argument: 25 | def py_func(func, inp, Tout, stateful=True, name=None, grad=None): 26 | rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8)) # generate a unique name to avoid duplicates 27 | tf.RegisterGradient(rnd_name)(grad) 28 | g = tf.get_default_graph() 29 | with g.gradient_override_map({"PyFunc": rnd_name}): 30 | res = tf.py_func(func, inp, Tout, stateful=stateful, name=name) 31 | res[0].set_shape(inp[0].get_shape()) 32 | return res 33 | 34 | 35 | def roundWithGrad(x, name=None): 36 | with ops.name_scope(name, "roundWithGrad", [x]) as name: 37 | round_x = py_func(lambda x: np.round(x).astype('float32'), [x], [tf.float32], name=name, 38 | grad=_roundWithGrad_grad) # <-- here's the call to the gradient 39 | return round_x[0] 40 | 41 | 42 | def _roundWithGrad_grad(op, grad): 43 | x = op.inputs[0] 44 | return grad * 1 # do whatever with gradient here (e.g. could return grad * 2 * x if op was f(x)=x**2) 45 | -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/spatial_factorisation/233d72511ffb52f52214a68f1c996555345991d0/loaders/__init__.py -------------------------------------------------------------------------------- /loaders/acdc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import nibabel as nib 5 | import numpy as np 6 | 7 | from loaders.base_loader import Loader 8 | from loaders.data import Data 9 | from parameters import conf 10 | from utils import data_utils 11 | 12 | 13 | class ACDCLoader(Loader): 14 | """ 15 | ACDC Challenge loader. Annotations for LV, MYO, RV with labels 3, 2, 1 respectively. 16 | """ 17 | 18 | def __init__(self): 19 | super(ACDCLoader, self).__init__() 20 | self.num_volumes = 100 21 | self.input_shape = (224, 224, 1) 22 | self.data_folder = conf['acdc'] 23 | self.log = logging.getLogger('acdc') 24 | 25 | def splits(self): 26 | """ 27 | :return: an array of splits into validation, test and train indices 28 | """ 29 | 30 | splits = [ 31 | {'validation': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], 32 | 'test': [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], 33 | 'training': [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 34 | 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 35 | 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 36 | 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100] 37 | }, 38 | {'validation': [85, 13, 9, 74, 73, 68, 59, 79, 47, 80, 14, 95, 25, 92, 87], 39 | 'test': [54, 55, 99, 63, 91, 24, 51, 3, 64, 43, 61, 66, 96, 27, 76], 40 | 'training': [46, 57, 49, 34, 17, 8, 19, 28, 97, 1, 90, 22, 88, 45, 12, 4, 5, 41 | 75, 53, 94, 62, 86, 35, 58, 82, 37, 84, 93, 6, 33, 15, 81, 23, 48, 42 | 71, 70, 11, 77, 36, 60, 31, 65, 32, 78, 98, 52, 100, 42, 38, 2, 20, 43 | 69, 26, 18, 40, 50, 16, 7, 41, 10, 83, 21, 39, 72, 56, 67, 44, 30, 89, 29] 44 | }, 45 | {'validation': [47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61], 46 | 'test': [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78], 47 | 'training': [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48 | 100, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 62, 63, 49 | 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 79, 80, 50 | 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] 51 | }, 52 | {'validation': [20, 23, 24, 25, 26, 27, 28, 29, 30, 31, 37, 33, 34, 35, 36], 53 | 'test': [38, 39, 40, 41, 43, 45, 46, 47, 48, 50, 51, 52, 53, 54, 55], 54 | 'training': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 55 | 21, 22, 32, 42, 44, 49, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 56 | 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 57 | 91, 92, 93, 94, 95, 96, 97, 98, 99, 100] 58 | } 59 | ] 60 | 61 | return splits 62 | 63 | def load_labelled_data(self, split, split_type, modality='MR', normalise=True, value_crop=True, downsample=1): 64 | """ 65 | Load labelled data, and return a Data object. In ACDC there are ES and ED annotations. Preprocessed data 66 | are saved in .npz files. If they don't exist, load the original images and preprocess. 67 | 68 | :param split: Cross validation split: can be 0, 1, 2. 69 | :param split_type: Cross validation type: can be training, validation, test, all 70 | :param modality: Data modality. Unused here. 71 | :param normalise: Use normalised data: can be True/False 72 | :param value_crop: Crop extreme values: can be True/False 73 | :param downsample: Downsample data to smaller size. Only used for testing. 74 | :return: a Data object 75 | """ 76 | if split < 0: 77 | raise ValueError('Invalid value for split: %d.' % split) 78 | if split_type not in ['training', 'validation', 'test', 'all']: 79 | raise ValueError('Invalid value for split_type: %s. Allowed values are training, validation, test, all' 80 | % split_type) 81 | 82 | npz_prefix = 'norm_' if normalise else 'unnorm_' 83 | 84 | # If numpy arrays are not saved, load and process raw data 85 | if not os.path.exists(os.path.join(self.data_folder, npz_prefix + 'acdc_images.npz')): 86 | images, masks_lv, masks_rv, masks, index = self.load_raw_labelled_data(normalise, value_crop) 87 | 88 | # save numpy arrays 89 | np.savez_compressed(os.path.join(self.data_folder, npz_prefix + 'acdc_images'), images) 90 | np.savez_compressed(os.path.join(self.data_folder, npz_prefix + 'acdc_masks_lv'), masks_lv) 91 | np.savez_compressed(os.path.join(self.data_folder, npz_prefix + 'acdc_masks_rv'), masks_rv) 92 | np.savez_compressed(os.path.join(self.data_folder, npz_prefix + 'acdc_masks_myo'), masks) 93 | np.savez_compressed(os.path.join(self.data_folder, npz_prefix + 'acdc_index'), index) 94 | # Load data from saved numpy arrays 95 | else: 96 | images = np.load(os.path.join(self.data_folder, npz_prefix + 'acdc_images.npz'))['arr_0'] 97 | masks = np.load(os.path.join(self.data_folder, npz_prefix + 'acdc_masks_myo.npz'))['arr_0'] 98 | index = np.load(os.path.join(self.data_folder, npz_prefix + 'acdc_index.npz'))['arr_0'] 99 | 100 | assert images is not None and masks is not None and index is not None, 'Could not find saved data' 101 | 102 | assert images.max() == 1 and images.min() == -1, 'Images max=%.3f, min=%.3f' % (images.max(), images.min()) 103 | assert masks.max() == 1 and masks.min() == 0, 'Masks max=%.3f, min=%.3f' % (masks.max(), masks.min()) 104 | 105 | self.log.debug('Loaded compressed acdc data of shape: ' + str(images.shape)) 106 | 107 | # Case to load data from all splits. 108 | if split_type == 'all': 109 | return Data(images, masks, index) 110 | 111 | # Select images belonging to the volumes of the split_type (training, validation, test) 112 | volumes = self.splits()[split][split_type] 113 | images = np.concatenate([images[index == v] for v in volumes]) 114 | masks = np.concatenate([masks[index == v] for v in volumes]) 115 | index = np.concatenate([index[index == v] for v in volumes]) 116 | 117 | self.log.debug(split_type + ' set: ' + str(images.shape)) 118 | return Data(images, masks, index, downsample) 119 | 120 | def load_unlabelled_data(self, split, split_type, modality='MR', normalise=True, value_crop=True): 121 | """ 122 | Load unlabelled data. In ACDC, this contains images from the cardiac phases between ES and ED. 123 | :param split: Cross validation split: can be 0, 1, 2. 124 | :param split_type: Cross validation type: can be training, validation, test, all 125 | :param modality: Data modality. Unused here. 126 | :param normalise: Use normalised data: can be True/False 127 | :param value_crop: Crop extreme values: can be True/False 128 | :return: a Data object 129 | """ 130 | images, index = self.base_load_unlabelled_images('acdc', split, split_type, False, normalise, value_crop) 131 | masks = np.zeros(shape=(images.shape[:-1]) + (1,)) 132 | return Data(images, masks, index) 133 | 134 | def load_all_data(self, split, split_type, modality='MR', normalise=True, value_crop=True): 135 | """ 136 | Load all images, unlabelled and labelled, meaning all images from all cardiac phases. 137 | :param split: Cross validation split: can be 0, 1, 2. 138 | :param split_type: Cross validation type: can be training, validation, test, all 139 | :param modality: Data modality. Unused here. 140 | :param normalise: Use normalised data: can be True/False 141 | :param value_crop: Crop extreme values: can be True/False 142 | :return: a Data object 143 | """ 144 | images, index = self.base_load_unlabelled_images('acdc', split, split_type, True, normalise, value_crop) 145 | masks = np.zeros(shape=(images.shape[:-1]) + (1,)) 146 | return Data(images, masks, index) 147 | 148 | def load_raw_labelled_data(self, normalise=True, value_crop=True): 149 | """ 150 | Load labelled data iterating through the ACDC folder structure. 151 | :param normalise: normalise data between -1, 1 152 | :param value_crop: crop between 5 and 95 percentile 153 | :return: a tuple of the image and mask arrays 154 | """ 155 | self.log.debug('Loading acdc data from original location') 156 | images, masks_lv, masks_rv, masks_myo, index = [], [], [], [], [] 157 | 158 | # Iterate through patient folders 159 | for patient_i in self.volumes: 160 | patient = 'patient%03d' % patient_i 161 | patient_folder = os.path.join(self.data_folder, patient) 162 | 163 | # load ground truth mask and image 164 | gt = [f for f in os.listdir(patient_folder) if 'gt' in f and f.startswith(patient + '_frame')] 165 | ims = [f.replace('_gt', '') for f in gt] 166 | 167 | # process every image slice 168 | for i in range(len(ims)): 169 | im = self.process_raw_image(ims[i], patient_folder, value_crop, normalise) 170 | m = self.resample_raw_image(gt[i], patient_folder) 171 | 172 | images.append(im) 173 | 174 | # convert 3-dim mask array to 3 binary mask arrays for lv, rv, myo 175 | m_lv = m.copy() 176 | m_lv[m != 3] = 0 177 | m_lv[m == 3] = 1 178 | masks_lv.append(m_lv) 179 | 180 | m_rv = m.copy() 181 | m_rv[m != 1] = 0 182 | m_rv[m == 1] = 1 183 | masks_rv.append(m_rv) 184 | 185 | m_myo = m.copy() 186 | m_myo[m != 2] = 0 187 | m_myo[m == 2] = 1 188 | masks_myo.append(m_myo) 189 | 190 | index += [patient_i] * im.shape[2] 191 | 192 | assert len(images) == len(masks_myo) 193 | 194 | # move slice axis to the first position 195 | images = [np.moveaxis(im, 2, 0) for im in images] 196 | masks_lv = [np.moveaxis(m, 2, 0) for m in masks_lv] 197 | masks_rv = [np.moveaxis(m, 2, 0) for m in masks_rv] 198 | masks_myo = [np.moveaxis(m, 2, 0) for m in masks_myo] 199 | 200 | # crop images and masks to the same pixel dimensions and concatenate all data 201 | images_cropped, masks_lv_cropped = data_utils.crop_same(images, masks_lv, (224, 224)) 202 | _, masks_rv_cropped = data_utils.crop_same(images, masks_rv, (224, 224)) 203 | _, masks_myo_cropped = data_utils.crop_same(images, masks_myo, (224, 224)) 204 | 205 | images_cropped = np.concatenate(images_cropped, axis=0) 206 | masks_lv_cropped = np.concatenate(masks_lv_cropped, axis=0) 207 | masks_rv_cropped = np.concatenate(masks_rv_cropped, axis=0) 208 | masks_myo_cropped = np.concatenate(masks_myo_cropped, axis=0) 209 | 210 | self.log.debug(str(images[0].shape) + ', ' + str(masks_lv[0].shape)) 211 | return images_cropped, masks_lv_cropped, masks_rv_cropped, masks_myo_cropped, index 212 | 213 | def resample_raw_image(self, mask_fname, patient_folder): 214 | """ 215 | Load raw data (image/mask) and resample to fixed resolution. 216 | :param mask_fname: filename of mask 217 | :param patient_folder: folder containing patient data 218 | :return: the resampled image 219 | """ 220 | m_nii_fname = os.path.join(patient_folder, mask_fname) 221 | m_nii_res_fname = os.path.join(patient_folder, 'res_' + mask_fname) 222 | 223 | # resample to fixed pixel resolution 224 | if not os.path.exists(m_nii_res_fname): 225 | data_utils.resample_ants(m_nii_fname, m_nii_res_fname) 226 | 227 | # load resampled mask 228 | m_nii = nib.load(m_nii_res_fname) 229 | m = m_nii.get_data() 230 | m_voxel_size = m_nii.header.get_zooms() 231 | assert m_voxel_size[0] - 1.37 < 0.0001 and m_voxel_size[1] - 1.37 < 0.0001 \ 232 | and m_voxel_size[2] - 10.0 < 0.0001 and m_voxel_size[3] - 1.0 < 0.0001, m_voxel_size 233 | return m 234 | 235 | def process_raw_image(self, im_fname, patient_folder, value_crop, normalise): 236 | """ 237 | Normalise between -1 and 1, and crop extreme values of an image 238 | :param im_fname: filename of the image 239 | :param patient_folder: folder of patient data 240 | :param value_crop: True/False to crop values between 5/95 percentiles 241 | :param normalise: True/False normalise images 242 | :return: a processed image 243 | """ 244 | im = self.resample_raw_image(im_fname, patient_folder) 245 | 246 | # crop to 5-95 percentile 247 | if value_crop: 248 | p5 = np.percentile(im.flatten(), 5) 249 | p95 = np.percentile(im.flatten(), 95) 250 | im = np.clip(im, p5, p95) 251 | 252 | # normalise to -1, 1 253 | if normalise: 254 | im = data_utils.normalise(im, -1, 1) 255 | 256 | return im 257 | 258 | def load_raw_unlabelled_data(self, include_labelled=True, normalise=True, value_crop=True): 259 | """ 260 | Load unlabelled data iterating through the ACDC folder structure. 261 | :param include_labelled: include images from ES, ED phases that are labelled. Can be True/False 262 | :param normalise: normalise data between -1, 1 263 | :param value_crop: crop between 5 and 95 percentile 264 | :return: an image array 265 | """ 266 | self.log.debug('Loading unlabelled acdc data from original location') 267 | images, index = [], [] 268 | 269 | # Iterate through patient folders 270 | for patient_i in self.volumes: 271 | patient = 'patient%03d' % patient_i 272 | self.log.debug('Loading patient %s' % patient) 273 | patient_folder = os.path.join(self.data_folder, patient) 274 | 275 | im_name = patient + '_4d.nii.gz' 276 | im = self.process_raw_image(im_name, patient_folder, value_crop, normalise) 277 | 278 | frames = range(im.shape[-1]) 279 | if not include_labelled: 280 | gt = [f for f in os.listdir(patient_folder) if 'gt' in f and not f.startswith('._')] 281 | gt_ims = [f.replace('_gt', '') for f in gt if not f.startswith('._')] 282 | 283 | exclude_frames = [int(gt_im.split('.')[0].split('frame')[1]) for gt_im in gt_ims] 284 | frames = [f for f in range(im.shape[-1]) if f not in exclude_frames] 285 | 286 | for frame in frames: 287 | im_res = im[:, :, :, frame] 288 | if im_res.sum() == 0: 289 | print('Skipping blank images') 290 | continue 291 | images.append(im_res) 292 | index += [patient_i] * im_res.shape[-1] 293 | 294 | images = [np.expand_dims(np.moveaxis(im, 2, 0), axis=3) for im in images] 295 | zeros = [np.zeros(im.shape) for im in images] 296 | images_cropped, _ = data_utils.crop_same(images, zeros, (224, 224)) 297 | images_cropped = np.concatenate(images_cropped, axis=0)[..., 0] 298 | index = np.array(index) 299 | 300 | return images_cropped, index 301 | -------------------------------------------------------------------------------- /loaders/base_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class Loader(object): 7 | """ 8 | Abstract class defining the behaviour of loaders for different datasets. 9 | """ 10 | def __init__(self): 11 | self.num_masks = 0 12 | self.num_volumes = 0 13 | self.input_shape = (None, None, 1) 14 | self.data_folder = None 15 | self.volumes = sorted(self.splits()[0]['training'] + 16 | self.splits()[0]['validation'] + 17 | self.splits()[0]['test']) 18 | self.log = None 19 | 20 | @abstractmethod 21 | def splits(self): 22 | """ 23 | :return: an array of splits into validation, test and train indices 24 | """ 25 | pass 26 | 27 | @abstractmethod 28 | def load_labelled_data(self, split, split_type, modality, normalise=True, value_crop=True, downsample=1): 29 | """ 30 | Load labelled data from saved numpy arrays. 31 | Assumes a naming convention of numpy arrays as: 32 | _images.npz, _masks_lv.npz, _masks_myo.npz etc. 33 | If numpy arrays are not found, then data is loaded from sources and saved in numpy arrays. 34 | 35 | :param split: the split number, e.g. 0, 1 36 | :param split_type: the split type, e.g. training, validation, test, all (for all data) 37 | :param modality: modality to load if the dataset has multimodal data 38 | :param normalise: True/False: normalise images to [-1, 1] 39 | :param value_crop: True/False: crop values between 5-95 percentiles 40 | :param downsample: downsample image ratio - used for for testing 41 | :return: a Data object containing the loaded data 42 | """ 43 | pass 44 | 45 | @abstractmethod 46 | def load_unlabelled_data(self, split, split_type, modality='MR', normalise=True, value_crop=True): 47 | """ 48 | Load unlabelled data from saved numpy arrays. 49 | Assumes a naming convention of numpy arrays as ul__images.npz 50 | If numpy arrays are not found, then data is loaded from sources and saved in numpy arrays. 51 | :param split: the split number, e.g. 0, 1 52 | :param split_type: the split type, e.g. training, validation, test, all (for all data) 53 | :param modality: modality to load if the dataset has multimodal data 54 | :param normalise: True/False: normalise images to [-1, 1] 55 | :param value_crop: True/False: crop values between 5-95 percentiles 56 | :return: a Data object containing the loaded data 57 | """ 58 | pass 59 | 60 | @abstractmethod 61 | def load_all_data(self, split, split_type, modality='MR', normalise=True, value_crop=True): 62 | """ 63 | Load all images (labelled and unlabelled) from saved numpy arrays. 64 | Assumes a naming convention of numpy arrays as all__images.npz 65 | If numpy arrays are not found, then data is loaded from sources and saved in numpy arrays. 66 | :param split: the split number, e.g. 0, 1 67 | :param split_type: the split type, e.g. training, validation, test, all (for all data) 68 | :param modality: modality to load if the dataset has multimodal data 69 | :param normalise: True/False: normalise images to [-1, 1] 70 | :param value_crop: True/False: crop values between 5-95 percentiles 71 | :return: a Data object containing the loaded data 72 | """ 73 | pass 74 | 75 | @abstractmethod 76 | def load_raw_labelled_data(self, normalise=True, value_crop=True): 77 | """ 78 | Load raw data, do preprocessing e.g. normalisation, resampling, value cropping etc 79 | :param normalise: True or False to normalise data 80 | :param value_crop: True or False to crop in the 5-95 percentiles or not. 81 | :return: a pair of arrays (images, index) 82 | """ 83 | pass 84 | 85 | @abstractmethod 86 | def load_raw_unlabelled_data(self, include_labelled, normalise=True, value_crop=True): 87 | """ 88 | Load raw data, do preprocessing e.g. normalisation, resampling, value cropping etc 89 | :param include_labelled True or False to include labelled images or not 90 | :param normalise: True or False to normalise data 91 | :param value_crop: True or False to crop in the 5-95 percentiles or not. 92 | :return: a pair of arrays (images, index) 93 | """ 94 | pass 95 | 96 | def base_load_unlabelled_images(self, dataset, split, split_type, include_labelled, normalise, value_crop): 97 | npz_prefix_type = 'ul_' if not include_labelled else 'all_' 98 | npz_prefix = npz_prefix_type + 'norm_' if normalise else npz_prefix_type + 'unnorm_' 99 | 100 | # Load saved numpy array 101 | if os.path.exists(os.path.join(self.data_folder, npz_prefix + dataset + '_images.npz')): 102 | images = np.load(os.path.join(self.data_folder, npz_prefix + dataset + '_images.npz'))['arr_0'] 103 | index = np.load(os.path.join(self.data_folder, npz_prefix + dataset + '_index.npz'))['arr_0'] 104 | self.log.debug('Loaded compressed ' + dataset + ' unlabelled data of shape ' + str(images.shape)) 105 | # Load from source 106 | else: 107 | images, index = self.load_raw_unlabelled_data(include_labelled, normalise, value_crop) 108 | images = np.expand_dims(images, axis=3) 109 | np.savez_compressed(os.path.join(self.data_folder, npz_prefix + dataset + '_images'), images) 110 | np.savez_compressed(os.path.join(self.data_folder, npz_prefix + dataset + '_index'), index) 111 | 112 | assert split_type in ['training', 'validation', 'test', 'all'], 'Unknown split_type: ' + split_type 113 | 114 | if split_type == 'all': 115 | return images, index 116 | 117 | volumes = self.splits()[split][split_type] 118 | images = np.concatenate([images[index == v] for v in volumes]) 119 | index = np.concatenate([index[index==v] for v in volumes]) 120 | return images, index 121 | -------------------------------------------------------------------------------- /loaders/data.py: -------------------------------------------------------------------------------- 1 | from skimage.measure import block_reduce 2 | import numpy as np 3 | import logging 4 | log = logging.getLogger('data') 5 | 6 | 7 | class Data(object): 8 | """ 9 | Object that stores images and masks loaded from a dataset. It defines useful functions for data manipulation and 10 | data selection. 11 | """ 12 | 13 | def __init__(self, images, masks, index, downsample=1): 14 | """ 15 | Data constructor. 16 | :param images: a 4-D numpy array of images. Expected shape: (N, H, W, 1) 17 | :param masks: a 4-D numpy array of myocardium segmentation masks. Expected shape: (N, H, W, 1) 18 | :param index: a 1-D numpy array indicating the volume each image/mask belongs to. Used for data selection. 19 | """ 20 | if images is None: 21 | raise ValueError('Images cannot be None.') 22 | if masks is None: 23 | raise ValueError('Masks cannot be None.') 24 | if index is None: 25 | raise ValueError('Index cannot be None.') 26 | if images.shape != masks.shape: 27 | raise ValueError('Image shape=%s different from Mask shape=%s' % (str(images.shape), str(masks.shape))) 28 | if images.shape[0] != index.shape[0]: 29 | raise ValueError('Different number of images and indices: %d vs %d' % (images.shape[0], index.shape[0])) 30 | 31 | self.images = images 32 | self.masks = masks 33 | self.index = index 34 | self.downsample(downsample) 35 | num_volumes = len(self.volumes()) 36 | log.info('Created Data object with images of shape %s and %d volumes' % (str(images.shape), num_volumes)) 37 | log.info('Images value range [%.1f, %.1f]' % (images.min(), images.max())) 38 | log.info('Masks value range [%.1f, %.1f]' % (masks.min(), masks.max())) 39 | 40 | def volumes(self): 41 | return sorted(set(self.index)) 42 | 43 | def get_volume_image(self, vol): 44 | return self.images[self.index == vol] 45 | 46 | def get_volume_mask(self, vol): 47 | return self.masks[self.index == vol] 48 | 49 | def size(self): 50 | return len(self.images) 51 | 52 | def resize(self, num): 53 | self.images = self.images[:num] 54 | self.masks = self.masks[:num] 55 | 56 | def shape(self): 57 | return self.images.shape 58 | 59 | def downsample(self, ratio=2): 60 | if ratio == 1: 61 | return 62 | 63 | self.images = block_reduce(self.images, block_size=(1, ratio, ratio, 1), func=np.mean) 64 | if self.masks is not None: 65 | self.masks = block_reduce(self.masks, block_size=(1, ratio, ratio, 1), func=np.mean) 66 | log.info('Downsampled data by %d to shape %s' % (ratio, str(self.images.shape))) 67 | 68 | def sample(self, nb_samples, seed=-1): 69 | log.info('Sampling %d images out of total %d' % (nb_samples, self.size())) 70 | if seed > -1: 71 | np.random.seed(seed) 72 | 73 | idx = np.random.choice(self.size(), size=nb_samples, replace=False) 74 | log.debug('Indices sampled: ' + str(idx)) 75 | self.images = np.array([self.images[i] for i in idx]) 76 | self.masks = np.array([self.masks[i] for i in idx]) 77 | self.index = np.array([self.index[i] for i in idx]) 78 | -------------------------------------------------------------------------------- /loaders/loader_factory.py: -------------------------------------------------------------------------------- 1 | from loaders.acdc import ACDCLoader 2 | 3 | 4 | def init_loader(dataset): 5 | """ 6 | Factory method for initialising data loaders by name. 7 | """ 8 | if dataset == 'acdc': 9 | return ACDCLoader() 10 | return None 11 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Entry point for running an experiment with SDNet. 3 | """ 4 | import argparse 5 | import os 6 | import logging 7 | from config import Configuration 8 | from loaders import loader_factory 9 | from sdnet import SDNet 10 | from sdnet_trainer import SDNetTrainer 11 | 12 | 13 | def init_logging(config): 14 | if not os.path.exists(config.folder): 15 | os.makedirs(config.folder) 16 | logging.basicConfig(filename=config.folder + '/logfile.log', level=logging.DEBUG, format='%(asctime)s %(message)s') 17 | logging.getLogger().addHandler(logging.StreamHandler()) 18 | 19 | log = logging.getLogger() 20 | log.debug(config.__dict__) 21 | log.info('---- Setting up experiment at ' + config.folder + '----') 22 | 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser(description='Run SDNet') 26 | parser.add_argument('--epochs', help='Number of epochs to train', type=int) 27 | parser.add_argument('--dataset', help='Dataset to use', choices=['acdc'], required=True) 28 | parser.add_argument('--description', help='Experiment description') 29 | parser.add_argument('--split', help='Split for Cross Validation', type=int, required=True) 30 | parser.add_argument('--test', help='Test', type=bool) 31 | parser.add_argument('--ul_mix', help='Percentage of unlabelled data to mix', type=float, required=True) 32 | parser.add_argument('--l_mix', help='Percentage of labelled data to mix', type=float, required=True) 33 | args = parser.parse_args() 34 | 35 | # Create configuration object from parameters 36 | loader = loader_factory.init_loader(args.dataset) 37 | data = loader.load_labelled_data(args.split, 'training') 38 | 39 | folder = 'sdnet_%s_ul_%.3f_l_%.3f_split%d' % (args.dataset, args.ul_mix, args.l_mix, args.split) 40 | conf = Configuration(folder, data.size(), data.shape()[1:]) 41 | del data 42 | 43 | conf.description = args.description if args.description else '' 44 | if args.epochs: 45 | conf.epochs = args.epochs 46 | conf.dataset_name = args.dataset 47 | conf.ul_mix = args.ul_mix 48 | conf.l_mix = args.l_mix 49 | conf.split = args.split 50 | conf.save() 51 | 52 | init_logging(conf) 53 | 54 | sdnet = SDNet(conf) 55 | sdnet.build() 56 | 57 | trainer = SDNetTrainer(sdnet, conf) 58 | 59 | if not args.test: 60 | trainer.fit() 61 | trainer.test() 62 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/spatial_factorisation/233d72511ffb52f52214a68f1c996555345991d0/models/__init__.py -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | from keras import Input, Model 2 | from keras.layers import Conv2D, LeakyReLU, Flatten, Dense 3 | import logging 4 | log = logging.getLogger('discriminator') 5 | 6 | class Discriminator(object): 7 | """ 8 | LS-GAN Discriminator 9 | """ 10 | def __init__(self, inp_shape, output='2D', downsample_blocks=3, name=''): 11 | """ 12 | Discriminator constructor 13 | :param inp_shape: 3-D input shape: (H, W, 1) 14 | :param output: can be 1D if there's a single decision or 2D if the output is a 2D image 15 | :param downsample_blocks: number of downsample blocks 16 | :param name: Model name 17 | """ 18 | super(Discriminator, self).__init__() 19 | 20 | self.inp_shape = inp_shape 21 | self.output = output 22 | self.downsample_blocks = downsample_blocks 23 | self.name = name 24 | self.model = None 25 | 26 | def build(self): 27 | f = 32 28 | 29 | d_input = Input(self.inp_shape) 30 | l = Conv2D(f, 4, strides=2, padding='same')(d_input) 31 | l = LeakyReLU(0.2)(l) 32 | 33 | for i in range(self.downsample_blocks): 34 | s = 1 if i == self.downsample_blocks - 1 else 2 35 | l = Conv2D(f * 2 * (2 ** i), 4, strides=s, padding='same')(l) 36 | l = LeakyReLU(0.2)(l) 37 | 38 | if self.output == '2D': 39 | l = Conv2D(1, 4, padding='same')(l) 40 | elif self.output == '1D': 41 | l = Flatten()(l) 42 | l = Dense(1, activation='linear')(l) 43 | 44 | self.model = Model(d_input, l, name=self.name) 45 | log.info('Discriminator %s' % self.name) 46 | self.model.summary(print_fn=log.info) 47 | 48 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | from keras import Input, Model 2 | from keras.layers import Conv2D, BatchNormalization, Add, UpSampling2D, LeakyReLU, Lambda 3 | import keras.backend as K 4 | from layers.instance_normalization import InstanceNormalization 5 | 6 | 7 | class ResNet(object): 8 | """ 9 | A ResNet neural network for image synthesis 10 | """ 11 | def __init__(self, input_shape, norm=None, nb_blocks=6, name=''): 12 | """ 13 | Resnet constructor. 14 | :param input_shape: image shape (H, W, 1) 15 | :param norm: layer normalisation: Can be batch for BatchNormalization or norm for InstanceNormalization 16 | :param nb_blocks: number of residual blocks 17 | :param name: model name 18 | """ 19 | self.input_shape = input_shape 20 | self.norm = norm 21 | self.nb_blocks = nb_blocks 22 | self.name = name 23 | self.model = None 24 | 25 | def build(self): 26 | """ 27 | Build the model 28 | """ 29 | input = Input(self.input_shape) 30 | l = self.downsample(input) 31 | l = self.residuals(l) 32 | l = self.upsample(l) 33 | self.output(input, l) 34 | 35 | def downsample(self, input): 36 | """ 37 | Build downsample layers: c7s1-32, c3s2-64, c3s2-128 38 | :param input: input layer 39 | :return: last layer of downsample operation 40 | """ 41 | f = 32 42 | 43 | # c7s1-32 44 | l = Conv2D(f, 7, padding='same')(input) 45 | l = normalise(self.norm, inshape=K.int_shape(l))(l) 46 | l = LeakyReLU()(l) 47 | 48 | # c3s2-64 49 | l = Conv2D(f * 2, 3, strides=2, padding='same')(l) 50 | l = normalise(self.norm, inshape=K.int_shape(l))(l) 51 | l = LeakyReLU()(l) 52 | 53 | # c3s2-128 54 | l = Conv2D(f * 4, 3, strides=2, padding='same')(l) 55 | l = normalise(self.norm, inshape=K.int_shape(l))(l) 56 | l = LeakyReLU()(l) 57 | 58 | return l 59 | 60 | def residuals(self, l, f=32 * 4): 61 | """ 62 | Build residual layers: R128 * nb_blocks 63 | :param l: input layers 64 | :param f: number of feature maps 65 | :return: the last layer of the residuals 66 | """ 67 | for block in range(self.nb_blocks): 68 | l = residual_block(l, f, self.norm) 69 | return l 70 | 71 | def upsample(self, l): 72 | """ 73 | Build uplample layers: u64, u32 74 | :param l: input layer 75 | :return: the last layer of the upsample operation 76 | """ 77 | f = 32 78 | 79 | # u64 80 | l = UpSampling2D(size=2)(l) 81 | l = Conv2D(f * 2, 3, padding='same')(l) 82 | l = normalise(self.norm, inshape=K.int_shape(l))(l) 83 | l = LeakyReLU()(l) 84 | 85 | # u32 86 | l = UpSampling2D(size=2)(l) 87 | l = Conv2D(f, 3, padding='same')(l) 88 | l = normalise(self.norm, inshape=K.int_shape(l))(l) 89 | l = LeakyReLU()(l) 90 | 91 | return l 92 | 93 | def output(self, input, l): 94 | """ 95 | Build last output layer and a ResNet model 96 | :param input: input layer 97 | :param l: last upsample layer 98 | """ 99 | l = Conv2D(1, 7, activation='tanh', padding='same')(l) 100 | self.model = Model(inputs=input, outputs=l, name=self.name) 101 | 102 | 103 | def normalise(norm=None, **kwargs): 104 | """ 105 | Build a Keras normalization layer 106 | :param norm: normalization option 107 | :return: a normalization layer 108 | """ 109 | if norm == 'instance': 110 | return InstanceNormalization(**kwargs) 111 | elif norm == 'batch': 112 | return BatchNormalization() 113 | else: 114 | return Lambda(lambda x: x) 115 | 116 | 117 | def residual_block(l0, f, norm): 118 | """ 119 | Build residual block 120 | :param l0: first layer 121 | :param f: number of feature maps 122 | :param norm: normalization type 123 | :return: last layer of the block 124 | """ 125 | l = Conv2D(f, 3, strides=1, padding='same')(l0) 126 | l = normalise(norm)(l) 127 | l = LeakyReLU()(l) 128 | l = Conv2D(f, 3, strides=1, padding='same')(l) 129 | l = normalise(norm)(l) 130 | l = Add()([l0, l]) 131 | return LeakyReLU()(l) 132 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | from keras import Input, Model 2 | from keras.layers import Concatenate, Conv2D, BatchNormalization, LeakyReLU, UpSampling2D, Activation, MaxPooling2D 3 | 4 | 5 | class UNet(object): 6 | """ 7 | UNet implementation of 4 downsampling and 4 upsampling blocks for segmentation. 8 | Each block has 2 convolutions, batch normalisation, leaky relu and an optional residual connection. 9 | The number of filters for the 1st layer is 64 and at every block, this is doubled. Each upsampling blocks halves the 10 | number of filters. 11 | """ 12 | def __init__(self, input_shape, residual, f=64): 13 | """ 14 | Constructor 15 | :param input_shape: image shape: (H, W, 1) 16 | :param residual: option for residual blocks in the downsampling and upsampling paths 17 | :param f: number of feature maps in the first layer 18 | """ 19 | self.input_shape = input_shape 20 | self.residual = residual 21 | self.f = f 22 | 23 | # model layers 24 | self.model = None # the Keras model 25 | self.input = None # input layer 26 | self.d_l0 = None # downsample layer 1 27 | self.d_l1 = None # downsample layer 2 28 | self.d_l2 = None # downsample layer 3 29 | self.d_l3 = None # downsample layer 4 30 | self.bottleneck = None # most downsampled UNet layer 31 | self.u_l3 = None # upsample layer 1 32 | self.u_l2 = None # upsample layer 2 33 | self.u_l1 = None # upsample layer 3 34 | self.u_l0 = None # upsample layer 4 35 | 36 | def build(self): 37 | """ 38 | Build the model. 39 | """ 40 | self.input = Input(shape=self.input_shape) 41 | l = self.unet_downsample(self.input) 42 | self.unet_bottleneck(l) 43 | l = self.unet_upsample(self.bottleneck) 44 | out = self.out(l) 45 | self.model = Model(inputs=self.input, outputs=out) 46 | 47 | def unet_downsample(self, inp): 48 | """ 49 | Build downsampling path 50 | :param inp: input layer 51 | :return: last layer of the downsampling path 52 | """ 53 | self.d_l0 = conv_block(inp, self.f, self.residual) 54 | l = MaxPooling2D(pool_size=(2, 2))(self.d_l0) 55 | self.d_l1 = conv_block(l, self.f * 2, self.residual) 56 | l = MaxPooling2D(pool_size=(2, 2))(self.d_l1) 57 | self.d_l2 = conv_block(l, self.f * 4, self.residual) 58 | l = MaxPooling2D(pool_size=(2, 2))(self.d_l2) 59 | self.d_l3 = conv_block(l, self.f * 8) 60 | l = MaxPooling2D(pool_size=(2, 2))(self.d_l3) 61 | return l 62 | 63 | def unet_bottleneck(self, l): 64 | """ 65 | Build bottleneck layers 66 | :param l: the input layer 67 | """ 68 | self.bottleneck = conv_block(l, self.f * 16, self.residual) 69 | 70 | def unet_upsample(self, l): 71 | """ 72 | Build upsampling path 73 | :param l: the input layer 74 | :return: the last layer of the upsampling path 75 | """ 76 | l = upsample_block(l, self.f * 8, activation='linear') 77 | l = Concatenate()([l, self.d_l3]) 78 | self.u_l3 = conv_block(l, self.f * 8) 79 | l = upsample_block(self.u_l3, self.f * 4, activation='linear') 80 | l = Concatenate()([l, self.d_l2]) 81 | self.u_l2 = conv_block(l, self.f * 4, self.residual) 82 | l = upsample_block(self.u_l2, self.f * 2, activation='linear') 83 | l = Concatenate()([l, self.d_l1]) 84 | self.u_l1 = conv_block(l, self.f * 2, self.residual) 85 | l = upsample_block(self.u_l1, self.f, activation='linear') 86 | l = Concatenate()([l, self.d_l0]) 87 | self.u_l0 = conv_block(l, self.f, self.residual) 88 | return self.u_l0 89 | 90 | def out(self, l): 91 | """ 92 | Build ouput layer 93 | :param l: last layer from the upsampling path 94 | :return: the final segmentation layer 95 | """ 96 | return Conv2D(1, 1, activation='sigmoid')(l) 97 | 98 | 99 | def conv_block(l0, f, residual=False): 100 | """ 101 | Convolutional block of the downsampling path 102 | :param l0: the input layer 103 | :param f: number of feature maps 104 | :param residual: True/False to define residual connections 105 | :return: the last layer of the convolutional block 106 | """ 107 | l = Conv2D(f, 3, strides=1, padding='same')(l0) 108 | l = BatchNormalization()(l) 109 | l = LeakyReLU()(l) 110 | l = Conv2D(f, 3, strides=1, padding='same')(l) 111 | l = BatchNormalization()(l) 112 | l = LeakyReLU()(l) 113 | return Concatenate()([l0, l]) if residual else l 114 | 115 | 116 | def upsample_block(l0, f, activation='relu'): 117 | """ 118 | Upsampling block. 119 | :param l0: input layer 120 | :param f: number of feature maps 121 | :param activation: activation name 122 | :return: the last layer of the upsampling block 123 | """ 124 | l = UpSampling2D(size=2)(l0) 125 | l = Conv2D(f, 3, padding='same')(l) 126 | l = BatchNormalization()(l) 127 | return Activation(activation)(l) 128 | -------------------------------------------------------------------------------- /parameters.py: -------------------------------------------------------------------------------- 1 | conf = { 2 | 'data_folder': '../data', 3 | 'acdc': '../../data/Cardiac/ACDC/segmentation/training' 4 | } 5 | -------------------------------------------------------------------------------- /sdnet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import matplotlib 5 | 6 | matplotlib.use('Agg') 7 | 8 | from keras import Model 9 | from keras.layers import Input, Flatten, Dense, Concatenate, Conv2D, Reshape, BatchNormalization, LeakyReLU, \ 10 | UpSampling2D 11 | from keras.optimizers import Adam 12 | 13 | import costs 14 | from loaders import loader_factory 15 | from models.discriminator import Discriminator 16 | from layers.rounding import Rounding 17 | from models.resnet import ResNet 18 | from models.unet import UNet 19 | 20 | log = logging.getLogger('sdnet') 21 | 22 | 23 | class SDNet(object): 24 | """ 25 | SDNet model for semi-supervised segmentation. 26 | """ 27 | 28 | def __init__(self, conf): 29 | """ 30 | SDNet constructor 31 | :param conf: configuration object 32 | """ 33 | super(SDNet, self).__init__() 34 | self.other_masks = None 35 | self.conf = conf 36 | self.loader = loader_factory.init_loader(self.conf.dataset_name) 37 | 38 | self.D_model = None # Discriminator trainer 39 | self.G_model = None # Unsupervised generator trainer 40 | self.G_supervised_model = None # Supervised generator trainer 41 | self.Decomposer = None # Decomposer 42 | self.Reconstructor = None # Reconstructor 43 | self.ImageDiscriminator = None # Image discriminator 44 | self.MaskDiscriminator = None # Mask discriminator 45 | 46 | def build(self): 47 | self.build_discriminator_trainer() 48 | self.build_generator_trainer() 49 | self.load_models() 50 | 51 | def load_models(self): 52 | """ 53 | Load weights from saved model files 54 | """ 55 | if os.path.exists(self.conf.folder + '/D_model'): 56 | print('Loading trained D_model from file') 57 | self.D_model.load_weights(self.conf.folder + '/D_model') 58 | self.ImageDiscriminator = get_net(self.D_model, 'D_X') 59 | self.MaskDiscriminator = get_net(self.D_model, 'D_M') 60 | 61 | if os.path.exists(self.conf.folder + '/G_model'): 62 | print('Loading trained G_model from file') 63 | self.G_model.load_weights(self.conf.folder + '/G_model') 64 | self.Decomposer = get_net(self.G_model, 'Decomposer') 65 | self.Reconstructor = get_net(self.G_model, 'Reconstructor') 66 | 67 | if os.path.exists(self.conf.folder + '/G_supervised_model'): 68 | print('Loading trained G_supervised_model from file') 69 | self.G_supervised_model.load_weights(self.conf.folder + '/G_supervised_model') 70 | self.Decomposer = get_net(self.G_model, 'Decomposer') 71 | self.Reconstructor = get_net(self.G_model, 'Reconstructor') 72 | 73 | def save_models(self): 74 | """ 75 | Save model weights in files. 76 | """ 77 | self.D_model.save_weights(self.conf.folder + '/D_model') 78 | self.G_model.save_weights(self.conf.folder + '/G_model') 79 | self.G_supervised_model.save_weights(self.conf.folder + '/G_supervised_model') 80 | 81 | def build_discriminator_trainer(self): 82 | """ 83 | Build a Keras model for training image and mask discriminators. 84 | """ 85 | # Mask Discriminator 86 | D_Mask = Discriminator(self.conf.input_shape, output='2D', downsample_blocks=3, name='D_M') 87 | D_Mask.build() 88 | self.MaskDiscriminator = D_Mask.model 89 | 90 | real_M = Input(self.conf.input_shape) 91 | fake_M = Input(self.conf.input_shape) 92 | dis_real_M = self.MaskDiscriminator(real_M) 93 | dis_fake_M = self.MaskDiscriminator(fake_M) 94 | 95 | D_Image = Discriminator(self.conf.input_shape, output='2D', downsample_blocks=3, name='D_X') 96 | D_Image.build() 97 | self.ImageDiscriminator = D_Image.model 98 | 99 | real_X = Input(self.conf.input_shape) 100 | fake_X = Input(self.conf.input_shape) 101 | dis_real_X = self.ImageDiscriminator(real_X) 102 | dis_fake_X = self.ImageDiscriminator(fake_X) 103 | 104 | self.D_model = Model(inputs=[real_M, fake_M, real_X, fake_X], 105 | outputs=[dis_real_M, dis_fake_M, dis_real_X, dis_fake_X]) 106 | self.D_model.compile(Adam(lr=0.0001, beta_1=0.5), loss='mse') 107 | log.info('Discriminators Trainer') 108 | self.D_model.summary(print_fn=log.info) 109 | 110 | def build_generator_trainer(self): 111 | """ 112 | Build Decomposer, Reconstructor and training models. 113 | """ 114 | assert self.D_model is not None, 'Discriminator has not been built yet' 115 | make_trainable(self.D_model, False) 116 | 117 | self.Decomposer = self._decomposer() 118 | self.Reconstructor = self._reconstructor() 119 | 120 | self.build_unsupervised_trainer() 121 | self.build_supervised_trainer() 122 | 123 | def build_unsupervised_trainer(self): 124 | """ 125 | Build a Keras model for training SDNet with no supervision, using adversarial training with a mask 126 | discriminator and an image reconstruction cost. 127 | """ 128 | # Decomposition/Segmentation X -> M', Z 129 | real_X = Input(self.conf.input_shape) 130 | fake_M, fake_Z = self.Decomposer(real_X) 131 | adv_M = self.MaskDiscriminator(fake_M) 132 | 133 | # Reconstruction M', Z' -> X' 134 | rec_X = self.Reconstructor([fake_M, fake_Z]) 135 | adv_X = self.ImageDiscriminator(rec_X) 136 | 137 | self.G_model = Model(inputs=real_X, outputs=[adv_M, rec_X, adv_X]) 138 | self.G_model.compile(Adam(lr=0.0001, beta_1=0.5), loss=['mse', 'mae', 'mse'], 139 | loss_weights=[self.conf.w_uns_adv_M, self.conf.w_uns_rec_X, self.conf.w_uns_adv_X]) 140 | log.info('Unsupervised trainer') 141 | self.G_model.summary(print_fn=log.info) 142 | 143 | def build_supervised_trainer(self): 144 | """ 145 | Build a Keras model for training SDNet with supervision, when we have labelled data. 146 | """ 147 | # Decomposition/Segmentation X -> M', Z' 148 | real_X = Input(self.conf.input_shape) 149 | fake_M, fake_Z = self.Decomposer(real_X) 150 | adv_M = self.MaskDiscriminator(fake_M) 151 | 152 | # Reconstruction M', Z' -> X' 153 | rec_X = self.Reconstructor([fake_M, fake_Z]) 154 | 155 | # Reconstruction using a real Mask: M, Z' -> X' 156 | real_M = Input(self.conf.input_shape) 157 | fake_X = self.Reconstructor([real_M, fake_Z]) 158 | adv_X = self.ImageDiscriminator(fake_X) 159 | 160 | self.G_supervised_model = Model(inputs=[real_X, real_M], outputs=[fake_M, fake_X, rec_X, adv_M, adv_X]) 161 | self.G_supervised_model.compile(Adam(lr=0.0001, beta_1=0.5), 162 | loss=[costs.dice_coef_loss, 'mae', 'mae', 'mse', 'mse'], 163 | loss_weights=[self.conf.w_fake_M, self.conf.w_fake_X, self.conf.w_rec_X, 164 | self.conf.w_adv_M, self.conf.w_adv_X_fromreal]) 165 | log.info('Supervised trainer') 166 | self.G_supervised_model.summary(print_fn=log.info) 167 | 168 | def _decomposer(self): 169 | """ 170 | Build an image decomposer into a spatial binary mask of the myocardium and a non-spatial vector z of the 171 | remaining image information. 172 | :return a Keras model of the decomposer 173 | """ 174 | input = Input(self.conf.input_shape) 175 | 176 | unet = UNet(self.conf.input_shape, residual=False) 177 | l = unet.unet_downsample(input) 178 | unet.unet_bottleneck(l) 179 | l = unet.bottleneck 180 | 181 | # build Z regressor 182 | modality = Conv2D(256, 3, strides=1, padding='same')(l) 183 | modality = BatchNormalization()(modality) 184 | modality = LeakyReLU()(modality) 185 | modality = Conv2D(64, 3, strides=1, padding='same')(modality) 186 | modality = BatchNormalization()(modality) 187 | modality = LeakyReLU()(modality) 188 | modality = Flatten()(modality) 189 | modality = Dense(32)(modality) 190 | modality = LeakyReLU()(modality) 191 | modality = Dense(16, activation='sigmoid')(modality) 192 | 193 | l = unet.unet_upsample(unet.bottleneck) 194 | anatomy = unet.out(l) 195 | 196 | m = Model(inputs=input, outputs=[anatomy, modality], name='Decomposer') 197 | log.info('Decomposer') 198 | m.summary(print_fn=log.info) 199 | return m 200 | 201 | def _reconstructor(self): 202 | """ 203 | Build an image reconstructor, that fuses an anatomy (binary mask) and Z to reconstructs the input image. 204 | :return: a Keras model of the reconstructor 205 | """ 206 | mask_input = Input(shape=self.conf.input_shape) 207 | round = Rounding()(mask_input) # rounding layer that binarises the anatomical representation. 208 | 209 | resnet = ResNet(self.conf.input_shape, norm='instance', nb_blocks=3, name='Reconstructor') 210 | 211 | # Map Z into a 8-channel feature map 212 | resd_input = Input((16,)) 213 | modality = Dense(32)(resd_input) 214 | modality = LeakyReLU()(modality) 215 | modality = Dense(self.conf.input_shape[0] * self.conf.input_shape[1])(modality) 216 | modality = LeakyReLU()(modality) 217 | modality = Reshape((int(self.conf.input_shape[0] / 4), int(self.conf.input_shape[1] / 4), 16))(modality) 218 | modality = UpSampling2D(size=2)(modality) 219 | modality = Conv2D(16, 3, padding='same')(modality) 220 | modality = BatchNormalization()(modality) 221 | modality = LeakyReLU()(modality) 222 | modality = UpSampling2D(size=2)(modality) 223 | modality = Conv2D(8, 3, padding='same')(modality) 224 | modality = BatchNormalization()(modality) 225 | modality = LeakyReLU()(modality) 226 | 227 | # Concatenate Mask and Z 228 | conc_lr = Concatenate()([round, modality]) 229 | l = resnet.residuals(conc_lr, f=9) 230 | resnet.output([mask_input, resd_input], l) 231 | resnet.model.summary(print_fn=log.info) 232 | return resnet.model 233 | 234 | 235 | def get_net(trainer_model, name): 236 | """ 237 | Helper method to get a layer with a given name out of a model 238 | :param trainer_model: base model 239 | :param name: layer name 240 | :return: a layer with the specified name 241 | """ 242 | layers = [l for l in trainer_model.layers if l.name == name] 243 | assert len(layers) == 1 244 | return layers[0] 245 | 246 | 247 | def make_trainable(model, val): 248 | """ 249 | Helper method to enable/disable training of a model 250 | :param model: a Keras model 251 | :param val: True/False 252 | """ 253 | model.trainable = val 254 | try: 255 | for l in model.layers: 256 | try: 257 | for k in l.layers: 258 | make_trainable(k, val) 259 | except: 260 | # Layer is not a model, so continue 261 | pass 262 | l.trainable = val 263 | except: 264 | # Layer is not a model, so continue 265 | pass 266 | -------------------------------------------------------------------------------- /sdnet_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import numpy as np 4 | import scipy 5 | from keras.callbacks import CSVLogger, EarlyStopping 6 | from keras.utils import Progbar 7 | 8 | import costs 9 | from callbacks.loss_callback import SaveLoss 10 | from callbacks.sdnet_callback import SDNetCallback 11 | from loaders import loader_factory 12 | from utils import data_utils 13 | 14 | log = logging.getLogger('sdnettrainer') 15 | 16 | 17 | class SDNetTrainer(object): 18 | """ 19 | Trainer class for running a segmentation experiment using SDNet. 20 | """ 21 | def __init__(self, sdnet, conf): 22 | self.sdnet = sdnet 23 | self.conf = conf 24 | self.loader = loader_factory.init_loader(self.conf.dataset_name) 25 | 26 | # Data iterators 27 | self.gen_X_L = None # labelled data: (image, mask) pairs 28 | self.gen_X_U = None # unlabelled data 29 | self.other_masks = None # real masks to use for discriminator training 30 | 31 | self.fake_image_pool = [] 32 | self.fake_mask_pool = [] 33 | self.batch = 0 34 | self.epoch = 0 35 | 36 | if not os.path.exists(self.conf.folder): 37 | os.makedirs(self.conf.folder) 38 | 39 | def init_train(self): 40 | """ 41 | Initialise data generators for iterating through the images and masks. 42 | """ 43 | data = self.loader.load_labelled_data(self.conf.split, 'training') 44 | 45 | # Initialise unlabelled data iterator 46 | num_ul = 0 47 | if self.conf.ul_mix > 0: 48 | ul_data = self.loader.load_unlabelled_data(self.conf.split, 'all') 49 | 50 | # calculate number of unlabelled images as a proportion of the labelled images 51 | num_ul = int(data.size() * self.conf.ul_mix) 52 | num_ul = num_ul if num_ul <= ul_data.size() else ul_data.size() 53 | log.info('Sampling %d unlabelled images out of total %d.' % (num_ul, ul_data.size())) 54 | ul_data.sample(num_ul) 55 | self.gen_X_U = data_utils.generator(self.conf.batch_size, 'overflow', ul_data.images) 56 | 57 | # Initialise labelled data iterator 58 | assert self.conf.l_mix >= 0 59 | 60 | # calculate number of labelled images 61 | num_l = int(data.size() * self.conf.l_mix) 62 | num_l = num_l if num_l <= data.size() else data.size() 63 | log.info('Using %d labelled images out of total %d.' % (num_l, data.size())) 64 | train_images = data.images[:num_l] 65 | train_masks = data.masks[:num_l] 66 | 67 | self.conf.unlabelled_image_num = num_ul 68 | self.conf.labelled_image_num = num_l 69 | self.conf.data_len = num_ul if num_ul > num_l else num_l 70 | self.conf.batches = int(np.ceil(self.conf.data_len / self.conf.batch_size)) 71 | self.conf.save() 72 | 73 | self.gen_X_L = data_utils.generator(self.conf.batch_size, 'overflow', train_images, train_masks) 74 | 75 | # Initialise real masks iterator for discriminator training, using the real masks from the data CV split. 76 | self.other_masks = data_utils.generator(self.conf.batch_size, 'overflow', data.masks + 0) 77 | 78 | def fit(self): 79 | """ 80 | Train SDNet 81 | """ 82 | log.info('Training SDNet') 83 | 84 | # Load data 85 | self.init_train() 86 | 87 | # Initialise callbacks 88 | sl = SaveLoss(self.conf.folder) 89 | cl = CSVLogger(self.conf.folder + '/training.csv') 90 | cl.on_train_begin() 91 | si = SDNetCallback(self.conf.folder, self.conf.batch_size, self.sdnet) 92 | es = EarlyStopping('val_loss', min_delta=0.001, patience=20) 93 | es.on_train_begin() 94 | 95 | loss_names = ['adv_M', 'adv_X', 'rec_X', 'rec_M', 'rec_Z', 'dis_M', 'dis_X', 'mask', 'image', 'val_loss'] 96 | 97 | total_loss = {n: [] for n in loss_names} 98 | 99 | progress_bar = Progbar(target=self.conf.batches * self.conf.batch_size) 100 | 101 | for self.epoch in range(self.conf.epochs): 102 | log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs)) 103 | 104 | real_lb_pool, real_ul_pool = [], [] # these are used only for printing images 105 | 106 | epoch_loss = {n: [] for n in loss_names} 107 | 108 | D_initial_weights = np.mean([np.mean(w) for w in self.sdnet.D_model.get_weights()]) 109 | G_initial_weights = np.mean([np.mean(w) for w in self.sdnet.G_model.get_weights()]) 110 | for self.batch in range(self.conf.batches): 111 | real_lb = next(self.gen_X_L) 112 | real_ul = next(self.gen_X_U) 113 | 114 | # Add image/mask batch to the data pool 115 | x, m = real_lb 116 | real_lb_pool.extend([(x[i:i+1], m[i:i+1]) for i in range(x.shape[0])]) 117 | real_ul_pool.extend(real_ul) 118 | 119 | D_weights1 = np.mean([np.mean(w) for w in self.sdnet.D_model.get_weights()]) 120 | self.train_batch_generator(real_lb, real_ul, epoch_loss) 121 | D_weights2 = np.mean([np.mean(w) for w in self.sdnet.D_model.get_weights()]) 122 | assert D_weights1 == D_weights2 123 | 124 | self.train_batch_discriminator(real_lb, real_ul, epoch_loss) 125 | 126 | progress_bar.update((self.batch + 1) * self.conf.batch_size) 127 | 128 | G_final_weights = np.mean([np.mean(w) for w in self.sdnet.G_model.get_weights()]) 129 | D_final_weights = np.mean([np.mean(w) for w in self.sdnet.D_model.get_weights()]) 130 | 131 | # Check training is altering weights 132 | assert D_initial_weights != D_final_weights 133 | assert G_initial_weights != G_final_weights 134 | 135 | # Plot some example images 136 | si.on_epoch_end(self.epoch, np.array(real_lb_pool), np.array(real_ul_pool)) 137 | 138 | self.validate(epoch_loss) 139 | 140 | # Calculate epoch losses 141 | for n in loss_names: 142 | total_loss[n].append(np.mean(epoch_loss[n])) 143 | log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % \ 144 | ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) 145 | logs = {l: total_loss[l][-1] for l in loss_names} 146 | sl.on_epoch_end(self.epoch, logs) 147 | 148 | # log losses to csv 149 | cl.model = self.sdnet.D_model 150 | cl.model.stop_training = False 151 | cl.on_epoch_end(self.epoch, logs) 152 | 153 | # save models 154 | self.sdnet.save_models() 155 | 156 | # early stopping 157 | if self.stop_criterion(es, self.epoch, logs): 158 | log.info('Finished training from early stopping criterion') 159 | break 160 | 161 | def train_batch_generator(self, real_lb, real_ul, epoch_loss): 162 | """ 163 | Train Generator networks. This is done in two passes: 164 | (a) Unsupervised training using unlabelled data and masks 165 | (b) Supervised training using labelled data and masks. 166 | :param real_lb: labelled tuple of images and masks 167 | :param real_ul: unlabelled images 168 | :param epoch_loss: loss dictionary for the epoch 169 | """ 170 | 171 | X_L, X_M = real_lb 172 | X_U = real_ul 173 | 174 | # Train unlabelled path (G_model) 175 | adv_M, rec_X, adv_X = [], [], [] 176 | if X_U.shape[0] > 0: 177 | zeros = np.zeros((X_U.shape[0],) + self.sdnet.D_model.output_shape[0][1:]) 178 | _, ul_l_adv_M, ul_l_rec_X, ul_l_adv_X = self.sdnet.G_model.train_on_batch(X_U, [zeros, X_U, zeros]) 179 | assert np.mean(ul_l_adv_M) >= 0, "loss_fake_M: " + str(ul_l_adv_M) 180 | assert ul_l_rec_X >= 0, "loss_rec_X: " + str(ul_l_rec_X) 181 | adv_M.append(ul_l_adv_M) 182 | rec_X.append(ul_l_rec_X) 183 | adv_X.append(ul_l_adv_X) 184 | 185 | # Train labelled path (G_supervised_model) 186 | if X_L.shape[0] > 0: 187 | zeros = np.zeros((X_L.shape[0],) + self.sdnet.D_model.output_shape[0][1:]) 188 | _, Z = self.sdnet.Decomposer.predict(X_L) 189 | x = [X_L, X_M] 190 | y = [X_M, X_L, X_L, zeros, zeros] 191 | _, l_mask, l_img, l_rec_X , l_adv_M, l_adv_X = self.sdnet.G_supervised_model.train_on_batch(x, y) 192 | epoch_loss['mask'].append(l_mask) 193 | epoch_loss['image'].append(l_img) 194 | adv_M.append(l_adv_M) 195 | rec_X.append(l_rec_X) 196 | adv_X.append(l_adv_X) 197 | 198 | epoch_loss['adv_M'].append(np.mean(adv_M)) 199 | epoch_loss['adv_X'].append(np.mean(adv_X)) 200 | epoch_loss['rec_X'].append(np.mean(rec_X)) 201 | 202 | def train_batch_discriminator(self, real_lb, real_ul, epoch_loss): 203 | """ 204 | Train a discriminator with real X / fake X and real M / fake M. To produce a fake X we use a real M and a Z 205 | produced by the mask's decomposition 206 | :param real_lb: tuple of labelled images and masks 207 | :param real_B: unlabelled images 208 | :param epoch_loss: dictionary of losses for the epoch 209 | """ 210 | X_L, X_M = real_lb 211 | X_U = real_ul 212 | 213 | # When reaching the end of the array, the array size might be less than the true batch size 214 | batch_size = np.min([X_L.shape[0], X_U.shape[0]]) 215 | 216 | if batch_size < X_M.shape[0]: 217 | idx = np.random.choice(X_M.shape[0], size=batch_size, replace=False) 218 | X_M = np.array([X_M[i] for i in idx]) 219 | 220 | X = self.sample_X(X_L, X_U, size=batch_size) 221 | fake_M, Z = self.sdnet.Decomposer.predict(X) 222 | fake_X = self.sdnet.Reconstructor.predict([X_M, Z]) 223 | 224 | # Pool of fake images. Using one pool regularises the Mask discriminator in the first epochs. 225 | self.fake_mask_pool, fake_M = self.get_fake(fake_M, self.fake_image_pool, size=batch_size) 226 | self.fake_image_pool, fake_X = self.get_fake(fake_X, self.fake_image_pool, size=batch_size) 227 | 228 | # If we have a pool of other images use some of it for real examples 229 | if self.other_masks is not None: 230 | M_other = next(self.other_masks) 231 | X_M = data_utils.sample(np.concatenate([X_M, M_other], axis=0), batch_size) 232 | 233 | # Train Discriminator 234 | zeros = np.zeros((X_M.shape[0],) + self.sdnet.D_model.output_shape[0][1:]) 235 | ones = np.ones(zeros.shape) 236 | 237 | x = [X_M, fake_M, X, fake_X] 238 | y = [zeros, ones, zeros, ones] 239 | _, D_loss_real_M, D_loss_fake_M, D_loss_real_X, D_loss_fake_X = self.sdnet.D_model.train_on_batch(x, y) 240 | epoch_loss['dis_M'].append(np.mean([D_loss_real_M, D_loss_fake_M])) 241 | epoch_loss['dis_X'].append(np.mean([D_loss_real_X, D_loss_fake_X])) 242 | 243 | def sample_X(self, X_L, X_U, size): 244 | """ 245 | Sample images of size=batch size from the labelled and unlabelled array. If we've passed through all 246 | labelled images, ignore them when sampling (and vice versa). 247 | :param X_L: array of labelled images 248 | :param X_U: array of unlabelled images 249 | :return: an image array to be used for calculating fake_masks 250 | """ 251 | # find the batch number that the iterator of the labelled (or unlabelled) images finishes 252 | bn_end = np.min([self.conf.unlabelled_image_num, self.conf.labelled_image_num]) / self.conf.batch_size 253 | if self.batch < bn_end: 254 | all = np.concatenate([X_L, X_U], axis=0) 255 | idx = np.random.choice(all.shape[0], size=size, replace=False) 256 | X = np.array([all[i] for i in idx]) 257 | elif self.conf.labelled_image_num > self.conf.unlabelled_image_num: 258 | idx = np.random.choice(X_L.shape[0], size=size, replace=False) 259 | X = np.array([X_L[i] for i in idx]) 260 | else: 261 | idx = np.random.choice(X_U.shape[0], size=size, replace=False) 262 | X = np.array([X_U[i] for i in idx]) 263 | 264 | return X 265 | 266 | def get_fake(self, pred, fake_pool, size): 267 | """ 268 | Add item to the pool of data. Then select a random number of items from the pool. 269 | :param pred: new datum to add to the pool 270 | :param fake_pool: the data pool of fake images/masks 271 | :return: the sampled data from the pool 272 | """ 273 | fake_pool.extend(pred) 274 | fake_pool = fake_pool[-self.conf.pool_size:] 275 | sel = np.random.choice(len(fake_pool), size=size, replace=False) 276 | fake_A = np.array([fake_pool[ind] for ind in sel]) 277 | return fake_pool, fake_A 278 | 279 | def validate(self, epoch_loss): 280 | """ 281 | Report validation error 282 | :param epoch_loss: dictionary of losses 283 | """ 284 | valid_data = self.loader.load_labelled_data(self.conf.split, 'validation') 285 | mask, _ = self.sdnet.Decomposer.predict(valid_data.images) 286 | assert mask.shape == valid_data.masks.shape 287 | epoch_loss['val_loss'].append((1-costs.dice(valid_data.masks, mask))) 288 | 289 | def stop_criterion(self, es, epoch, logs): 290 | """ 291 | Criterion for early stopping of training 292 | :param es: Keras EarlyStopping callback 293 | :param epoch: epoch number 294 | :param logs: dictionary of losses 295 | :return: True/False: stop training or not 296 | """ 297 | es.model = self.sdnet.Decomposer 298 | es.on_epoch_end(epoch, logs) 299 | if es.stopped_epoch > 0: 300 | return True 301 | 302 | def test(self): 303 | """ 304 | Evaluate a model on the test data. 305 | """ 306 | log.info('Evaluating model on test data') 307 | folder = os.path.join(self.conf.folder, 'test_results_%s' % self.conf.dataset_name) 308 | if not os.path.exists(folder): 309 | os.makedirs(folder) 310 | 311 | test_loader = loader_factory.init_loader(self.conf.dataset_name) 312 | test_data = test_loader.load_labelled_data(self.conf.split, 'test') 313 | 314 | synth = [] 315 | im_dice = {} 316 | samples = os.path.join(folder, 'samples') 317 | if not os.path.exists(samples): 318 | os.makedirs(samples) 319 | 320 | f = open(os.path.join(folder, 'results.csv'), 'w') 321 | f.writelines('Vol, Dice\n') 322 | 323 | for vol_i in test_data.volumes(): 324 | vol_folder = os.path.join(samples, 'vol_%s' % str(vol_i)) 325 | if not os.path.exists(vol_folder): 326 | os.makedirs(vol_folder) 327 | 328 | vol_image = test_data.get_volume_image(vol_i) 329 | vol_mask = test_data.get_volume_mask(vol_i) 330 | assert vol_image.shape[0] > 0 and vol_image.shape == vol_mask.shape 331 | pred, _ = self.sdnet.Decomposer.predict(vol_image) 332 | 333 | synth.append(pred) 334 | im_dice[vol_i] = costs.dice(vol_mask, pred) 335 | f.writelines('%s, %.3f\n' % (str(vol_i), im_dice[vol_i])) 336 | 337 | for i in range(vol_image.shape[0]): 338 | im = np.concatenate([vol_image[i, :, :, 0], pred[i, :, :, 0], vol_mask[i, :, :, 0]], axis=1) 339 | scipy.misc.imsave(os.path.join(vol_folder, 'test_vol%d_sl%d.png' % (vol_i, i)), im) 340 | 341 | print('Dice score: %.3f' % np.mean(list(im_dice.values()))) 342 | f.close() 343 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agis85/spatial_factorisation/233d72511ffb52f52214a68f1c996555345991d0/utils/__init__.py -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import logging 4 | log = logging.getLogger('data_utils') 5 | 6 | 7 | def resample_ants(nii_file, nii_file_newres, new_res=(1.37, 1.37, 10, 1)): 8 | ''' 9 | Call ANTs to resample an image to a given resolution and save a new resampled file. 10 | :param nii_file: the path of the input file 11 | :param nii_file_newres: the path of the output file 12 | :param new_res: the pixel resolution to resample 13 | ''' 14 | print('Resampling %s at resolution %s to file %s' % (nii_file, str(new_res), nii_file_newres)) 15 | os.system('~/bin/ants/bin/ResampleImage %d %s %s %s' % 16 | (len(new_res), nii_file, nii_file_newres, 'x'.join([str(r) for r in new_res]))) 17 | 18 | 19 | def normalise(array, min_value, max_value): 20 | array = (max_value - min_value) * (array - float(array.min())) / (array.max() - array.min()) + min_value 21 | assert array.max() == max_value and array.min() == min_value 22 | return array 23 | 24 | 25 | def crop_same(image_list, mask_list, size=(None, None), mode='equal', pad_mode='constant'): 26 | ''' 27 | Crop the data in the image and mask lists, so that they have the same size. 28 | :param image_list: a list of images. Each element should be 4-dimensional, (sl,h,w,chn) 29 | :param mask_list: a list of masks. Each element should be 4-dimensional, (sl,h,w,chn) 30 | :param size: dimensions to crop the images to. 31 | :param mode: can be one of [equal, left, right]. Denotes where to crop pixels from. Defaults to middle. 32 | :param pad_mode: can be one of ['edge', 'constant']. 'edge' pads using the values of the edge pixels, 33 | 'constant' pads with a constant value 34 | :return: the modified arrays 35 | ''' 36 | min_w = np.min([m.shape[1] for m in mask_list]) if size[0] is None else size[0] 37 | min_h = np.min([m.shape[2] for m in mask_list]) if size[1] is None else size[1] 38 | 39 | # log.debug('Resizing list1 of size %s to size %s' % (str(image_list[0].shape), str((min_w, min_h)))) 40 | # log.debug('Resizing list2 of size %s to size %s' % (str(mask_list[0].shape), str((min_w, min_h)))) 41 | 42 | img_result, msk_result = [], [] 43 | for i in range(len(mask_list)): 44 | im = image_list[i] 45 | m = mask_list[i] 46 | 47 | if m.shape[1] > min_w: 48 | m = _crop(m, 1, min_w, mode) 49 | if im.shape[1] > min_w: 50 | im = _crop(im, 1, min_w, mode) 51 | if m.shape[1] < min_w: 52 | m = _pad(m, 1, min_w, pad_mode) 53 | if im.shape[1] < min_w: 54 | im = _pad(im, 1, min_w, pad_mode) 55 | 56 | if m.shape[2] > min_h: 57 | m = _crop(m, 2, min_h, mode) 58 | if im.shape[2] > min_h: 59 | im = _crop(im, 2, min_h, mode) 60 | if m.shape[2] < min_h: 61 | m = _pad(m, 2, min_h, pad_mode) 62 | if im.shape[2] < min_h: 63 | im = _pad(im, 2, min_h, pad_mode) 64 | 65 | img_result.append(im) 66 | msk_result.append(m) 67 | return img_result, msk_result 68 | 69 | 70 | def _crop(image, dim, nb_pixels, mode): 71 | diff = image.shape[dim] - nb_pixels 72 | if mode == 'equal': 73 | l = int(np.ceil(diff / 2)) 74 | r = image.shape[dim] - l 75 | elif mode == 'right': 76 | l = 0 77 | r = nb_pixels 78 | elif mode == 'left': 79 | l = diff 80 | r = image.shape[dim] 81 | else: 82 | raise 'Unexpected mode: %s. Expected to be one of [equal, left, right].' % mode 83 | 84 | if dim == 1: 85 | return image[:, l:r, :, :] 86 | elif dim == 2: 87 | return image[:, :, l:r, :] 88 | else: 89 | return None 90 | 91 | 92 | def _pad(image, dim, nb_pixels, mode): 93 | diff = nb_pixels - image.shape[dim] 94 | l = int(diff / 2) 95 | r = int(diff - l) 96 | if dim == 1: 97 | pad_width = ((0, 0), (l, r), (0, 0), (0, 0)) 98 | elif dim == 2: 99 | pad_width = ((0, 0), (0, 0), (l, r), (0, 0)) 100 | else: 101 | return None 102 | 103 | if mode == 'edge': 104 | new_image = np.pad(image, pad_width, 'edge') 105 | elif mode == 'constant': 106 | new_image = np.pad(image, pad_width, 'constant', constant_values=0) 107 | else: 108 | raise Exception('Invalid pad mode: ' + mode) 109 | 110 | return new_image 111 | 112 | 113 | def sample(data, nb_samples, seed=-1): 114 | if seed > -1: 115 | np.random.seed(seed) 116 | idx = np.random.choice(len(data), size=nb_samples, replace=False) 117 | return np.array([data[i] for i in idx]) 118 | 119 | 120 | def generator(batch, mode, *x): 121 | assert mode in ['overflow', 'no_overflow'] 122 | imshape = x[0].shape 123 | for ar in x: 124 | # case where all inputs are images 125 | if len(ar.shape) == len(imshape): 126 | assert ar.shape[:-1] == imshape[:-1], str(ar.shape) + ' vs ' + str(imshape) 127 | # case where inputs might be arrays of different dimensions 128 | else: 129 | assert ar.shape[0] == imshape[0], str(ar.shape) + ' vs ' + str(imshape) 130 | 131 | start = 0 132 | while 1: 133 | if isempty(*x): # if the arrays are empty do not process and yield empty arrays 134 | log.info('Empty inputs. Return empty arrays') 135 | res = [] 136 | for ar in x: 137 | res.append(np.empty(shape=ar.shape)) 138 | if len(res) > 1: 139 | yield res 140 | else: 141 | yield res[0] 142 | else: 143 | start, ims = generate(start, batch, mode, *x) 144 | if len(ims) == 1: 145 | yield ims[0] 146 | else: 147 | yield ims 148 | 149 | 150 | def isempty(*x): 151 | for ar in x: 152 | if ar.shape[0] > 0: 153 | return False 154 | return True 155 | 156 | 157 | def generate(start, batch, mode, *images): 158 | result = [] 159 | 160 | if mode == 'no_overflow': 161 | for ar in images: 162 | result.append(ar[start:start + batch] + 0) 163 | start += batch 164 | 165 | if start >= len(images[0]): 166 | index = np.array(range(len(images[0]))) 167 | np.random.shuffle(index) 168 | for ar in images: 169 | ar[:] = ar[index] # shuffle array 170 | start = 0 171 | 172 | return start, result 173 | 174 | if start + batch <= len(images[0]): 175 | for ar in images: 176 | result.append(ar[start:start + batch] + 0) 177 | start += batch 178 | return start, result 179 | else: 180 | # shuffle images 181 | index = np.array(range(len(images[0]))) 182 | np.random.shuffle(index) 183 | 184 | extra = batch + start - len(images[0]) # extra images to use from the beginning 185 | for ar in images: 186 | ims = ar[start:] + 0 # last images of array 187 | ar[:] = ar[index] # shuffle array 188 | if extra > 0: 189 | result.append(np.concatenate([ims, ar[0:extra]], axis=0)) 190 | 191 | return extra, result --------------------------------------------------------------------------------