├── no_augmentation.png ├── with_random_shift.png ├── with_random_shift_rotation.png ├── rand_shift_rotation_gamma_noise.png ├── README.md ├── LICENSE.md ├── .gitignore ├── example.py └── transforms.py /no_augmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gsmartensson/nifti_transforms/HEAD/no_augmentation.png -------------------------------------------------------------------------------- /with_random_shift.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gsmartensson/nifti_transforms/HEAD/with_random_shift.png -------------------------------------------------------------------------------- /with_random_shift_rotation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gsmartensson/nifti_transforms/HEAD/with_random_shift_rotation.png -------------------------------------------------------------------------------- /rand_shift_rotation_gamma_noise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gsmartensson/nifti_transforms/HEAD/rand_shift_rotation_gamma_noise.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nifti transforms of 3D MRI images for pytorch 2 | 3 | Code with examples that can be used for data loading and data augmentation of 3D MRI images. 4 | 5 | Cropping, scaling and rotation are computed as individual transformation matrices that are mutliplied before being applied (all at once) to the image data in ApplyAffine(). This way the interpolation, which is the major bottleneck, is done only once. 6 | 7 | `example.py` contains examples of how to use the transformations in `transforms.py`. 8 | 9 | vrsvrsvrsvrs 10 | 11 | Figure: Examples of augmentations. 12 | ## License 13 | 14 | Code is licensed under the MIT License - see [LICENSE.md](LICENSE.md) for details. 15 | 16 | Please note that code relies on third-party libraries, s.a. nibabel [license](http://nipy.org/nibabel/legal.html) and scipy. 17 | 18 | ## Contact 19 | gustav.martensson@ki.se 20 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Gustav Mårtensson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | logs/ 4 | plots/ 5 | runs/ 6 | *.py[cod] 7 | *$py.class 8 | #*.pth.tar 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Example script of how to use the transformations of .nii(.gz) images. 5 | 6 | The transformations matrices (data augmentations) are multiplied before 7 | being applied to the image, resulting in only having to do the time consuming 8 | part once. 9 | 10 | Input is the path to the .nii or .nii.gz image. 11 | 12 | @author: gustav 13 | """ 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | import os 17 | import transforms as tfs 18 | np.random.seed(0) 19 | 20 | # Filename to augment. Uses MNI brain template in FSL dir if installed. 21 | fsl_path=os.environ['FSLDIR'] 22 | fname = os.path.join(fsl_path,'data','standard','MNI152_T1_1mm.nii.gz' ) 23 | 24 | #%% Four different augmentations 25 | new_dim = [160,160,160] # output dimension of transformed images 26 | new_res = [1,1,1]# output resolution (1mm x 1mm x 1mm) 27 | 28 | 29 | # no random augmentation in transform 30 | title_dict = {0:'No augmentation'} 31 | t0=tfs.ComposeMRI([ 32 | tfs.LoadNifti(), # image information stored as dict 33 | tfs.TranslateToCom(), # translate to image's center of mass 34 | tfs.SetResolution(new_dim= new_dim, new_res=new_res), 35 | tfs.CropShift(np.array([0,-1,-30])), # to "manually" shift image to center 36 | tfs.ApplyAffine(so=2), # apply all transforms 37 | tfs.ReturnImageData(), # dict -> numpy array 38 | tfs.ToTensor(), # from numpy to torch.tensor 39 | tfs.SwapAxes(0,1), # swap primary axis 40 | tfs.UnitInterval(), # normalize image to be in [-1,1] 41 | ]) 42 | 43 | # adding random shift 44 | title_dict[1]='With random shift' 45 | t1=tfs.ComposeMRI([ 46 | tfs.LoadNifti(), 47 | tfs.TranslateToCom(), 48 | tfs.SetResolution(new_dim= new_dim, new_res=new_res), 49 | tfs.CropShift(np.array([0,-1,-30])), 50 | tfs.RandomShift([30,0,0]), # random shift 51 | tfs.ApplyAffine(so=2), 52 | tfs.ReturnImageData(), 53 | tfs.SwapAxes(0,1), 54 | tfs.ToTensor(), 55 | tfs.UnitInterval(), 56 | ]) 57 | 58 | # adding rotation 59 | title_dict[2]='With random shift + rotation' 60 | t2=tfs.ComposeMRI([ 61 | tfs.LoadNifti(), 62 | tfs.TranslateToCom(), 63 | tfs.SetResolution(new_dim= new_dim, new_res=new_res), 64 | tfs.CropShift(np.array([0,-1,-30])), 65 | tfs.RandomShift([10,0,0]), 66 | tfs.RandomRotation(angle_interval=[-10,10],rotation_axis=[0,1,0]),# random rotation 67 | tfs.ApplyAffine(so=2), 68 | tfs.ReturnImageData(), 69 | tfs.SwapAxes(0,1), 70 | tfs.ToTensor(), 71 | tfs.UnitInterval(), 72 | ]) 73 | 74 | # adding gamma transform + gaussian noise 75 | title_dict[3]='Rand shift, rotation, gamma, noise' 76 | t3=tfs.ComposeMRI([ 77 | tfs.LoadNifti(), 78 | tfs.TranslateToCom(), 79 | tfs.SetResolution(new_dim= new_dim, new_res=new_res), 80 | tfs.CropShift(np.array([0,-1,-30])), 81 | tfs.RandomShift([10,0,0]), 82 | tfs.RandomRotation(angle_interval=[10,10],rotation_axis=[0,1,0]), 83 | tfs.ApplyAffine(so=2), 84 | tfs.ReturnImageData(), 85 | tfs.SwapAxes(0,1), 86 | tfs.ToTensor(), 87 | tfs.UnitInterval(), 88 | tfs.Gamma(gamma_range=[.8,1.5],chance=1), 89 | tfs.RandomNoise(noise_var=.05), 90 | ]) 91 | 92 | #%% Plot example images 93 | def plot_img(img,title_str=''): 94 | fig=plt.figure() 95 | plt.imshow(np.rot90(img[new_dim[0]//2,:,:]),cmap='gray',vmin=-1,vmax=.6); 96 | 97 | plt.axis('off')#;plt.colorbar() 98 | plt.title(title_str) 99 | plt.show() 100 | fig.savefig(title_str.replace(' ','_').replace('+','_').replace(',','_').replace('__','_').replace('__','_').lower()+'.png') 101 | 102 | for i,transforms in enumerate([t0,t1,t2,t3]): 103 | img = transforms(fname) 104 | plot_img(img,title_dict[i]) 105 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Created on Thu Jun 18 14:39:07 2020 5 | 6 | @author: gustav 7 | ''' 8 | from __future__ import print_function, division 9 | import nibabel 10 | from scipy import ndimage 11 | import numpy as np 12 | import torch 13 | import matplotlib.pyplot as plt 14 | import math 15 | 16 | 17 | class LoadNifti(object): 18 | def __call__(self,file): 19 | # Load mri feils with same orientation and numpy format 20 | a = nibabel.load(file) # load file 21 | a= nibabel.as_closest_canonical(a) # transform into RAS orientation 22 | pixdim = a.header.get('pixdim')[1:4] 23 | a = np.array(a.dataobj) # fetch data as float 32bit 24 | a=np.float32(a) 25 | return {'data':a,'pixdim':pixdim,'affine':[],'new_dim':a.shape} 26 | 27 | class CropShift(object): 28 | ''' 29 | Shift center voxel of croping and adds to affine transformation matrix 30 | ''' 31 | def __init__(self,shift): 32 | self.shift =np.array(shift) 33 | def __call__(self,image): 34 | ndims = len(image['data'].shape) 35 | T = np.identity(ndims+1) 36 | T[0:ndims,-1]=self.shift 37 | 38 | image['crop_shift']=T 39 | image['affine'].append('crop_shift') 40 | return image 41 | 42 | 43 | 44 | class RandomShift(object): 45 | ''' 46 | Randomly shift center voxel of croping between \pm max_shift. 47 | Adds transformation matrix to list of affine transformations. 48 | ''' 49 | def __init__(self, max_shift=[0,0,0]): 50 | self.max_shift =np.array(max_shift) 51 | def __call__(self, image): 52 | ndims = len(image['data'].shape) 53 | 54 | shift = 2*(np.random.rand(ndims)-.5)*self.max_shift 55 | 56 | T = np.identity(ndims+1) 57 | 58 | T[0:ndims,-1]=shift 59 | 60 | image['random_shift']=T 61 | image['affine'].append('random_shift') 62 | return image 63 | 64 | class RandomScaling(object): 65 | ''' 66 | Adds an isotropic random scaling of the 3D image with a 67 | factor s \in scale_range 68 | ''' 69 | def __init__(self,scale_range=[.95,1.05]): 70 | self.scale_range=scale_range 71 | 72 | def __call__(self,image): 73 | old_res = np.array(image['pixdim']) 74 | 75 | scale_factor = np.random.rand()*np.diff(self.scale_range)+self.scale_range[0] 76 | scale_factor = np.ones(len(old_res))*scale_factor 77 | 78 | S = np.ones(old_res.size+1) 79 | S[0:len(scale_factor)] = scale_factor 80 | S = np.diag(S) 81 | 82 | image['random_scale']=S 83 | image['affine'].append('random_scale') 84 | return image 85 | 86 | 87 | class SetResolution(object): 88 | ''' 89 | Specify resolution and size of output array. 90 | args: 91 | new_dim: output dimensions, e.g. [128,128,64] 92 | new_res: output resolution, e.g. [1,1,2] (1mmx1mmx2mm), and scales 93 | image appropriately based on information of original resolution 94 | in nifti header. 95 | ''' 96 | def __init__(self,new_dim,new_res=None): 97 | self.new_res=new_res 98 | self.new_dim = np.array(new_dim) 99 | def __call__(self,image): 100 | old_res = np.array(image['pixdim']) 101 | if self.new_res==None: # don't change resolution, only matrix size 102 | new_res_tmp=image['new_dim'] 103 | else: 104 | new_res_tmp=self.new_res 105 | image['new_dim'] = self.new_dim 106 | new_res_tmp=np.array(new_res_tmp) 107 | #old_size = np.array(image['data'].shape) 108 | scale_factor = (old_res/new_res_tmp) 109 | #scale_factor *= (self.new_dim/old_size) 110 | S = np.ones(old_res.size+1) 111 | S[0:len(scale_factor)] = scale_factor 112 | S = np.diag(S) 113 | #print(S) 114 | image['scale']=S 115 | image['affine'].append('scale') 116 | return image 117 | 118 | class TranslateToCom(object): 119 | ''' 120 | Translate image to center of mass ("Com"). Can be useful as a 121 | quick-and-dirty method to e.g. "center" the brain in an MRI image. 122 | args: 123 | scale_f: downscaling factor used when calculating the center of mass. 124 | scale_f=1 yields precise COM coordinates but more computationally 125 | expensive than e.g. scale_f=4. 126 | ''' 127 | def __init__(self,scale_f=4): 128 | self.f = scale_f 129 | def __call__(self,image): 130 | img_tmp=image['data'][::self.f,::self.f,::self.f] 131 | # prc5 = np.percentile(img_tmp.ravel()[::4],5) 132 | # img_tmp=img_tmp>prc5 133 | # #plt.hist(img_tmp.ravel()) 134 | # #plt.show() 135 | com = self.f*np.array(ndimage.center_of_mass(img_tmp)) 136 | # com = self.f*np.array(ndimage.center_of_mass(image['data'][::self.f,::self.f,::self.f])) 137 | 138 | mid = np.array(image['data'].shape)/2 139 | #print(com) 140 | T = np.identity(len(mid)+1) 141 | T[0:len(mid),-1] = mid-com 142 | image['com'] = T 143 | image['affine'].append('com') 144 | return image 145 | 146 | class RandomRotation(object): 147 | ''' 148 | Adds a random rotation along a random or specifed axis. 149 | args: 150 | angle_interval: upper and lower bound of interval of angle (in degrees) 151 | of random rotation. 152 | rotation_axis: axis of which rotation occurs. If None, 153 | then axis is random. If e.g. [1,0,0], the rotation is done in 154 | coronal plane for brain images. 155 | 156 | Code adapted from https://stackoverflow.com/questions/47623582/efficiently-calculate-list-of-3d-rotation-matrices-in-numpy-or-scipy 157 | ''' 158 | def __init__(self,angle_interval=[-5,5],rotation_axis=None): 159 | self.a_l,self.a_u =angle_interval 160 | self.rotation_axis =rotation_axis 161 | def unit_vector(self,data, axis=None, out=None): 162 | ''' 163 | Return ndarray normalized by length, i.e. Euclidean norm, along axis. 164 | ''' 165 | if out is None: 166 | data = np.array(data, dtype=np.float64, copy=True) 167 | if data.ndim == 1: 168 | data /= math.sqrt(np.dot(data, data)) 169 | return data 170 | else: 171 | if out is not data: 172 | out[:] = np.array(data, copy=False) 173 | data = out 174 | length = np.atleast_1d(np.sum(data*data, axis)) 175 | np.sqrt(length, length) 176 | if axis is not None: 177 | length = np.expand_dims(length, axis) 178 | data /= length 179 | if out is None: 180 | return data 181 | def rotation_matrix(self,angle, direction, point=None): 182 | ''' 183 | Return matrix to rotate about axis defined by point and direction. 184 | 185 | ''' 186 | sina = math.sin(angle) 187 | cosa = math.cos(angle) 188 | direction = self.unit_vector(direction[:3]) 189 | # rotation matrix around unit vector 190 | R = np.diag([cosa, cosa, cosa]) 191 | R += np.outer(direction, direction) * (1.0 - cosa) 192 | direction *= sina 193 | R += np.array([[ 0.0, -direction[2], direction[1]], 194 | [ direction[2], 0.0, -direction[0]], 195 | [-direction[1], direction[0], 0.0]]) 196 | M = np.identity(4) 197 | M[:3, :3] = R 198 | if point is not None: 199 | # rotation not around origin 200 | point = np.array(point[:3], dtype=np.float64, copy=False) 201 | M[:3, 3] = point - np.dot(R, point) 202 | return M 203 | def __call__(self,image): 204 | theta = np.random.uniform(self.a_l,self.a_u) 205 | 206 | angle =theta/180*np.pi 207 | if self.rotation_axis is None: 208 | u=np.random.rand(3)-.5 209 | else: 210 | u=self.rotation_axis 211 | u = u/(np.dot(u,u)) 212 | R = self.rotation_matrix(-angle,u) 213 | 214 | image['rotation']=R 215 | image['affine'].append('rotation') 216 | return image 217 | 218 | class ApplyAffine(object): 219 | ''' 220 | Multiply all previously added affine transformations matrices and 221 | apply them. Returns transformed image. 222 | args: 223 | new_dim: 224 | so: order of interpolation. Trade-off between speed and accuracy. 225 | ''' 226 | def __init__(self,new_dim=None, so = 3,chance=1): 227 | self.chance=chance # if random number is below self.chance then apply transformation 228 | if not new_dim==None: 229 | self.new_dim = np.array(new_dim) 230 | else: 231 | self.new_dim = np.array([new_dim]) 232 | self.so=so 233 | def __call__(self,image): 234 | if image['affine']==[] or np.random.rand()>self.chance: 235 | #print('no transform') 236 | return image 237 | else: 238 | # forward mapping 239 | if self.new_dim[0] is not None: 240 | if np.any(image['new_dim']!=image['data'].shape) and np.any(image['new_dim']!=self.new_dim): 241 | raise Exception('Error - two different new_dim were given. Probably also in SetResolution()?') 242 | # new_dim = image['new_dim'] 243 | else: 244 | # print('setting new dim (or same as before)') 245 | image['new_dim'] = self.new_dim 246 | new_dim = self.new_dim 247 | else: 248 | new_dim = image['new_dim'] 249 | # print('using previous new_dim') 250 | 251 | ndims = len(image['pixdim']) 252 | T=np.identity(ndims+1) 253 | for a in image['affine']: 254 | T = np.dot(image[a],T) 255 | T_inv = np.linalg.inv(T) 256 | 257 | # compute offset for centering translation 258 | c_in = np.array(image['data'].shape)*.5 259 | c_out=np.array(new_dim)*.5 260 | s=c_in-np.dot(T_inv[:ndims,:ndims],c_out) 261 | #tx,ty = -s 262 | translation =np.identity(ndims +1) 263 | translation[0:ndims,-1]=-s 264 | T_inv = np.dot(np.linalg.inv(translation),T_inv) 265 | 266 | image['data'] = ndimage.affine_transform( 267 | image['data'],T_inv,output_shape=new_dim,order=self.so) 268 | return image 269 | 270 | class ReturnImageData(object): 271 | ''' 272 | Return image from dict after ApplyAffine() has been called 273 | ''' 274 | def __call__(self,image): 275 | return image['data'] 276 | 277 | class Gamma(object): 278 | ''' 279 | Apply gamma correction V_out = V_in^(gamma) at probability chance with 280 | a random gamma value \in gamma_range. 281 | 282 | This tranformation should be applied after ReturnImageData() has been called. 283 | ''' 284 | def __init__(self, gamma_range = [.8,1.2],chance=1): 285 | self.chance=chance # a value between [0,1] 286 | self.gamma_range=gamma_range 287 | def __call__(self, image): 288 | if np.random.rand()=self.ul]=self.ul 363 | return image 364 | 365 | class RandomNoise(object): 366 | '''Add random normally distributed noise to image tensor. 367 | 368 | Args: 369 | noise_var (float): maximum variance of added noise 370 | p (float): probability of adding noise 371 | ''' 372 | 373 | def __init__(self, noise_var=.1, p=1): 374 | 375 | self.noise_var = noise_var 376 | self.p = p 377 | 378 | def __call__(self, image): 379 | if torch.rand(1)[0]0: # 50/50 if to rotate 394 | image = np.flip(image,self.axis).copy() 395 | return image 396 | 397 | class PerImageNormalization(object): 398 | ''' 399 | Transforms all pixel values to to have mean= 0 and std = 1 400 | ''' 401 | def __call__(self, image): 402 | image -=image.mean() 403 | image /=image.std() 404 | 405 | return image 406 | 407 | class Window(object): 408 | ''' 409 | Cap image to be between [low,high] 410 | ''' 411 | def __init__(self, low,high): 412 | self.low=low 413 | self.high=high 414 | def __call__(self, image): 415 | # transform data to 416 | 417 | image[imageself.high] =self.high 419 | 420 | return image 421 | 422 | class PrcCap(object): 423 | ''' 424 | Cap all pixel values between two percentiles 425 | args: 426 | low: lower percentile value to cap high pixel values to. 427 | high: upper percentile value to cap high pixel values to. 428 | ''' 429 | def __init__(self, low=5,high=99): 430 | self.low = low 431 | self.high= high 432 | def __call__(self, image): 433 | # transform data to 434 | 435 | l= np.percentile(image,self.low) 436 | h= np.percentile(image,self.high) 437 | image[imageh] = h 438 | 439 | return image 440 | 441 | class UnitInterval(object): 442 | ''' 443 | Transforms all pixel values to be in [-1,1] 444 | ''' 445 | 446 | def __call__(self, image): 447 | # transform data to 448 | 449 | image -=image.min() 450 | image /=image.max() 451 | image = (image-.5)*2 452 | return image 453 | 454 | class ComposeMRI(object): 455 | """ Composes several co_transforms together. 456 | For example: 457 | >>> co_transforms.Compose([ 458 | >>> co_transforms.CenterCrop(10), 459 | >>> co_transforms.ToTensor(), 460 | >>> ]) 461 | #TODO 462 | """ 463 | 464 | def __init__(self, transforms): 465 | self.transforms = transforms 466 | 467 | def __call__(self, input): 468 | for t in self.transforms: 469 | #print(t) 470 | input= t(input) 471 | return input --------------------------------------------------------------------------------