├── LICENSE ├── PreprocessingMRI.md ├── README.md ├── ViT-V-Net ├── __pycache__ │ ├── ViT_V_Net.cpython-38.pyc │ ├── configs.cpython-38.pyc │ ├── losses.cpython-38.pyc │ ├── models.cpython-38.pyc │ ├── utils.cpython-38.pyc │ └── vit_reg_configs.cpython-38.pyc ├── configs.py ├── data │ ├── __pycache__ │ │ ├── data_utils.cpython-38.pyc │ │ ├── datasets.cpython-38.pyc │ │ ├── rand.cpython-38.pyc │ │ └── trans.cpython-38.pyc │ ├── data_utils.py │ ├── datasets.py │ ├── rand.py │ └── trans.py ├── infer.py ├── label_info.txt ├── losses.py ├── models.py ├── train.py └── utils.py └── figures ├── ViTVNet_res.jpg ├── dice_details_.jpg ├── net_arch.jpg └── trans_arch.jpg /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Junyu Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PreprocessingMRI.md: -------------------------------------------------------------------------------- 1 | 1. Install FreeSurfer from https://surfer.nmr.mgh.harvard.edu/fswiki/DownloadAndInstall 2 | 2. ```export FREESURFER_HOME=/your_freesurfer_directory``` 3 | 3. ```source $FREESURFER_HOME/SetUpFreeSurfer.sh``` 4 | 4. ```export SUBJECTS_DIR=/dataset_directory``` 5 | 5. ```recon-all -parallel -i dataset_directory/img_name.nii -autorecon1 -subjid img_name``` -> This step does motion correction, skull stripping, affine transform comuptation, and intensity normalization. 6 | 6. ```mri_convert dataset_directory/img_name/mri/brainmask.mgz dataset_directory/img_name/mri/brainmask.nii.gz``` -> This step converts the preprocessed image from .mgz into .nii format. 7 | 7. ```mri_convert dataset_directory/img_name/mri/brainmask.mgz --apply_transform dataset_directory/img_name/mri/transforms/talairach.xfm -o dataset_directory/img_name/mri/brainmask_align.mgz``` -> This step does affine tranform to Talairach space. 8 | 8. ```mri_convert dataset_directory/img_name/mri/brainmask_align.mgz dataset_directory/img_name/mri/brainmask_align.nii.gz``` -> This step converts the transformed image from .mgz into .nii format. 9 | 9. ```recon-all -parallel -s dataset_directory/img_name.nii -subcortseg -subjid img_name``` -> This step does subcortical segmentation. 10 | 10. ```mri_convert dataset_directory/img_name/mri/aseg.auto.mgz dataset_directory/img_name/mri/aseg.nii.gz``` -> This step converts label image from .mgz into .nii format. 11 | 11. ```mri_convert -rt nearest dataset_directory/img_name/mri/aseg.auto.mgz --apply_transform dataset_directory/img_name/mri/transforms/talairach.xfm -o dataset_directory/img_name/mri/aseg_align.mgz``` -> This step does affine tranform to Talairach space using nearest neighbor interpolation for label image. 12 | 12. ```mri_convert dataset_directory/img_name/mri/aseg_align.mgz dataset_directory/img_name/mri/aseg_align.nii.gz``` -> This step converts the transformed label image from .mgz into .nii format. 13 | 14 | Note that these steps may take up to **12-24 hours per image** base on our experience. Therefore running these commands in parallel on a server or a cluster is recommended. 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ViT-V-Net: Vision Transformer for Volumetric Medical Image Registration 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2104.06468-b31b1b.svg)](https://arxiv.org/abs/2104.06468) 4 | 5 | ** Please also check out our newly proposed registration model :point_right: [TransMorph](https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration)**\ 6 | **The pretrained model and the quantitative results of ViT-V-Net on IXI dataset are available here: [IXI_dataset](https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration/blob/main/IXI/TransMorph_on_IXI.md).\ 7 | Additionally, we have made our preprocessed IXI dataset publicly available!** 8 | 9 | keywords: vision transformer, convolutional neural networks, image registration 10 | 11 | This is a **PyTorch** implementation of my short paper: 12 | 13 | Chen, Junyu, et al. "ViT-V-Net: Vision Transformer for Unsupervised Volumetric Medical Image Registration. " Medical Imaging with Deep Learning (MIDL), 2021. 14 | 15 | 16 | ***train.py*** is the training script. 17 | ***models.py*** contains ViT-V-Net model. 18 | 19 | ***Pretrained ViT-V-Net:*** pretrained model 20 | 21 | ***Dataset:*** Due to restrictions, we cannot distribute our brain MRI data. However, several brain MRI datasets are publicly available online: IXI, ADNI, OASIS, ABIDE, etc. Note that those datasets may not contain labels (segmentation). To generate labels, you can use FreeSurfer, which is an open-source software for normalizing brain MRI images. Here are some useful commands in FreeSurfer: Brain MRI preprocessing and subcortical segmentation using FreeSurfer. 22 | 23 | ## Model Architecture: 24 | 25 | 26 | ### Vision Transformer Achitecture: 27 | 28 | 29 | ## Example Results: 30 | 31 | 32 | ## Quantitative Results: 33 | 34 | 35 | 36 | ## Reference: 37 | TransUnet 38 | 39 | ViT-pytorch 40 | 41 | VoxelMorph 42 | 43 | 44 | If you find this code is useful in your research, please consider to cite: 45 | 46 | @inproceedings{chen2021vitvnet, 47 | title={ViT-V-Net: Vision Transformer for Unsupervised Volumetric Medical Image Registration}, 48 | author={Junyu Chen and Yufan He and Eric Frey and Ye Li and Yong Du}, 49 | booktitle={Medical Imaging with Deep Learning}, 50 | year={2021}, 51 | url={https://openreview.net/forum?id=h3HC1EU7AEz} 52 | } 53 | 54 | ### About Me 55 | -------------------------------------------------------------------------------- /ViT-V-Net/__pycache__/ViT_V_Net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/__pycache__/ViT_V_Net.cpython-38.pyc -------------------------------------------------------------------------------- /ViT-V-Net/__pycache__/configs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/__pycache__/configs.cpython-38.pyc -------------------------------------------------------------------------------- /ViT-V-Net/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /ViT-V-Net/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /ViT-V-Net/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /ViT-V-Net/__pycache__/vit_reg_configs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/__pycache__/vit_reg_configs.cpython-38.pyc -------------------------------------------------------------------------------- /ViT-V-Net/configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | def get_3DReg_config(): 4 | config = ml_collections.ConfigDict() 5 | config.patches = ml_collections.ConfigDict({'size': (8, 8, 8)}) 6 | config.patches.grid = (8, 8, 8) 7 | config.hidden_size = 252 8 | config.transformer = ml_collections.ConfigDict() 9 | config.transformer.mlp_dim = 3072 10 | config.transformer.num_heads = 12 11 | config.transformer.num_layers = 12 12 | config.transformer.attention_dropout_rate = 0.0 13 | config.transformer.dropout_rate = 0.1 14 | config.patch_size = 8 15 | 16 | config.conv_first_channel = 512 17 | config.encoder_channels = (16, 32, 32) 18 | config.down_factor = 2 19 | config.down_num = 2 20 | config.decoder_channels = (96, 48, 32, 32, 16) 21 | config.skip_channels = (32, 32, 32, 32, 16) 22 | config.n_dims = 3 23 | config.n_skip = 5 24 | return config 25 | -------------------------------------------------------------------------------- /ViT-V-Net/data/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/data/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /ViT-V-Net/data/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/data/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /ViT-V-Net/data/__pycache__/rand.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/data/__pycache__/rand.cpython-38.pyc -------------------------------------------------------------------------------- /ViT-V-Net/data/__pycache__/trans.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/ViT-V-Net/data/__pycache__/trans.cpython-38.pyc -------------------------------------------------------------------------------- /ViT-V-Net/data/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pickle 3 | import numpy as np 4 | import torch 5 | 6 | M = 2 ** 32 - 1 7 | 8 | 9 | def init_fn(worker): 10 | seed = torch.LongTensor(1).random_().item() 11 | seed = (seed + worker) % M 12 | np.random.seed(seed) 13 | random.seed(seed) 14 | 15 | 16 | def add_mask(x, mask, dim=1): 17 | mask = mask.unsqueeze(dim) 18 | shape = list(x.shape); 19 | shape[dim] += 21 20 | new_x = x.new(*shape).zero_() 21 | new_x = new_x.scatter_(dim, mask, 1.0) 22 | s = [slice(None)] * len(shape) 23 | s[dim] = slice(21, None) 24 | new_x[s] = x 25 | return new_x 26 | 27 | 28 | def sample(x, size): 29 | # https://gist.github.com/yoavram/4134617 30 | i = random.sample(range(x.shape[0]), size) 31 | return torch.tensor(x[i], dtype=torch.int16) 32 | # x = np.random.permutation(x) 33 | # return torch.tensor(x[:size]) 34 | 35 | 36 | def pkload(fname): 37 | with open(fname, 'rb') as f: 38 | return pickle.load(f) 39 | 40 | 41 | _shape = (240, 240, 155) 42 | 43 | 44 | def get_all_coords(stride): 45 | return torch.tensor( 46 | np.stack([v.reshape(-1) for v in 47 | np.meshgrid( 48 | *[stride // 2 + np.arange(0, s, stride) for s in _shape], 49 | indexing='ij')], 50 | -1), dtype=torch.int16) 51 | 52 | 53 | _zero = torch.tensor([0]) 54 | 55 | 56 | def gen_feats(): 57 | x, y, z = 240, 240, 155 58 | feats = np.stack( 59 | np.meshgrid( 60 | np.arange(x), np.arange(y), np.arange(z), 61 | indexing='ij'), -1).astype('float32') 62 | shape = np.array([x, y, z]) 63 | feats -= shape / 2.0 64 | feats /= shape 65 | 66 | return feats -------------------------------------------------------------------------------- /ViT-V-Net/data/datasets.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import torch, sys 3 | from torch.utils.data import Dataset 4 | from .data_utils import pkload 5 | import matplotlib.pyplot as plt 6 | 7 | import numpy as np 8 | 9 | 10 | class JHUBrainDataset(Dataset): 11 | def __init__(self, data_path, transforms): 12 | self.paths = data_path 13 | self.transforms = transforms 14 | 15 | def one_hot(self, img, C): 16 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3])) 17 | for i in range(C): 18 | out[i,...] = img == i 19 | return out 20 | 21 | def __getitem__(self, index): 22 | path = self.paths[index] 23 | x, y = pkload(path) 24 | #print(x.shape) 25 | #print(x.shape) 26 | #print(np.unique(y)) 27 | # print(x.shape, y.shape)#(240, 240, 155) (240, 240, 155) 28 | # transforms work with nhwtc 29 | x, y = x[None, ...], y[None, ...] 30 | # print(x.shape, y.shape)#(1, 240, 240, 155) (1, 240, 240, 155) 31 | x,y = self.transforms([x, y]) 32 | #y = self.one_hot(y, 2) 33 | #print(y.shape) 34 | #sys.exit(0) 35 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth] 36 | y = np.ascontiguousarray(y) 37 | #plt.figure() 38 | #plt.subplot(1, 2, 1) 39 | #plt.imshow(x[0, :, :, 8], cmap='gray') 40 | #plt.subplot(1, 2, 2) 41 | #plt.imshow(y[0, :, :, 8], cmap='gray') 42 | #plt.show() 43 | #sys.exit(0) 44 | #y = np.squeeze(y, axis=0) 45 | x, y = torch.from_numpy(x), torch.from_numpy(y) 46 | return x, y 47 | 48 | def __len__(self): 49 | return len(self.paths) 50 | 51 | 52 | class JHUBrainInferDataset(Dataset): 53 | def __init__(self, data_path, transforms): 54 | self.paths = data_path 55 | self.transforms = transforms 56 | 57 | def one_hot(self, img, C): 58 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3])) 59 | for i in range(C): 60 | out[i,...] = img == i 61 | return out 62 | 63 | def __getitem__(self, index): 64 | path = self.paths[index] 65 | x, y, x_seg, y_seg = pkload(path) 66 | #print(x.shape) 67 | #print(x.shape) 68 | #print(np.unique(y)) 69 | # print(x.shape, y.shape)#(240, 240, 155) (240, 240, 155) 70 | # transforms work with nhwtc 71 | x, y = x[None, ...], y[None, ...] 72 | x_seg, y_seg= x_seg[None, ...], y_seg[None, ...] 73 | # print(x.shape, y.shape)#(1, 240, 240, 155) (1, 240, 240, 155) 74 | x, x_seg = self.transforms([x, x_seg]) 75 | y, y_seg = self.transforms([y, y_seg]) 76 | #y = self.one_hot(y, 2) 77 | #print(y.shape) 78 | #sys.exit(0) 79 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth] 80 | y = np.ascontiguousarray(y) 81 | x_seg = np.ascontiguousarray(x_seg) # [Bsize,channelsHeight,,Width,Depth] 82 | y_seg = np.ascontiguousarray(y_seg) 83 | #plt.figure() 84 | #plt.subplot(1, 2, 1) 85 | #plt.imshow(x[0, :, :, 8], cmap='gray') 86 | #plt.subplot(1, 2, 2) 87 | #plt.imshow(y[0, :, :, 8], cmap='gray') 88 | #plt.show() 89 | #sys.exit(0) 90 | #y = np.squeeze(y, axis=0) 91 | x, y, x_seg, y_seg = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(x_seg), torch.from_numpy(y_seg) 92 | return x, y, x_seg, y_seg 93 | 94 | def __len__(self): 95 | return len(self.paths) -------------------------------------------------------------------------------- /ViT-V-Net/data/rand.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | class Uniform(object): 5 | def __init__(self, a, b): 6 | self.a = a 7 | self.b = b 8 | 9 | def sample(self): 10 | return random.uniform(self.a, self.b) 11 | 12 | 13 | class Gaussian(object): 14 | def __init__(self, mean, std): 15 | self.mean = mean 16 | self.std = std 17 | 18 | def sample(self): 19 | return random.gauss(self.mean, self.std) 20 | 21 | 22 | class Constant(object): 23 | def __init__(self, val): 24 | self.val = val 25 | 26 | def sample(self): 27 | return self.val -------------------------------------------------------------------------------- /ViT-V-Net/data/trans.py: -------------------------------------------------------------------------------- 1 | # import math 2 | import random 3 | import collections 4 | import numpy as np 5 | import torch, sys, random, math 6 | from scipy import ndimage 7 | 8 | from .rand import Constant, Uniform, Gaussian 9 | from scipy.ndimage import rotate 10 | from skimage.transform import rescale, resize 11 | 12 | class Base(object): 13 | def sample(self, *shape): 14 | return shape 15 | 16 | def tf(self, img, k=0): 17 | return img 18 | 19 | def __call__(self, img, dim=3, reuse=False): # class -> func() 20 | # image: nhwtc 21 | # shape: no first dim 22 | if not reuse: 23 | im = img if isinstance(img, np.ndarray) else img[0] 24 | # how to know if the last dim is channel?? 25 | # nhwtc vs nhwt?? 26 | shape = im.shape[1:dim+1] 27 | # print(dim,shape) # 3, (240,240,155) 28 | self.sample(*shape) 29 | 30 | if isinstance(img, collections.Sequence): 31 | return [self.tf(x, k) for k, x in enumerate(img)] # img:k=0,label:k=1 32 | 33 | return self.tf(img) 34 | 35 | def __str__(self): 36 | return 'Identity()' 37 | 38 | Identity = Base 39 | 40 | # gemetric transformations, need a buffers 41 | # first axis is N 42 | class Rot90(Base): 43 | def __init__(self, axes=(0, 1)): 44 | self.axes = axes 45 | 46 | for a in self.axes: 47 | assert a > 0 48 | 49 | def sample(self, *shape): 50 | shape = list(shape) 51 | i, j = self.axes 52 | 53 | # shape: no first dim 54 | i, j = i-1, j-1 55 | shape[i], shape[j] = shape[j], shape[i] 56 | 57 | return shape 58 | 59 | def tf(self, img, k=0): 60 | return np.rot90(img, axes=self.axes) 61 | 62 | def __str__(self): 63 | return 'Rot90(axes=({}, {})'.format(*self.axes) 64 | 65 | # class RandomRotion(Base): 66 | # def __init__(self, angle=20):# angle :in degress, float, [0,360] 67 | # assert angle >= 0.0 68 | # self.axes = (0,1) # 只对HW方向进行旋转 69 | # self.angle = angle # 70 | # self.buffer = None 71 | # 72 | # def sample(self, *shape):# shape : [H,W,D] 73 | # shape = list(shape) 74 | # self.buffer = round(np.random.uniform(low=-self.angle,high=self.angle),2) # 2个小数点 75 | # if self.buffer < 0: 76 | # self.buffer += 180 77 | # return shape 78 | # 79 | # def tf(self, img, k=0): # img shape [1,H,W,D,c] while label shape is [1,H,W,D] 80 | # return ndimage.rotate(img, angle=self.buffer, reshape=False) 81 | # 82 | # def __str__(self): 83 | # return 'RandomRotion(axes=({}, {}),Angle:{}'.format(*self.axes,self.buffer) 84 | 85 | class RandomRotion(Base): 86 | def __init__(self,angle_spectrum=10): 87 | assert isinstance(angle_spectrum,int) 88 | # axes = [(2, 1), (3, 1),(3, 2)] 89 | axes = [(1, 0), (2, 1),(2, 0)] 90 | self.angle_spectrum = angle_spectrum 91 | self.axes = axes 92 | 93 | def sample(self,*shape): 94 | self.axes_buffer = self.axes[np.random.choice(list(range(len(self.axes))))] # choose the random direction 95 | self.angle_buffer = np.random.randint(-self.angle_spectrum, self.angle_spectrum) # choose the random direction 96 | return list(shape) 97 | 98 | def tf(self, img, k=0): 99 | """ Introduction: The rotation function supports the shape [H,W,D,C] or shape [H,W,D] 100 | :param img: if x, shape is [1,H,W,D,c]; if label, shape is [1,H,W,D] 101 | :param k: if x, k=0; if label, k=1 102 | """ 103 | bsize = img.shape[0] 104 | 105 | for bs in range(bsize): 106 | if k == 0: 107 | # [[H,W,D], ...] 108 | # print(img.shape) # (1, 128, 128, 128, 4) 109 | channels = [rotate(img[bs,:,:,:,c], self.angle_buffer, axes=self.axes_buffer, reshape=False, order=0, mode='constant', cval=-1) for c in 110 | range(img.shape[4])] 111 | img[bs,...] = np.stack(channels, axis=-1) 112 | 113 | if k == 1: 114 | img[bs,...] = rotate(img[bs,...], self.angle_buffer, axes=self.axes_buffer, reshape=False, order=0, mode='constant', cval=-1) 115 | 116 | return img 117 | 118 | def __str__(self): 119 | return 'RandomRotion(axes={},Angle:{}'.format(self.axes_buffer,self.angle_buffer) 120 | 121 | 122 | class Flip(Base): 123 | def __init__(self, axis=0): 124 | self.axis = axis 125 | 126 | def tf(self, img, k=0): 127 | return np.flip(img, self.axis) 128 | 129 | def __str__(self): 130 | return 'Flip(axis={})'.format(self.axis) 131 | 132 | class RandomFlip(Base): 133 | # mirror flip across all x,y,z 134 | def __init__(self,axis=0): 135 | # assert axis == (1,2,3) # For both data and label, it has to specify the axis. 136 | self.axis = (1,2,3) 137 | self.x_buffer = None 138 | self.y_buffer = None 139 | self.z_buffer = None 140 | 141 | def sample(self, *shape): 142 | self.x_buffer = np.random.choice([True,False]) 143 | self.y_buffer = np.random.choice([True,False]) 144 | self.z_buffer = np.random.choice([True,False]) 145 | return list(shape) # the shape is not changed 146 | 147 | def tf(self,img,k=0): # img shape is (1, 240, 240, 155, 4) 148 | if self.x_buffer: 149 | img = np.flip(img,axis=self.axis[0]) 150 | if self.y_buffer: 151 | img = np.flip(img,axis=self.axis[1]) 152 | if self.z_buffer: 153 | img = np.flip(img,axis=self.axis[2]) 154 | return img 155 | 156 | 157 | class RandSelect(Base): 158 | def __init__(self, prob=0.5, tf=None): 159 | self.prob = prob 160 | self.ops = tf if isinstance(tf, collections.Sequence) else (tf, ) 161 | self.buff = False 162 | 163 | def sample(self, *shape): 164 | self.buff = random.random() < self.prob 165 | 166 | if self.buff: 167 | for op in self.ops: 168 | shape = op.sample(*shape) 169 | 170 | return shape 171 | 172 | def tf(self, img, k=0): 173 | if self.buff: 174 | for op in self.ops: 175 | img = op.tf(img, k) 176 | return img 177 | 178 | def __str__(self): 179 | if len(self.ops) == 1: 180 | ops = str(self.ops[0]) 181 | else: 182 | ops = '[{}]'.format(', '.join([str(op) for op in self.ops])) 183 | return 'RandSelect({}, {})'.format(self.prob, ops) 184 | 185 | 186 | class CenterCrop(Base): 187 | def __init__(self, size): 188 | self.size = size 189 | self.buffer = None 190 | 191 | def sample(self, *shape): 192 | size = self.size 193 | start = [(s -size)//2 for s in shape] 194 | self.buffer = [slice(None)] + [slice(s, s+size) for s in start] 195 | return [size] * len(shape) 196 | 197 | def tf(self, img, k=0): 198 | # print(img.shape)#(1, 240, 240, 155, 4) 199 | return img[tuple(self.buffer)] 200 | # return img[self.buffer] 201 | 202 | def __str__(self): 203 | return 'CenterCrop({})'.format(self.size) 204 | 205 | class CenterCropBySize(CenterCrop): 206 | def sample(self, *shape): 207 | assert len(self.size) == 3 # random crop [H,W,T] from img [240,240,155] 208 | if not isinstance(self.size, list): 209 | size = list(self.size) 210 | else: 211 | size = self.size 212 | start = [(s-i)//2 for i, s in zip(size, shape)] 213 | self.buffer = [slice(None)] + [slice(s, s+i) for i, s in zip(size, start)] 214 | return size 215 | 216 | def __str__(self): 217 | return 'CenterCropBySize({})'.format(self.size) 218 | 219 | class RandCrop(CenterCrop): 220 | def sample(self, *shape): 221 | size = self.size 222 | start = [random.randint(0, s-size) for s in shape] 223 | self.buffer = [slice(None)] + [slice(s, s+size) for s in start] 224 | return [size]*len(shape) 225 | 226 | def __str__(self): 227 | return 'RandCrop({})'.format(self.size) 228 | 229 | 230 | class RandCrop3D(CenterCrop): 231 | def sample(self, *shape): # shape : [240,240,155] 232 | assert len(self.size)==3 # random crop [H,W,T] from img [240,240,155] 233 | if not isinstance(self.size,list): 234 | size = list(self.size) 235 | else: 236 | size = self.size 237 | start = [random.randint(0, s-i) for i,s in zip(size,shape)] 238 | self.buffer = [slice(None)] + [slice(s, s+k) for s,k in zip(start,size)] 239 | return size 240 | 241 | def __str__(self): 242 | return 'RandCrop({})'.format(self.size) 243 | 244 | # for data only 245 | class RandomIntensityChange(Base): 246 | def __init__(self,factor): 247 | shift,scale = factor 248 | assert (shift >0) and (scale >0) 249 | self.shift = shift 250 | self.scale = scale 251 | 252 | def tf(self,img,k=0): 253 | if k==1: 254 | return img 255 | 256 | shift_factor = np.random.uniform(-self.shift,self.shift,size=[1,img.shape[1],1,1,img.shape[4]]) # [-0.1,+0.1] 257 | scale_factor = np.random.uniform(1.0 - self.scale, 1.0 + self.scale,size=[1,img.shape[1],1,1,img.shape[4]]) # [0.9,1.1) 258 | # shift_factor = np.random.uniform(-self.shift,self.shift,size=[1,1,1,img.shape[3],img.shape[4]]) # [-0.1,+0.1] 259 | # scale_factor = np.random.uniform(1.0 - self.scale, 1.0 + self.scale,size=[1,1,1,img.shape[3],img.shape[4]]) # [0.9,1.1) 260 | return img * scale_factor + shift_factor 261 | 262 | def __str__(self): 263 | return 'random intensity shift per channels on the input image, including' 264 | 265 | class RandomGammaCorrection(Base): 266 | def __init__(self,factor): 267 | lower, upper = factor 268 | assert (lower >0) and (upper >0) 269 | self.lower = lower 270 | self.upper = upper 271 | 272 | def tf(self,img,k=0): 273 | if k==1: 274 | return img 275 | img = img + np.min(img) 276 | img_max = np.max(img) 277 | img = img/img_max 278 | factor = random.choice(np.arange(self.lower, self.upper, 0.1)) 279 | gamma = random.choice([1, factor]) 280 | if gamma == 1: 281 | return img 282 | img = img ** gamma * img_max 283 | img = (img - img.mean())/img.std() 284 | return img 285 | 286 | def __str__(self): 287 | return 'random intensity shift per channels on the input image, including' 288 | 289 | class MinMax_norm(Base): 290 | def __init__(self, ): 291 | a = None 292 | 293 | def tf(self, img, k=0): 294 | if k == 1: 295 | return img 296 | img = (img - img.min()) / (img.max()-img.min()) 297 | return img 298 | 299 | class Seg_norm(Base): 300 | def __init__(self, ): 301 | a = None 302 | self.seg_table = np.array([0, 2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 303 | 28, 30, 31, 41, 42, 43, 44, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 62, 304 | 63, 72, 77, 80, 85, 251, 252, 253, 254, 255]) 305 | def tf(self, img, k=0): 306 | if k == 0: 307 | return img 308 | img_out = np.zeros_like(img) 309 | for i in range(len(self.seg_table)): 310 | img_out[img == self.seg_table[i]] = i 311 | return img_out 312 | 313 | class Resize_img(Base): 314 | def __init__(self, shape): 315 | self.shape = shape 316 | 317 | def tf(self, img, k=0): 318 | if k == 1: 319 | img = resize(img, (img.shape[0], self.shape[0], self.shape[1], self.shape[2]), 320 | anti_aliasing=False, order=0) 321 | else: 322 | img = resize(img, (img.shape[0], self.shape[0], self.shape[1], self.shape[2]), 323 | anti_aliasing=False, order=3) 324 | return img 325 | 326 | class Pad(Base): 327 | def __init__(self, pad): # [0,0,0,5,0] 328 | self.pad = pad 329 | self.px = tuple(zip([0]*len(pad), pad)) 330 | 331 | def sample(self, *shape): 332 | 333 | shape = list(shape) 334 | 335 | # shape: no first dim 336 | for i in range(len(shape)): 337 | shape[i] += self.pad[i+1] 338 | 339 | return shape 340 | 341 | def tf(self, img, k=0): 342 | #nhwtc, nhwt 343 | dim = len(img.shape) 344 | return np.pad(img, self.px[:dim], mode='constant') 345 | 346 | def __str__(self): 347 | return 'Pad(({}, {}, {}))'.format(*self.pad) 348 | 349 | class Pad3DIfNeeded(Base): 350 | def __init__(self, shape, value=0, mask_value=0): # [0,0,0,5,0] 351 | self.shape = shape 352 | self.value = value 353 | self.mask_value = mask_value 354 | 355 | def tf(self, img, k=0): 356 | pad = [(0,0)] 357 | if k==0: 358 | img_shape = img.shape[1:-1] 359 | else: 360 | img_shape = img.shape[1:] 361 | for i, t in zip(img_shape, self.shape): 362 | if i < t: 363 | diff = t-i 364 | pad.append((math.ceil(diff/2),math.floor(diff/2))) 365 | else: 366 | pad.append((0,0)) 367 | if k == 0: 368 | pad.append((0,0)) 369 | pad = tuple(pad) 370 | if k==0: 371 | return np.pad(img, pad, mode='constant', constant_values=img.min()) 372 | else: 373 | return np.pad(img, pad, mode='constant', constant_values=self.mask_value) 374 | 375 | def __str__(self): 376 | return 'Pad(({}, {}, {}))'.format(*self.pad) 377 | 378 | class Noise(Base): 379 | def __init__(self, dim, sigma=0.1, channel=True, num=-1): 380 | self.dim = dim 381 | self.sigma = sigma 382 | self.channel = channel 383 | self.num = num 384 | 385 | def tf(self, img, k=0): 386 | if self.num > 0 and k >= self.num: 387 | return img 388 | 389 | if self.channel: 390 | #nhwtc, hwtc, hwt 391 | shape = [1] if len(img.shape) < self.dim+2 else [img.shape[-1]] 392 | else: 393 | shape = img.shape 394 | return img * np.exp(self.sigma * torch.randn(shape, dtype=torch.float32).numpy()) 395 | 396 | def __str__(self): 397 | return 'Noise()' 398 | 399 | 400 | # dim could come from shape 401 | class GaussianBlur(Base): 402 | def __init__(self, dim, sigma=Constant(1.5), app=-1): 403 | # 1.5 pixel 404 | self.dim = dim 405 | self.sigma = sigma 406 | self.eps = 0.001 407 | self.app = app 408 | 409 | def tf(self, img, k=0): 410 | if self.num > 0 and k >= self.num: 411 | return img 412 | 413 | # image is nhwtc 414 | for n in range(img.shape[0]): 415 | sig = self.sigma.sample() 416 | # sample each channel saperately to avoid correlations 417 | if sig > self.eps: 418 | if len(img.shape) == self.dim+2: 419 | C = img.shape[-1] 420 | for c in range(C): 421 | img[n,..., c] = ndimage.gaussian_filter(img[n, ..., c], sig) 422 | elif len(img.shape) == self.dim+1: 423 | img[n] = ndimage.gaussian_filter(img[n], sig) 424 | else: 425 | raise ValueError('image shape is not supported') 426 | 427 | return img 428 | 429 | def __str__(self): 430 | return 'GaussianBlur()' 431 | 432 | 433 | class ToNumpy(Base): 434 | def __init__(self, num=-1): 435 | self.num = num 436 | 437 | def tf(self, img, k=0): 438 | if self.num > 0 and k >= self.num: 439 | return img 440 | return img.numpy() 441 | 442 | def __str__(self): 443 | return 'ToNumpy()' 444 | 445 | 446 | class ToTensor(Base): 447 | def __init__(self, num=-1): 448 | self.num = num 449 | 450 | def tf(self, img, k=0): 451 | if self.num > 0 and k >= self.num: 452 | return img 453 | 454 | return torch.from_numpy(img) 455 | 456 | def __str__(self): 457 | return 'ToTensor' 458 | 459 | 460 | class TensorType(Base): 461 | def __init__(self, types, num=-1): 462 | self.types = types # ('torch.float32', 'torch.int64') 463 | self.num = num 464 | 465 | def tf(self, img, k=0): 466 | if self.num > 0 and k >= self.num: 467 | return img 468 | # make this work with both Tensor and Numpy 469 | return img.type(self.types[k]) 470 | 471 | def __str__(self): 472 | s = ', '.join([str(s) for s in self.types]) 473 | return 'TensorType(({}))'.format(s) 474 | 475 | 476 | class NumpyType(Base): 477 | def __init__(self, types, num=-1): 478 | self.types = types # ('float32', 'int64') 479 | self.num = num 480 | 481 | def tf(self, img, k=0): 482 | if self.num > 0 and k >= self.num: 483 | return img 484 | # make this work with both Tensor and Numpy 485 | return img.astype(self.types[k]) 486 | 487 | def __str__(self): 488 | s = ', '.join([str(s) for s in self.types]) 489 | return 'NumpyType(({}))'.format(s) 490 | 491 | 492 | class Normalize(Base): 493 | def __init__(self, mean=0.0, std=1.0, num=-1): 494 | self.mean = mean 495 | self.std = std 496 | self.num = num 497 | 498 | def tf(self, img, k=0): 499 | if self.num > 0 and k >= self.num: 500 | return img 501 | img -= self.mean 502 | img /= self.std 503 | return img 504 | 505 | def __str__(self): 506 | return 'Normalize()' 507 | 508 | 509 | class Compose(Base): 510 | def __init__(self, ops): 511 | if not isinstance(ops, collections.Sequence): 512 | ops = ops, 513 | self.ops = ops 514 | 515 | def sample(self, *shape): 516 | for op in self.ops: 517 | shape = op.sample(*shape) 518 | 519 | def tf(self, img, k=0): 520 | #is_tensor = isinstance(img, torch.Tensor) 521 | #if is_tensor: 522 | # img = img.numpy() 523 | 524 | for op in self.ops: 525 | # print(op,img.shape,k) 526 | img = op.tf(img, k) # do not use op(img) here 527 | 528 | #if is_tensor: 529 | # img = np.ascontiguousarray(img) 530 | # img = torch.from_numpy(img) 531 | 532 | return img 533 | 534 | def __str__(self): 535 | ops = ', '.join([str(op) for op in self.ops]) 536 | return 'Compose([{}])'.format(ops) -------------------------------------------------------------------------------- /ViT-V-Net/infer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from torch.utils.tensorboard import SummaryWriter 3 | import logging 4 | import os, losses, utils, nrrd 5 | import shutil 6 | import sys 7 | from torch.utils.data import DataLoader 8 | from data import datasets, trans 9 | import numpy as np 10 | import torch, models 11 | from torchvision import transforms 12 | from torch import optim 13 | import torch.nn as nn 14 | from ignite.contrib.handlers import ProgressBar 15 | from torchsummary import summary 16 | import matplotlib.pyplot as plt 17 | from models import CONFIGS as CONFIGS_ViT_seg 18 | from mpl_toolkits.mplot3d import axes3d 19 | from natsort import natsorted 20 | 21 | 22 | 23 | def plot_grid(gridx,gridy, **kwargs): 24 | for i in range(gridx.shape[1]): 25 | plt.plot(gridx[i,:], gridy[i,:], linewidth=0.8, **kwargs) 26 | for i in range(gridx.shape[0]): 27 | plt.plot(gridx[:,i], gridy[:,i], linewidth=0.8, **kwargs) 28 | 29 | class AverageMeter(object): 30 | """Computes and stores the average and current value""" 31 | def __init__(self): 32 | self.reset() 33 | 34 | def reset(self): 35 | self.val = 0 36 | self.avg = 0 37 | self.sum = 0 38 | self.count = 0 39 | self.vals = [] 40 | self.std = 0 41 | 42 | def update(self, val, n=1): 43 | self.val = val 44 | self.sum += val * n 45 | self.count += n 46 | self.avg = self.sum / self.count 47 | self.vals.append(val) 48 | self.std = np.std(self.vals) 49 | 50 | def MSE_torch(x, y): 51 | return torch.mean((x - y) ** 2) 52 | 53 | def MAE_torch(x, y): 54 | return torch.mean(torch.abs(x - y)) 55 | 56 | def main(): 57 | test_dir = 'D:/DATA/JHUBrain/Test/' 58 | model_idx = -1 59 | model_folder = 'ViTVNet_reg0.02_mse_diff/' 60 | model_dir = 'experiments/' + model_folder 61 | config_vit = CONFIGS_ViT_seg['ViT-V-Net'] 62 | dict = utils.process_label() 63 | if os.path.exists('experiments/'+model_folder[:-1]+'.csv'): 64 | os.remove('experiments/'+model_folder[:-1]+'.csv') 65 | csv_writter(model_folder[:-1], 'experiments/' + model_folder[:-1]) 66 | line = '' 67 | for i in range(46): 68 | line = line + ',' + dict[i] 69 | csv_writter(line, 'experiments/' + model_folder[:-1]) 70 | model = models.ViTVNet(config_vit, img_size=(160, 192, 224)) 71 | best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[model_idx])['state_dict'] 72 | print('Best model: {}'.format(natsorted(os.listdir(model_dir))[model_idx])) 73 | model.load_state_dict(best_model) 74 | model.cuda() 75 | reg_model = utils.register_model((160, 192, 224), 'nearest') 76 | reg_model.cuda() 77 | test_composed = transforms.Compose([trans.Seg_norm(), 78 | trans.NumpyType((np.float32, np.int16)), 79 | ]) 80 | test_set = datasets.JHUBrainInferDataset(glob.glob(test_dir + '*.pkl'), transforms=test_composed) 81 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=1, pin_memory=True, drop_last=True) 82 | eval_dsc_def = AverageMeter() 83 | eval_dsc_raw = AverageMeter() 84 | eval_det = AverageMeter() 85 | with torch.no_grad(): 86 | stdy_idx = 0 87 | for data in test_loader: 88 | model.eval() 89 | data = [t.cuda() for t in data] 90 | x = data[0] 91 | y = data[1] 92 | x_seg = data[2] 93 | y_seg = data[3] 94 | 95 | x_in = torch.cat((x,y),dim=1) 96 | x_def, flow = model(x_in) 97 | def_out = reg_model([x_seg.cuda().float(), flow.cuda()]) 98 | tar = y.detach().cpu().numpy()[0, 0, :, :, :] 99 | #jac_det = utils.jacobian_determinant(flow.detach().cpu().numpy()[0, :, :, :, :]) 100 | line = utils.dice_val_substruct(def_out.long(), y_seg.long(), stdy_idx) 101 | line = line #+','+str(np.sum(jac_det <= 0)/np.prod(tar.shape)) 102 | csv_writter(line, 'experiments/' + model_folder[:-1]) 103 | #eval_det.update(np.sum(jac_det <= 0) / np.prod(tar.shape), x.size(0)) 104 | 105 | dsc_trans = utils.dice_val(def_out.long(), y_seg.long(), 46) 106 | dsc_raw = utils.dice_val(x_seg.long(), y_seg.long(), 46) 107 | print('Trans diff: {:.4f}, Raw diff: {:.4f}'.format(dsc_trans.item(),dsc_raw.item())) 108 | eval_dsc_def.update(dsc_trans.item(), x.size(0)) 109 | eval_dsc_raw.update(dsc_raw.item(), x.size(0)) 110 | stdy_idx += 1 111 | 112 | # flip moving and fixed images 113 | y_in = torch.cat((y, x), dim=1) 114 | y_def, flow = model(y_in) 115 | def_out = reg_model([y_seg.cuda().float(), flow.cuda()]) 116 | tar = x.detach().cpu().numpy()[0, 0, :, :, :] 117 | 118 | #jac_det = utils.jacobian_determinant(flow.detach().cpu().numpy()[0, :, :, :, :]) 119 | line = utils.dice_val_substruct(def_out.long(), x_seg.long(), stdy_idx) 120 | line = line #+ ',' + str(np.sum(jac_det < 0) / np.prod(tar.shape)) 121 | out = def_out.detach().cpu().numpy()[0, 0, :, :, :] 122 | #print('det < 0: {}'.format(np.sum(jac_det <= 0)/np.prod(tar.shape))) 123 | csv_writter(line, 'experiments/' + model_folder[:-1]) 124 | #eval_det.update(np.sum(jac_det <= 0) / np.prod(tar.shape), x.size(0)) 125 | 126 | dsc_trans = utils.dice_val(def_out.long(), x_seg.long(), 46) 127 | dsc_raw = utils.dice_val(y_seg.long(), x_seg.long(), 46) 128 | print('Trans diff: {:.4f}, Raw diff: {:.4f}'.format(dsc_trans.item(), dsc_raw.item())) 129 | eval_dsc_def.update(dsc_trans.item(), x.size(0)) 130 | eval_dsc_raw.update(dsc_raw.item(), x.size(0)) 131 | stdy_idx += 1 132 | 133 | print('Deformed DSC: {:.3f} +- {:.3f}, Affine DSC: {:.3f} +- {:.3f}'.format(eval_dsc_def.avg, 134 | eval_dsc_def.std, 135 | eval_dsc_raw.avg, 136 | eval_dsc_raw.std)) 137 | print('deformed det: {}, std: {}'.format(eval_det.avg, eval_det.std)) 138 | 139 | def csv_writter(line, name): 140 | with open(name+'.csv', 'a') as file: 141 | file.write(line) 142 | file.write('\n') 143 | 144 | if __name__ == '__main__': 145 | ''' 146 | GPU configuration 147 | ''' 148 | GPU_iden = 0 149 | GPU_num = torch.cuda.device_count() 150 | print('Number of GPU: ' + str(GPU_num)) 151 | for GPU_idx in range(GPU_num): 152 | GPU_name = torch.cuda.get_device_name(GPU_idx) 153 | print(' GPU #' + str(GPU_idx) + ': ' + GPU_name) 154 | torch.cuda.set_device(GPU_iden) 155 | GPU_avai = torch.cuda.is_available() 156 | print('Currently using: ' + torch.cuda.get_device_name(GPU_iden)) 157 | print('If the GPU is available? ' + str(GPU_avai)) 158 | main() -------------------------------------------------------------------------------- /ViT-V-Net/label_info.txt: -------------------------------------------------------------------------------- 1 | 0 Unknown 0 0 0 0 2 | 1 Left-Cerebral-Exterior 70 130 180 0 3 | 2 Left-Cerebral-White-Matter 245 245 245 0 4 | 3 Left-Cerebral-Cortex 205 62 78 0 5 | 4 Left-Lateral-Ventricle 120 18 134 0 6 | 5 Left-Inf-Lat-Vent 196 58 250 0 7 | 6 Left-Cerebellum-Exterior 0 148 0 0 8 | 7 Left-Cerebellum-White-Matter 220 248 164 0 9 | 8 Left-Cerebellum-Cortex 230 148 34 0 10 | 9 Left-Thalamus 0 118 14 0 11 | 10 Left-Thalamus-Proper* 0 118 14 0 12 | 11 Left-Caudate 122 186 220 0 13 | 12 Left-Putamen 236 13 176 0 14 | 13 Left-Pallidum 12 48 255 0 15 | 14 3rd-Ventricle 204 182 142 0 16 | 15 4th-Ventricle 42 204 164 0 17 | 16 Brain-Stem 119 159 176 0 18 | 17 Left-Hippocampus 220 216 20 0 19 | 18 Left-Amygdala 103 255 255 0 20 | 19 Left-Insula 80 196 98 0 21 | 20 Left-Operculum 60 58 210 0 22 | 21 Line-1 60 58 210 0 23 | 22 Line-2 60 58 210 0 24 | 23 Line-3 60 58 210 0 25 | 24 CSF 60 60 60 0 26 | 25 Left-Lesion 255 165 0 0 27 | 26 Left-Accumbens-area 255 165 0 0 28 | 27 Left-Substancia-Nigra 0 255 127 0 29 | 28 Left-VentralDC 165 42 42 0 30 | 29 Left-undetermined 135 206 235 0 31 | 30 Left-vessel 160 32 240 0 32 | 31 Left-choroid-plexus 0 200 200 0 33 | 32 Left-F3orb 100 50 100 0 34 | 33 Left-lOg 135 50 74 0 35 | 34 Left-aOg 122 135 50 0 36 | 35 Left-mOg 51 50 135 0 37 | 36 Left-pOg 74 155 60 0 38 | 37 Left-Stellate 120 62 43 0 39 | 38 Left-Porg 74 155 60 0 40 | 39 Left-Aorg 122 135 50 0 41 | 40 Right-Cerebral-Exterior 70 130 180 0 42 | 41 Right-Cerebral-White-Matter 245 245 245 0 43 | 42 Right-Cerebral-Cortex 205 62 78 0 44 | 43 Right-Lateral-Ventricle 120 18 134 0 45 | 44 Right-Inf-Lat-Vent 196 58 250 0 46 | 45 Right-Cerebellum-Exterior 0 148 0 0 47 | 46 Right-Cerebellum-White-Matter 220 248 164 0 48 | 47 Right-Cerebellum-Cortex 230 148 34 0 49 | 48 Right-Thalamus 0 118 14 0 50 | 49 Right-Thalamus-Proper* 0 118 14 0 51 | 50 Right-Caudate 122 186 220 0 52 | 51 Right-Putamen 236 13 176 0 53 | 52 Right-Pallidum 13 48 255 0 54 | 53 Right-Hippocampus 220 216 20 0 55 | 54 Right-Amygdala 103 255 255 0 56 | 55 Right-Insula 80 196 98 0 57 | 56 Right-Operculum 60 58 210 0 58 | 57 Right-Lesion 255 165 0 0 59 | 58 Right-Accumbens-area 255 165 0 0 60 | 59 Right-Substancia-Nigra 0 255 127 0 61 | 60 Right-VentralDC 165 42 42 0 62 | 61 Right-undetermined 135 206 235 0 63 | 62 Right-vessel 160 32 240 0 64 | 63 Right-choroid-plexus 0 200 221 0 65 | 64 Right-F3orb 100 50 100 0 66 | 65 Right-lOg 135 50 74 0 67 | 66 Right-aOg 122 135 50 0 68 | 67 Right-mOg 51 50 135 0 69 | 68 Right-pOg 74 155 60 0 70 | 69 Right-Stellate 120 62 43 0 71 | 70 Right-Porg 74 155 60 0 72 | 71 Right-Aorg 122 135 50 0 73 | 72 5th-Ventricle 120 190 150 0 74 | 73 Left-Interior 122 135 50 0 75 | 74 Right-Interior 122 135 50 0 76 | 77 | 77 WM-hypointensities 200 70 255 0 78 | 78 Left-WM-hypointensities 255 148 10 0 79 | 79 Right-WM-hypointensities 255 148 10 0 80 | 80 non-WM-hypointensities 164 108 226 0 81 | 81 Left-non-WM-hypointensities 164 108 226 0 82 | 82 Right-non-WM-hypointensities 164 108 226 0 83 | 83 Left-F1 255 218 185 0 84 | 84 Right-F1 255 218 185 0 85 | 85 Optic-Chiasm 234 169 30 0 86 | 192 Corpus_Callosum 250 255 50 0 87 | 88 | 86 Left_future_WMSA 200 120 255 0 89 | 87 Right_future_WMSA 200 121 255 0 90 | 88 future_WMSA 200 122 255 0 91 | 92 | 93 | 96 Left-Amygdala-Anterior 205 10 125 0 94 | 97 Right-Amygdala-Anterior 205 10 125 0 95 | 98 Dura 160 32 240 0 96 | 97 | 100 Left-wm-intensity-abnormality 124 140 178 0 98 | 101 Left-caudate-intensity-abnormality 125 140 178 0 99 | 102 Left-putamen-intensity-abnormality 126 140 178 0 100 | 103 Left-accumbens-intensity-abnormality 127 140 178 0 101 | 104 Left-pallidum-intensity-abnormality 124 141 178 0 102 | 105 Left-amygdala-intensity-abnormality 124 142 178 0 103 | 106 Left-hippocampus-intensity-abnormality 124 143 178 0 104 | 107 Left-thalamus-intensity-abnormality 124 144 178 0 105 | 108 Left-VDC-intensity-abnormality 124 140 179 0 106 | 109 Right-wm-intensity-abnormality 124 140 178 0 107 | 110 Right-caudate-intensity-abnormality 125 140 178 0 108 | 111 Right-putamen-intensity-abnormality 126 140 178 0 109 | 112 Right-accumbens-intensity-abnormality 127 140 178 0 110 | 113 Right-pallidum-intensity-abnormality 124 141 178 0 111 | 114 Right-amygdala-intensity-abnormality 124 142 178 0 112 | 115 Right-hippocampus-intensity-abnormality 124 143 178 0 113 | 116 Right-thalamus-intensity-abnormality 124 144 178 0 114 | 117 Right-VDC-intensity-abnormality 124 140 179 0 115 | 116 | 118 Epidermis 255 20 147 0 117 | 119 Conn-Tissue 205 179 139 0 118 | 120 SC-Fat-Muscle 238 238 209 0 119 | 121 Cranium 200 200 200 0 120 | 122 CSF-SA 74 255 74 0 121 | 123 Muscle 238 0 0 0 122 | 124 Ear 0 0 139 0 123 | 125 Adipose 173 255 47 0 124 | 126 Spinal-Cord 133 203 229 0 125 | 127 Soft-Tissue 26 237 57 0 126 | 128 Nerve 34 139 34 0 127 | 129 Bone 30 144 255 0 128 | 130 Air 147 19 173 0 129 | 131 Orbital-Fat 238 59 59 0 130 | 132 Tongue 221 39 200 0 131 | 133 Nasal-Structures 238 174 238 0 132 | 134 Globe 255 0 0 0 133 | 135 Teeth 72 61 139 0 134 | 136 Left-Caudate-Putamen 21 39 132 0 135 | 137 Right-Caudate-Putamen 21 39 132 0 136 | 138 Left-Claustrum 65 135 20 0 137 | 139 Right-Claustrum 65 135 20 0 138 | 140 Cornea 134 4 160 0 139 | 142 Diploe 221 226 68 0 140 | 143 Vitreous-Humor 255 255 254 0 141 | 144 Lens 52 209 226 0 142 | 145 Aqueous-Humor 239 160 223 0 143 | 146 Outer-Table 70 130 180 0 144 | 147 Inner-Table 70 130 181 0 145 | 148 Periosteum 139 121 94 0 146 | 149 Endosteum 224 224 224 0 147 | 150 R-C-S 255 0 0 0 148 | 151 Iris 205 205 0 0 149 | 152 SC-Adipose-Muscle 238 238 209 0 150 | 153 SC-Tissue 139 121 94 0 151 | 154 Orbital-Adipose 238 59 59 0 152 | 153 | 155 Left-IntCapsule-Ant 238 59 59 0 154 | 156 Right-IntCapsule-Ant 238 59 59 0 155 | 157 Left-IntCapsule-Pos 62 10 205 0 156 | 158 Right-IntCapsule-Pos 62 10 205 0 157 | 158 | # These labels are for babies/children 159 | 159 Left-Cerebral-WM-unmyelinated 0 118 14 0 160 | 160 Right-Cerebral-WM-unmyelinated 0 118 14 0 161 | 161 Left-Cerebral-WM-myelinated 220 216 21 0 162 | 162 Right-Cerebral-WM-myelinated 220 216 21 0 163 | 163 Left-Subcortical-Gray-Matter 122 186 220 0 164 | 164 Right-Subcortical-Gray-Matter 122 186 220 0 165 | 165 Skull 120 120 120 0 166 | 166 Posterior-fossa 14 48 255 0 167 | 167 Scalp 166 42 42 0 168 | 168 Hematoma 121 18 134 0 169 | 169 Left-Basal-Ganglia 236 13 127 0 170 | 176 Right-Basal-Ganglia 236 13 126 0 171 | 172 | # Label names and colors for Brainstem consituents 173 | # No. Label Name: R G B A 174 | 170 brainstem 119 159 176 0 175 | 171 DCG 119 0 176 0 176 | 172 Vermis 119 100 176 0 177 | 173 Midbrain 242 104 76 0 178 | 174 Pons 206 195 58 0 179 | 175 Medulla 119 159 176 0 180 | 177 Vermis-White-Matter 119 50 176 0 181 | 178 SCP 142 182 0 0 182 | 179 Floculus 19 100 176 0 183 | 184 | 180 Left-Cortical-Dysplasia 73 61 139 0 185 | 181 Right-Cortical-Dysplasia 73 62 139 0 186 | 182 CblumNodulus 10 100 176 0 187 | 188 | 193 Left-hippocampal_fissure 0 196 255 0 189 | 194 Left-CADG-head 255 164 164 0 190 | 195 Left-subiculum 196 196 0 0 191 | 196 Left-fimbria 0 100 255 0 192 | 197 Right-hippocampal_fissure 128 196 164 0 193 | 198 Right-CADG-head 0 126 75 0 194 | 199 Right-subiculum 128 96 64 0 195 | 200 Right-fimbria 0 50 128 0 196 | 201 alveus 255 204 153 0 197 | 202 perforant_pathway 255 128 128 0 198 | 203 parasubiculum 255 255 0 0 199 | 204 presubiculum 64 0 64 0 200 | 205 subiculum 0 0 255 0 201 | 206 CA1 255 0 0 0 202 | 207 CA2 128 128 255 0 203 | 208 CA3 0 128 0 0 204 | 209 CA4 196 160 128 0 205 | 210 GC-DG 32 200 255 0 206 | 211 HATA 128 255 128 0 207 | 212 fimbria 204 153 204 0 208 | 213 lateral_ventricle 121 17 136 0 209 | 214 molecular_layer_HP 128 0 0 0 210 | 215 hippocampal_fissure 128 32 255 0 211 | 216 entorhinal_cortex 255 204 102 0 212 | 217 molecular_layer_subiculum 128 128 128 0 213 | 218 Amygdala 104 255 255 0 214 | 219 Cerebral_White_Matter 0 226 0 0 215 | 220 Cerebral_Cortex 205 63 78 0 216 | 221 Inf_Lat_Vent 197 58 250 0 217 | 222 Perirhinal 33 150 250 0 218 | 223 Cerebral_White_Matter_Edge 226 0 0 0 219 | 224 Background 100 100 100 0 220 | 225 Ectorhinal 197 150 250 0 221 | 226 HP_tail 170 170 255 0 222 | 223 | 250 Fornix 255 0 0 0 224 | 251 CC_Posterior 0 0 64 0 225 | 252 CC_Mid_Posterior 0 0 112 0 226 | 253 CC_Central 0 0 160 0 227 | 254 CC_Mid_Anterior 0 0 208 0 228 | 255 CC_Anterior 0 0 255 0 229 | -------------------------------------------------------------------------------- /ViT-V-Net/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | import math 7 | 8 | 9 | def gaussian(window_size, sigma): 10 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 11 | return gauss / gauss.sum() 12 | 13 | 14 | def create_window(window_size, channel): 15 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 16 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 17 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 18 | return window 19 | 20 | 21 | def create_window_3D(window_size, channel): 22 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 23 | _2D_window = _1D_window.mm(_1D_window.t()) 24 | _3D_window = _1D_window.mm(_2D_window.reshape(1, -1)).reshape(window_size, window_size, 25 | window_size).float().unsqueeze(0).unsqueeze(0) 26 | window = Variable(_3D_window.expand(channel, 1, window_size, window_size, window_size).contiguous()) 27 | return window 28 | 29 | 30 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 31 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 32 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 33 | 34 | mu1_sq = mu1.pow(2) 35 | mu2_sq = mu2.pow(2) 36 | mu1_mu2 = mu1 * mu2 37 | 38 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 39 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 40 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 41 | 42 | C1 = 0.01 ** 2 43 | C2 = 0.03 ** 2 44 | 45 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 46 | 47 | if size_average: 48 | return ssim_map.mean() 49 | else: 50 | return ssim_map.mean(1).mean(1).mean(1) 51 | 52 | 53 | def _ssim_3D(img1, img2, window, window_size, channel, size_average=True): 54 | mu1 = F.conv3d(img1, window, padding=window_size // 2, groups=channel) 55 | mu2 = F.conv3d(img2, window, padding=window_size // 2, groups=channel) 56 | 57 | mu1_sq = mu1.pow(2) 58 | mu2_sq = mu2.pow(2) 59 | 60 | mu1_mu2 = mu1 * mu2 61 | 62 | sigma1_sq = F.conv3d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 63 | sigma2_sq = F.conv3d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 64 | sigma12 = F.conv3d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 65 | 66 | C1 = 0.01 ** 2 67 | C2 = 0.03 ** 2 68 | 69 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 70 | 71 | if size_average: 72 | return ssim_map.mean() 73 | else: 74 | return ssim_map.mean(1).mean(1).mean(1) 75 | 76 | 77 | class SSIM(torch.nn.Module): 78 | def __init__(self, window_size=11, size_average=True): 79 | super(SSIM, self).__init__() 80 | self.window_size = window_size 81 | self.size_average = size_average 82 | self.channel = 1 83 | self.window = create_window(window_size, self.channel) 84 | 85 | def forward(self, img1, img2): 86 | (_, channel, _, _) = img1.size() 87 | 88 | if channel == self.channel and self.window.data.type() == img1.data.type(): 89 | window = self.window 90 | else: 91 | window = create_window(self.window_size, channel) 92 | 93 | if img1.is_cuda: 94 | window = window.cuda(img1.get_device()) 95 | window = window.type_as(img1) 96 | 97 | self.window = window 98 | self.channel = channel 99 | 100 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 101 | 102 | 103 | class SSIM3D(torch.nn.Module): 104 | def __init__(self, window_size=11, size_average=True): 105 | super(SSIM3D, self).__init__() 106 | self.window_size = window_size 107 | self.size_average = size_average 108 | self.channel = 1 109 | self.window = create_window_3D(window_size, self.channel) 110 | 111 | def forward(self, img1, img2): 112 | (_, channel, _, _, _) = img1.size() 113 | 114 | if channel == self.channel and self.window.data.type() == img1.data.type(): 115 | window = self.window 116 | else: 117 | window = create_window_3D(self.window_size, channel) 118 | 119 | if img1.is_cuda: 120 | window = window.cuda(img1.get_device()) 121 | window = window.type_as(img1) 122 | 123 | self.window = window 124 | self.channel = channel 125 | 126 | return 1-_ssim_3D(img1, img2, window, self.window_size, channel, self.size_average) 127 | 128 | 129 | def ssim(img1, img2, window_size=11, size_average=True): 130 | (_, channel, _, _) = img1.size() 131 | window = create_window(window_size, channel) 132 | 133 | if img1.is_cuda: 134 | window = window.cuda(img1.get_device()) 135 | window = window.type_as(img1) 136 | 137 | return _ssim(img1, img2, window, window_size, channel, size_average) 138 | 139 | 140 | def ssim3D(img1, img2, window_size=11, size_average=True): 141 | (_, channel, _, _, _) = img1.size() 142 | window = create_window_3D(window_size, channel) 143 | 144 | if img1.is_cuda: 145 | window = window.cuda(img1.get_device()) 146 | window = window.type_as(img1) 147 | 148 | return _ssim_3D(img1, img2, window, window_size, channel, size_average) 149 | 150 | 151 | class Grad(torch.nn.Module): 152 | """ 153 | N-D gradient loss. 154 | """ 155 | 156 | def __init__(self, penalty='l1', loss_mult=None): 157 | super(Grad, self).__init__() 158 | self.penalty = penalty 159 | self.loss_mult = loss_mult 160 | 161 | def forward(self, y_pred, y_true): 162 | dy = torch.abs(y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :]) 163 | dx = torch.abs(y_pred[:, :, :, 1:] - y_pred[:, :, :, :-1]) 164 | #dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1]) 165 | 166 | if self.penalty == 'l2': 167 | dy = dy * dy 168 | dx = dx * dx 169 | #dz = dz * dz 170 | 171 | d = torch.mean(dx) + torch.mean(dy)# + torch.mean(dz) 172 | grad = d / 2.0 173 | 174 | if self.loss_mult is not None: 175 | grad *= self.loss_mult 176 | return grad 177 | 178 | class Grad3d(torch.nn.Module): 179 | """ 180 | N-D gradient loss. 181 | """ 182 | 183 | def __init__(self, penalty='l1', loss_mult=None): 184 | super(Grad3d, self).__init__() 185 | self.penalty = penalty 186 | self.loss_mult = loss_mult 187 | 188 | def forward(self, y_pred, y_true): 189 | dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :]) 190 | dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :]) 191 | dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1]) 192 | 193 | if self.penalty == 'l2': 194 | dy = dy * dy 195 | dx = dx * dx 196 | dz = dz * dz 197 | 198 | d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz) 199 | grad = d / 3.0 200 | 201 | if self.loss_mult is not None: 202 | grad *= self.loss_mult 203 | return grad 204 | 205 | class Grad3DiTV(torch.nn.Module): 206 | """ 207 | N-D gradient loss. 208 | """ 209 | 210 | def __init__(self): 211 | super(Grad3DiTV, self).__init__() 212 | a = 1 213 | 214 | def forward(self, y_pred, y_true): 215 | dy = torch.abs(y_pred[:, :, 1:, 1:, 1:] - y_pred[:, :, :-1, 1:, 1:]) 216 | dx = torch.abs(y_pred[:, :, 1:, 1:, 1:] - y_pred[:, :, 1:, :-1, 1:]) 217 | dz = torch.abs(y_pred[:, :, 1:, 1:, 1:] - y_pred[:, :, 1:, 1:, :-1]) 218 | dy = dy * dy 219 | dx = dx * dx 220 | dz = dz * dz 221 | d = torch.mean(torch.sqrt(dx+dy+dz+1e-6)) 222 | grad = d / 3.0 223 | return grad 224 | 225 | class NCC(torch.nn.Module): 226 | """ 227 | Local (over window) normalized cross correlation loss. 228 | """ 229 | 230 | def __init__(self, win=None): 231 | super(NCC, self).__init__() 232 | self.win = win 233 | 234 | def forward(self, y_pred, y_true): 235 | 236 | I = y_true 237 | J = y_pred 238 | 239 | # get dimension of volume 240 | # assumes I, J are sized [batch_size, *vol_shape, nb_feats] 241 | ndims = len(list(I.size())) - 2 242 | assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims 243 | 244 | # set window size 245 | win = [9] * ndims if self.win is None else self.win 246 | 247 | # compute filters 248 | sum_filt = torch.ones([1, 1, *win]).to("cuda") 249 | 250 | pad_no = math.floor(win[0]/2) 251 | 252 | if ndims == 1: 253 | stride = (1) 254 | padding = (pad_no) 255 | elif ndims == 2: 256 | stride = (1,1) 257 | padding = (pad_no, pad_no) 258 | else: 259 | stride = (1,1,1) 260 | padding = (pad_no, pad_no, pad_no) 261 | 262 | # get convolution function 263 | conv_fn = getattr(F, 'conv%dd' % ndims) 264 | 265 | # compute CC squares 266 | I2 = I * I 267 | J2 = J * J 268 | IJ = I * J 269 | 270 | I_sum = conv_fn(I, sum_filt, stride=stride, padding=padding) 271 | J_sum = conv_fn(J, sum_filt, stride=stride, padding=padding) 272 | I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding) 273 | J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding) 274 | IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding) 275 | 276 | win_size = np.prod(win) 277 | u_I = I_sum / win_size 278 | u_J = J_sum / win_size 279 | 280 | cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size 281 | I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size 282 | J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size 283 | 284 | cc = cross * cross / (I_var * J_var + 1e-5) 285 | 286 | return -torch.mean(cc) 287 | 288 | class MutualInformation(torch.nn.Module): 289 | """ 290 | Mutual Information 291 | """ 292 | def __init__(self, sigma_ratio=1, minval=0., maxval=1., num_bin=32): 293 | super(MutualInformation, self).__init__() 294 | 295 | """Create bin centers""" 296 | bin_centers = np.linspace(minval, maxval, num=num_bin) 297 | vol_bin_centers = Variable(torch.linspace(minval, maxval, num_bin), requires_grad=False).cuda() 298 | num_bins = len(bin_centers) 299 | 300 | """Sigma for Gaussian approx.""" 301 | sigma = np.mean(np.diff(bin_centers)) * sigma_ratio 302 | print(sigma) 303 | 304 | self.preterm = 1 / (2 * sigma**2) 305 | self.bin_centers = bin_centers 306 | self.max_clip = maxval 307 | self.num_bins = num_bins 308 | self.vol_bin_centers = vol_bin_centers 309 | 310 | def mi(self, y_true, y_pred): 311 | y_pred = torch.clamp(y_pred, 0., self.max_clip) 312 | y_true = torch.clamp(y_true, 0, self.max_clip) 313 | 314 | y_true = y_true.view(y_true.shape[0], -1) 315 | y_true = torch.unsqueeze(y_true, 2) 316 | y_pred = y_pred.view(y_pred.shape[0], -1) 317 | y_pred = torch.unsqueeze(y_pred, 2) 318 | 319 | nb_voxels = y_pred.shape[1] # total num of voxels 320 | 321 | """Reshape bin centers""" 322 | o = [1, 1, np.prod(self.vol_bin_centers.shape)] 323 | vbc = torch.reshape(self.vol_bin_centers, o).cuda() 324 | 325 | """compute image terms by approx. Gaussian dist.""" 326 | I_a = torch.exp(- self.preterm * torch.square(y_true - vbc)) 327 | I_a = I_a / torch.sum(I_a, dim=-1, keepdim=True) 328 | 329 | I_b = torch.exp(- self.preterm * torch.square(y_pred - vbc)) 330 | I_b = I_b / torch.sum(I_b, dim=-1, keepdim=True) 331 | 332 | # compute probabilities 333 | pab = torch.bmm(I_a.permute(0, 2, 1), I_b) 334 | pab = pab/nb_voxels 335 | pa = torch.mean(I_a, dim=1, keepdim=True) 336 | pb = torch.mean(I_b, dim=1, keepdim=True) 337 | 338 | papb = torch.bmm(pa.permute(0, 2, 1), pb) + 1e-6 339 | mi = torch.sum(torch.sum(pab * torch.log(pab / papb + 1e-6), dim=1), dim=1) 340 | return mi.mean() #average across batch 341 | 342 | def forward(self, y_true, y_pred): 343 | return -self.mi(y_true, y_pred) 344 | 345 | class localMutualInformation(torch.nn.Module): 346 | """ 347 | Local Mutual Information for non-overlapping patches 348 | """ 349 | def __init__(self, sigma_ratio=1, minval=0., maxval=1., num_bin=32, patch_size=5): 350 | super(localMutualInformation, self).__init__() 351 | 352 | """Create bin centers""" 353 | bin_centers = np.linspace(minval, maxval, num=num_bin) 354 | vol_bin_centers = Variable(torch.linspace(minval, maxval, num_bin), requires_grad=False).cuda() 355 | num_bins = len(bin_centers) 356 | 357 | """Sigma for Gaussian approx.""" 358 | sigma = np.mean(np.diff(bin_centers)) * sigma_ratio 359 | 360 | self.preterm = 1 / (2 * sigma**2) 361 | self.bin_centers = bin_centers 362 | self.max_clip = maxval 363 | self.num_bins = num_bins 364 | self.vol_bin_centers = vol_bin_centers 365 | self.patch_size = patch_size 366 | 367 | def local_mi(self, y_true, y_pred): 368 | y_pred = torch.clamp(y_pred, 0., self.max_clip) 369 | y_true = torch.clamp(y_true, 0, self.max_clip) 370 | 371 | """Reshape bin centers""" 372 | o = [1, 1, np.prod(self.vol_bin_centers.shape)] 373 | vbc = torch.reshape(self.vol_bin_centers, o).cuda() 374 | 375 | """Making image paddings""" 376 | if len(list(y_pred.size())[2:]) == 3: 377 | ndim = 3 378 | x, y, z = list(y_pred.size())[2:] 379 | # compute padding sizes 380 | x_r = -x % self.patch_size 381 | y_r = -y % self.patch_size 382 | z_r = -z % self.patch_size 383 | padding = (z_r // 2, z_r - z_r // 2, y_r // 2, y_r - y_r // 2, x_r // 2, x_r - x_r // 2, 0, 0, 0, 0) 384 | elif len(list(y_pred.size())[2:]) == 2: 385 | ndim = 2 386 | x, y = list(y_pred.size())[2:] 387 | # compute padding sizes 388 | x_r = -x % self.patch_size 389 | y_r = -y % self.patch_size 390 | padding = (y_r // 2, y_r - y_r // 2, x_r // 2, x_r - x_r // 2, 0, 0, 0, 0) 391 | else: 392 | raise Exception('Supports 2D and 3D but not {}'.format(list(y_pred.size()))) 393 | y_true = F.pad(y_true, padding, "constant", 0) 394 | y_pred = F.pad(y_pred, padding, "constant", 0) 395 | 396 | """Reshaping images into non-overlapping patches""" 397 | if ndim == 3: 398 | y_true_patch = torch.reshape(y_true, (y_true.shape[0], y_true.shape[1], 399 | (x + x_r) // self.patch_size, self.patch_size, 400 | (y + y_r) // self.patch_size, self.patch_size, 401 | (z + z_r) // self.patch_size, self.patch_size)) 402 | y_true_patch = y_true_patch.permute(0, 1, 2, 4, 6, 3, 5, 7) 403 | y_true_patch = torch.reshape(y_true_patch, (-1, self.patch_size ** 3, 1)) 404 | 405 | y_pred_patch = torch.reshape(y_pred, (y_pred.shape[0], y_pred.shape[1], 406 | (x + x_r) // self.patch_size, self.patch_size, 407 | (y + y_r) // self.patch_size, self.patch_size, 408 | (z + z_r) // self.patch_size, self.patch_size)) 409 | y_pred_patch = y_pred_patch.permute(0, 1, 2, 4, 6, 3, 5, 7) 410 | y_pred_patch = torch.reshape(y_pred_patch, (-1, self.patch_size ** 3, 1)) 411 | else: 412 | y_true_patch = torch.reshape(y_true, (y_true.shape[0], y_true.shape[1], 413 | (x + x_r) // self.patch_size, self.patch_size, 414 | (y + y_r) // self.patch_size, self.patch_size)) 415 | y_true_patch = y_true_patch.permute(0, 1, 2, 4, 3, 5) 416 | y_true_patch = torch.reshape(y_true_patch, (-1, self.patch_size ** 2, 1)) 417 | 418 | y_pred_patch = torch.reshape(y_pred, (y_pred.shape[0], y_pred.shape[1], 419 | (x + x_r) // self.patch_size, self.patch_size, 420 | (y + y_r) // self.patch_size, self.patch_size)) 421 | y_pred_patch = y_pred_patch.permute(0, 1, 2, 4, 3, 5) 422 | y_pred_patch = torch.reshape(y_pred_patch, (-1, self.patch_size ** 2, 1)) 423 | 424 | """Compute MI""" 425 | I_a_patch = torch.exp(- self.preterm * torch.square(y_true_patch - vbc)) 426 | I_a_patch = I_a_patch / torch.sum(I_a_patch, dim=-1, keepdim=True) 427 | 428 | I_b_patch = torch.exp(- self.preterm * torch.square(y_pred_patch - vbc)) 429 | I_b_patch = I_b_patch / torch.sum(I_b_patch, dim=-1, keepdim=True) 430 | 431 | pab = torch.bmm(I_a_patch.permute(0, 2, 1), I_b_patch) 432 | pab = pab / self.patch_size ** ndim 433 | pa = torch.mean(I_a_patch, dim=1, keepdim=True) 434 | pb = torch.mean(I_b_patch, dim=1, keepdim=True) 435 | 436 | papb = torch.bmm(pa.permute(0, 2, 1), pb) + 1e-6 437 | mi = torch.sum(torch.sum(pab * torch.log(pab / papb + 1e-6), dim=1), dim=1) 438 | return mi.mean() 439 | 440 | def forward(self,y_true, y_pred): 441 | return -self.local_mi(y_true, y_pred) 442 | -------------------------------------------------------------------------------- /ViT-V-Net/models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as nnf 12 | from torch.nn import Dropout, Softmax, Linear, Conv3d, LayerNorm 13 | from torch.nn.modules.utils import _pair, _triple 14 | import configs as configs 15 | from torch.distributions.normal import Normal 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 21 | ATTENTION_K = "MultiHeadDotProductAttention_1/key" 22 | ATTENTION_V = "MultiHeadDotProductAttention_1/value" 23 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" 24 | FC_0 = "MlpBlock_3/Dense_0" 25 | FC_1 = "MlpBlock_3/Dense_1" 26 | ATTENTION_NORM = "LayerNorm_0" 27 | MLP_NORM = "LayerNorm_2" 28 | 29 | 30 | def np2th(weights, conv=False): 31 | """Possibly convert HWIO to OIHW.""" 32 | if conv: 33 | weights = weights.transpose([3, 2, 0, 1]) 34 | return torch.from_numpy(weights) 35 | 36 | 37 | def swish(x): 38 | return x * torch.sigmoid(x) 39 | 40 | 41 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, config, vis): 46 | super(Attention, self).__init__() 47 | self.vis = vis 48 | self.num_attention_heads = config.transformer["num_heads"] 49 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) 50 | self.all_head_size = self.num_attention_heads * self.attention_head_size 51 | 52 | self.query = Linear(config.hidden_size, self.all_head_size) 53 | self.key = Linear(config.hidden_size, self.all_head_size) 54 | self.value = Linear(config.hidden_size, self.all_head_size) 55 | 56 | self.out = Linear(config.hidden_size, config.hidden_size) 57 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) 58 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) 59 | 60 | self.softmax = Softmax(dim=-1) 61 | 62 | def transpose_for_scores(self, x): 63 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 64 | x = x.view(*new_x_shape) 65 | return x.permute(0, 2, 1, 3) 66 | 67 | def forward(self, hidden_states): 68 | mixed_query_layer = self.query(hidden_states) 69 | mixed_key_layer = self.key(hidden_states) 70 | mixed_value_layer = self.value(hidden_states) 71 | 72 | query_layer = self.transpose_for_scores(mixed_query_layer) 73 | key_layer = self.transpose_for_scores(mixed_key_layer) 74 | value_layer = self.transpose_for_scores(mixed_value_layer) 75 | 76 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 77 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 78 | attention_probs = self.softmax(attention_scores) 79 | weights = attention_probs if self.vis else None 80 | attention_probs = self.attn_dropout(attention_probs) 81 | 82 | context_layer = torch.matmul(attention_probs, value_layer) 83 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 84 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 85 | context_layer = context_layer.view(*new_context_layer_shape) 86 | attention_output = self.out(context_layer) 87 | attention_output = self.proj_dropout(attention_output) 88 | return attention_output, weights 89 | 90 | 91 | class Mlp(nn.Module): 92 | def __init__(self, config): 93 | super(Mlp, self).__init__() 94 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) 95 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) 96 | self.act_fn = ACT2FN["gelu"] 97 | self.dropout = Dropout(config.transformer["dropout_rate"]) 98 | 99 | self._init_weights() 100 | 101 | def _init_weights(self): 102 | nn.init.xavier_uniform_(self.fc1.weight) 103 | nn.init.xavier_uniform_(self.fc2.weight) 104 | nn.init.normal_(self.fc1.bias, std=1e-6) 105 | nn.init.normal_(self.fc2.bias, std=1e-6) 106 | 107 | def forward(self, x): 108 | x = self.fc1(x) 109 | x = self.act_fn(x) 110 | x = self.dropout(x) 111 | x = self.fc2(x) 112 | x = self.dropout(x) 113 | return x 114 | 115 | 116 | class Embeddings(nn.Module): 117 | """Construct the embeddings from patch, position embeddings. 118 | """ 119 | def __init__(self, config, img_size): 120 | super(Embeddings, self).__init__() 121 | self.config = config 122 | down_factor = config.down_factor 123 | patch_size = _triple(config.patches["size"]) 124 | n_patches = int((img_size[0]/2**down_factor// patch_size[0]) * (img_size[1]/2**down_factor// patch_size[1]) * (img_size[2]/2**down_factor// patch_size[2])) 125 | self.hybrid_model = CNNEncoder(config, n_channels=2) 126 | in_channels = config['encoder_channels'][-1] 127 | self.patch_embeddings = Conv3d(in_channels=in_channels, 128 | out_channels=config.hidden_size, 129 | kernel_size=patch_size, 130 | stride=patch_size) 131 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) 132 | 133 | self.dropout = Dropout(config.transformer["dropout_rate"]) 134 | 135 | def forward(self, x): 136 | x, features = self.hybrid_model(x) 137 | x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) 138 | x = x.flatten(2) 139 | x = x.transpose(-1, -2) # (B, n_patches, hidden) 140 | embeddings = x + self.position_embeddings 141 | embeddings = self.dropout(embeddings) 142 | return embeddings, features 143 | 144 | 145 | class Block(nn.Module): 146 | def __init__(self, config, vis): 147 | super(Block, self).__init__() 148 | self.hidden_size = config.hidden_size 149 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) 150 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) 151 | self.ffn = Mlp(config) 152 | self.attn = Attention(config, vis) 153 | 154 | def forward(self, x): 155 | h = x 156 | 157 | x = self.attention_norm(x) 158 | x, weights = self.attn(x) 159 | x = x + h 160 | 161 | h = x 162 | x = self.ffn_norm(x) 163 | x = self.ffn(x) 164 | x = x + h 165 | return x, weights 166 | 167 | class Encoder(nn.Module): 168 | def __init__(self, config, vis): 169 | super(Encoder, self).__init__() 170 | self.vis = vis 171 | self.layer = nn.ModuleList() 172 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) 173 | for _ in range(config.transformer["num_layers"]): 174 | layer = Block(config, vis) 175 | self.layer.append(copy.deepcopy(layer)) 176 | 177 | def forward(self, hidden_states): 178 | attn_weights = [] 179 | for layer_block in self.layer: 180 | hidden_states, weights = layer_block(hidden_states) 181 | if self.vis: 182 | attn_weights.append(weights) 183 | encoded = self.encoder_norm(hidden_states) 184 | return encoded, attn_weights 185 | 186 | 187 | class Transformer(nn.Module): 188 | def __init__(self, config, img_size, vis): 189 | super(Transformer, self).__init__() 190 | self.embeddings = Embeddings(config, img_size=img_size) 191 | self.encoder = Encoder(config, vis) 192 | 193 | def forward(self, input_ids): 194 | embedding_output, features = self.embeddings(input_ids) 195 | encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) 196 | return encoded, attn_weights, features 197 | 198 | 199 | class Conv3dReLU(nn.Sequential): 200 | def __init__( 201 | self, 202 | in_channels, 203 | out_channels, 204 | kernel_size, 205 | padding=0, 206 | stride=1, 207 | use_batchnorm=True, 208 | ): 209 | conv = nn.Conv3d( 210 | in_channels, 211 | out_channels, 212 | kernel_size, 213 | stride=stride, 214 | padding=padding, 215 | bias=not (use_batchnorm), 216 | ) 217 | relu = nn.ReLU(inplace=True) 218 | 219 | bn = nn.BatchNorm3d(out_channels) 220 | 221 | super(Conv3dReLU, self).__init__(conv, bn, relu) 222 | 223 | 224 | class DecoderBlock(nn.Module): 225 | def __init__( 226 | self, 227 | in_channels, 228 | out_channels, 229 | skip_channels=0, 230 | use_batchnorm=True, 231 | ): 232 | super().__init__() 233 | self.conv1 = Conv3dReLU( 234 | in_channels + skip_channels, 235 | out_channels, 236 | kernel_size=3, 237 | padding=1, 238 | use_batchnorm=use_batchnorm, 239 | ) 240 | self.conv2 = Conv3dReLU( 241 | out_channels, 242 | out_channels, 243 | kernel_size=3, 244 | padding=1, 245 | use_batchnorm=use_batchnorm, 246 | ) 247 | self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) 248 | 249 | def forward(self, x, skip=None): 250 | x = self.up(x) 251 | if skip is not None: 252 | x = torch.cat([x, skip], dim=1) 253 | x = self.conv1(x) 254 | x = self.conv2(x) 255 | return x 256 | 257 | class DecoderCup(nn.Module): 258 | def __init__(self, config, img_size): 259 | super().__init__() 260 | self.config = config 261 | self.down_factor = config.down_factor 262 | head_channels = config.conv_first_channel 263 | self.img_size = img_size 264 | self.conv_more = Conv3dReLU( 265 | config.hidden_size, 266 | head_channels, 267 | kernel_size=3, 268 | padding=1, 269 | use_batchnorm=True, 270 | ) 271 | decoder_channels = config.decoder_channels 272 | in_channels = [head_channels] + list(decoder_channels[:-1]) 273 | out_channels = decoder_channels 274 | self.patch_size = _triple(config.patches["size"]) 275 | skip_channels = self.config.skip_channels 276 | blocks = [ 277 | DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) 278 | ] 279 | self.blocks = nn.ModuleList(blocks) 280 | 281 | def forward(self, hidden_states, features=None): 282 | B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) 283 | l, h, w = (self.img_size[0]//2**self.down_factor//self.patch_size[0]), (self.img_size[1]//2**self.down_factor//self.patch_size[1]), (self.img_size[2]//2**self.down_factor//self.patch_size[2]) 284 | x = hidden_states.permute(0, 2, 1) 285 | x = x.contiguous().view(B, hidden, l, h, w) 286 | x = self.conv_more(x) 287 | for i, decoder_block in enumerate(self.blocks): 288 | if features is not None: 289 | skip = features[i] if (i < self.config.n_skip) else None 290 | #print(skip.shape) 291 | else: 292 | skip = None 293 | x = decoder_block(x, skip=skip) 294 | return x 295 | 296 | class SpatialTransformer(nn.Module): 297 | """ 298 | N-D Spatial Transformer 299 | 300 | Obtained from https://github.com/voxelmorph/voxelmorph 301 | """ 302 | 303 | def __init__(self, size, mode='bilinear'): 304 | super().__init__() 305 | 306 | self.mode = mode 307 | 308 | # create sampling grid 309 | vectors = [torch.arange(0, s) for s in size] 310 | grids = torch.meshgrid(vectors) 311 | grid = torch.stack(grids) 312 | grid = torch.unsqueeze(grid, 0) 313 | grid = grid.type(torch.FloatTensor) 314 | 315 | # registering the grid as a buffer cleanly moves it to the GPU, but it also 316 | # adds it to the state dict. this is annoying since everything in the state dict 317 | # is included when saving weights to disk, so the model files are way bigger 318 | # than they need to be. so far, there does not appear to be an elegant solution. 319 | # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict 320 | self.register_buffer('grid', grid) 321 | 322 | def forward(self, src, flow): 323 | # new locations 324 | new_locs = self.grid + flow 325 | shape = flow.shape[2:] 326 | 327 | # need to normalize grid values to [-1, 1] for resampler 328 | for i in range(len(shape)): 329 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) 330 | 331 | # move channels dim to last position 332 | # also not sure why, but the channels need to be reversed 333 | if len(shape) == 2: 334 | new_locs = new_locs.permute(0, 2, 3, 1) 335 | new_locs = new_locs[..., [1, 0]] 336 | elif len(shape) == 3: 337 | new_locs = new_locs.permute(0, 2, 3, 4, 1) 338 | new_locs = new_locs[..., [2, 1, 0]] 339 | 340 | return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode) 341 | 342 | class DoubleConv(nn.Module): 343 | """(convolution => [BN] => ReLU) * 2""" 344 | 345 | def __init__(self, in_channels, out_channels, mid_channels=None): 346 | super().__init__() 347 | if not mid_channels: 348 | mid_channels = out_channels 349 | self.double_conv = nn.Sequential( 350 | nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1), 351 | nn.ReLU(inplace=True), 352 | nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1), 353 | nn.ReLU(inplace=True) 354 | ) 355 | 356 | def forward(self, x): 357 | return self.double_conv(x) 358 | 359 | 360 | class Down(nn.Module): 361 | """Downscaling with maxpool then double conv""" 362 | 363 | def __init__(self, in_channels, out_channels): 364 | super().__init__() 365 | self.maxpool_conv = nn.Sequential( 366 | nn.MaxPool3d(2), 367 | DoubleConv(in_channels, out_channels) 368 | ) 369 | 370 | def forward(self, x): 371 | return self.maxpool_conv(x) 372 | 373 | class CNNEncoder(nn.Module): 374 | def __init__(self, config, n_channels=2): 375 | super(CNNEncoder, self).__init__() 376 | self.n_channels = n_channels 377 | decoder_channels = config.decoder_channels 378 | encoder_channels = config.encoder_channels 379 | self.down_num = config.down_num 380 | self.inc = DoubleConv(n_channels, encoder_channels[0]) 381 | self.down1 = Down(encoder_channels[0], encoder_channels[1]) 382 | self.down2 = Down(encoder_channels[1], encoder_channels[2]) 383 | self.width = encoder_channels[-1] 384 | def forward(self, x): 385 | features = [] 386 | x1 = self.inc(x) 387 | features.append(x1) 388 | x2 = self.down1(x1) 389 | features.append(x2) 390 | feats = self.down2(x2) 391 | features.append(feats) 392 | feats_down = feats 393 | for i in range(self.down_num): 394 | feats_down = nn.MaxPool3d(2)(feats_down) 395 | features.append(feats_down) 396 | return feats, features[::-1] 397 | 398 | class RegistrationHead(nn.Sequential): 399 | def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): 400 | conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 401 | conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape)) 402 | conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape)) 403 | super().__init__(conv3d) 404 | 405 | class ViTVNet(nn.Module): 406 | def __init__(self, config, img_size=(64, 256, 256), int_steps=7, vis=False): 407 | super(ViTVNet, self).__init__() 408 | self.transformer = Transformer(config, img_size, vis) 409 | self.decoder = DecoderCup(config, img_size) 410 | self.reg_head = RegistrationHead( 411 | in_channels=config.decoder_channels[-1], 412 | out_channels=config['n_dims'], 413 | kernel_size=3, 414 | ) 415 | self.spatial_trans = SpatialTransformer(img_size) 416 | self.config = config 417 | #self.integrate = VecInt(img_size, int_steps) 418 | def forward(self, x): 419 | 420 | source = x[:,0:1,:,:] 421 | 422 | x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) 423 | x = self.decoder(x, features) 424 | flow = self.reg_head(x) 425 | #flow = self.integrate(flow) 426 | out = self.spatial_trans(source, flow) 427 | return out, flow 428 | 429 | class VecInt(nn.Module): 430 | """ 431 | Integrates a vector field via scaling and squaring. 432 | 433 | Obtained from https://github.com/voxelmorph/voxelmorph 434 | """ 435 | 436 | def __init__(self, inshape, nsteps): 437 | super().__init__() 438 | 439 | assert nsteps >= 0, 'nsteps should be >= 0, found: %d' % nsteps 440 | self.nsteps = nsteps 441 | self.scale = 1.0 / (2 ** self.nsteps) 442 | self.transformer = SpatialTransformer(inshape) 443 | 444 | def forward(self, vec): 445 | vec = vec * self.scale 446 | for _ in range(self.nsteps): 447 | vec = vec + self.transformer(vec, vec) 448 | return vec 449 | 450 | CONFIGS = { 451 | 'ViT-V-Net': configs.get_3DReg_config(), 452 | } 453 | -------------------------------------------------------------------------------- /ViT-V-Net/train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | import os, utils, glob, losses 3 | import sys 4 | from torch.utils.data import DataLoader 5 | from data import datasets, trans 6 | import numpy as np 7 | import torch, models 8 | from torchvision import transforms 9 | from torch import optim 10 | import torch.nn as nn 11 | import matplotlib.pyplot as plt 12 | from models import CONFIGS as CONFIGS_ViT_seg 13 | from natsort import natsorted 14 | 15 | class AverageMeter(object): 16 | """Computes and stores the average and current value""" 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | def MSE_torch(x, y): 33 | return torch.mean((x - y) ** 2) 34 | 35 | def main(): 36 | batch_size = 2 37 | train_dir = 'D:/DATA/JHUBrain/Train/' 38 | val_dir = 'D:/DATA/JHUBrain/Val/' 39 | save_dir = 'ViTVNet_reg0.02_mse_diff/' 40 | lr = 0.0001 41 | epoch_start = 0 42 | max_epoch = 500 43 | cont_training = False 44 | config_vit = CONFIGS_ViT_seg['ViT-V-Net'] 45 | reg_model = utils.register_model((160, 192, 224), 'nearest') 46 | reg_model.cuda() 47 | model = models.ViTVNet(config_vit, img_size=(160, 192, 224)) 48 | if cont_training: 49 | epoch_start = 335 50 | model_dir = 'experiments/'+save_dir 51 | updated_lr = round(lr * np.power(1 - (epoch_start) / max_epoch,0.9),8) 52 | best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[0])['state_dict'] 53 | model.load_state_dict(best_model) 54 | else: 55 | updated_lr = lr 56 | model.cuda() 57 | train_composed = transforms.Compose([trans.RandomFlip(0), 58 | trans.NumpyType((np.float32, np.float32)), 59 | ]) 60 | 61 | val_composed = transforms.Compose([trans.Seg_norm(), #rearrange segmentation label to 1 to 46 62 | trans.NumpyType((np.float32, np.int16)), 63 | ]) 64 | 65 | train_set = datasets.JHUBrainDataset(glob.glob(train_dir + '*.pkl'), transforms=train_composed) 66 | val_set = datasets.JHUBrainInferDataset(glob.glob(val_dir + '*.pkl'), transforms=val_composed) 67 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) 68 | val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) 69 | 70 | optimizer = optim.Adam(model.parameters(), lr=updated_lr, weight_decay=0, amsgrad=True) 71 | criterion = nn.MSELoss() 72 | criterions = [criterion] 73 | weights = [1] 74 | # prepare deformation loss 75 | criterions += [losses.Grad3d(penalty='l2')] 76 | weights += [0.02] 77 | best_mse = 0 78 | writer = SummaryWriter(log_dir='ViTVNet_log') 79 | for epoch in range(epoch_start, max_epoch): 80 | print('Training Starts') 81 | ''' 82 | Training 83 | ''' 84 | loss_all = AverageMeter() 85 | idx = 0 86 | for data in train_loader: 87 | idx += 1 88 | model.train() 89 | adjust_learning_rate(optimizer, epoch, max_epoch, lr) 90 | data = [t.cuda() for t in data] 91 | x = data[0] 92 | y = data[1] 93 | x_in = torch.cat((x,y), dim=1) 94 | output = model(x_in) 95 | loss = 0 96 | loss_vals = [] 97 | for n, loss_function in enumerate(criterions): 98 | curr_loss = loss_function(output[n], y) * weights[n] 99 | loss_vals.append(curr_loss) 100 | loss += curr_loss 101 | loss_all.update(loss.item(), y.numel()) 102 | # compute gradient and do SGD step 103 | optimizer.zero_grad() 104 | loss.backward() 105 | optimizer.step() 106 | 107 | del x_in 108 | del output 109 | # flip fixed and moving images 110 | loss = 0 111 | x_in = torch.cat((y, x), dim=1) 112 | output = model(x_in) 113 | for n, loss_function in enumerate(criterions): 114 | curr_loss = loss_function(output[n], x) * weights[n] 115 | loss_vals[n] += curr_loss 116 | loss += curr_loss 117 | loss_all.update(loss.item(), y.numel()) 118 | # compute gradient and do SGD step 119 | optimizer.zero_grad() 120 | loss.backward() 121 | optimizer.step() 122 | 123 | print('Iter {} of {} loss {:.4f}, Img Sim: {:.6f}, Reg: {:.6f}'.format(idx, len(train_loader), loss.item(), loss_vals[0].item()/2, loss_vals[1].item()/2)) 124 | 125 | writer.add_scalar('Loss/train', loss_all.avg, epoch) 126 | print('Epoch {} loss {:.4f}'.format(epoch, loss_all.avg)) 127 | ''' 128 | Validation 129 | ''' 130 | eval_dsc = AverageMeter() 131 | with torch.no_grad(): 132 | for data in val_loader: 133 | model.eval() 134 | data = [t.cuda() for t in data] 135 | x = data[0] 136 | y = data[1] 137 | x_seg = data[2] 138 | y_seg = data[3] 139 | # x = x.squeeze(0).permute(1, 0, 2, 3) 140 | # y = y.squeeze(0).permute(1, 0, 2, 3) 141 | x_in = torch.cat((x, y), dim=1) 142 | output = model(x_in) 143 | def_out = reg_model([x_seg.cuda().float(), output[1].cuda()]) 144 | dsc = utils.dice_val(def_out.long(), y_seg.long(), 46) 145 | eval_dsc.update(dsc.item(), x.size(0)) 146 | print(eval_dsc.avg) 147 | best_mse = max(eval_dsc.avg, best_mse) 148 | save_checkpoint({ 149 | 'epoch': epoch + 1, 150 | 'state_dict': model.state_dict(), 151 | 'best_mse': best_mse, 152 | 'optimizer': optimizer.state_dict(), 153 | }, save_dir='experiments/'+save_dir, filename='dsc{:.3f}.pth.tar'.format(eval_dsc.avg)) 154 | writer.add_scalar('MSE/validate', eval_dsc.avg, epoch) 155 | plt.switch_backend('agg') 156 | pred_fig = comput_fig(def_out) 157 | x_fig = comput_fig(x_seg) 158 | tar_fig = comput_fig(y_seg) 159 | writer.add_figure('input', x_fig, epoch) 160 | plt.close(x_fig) 161 | writer.add_figure('ground truth', tar_fig, epoch) 162 | plt.close(tar_fig) 163 | writer.add_figure('prediction', pred_fig, epoch) 164 | plt.close(pred_fig) 165 | loss_all.reset() 166 | writer.close() 167 | 168 | def comput_fig(img): 169 | img = img.detach().cpu().numpy()[0, 0, 48:64, :, :] 170 | fig = plt.figure(figsize=(12,12), dpi=180) 171 | for i in range(img.shape[0]): 172 | plt.subplot(4, 4, i + 1) 173 | plt.axis('off') 174 | plt.imshow(img[i, :, :], cmap='gray') 175 | fig.subplots_adjust(wspace=0, hspace=0) 176 | return fig 177 | 178 | def adjust_learning_rate(optimizer, epoch, MAX_EPOCHES, INIT_LR, power=0.9): 179 | for param_group in optimizer.param_groups: 180 | param_group['lr'] = round(INIT_LR * np.power( 1 - (epoch) / MAX_EPOCHES ,power),8) 181 | 182 | 183 | def save_checkpoint(state, save_dir='models', filename='checkpoint.pth.tar', max_model_num=8): 184 | torch.save(state, save_dir+filename) 185 | model_lists = natsorted(glob.glob(save_dir + '*')) 186 | while len(model_lists) > max_model_num: 187 | os.remove(model_lists[0]) 188 | model_lists = natsorted(glob.glob(save_dir + '*')) 189 | 190 | if __name__ == '__main__': 191 | ''' 192 | GPU configuration 193 | ''' 194 | GPU_iden = 0 195 | GPU_num = torch.cuda.device_count() 196 | print('Number of GPU: ' + str(GPU_num)) 197 | for GPU_idx in range(GPU_num): 198 | GPU_name = torch.cuda.get_device_name(GPU_idx) 199 | print(' GPU #' + str(GPU_idx) + ': ' + GPU_name) 200 | torch.cuda.set_device(GPU_iden) 201 | GPU_avai = torch.cuda.is_available() 202 | print('Currently using: ' + torch.cuda.get_device_name(GPU_iden)) 203 | print('If the GPU is available? ' + str(GPU_avai)) 204 | main() -------------------------------------------------------------------------------- /ViT-V-Net/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch, sys 5 | from torch import nn 6 | import pystrum.pynd.ndutils as nd 7 | 8 | def sliding_predict(model, image, tile_size, n_dims, overlap=1/2, flip=False): 9 | image_size = image.shape 10 | stride_x = math.ceil(tile_size[0] * (1 - overlap)) 11 | stride_y = math.ceil(tile_size[1] * (1 - overlap)) 12 | stride_z = math.ceil(tile_size[2] * (1 - overlap)) 13 | num_rows = int(math.ceil((image_size[2] - tile_size[0]) / stride_x) + 1) 14 | num_cols = int(math.ceil((image_size[3] - tile_size[1]) / stride_y) + 1) 15 | num_slcs = int(math.ceil((image_size[4] - tile_size[2]) / stride_z) + 1) 16 | total_predictions = torch.zeros((1, n_dims, image_size[2], image_size[3], image_size[4])).cuda() 17 | count_predictions = torch.zeros((image_size[2], image_size[3], image_size[4])).cuda() 18 | tile_counter = 0 19 | print(num_rows) 20 | for row in range(num_rows): 21 | for col in range(num_cols): 22 | for slc in range(num_slcs): 23 | x_min, y_min, z_min = int(row * stride_x), int(col * stride_y), int(slc * stride_z) 24 | x_max = x_min + tile_size[0] 25 | y_max = y_min + tile_size[1] 26 | z_max = z_min + tile_size[2] 27 | if x_max > image_size[2]: 28 | x_min = image_size[2] - stride_x 29 | x_max = image_size[2] 30 | if y_max > image_size[3]: 31 | y_min = image_size[3] - stride_y 32 | y_max = image_size[3] 33 | if z_max > image_size[4]: 34 | z_min = image_size[4] - stride_z 35 | y_max = image_size[4] 36 | img = image[:, :, x_min:x_max, y_min:y_max, z_min:z_max] 37 | padded_img = pad_image(img, tile_size) 38 | #print(padded_img.shape) 39 | 40 | tile_counter += 1 41 | padded_prediction = model(padded_img)[1] 42 | if flip: 43 | for dim in [-1, -2, -3]: 44 | fliped_img = padded_img.flip(dim) 45 | fliped_predictions = model(fliped_img)[1] 46 | padded_prediction = (fliped_predictions.flip(dim) + padded_prediction) 47 | padded_prediction = padded_prediction/4 48 | predictions = padded_prediction[:, :, :img.shape[2], :img.shape[3], :img.shape[4]] 49 | count_predictions[x_min:x_max, y_min:y_max, z_min:z_max] += 1 50 | total_predictions[:, :, x_min:x_max, y_min:y_max, z_min:z_max] += predictions.cuda()#.data.cpu().numpy() 51 | total_predictions /= count_predictions 52 | return total_predictions 53 | 54 | def pad_image(img, target_size): 55 | rows_to_pad = max(target_size[0] - img.shape[2], 0) 56 | cols_to_pad = max(target_size[1] - img.shape[3], 0) 57 | slcs_to_pad = max(target_size[2] - img.shape[4], 0) 58 | padded_img = F.pad(img, (0, slcs_to_pad, 0, cols_to_pad, 0, rows_to_pad), "constant", 0) 59 | return padded_img 60 | 61 | class SpatialTransformer(nn.Module): 62 | """ 63 | N-D Spatial Transformer 64 | """ 65 | 66 | def __init__(self, size, mode='bilinear'): 67 | super().__init__() 68 | 69 | self.mode = mode 70 | 71 | # create sampling grid 72 | vectors = [torch.arange(0, s) for s in size] 73 | grids = torch.meshgrid(vectors) 74 | grid = torch.stack(grids) 75 | grid = torch.unsqueeze(grid, 0) 76 | grid = grid.type(torch.FloatTensor).cuda() 77 | 78 | # registering the grid as a buffer cleanly moves it to the GPU, but it also 79 | # adds it to the state dict. this is annoying since everything in the state dict 80 | # is included when saving weights to disk, so the model files are way bigger 81 | # than they need to be. so far, there does not appear to be an elegant solution. 82 | # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict 83 | self.register_buffer('grid', grid) 84 | 85 | def forward(self, src, flow): 86 | # new locations 87 | new_locs = self.grid + flow 88 | shape = flow.shape[2:] 89 | 90 | # need to normalize grid values to [-1, 1] for resampler 91 | for i in range(len(shape)): 92 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) 93 | 94 | # move channels dim to last position 95 | # also not sure why, but the channels need to be reversed 96 | if len(shape) == 2: 97 | new_locs = new_locs.permute(0, 2, 3, 1) 98 | new_locs = new_locs[..., [1, 0]] 99 | elif len(shape) == 3: 100 | new_locs = new_locs.permute(0, 2, 3, 4, 1) 101 | new_locs = new_locs[..., [2, 1, 0]] 102 | 103 | return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode) 104 | 105 | class register_model(nn.Module): 106 | def __init__(self, img_size=(64, 256, 256), mode='bilinear'): 107 | super(register_model, self).__init__() 108 | self.spatial_trans = SpatialTransformer(img_size, mode) 109 | 110 | def forward(self, x): 111 | img = x[0].cuda() 112 | flow = x[1].cuda() 113 | out = self.spatial_trans(img, flow) 114 | return out 115 | 116 | def dice_val(y_pred, y_true, num_clus): 117 | y_pred = nn.functional.one_hot(y_pred, num_classes=num_clus) 118 | y_pred = torch.squeeze(y_pred, 1) 119 | y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous() 120 | y_true = nn.functional.one_hot(y_true, num_classes=num_clus) 121 | y_true = torch.squeeze(y_true, 1) 122 | y_true = y_true.permute(0, 4, 1, 2, 3).contiguous() 123 | intersection = y_pred * y_true 124 | intersection = intersection.sum(dim=[2, 3, 4]) 125 | union = y_pred.sum(dim=[2, 3, 4]) + y_true.sum(dim=[2, 3, 4]) 126 | dsc = (2.*intersection) / (union + 1e-5) 127 | return torch.mean(torch.mean(dsc, dim=1)) 128 | 129 | def jacobian_determinant(disp): 130 | """ 131 | jacobian determinant of a displacement field. 132 | NB: to compute the spatial gradients, we use np.gradient. 133 | Parameters: 134 | disp: 3D displacement field of size [nb_dims, *vol_shape] 135 | Returns: 136 | jacobian determinant (matrix) 137 | """ 138 | 139 | # check inputs 140 | volshape = disp.shape[1:] 141 | nb_dims = len(volshape) 142 | assert len(volshape) in (2, 3), 'flow has to be 2D or 3D' 143 | 144 | # compute grid 145 | grid_lst = nd.volsize2ndgrid(volshape) 146 | grid = np.stack(grid_lst, 0) 147 | 148 | # compute gradients 149 | [xFX, xFY, xFZ] = np.gradient(grid[0] - disp[0]) 150 | [yFX, yFY, yFZ] = np.gradient(grid[1] - disp[1]) 151 | [zFX, zFY, zFZ] = np.gradient(grid[2] - disp[2]) 152 | 153 | jac_det = np.zeros(grid[0].shape) 154 | for i in range(grid.shape[1]): 155 | for j in range(grid.shape[2]): 156 | for k in range(grid.shape[3]): 157 | jac_mij = [[xFX[i, j, k], xFY[i, j, k], xFZ[i, j, k]], [yFX[i, j, k], yFY[i, j, k], yFZ[i, j, k]], [zFX[i, j, k], zFY[i, j, k], zFZ[i, j, k]]] 158 | jac_det[i, j, k] = np.linalg.det(jac_mij) 159 | return jac_det 160 | 161 | 162 | import re 163 | def process_label(): 164 | #process labeling information for FreeSurfer 165 | seg_table = [0, 2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 166 | 28, 30, 31, 41, 42, 43, 44, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 62, 167 | 63, 72, 77, 80, 85, 251, 252, 253, 254, 255] 168 | 169 | 170 | file1 = open('label_info.txt', 'r') 171 | Lines = file1.readlines() 172 | dict = {} 173 | seg_i = 0 174 | seg_look_up = [] 175 | for seg_label in seg_table: 176 | for line in Lines: 177 | line = re.sub(' +', ' ',line).split(' ') 178 | try: 179 | int(line[0]) 180 | except: 181 | continue 182 | if int(line[0]) == seg_label: 183 | seg_look_up.append([seg_i, int(line[0]), line[1]]) 184 | dict[seg_i] = line[1] 185 | seg_i += 1 186 | return dict 187 | 188 | def write2csv(line, name): 189 | with open(name+'.csv', 'a') as file: 190 | file.write(line) 191 | file.write('\n') 192 | 193 | def dice_val_substruct(y_pred, y_true, std_idx): 194 | with torch.no_grad(): 195 | y_pred = nn.functional.one_hot(y_pred, num_classes=46) 196 | y_pred = torch.squeeze(y_pred, 1) 197 | y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous() 198 | y_true = nn.functional.one_hot(y_true, num_classes=46) 199 | y_true = torch.squeeze(y_true, 1) 200 | y_true = y_true.permute(0, 4, 1, 2, 3).contiguous() 201 | y_pred = y_pred.detach().cpu().numpy() 202 | y_true = y_true.detach().cpu().numpy() 203 | 204 | line = 'p_{}'.format(std_idx) 205 | for i in range(46): 206 | pred_clus = y_pred[0, i, ...] 207 | true_clus = y_true[0, i, ...] 208 | intersection = pred_clus * true_clus 209 | intersection = intersection.sum() 210 | union = pred_clus.sum() + true_clus.sum() 211 | dsc = (2.*intersection) / (union + 1e-5) 212 | line = line+','+str(dsc) 213 | return line 214 | 215 | -------------------------------------------------------------------------------- /figures/ViTVNet_res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/figures/ViTVNet_res.jpg -------------------------------------------------------------------------------- /figures/dice_details_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/figures/dice_details_.jpg -------------------------------------------------------------------------------- /figures/net_arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/figures/net_arch.jpg -------------------------------------------------------------------------------- /figures/trans_arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch/a5096d918a88f4fa7492f17cbcbaeb195eb51ced/figures/trans_arch.jpg --------------------------------------------------------------------------------