├── .gitignore ├── LICENSE ├── README.md ├── augmentation_utils.py ├── demo.py └── images ├── augmented.png └── diagram_gmm.png /.gitignore: -------------------------------------------------------------------------------- 1 | ## Folders to ignore 2 | .ipynb_checkpoints 3 | .idea 4 | results/ 5 | data/ 6 | 7 | ## Filetypes to ignore 8 | *.pyc 9 | *.pdf 10 | *.xml 11 | *.json 12 | *.nii 13 | *.nii.gz 14 | *.code-workspace 15 | *.save 16 | ~*# 17 | 18 | 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 icometrix 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An augmentation strategy to mimic multi-scanner variability in MRI 2 | 3 | Implementation of a data augmentation (DA) approach with the aim of reducing the scanner bias of models trained on 4 | single-scanner data. If you use this code please cite our paper: 5 | 6 | >Meyer, M.I., de la Rosa, E., Pedrosa de Barros, N., Paolella, R., Van Leemput, K., Sima, D.M. 7 | > [A Contrast Augmentation Approach to Improve Multi-Scanner Generalization in MRI.](https://www.frontiersin.org/articles/10.3389/fnins.2021.708196/full) 8 | > Front. Neurosci. 15:708196. doi: 10.3389/fnins.2021.708196 9 | 10 | 11 | Method 12 | ------------ 13 | The method aims to increase the intensity and contrast variability of a single-scanner dataset such that it is representative of the variability found in a large multi-scanner cohort. To do this we randomly modify the MRI tissue intensities using a Gaussian Mixture Model based approach. As a result the contrast between tissues varies, as seen when different scanners are used at acquisition. 14 | 15 | ![](images/diagram_gmm.png) 16 | 17 | Requirements 18 | ------------ 19 | The following packages are necessary to run the code: 20 | - python>=3.7 21 | - numpy>=1.18 22 | - matplotlib=3.3 23 | - scikit-learn=0.23 24 | - nibabel=3.2 25 | 26 | Applying the method to a single image 27 | ------------------------ 28 | The method is exemplified in the `demo.py` script. To run this you will need a T1 image and the corresponding brain 29 | mask in nifti format (nii.gz). 30 | We do not provide images to test the model on. The example results we show were 31 | acquired on the publicly available OASIS dataset [1]. 32 | You can run this using all the default options by simply doing 33 | 34 | ``` 35 | python demo.py -i -m 36 | ``` 37 | Optional arguments include: 38 | ``` 39 | -n number of components in the mixture (default: 3) 40 | -mu range of variability to change the mean of each component in the mixture (default: [0.03 0.06 0.08]) 41 | -s range of variability to change the standard deviation of each component in the mixture 42 | (default: [0.012 0.011 0.015]) 43 | -p percentiles to use when clipping the intensities. (default [1 99]). 44 | 45 | ``` 46 | The default range of variability to change the values of the components were estimated on a real clinical dataset, and are used as default to represent the method as submitted to ISBI. For the demo, if you wish to use a different number of components you also need to define the new values of -mu and -s. 47 | Let's say you want to run the method using 4 components, to allow the mean and standard deviation of each component 48 | to change by (0.1, 0.2, 0.3, 0.4) and (0.01, 0.02, 0.03, 0.04), and to clip the intensity percentiles at (0.1, 99.9): 49 | 50 | 51 | ``` 52 | python demo.py -i -m -n 4 -mu 0.1 0.2 0.3 0.4 -s 0.01 0.02 0.03 0.4 -p 0.1 99.9 53 | ``` 54 | 55 | This script saves a nifti image to a `results/` folder, which will be created automatically if it does not exist. 56 | A comparison to the original image is also plotted and saved for inspection. 57 | 58 | ![](images/augmented.png) 59 | 60 | ##### NOTE: 61 | 62 | The method performs better if images have been bias-field corrected, but this is not strictly necessary. 63 | When exploring the method it is also possible to set a fixed variation term for each of the components (see 64 | function generate_gmm_image() in augmentation_utils.py). 65 | 66 | 67 | 68 | 69 | ---------------------------------- 70 | 71 | 72 | -------- 73 | 74 | ### References 75 | [1] Marcus, DS, Wang, TH, Parker, J, Csernansky, JG, Morris, JC, Buckner, RL. 76 | *Open Access Series of Imaging Studies (OASIS): Cross-Sectional MRI Data in Young, Middle Aged, Nondemented, and Demented Older Adults* 77 | Journal of Cognitive Neuroscience, 19, 1498-1507. doi: 10.1162/jocn.2007.19.9.1498 78 | -------------------------------------------------------------------------------- /augmentation_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | File containing auxiliary functions for GMM-based data augmentation. 3 | """ 4 | 5 | import numpy as np 6 | from sklearn import mixture 7 | 8 | def normalize_image(image, clip_percentiles=False, pmin=1, pmax=99): 9 | """ 10 | Function to normalize the images between [0,1]. If percentiles is set to True it clips the intensities at 11 | percentile 1 and 99 12 | :param image: numpy array containing the image 13 | :param clip_percentiles: set to True to clip intensities. (default: False) 14 | :param pmin: lower percentile to clip 15 | :param pmax: upper percentile to clip 16 | :return: normalized image [0,1] 17 | """ 18 | if clip_percentiles is True: 19 | pmin = np.percentile(image, pmin) 20 | pmax = np.percentile(image, pmax) 21 | v = np.clip(image, pmin, pmax) 22 | else: 23 | v = image.copy() 24 | 25 | v_min = v.min(axis=(0, 1, 2), keepdims=True) 26 | v_max = v.max(axis=(0, 1, 2), keepdims=True) 27 | 28 | return (v - v_min) / (v_max - v_min) 29 | 30 | 31 | def select_component_size(x, min_components=1, max_components=5): 32 | """ 33 | Function that selects the optimal number of components for a given image, based on the BIC criterion. 34 | :param x: non-zero values of image. Should be of shape (N,1). Make sure to use X=np.expand_dims(data(data>0),1) 35 | :param min_components: minimum number of components allowed for the gmm 36 | :param max_components: maximum number of components allowed for the gmm 37 | :return: optimal number of components for X 38 | """ 39 | lowest_bic = np.infty 40 | bic = [] 41 | n_components_range = range(min_components, max_components) 42 | cv_type = 'full' # covariance type for the GaussianMixture function 43 | best_gmm = None 44 | for n_components in n_components_range: 45 | # Fit a Gaussian mixture with EM 46 | gmm = mixture.GaussianMixture(n_components=n_components, 47 | covariance_type=cv_type) 48 | gmm.fit(x) 49 | bic.append(gmm.aic(x)) 50 | if bic[-1] < lowest_bic: 51 | lowest_bic = bic[-1] 52 | best_gmm = gmm 53 | 54 | return best_gmm.n_components, best_gmm 55 | 56 | 57 | def fit_gmm(x, n_components=3): 58 | """ Fit the GMM to the data 59 | :param x: non-zero values of image. Should be of shape (N,1). Make sure to use X=np.expand_dims(data(data>0),1) 60 | :param n_components: number of components in the mixture. Set to None to select the optimal component number based 61 | on the BIC criterion. Default: 3 62 | :return: GMM model, fit to the data 63 | """ 64 | if n_components is None: 65 | n_components, gmm = select_component_size(x, min_components=3, max_components=7) 66 | else: 67 | gmm = mixture.GaussianMixture(n_components=n_components, covariance_type='diag', tol=1e-3) 68 | 69 | return gmm.fit(x) 70 | 71 | 72 | def get_new_components(gmm, p_mu=None, q_sigma=None, 73 | std_means=None, std_sigma=None): 74 | """ Computes the new intensity components by shifting the mean and variance of the individual components 75 | predicted by the GMM 76 | :param gmm: GMM model fit to 77 | :param p_mu: tuple containing the value by which to change the mean in each component in the mixture. Use only to 78 | check the behaviour of the method or to get specific changes. 79 | :param q_sigma: tuple containing the value by which to change the standard deviation in each component in the mixture. 80 | Use only to check the behaviour of the method or to get specific changes. 81 | :param std_means: tuple containing the range of variability to change the mean of each component in the 82 | mixture. If set to None, all components are augmented in a similar way 83 | :param std_sigma: tuple containing the range of variability to change the standard deviation of each component in 84 | the mixture. If set to None, all components are augmented in a similar way 85 | :return dictionary containing original and updated means and standard deviations for each component 86 | """ 87 | 88 | sort_indices = gmm.means_[:, 0].argsort(axis=0) 89 | mu = np.array(gmm.means_[:, 0][sort_indices]) 90 | std = np.array(np.sqrt(gmm.covariances_[:, 0])[sort_indices]) 91 | 92 | n_components = mu.shape[0] 93 | 94 | if std_means is not None: 95 | # use pre-computed intervals to draw values for each component in the mixture 96 | rng = np.random.default_rng() 97 | if p_mu is None: 98 | var_mean_diffs = np.array(std_means) 99 | p_mu = rng.uniform(-var_mean_diffs, var_mean_diffs) 100 | if std_sigma is not None: 101 | rng = np.random.default_rng() 102 | if q_sigma is None: 103 | var_std_diffs = np.array(std_sigma) 104 | q_sigma = rng.uniform(-var_std_diffs, var_std_diffs) 105 | else: 106 | # Draw random values for each component in the mixture 107 | # Multiply by random int for shifting left (-1), right (1) or not changing (0) the parameter. 108 | if p_mu is None: 109 | p_mu = 0.06 * np.random.random(n_components) * np.random.randint(-1, 2, n_components) 110 | if q_sigma is None: 111 | q_sigma = 0.005 * np.random.random(n_components) * np.random.randint(-1, 2, n_components) 112 | 113 | new_mu = mu + p_mu 114 | new_std = std + q_sigma 115 | 116 | return {'mu': mu, 'std': std, 'new_mu': new_mu, 'new_std': new_std} 117 | 118 | 119 | def reconstruct_intensities(data, dict_parameters): 120 | """ Reconstruct the new image intensities from the new components. 121 | :param data: original image intensities (nonzero) 122 | :param dict_parameters: dictionary containing original and updated means and standard deviations for each component 123 | :return intensities_im: updated intensities for the new components 124 | """ 125 | mu, std = dict_parameters['mu'], dict_parameters['std'] 126 | new_mu, new_std = dict_parameters['new_mu'], dict_parameters['new_std'] 127 | n_components = len(mu) 128 | 129 | # if we know the values of mean (mu) and standard deviation (sigma) we can find the new value of a voxel v 130 | # Fist we find the value of a factor w that informs about the percentile a given pixel belongs to: mu*v = d*sigma 131 | d_im = np.zeros(((n_components,) + data.shape)) 132 | for k in range(n_components): 133 | d_im[k] = (data.ravel() - mu[k]) / (std[k] + 1e-7) 134 | 135 | # we force the new pixel intensity to lie within the same percentile in the new distribution as in the original 136 | # distribution: px = mu + d*sigma 137 | intensities_im = np.zeros(((n_components,) + data.shape)) 138 | for k in range(n_components): 139 | intensities_im[k] = new_mu[k] + d_im[k] * new_std[k] 140 | 141 | return intensities_im 142 | 143 | 144 | def get_new_image_composed(intensities_im, probas_original): 145 | """ 146 | Compose the new image (brain) 147 | :param intensities_im: image intensities for the new components 148 | :param probas_original: initial probabilities for each component (as predicted by the GMM model) 149 | :return new_image_composed: new image after augmentation (skull stripped) 150 | """ 151 | n_components = probas_original.shape[1] 152 | new_image_composed = np.zeros(intensities_im[0].shape) 153 | for k in range(n_components): 154 | new_image_composed = new_image_composed + probas_original[:, k] * intensities_im[k] 155 | 156 | return new_image_composed 157 | 158 | 159 | def generate_gmm_image(image, mask=None, n_components=None, 160 | q_sigma=None, p_mu=None, 161 | std_means=None, std_sigma=None, 162 | normalize=True, percentiles=True): 163 | """ 164 | Funtion that takes an image and generates a new one by shifting the 165 | :param image: image to transform 166 | :param mask: brain mask. 167 | :param n_components: number of components in the mixture. if set to None will select best model based on AIC 168 | :param p_mu: tuple with same dimension as n_components containing factors to add to the mean of individual 169 | components. (default: None) 170 | :param q_sigma: list with same dimension as n_components containing factors to add to the std of individual 171 | components. (default: None) 172 | :param std_means: tuple containing the range of variability to change the mean of each component in the 173 | mixture. If set to None, all components are augmented in a similar way 174 | :param std_sigma: tuple containing the range of variability to change the standard deviation of each component in 175 | the mixture. If set to None, all components are augmented in a similar way 176 | :param normalize: if set to True will clip intensities at percentile 1 and 99 normalize the images between [0,1] 177 | :param percentiles: set to True to clip at percentiles 1 and 99. 178 | :return: new_image: image with updated intensities inside the brain mask, normalized to [0,1] 179 | """ 180 | 181 | if normalize: 182 | image = normalize_image(image, clip_percentiles=percentiles, pmin=1, pmax=99) # the percentiles can be changed 183 | 184 | if mask is None: 185 | masked_image = image 186 | else: 187 | masked_image = image * mask 188 | 189 | # # we only want nonzero values 190 | data = masked_image[masked_image > 0] 191 | x = np.expand_dims(data, 1) 192 | 193 | gmm = fit_gmm(x, n_components) 194 | sort_indices = gmm.means_[:, 0].argsort(axis=0) 195 | 196 | # Estimate the posterior probabilities 197 | probas_original = gmm.predict_proba(x)[:, sort_indices] 198 | 199 | # Get the new intensity components 200 | params_dict = get_new_components(gmm, p_mu=p_mu, q_sigma=q_sigma, 201 | std_means=std_means, std_sigma=std_sigma) 202 | intensities_im = reconstruct_intensities(data, params_dict) 203 | 204 | # Then we add the three predicted images by taking into consideration the probability that each pixel belongs to a 205 | # certain component of the gaussian mixture (probas_original) 206 | new_image_composed = get_new_image_composed(intensities_im, probas_original) 207 | 208 | # Reconstruct the image 209 | new_image = np.zeros(image.shape) 210 | new_image[np.where(masked_image > 0)] = new_image_composed 211 | 212 | # Put the skull back 213 | new_image[np.where(masked_image == 0)] = image[np.where(masked_image == 0)] 214 | 215 | # Return the image in [0,1] 216 | new_image = normalize_image(new_image, clip_percentiles=False) 217 | 218 | return new_image 219 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | """ Example run od the GMM augmentation """ 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import nibabel as nib 6 | 7 | import argparse 8 | import os 9 | 10 | from augmentation_utils import generate_gmm_image, normalize_image 11 | 12 | parser = argparse.ArgumentParser(description='GMM-based augmentation') 13 | 14 | parser.add_argument('-i', 15 | '--img_dir', 16 | help='Image_dir', 17 | type=str, 18 | default='data/t1.nii.gz') 19 | 20 | parser.add_argument('-m', 21 | '--mask_dir', 22 | help='Mask dir', 23 | type=str, 24 | default='data/mask.nii.gz') 25 | 26 | parser.add_argument('-n', 27 | '--n_components', 28 | type=int, 29 | default=3 30 | ) 31 | 32 | parser.add_argument('-mu', 33 | '--std_means', 34 | nargs='+', 35 | type=float, 36 | default=(0.03, 0.06, 0.08) 37 | ) 38 | 39 | parser.add_argument('-s', 40 | '--std_sigma', 41 | nargs='+', 42 | type=float, 43 | default=(0.012, 0.011, 0.015) 44 | ) 45 | 46 | parser.add_argument('-p', 47 | '--percentiles', 48 | nargs='+', 49 | type=float, 50 | default=(1, 99) 51 | ) 52 | 53 | 54 | def create_dir(mypath): 55 | """Create a directory if it does not exist.""" 56 | try: 57 | os.makedirs(mypath) 58 | except OSError as exc: 59 | if os.path.isdir(mypath): 60 | pass 61 | else: 62 | raise 63 | 64 | 65 | def demo_gmm_augmentation(args): 66 | 67 | # Load the image and mask 68 | print('Loading image and mask...') 69 | t1_path = args.img_dir 70 | mask_path = args.mask_dir 71 | t1 = nib.load(t1_path) 72 | t1_data = t1.get_fdata() 73 | mask = nib.load(mask_path).get_fdata() 74 | 75 | n_components = args.n_components 76 | 77 | # Clip percentiles and normalize the image 78 | t1_normalized = normalize_image(t1_data, clip_percentiles=False, pmin=args.percentiles[0], pmax=args.percentiles[1]) 79 | 80 | print('Augmenting image...') 81 | new_image = generate_gmm_image(t1_normalized, mask=mask, 82 | n_components=n_components, 83 | std_means=tuple(args.std_means), 84 | std_sigma=tuple(args.std_sigma), 85 | p_mu=None, q_sigma=None, 86 | normalize=False) 87 | 88 | # Save the augmented image 89 | print('Saving augmented image...') 90 | create_dir('results/') 91 | nib.save(nib.Nifti1Image(new_image, t1.affine, header=t1.header), 'results/augmented_img.nii.gz') 92 | 93 | # Plot and save the figure 94 | fig, ax = plt.subplots(1, 2, figsize=(15, 8)) 95 | 96 | ax[0].imshow(t1_normalized[:, :, int(np.shape(t1_normalized)[2] / 2)].T, origin='lower', cmap='gray', vmin=0, vmax=1) 97 | ax[0].set_title('Original image') 98 | 99 | ax[1].imshow(new_image[:, :, int(np.shape(new_image)[2] / 2)].T, origin='lower', cmap='gray', vmin=0, vmax=1) 100 | ax[1].set_title('Augmented image') 101 | plt.savefig('results/demo.png') 102 | 103 | 104 | if __name__ == "__main__": 105 | demo_gmm_augmentation(parser.parse_args()) 106 | -------------------------------------------------------------------------------- /images/augmented.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icometrix/gmm-augmentation/3dbddc3fe4a02a787e92bc2ab6dc2c80a0a07987/images/augmented.png -------------------------------------------------------------------------------- /images/diagram_gmm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icometrix/gmm-augmentation/3dbddc3fe4a02a787e92bc2ab6dc2c80a0a07987/images/diagram_gmm.png --------------------------------------------------------------------------------