├── LICENSE
├── PreprocessingMRI.md
├── README.md
├── ViT-V-Net
├── __pycache__
│ ├── ViT_V_Net.cpython-38.pyc
│ ├── configs.cpython-38.pyc
│ ├── losses.cpython-38.pyc
│ ├── models.cpython-38.pyc
│ ├── utils.cpython-38.pyc
│ └── vit_reg_configs.cpython-38.pyc
├── configs.py
├── data
│ ├── __pycache__
│ │ ├── data_utils.cpython-38.pyc
│ │ ├── datasets.cpython-38.pyc
│ │ ├── rand.cpython-38.pyc
│ │ └── trans.cpython-38.pyc
│ ├── data_utils.py
│ ├── datasets.py
│ ├── rand.py
│ └── trans.py
├── infer.py
├── label_info.txt
├── losses.py
├── models.py
├── train.py
└── utils.py
└── figures
├── ViTVNet_res.jpg
├── dice_details_.jpg
├── net_arch.jpg
└── trans_arch.jpg
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Junyu Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/PreprocessingMRI.md:
--------------------------------------------------------------------------------
1 | 1. Install FreeSurfer from https://surfer.nmr.mgh.harvard.edu/fswiki/DownloadAndInstall
2 | 2. ```export FREESURFER_HOME=/your_freesurfer_directory```
3 | 3. ```source $FREESURFER_HOME/SetUpFreeSurfer.sh```
4 | 4. ```export SUBJECTS_DIR=/dataset_directory```
5 | 5. ```recon-all -parallel -i dataset_directory/img_name.nii -autorecon1 -subjid img_name``` -> This step does motion correction, skull stripping, affine transform comuptation, and intensity normalization.
6 | 6. ```mri_convert dataset_directory/img_name/mri/brainmask.mgz dataset_directory/img_name/mri/brainmask.nii.gz``` -> This step converts the preprocessed image from .mgz into .nii format.
7 | 7. ```mri_convert dataset_directory/img_name/mri/brainmask.mgz --apply_transform dataset_directory/img_name/mri/transforms/talairach.xfm -o dataset_directory/img_name/mri/brainmask_align.mgz``` -> This step does affine tranform to Talairach space.
8 | 8. ```mri_convert dataset_directory/img_name/mri/brainmask_align.mgz dataset_directory/img_name/mri/brainmask_align.nii.gz``` -> This step converts the transformed image from .mgz into .nii format.
9 | 9. ```recon-all -parallel -s dataset_directory/img_name.nii -subcortseg -subjid img_name``` -> This step does subcortical segmentation.
10 | 10. ```mri_convert dataset_directory/img_name/mri/aseg.auto.mgz dataset_directory/img_name/mri/aseg.nii.gz``` -> This step converts label image from .mgz into .nii format.
11 | 11. ```mri_convert -rt nearest dataset_directory/img_name/mri/aseg.auto.mgz --apply_transform dataset_directory/img_name/mri/transforms/talairach.xfm -o dataset_directory/img_name/mri/aseg_align.mgz``` -> This step does affine tranform to Talairach space using nearest neighbor interpolation for label image.
12 | 12. ```mri_convert dataset_directory/img_name/mri/aseg_align.mgz dataset_directory/img_name/mri/aseg_align.nii.gz``` -> This step converts the transformed label image from .mgz into .nii format.
13 |
14 | Note that these steps may take up to **12-24 hours per image** base on our experience. Therefore running these commands in parallel on a server or a cluster is recommended.
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ViT-V-Net: Vision Transformer for Volumetric Medical Image Registration
2 |
3 |
[](https://arxiv.org/abs/2104.06468)
4 |
5 | **
Please also check out our newly proposed registration model :point_right: [TransMorph](https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration)**\
6 | **The pretrained model and the quantitative results of ViT-V-Net on IXI dataset are available here: [IXI_dataset](https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration/blob/main/IXI/TransMorph_on_IXI.md).\
7 | Additionally, we have made our preprocessed IXI dataset publicly available!**
8 |
9 | keywords: vision transformer, convolutional neural networks, image registration
10 |
11 | This is a **PyTorch** implementation of my short paper:
12 |
13 | Chen, Junyu, et al. "ViT-V-Net: Vision Transformer for Unsupervised Volumetric Medical Image Registration. " Medical Imaging with Deep Learning (MIDL), 2021.
14 |
15 |
16 | ***train.py*** is the training script.
17 | ***models.py*** contains ViT-V-Net model.
18 |
19 | ***Pretrained ViT-V-Net:*** pretrained model
20 |
21 | ***Dataset:*** Due to restrictions, we cannot distribute our brain MRI data. However, several brain MRI datasets are publicly available online: IXI, ADNI, OASIS, ABIDE, etc. Note that those datasets may not contain labels (segmentation). To generate labels, you can use FreeSurfer, which is an open-source software for normalizing brain MRI images. Here are some useful commands in FreeSurfer: Brain MRI preprocessing and subcortical segmentation using FreeSurfer.
22 |
23 | ## Model Architecture:
24 |
25 |
26 | ### Vision Transformer Achitecture:
27 |
28 |
29 | ## Example Results:
30 |
31 |
32 | ## Quantitative Results:
33 |
34 |
35 |
36 | ## Reference:
37 | TransUnet
38 |
39 | ViT-pytorch
40 |
41 | VoxelMorph
42 |
43 |
44 | If you find this code is useful in your research, please consider to cite:
45 |
46 | @inproceedings{chen2021vitvnet,
47 | title={ViT-V-Net: Vision Transformer for Unsupervised Volumetric Medical Image Registration},
48 | author={Junyu Chen and Yufan He and Eric Frey and Ye Li and Yong Du},
49 | booktitle={Medical Imaging with Deep Learning},
50 | year={2021},
51 | url={https://openreview.net/forum?id=h3HC1EU7AEz}
52 | }
53 |
54 | ### About Me
55 |
--------------------------------------------------------------------------------
/ViT-V-Net/__pycache__/ViT_V_Net.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/__pycache__/ViT_V_Net.cpython-38.pyc
--------------------------------------------------------------------------------
/ViT-V-Net/__pycache__/configs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/__pycache__/configs.cpython-38.pyc
--------------------------------------------------------------------------------
/ViT-V-Net/__pycache__/losses.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/__pycache__/losses.cpython-38.pyc
--------------------------------------------------------------------------------
/ViT-V-Net/__pycache__/models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/__pycache__/models.cpython-38.pyc
--------------------------------------------------------------------------------
/ViT-V-Net/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/ViT-V-Net/__pycache__/vit_reg_configs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/__pycache__/vit_reg_configs.cpython-38.pyc
--------------------------------------------------------------------------------
/ViT-V-Net/configs.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 | def get_3DReg_config():
4 | config = ml_collections.ConfigDict()
5 | config.patches = ml_collections.ConfigDict({'size': (8, 8, 8)})
6 | config.patches.grid = (8, 8, 8)
7 | config.hidden_size = 252
8 | config.transformer = ml_collections.ConfigDict()
9 | config.transformer.mlp_dim = 3072
10 | config.transformer.num_heads = 12
11 | config.transformer.num_layers = 12
12 | config.transformer.attention_dropout_rate = 0.0
13 | config.transformer.dropout_rate = 0.1
14 | config.patch_size = 8
15 |
16 | config.conv_first_channel = 512
17 | config.encoder_channels = (16, 32, 32)
18 | config.down_factor = 2
19 | config.down_num = 2
20 | config.decoder_channels = (96, 48, 32, 32, 16)
21 | config.skip_channels = (32, 32, 32, 32, 16)
22 | config.n_dims = 3
23 | config.n_skip = 5
24 | return config
25 |
--------------------------------------------------------------------------------
/ViT-V-Net/data/__pycache__/data_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/data/__pycache__/data_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/ViT-V-Net/data/__pycache__/datasets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/data/__pycache__/datasets.cpython-38.pyc
--------------------------------------------------------------------------------
/ViT-V-Net/data/__pycache__/rand.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/data/__pycache__/rand.cpython-38.pyc
--------------------------------------------------------------------------------
/ViT-V-Net/data/__pycache__/trans.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/data/__pycache__/trans.cpython-38.pyc
--------------------------------------------------------------------------------
/ViT-V-Net/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import pickle
3 | import numpy as np
4 | import torch
5 |
6 | M = 2 ** 32 - 1
7 |
8 |
9 | def init_fn(worker):
10 | seed = torch.LongTensor(1).random_().item()
11 | seed = (seed + worker) % M
12 | np.random.seed(seed)
13 | random.seed(seed)
14 |
15 |
16 | def add_mask(x, mask, dim=1):
17 | mask = mask.unsqueeze(dim)
18 | shape = list(x.shape);
19 | shape[dim] += 21
20 | new_x = x.new(*shape).zero_()
21 | new_x = new_x.scatter_(dim, mask, 1.0)
22 | s = [slice(None)] * len(shape)
23 | s[dim] = slice(21, None)
24 | new_x[s] = x
25 | return new_x
26 |
27 |
28 | def sample(x, size):
29 | # https://gist.github.com/yoavram/4134617
30 | i = random.sample(range(x.shape[0]), size)
31 | return torch.tensor(x[i], dtype=torch.int16)
32 | # x = np.random.permutation(x)
33 | # return torch.tensor(x[:size])
34 |
35 |
36 | def pkload(fname):
37 | with open(fname, 'rb') as f:
38 | return pickle.load(f)
39 |
40 |
41 | _shape = (240, 240, 155)
42 |
43 |
44 | def get_all_coords(stride):
45 | return torch.tensor(
46 | np.stack([v.reshape(-1) for v in
47 | np.meshgrid(
48 | *[stride // 2 + np.arange(0, s, stride) for s in _shape],
49 | indexing='ij')],
50 | -1), dtype=torch.int16)
51 |
52 |
53 | _zero = torch.tensor([0])
54 |
55 |
56 | def gen_feats():
57 | x, y, z = 240, 240, 155
58 | feats = np.stack(
59 | np.meshgrid(
60 | np.arange(x), np.arange(y), np.arange(z),
61 | indexing='ij'), -1).astype('float32')
62 | shape = np.array([x, y, z])
63 | feats -= shape / 2.0
64 | feats /= shape
65 |
66 | return feats
--------------------------------------------------------------------------------
/ViT-V-Net/data/datasets.py:
--------------------------------------------------------------------------------
1 | import os, glob
2 | import torch, sys
3 | from torch.utils.data import Dataset
4 | from .data_utils import pkload
5 | import matplotlib.pyplot as plt
6 |
7 | import numpy as np
8 |
9 |
10 | class JHUBrainDataset(Dataset):
11 | def __init__(self, data_path, transforms):
12 | self.paths = data_path
13 | self.transforms = transforms
14 |
15 | def one_hot(self, img, C):
16 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3]))
17 | for i in range(C):
18 | out[i,...] = img == i
19 | return out
20 |
21 | def __getitem__(self, index):
22 | path = self.paths[index]
23 | x, y = pkload(path)
24 | #print(x.shape)
25 | #print(x.shape)
26 | #print(np.unique(y))
27 | # print(x.shape, y.shape)#(240, 240, 155) (240, 240, 155)
28 | # transforms work with nhwtc
29 | x, y = x[None, ...], y[None, ...]
30 | # print(x.shape, y.shape)#(1, 240, 240, 155) (1, 240, 240, 155)
31 | x,y = self.transforms([x, y])
32 | #y = self.one_hot(y, 2)
33 | #print(y.shape)
34 | #sys.exit(0)
35 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth]
36 | y = np.ascontiguousarray(y)
37 | #plt.figure()
38 | #plt.subplot(1, 2, 1)
39 | #plt.imshow(x[0, :, :, 8], cmap='gray')
40 | #plt.subplot(1, 2, 2)
41 | #plt.imshow(y[0, :, :, 8], cmap='gray')
42 | #plt.show()
43 | #sys.exit(0)
44 | #y = np.squeeze(y, axis=0)
45 | x, y = torch.from_numpy(x), torch.from_numpy(y)
46 | return x, y
47 |
48 | def __len__(self):
49 | return len(self.paths)
50 |
51 |
52 | class JHUBrainInferDataset(Dataset):
53 | def __init__(self, data_path, transforms):
54 | self.paths = data_path
55 | self.transforms = transforms
56 |
57 | def one_hot(self, img, C):
58 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3]))
59 | for i in range(C):
60 | out[i,...] = img == i
61 | return out
62 |
63 | def __getitem__(self, index):
64 | path = self.paths[index]
65 | x, y, x_seg, y_seg = pkload(path)
66 | #print(x.shape)
67 | #print(x.shape)
68 | #print(np.unique(y))
69 | # print(x.shape, y.shape)#(240, 240, 155) (240, 240, 155)
70 | # transforms work with nhwtc
71 | x, y = x[None, ...], y[None, ...]
72 | x_seg, y_seg= x_seg[None, ...], y_seg[None, ...]
73 | # print(x.shape, y.shape)#(1, 240, 240, 155) (1, 240, 240, 155)
74 | x, x_seg = self.transforms([x, x_seg])
75 | y, y_seg = self.transforms([y, y_seg])
76 | #y = self.one_hot(y, 2)
77 | #print(y.shape)
78 | #sys.exit(0)
79 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth]
80 | y = np.ascontiguousarray(y)
81 | x_seg = np.ascontiguousarray(x_seg) # [Bsize,channelsHeight,,Width,Depth]
82 | y_seg = np.ascontiguousarray(y_seg)
83 | #plt.figure()
84 | #plt.subplot(1, 2, 1)
85 | #plt.imshow(x[0, :, :, 8], cmap='gray')
86 | #plt.subplot(1, 2, 2)
87 | #plt.imshow(y[0, :, :, 8], cmap='gray')
88 | #plt.show()
89 | #sys.exit(0)
90 | #y = np.squeeze(y, axis=0)
91 | x, y, x_seg, y_seg = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(x_seg), torch.from_numpy(y_seg)
92 | return x, y, x_seg, y_seg
93 |
94 | def __len__(self):
95 | return len(self.paths)
--------------------------------------------------------------------------------
/ViT-V-Net/data/rand.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 |
4 | class Uniform(object):
5 | def __init__(self, a, b):
6 | self.a = a
7 | self.b = b
8 |
9 | def sample(self):
10 | return random.uniform(self.a, self.b)
11 |
12 |
13 | class Gaussian(object):
14 | def __init__(self, mean, std):
15 | self.mean = mean
16 | self.std = std
17 |
18 | def sample(self):
19 | return random.gauss(self.mean, self.std)
20 |
21 |
22 | class Constant(object):
23 | def __init__(self, val):
24 | self.val = val
25 |
26 | def sample(self):
27 | return self.val
--------------------------------------------------------------------------------
/ViT-V-Net/data/trans.py:
--------------------------------------------------------------------------------
1 | # import math
2 | import random
3 | import collections
4 | import numpy as np
5 | import torch, sys, random, math
6 | from scipy import ndimage
7 |
8 | from .rand import Constant, Uniform, Gaussian
9 | from scipy.ndimage import rotate
10 | from skimage.transform import rescale, resize
11 |
12 | class Base(object):
13 | def sample(self, *shape):
14 | return shape
15 |
16 | def tf(self, img, k=0):
17 | return img
18 |
19 | def __call__(self, img, dim=3, reuse=False): # class -> func()
20 | # image: nhwtc
21 | # shape: no first dim
22 | if not reuse:
23 | im = img if isinstance(img, np.ndarray) else img[0]
24 | # how to know if the last dim is channel??
25 | # nhwtc vs nhwt??
26 | shape = im.shape[1:dim+1]
27 | # print(dim,shape) # 3, (240,240,155)
28 | self.sample(*shape)
29 |
30 | if isinstance(img, collections.Sequence):
31 | return [self.tf(x, k) for k, x in enumerate(img)] # img:k=0,label:k=1
32 |
33 | return self.tf(img)
34 |
35 | def __str__(self):
36 | return 'Identity()'
37 |
38 | Identity = Base
39 |
40 | # gemetric transformations, need a buffers
41 | # first axis is N
42 | class Rot90(Base):
43 | def __init__(self, axes=(0, 1)):
44 | self.axes = axes
45 |
46 | for a in self.axes:
47 | assert a > 0
48 |
49 | def sample(self, *shape):
50 | shape = list(shape)
51 | i, j = self.axes
52 |
53 | # shape: no first dim
54 | i, j = i-1, j-1
55 | shape[i], shape[j] = shape[j], shape[i]
56 |
57 | return shape
58 |
59 | def tf(self, img, k=0):
60 | return np.rot90(img, axes=self.axes)
61 |
62 | def __str__(self):
63 | return 'Rot90(axes=({}, {})'.format(*self.axes)
64 |
65 | # class RandomRotion(Base):
66 | # def __init__(self, angle=20):# angle :in degress, float, [0,360]
67 | # assert angle >= 0.0
68 | # self.axes = (0,1) # 只对HW方向进行旋转
69 | # self.angle = angle #
70 | # self.buffer = None
71 | #
72 | # def sample(self, *shape):# shape : [H,W,D]
73 | # shape = list(shape)
74 | # self.buffer = round(np.random.uniform(low=-self.angle,high=self.angle),2) # 2个小数点
75 | # if self.buffer < 0:
76 | # self.buffer += 180
77 | # return shape
78 | #
79 | # def tf(self, img, k=0): # img shape [1,H,W,D,c] while label shape is [1,H,W,D]
80 | # return ndimage.rotate(img, angle=self.buffer, reshape=False)
81 | #
82 | # def __str__(self):
83 | # return 'RandomRotion(axes=({}, {}),Angle:{}'.format(*self.axes,self.buffer)
84 |
85 | class RandomRotion(Base):
86 | def __init__(self,angle_spectrum=10):
87 | assert isinstance(angle_spectrum,int)
88 | # axes = [(2, 1), (3, 1),(3, 2)]
89 | axes = [(1, 0), (2, 1),(2, 0)]
90 | self.angle_spectrum = angle_spectrum
91 | self.axes = axes
92 |
93 | def sample(self,*shape):
94 | self.axes_buffer = self.axes[np.random.choice(list(range(len(self.axes))))] # choose the random direction
95 | self.angle_buffer = np.random.randint(-self.angle_spectrum, self.angle_spectrum) # choose the random direction
96 | return list(shape)
97 |
98 | def tf(self, img, k=0):
99 | """ Introduction: The rotation function supports the shape [H,W,D,C] or shape [H,W,D]
100 | :param img: if x, shape is [1,H,W,D,c]; if label, shape is [1,H,W,D]
101 | :param k: if x, k=0; if label, k=1
102 | """
103 | bsize = img.shape[0]
104 |
105 | for bs in range(bsize):
106 | if k == 0:
107 | # [[H,W,D], ...]
108 | # print(img.shape) # (1, 128, 128, 128, 4)
109 | channels = [rotate(img[bs,:,:,:,c], self.angle_buffer, axes=self.axes_buffer, reshape=False, order=0, mode='constant', cval=-1) for c in
110 | range(img.shape[4])]
111 | img[bs,...] = np.stack(channels, axis=-1)
112 |
113 | if k == 1:
114 | img[bs,...] = rotate(img[bs,...], self.angle_buffer, axes=self.axes_buffer, reshape=False, order=0, mode='constant', cval=-1)
115 |
116 | return img
117 |
118 | def __str__(self):
119 | return 'RandomRotion(axes={},Angle:{}'.format(self.axes_buffer,self.angle_buffer)
120 |
121 |
122 | class Flip(Base):
123 | def __init__(self, axis=0):
124 | self.axis = axis
125 |
126 | def tf(self, img, k=0):
127 | return np.flip(img, self.axis)
128 |
129 | def __str__(self):
130 | return 'Flip(axis={})'.format(self.axis)
131 |
132 | class RandomFlip(Base):
133 | # mirror flip across all x,y,z
134 | def __init__(self,axis=0):
135 | # assert axis == (1,2,3) # For both data and label, it has to specify the axis.
136 | self.axis = (1,2,3)
137 | self.x_buffer = None
138 | self.y_buffer = None
139 | self.z_buffer = None
140 |
141 | def sample(self, *shape):
142 | self.x_buffer = np.random.choice([True,False])
143 | self.y_buffer = np.random.choice([True,False])
144 | self.z_buffer = np.random.choice([True,False])
145 | return list(shape) # the shape is not changed
146 |
147 | def tf(self,img,k=0): # img shape is (1, 240, 240, 155, 4)
148 | if self.x_buffer:
149 | img = np.flip(img,axis=self.axis[0])
150 | if self.y_buffer:
151 | img = np.flip(img,axis=self.axis[1])
152 | if self.z_buffer:
153 | img = np.flip(img,axis=self.axis[2])
154 | return img
155 |
156 |
157 | class RandSelect(Base):
158 | def __init__(self, prob=0.5, tf=None):
159 | self.prob = prob
160 | self.ops = tf if isinstance(tf, collections.Sequence) else (tf, )
161 | self.buff = False
162 |
163 | def sample(self, *shape):
164 | self.buff = random.random() < self.prob
165 |
166 | if self.buff:
167 | for op in self.ops:
168 | shape = op.sample(*shape)
169 |
170 | return shape
171 |
172 | def tf(self, img, k=0):
173 | if self.buff:
174 | for op in self.ops:
175 | img = op.tf(img, k)
176 | return img
177 |
178 | def __str__(self):
179 | if len(self.ops) == 1:
180 | ops = str(self.ops[0])
181 | else:
182 | ops = '[{}]'.format(', '.join([str(op) for op in self.ops]))
183 | return 'RandSelect({}, {})'.format(self.prob, ops)
184 |
185 |
186 | class CenterCrop(Base):
187 | def __init__(self, size):
188 | self.size = size
189 | self.buffer = None
190 |
191 | def sample(self, *shape):
192 | size = self.size
193 | start = [(s -size)//2 for s in shape]
194 | self.buffer = [slice(None)] + [slice(s, s+size) for s in start]
195 | return [size] * len(shape)
196 |
197 | def tf(self, img, k=0):
198 | # print(img.shape)#(1, 240, 240, 155, 4)
199 | return img[tuple(self.buffer)]
200 | # return img[self.buffer]
201 |
202 | def __str__(self):
203 | return 'CenterCrop({})'.format(self.size)
204 |
205 | class CenterCropBySize(CenterCrop):
206 | def sample(self, *shape):
207 | assert len(self.size) == 3 # random crop [H,W,T] from img [240,240,155]
208 | if not isinstance(self.size, list):
209 | size = list(self.size)
210 | else:
211 | size = self.size
212 | start = [(s-i)//2 for i, s in zip(size, shape)]
213 | self.buffer = [slice(None)] + [slice(s, s+i) for i, s in zip(size, start)]
214 | return size
215 |
216 | def __str__(self):
217 | return 'CenterCropBySize({})'.format(self.size)
218 |
219 | class RandCrop(CenterCrop):
220 | def sample(self, *shape):
221 | size = self.size
222 | start = [random.randint(0, s-size) for s in shape]
223 | self.buffer = [slice(None)] + [slice(s, s+size) for s in start]
224 | return [size]*len(shape)
225 |
226 | def __str__(self):
227 | return 'RandCrop({})'.format(self.size)
228 |
229 |
230 | class RandCrop3D(CenterCrop):
231 | def sample(self, *shape): # shape : [240,240,155]
232 | assert len(self.size)==3 # random crop [H,W,T] from img [240,240,155]
233 | if not isinstance(self.size,list):
234 | size = list(self.size)
235 | else:
236 | size = self.size
237 | start = [random.randint(0, s-i) for i,s in zip(size,shape)]
238 | self.buffer = [slice(None)] + [slice(s, s+k) for s,k in zip(start,size)]
239 | return size
240 |
241 | def __str__(self):
242 | return 'RandCrop({})'.format(self.size)
243 |
244 | # for data only
245 | class RandomIntensityChange(Base):
246 | def __init__(self,factor):
247 | shift,scale = factor
248 | assert (shift >0) and (scale >0)
249 | self.shift = shift
250 | self.scale = scale
251 |
252 | def tf(self,img,k=0):
253 | if k==1:
254 | return img
255 |
256 | shift_factor = np.random.uniform(-self.shift,self.shift,size=[1,img.shape[1],1,1,img.shape[4]]) # [-0.1,+0.1]
257 | scale_factor = np.random.uniform(1.0 - self.scale, 1.0 + self.scale,size=[1,img.shape[1],1,1,img.shape[4]]) # [0.9,1.1)
258 | # shift_factor = np.random.uniform(-self.shift,self.shift,size=[1,1,1,img.shape[3],img.shape[4]]) # [-0.1,+0.1]
259 | # scale_factor = np.random.uniform(1.0 - self.scale, 1.0 + self.scale,size=[1,1,1,img.shape[3],img.shape[4]]) # [0.9,1.1)
260 | return img * scale_factor + shift_factor
261 |
262 | def __str__(self):
263 | return 'random intensity shift per channels on the input image, including'
264 |
265 | class RandomGammaCorrection(Base):
266 | def __init__(self,factor):
267 | lower, upper = factor
268 | assert (lower >0) and (upper >0)
269 | self.lower = lower
270 | self.upper = upper
271 |
272 | def tf(self,img,k=0):
273 | if k==1:
274 | return img
275 | img = img + np.min(img)
276 | img_max = np.max(img)
277 | img = img/img_max
278 | factor = random.choice(np.arange(self.lower, self.upper, 0.1))
279 | gamma = random.choice([1, factor])
280 | if gamma == 1:
281 | return img
282 | img = img ** gamma * img_max
283 | img = (img - img.mean())/img.std()
284 | return img
285 |
286 | def __str__(self):
287 | return 'random intensity shift per channels on the input image, including'
288 |
289 | class MinMax_norm(Base):
290 | def __init__(self, ):
291 | a = None
292 |
293 | def tf(self, img, k=0):
294 | if k == 1:
295 | return img
296 | img = (img - img.min()) / (img.max()-img.min())
297 | return img
298 |
299 | class Seg_norm(Base):
300 | def __init__(self, ):
301 | a = None
302 | self.seg_table = np.array([0, 2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26,
303 | 28, 30, 31, 41, 42, 43, 44, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 62,
304 | 63, 72, 77, 80, 85, 251, 252, 253, 254, 255])
305 | def tf(self, img, k=0):
306 | if k == 0:
307 | return img
308 | img_out = np.zeros_like(img)
309 | for i in range(len(self.seg_table)):
310 | img_out[img == self.seg_table[i]] = i
311 | return img_out
312 |
313 | class Resize_img(Base):
314 | def __init__(self, shape):
315 | self.shape = shape
316 |
317 | def tf(self, img, k=0):
318 | if k == 1:
319 | img = resize(img, (img.shape[0], self.shape[0], self.shape[1], self.shape[2]),
320 | anti_aliasing=False, order=0)
321 | else:
322 | img = resize(img, (img.shape[0], self.shape[0], self.shape[1], self.shape[2]),
323 | anti_aliasing=False, order=3)
324 | return img
325 |
326 | class Pad(Base):
327 | def __init__(self, pad): # [0,0,0,5,0]
328 | self.pad = pad
329 | self.px = tuple(zip([0]*len(pad), pad))
330 |
331 | def sample(self, *shape):
332 |
333 | shape = list(shape)
334 |
335 | # shape: no first dim
336 | for i in range(len(shape)):
337 | shape[i] += self.pad[i+1]
338 |
339 | return shape
340 |
341 | def tf(self, img, k=0):
342 | #nhwtc, nhwt
343 | dim = len(img.shape)
344 | return np.pad(img, self.px[:dim], mode='constant')
345 |
346 | def __str__(self):
347 | return 'Pad(({}, {}, {}))'.format(*self.pad)
348 |
349 | class Pad3DIfNeeded(Base):
350 | def __init__(self, shape, value=0, mask_value=0): # [0,0,0,5,0]
351 | self.shape = shape
352 | self.value = value
353 | self.mask_value = mask_value
354 |
355 | def tf(self, img, k=0):
356 | pad = [(0,0)]
357 | if k==0:
358 | img_shape = img.shape[1:-1]
359 | else:
360 | img_shape = img.shape[1:]
361 | for i, t in zip(img_shape, self.shape):
362 | if i < t:
363 | diff = t-i
364 | pad.append((math.ceil(diff/2),math.floor(diff/2)))
365 | else:
366 | pad.append((0,0))
367 | if k == 0:
368 | pad.append((0,0))
369 | pad = tuple(pad)
370 | if k==0:
371 | return np.pad(img, pad, mode='constant', constant_values=img.min())
372 | else:
373 | return np.pad(img, pad, mode='constant', constant_values=self.mask_value)
374 |
375 | def __str__(self):
376 | return 'Pad(({}, {}, {}))'.format(*self.pad)
377 |
378 | class Noise(Base):
379 | def __init__(self, dim, sigma=0.1, channel=True, num=-1):
380 | self.dim = dim
381 | self.sigma = sigma
382 | self.channel = channel
383 | self.num = num
384 |
385 | def tf(self, img, k=0):
386 | if self.num > 0 and k >= self.num:
387 | return img
388 |
389 | if self.channel:
390 | #nhwtc, hwtc, hwt
391 | shape = [1] if len(img.shape) < self.dim+2 else [img.shape[-1]]
392 | else:
393 | shape = img.shape
394 | return img * np.exp(self.sigma * torch.randn(shape, dtype=torch.float32).numpy())
395 |
396 | def __str__(self):
397 | return 'Noise()'
398 |
399 |
400 | # dim could come from shape
401 | class GaussianBlur(Base):
402 | def __init__(self, dim, sigma=Constant(1.5), app=-1):
403 | # 1.5 pixel
404 | self.dim = dim
405 | self.sigma = sigma
406 | self.eps = 0.001
407 | self.app = app
408 |
409 | def tf(self, img, k=0):
410 | if self.num > 0 and k >= self.num:
411 | return img
412 |
413 | # image is nhwtc
414 | for n in range(img.shape[0]):
415 | sig = self.sigma.sample()
416 | # sample each channel saperately to avoid correlations
417 | if sig > self.eps:
418 | if len(img.shape) == self.dim+2:
419 | C = img.shape[-1]
420 | for c in range(C):
421 | img[n,..., c] = ndimage.gaussian_filter(img[n, ..., c], sig)
422 | elif len(img.shape) == self.dim+1:
423 | img[n] = ndimage.gaussian_filter(img[n], sig)
424 | else:
425 | raise ValueError('image shape is not supported')
426 |
427 | return img
428 |
429 | def __str__(self):
430 | return 'GaussianBlur()'
431 |
432 |
433 | class ToNumpy(Base):
434 | def __init__(self, num=-1):
435 | self.num = num
436 |
437 | def tf(self, img, k=0):
438 | if self.num > 0 and k >= self.num:
439 | return img
440 | return img.numpy()
441 |
442 | def __str__(self):
443 | return 'ToNumpy()'
444 |
445 |
446 | class ToTensor(Base):
447 | def __init__(self, num=-1):
448 | self.num = num
449 |
450 | def tf(self, img, k=0):
451 | if self.num > 0 and k >= self.num:
452 | return img
453 |
454 | return torch.from_numpy(img)
455 |
456 | def __str__(self):
457 | return 'ToTensor'
458 |
459 |
460 | class TensorType(Base):
461 | def __init__(self, types, num=-1):
462 | self.types = types # ('torch.float32', 'torch.int64')
463 | self.num = num
464 |
465 | def tf(self, img, k=0):
466 | if self.num > 0 and k >= self.num:
467 | return img
468 | # make this work with both Tensor and Numpy
469 | return img.type(self.types[k])
470 |
471 | def __str__(self):
472 | s = ', '.join([str(s) for s in self.types])
473 | return 'TensorType(({}))'.format(s)
474 |
475 |
476 | class NumpyType(Base):
477 | def __init__(self, types, num=-1):
478 | self.types = types # ('float32', 'int64')
479 | self.num = num
480 |
481 | def tf(self, img, k=0):
482 | if self.num > 0 and k >= self.num:
483 | return img
484 | # make this work with both Tensor and Numpy
485 | return img.astype(self.types[k])
486 |
487 | def __str__(self):
488 | s = ', '.join([str(s) for s in self.types])
489 | return 'NumpyType(({}))'.format(s)
490 |
491 |
492 | class Normalize(Base):
493 | def __init__(self, mean=0.0, std=1.0, num=-1):
494 | self.mean = mean
495 | self.std = std
496 | self.num = num
497 |
498 | def tf(self, img, k=0):
499 | if self.num > 0 and k >= self.num:
500 | return img
501 | img -= self.mean
502 | img /= self.std
503 | return img
504 |
505 | def __str__(self):
506 | return 'Normalize()'
507 |
508 |
509 | class Compose(Base):
510 | def __init__(self, ops):
511 | if not isinstance(ops, collections.Sequence):
512 | ops = ops,
513 | self.ops = ops
514 |
515 | def sample(self, *shape):
516 | for op in self.ops:
517 | shape = op.sample(*shape)
518 |
519 | def tf(self, img, k=0):
520 | #is_tensor = isinstance(img, torch.Tensor)
521 | #if is_tensor:
522 | # img = img.numpy()
523 |
524 | for op in self.ops:
525 | # print(op,img.shape,k)
526 | img = op.tf(img, k) # do not use op(img) here
527 |
528 | #if is_tensor:
529 | # img = np.ascontiguousarray(img)
530 | # img = torch.from_numpy(img)
531 |
532 | return img
533 |
534 | def __str__(self):
535 | ops = ', '.join([str(op) for op in self.ops])
536 | return 'Compose([{}])'.format(ops)
--------------------------------------------------------------------------------
/ViT-V-Net/infer.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from torch.utils.tensorboard import SummaryWriter
3 | import logging
4 | import os, losses, utils, nrrd
5 | import shutil
6 | import sys
7 | from torch.utils.data import DataLoader
8 | from data import datasets, trans
9 | import numpy as np
10 | import torch, models
11 | from torchvision import transforms
12 | from torch import optim
13 | import torch.nn as nn
14 | from ignite.contrib.handlers import ProgressBar
15 | from torchsummary import summary
16 | import matplotlib.pyplot as plt
17 | from models import CONFIGS as CONFIGS_ViT_seg
18 | from mpl_toolkits.mplot3d import axes3d
19 | from natsort import natsorted
20 |
21 |
22 |
23 | def plot_grid(gridx,gridy, **kwargs):
24 | for i in range(gridx.shape[1]):
25 | plt.plot(gridx[i,:], gridy[i,:], linewidth=0.8, **kwargs)
26 | for i in range(gridx.shape[0]):
27 | plt.plot(gridx[:,i], gridy[:,i], linewidth=0.8, **kwargs)
28 |
29 | class AverageMeter(object):
30 | """Computes and stores the average and current value"""
31 | def __init__(self):
32 | self.reset()
33 |
34 | def reset(self):
35 | self.val = 0
36 | self.avg = 0
37 | self.sum = 0
38 | self.count = 0
39 | self.vals = []
40 | self.std = 0
41 |
42 | def update(self, val, n=1):
43 | self.val = val
44 | self.sum += val * n
45 | self.count += n
46 | self.avg = self.sum / self.count
47 | self.vals.append(val)
48 | self.std = np.std(self.vals)
49 |
50 | def MSE_torch(x, y):
51 | return torch.mean((x - y) ** 2)
52 |
53 | def MAE_torch(x, y):
54 | return torch.mean(torch.abs(x - y))
55 |
56 | def main():
57 | test_dir = 'D:/DATA/JHUBrain/Test/'
58 | model_idx = -1
59 | model_folder = 'ViTVNet_reg0.02_mse_diff/'
60 | model_dir = 'experiments/' + model_folder
61 | config_vit = CONFIGS_ViT_seg['ViT-V-Net']
62 | dict = utils.process_label()
63 | if os.path.exists('experiments/'+model_folder[:-1]+'.csv'):
64 | os.remove('experiments/'+model_folder[:-1]+'.csv')
65 | csv_writter(model_folder[:-1], 'experiments/' + model_folder[:-1])
66 | line = ''
67 | for i in range(46):
68 | line = line + ',' + dict[i]
69 | csv_writter(line, 'experiments/' + model_folder[:-1])
70 | model = models.ViTVNet(config_vit, img_size=(160, 192, 224))
71 | best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[model_idx])['state_dict']
72 | print('Best model: {}'.format(natsorted(os.listdir(model_dir))[model_idx]))
73 | model.load_state_dict(best_model)
74 | model.cuda()
75 | reg_model = utils.register_model((160, 192, 224), 'nearest')
76 | reg_model.cuda()
77 | test_composed = transforms.Compose([trans.Seg_norm(),
78 | trans.NumpyType((np.float32, np.int16)),
79 | ])
80 | test_set = datasets.JHUBrainInferDataset(glob.glob(test_dir + '*.pkl'), transforms=test_composed)
81 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=1, pin_memory=True, drop_last=True)
82 | eval_dsc_def = AverageMeter()
83 | eval_dsc_raw = AverageMeter()
84 | eval_det = AverageMeter()
85 | with torch.no_grad():
86 | stdy_idx = 0
87 | for data in test_loader:
88 | model.eval()
89 | data = [t.cuda() for t in data]
90 | x = data[0]
91 | y = data[1]
92 | x_seg = data[2]
93 | y_seg = data[3]
94 |
95 | x_in = torch.cat((x,y),dim=1)
96 | x_def, flow = model(x_in)
97 | def_out = reg_model([x_seg.cuda().float(), flow.cuda()])
98 | tar = y.detach().cpu().numpy()[0, 0, :, :, :]
99 | #jac_det = utils.jacobian_determinant(flow.detach().cpu().numpy()[0, :, :, :, :])
100 | line = utils.dice_val_substruct(def_out.long(), y_seg.long(), stdy_idx)
101 | line = line #+','+str(np.sum(jac_det <= 0)/np.prod(tar.shape))
102 | csv_writter(line, 'experiments/' + model_folder[:-1])
103 | #eval_det.update(np.sum(jac_det <= 0) / np.prod(tar.shape), x.size(0))
104 |
105 | dsc_trans = utils.dice_val(def_out.long(), y_seg.long(), 46)
106 | dsc_raw = utils.dice_val(x_seg.long(), y_seg.long(), 46)
107 | print('Trans diff: {:.4f}, Raw diff: {:.4f}'.format(dsc_trans.item(),dsc_raw.item()))
108 | eval_dsc_def.update(dsc_trans.item(), x.size(0))
109 | eval_dsc_raw.update(dsc_raw.item(), x.size(0))
110 | stdy_idx += 1
111 |
112 | # flip moving and fixed images
113 | y_in = torch.cat((y, x), dim=1)
114 | y_def, flow = model(y_in)
115 | def_out = reg_model([y_seg.cuda().float(), flow.cuda()])
116 | tar = x.detach().cpu().numpy()[0, 0, :, :, :]
117 |
118 | #jac_det = utils.jacobian_determinant(flow.detach().cpu().numpy()[0, :, :, :, :])
119 | line = utils.dice_val_substruct(def_out.long(), x_seg.long(), stdy_idx)
120 | line = line #+ ',' + str(np.sum(jac_det < 0) / np.prod(tar.shape))
121 | out = def_out.detach().cpu().numpy()[0, 0, :, :, :]
122 | #print('det < 0: {}'.format(np.sum(jac_det <= 0)/np.prod(tar.shape)))
123 | csv_writter(line, 'experiments/' + model_folder[:-1])
124 | #eval_det.update(np.sum(jac_det <= 0) / np.prod(tar.shape), x.size(0))
125 |
126 | dsc_trans = utils.dice_val(def_out.long(), x_seg.long(), 46)
127 | dsc_raw = utils.dice_val(y_seg.long(), x_seg.long(), 46)
128 | print('Trans diff: {:.4f}, Raw diff: {:.4f}'.format(dsc_trans.item(), dsc_raw.item()))
129 | eval_dsc_def.update(dsc_trans.item(), x.size(0))
130 | eval_dsc_raw.update(dsc_raw.item(), x.size(0))
131 | stdy_idx += 1
132 |
133 | print('Deformed DSC: {:.3f} +- {:.3f}, Affine DSC: {:.3f} +- {:.3f}'.format(eval_dsc_def.avg,
134 | eval_dsc_def.std,
135 | eval_dsc_raw.avg,
136 | eval_dsc_raw.std))
137 | print('deformed det: {}, std: {}'.format(eval_det.avg, eval_det.std))
138 |
139 | def csv_writter(line, name):
140 | with open(name+'.csv', 'a') as file:
141 | file.write(line)
142 | file.write('\n')
143 |
144 | if __name__ == '__main__':
145 | '''
146 | GPU configuration
147 | '''
148 | GPU_iden = 0
149 | GPU_num = torch.cuda.device_count()
150 | print('Number of GPU: ' + str(GPU_num))
151 | for GPU_idx in range(GPU_num):
152 | GPU_name = torch.cuda.get_device_name(GPU_idx)
153 | print(' GPU #' + str(GPU_idx) + ': ' + GPU_name)
154 | torch.cuda.set_device(GPU_iden)
155 | GPU_avai = torch.cuda.is_available()
156 | print('Currently using: ' + torch.cuda.get_device_name(GPU_iden))
157 | print('If the GPU is available? ' + str(GPU_avai))
158 | main()
--------------------------------------------------------------------------------
/ViT-V-Net/label_info.txt:
--------------------------------------------------------------------------------
1 | 0 Unknown 0 0 0 0
2 | 1 Left-Cerebral-Exterior 70 130 180 0
3 | 2 Left-Cerebral-White-Matter 245 245 245 0
4 | 3 Left-Cerebral-Cortex 205 62 78 0
5 | 4 Left-Lateral-Ventricle 120 18 134 0
6 | 5 Left-Inf-Lat-Vent 196 58 250 0
7 | 6 Left-Cerebellum-Exterior 0 148 0 0
8 | 7 Left-Cerebellum-White-Matter 220 248 164 0
9 | 8 Left-Cerebellum-Cortex 230 148 34 0
10 | 9 Left-Thalamus 0 118 14 0
11 | 10 Left-Thalamus-Proper* 0 118 14 0
12 | 11 Left-Caudate 122 186 220 0
13 | 12 Left-Putamen 236 13 176 0
14 | 13 Left-Pallidum 12 48 255 0
15 | 14 3rd-Ventricle 204 182 142 0
16 | 15 4th-Ventricle 42 204 164 0
17 | 16 Brain-Stem 119 159 176 0
18 | 17 Left-Hippocampus 220 216 20 0
19 | 18 Left-Amygdala 103 255 255 0
20 | 19 Left-Insula 80 196 98 0
21 | 20 Left-Operculum 60 58 210 0
22 | 21 Line-1 60 58 210 0
23 | 22 Line-2 60 58 210 0
24 | 23 Line-3 60 58 210 0
25 | 24 CSF 60 60 60 0
26 | 25 Left-Lesion 255 165 0 0
27 | 26 Left-Accumbens-area 255 165 0 0
28 | 27 Left-Substancia-Nigra 0 255 127 0
29 | 28 Left-VentralDC 165 42 42 0
30 | 29 Left-undetermined 135 206 235 0
31 | 30 Left-vessel 160 32 240 0
32 | 31 Left-choroid-plexus 0 200 200 0
33 | 32 Left-F3orb 100 50 100 0
34 | 33 Left-lOg 135 50 74 0
35 | 34 Left-aOg 122 135 50 0
36 | 35 Left-mOg 51 50 135 0
37 | 36 Left-pOg 74 155 60 0
38 | 37 Left-Stellate 120 62 43 0
39 | 38 Left-Porg 74 155 60 0
40 | 39 Left-Aorg 122 135 50 0
41 | 40 Right-Cerebral-Exterior 70 130 180 0
42 | 41 Right-Cerebral-White-Matter 245 245 245 0
43 | 42 Right-Cerebral-Cortex 205 62 78 0
44 | 43 Right-Lateral-Ventricle 120 18 134 0
45 | 44 Right-Inf-Lat-Vent 196 58 250 0
46 | 45 Right-Cerebellum-Exterior 0 148 0 0
47 | 46 Right-Cerebellum-White-Matter 220 248 164 0
48 | 47 Right-Cerebellum-Cortex 230 148 34 0
49 | 48 Right-Thalamus 0 118 14 0
50 | 49 Right-Thalamus-Proper* 0 118 14 0
51 | 50 Right-Caudate 122 186 220 0
52 | 51 Right-Putamen 236 13 176 0
53 | 52 Right-Pallidum 13 48 255 0
54 | 53 Right-Hippocampus 220 216 20 0
55 | 54 Right-Amygdala 103 255 255 0
56 | 55 Right-Insula 80 196 98 0
57 | 56 Right-Operculum 60 58 210 0
58 | 57 Right-Lesion 255 165 0 0
59 | 58 Right-Accumbens-area 255 165 0 0
60 | 59 Right-Substancia-Nigra 0 255 127 0
61 | 60 Right-VentralDC 165 42 42 0
62 | 61 Right-undetermined 135 206 235 0
63 | 62 Right-vessel 160 32 240 0
64 | 63 Right-choroid-plexus 0 200 221 0
65 | 64 Right-F3orb 100 50 100 0
66 | 65 Right-lOg 135 50 74 0
67 | 66 Right-aOg 122 135 50 0
68 | 67 Right-mOg 51 50 135 0
69 | 68 Right-pOg 74 155 60 0
70 | 69 Right-Stellate 120 62 43 0
71 | 70 Right-Porg 74 155 60 0
72 | 71 Right-Aorg 122 135 50 0
73 | 72 5th-Ventricle 120 190 150 0
74 | 73 Left-Interior 122 135 50 0
75 | 74 Right-Interior 122 135 50 0
76 |
77 | 77 WM-hypointensities 200 70 255 0
78 | 78 Left-WM-hypointensities 255 148 10 0
79 | 79 Right-WM-hypointensities 255 148 10 0
80 | 80 non-WM-hypointensities 164 108 226 0
81 | 81 Left-non-WM-hypointensities 164 108 226 0
82 | 82 Right-non-WM-hypointensities 164 108 226 0
83 | 83 Left-F1 255 218 185 0
84 | 84 Right-F1 255 218 185 0
85 | 85 Optic-Chiasm 234 169 30 0
86 | 192 Corpus_Callosum 250 255 50 0
87 |
88 | 86 Left_future_WMSA 200 120 255 0
89 | 87 Right_future_WMSA 200 121 255 0
90 | 88 future_WMSA 200 122 255 0
91 |
92 |
93 | 96 Left-Amygdala-Anterior 205 10 125 0
94 | 97 Right-Amygdala-Anterior 205 10 125 0
95 | 98 Dura 160 32 240 0
96 |
97 | 100 Left-wm-intensity-abnormality 124 140 178 0
98 | 101 Left-caudate-intensity-abnormality 125 140 178 0
99 | 102 Left-putamen-intensity-abnormality 126 140 178 0
100 | 103 Left-accumbens-intensity-abnormality 127 140 178 0
101 | 104 Left-pallidum-intensity-abnormality 124 141 178 0
102 | 105 Left-amygdala-intensity-abnormality 124 142 178 0
103 | 106 Left-hippocampus-intensity-abnormality 124 143 178 0
104 | 107 Left-thalamus-intensity-abnormality 124 144 178 0
105 | 108 Left-VDC-intensity-abnormality 124 140 179 0
106 | 109 Right-wm-intensity-abnormality 124 140 178 0
107 | 110 Right-caudate-intensity-abnormality 125 140 178 0
108 | 111 Right-putamen-intensity-abnormality 126 140 178 0
109 | 112 Right-accumbens-intensity-abnormality 127 140 178 0
110 | 113 Right-pallidum-intensity-abnormality 124 141 178 0
111 | 114 Right-amygdala-intensity-abnormality 124 142 178 0
112 | 115 Right-hippocampus-intensity-abnormality 124 143 178 0
113 | 116 Right-thalamus-intensity-abnormality 124 144 178 0
114 | 117 Right-VDC-intensity-abnormality 124 140 179 0
115 |
116 | 118 Epidermis 255 20 147 0
117 | 119 Conn-Tissue 205 179 139 0
118 | 120 SC-Fat-Muscle 238 238 209 0
119 | 121 Cranium 200 200 200 0
120 | 122 CSF-SA 74 255 74 0
121 | 123 Muscle 238 0 0 0
122 | 124 Ear 0 0 139 0
123 | 125 Adipose 173 255 47 0
124 | 126 Spinal-Cord 133 203 229 0
125 | 127 Soft-Tissue 26 237 57 0
126 | 128 Nerve 34 139 34 0
127 | 129 Bone 30 144 255 0
128 | 130 Air 147 19 173 0
129 | 131 Orbital-Fat 238 59 59 0
130 | 132 Tongue 221 39 200 0
131 | 133 Nasal-Structures 238 174 238 0
132 | 134 Globe 255 0 0 0
133 | 135 Teeth 72 61 139 0
134 | 136 Left-Caudate-Putamen 21 39 132 0
135 | 137 Right-Caudate-Putamen 21 39 132 0
136 | 138 Left-Claustrum 65 135 20 0
137 | 139 Right-Claustrum 65 135 20 0
138 | 140 Cornea 134 4 160 0
139 | 142 Diploe 221 226 68 0
140 | 143 Vitreous-Humor 255 255 254 0
141 | 144 Lens 52 209 226 0
142 | 145 Aqueous-Humor 239 160 223 0
143 | 146 Outer-Table 70 130 180 0
144 | 147 Inner-Table 70 130 181 0
145 | 148 Periosteum 139 121 94 0
146 | 149 Endosteum 224 224 224 0
147 | 150 R-C-S 255 0 0 0
148 | 151 Iris 205 205 0 0
149 | 152 SC-Adipose-Muscle 238 238 209 0
150 | 153 SC-Tissue 139 121 94 0
151 | 154 Orbital-Adipose 238 59 59 0
152 |
153 | 155 Left-IntCapsule-Ant 238 59 59 0
154 | 156 Right-IntCapsule-Ant 238 59 59 0
155 | 157 Left-IntCapsule-Pos 62 10 205 0
156 | 158 Right-IntCapsule-Pos 62 10 205 0
157 |
158 | # These labels are for babies/children
159 | 159 Left-Cerebral-WM-unmyelinated 0 118 14 0
160 | 160 Right-Cerebral-WM-unmyelinated 0 118 14 0
161 | 161 Left-Cerebral-WM-myelinated 220 216 21 0
162 | 162 Right-Cerebral-WM-myelinated 220 216 21 0
163 | 163 Left-Subcortical-Gray-Matter 122 186 220 0
164 | 164 Right-Subcortical-Gray-Matter 122 186 220 0
165 | 165 Skull 120 120 120 0
166 | 166 Posterior-fossa 14 48 255 0
167 | 167 Scalp 166 42 42 0
168 | 168 Hematoma 121 18 134 0
169 | 169 Left-Basal-Ganglia 236 13 127 0
170 | 176 Right-Basal-Ganglia 236 13 126 0
171 |
172 | # Label names and colors for Brainstem consituents
173 | # No. Label Name: R G B A
174 | 170 brainstem 119 159 176 0
175 | 171 DCG 119 0 176 0
176 | 172 Vermis 119 100 176 0
177 | 173 Midbrain 242 104 76 0
178 | 174 Pons 206 195 58 0
179 | 175 Medulla 119 159 176 0
180 | 177 Vermis-White-Matter 119 50 176 0
181 | 178 SCP 142 182 0 0
182 | 179 Floculus 19 100 176 0
183 |
184 | 180 Left-Cortical-Dysplasia 73 61 139 0
185 | 181 Right-Cortical-Dysplasia 73 62 139 0
186 | 182 CblumNodulus 10 100 176 0
187 |
188 | 193 Left-hippocampal_fissure 0 196 255 0
189 | 194 Left-CADG-head 255 164 164 0
190 | 195 Left-subiculum 196 196 0 0
191 | 196 Left-fimbria 0 100 255 0
192 | 197 Right-hippocampal_fissure 128 196 164 0
193 | 198 Right-CADG-head 0 126 75 0
194 | 199 Right-subiculum 128 96 64 0
195 | 200 Right-fimbria 0 50 128 0
196 | 201 alveus 255 204 153 0
197 | 202 perforant_pathway 255 128 128 0
198 | 203 parasubiculum 255 255 0 0
199 | 204 presubiculum 64 0 64 0
200 | 205 subiculum 0 0 255 0
201 | 206 CA1 255 0 0 0
202 | 207 CA2 128 128 255 0
203 | 208 CA3 0 128 0 0
204 | 209 CA4 196 160 128 0
205 | 210 GC-DG 32 200 255 0
206 | 211 HATA 128 255 128 0
207 | 212 fimbria 204 153 204 0
208 | 213 lateral_ventricle 121 17 136 0
209 | 214 molecular_layer_HP 128 0 0 0
210 | 215 hippocampal_fissure 128 32 255 0
211 | 216 entorhinal_cortex 255 204 102 0
212 | 217 molecular_layer_subiculum 128 128 128 0
213 | 218 Amygdala 104 255 255 0
214 | 219 Cerebral_White_Matter 0 226 0 0
215 | 220 Cerebral_Cortex 205 63 78 0
216 | 221 Inf_Lat_Vent 197 58 250 0
217 | 222 Perirhinal 33 150 250 0
218 | 223 Cerebral_White_Matter_Edge 226 0 0 0
219 | 224 Background 100 100 100 0
220 | 225 Ectorhinal 197 150 250 0
221 | 226 HP_tail 170 170 255 0
222 |
223 | 250 Fornix 255 0 0 0
224 | 251 CC_Posterior 0 0 64 0
225 | 252 CC_Mid_Posterior 0 0 112 0
226 | 253 CC_Central 0 0 160 0
227 | 254 CC_Mid_Anterior 0 0 208 0
228 | 255 CC_Anterior 0 0 255 0
229 |
--------------------------------------------------------------------------------
/ViT-V-Net/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 | import math
7 |
8 |
9 | def gaussian(window_size, sigma):
10 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
11 | return gauss / gauss.sum()
12 |
13 |
14 | def create_window(window_size, channel):
15 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
16 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
17 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
18 | return window
19 |
20 |
21 | def create_window_3D(window_size, channel):
22 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
23 | _2D_window = _1D_window.mm(_1D_window.t())
24 | _3D_window = _1D_window.mm(_2D_window.reshape(1, -1)).reshape(window_size, window_size,
25 | window_size).float().unsqueeze(0).unsqueeze(0)
26 | window = Variable(_3D_window.expand(channel, 1, window_size, window_size, window_size).contiguous())
27 | return window
28 |
29 |
30 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
31 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
32 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
33 |
34 | mu1_sq = mu1.pow(2)
35 | mu2_sq = mu2.pow(2)
36 | mu1_mu2 = mu1 * mu2
37 |
38 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
39 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
40 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
41 |
42 | C1 = 0.01 ** 2
43 | C2 = 0.03 ** 2
44 |
45 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
46 |
47 | if size_average:
48 | return ssim_map.mean()
49 | else:
50 | return ssim_map.mean(1).mean(1).mean(1)
51 |
52 |
53 | def _ssim_3D(img1, img2, window, window_size, channel, size_average=True):
54 | mu1 = F.conv3d(img1, window, padding=window_size // 2, groups=channel)
55 | mu2 = F.conv3d(img2, window, padding=window_size // 2, groups=channel)
56 |
57 | mu1_sq = mu1.pow(2)
58 | mu2_sq = mu2.pow(2)
59 |
60 | mu1_mu2 = mu1 * mu2
61 |
62 | sigma1_sq = F.conv3d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
63 | sigma2_sq = F.conv3d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
64 | sigma12 = F.conv3d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
65 |
66 | C1 = 0.01 ** 2
67 | C2 = 0.03 ** 2
68 |
69 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
70 |
71 | if size_average:
72 | return ssim_map.mean()
73 | else:
74 | return ssim_map.mean(1).mean(1).mean(1)
75 |
76 |
77 | class SSIM(torch.nn.Module):
78 | def __init__(self, window_size=11, size_average=True):
79 | super(SSIM, self).__init__()
80 | self.window_size = window_size
81 | self.size_average = size_average
82 | self.channel = 1
83 | self.window = create_window(window_size, self.channel)
84 |
85 | def forward(self, img1, img2):
86 | (_, channel, _, _) = img1.size()
87 |
88 | if channel == self.channel and self.window.data.type() == img1.data.type():
89 | window = self.window
90 | else:
91 | window = create_window(self.window_size, channel)
92 |
93 | if img1.is_cuda:
94 | window = window.cuda(img1.get_device())
95 | window = window.type_as(img1)
96 |
97 | self.window = window
98 | self.channel = channel
99 |
100 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
101 |
102 |
103 | class SSIM3D(torch.nn.Module):
104 | def __init__(self, window_size=11, size_average=True):
105 | super(SSIM3D, self).__init__()
106 | self.window_size = window_size
107 | self.size_average = size_average
108 | self.channel = 1
109 | self.window = create_window_3D(window_size, self.channel)
110 |
111 | def forward(self, img1, img2):
112 | (_, channel, _, _, _) = img1.size()
113 |
114 | if channel == self.channel and self.window.data.type() == img1.data.type():
115 | window = self.window
116 | else:
117 | window = create_window_3D(self.window_size, channel)
118 |
119 | if img1.is_cuda:
120 | window = window.cuda(img1.get_device())
121 | window = window.type_as(img1)
122 |
123 | self.window = window
124 | self.channel = channel
125 |
126 | return 1-_ssim_3D(img1, img2, window, self.window_size, channel, self.size_average)
127 |
128 |
129 | def ssim(img1, img2, window_size=11, size_average=True):
130 | (_, channel, _, _) = img1.size()
131 | window = create_window(window_size, channel)
132 |
133 | if img1.is_cuda:
134 | window = window.cuda(img1.get_device())
135 | window = window.type_as(img1)
136 |
137 | return _ssim(img1, img2, window, window_size, channel, size_average)
138 |
139 |
140 | def ssim3D(img1, img2, window_size=11, size_average=True):
141 | (_, channel, _, _, _) = img1.size()
142 | window = create_window_3D(window_size, channel)
143 |
144 | if img1.is_cuda:
145 | window = window.cuda(img1.get_device())
146 | window = window.type_as(img1)
147 |
148 | return _ssim_3D(img1, img2, window, window_size, channel, size_average)
149 |
150 |
151 | class Grad(torch.nn.Module):
152 | """
153 | N-D gradient loss.
154 | """
155 |
156 | def __init__(self, penalty='l1', loss_mult=None):
157 | super(Grad, self).__init__()
158 | self.penalty = penalty
159 | self.loss_mult = loss_mult
160 |
161 | def forward(self, y_pred, y_true):
162 | dy = torch.abs(y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :])
163 | dx = torch.abs(y_pred[:, :, :, 1:] - y_pred[:, :, :, :-1])
164 | #dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])
165 |
166 | if self.penalty == 'l2':
167 | dy = dy * dy
168 | dx = dx * dx
169 | #dz = dz * dz
170 |
171 | d = torch.mean(dx) + torch.mean(dy)# + torch.mean(dz)
172 | grad = d / 2.0
173 |
174 | if self.loss_mult is not None:
175 | grad *= self.loss_mult
176 | return grad
177 |
178 | class Grad3d(torch.nn.Module):
179 | """
180 | N-D gradient loss.
181 | """
182 |
183 | def __init__(self, penalty='l1', loss_mult=None):
184 | super(Grad3d, self).__init__()
185 | self.penalty = penalty
186 | self.loss_mult = loss_mult
187 |
188 | def forward(self, y_pred, y_true):
189 | dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :])
190 | dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :])
191 | dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])
192 |
193 | if self.penalty == 'l2':
194 | dy = dy * dy
195 | dx = dx * dx
196 | dz = dz * dz
197 |
198 | d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz)
199 | grad = d / 3.0
200 |
201 | if self.loss_mult is not None:
202 | grad *= self.loss_mult
203 | return grad
204 |
205 | class Grad3DiTV(torch.nn.Module):
206 | """
207 | N-D gradient loss.
208 | """
209 |
210 | def __init__(self):
211 | super(Grad3DiTV, self).__init__()
212 | a = 1
213 |
214 | def forward(self, y_pred, y_true):
215 | dy = torch.abs(y_pred[:, :, 1:, 1:, 1:] - y_pred[:, :, :-1, 1:, 1:])
216 | dx = torch.abs(y_pred[:, :, 1:, 1:, 1:] - y_pred[:, :, 1:, :-1, 1:])
217 | dz = torch.abs(y_pred[:, :, 1:, 1:, 1:] - y_pred[:, :, 1:, 1:, :-1])
218 | dy = dy * dy
219 | dx = dx * dx
220 | dz = dz * dz
221 | d = torch.mean(torch.sqrt(dx+dy+dz+1e-6))
222 | grad = d / 3.0
223 | return grad
224 |
225 | class NCC(torch.nn.Module):
226 | """
227 | Local (over window) normalized cross correlation loss.
228 | """
229 |
230 | def __init__(self, win=None):
231 | super(NCC, self).__init__()
232 | self.win = win
233 |
234 | def forward(self, y_pred, y_true):
235 |
236 | I = y_true
237 | J = y_pred
238 |
239 | # get dimension of volume
240 | # assumes I, J are sized [batch_size, *vol_shape, nb_feats]
241 | ndims = len(list(I.size())) - 2
242 | assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims
243 |
244 | # set window size
245 | win = [9] * ndims if self.win is None else self.win
246 |
247 | # compute filters
248 | sum_filt = torch.ones([1, 1, *win]).to("cuda")
249 |
250 | pad_no = math.floor(win[0]/2)
251 |
252 | if ndims == 1:
253 | stride = (1)
254 | padding = (pad_no)
255 | elif ndims == 2:
256 | stride = (1,1)
257 | padding = (pad_no, pad_no)
258 | else:
259 | stride = (1,1,1)
260 | padding = (pad_no, pad_no, pad_no)
261 |
262 | # get convolution function
263 | conv_fn = getattr(F, 'conv%dd' % ndims)
264 |
265 | # compute CC squares
266 | I2 = I * I
267 | J2 = J * J
268 | IJ = I * J
269 |
270 | I_sum = conv_fn(I, sum_filt, stride=stride, padding=padding)
271 | J_sum = conv_fn(J, sum_filt, stride=stride, padding=padding)
272 | I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding)
273 | J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding)
274 | IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding)
275 |
276 | win_size = np.prod(win)
277 | u_I = I_sum / win_size
278 | u_J = J_sum / win_size
279 |
280 | cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size
281 | I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
282 | J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size
283 |
284 | cc = cross * cross / (I_var * J_var + 1e-5)
285 |
286 | return -torch.mean(cc)
287 |
288 | class MutualInformation(torch.nn.Module):
289 | """
290 | Mutual Information
291 | """
292 | def __init__(self, sigma_ratio=1, minval=0., maxval=1., num_bin=32):
293 | super(MutualInformation, self).__init__()
294 |
295 | """Create bin centers"""
296 | bin_centers = np.linspace(minval, maxval, num=num_bin)
297 | vol_bin_centers = Variable(torch.linspace(minval, maxval, num_bin), requires_grad=False).cuda()
298 | num_bins = len(bin_centers)
299 |
300 | """Sigma for Gaussian approx."""
301 | sigma = np.mean(np.diff(bin_centers)) * sigma_ratio
302 | print(sigma)
303 |
304 | self.preterm = 1 / (2 * sigma**2)
305 | self.bin_centers = bin_centers
306 | self.max_clip = maxval
307 | self.num_bins = num_bins
308 | self.vol_bin_centers = vol_bin_centers
309 |
310 | def mi(self, y_true, y_pred):
311 | y_pred = torch.clamp(y_pred, 0., self.max_clip)
312 | y_true = torch.clamp(y_true, 0, self.max_clip)
313 |
314 | y_true = y_true.view(y_true.shape[0], -1)
315 | y_true = torch.unsqueeze(y_true, 2)
316 | y_pred = y_pred.view(y_pred.shape[0], -1)
317 | y_pred = torch.unsqueeze(y_pred, 2)
318 |
319 | nb_voxels = y_pred.shape[1] # total num of voxels
320 |
321 | """Reshape bin centers"""
322 | o = [1, 1, np.prod(self.vol_bin_centers.shape)]
323 | vbc = torch.reshape(self.vol_bin_centers, o).cuda()
324 |
325 | """compute image terms by approx. Gaussian dist."""
326 | I_a = torch.exp(- self.preterm * torch.square(y_true - vbc))
327 | I_a = I_a / torch.sum(I_a, dim=-1, keepdim=True)
328 |
329 | I_b = torch.exp(- self.preterm * torch.square(y_pred - vbc))
330 | I_b = I_b / torch.sum(I_b, dim=-1, keepdim=True)
331 |
332 | # compute probabilities
333 | pab = torch.bmm(I_a.permute(0, 2, 1), I_b)
334 | pab = pab/nb_voxels
335 | pa = torch.mean(I_a, dim=1, keepdim=True)
336 | pb = torch.mean(I_b, dim=1, keepdim=True)
337 |
338 | papb = torch.bmm(pa.permute(0, 2, 1), pb) + 1e-6
339 | mi = torch.sum(torch.sum(pab * torch.log(pab / papb + 1e-6), dim=1), dim=1)
340 | return mi.mean() #average across batch
341 |
342 | def forward(self, y_true, y_pred):
343 | return -self.mi(y_true, y_pred)
344 |
345 | class localMutualInformation(torch.nn.Module):
346 | """
347 | Local Mutual Information for non-overlapping patches
348 | """
349 | def __init__(self, sigma_ratio=1, minval=0., maxval=1., num_bin=32, patch_size=5):
350 | super(localMutualInformation, self).__init__()
351 |
352 | """Create bin centers"""
353 | bin_centers = np.linspace(minval, maxval, num=num_bin)
354 | vol_bin_centers = Variable(torch.linspace(minval, maxval, num_bin), requires_grad=False).cuda()
355 | num_bins = len(bin_centers)
356 |
357 | """Sigma for Gaussian approx."""
358 | sigma = np.mean(np.diff(bin_centers)) * sigma_ratio
359 |
360 | self.preterm = 1 / (2 * sigma**2)
361 | self.bin_centers = bin_centers
362 | self.max_clip = maxval
363 | self.num_bins = num_bins
364 | self.vol_bin_centers = vol_bin_centers
365 | self.patch_size = patch_size
366 |
367 | def local_mi(self, y_true, y_pred):
368 | y_pred = torch.clamp(y_pred, 0., self.max_clip)
369 | y_true = torch.clamp(y_true, 0, self.max_clip)
370 |
371 | """Reshape bin centers"""
372 | o = [1, 1, np.prod(self.vol_bin_centers.shape)]
373 | vbc = torch.reshape(self.vol_bin_centers, o).cuda()
374 |
375 | """Making image paddings"""
376 | if len(list(y_pred.size())[2:]) == 3:
377 | ndim = 3
378 | x, y, z = list(y_pred.size())[2:]
379 | # compute padding sizes
380 | x_r = -x % self.patch_size
381 | y_r = -y % self.patch_size
382 | z_r = -z % self.patch_size
383 | padding = (z_r // 2, z_r - z_r // 2, y_r // 2, y_r - y_r // 2, x_r // 2, x_r - x_r // 2, 0, 0, 0, 0)
384 | elif len(list(y_pred.size())[2:]) == 2:
385 | ndim = 2
386 | x, y = list(y_pred.size())[2:]
387 | # compute padding sizes
388 | x_r = -x % self.patch_size
389 | y_r = -y % self.patch_size
390 | padding = (y_r // 2, y_r - y_r // 2, x_r // 2, x_r - x_r // 2, 0, 0, 0, 0)
391 | else:
392 | raise Exception('Supports 2D and 3D but not {}'.format(list(y_pred.size())))
393 | y_true = F.pad(y_true, padding, "constant", 0)
394 | y_pred = F.pad(y_pred, padding, "constant", 0)
395 |
396 | """Reshaping images into non-overlapping patches"""
397 | if ndim == 3:
398 | y_true_patch = torch.reshape(y_true, (y_true.shape[0], y_true.shape[1],
399 | (x + x_r) // self.patch_size, self.patch_size,
400 | (y + y_r) // self.patch_size, self.patch_size,
401 | (z + z_r) // self.patch_size, self.patch_size))
402 | y_true_patch = y_true_patch.permute(0, 1, 2, 4, 6, 3, 5, 7)
403 | y_true_patch = torch.reshape(y_true_patch, (-1, self.patch_size ** 3, 1))
404 |
405 | y_pred_patch = torch.reshape(y_pred, (y_pred.shape[0], y_pred.shape[1],
406 | (x + x_r) // self.patch_size, self.patch_size,
407 | (y + y_r) // self.patch_size, self.patch_size,
408 | (z + z_r) // self.patch_size, self.patch_size))
409 | y_pred_patch = y_pred_patch.permute(0, 1, 2, 4, 6, 3, 5, 7)
410 | y_pred_patch = torch.reshape(y_pred_patch, (-1, self.patch_size ** 3, 1))
411 | else:
412 | y_true_patch = torch.reshape(y_true, (y_true.shape[0], y_true.shape[1],
413 | (x + x_r) // self.patch_size, self.patch_size,
414 | (y + y_r) // self.patch_size, self.patch_size))
415 | y_true_patch = y_true_patch.permute(0, 1, 2, 4, 3, 5)
416 | y_true_patch = torch.reshape(y_true_patch, (-1, self.patch_size ** 2, 1))
417 |
418 | y_pred_patch = torch.reshape(y_pred, (y_pred.shape[0], y_pred.shape[1],
419 | (x + x_r) // self.patch_size, self.patch_size,
420 | (y + y_r) // self.patch_size, self.patch_size))
421 | y_pred_patch = y_pred_patch.permute(0, 1, 2, 4, 3, 5)
422 | y_pred_patch = torch.reshape(y_pred_patch, (-1, self.patch_size ** 2, 1))
423 |
424 | """Compute MI"""
425 | I_a_patch = torch.exp(- self.preterm * torch.square(y_true_patch - vbc))
426 | I_a_patch = I_a_patch / torch.sum(I_a_patch, dim=-1, keepdim=True)
427 |
428 | I_b_patch = torch.exp(- self.preterm * torch.square(y_pred_patch - vbc))
429 | I_b_patch = I_b_patch / torch.sum(I_b_patch, dim=-1, keepdim=True)
430 |
431 | pab = torch.bmm(I_a_patch.permute(0, 2, 1), I_b_patch)
432 | pab = pab / self.patch_size ** ndim
433 | pa = torch.mean(I_a_patch, dim=1, keepdim=True)
434 | pb = torch.mean(I_b_patch, dim=1, keepdim=True)
435 |
436 | papb = torch.bmm(pa.permute(0, 2, 1), pb) + 1e-6
437 | mi = torch.sum(torch.sum(pab * torch.log(pab / papb + 1e-6), dim=1), dim=1)
438 | return mi.mean()
439 |
440 | def forward(self,y_true, y_pred):
441 | return -self.local_mi(y_true, y_pred)
442 |
--------------------------------------------------------------------------------
/ViT-V-Net/models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import copy
7 | import logging
8 | import math
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as nnf
12 | from torch.nn import Dropout, Softmax, Linear, Conv3d, LayerNorm
13 | from torch.nn.modules.utils import _pair, _triple
14 | import configs as configs
15 | from torch.distributions.normal import Normal
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
21 | ATTENTION_K = "MultiHeadDotProductAttention_1/key"
22 | ATTENTION_V = "MultiHeadDotProductAttention_1/value"
23 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
24 | FC_0 = "MlpBlock_3/Dense_0"
25 | FC_1 = "MlpBlock_3/Dense_1"
26 | ATTENTION_NORM = "LayerNorm_0"
27 | MLP_NORM = "LayerNorm_2"
28 |
29 |
30 | def np2th(weights, conv=False):
31 | """Possibly convert HWIO to OIHW."""
32 | if conv:
33 | weights = weights.transpose([3, 2, 0, 1])
34 | return torch.from_numpy(weights)
35 |
36 |
37 | def swish(x):
38 | return x * torch.sigmoid(x)
39 |
40 |
41 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
42 |
43 |
44 | class Attention(nn.Module):
45 | def __init__(self, config, vis):
46 | super(Attention, self).__init__()
47 | self.vis = vis
48 | self.num_attention_heads = config.transformer["num_heads"]
49 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
50 | self.all_head_size = self.num_attention_heads * self.attention_head_size
51 |
52 | self.query = Linear(config.hidden_size, self.all_head_size)
53 | self.key = Linear(config.hidden_size, self.all_head_size)
54 | self.value = Linear(config.hidden_size, self.all_head_size)
55 |
56 | self.out = Linear(config.hidden_size, config.hidden_size)
57 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
58 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
59 |
60 | self.softmax = Softmax(dim=-1)
61 |
62 | def transpose_for_scores(self, x):
63 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
64 | x = x.view(*new_x_shape)
65 | return x.permute(0, 2, 1, 3)
66 |
67 | def forward(self, hidden_states):
68 | mixed_query_layer = self.query(hidden_states)
69 | mixed_key_layer = self.key(hidden_states)
70 | mixed_value_layer = self.value(hidden_states)
71 |
72 | query_layer = self.transpose_for_scores(mixed_query_layer)
73 | key_layer = self.transpose_for_scores(mixed_key_layer)
74 | value_layer = self.transpose_for_scores(mixed_value_layer)
75 |
76 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
77 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
78 | attention_probs = self.softmax(attention_scores)
79 | weights = attention_probs if self.vis else None
80 | attention_probs = self.attn_dropout(attention_probs)
81 |
82 | context_layer = torch.matmul(attention_probs, value_layer)
83 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
84 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
85 | context_layer = context_layer.view(*new_context_layer_shape)
86 | attention_output = self.out(context_layer)
87 | attention_output = self.proj_dropout(attention_output)
88 | return attention_output, weights
89 |
90 |
91 | class Mlp(nn.Module):
92 | def __init__(self, config):
93 | super(Mlp, self).__init__()
94 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
95 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
96 | self.act_fn = ACT2FN["gelu"]
97 | self.dropout = Dropout(config.transformer["dropout_rate"])
98 |
99 | self._init_weights()
100 |
101 | def _init_weights(self):
102 | nn.init.xavier_uniform_(self.fc1.weight)
103 | nn.init.xavier_uniform_(self.fc2.weight)
104 | nn.init.normal_(self.fc1.bias, std=1e-6)
105 | nn.init.normal_(self.fc2.bias, std=1e-6)
106 |
107 | def forward(self, x):
108 | x = self.fc1(x)
109 | x = self.act_fn(x)
110 | x = self.dropout(x)
111 | x = self.fc2(x)
112 | x = self.dropout(x)
113 | return x
114 |
115 |
116 | class Embeddings(nn.Module):
117 | """Construct the embeddings from patch, position embeddings.
118 | """
119 | def __init__(self, config, img_size):
120 | super(Embeddings, self).__init__()
121 | self.config = config
122 | down_factor = config.down_factor
123 | patch_size = _triple(config.patches["size"])
124 | n_patches = int((img_size[0]/2**down_factor// patch_size[0]) * (img_size[1]/2**down_factor// patch_size[1]) * (img_size[2]/2**down_factor// patch_size[2]))
125 | self.hybrid_model = CNNEncoder(config, n_channels=2)
126 | in_channels = config['encoder_channels'][-1]
127 | self.patch_embeddings = Conv3d(in_channels=in_channels,
128 | out_channels=config.hidden_size,
129 | kernel_size=patch_size,
130 | stride=patch_size)
131 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
132 |
133 | self.dropout = Dropout(config.transformer["dropout_rate"])
134 |
135 | def forward(self, x):
136 | x, features = self.hybrid_model(x)
137 | x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
138 | x = x.flatten(2)
139 | x = x.transpose(-1, -2) # (B, n_patches, hidden)
140 | embeddings = x + self.position_embeddings
141 | embeddings = self.dropout(embeddings)
142 | return embeddings, features
143 |
144 |
145 | class Block(nn.Module):
146 | def __init__(self, config, vis):
147 | super(Block, self).__init__()
148 | self.hidden_size = config.hidden_size
149 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
150 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
151 | self.ffn = Mlp(config)
152 | self.attn = Attention(config, vis)
153 |
154 | def forward(self, x):
155 | h = x
156 |
157 | x = self.attention_norm(x)
158 | x, weights = self.attn(x)
159 | x = x + h
160 |
161 | h = x
162 | x = self.ffn_norm(x)
163 | x = self.ffn(x)
164 | x = x + h
165 | return x, weights
166 |
167 | class Encoder(nn.Module):
168 | def __init__(self, config, vis):
169 | super(Encoder, self).__init__()
170 | self.vis = vis
171 | self.layer = nn.ModuleList()
172 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
173 | for _ in range(config.transformer["num_layers"]):
174 | layer = Block(config, vis)
175 | self.layer.append(copy.deepcopy(layer))
176 |
177 | def forward(self, hidden_states):
178 | attn_weights = []
179 | for layer_block in self.layer:
180 | hidden_states, weights = layer_block(hidden_states)
181 | if self.vis:
182 | attn_weights.append(weights)
183 | encoded = self.encoder_norm(hidden_states)
184 | return encoded, attn_weights
185 |
186 |
187 | class Transformer(nn.Module):
188 | def __init__(self, config, img_size, vis):
189 | super(Transformer, self).__init__()
190 | self.embeddings = Embeddings(config, img_size=img_size)
191 | self.encoder = Encoder(config, vis)
192 |
193 | def forward(self, input_ids):
194 | embedding_output, features = self.embeddings(input_ids)
195 | encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
196 | return encoded, attn_weights, features
197 |
198 |
199 | class Conv3dReLU(nn.Sequential):
200 | def __init__(
201 | self,
202 | in_channels,
203 | out_channels,
204 | kernel_size,
205 | padding=0,
206 | stride=1,
207 | use_batchnorm=True,
208 | ):
209 | conv = nn.Conv3d(
210 | in_channels,
211 | out_channels,
212 | kernel_size,
213 | stride=stride,
214 | padding=padding,
215 | bias=not (use_batchnorm),
216 | )
217 | relu = nn.ReLU(inplace=True)
218 |
219 | bn = nn.BatchNorm3d(out_channels)
220 |
221 | super(Conv3dReLU, self).__init__(conv, bn, relu)
222 |
223 |
224 | class DecoderBlock(nn.Module):
225 | def __init__(
226 | self,
227 | in_channels,
228 | out_channels,
229 | skip_channels=0,
230 | use_batchnorm=True,
231 | ):
232 | super().__init__()
233 | self.conv1 = Conv3dReLU(
234 | in_channels + skip_channels,
235 | out_channels,
236 | kernel_size=3,
237 | padding=1,
238 | use_batchnorm=use_batchnorm,
239 | )
240 | self.conv2 = Conv3dReLU(
241 | out_channels,
242 | out_channels,
243 | kernel_size=3,
244 | padding=1,
245 | use_batchnorm=use_batchnorm,
246 | )
247 | self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)
248 |
249 | def forward(self, x, skip=None):
250 | x = self.up(x)
251 | if skip is not None:
252 | x = torch.cat([x, skip], dim=1)
253 | x = self.conv1(x)
254 | x = self.conv2(x)
255 | return x
256 |
257 | class DecoderCup(nn.Module):
258 | def __init__(self, config, img_size):
259 | super().__init__()
260 | self.config = config
261 | self.down_factor = config.down_factor
262 | head_channels = config.conv_first_channel
263 | self.img_size = img_size
264 | self.conv_more = Conv3dReLU(
265 | config.hidden_size,
266 | head_channels,
267 | kernel_size=3,
268 | padding=1,
269 | use_batchnorm=True,
270 | )
271 | decoder_channels = config.decoder_channels
272 | in_channels = [head_channels] + list(decoder_channels[:-1])
273 | out_channels = decoder_channels
274 | self.patch_size = _triple(config.patches["size"])
275 | skip_channels = self.config.skip_channels
276 | blocks = [
277 | DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
278 | ]
279 | self.blocks = nn.ModuleList(blocks)
280 |
281 | def forward(self, hidden_states, features=None):
282 | B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
283 | l, h, w = (self.img_size[0]//2**self.down_factor//self.patch_size[0]), (self.img_size[1]//2**self.down_factor//self.patch_size[1]), (self.img_size[2]//2**self.down_factor//self.patch_size[2])
284 | x = hidden_states.permute(0, 2, 1)
285 | x = x.contiguous().view(B, hidden, l, h, w)
286 | x = self.conv_more(x)
287 | for i, decoder_block in enumerate(self.blocks):
288 | if features is not None:
289 | skip = features[i] if (i < self.config.n_skip) else None
290 | #print(skip.shape)
291 | else:
292 | skip = None
293 | x = decoder_block(x, skip=skip)
294 | return x
295 |
296 | class SpatialTransformer(nn.Module):
297 | """
298 | N-D Spatial Transformer
299 |
300 | Obtained from https://github.com/voxelmorph/voxelmorph
301 | """
302 |
303 | def __init__(self, size, mode='bilinear'):
304 | super().__init__()
305 |
306 | self.mode = mode
307 |
308 | # create sampling grid
309 | vectors = [torch.arange(0, s) for s in size]
310 | grids = torch.meshgrid(vectors)
311 | grid = torch.stack(grids)
312 | grid = torch.unsqueeze(grid, 0)
313 | grid = grid.type(torch.FloatTensor)
314 |
315 | # registering the grid as a buffer cleanly moves it to the GPU, but it also
316 | # adds it to the state dict. this is annoying since everything in the state dict
317 | # is included when saving weights to disk, so the model files are way bigger
318 | # than they need to be. so far, there does not appear to be an elegant solution.
319 | # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
320 | self.register_buffer('grid', grid)
321 |
322 | def forward(self, src, flow):
323 | # new locations
324 | new_locs = self.grid + flow
325 | shape = flow.shape[2:]
326 |
327 | # need to normalize grid values to [-1, 1] for resampler
328 | for i in range(len(shape)):
329 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
330 |
331 | # move channels dim to last position
332 | # also not sure why, but the channels need to be reversed
333 | if len(shape) == 2:
334 | new_locs = new_locs.permute(0, 2, 3, 1)
335 | new_locs = new_locs[..., [1, 0]]
336 | elif len(shape) == 3:
337 | new_locs = new_locs.permute(0, 2, 3, 4, 1)
338 | new_locs = new_locs[..., [2, 1, 0]]
339 |
340 | return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)
341 |
342 | class DoubleConv(nn.Module):
343 | """(convolution => [BN] => ReLU) * 2"""
344 |
345 | def __init__(self, in_channels, out_channels, mid_channels=None):
346 | super().__init__()
347 | if not mid_channels:
348 | mid_channels = out_channels
349 | self.double_conv = nn.Sequential(
350 | nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
351 | nn.ReLU(inplace=True),
352 | nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
353 | nn.ReLU(inplace=True)
354 | )
355 |
356 | def forward(self, x):
357 | return self.double_conv(x)
358 |
359 |
360 | class Down(nn.Module):
361 | """Downscaling with maxpool then double conv"""
362 |
363 | def __init__(self, in_channels, out_channels):
364 | super().__init__()
365 | self.maxpool_conv = nn.Sequential(
366 | nn.MaxPool3d(2),
367 | DoubleConv(in_channels, out_channels)
368 | )
369 |
370 | def forward(self, x):
371 | return self.maxpool_conv(x)
372 |
373 | class CNNEncoder(nn.Module):
374 | def __init__(self, config, n_channels=2):
375 | super(CNNEncoder, self).__init__()
376 | self.n_channels = n_channels
377 | decoder_channels = config.decoder_channels
378 | encoder_channels = config.encoder_channels
379 | self.down_num = config.down_num
380 | self.inc = DoubleConv(n_channels, encoder_channels[0])
381 | self.down1 = Down(encoder_channels[0], encoder_channels[1])
382 | self.down2 = Down(encoder_channels[1], encoder_channels[2])
383 | self.width = encoder_channels[-1]
384 | def forward(self, x):
385 | features = []
386 | x1 = self.inc(x)
387 | features.append(x1)
388 | x2 = self.down1(x1)
389 | features.append(x2)
390 | feats = self.down2(x2)
391 | features.append(feats)
392 | feats_down = feats
393 | for i in range(self.down_num):
394 | feats_down = nn.MaxPool3d(2)(feats_down)
395 | features.append(feats_down)
396 | return feats, features[::-1]
397 |
398 | class RegistrationHead(nn.Sequential):
399 | def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
400 | conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
401 | conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape))
402 | conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape))
403 | super().__init__(conv3d)
404 |
405 | class ViTVNet(nn.Module):
406 | def __init__(self, config, img_size=(64, 256, 256), int_steps=7, vis=False):
407 | super(ViTVNet, self).__init__()
408 | self.transformer = Transformer(config, img_size, vis)
409 | self.decoder = DecoderCup(config, img_size)
410 | self.reg_head = RegistrationHead(
411 | in_channels=config.decoder_channels[-1],
412 | out_channels=config['n_dims'],
413 | kernel_size=3,
414 | )
415 | self.spatial_trans = SpatialTransformer(img_size)
416 | self.config = config
417 | #self.integrate = VecInt(img_size, int_steps)
418 | def forward(self, x):
419 |
420 | source = x[:,0:1,:,:]
421 |
422 | x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
423 | x = self.decoder(x, features)
424 | flow = self.reg_head(x)
425 | #flow = self.integrate(flow)
426 | out = self.spatial_trans(source, flow)
427 | return out, flow
428 |
429 | class VecInt(nn.Module):
430 | """
431 | Integrates a vector field via scaling and squaring.
432 |
433 | Obtained from https://github.com/voxelmorph/voxelmorph
434 | """
435 |
436 | def __init__(self, inshape, nsteps):
437 | super().__init__()
438 |
439 | assert nsteps >= 0, 'nsteps should be >= 0, found: %d' % nsteps
440 | self.nsteps = nsteps
441 | self.scale = 1.0 / (2 ** self.nsteps)
442 | self.transformer = SpatialTransformer(inshape)
443 |
444 | def forward(self, vec):
445 | vec = vec * self.scale
446 | for _ in range(self.nsteps):
447 | vec = vec + self.transformer(vec, vec)
448 | return vec
449 |
450 | CONFIGS = {
451 | 'ViT-V-Net': configs.get_3DReg_config(),
452 | }
453 |
--------------------------------------------------------------------------------
/ViT-V-Net/train.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 | import os, utils, glob, losses
3 | import sys
4 | from torch.utils.data import DataLoader
5 | from data import datasets, trans
6 | import numpy as np
7 | import torch, models
8 | from torchvision import transforms
9 | from torch import optim
10 | import torch.nn as nn
11 | import matplotlib.pyplot as plt
12 | from models import CONFIGS as CONFIGS_ViT_seg
13 | from natsort import natsorted
14 |
15 | class AverageMeter(object):
16 | """Computes and stores the average and current value"""
17 | def __init__(self):
18 | self.reset()
19 |
20 | def reset(self):
21 | self.val = 0
22 | self.avg = 0
23 | self.sum = 0
24 | self.count = 0
25 |
26 | def update(self, val, n=1):
27 | self.val = val
28 | self.sum += val * n
29 | self.count += n
30 | self.avg = self.sum / self.count
31 |
32 | def MSE_torch(x, y):
33 | return torch.mean((x - y) ** 2)
34 |
35 | def main():
36 | batch_size = 2
37 | train_dir = 'D:/DATA/JHUBrain/Train/'
38 | val_dir = 'D:/DATA/JHUBrain/Val/'
39 | save_dir = 'ViTVNet_reg0.02_mse_diff/'
40 | lr = 0.0001
41 | epoch_start = 0
42 | max_epoch = 500
43 | cont_training = False
44 | config_vit = CONFIGS_ViT_seg['ViT-V-Net']
45 | reg_model = utils.register_model((160, 192, 224), 'nearest')
46 | reg_model.cuda()
47 | model = models.ViTVNet(config_vit, img_size=(160, 192, 224))
48 | if cont_training:
49 | epoch_start = 335
50 | model_dir = 'experiments/'+save_dir
51 | updated_lr = round(lr * np.power(1 - (epoch_start) / max_epoch,0.9),8)
52 | best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[0])['state_dict']
53 | model.load_state_dict(best_model)
54 | else:
55 | updated_lr = lr
56 | model.cuda()
57 | train_composed = transforms.Compose([trans.RandomFlip(0),
58 | trans.NumpyType((np.float32, np.float32)),
59 | ])
60 |
61 | val_composed = transforms.Compose([trans.Seg_norm(), #rearrange segmentation label to 1 to 46
62 | trans.NumpyType((np.float32, np.int16)),
63 | ])
64 |
65 | train_set = datasets.JHUBrainDataset(glob.glob(train_dir + '*.pkl'), transforms=train_composed)
66 | val_set = datasets.JHUBrainInferDataset(glob.glob(val_dir + '*.pkl'), transforms=val_composed)
67 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
68 | val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
69 |
70 | optimizer = optim.Adam(model.parameters(), lr=updated_lr, weight_decay=0, amsgrad=True)
71 | criterion = nn.MSELoss()
72 | criterions = [criterion]
73 | weights = [1]
74 | # prepare deformation loss
75 | criterions += [losses.Grad3d(penalty='l2')]
76 | weights += [0.02]
77 | best_mse = 0
78 | writer = SummaryWriter(log_dir='ViTVNet_log')
79 | for epoch in range(epoch_start, max_epoch):
80 | print('Training Starts')
81 | '''
82 | Training
83 | '''
84 | loss_all = AverageMeter()
85 | idx = 0
86 | for data in train_loader:
87 | idx += 1
88 | model.train()
89 | adjust_learning_rate(optimizer, epoch, max_epoch, lr)
90 | data = [t.cuda() for t in data]
91 | x = data[0]
92 | y = data[1]
93 | x_in = torch.cat((x,y), dim=1)
94 | output = model(x_in)
95 | loss = 0
96 | loss_vals = []
97 | for n, loss_function in enumerate(criterions):
98 | curr_loss = loss_function(output[n], y) * weights[n]
99 | loss_vals.append(curr_loss)
100 | loss += curr_loss
101 | loss_all.update(loss.item(), y.numel())
102 | # compute gradient and do SGD step
103 | optimizer.zero_grad()
104 | loss.backward()
105 | optimizer.step()
106 |
107 | del x_in
108 | del output
109 | # flip fixed and moving images
110 | loss = 0
111 | x_in = torch.cat((y, x), dim=1)
112 | output = model(x_in)
113 | for n, loss_function in enumerate(criterions):
114 | curr_loss = loss_function(output[n], x) * weights[n]
115 | loss_vals[n] += curr_loss
116 | loss += curr_loss
117 | loss_all.update(loss.item(), y.numel())
118 | # compute gradient and do SGD step
119 | optimizer.zero_grad()
120 | loss.backward()
121 | optimizer.step()
122 |
123 | print('Iter {} of {} loss {:.4f}, Img Sim: {:.6f}, Reg: {:.6f}'.format(idx, len(train_loader), loss.item(), loss_vals[0].item()/2, loss_vals[1].item()/2))
124 |
125 | writer.add_scalar('Loss/train', loss_all.avg, epoch)
126 | print('Epoch {} loss {:.4f}'.format(epoch, loss_all.avg))
127 | '''
128 | Validation
129 | '''
130 | eval_dsc = AverageMeter()
131 | with torch.no_grad():
132 | for data in val_loader:
133 | model.eval()
134 | data = [t.cuda() for t in data]
135 | x = data[0]
136 | y = data[1]
137 | x_seg = data[2]
138 | y_seg = data[3]
139 | # x = x.squeeze(0).permute(1, 0, 2, 3)
140 | # y = y.squeeze(0).permute(1, 0, 2, 3)
141 | x_in = torch.cat((x, y), dim=1)
142 | output = model(x_in)
143 | def_out = reg_model([x_seg.cuda().float(), output[1].cuda()])
144 | dsc = utils.dice_val(def_out.long(), y_seg.long(), 46)
145 | eval_dsc.update(dsc.item(), x.size(0))
146 | print(eval_dsc.avg)
147 | best_mse = max(eval_dsc.avg, best_mse)
148 | save_checkpoint({
149 | 'epoch': epoch + 1,
150 | 'state_dict': model.state_dict(),
151 | 'best_mse': best_mse,
152 | 'optimizer': optimizer.state_dict(),
153 | }, save_dir='experiments/'+save_dir, filename='dsc{:.3f}.pth.tar'.format(eval_dsc.avg))
154 | writer.add_scalar('MSE/validate', eval_dsc.avg, epoch)
155 | plt.switch_backend('agg')
156 | pred_fig = comput_fig(def_out)
157 | x_fig = comput_fig(x_seg)
158 | tar_fig = comput_fig(y_seg)
159 | writer.add_figure('input', x_fig, epoch)
160 | plt.close(x_fig)
161 | writer.add_figure('ground truth', tar_fig, epoch)
162 | plt.close(tar_fig)
163 | writer.add_figure('prediction', pred_fig, epoch)
164 | plt.close(pred_fig)
165 | loss_all.reset()
166 | writer.close()
167 |
168 | def comput_fig(img):
169 | img = img.detach().cpu().numpy()[0, 0, 48:64, :, :]
170 | fig = plt.figure(figsize=(12,12), dpi=180)
171 | for i in range(img.shape[0]):
172 | plt.subplot(4, 4, i + 1)
173 | plt.axis('off')
174 | plt.imshow(img[i, :, :], cmap='gray')
175 | fig.subplots_adjust(wspace=0, hspace=0)
176 | return fig
177 |
178 | def adjust_learning_rate(optimizer, epoch, MAX_EPOCHES, INIT_LR, power=0.9):
179 | for param_group in optimizer.param_groups:
180 | param_group['lr'] = round(INIT_LR * np.power( 1 - (epoch) / MAX_EPOCHES ,power),8)
181 |
182 |
183 | def save_checkpoint(state, save_dir='models', filename='checkpoint.pth.tar', max_model_num=8):
184 | torch.save(state, save_dir+filename)
185 | model_lists = natsorted(glob.glob(save_dir + '*'))
186 | while len(model_lists) > max_model_num:
187 | os.remove(model_lists[0])
188 | model_lists = natsorted(glob.glob(save_dir + '*'))
189 |
190 | if __name__ == '__main__':
191 | '''
192 | GPU configuration
193 | '''
194 | GPU_iden = 0
195 | GPU_num = torch.cuda.device_count()
196 | print('Number of GPU: ' + str(GPU_num))
197 | for GPU_idx in range(GPU_num):
198 | GPU_name = torch.cuda.get_device_name(GPU_idx)
199 | print(' GPU #' + str(GPU_idx) + ': ' + GPU_name)
200 | torch.cuda.set_device(GPU_iden)
201 | GPU_avai = torch.cuda.is_available()
202 | print('Currently using: ' + torch.cuda.get_device_name(GPU_iden))
203 | print('If the GPU is available? ' + str(GPU_avai))
204 | main()
--------------------------------------------------------------------------------
/ViT-V-Net/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch.nn.functional as F
4 | import torch, sys
5 | from torch import nn
6 | import pystrum.pynd.ndutils as nd
7 |
8 | def sliding_predict(model, image, tile_size, n_dims, overlap=1/2, flip=False):
9 | image_size = image.shape
10 | stride_x = math.ceil(tile_size[0] * (1 - overlap))
11 | stride_y = math.ceil(tile_size[1] * (1 - overlap))
12 | stride_z = math.ceil(tile_size[2] * (1 - overlap))
13 | num_rows = int(math.ceil((image_size[2] - tile_size[0]) / stride_x) + 1)
14 | num_cols = int(math.ceil((image_size[3] - tile_size[1]) / stride_y) + 1)
15 | num_slcs = int(math.ceil((image_size[4] - tile_size[2]) / stride_z) + 1)
16 | total_predictions = torch.zeros((1, n_dims, image_size[2], image_size[3], image_size[4])).cuda()
17 | count_predictions = torch.zeros((image_size[2], image_size[3], image_size[4])).cuda()
18 | tile_counter = 0
19 | print(num_rows)
20 | for row in range(num_rows):
21 | for col in range(num_cols):
22 | for slc in range(num_slcs):
23 | x_min, y_min, z_min = int(row * stride_x), int(col * stride_y), int(slc * stride_z)
24 | x_max = x_min + tile_size[0]
25 | y_max = y_min + tile_size[1]
26 | z_max = z_min + tile_size[2]
27 | if x_max > image_size[2]:
28 | x_min = image_size[2] - stride_x
29 | x_max = image_size[2]
30 | if y_max > image_size[3]:
31 | y_min = image_size[3] - stride_y
32 | y_max = image_size[3]
33 | if z_max > image_size[4]:
34 | z_min = image_size[4] - stride_z
35 | y_max = image_size[4]
36 | img = image[:, :, x_min:x_max, y_min:y_max, z_min:z_max]
37 | padded_img = pad_image(img, tile_size)
38 | #print(padded_img.shape)
39 |
40 | tile_counter += 1
41 | padded_prediction = model(padded_img)[1]
42 | if flip:
43 | for dim in [-1, -2, -3]:
44 | fliped_img = padded_img.flip(dim)
45 | fliped_predictions = model(fliped_img)[1]
46 | padded_prediction = (fliped_predictions.flip(dim) + padded_prediction)
47 | padded_prediction = padded_prediction/4
48 | predictions = padded_prediction[:, :, :img.shape[2], :img.shape[3], :img.shape[4]]
49 | count_predictions[x_min:x_max, y_min:y_max, z_min:z_max] += 1
50 | total_predictions[:, :, x_min:x_max, y_min:y_max, z_min:z_max] += predictions.cuda()#.data.cpu().numpy()
51 | total_predictions /= count_predictions
52 | return total_predictions
53 |
54 | def pad_image(img, target_size):
55 | rows_to_pad = max(target_size[0] - img.shape[2], 0)
56 | cols_to_pad = max(target_size[1] - img.shape[3], 0)
57 | slcs_to_pad = max(target_size[2] - img.shape[4], 0)
58 | padded_img = F.pad(img, (0, slcs_to_pad, 0, cols_to_pad, 0, rows_to_pad), "constant", 0)
59 | return padded_img
60 |
61 | class SpatialTransformer(nn.Module):
62 | """
63 | N-D Spatial Transformer
64 | """
65 |
66 | def __init__(self, size, mode='bilinear'):
67 | super().__init__()
68 |
69 | self.mode = mode
70 |
71 | # create sampling grid
72 | vectors = [torch.arange(0, s) for s in size]
73 | grids = torch.meshgrid(vectors)
74 | grid = torch.stack(grids)
75 | grid = torch.unsqueeze(grid, 0)
76 | grid = grid.type(torch.FloatTensor).cuda()
77 |
78 | # registering the grid as a buffer cleanly moves it to the GPU, but it also
79 | # adds it to the state dict. this is annoying since everything in the state dict
80 | # is included when saving weights to disk, so the model files are way bigger
81 | # than they need to be. so far, there does not appear to be an elegant solution.
82 | # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
83 | self.register_buffer('grid', grid)
84 |
85 | def forward(self, src, flow):
86 | # new locations
87 | new_locs = self.grid + flow
88 | shape = flow.shape[2:]
89 |
90 | # need to normalize grid values to [-1, 1] for resampler
91 | for i in range(len(shape)):
92 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
93 |
94 | # move channels dim to last position
95 | # also not sure why, but the channels need to be reversed
96 | if len(shape) == 2:
97 | new_locs = new_locs.permute(0, 2, 3, 1)
98 | new_locs = new_locs[..., [1, 0]]
99 | elif len(shape) == 3:
100 | new_locs = new_locs.permute(0, 2, 3, 4, 1)
101 | new_locs = new_locs[..., [2, 1, 0]]
102 |
103 | return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode)
104 |
105 | class register_model(nn.Module):
106 | def __init__(self, img_size=(64, 256, 256), mode='bilinear'):
107 | super(register_model, self).__init__()
108 | self.spatial_trans = SpatialTransformer(img_size, mode)
109 |
110 | def forward(self, x):
111 | img = x[0].cuda()
112 | flow = x[1].cuda()
113 | out = self.spatial_trans(img, flow)
114 | return out
115 |
116 | def dice_val(y_pred, y_true, num_clus):
117 | y_pred = nn.functional.one_hot(y_pred, num_classes=num_clus)
118 | y_pred = torch.squeeze(y_pred, 1)
119 | y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous()
120 | y_true = nn.functional.one_hot(y_true, num_classes=num_clus)
121 | y_true = torch.squeeze(y_true, 1)
122 | y_true = y_true.permute(0, 4, 1, 2, 3).contiguous()
123 | intersection = y_pred * y_true
124 | intersection = intersection.sum(dim=[2, 3, 4])
125 | union = y_pred.sum(dim=[2, 3, 4]) + y_true.sum(dim=[2, 3, 4])
126 | dsc = (2.*intersection) / (union + 1e-5)
127 | return torch.mean(torch.mean(dsc, dim=1))
128 |
129 | def jacobian_determinant(disp):
130 | """
131 | jacobian determinant of a displacement field.
132 | NB: to compute the spatial gradients, we use np.gradient.
133 | Parameters:
134 | disp: 3D displacement field of size [nb_dims, *vol_shape]
135 | Returns:
136 | jacobian determinant (matrix)
137 | """
138 |
139 | # check inputs
140 | volshape = disp.shape[1:]
141 | nb_dims = len(volshape)
142 | assert len(volshape) in (2, 3), 'flow has to be 2D or 3D'
143 |
144 | # compute grid
145 | grid_lst = nd.volsize2ndgrid(volshape)
146 | grid = np.stack(grid_lst, 0)
147 |
148 | # compute gradients
149 | [xFX, xFY, xFZ] = np.gradient(grid[0] - disp[0])
150 | [yFX, yFY, yFZ] = np.gradient(grid[1] - disp[1])
151 | [zFX, zFY, zFZ] = np.gradient(grid[2] - disp[2])
152 |
153 | jac_det = np.zeros(grid[0].shape)
154 | for i in range(grid.shape[1]):
155 | for j in range(grid.shape[2]):
156 | for k in range(grid.shape[3]):
157 | jac_mij = [[xFX[i, j, k], xFY[i, j, k], xFZ[i, j, k]], [yFX[i, j, k], yFY[i, j, k], yFZ[i, j, k]], [zFX[i, j, k], zFY[i, j, k], zFZ[i, j, k]]]
158 | jac_det[i, j, k] = np.linalg.det(jac_mij)
159 | return jac_det
160 |
161 |
162 | import re
163 | def process_label():
164 | #process labeling information for FreeSurfer
165 | seg_table = [0, 2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26,
166 | 28, 30, 31, 41, 42, 43, 44, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 62,
167 | 63, 72, 77, 80, 85, 251, 252, 253, 254, 255]
168 |
169 |
170 | file1 = open('label_info.txt', 'r')
171 | Lines = file1.readlines()
172 | dict = {}
173 | seg_i = 0
174 | seg_look_up = []
175 | for seg_label in seg_table:
176 | for line in Lines:
177 | line = re.sub(' +', ' ',line).split(' ')
178 | try:
179 | int(line[0])
180 | except:
181 | continue
182 | if int(line[0]) == seg_label:
183 | seg_look_up.append([seg_i, int(line[0]), line[1]])
184 | dict[seg_i] = line[1]
185 | seg_i += 1
186 | return dict
187 |
188 | def write2csv(line, name):
189 | with open(name+'.csv', 'a') as file:
190 | file.write(line)
191 | file.write('\n')
192 |
193 | def dice_val_substruct(y_pred, y_true, std_idx):
194 | with torch.no_grad():
195 | y_pred = nn.functional.one_hot(y_pred, num_classes=46)
196 | y_pred = torch.squeeze(y_pred, 1)
197 | y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous()
198 | y_true = nn.functional.one_hot(y_true, num_classes=46)
199 | y_true = torch.squeeze(y_true, 1)
200 | y_true = y_true.permute(0, 4, 1, 2, 3).contiguous()
201 | y_pred = y_pred.detach().cpu().numpy()
202 | y_true = y_true.detach().cpu().numpy()
203 |
204 | line = 'p_{}'.format(std_idx)
205 | for i in range(46):
206 | pred_clus = y_pred[0, i, ...]
207 | true_clus = y_true[0, i, ...]
208 | intersection = pred_clus * true_clus
209 | intersection = intersection.sum()
210 | union = pred_clus.sum() + true_clus.sum()
211 | dsc = (2.*intersection) / (union + 1e-5)
212 | line = line+','+str(dsc)
213 | return line
214 |
215 |
--------------------------------------------------------------------------------
/figures/ViTVNet_res.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/figures/ViTVNet_res.jpg
--------------------------------------------------------------------------------
/figures/dice_details_.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/figures/dice_details_.jpg
--------------------------------------------------------------------------------
/figures/net_arch.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/figures/net_arch.jpg
--------------------------------------------------------------------------------
/figures/trans_arch.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/figures/trans_arch.jpg
--------------------------------------------------------------------------------