├── no_augmentation.png
├── with_random_shift.png
├── with_random_shift_rotation.png
├── rand_shift_rotation_gamma_noise.png
├── README.md
├── LICENSE.md
├── .gitignore
├── example.py
└── transforms.py
/no_augmentation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gsmartensson/nifti_transforms/HEAD/no_augmentation.png
--------------------------------------------------------------------------------
/with_random_shift.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gsmartensson/nifti_transforms/HEAD/with_random_shift.png
--------------------------------------------------------------------------------
/with_random_shift_rotation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gsmartensson/nifti_transforms/HEAD/with_random_shift_rotation.png
--------------------------------------------------------------------------------
/rand_shift_rotation_gamma_noise.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gsmartensson/nifti_transforms/HEAD/rand_shift_rotation_gamma_noise.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Nifti transforms of 3D MRI images for pytorch
2 |
3 | Code with examples that can be used for data loading and data augmentation of 3D MRI images.
4 |
5 | Cropping, scaling and rotation are computed as individual transformation matrices that are mutliplied before being applied (all at once) to the image data in ApplyAffine(). This way the interpolation, which is the major bottleneck, is done only once.
6 |
7 | `example.py` contains examples of how to use the transformations in `transforms.py`.
8 |
9 | 


10 |
11 | Figure: Examples of augmentations.
12 | ## License
13 |
14 | Code is licensed under the MIT License - see [LICENSE.md](LICENSE.md) for details.
15 |
16 | Please note that code relies on third-party libraries, s.a. nibabel [license](http://nipy.org/nibabel/legal.html) and scipy.
17 |
18 | ## Contact
19 | gustav.martensson@ki.se
20 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Gustav Mårtensson
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | logs/
4 | plots/
5 | runs/
6 | *.py[cod]
7 | *$py.class
8 | #*.pth.tar
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 |
91 | # Spyder project settings
92 | .spyderproject
93 | .spyproject
94 |
95 | # Rope project settings
96 | .ropeproject
97 |
98 | # mkdocs documentation
99 | /site
100 |
101 | # mypy
102 | .mypy_cache/
103 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Example script of how to use the transformations of .nii(.gz) images.
5 |
6 | The transformations matrices (data augmentations) are multiplied before
7 | being applied to the image, resulting in only having to do the time consuming
8 | part once.
9 |
10 | Input is the path to the .nii or .nii.gz image.
11 |
12 | @author: gustav
13 | """
14 | import numpy as np
15 | import matplotlib.pyplot as plt
16 | import os
17 | import transforms as tfs
18 | np.random.seed(0)
19 |
20 | # Filename to augment. Uses MNI brain template in FSL dir if installed.
21 | fsl_path=os.environ['FSLDIR']
22 | fname = os.path.join(fsl_path,'data','standard','MNI152_T1_1mm.nii.gz' )
23 |
24 | #%% Four different augmentations
25 | new_dim = [160,160,160] # output dimension of transformed images
26 | new_res = [1,1,1]# output resolution (1mm x 1mm x 1mm)
27 |
28 |
29 | # no random augmentation in transform
30 | title_dict = {0:'No augmentation'}
31 | t0=tfs.ComposeMRI([
32 | tfs.LoadNifti(), # image information stored as dict
33 | tfs.TranslateToCom(), # translate to image's center of mass
34 | tfs.SetResolution(new_dim= new_dim, new_res=new_res),
35 | tfs.CropShift(np.array([0,-1,-30])), # to "manually" shift image to center
36 | tfs.ApplyAffine(so=2), # apply all transforms
37 | tfs.ReturnImageData(), # dict -> numpy array
38 | tfs.ToTensor(), # from numpy to torch.tensor
39 | tfs.SwapAxes(0,1), # swap primary axis
40 | tfs.UnitInterval(), # normalize image to be in [-1,1]
41 | ])
42 |
43 | # adding random shift
44 | title_dict[1]='With random shift'
45 | t1=tfs.ComposeMRI([
46 | tfs.LoadNifti(),
47 | tfs.TranslateToCom(),
48 | tfs.SetResolution(new_dim= new_dim, new_res=new_res),
49 | tfs.CropShift(np.array([0,-1,-30])),
50 | tfs.RandomShift([30,0,0]), # random shift
51 | tfs.ApplyAffine(so=2),
52 | tfs.ReturnImageData(),
53 | tfs.SwapAxes(0,1),
54 | tfs.ToTensor(),
55 | tfs.UnitInterval(),
56 | ])
57 |
58 | # adding rotation
59 | title_dict[2]='With random shift + rotation'
60 | t2=tfs.ComposeMRI([
61 | tfs.LoadNifti(),
62 | tfs.TranslateToCom(),
63 | tfs.SetResolution(new_dim= new_dim, new_res=new_res),
64 | tfs.CropShift(np.array([0,-1,-30])),
65 | tfs.RandomShift([10,0,0]),
66 | tfs.RandomRotation(angle_interval=[-10,10],rotation_axis=[0,1,0]),# random rotation
67 | tfs.ApplyAffine(so=2),
68 | tfs.ReturnImageData(),
69 | tfs.SwapAxes(0,1),
70 | tfs.ToTensor(),
71 | tfs.UnitInterval(),
72 | ])
73 |
74 | # adding gamma transform + gaussian noise
75 | title_dict[3]='Rand shift, rotation, gamma, noise'
76 | t3=tfs.ComposeMRI([
77 | tfs.LoadNifti(),
78 | tfs.TranslateToCom(),
79 | tfs.SetResolution(new_dim= new_dim, new_res=new_res),
80 | tfs.CropShift(np.array([0,-1,-30])),
81 | tfs.RandomShift([10,0,0]),
82 | tfs.RandomRotation(angle_interval=[10,10],rotation_axis=[0,1,0]),
83 | tfs.ApplyAffine(so=2),
84 | tfs.ReturnImageData(),
85 | tfs.SwapAxes(0,1),
86 | tfs.ToTensor(),
87 | tfs.UnitInterval(),
88 | tfs.Gamma(gamma_range=[.8,1.5],chance=1),
89 | tfs.RandomNoise(noise_var=.05),
90 | ])
91 |
92 | #%% Plot example images
93 | def plot_img(img,title_str=''):
94 | fig=plt.figure()
95 | plt.imshow(np.rot90(img[new_dim[0]//2,:,:]),cmap='gray',vmin=-1,vmax=.6);
96 |
97 | plt.axis('off')#;plt.colorbar()
98 | plt.title(title_str)
99 | plt.show()
100 | fig.savefig(title_str.replace(' ','_').replace('+','_').replace(',','_').replace('__','_').replace('__','_').lower()+'.png')
101 |
102 | for i,transforms in enumerate([t0,t1,t2,t3]):
103 | img = transforms(fname)
104 | plot_img(img,title_dict[i])
105 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Created on Thu Jun 18 14:39:07 2020
5 |
6 | @author: gustav
7 | '''
8 | from __future__ import print_function, division
9 | import nibabel
10 | from scipy import ndimage
11 | import numpy as np
12 | import torch
13 | import matplotlib.pyplot as plt
14 | import math
15 |
16 |
17 | class LoadNifti(object):
18 | def __call__(self,file):
19 | # Load mri feils with same orientation and numpy format
20 | a = nibabel.load(file) # load file
21 | a= nibabel.as_closest_canonical(a) # transform into RAS orientation
22 | pixdim = a.header.get('pixdim')[1:4]
23 | a = np.array(a.dataobj) # fetch data as float 32bit
24 | a=np.float32(a)
25 | return {'data':a,'pixdim':pixdim,'affine':[],'new_dim':a.shape}
26 |
27 | class CropShift(object):
28 | '''
29 | Shift center voxel of croping and adds to affine transformation matrix
30 | '''
31 | def __init__(self,shift):
32 | self.shift =np.array(shift)
33 | def __call__(self,image):
34 | ndims = len(image['data'].shape)
35 | T = np.identity(ndims+1)
36 | T[0:ndims,-1]=self.shift
37 |
38 | image['crop_shift']=T
39 | image['affine'].append('crop_shift')
40 | return image
41 |
42 |
43 |
44 | class RandomShift(object):
45 | '''
46 | Randomly shift center voxel of croping between \pm max_shift.
47 | Adds transformation matrix to list of affine transformations.
48 | '''
49 | def __init__(self, max_shift=[0,0,0]):
50 | self.max_shift =np.array(max_shift)
51 | def __call__(self, image):
52 | ndims = len(image['data'].shape)
53 |
54 | shift = 2*(np.random.rand(ndims)-.5)*self.max_shift
55 |
56 | T = np.identity(ndims+1)
57 |
58 | T[0:ndims,-1]=shift
59 |
60 | image['random_shift']=T
61 | image['affine'].append('random_shift')
62 | return image
63 |
64 | class RandomScaling(object):
65 | '''
66 | Adds an isotropic random scaling of the 3D image with a
67 | factor s \in scale_range
68 | '''
69 | def __init__(self,scale_range=[.95,1.05]):
70 | self.scale_range=scale_range
71 |
72 | def __call__(self,image):
73 | old_res = np.array(image['pixdim'])
74 |
75 | scale_factor = np.random.rand()*np.diff(self.scale_range)+self.scale_range[0]
76 | scale_factor = np.ones(len(old_res))*scale_factor
77 |
78 | S = np.ones(old_res.size+1)
79 | S[0:len(scale_factor)] = scale_factor
80 | S = np.diag(S)
81 |
82 | image['random_scale']=S
83 | image['affine'].append('random_scale')
84 | return image
85 |
86 |
87 | class SetResolution(object):
88 | '''
89 | Specify resolution and size of output array.
90 | args:
91 | new_dim: output dimensions, e.g. [128,128,64]
92 | new_res: output resolution, e.g. [1,1,2] (1mmx1mmx2mm), and scales
93 | image appropriately based on information of original resolution
94 | in nifti header.
95 | '''
96 | def __init__(self,new_dim,new_res=None):
97 | self.new_res=new_res
98 | self.new_dim = np.array(new_dim)
99 | def __call__(self,image):
100 | old_res = np.array(image['pixdim'])
101 | if self.new_res==None: # don't change resolution, only matrix size
102 | new_res_tmp=image['new_dim']
103 | else:
104 | new_res_tmp=self.new_res
105 | image['new_dim'] = self.new_dim
106 | new_res_tmp=np.array(new_res_tmp)
107 | #old_size = np.array(image['data'].shape)
108 | scale_factor = (old_res/new_res_tmp)
109 | #scale_factor *= (self.new_dim/old_size)
110 | S = np.ones(old_res.size+1)
111 | S[0:len(scale_factor)] = scale_factor
112 | S = np.diag(S)
113 | #print(S)
114 | image['scale']=S
115 | image['affine'].append('scale')
116 | return image
117 |
118 | class TranslateToCom(object):
119 | '''
120 | Translate image to center of mass ("Com"). Can be useful as a
121 | quick-and-dirty method to e.g. "center" the brain in an MRI image.
122 | args:
123 | scale_f: downscaling factor used when calculating the center of mass.
124 | scale_f=1 yields precise COM coordinates but more computationally
125 | expensive than e.g. scale_f=4.
126 | '''
127 | def __init__(self,scale_f=4):
128 | self.f = scale_f
129 | def __call__(self,image):
130 | img_tmp=image['data'][::self.f,::self.f,::self.f]
131 | # prc5 = np.percentile(img_tmp.ravel()[::4],5)
132 | # img_tmp=img_tmp>prc5
133 | # #plt.hist(img_tmp.ravel())
134 | # #plt.show()
135 | com = self.f*np.array(ndimage.center_of_mass(img_tmp))
136 | # com = self.f*np.array(ndimage.center_of_mass(image['data'][::self.f,::self.f,::self.f]))
137 |
138 | mid = np.array(image['data'].shape)/2
139 | #print(com)
140 | T = np.identity(len(mid)+1)
141 | T[0:len(mid),-1] = mid-com
142 | image['com'] = T
143 | image['affine'].append('com')
144 | return image
145 |
146 | class RandomRotation(object):
147 | '''
148 | Adds a random rotation along a random or specifed axis.
149 | args:
150 | angle_interval: upper and lower bound of interval of angle (in degrees)
151 | of random rotation.
152 | rotation_axis: axis of which rotation occurs. If None,
153 | then axis is random. If e.g. [1,0,0], the rotation is done in
154 | coronal plane for brain images.
155 |
156 | Code adapted from https://stackoverflow.com/questions/47623582/efficiently-calculate-list-of-3d-rotation-matrices-in-numpy-or-scipy
157 | '''
158 | def __init__(self,angle_interval=[-5,5],rotation_axis=None):
159 | self.a_l,self.a_u =angle_interval
160 | self.rotation_axis =rotation_axis
161 | def unit_vector(self,data, axis=None, out=None):
162 | '''
163 | Return ndarray normalized by length, i.e. Euclidean norm, along axis.
164 | '''
165 | if out is None:
166 | data = np.array(data, dtype=np.float64, copy=True)
167 | if data.ndim == 1:
168 | data /= math.sqrt(np.dot(data, data))
169 | return data
170 | else:
171 | if out is not data:
172 | out[:] = np.array(data, copy=False)
173 | data = out
174 | length = np.atleast_1d(np.sum(data*data, axis))
175 | np.sqrt(length, length)
176 | if axis is not None:
177 | length = np.expand_dims(length, axis)
178 | data /= length
179 | if out is None:
180 | return data
181 | def rotation_matrix(self,angle, direction, point=None):
182 | '''
183 | Return matrix to rotate about axis defined by point and direction.
184 |
185 | '''
186 | sina = math.sin(angle)
187 | cosa = math.cos(angle)
188 | direction = self.unit_vector(direction[:3])
189 | # rotation matrix around unit vector
190 | R = np.diag([cosa, cosa, cosa])
191 | R += np.outer(direction, direction) * (1.0 - cosa)
192 | direction *= sina
193 | R += np.array([[ 0.0, -direction[2], direction[1]],
194 | [ direction[2], 0.0, -direction[0]],
195 | [-direction[1], direction[0], 0.0]])
196 | M = np.identity(4)
197 | M[:3, :3] = R
198 | if point is not None:
199 | # rotation not around origin
200 | point = np.array(point[:3], dtype=np.float64, copy=False)
201 | M[:3, 3] = point - np.dot(R, point)
202 | return M
203 | def __call__(self,image):
204 | theta = np.random.uniform(self.a_l,self.a_u)
205 |
206 | angle =theta/180*np.pi
207 | if self.rotation_axis is None:
208 | u=np.random.rand(3)-.5
209 | else:
210 | u=self.rotation_axis
211 | u = u/(np.dot(u,u))
212 | R = self.rotation_matrix(-angle,u)
213 |
214 | image['rotation']=R
215 | image['affine'].append('rotation')
216 | return image
217 |
218 | class ApplyAffine(object):
219 | '''
220 | Multiply all previously added affine transformations matrices and
221 | apply them. Returns transformed image.
222 | args:
223 | new_dim:
224 | so: order of interpolation. Trade-off between speed and accuracy.
225 | '''
226 | def __init__(self,new_dim=None, so = 3,chance=1):
227 | self.chance=chance # if random number is below self.chance then apply transformation
228 | if not new_dim==None:
229 | self.new_dim = np.array(new_dim)
230 | else:
231 | self.new_dim = np.array([new_dim])
232 | self.so=so
233 | def __call__(self,image):
234 | if image['affine']==[] or np.random.rand()>self.chance:
235 | #print('no transform')
236 | return image
237 | else:
238 | # forward mapping
239 | if self.new_dim[0] is not None:
240 | if np.any(image['new_dim']!=image['data'].shape) and np.any(image['new_dim']!=self.new_dim):
241 | raise Exception('Error - two different new_dim were given. Probably also in SetResolution()?')
242 | # new_dim = image['new_dim']
243 | else:
244 | # print('setting new dim (or same as before)')
245 | image['new_dim'] = self.new_dim
246 | new_dim = self.new_dim
247 | else:
248 | new_dim = image['new_dim']
249 | # print('using previous new_dim')
250 |
251 | ndims = len(image['pixdim'])
252 | T=np.identity(ndims+1)
253 | for a in image['affine']:
254 | T = np.dot(image[a],T)
255 | T_inv = np.linalg.inv(T)
256 |
257 | # compute offset for centering translation
258 | c_in = np.array(image['data'].shape)*.5
259 | c_out=np.array(new_dim)*.5
260 | s=c_in-np.dot(T_inv[:ndims,:ndims],c_out)
261 | #tx,ty = -s
262 | translation =np.identity(ndims +1)
263 | translation[0:ndims,-1]=-s
264 | T_inv = np.dot(np.linalg.inv(translation),T_inv)
265 |
266 | image['data'] = ndimage.affine_transform(
267 | image['data'],T_inv,output_shape=new_dim,order=self.so)
268 | return image
269 |
270 | class ReturnImageData(object):
271 | '''
272 | Return image from dict after ApplyAffine() has been called
273 | '''
274 | def __call__(self,image):
275 | return image['data']
276 |
277 | class Gamma(object):
278 | '''
279 | Apply gamma correction V_out = V_in^(gamma) at probability chance with
280 | a random gamma value \in gamma_range.
281 |
282 | This tranformation should be applied after ReturnImageData() has been called.
283 | '''
284 | def __init__(self, gamma_range = [.8,1.2],chance=1):
285 | self.chance=chance # a value between [0,1]
286 | self.gamma_range=gamma_range
287 | def __call__(self, image):
288 | if np.random.rand()=self.ul]=self.ul
363 | return image
364 |
365 | class RandomNoise(object):
366 | '''Add random normally distributed noise to image tensor.
367 |
368 | Args:
369 | noise_var (float): maximum variance of added noise
370 | p (float): probability of adding noise
371 | '''
372 |
373 | def __init__(self, noise_var=.1, p=1):
374 |
375 | self.noise_var = noise_var
376 | self.p = p
377 |
378 | def __call__(self, image):
379 | if torch.rand(1)[0]0: # 50/50 if to rotate
394 | image = np.flip(image,self.axis).copy()
395 | return image
396 |
397 | class PerImageNormalization(object):
398 | '''
399 | Transforms all pixel values to to have mean= 0 and std = 1
400 | '''
401 | def __call__(self, image):
402 | image -=image.mean()
403 | image /=image.std()
404 |
405 | return image
406 |
407 | class Window(object):
408 | '''
409 | Cap image to be between [low,high]
410 | '''
411 | def __init__(self, low,high):
412 | self.low=low
413 | self.high=high
414 | def __call__(self, image):
415 | # transform data to
416 |
417 | image[imageself.high] =self.high
419 |
420 | return image
421 |
422 | class PrcCap(object):
423 | '''
424 | Cap all pixel values between two percentiles
425 | args:
426 | low: lower percentile value to cap high pixel values to.
427 | high: upper percentile value to cap high pixel values to.
428 | '''
429 | def __init__(self, low=5,high=99):
430 | self.low = low
431 | self.high= high
432 | def __call__(self, image):
433 | # transform data to
434 |
435 | l= np.percentile(image,self.low)
436 | h= np.percentile(image,self.high)
437 | image[imageh] = h
438 |
439 | return image
440 |
441 | class UnitInterval(object):
442 | '''
443 | Transforms all pixel values to be in [-1,1]
444 | '''
445 |
446 | def __call__(self, image):
447 | # transform data to
448 |
449 | image -=image.min()
450 | image /=image.max()
451 | image = (image-.5)*2
452 | return image
453 |
454 | class ComposeMRI(object):
455 | """ Composes several co_transforms together.
456 | For example:
457 | >>> co_transforms.Compose([
458 | >>> co_transforms.CenterCrop(10),
459 | >>> co_transforms.ToTensor(),
460 | >>> ])
461 | #TODO
462 | """
463 |
464 | def __init__(self, transforms):
465 | self.transforms = transforms
466 |
467 | def __call__(self, input):
468 | for t in self.transforms:
469 | #print(t)
470 | input= t(input)
471 | return input
--------------------------------------------------------------------------------