├── pytorch_ssim ├── __pycache__ │ └── __init__.cpython-35.pyc └── __init__.py ├── LICENSE ├── ATLAS_dataset.py ├── ADNI_dataset.py ├── Model_WGAN.py ├── BRATS_dataset.py ├── README.md ├── Model_alphaWGAN.py ├── Model_alphaGAN.py ├── Model_VAEGAN.py ├── WGAN_ADNI_train.ipynb ├── VAEGAN_ADNI_train.ipynb ├── Alpha_GAN_ADNI_train.ipynb ├── Alpha_WGAN_ADNI_train.ipynb └── Test.ipynb /pytorch_ssim/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyclomon/3dbraingen/HEAD/pytorch_ssim/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 cyclomon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ATLAS_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import torch 4 | from torch.utils.data.dataset import Dataset 5 | import os 6 | from torchvision import transforms 7 | from skimage.transform import resize 8 | import nibabel as nib 9 | from skimage import exposure 10 | 11 | class ATLASdataset(Dataset): 12 | def __init__(self,augmentation=True): 13 | list_path = [] 14 | for i in range(9): 15 | root = '../ATLAS_R1.1/Site'+str(i+1) 16 | 17 | list_img = os.listdir(root) 18 | for s in range(len(list_img)): 19 | list_path.append(os.path.join(root,list_img[s])) 20 | 21 | list_path.sort() 22 | self.augmentation= augmentation 23 | self.imglist = list_path 24 | 25 | def __len__(self): 26 | return len(self.imglist) 27 | 28 | def __getitem__(self, index): 29 | path = os.path.join(self.imglist[index],'t01') 30 | tempimg = nib.load(os.path.join(path,'T1w_p.nii')) 31 | B = np.flip(tempimg.get_data(),1) 32 | sp_size = 64 33 | img = resize(B, (sp_size,sp_size,sp_size), mode='constant') 34 | img = 1.0*img 35 | img = (img-np.min(img))/(np.max(img)-np.min(img)) 36 | 37 | if self.augmentation: 38 | random_n = torch.rand(1) 39 | if random_n[0] > 0.5: 40 | img = np.flip(img,0) 41 | 42 | img = np.ascontiguousarray(img,dtype=np.float32) 43 | 44 | imageout = torch.from_numpy(img).float().view(1,sp_size,sp_size,sp_size) 45 | imageout = 2*imageout-1 46 | 47 | return imageout 48 | 49 | 50 | -------------------------------------------------------------------------------- /ADNI_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import torch 4 | from torch.utils.data.dataset import Dataset 5 | import os 6 | from torchvision import transforms 7 | from skimage.transform import resize 8 | from nilearn import surface 9 | import nibabel as nib 10 | 11 | class ADNIdataset(Dataset): 12 | def __init__(self, root='../ADNI', augmentation=False): 13 | self.root = root 14 | self.basis = 'FreeSurfer_Cross-Sectional_Processing_brainmask' 15 | self.augmentation = augmentation 16 | f = open('CN_list.csv','r') 17 | rdr = csv.reader(f) 18 | 19 | name = [] 20 | labels = [] 21 | date = [] 22 | for line in rdr: 23 | [month,day,year] = line[9].split('/') 24 | month = month.zfill(2) 25 | date.append(year+'-'+month+'-'+day) 26 | name.append(line[1]) 27 | 28 | name = np.asarray(name) 29 | date = np.asarray(date) 30 | 31 | self.name =name 32 | self.date =date 33 | def __len__(self): 34 | return len(self.name) 35 | 36 | def __getitem__(self, index): 37 | path = os.path.join(self.root,self.name[index],self.basis) 38 | files = os.listdir(path) 39 | for file in files: 40 | if file[:10] == self.date[index]: 41 | rname = file 42 | aname = os.listdir(os.path.join(path,rname))[0] 43 | path = os.path.join(path,rname,aname,'mri') 44 | img = nib.load(os.path.join(path,'image.nii')) 45 | 46 | img = np.swapaxes(img.get_data(),1,2) 47 | img = np.flip(img,1) 48 | img = np.flip(img,2) 49 | sp_size = 64 50 | img = resize(img, (sp_size,sp_size,sp_size), mode='constant') 51 | if self.augmentation: 52 | random_n = torch.rand(1) 53 | random_i = 0.3*torch.rand(1)[0]+0.7 54 | if random_n[0] > 0.5: 55 | img = np.flip(img,0) 56 | 57 | img = img*random_i.data.cpu().numpy() 58 | 59 | imageout = torch.from_numpy(img).float().view(1,sp_size,sp_size,sp_size) 60 | imageout = imageout*2-1 61 | 62 | return imageout 63 | 64 | -------------------------------------------------------------------------------- /Model_WGAN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from torch import nn 5 | from torch import optim 6 | from torch.nn import functional as F 7 | 8 | class Discriminator(nn.Module): 9 | def __init__(self, channel=512): 10 | super(Discriminator, self).__init__() 11 | self.channel = channel 12 | n_class = 1 13 | 14 | self.conv1 = nn.Conv3d(1, channel//8, kernel_size=4, stride=2, padding=1) 15 | self.conv2 = nn.Conv3d(channel//8, channel//4, kernel_size=4, stride=2, padding=1) 16 | self.bn2 = nn.BatchNorm3d(channel//4) 17 | self.conv3 = nn.Conv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) 18 | self.bn3 = nn.BatchNorm3d(channel//2) 19 | self.conv4 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) 20 | self.bn4 = nn.BatchNorm3d(channel) 21 | 22 | self.conv5 = nn.Conv3d(channel, n_class, kernel_size=4, stride=2, padding=1) 23 | 24 | def forward(self, x, _return_activations=False): 25 | h1 = F.leaky_relu(self.conv1(x), negative_slope=0.2) 26 | h2 = F.leaky_relu(self.bn2(self.conv2(h1)), negative_slope=0.2) 27 | h3 = F.leaky_relu(self.bn3(self.conv3(h2)), negative_slope=0.2) 28 | h4 = F.leaky_relu(self.bn4(self.conv4(h3)), negative_slope=0.2) 29 | h5 = self.conv5(h4) 30 | output = h5 31 | 32 | return output 33 | 34 | 35 | class Generator(nn.Module): 36 | def __init__(self, noise:int=1000, channel:int=64): 37 | super(Generator, self).__init__() 38 | _c = channel 39 | 40 | self.noise = noise 41 | self.fc = nn.Linear(1000,512*4*4*4) 42 | self.bn1 = nn.BatchNorm3d(_c*8) 43 | 44 | self.tp_conv2 = nn.Conv3d(_c*8, _c*4, kernel_size=3, stride=1, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm3d(_c*4) 46 | 47 | self.tp_conv3 = nn.Conv3d(_c*4, _c*2, kernel_size=3, stride=1, padding=1, bias=False) 48 | self.bn3 = nn.BatchNorm3d(_c*2) 49 | 50 | self.tp_conv4 = nn.Conv3d(_c*2, _c, kernel_size=3, stride=1, padding=1, bias=False) 51 | self.bn4 = nn.BatchNorm3d(_c) 52 | 53 | self.tp_conv5 = nn.Conv3d(_c, 1, kernel_size=3, stride=1, padding=1, bias=False) 54 | 55 | def forward(self, noise): 56 | noise = noise.view(-1, 1000) 57 | h = self.fc(noise) 58 | h = h.view(-1,512,4,4,4) 59 | h = F.relu(self.bn1(h)) 60 | 61 | h = F.upsample(h,scale_factor = 2) 62 | h = self.tp_conv2(h) 63 | h = F.relu(self.bn2(h)) 64 | 65 | h = F.upsample(h,scale_factor = 2) 66 | h = self.tp_conv3(h) 67 | h = F.relu(self.bn3(h)) 68 | 69 | h = F.upsample(h,scale_factor = 2) 70 | h = self.tp_conv4(h) 71 | h = F.relu(self.bn4(h)) 72 | 73 | h = F.upsample(h,scale_factor = 2) 74 | h = self.tp_conv5(h) 75 | 76 | h = F.tanh(h) 77 | 78 | return h -------------------------------------------------------------------------------- /BRATS_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import torch 4 | from torch.utils.data.dataset import Dataset 5 | import os 6 | from skimage.transform import resize 7 | from nilearn import surface 8 | import nibabel as nib 9 | from skimage import exposure 10 | 11 | class BRATSdataset(Dataset): 12 | def __init__(self, train=True, imgtype = 'flair', severity='HGG',is_flip=False,augmentation=True): 13 | self.augmentation = augmentation 14 | if train: 15 | self.root = '../Training_brats/' + severity 16 | else: 17 | self.root = '../Validation_brats' 18 | self.imgtype = imgtype 19 | list_img = os.listdir(self.root) 20 | list_img.sort() 21 | self.imglist = list_img 22 | self.is_flip = is_flip 23 | 24 | def __len__(self): 25 | return len(self.imglist) 26 | 27 | def __getitem__(self, index): 28 | 29 | path = os.path.join(self.root,self.imglist[index]) 30 | 31 | img = nib.load(os.path.join(path,self.imglist[index]+'_'+self.imgtype+'.nii.gz')) 32 | gt = nib.load(os.path.join(path,self.imglist[index])+'_'+'seg.nii.gz') 33 | 34 | A = np.zeros((240,240,166)) 35 | G = np.zeros((240,240,166)) 36 | A[:,:,11:] = img.get_data() 37 | G[:,:,11:] = gt.get_data() 38 | x=[] 39 | y=[] 40 | z=[] 41 | 42 | for i in range(240): 43 | if np.all(A[i,:,:] ==0): 44 | x.append(i) 45 | if np.all(A[:,i,:]==0): 46 | y.append(i) 47 | if i <155: 48 | if np.all(A[:,:,i]==0): 49 | z.append(i) 50 | 51 | xl,yl,zl = 0,0,0 52 | xh,yh,zh = 240,240,155 53 | for xn in x: 54 | if xn < 120: 55 | if xn> xl: 56 | xl = xn 57 | else: 58 | if xn yl: 63 | yl = yn 64 | else: 65 | if yn zl: 70 | zl = zn 71 | else: 72 | if zn 0.5: 89 | img = np.flip(img,0) 90 | 91 | img = 1.0*img 92 | img = exposure.rescale_intensity(img) 93 | img = (img-np.min(img))/(np.max(img)-np.min(img)) 94 | img = 2*img-1 95 | 96 | imageout = torch.from_numpy(img).float().view(1,sp_size,sp_size,sp_size) 97 | 98 | return imageout 99 | 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Official Pytorch Implementation of "Generation of 3D Brain MRI Using Auto-Encoding Generative Adversarial Networks" (accepted by MICCAI 2019) 2 | 3 | This repository provides a PyTorch implementation of 3D brain Generation. It can successfully generates plausible 3-dimensional brain MRI with Generative Adversarial Networks. Trained models are also provided in this page. 4 | 5 | ## Paper 6 | "Generation of 3D Brain MRI Using Auto-Encoding Generative Adversarial Networks" 7 | 8 | The 22nd International Conference on Medical Image Computing and Computer Assisted Intervention(MICCAI 2019) 9 | : (https://arxiv.org/abs/1908.02498) 10 | 11 | ## Cite 12 | ``` 13 | @inproceedings{kwon2019generation, 14 | title={Generation of 3D brain MRI using auto-encoding generative adversarial networks}, 15 | author={Kwon, Gihyun and Han, Chihye and Kim, Dae-shik}, 16 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 17 | pages={118--126}, 18 | year={2019}, 19 | organization={Springer} 20 | } 21 | ``` 22 | 23 | ## Dependencies 24 | * [Python 3.5+](https://www.continuum.io/downloads) 25 | * [PyTorch 0.4.0+](http://pytorch.org/) 26 | * [Jupyter Notebook](https://jupyter.org/) 27 | * [Nilearn](https://nilearn.github.io/) 28 | * [Nibabel](https://nipy.org/nibabel/) 29 | 30 | We highly recommend you to use Jupyter Notebook for the better visualization! 31 | 32 | ## Dataset 33 | You can download the Normal MRI data in [Alzheimer's Disease Neuroimaging Initiative(ADNI)](http://adni.loni.usc.edu/) 34 | , Tumor MRI data in [BRATS2018](https://www.med.upenn.edu/sbia/brats2018/data.html) and Stroke MRI data in [Anatomical Tracings of Lesions After Stroke (ATLAS)](http://fcon_1000.projects.nitrc.org/indi/retro/atlas.html). 35 | 36 | We converted all the DICOM(.dcm) files of ADNI into Nifti(.nii) file format using [SPM12](https://www.fil.ion.ucl.ac.uk/spm/software/spm12/) I/O tools. 37 | 38 | ADNI : Download Post-processed(processed with 'recon-all' command of [Freesurfer](https://surfer.nmr.mgh.harvard.edu/)) Structural images labeled as 'Control Normal'. 39 | 40 | BRATS : Download dataset from BRATS2018 website. 41 | 42 | ATLAS : Download dataset from ATLAS website. 43 | Obtain probability maps(masks) from the original .nii images with SPM12 'Segmentation' function. 44 | Extract Brain areas with multiplying masks(c1,c2,c3 / GM,WM,CSF) with original images. 45 | 46 | ## Training Details 47 | For each training, run 12,000 iterations (100 epochs in VAE-GAN) 48 | 49 | Each run takes ~12 hour with one NVIDIA TITAN X GPU. 50 | 51 | Run the Jupyter Notebook code for training (~train.ipynb) 52 | 53 | ## Test Details 54 | You can download our Pre-trained models in our [Google Drive](https://drive.google.com/open?id=1Q5kkI_GxCY066c9owqzFFjzB_iEFCefJ) 55 | 56 | Download the models and save them in the directory './checkpoint' 57 | Then you can run the test code ('Test.ipynb') 58 | 59 | Quantitative calculation (MS-SSIM / MMD score) & Image sampling is availble in the code. 60 | 61 | For the PCA visualization, please follow the PCA tutorial that Nilearn provides. 62 | 63 | ## Model Details 64 | You can get the detailed settings of used models in our model codes 65 | 66 | (Model_alphaGAN.py , Model_alphaWGAN.py , Model_VAEGAN.py, Model_WGAN.py) 67 | 68 | 69 | ## Details for Dataset 70 | If you have any question about data, feel free to e-mail me! 71 | 72 | cyclomon@kaist.ac.kr 73 | -------------------------------------------------------------------------------- /Model_alphaWGAN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from torch import nn 5 | from torch import optim 6 | from torch.nn import functional as F 7 | 8 | #*********************************************** 9 | #Encoder and Discriminator has same architecture 10 | #*********************************************** 11 | class Discriminator(nn.Module): 12 | def __init__(self, channel=512,out_class=1,is_dis =True): 13 | super(Discriminator, self).__init__() 14 | self.is_dis=is_dis 15 | self.channel = channel 16 | n_class = out_class 17 | 18 | self.conv1 = nn.Conv3d(1, channel//8, kernel_size=4, stride=2, padding=1) 19 | self.conv2 = nn.Conv3d(channel//8, channel//4, kernel_size=4, stride=2, padding=1) 20 | self.bn2 = nn.BatchNorm3d(channel//4) 21 | self.conv3 = nn.Conv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) 22 | self.bn3 = nn.BatchNorm3d(channel//2) 23 | self.conv4 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) 24 | self.bn4 = nn.BatchNorm3d(channel) 25 | self.conv5 = nn.Conv3d(channel, n_class, kernel_size=4, stride=1, padding=0) 26 | 27 | def forward(self, x, _return_activations=False): 28 | h1 = F.leaky_relu(self.conv1(x), negative_slope=0.2) 29 | h2 = F.leaky_relu(self.bn2(self.conv2(h1)), negative_slope=0.2) 30 | h3 = F.leaky_relu(self.bn3(self.conv3(h2)), negative_slope=0.2) 31 | h4 = F.leaky_relu(self.bn4(self.conv4(h3)), negative_slope=0.2) 32 | h5 = self.conv5(h4) 33 | output = h5 34 | 35 | return output 36 | 37 | class Code_Discriminator(nn.Module): 38 | def __init__(self, code_size=100,num_units=750): 39 | super(Code_Discriminator, self).__init__() 40 | n_class = 1 41 | self.l1 = nn.Sequential(nn.Linear(code_size, num_units), 42 | nn.BatchNorm1d(num_units), 43 | nn.LeakyReLU(0.2,inplace=True)) 44 | self.l2 = nn.Sequential(nn.Linear(num_units, num_units), 45 | nn.BatchNorm1d(num_units), 46 | nn.LeakyReLU(0.2,inplace=True)) 47 | self.l3 = nn.Linear(num_units, 1) 48 | 49 | def forward(self, x): 50 | h1 = self.l1(x) 51 | h2 = self.l2(h1) 52 | h3 = self.l3(h2) 53 | output = h3 54 | 55 | return output 56 | 57 | class Generator(nn.Module): 58 | def __init__(self, noise:int=100, channel:int=64): 59 | super(Generator, self).__init__() 60 | _c = channel 61 | 62 | self.relu = nn.ReLU() 63 | self.noise = noise 64 | self.tp_conv1 = nn.ConvTranspose3d(noise, _c*8, kernel_size=4, stride=1, padding=0, bias=False) 65 | self.bn1 = nn.BatchNorm3d(_c*8) 66 | 67 | self.tp_conv2 = nn.Conv3d(_c*8, _c*4, kernel_size=3, stride=1, padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm3d(_c*4) 69 | 70 | self.tp_conv3 = nn.Conv3d(_c*4, _c*2, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn3 = nn.BatchNorm3d(_c*2) 72 | 73 | self.tp_conv4 = nn.Conv3d(_c*2, _c, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.bn4 = nn.BatchNorm3d(_c) 75 | 76 | self.tp_conv5 = nn.Conv3d(_c, 1, kernel_size=3, stride=1, padding=1, bias=False) 77 | 78 | def forward(self, noise): 79 | 80 | noise = noise.view(-1,self.noise,1,1,1) 81 | h = self.tp_conv1(noise) 82 | h = self.relu(self.bn1(h)) 83 | 84 | h = F.upsample(h,scale_factor = 2) 85 | h = self.tp_conv2(h) 86 | h = self.relu(self.bn2(h)) 87 | 88 | h = F.upsample(h,scale_factor = 2) 89 | h = self.tp_conv3(h) 90 | h = self.relu(self.bn3(h)) 91 | 92 | h = F.upsample(h,scale_factor = 2) 93 | h = self.tp_conv4(h) 94 | h = self.relu(self.bn4(h)) 95 | 96 | h = F.upsample(h,scale_factor = 2) 97 | h = self.tp_conv5(h) 98 | 99 | h = F.tanh(h) 100 | 101 | return h -------------------------------------------------------------------------------- /Model_alphaGAN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from torch import nn 5 | from torch import optim 6 | from torch.nn import functional as F 7 | 8 | #*********************************************** 9 | #Encoder and Discriminator has same architecture 10 | #*********************************************** 11 | class Discriminator(nn.Module): 12 | def __init__(self, channel=512,out_class=1,is_dis =True): 13 | super(Discriminator, self).__init__() 14 | self.is_dis=is_dis 15 | self.channel = channel 16 | n_class = out_class 17 | 18 | self.conv1 = nn.Conv3d(1, channel//8, kernel_size=4, stride=2, padding=1) 19 | self.conv2 = nn.Conv3d(channel//8, channel//4, kernel_size=4, stride=2, padding=1) 20 | self.bn2 = nn.BatchNorm3d(channel//4) 21 | self.conv3 = nn.Conv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) 22 | self.bn3 = nn.BatchNorm3d(channel//2) 23 | self.conv4 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) 24 | self.bn4 = nn.BatchNorm3d(channel) 25 | 26 | self.conv5 = nn.Conv3d(channel, n_class, kernel_size=4, stride=1, padding=0) 27 | 28 | def forward(self, x, _return_activations=False): 29 | h1 = F.leaky_relu(self.conv1(x), negative_slope=0.2) 30 | h2 = F.leaky_relu(self.bn2(self.conv2(h1)), negative_slope=0.2) 31 | h3 = F.leaky_relu(self.bn3(self.conv3(h2)), negative_slope=0.2) 32 | h4 = F.leaky_relu(self.bn4(self.conv4(h3)), negative_slope=0.2) 33 | h5 = self.conv5(h4) 34 | 35 | if self.is_dis: 36 | output = F.sigmoid(h5.view(h5.size()[0],-1)) 37 | else: 38 | output = h5.view(h5.size()[0],-1) 39 | 40 | return output 41 | 42 | class Code_Discriminator(nn.Module): 43 | def __init__(self, code_size=100,num_units=750): 44 | super(Code_Discriminator, self).__init__() 45 | n_class = 1 46 | self.l1 = nn.Sequential(nn.Linear(code_size, num_units), 47 | nn.BatchNorm1d(num_units), 48 | nn.LeakyReLU(0.2,inplace=True)) 49 | self.l2 = nn.Sequential(nn.Linear(num_units, num_units), 50 | nn.BatchNorm1d(num_units), 51 | nn.LeakyReLU(0.2,inplace=True)) 52 | self.l3 = nn.Linear(num_units, 1) 53 | 54 | def forward(self, x): 55 | h1 = self.l1(x) 56 | h2 = self.l2(h1) 57 | h3 = self.l3(h2) 58 | output = F.sigmoid(h3) 59 | 60 | return output 61 | 62 | class Generator(nn.Module): 63 | def __init__(self, noise:int=100, channel:int=64): 64 | super(Generator, self).__init__() 65 | _c = channel 66 | 67 | self.relu = nn.ReLU() 68 | self.noise = noise 69 | self.tp_conv1 = nn.ConvTranspose3d(noise, _c*8, kernel_size=4, stride=1, padding=0, bias=False) 70 | self.bn1 = nn.BatchNorm3d(_c*8) 71 | 72 | self.tp_conv2 = nn.Conv3d(_c*8, _c*4, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn2 = nn.BatchNorm3d(_c*4) 74 | 75 | self.tp_conv3 = nn.Conv3d(_c*4, _c*2, kernel_size=3, stride=1, padding=1, bias=False) 76 | self.bn3 = nn.BatchNorm3d(_c*2) 77 | 78 | self.tp_conv4 = nn.Conv3d(_c*2, _c, kernel_size=3, stride=1, padding=1, bias=False) 79 | self.bn4 = nn.BatchNorm3d(_c) 80 | 81 | self.tp_conv5 = nn.Conv3d(_c, 1, kernel_size=3, stride=1, padding=1, bias=False) 82 | 83 | def forward(self, noise): 84 | 85 | noise = noise.view(-1,self.noise,1,1,1) 86 | h = self.tp_conv1(noise) 87 | h = self.relu(self.bn1(h)) 88 | 89 | h = F.upsample(h,scale_factor = 2) 90 | h = self.tp_conv2(h) 91 | h = self.relu(self.bn2(h)) 92 | 93 | h = F.upsample(h,scale_factor = 2) 94 | h = self.tp_conv3(h) 95 | h = self.relu(self.bn3(h)) 96 | 97 | h = F.upsample(h,scale_factor = 2) 98 | h = self.tp_conv4(h) 99 | h = self.relu(self.bn4(h)) 100 | 101 | h = F.upsample(h,scale_factor = 2) 102 | h = self.tp_conv5(h) 103 | 104 | h = F.tanh(h) 105 | 106 | return h -------------------------------------------------------------------------------- /Model_VAEGAN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from torch import nn 5 | from torch import optim 6 | from torch.nn import functional as F 7 | 8 | class Discriminator(nn.Module): 9 | def __init__(self, channel=512,out_class=1): 10 | super(Discriminator, self).__init__() 11 | 12 | self.channel = channel 13 | n_class = out_class 14 | 15 | self.conv1 = nn.Conv3d(1, channel//8, kernel_size=4, stride=2, padding=1) 16 | self.conv2 = nn.Conv3d(channel//8, channel//4, kernel_size=4, stride=2, padding=1) 17 | self.bn2 = nn.BatchNorm3d(channel//4) 18 | self.conv3 = nn.Conv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) 19 | self.bn3 = nn.BatchNorm3d(channel//2) 20 | self.conv4 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) 21 | self.bn4 = nn.BatchNorm3d(channel) 22 | self.conv5 = nn.Conv3d(channel, n_class, kernel_size=4, stride=1, padding=0) 23 | 24 | 25 | def forward(self, x): 26 | batch_size = x.size()[0] 27 | h1 = F.leaky_relu(self.conv1(x), negative_slope=0.2) 28 | h2 = F.leaky_relu(self.bn2(self.conv2(h1)), negative_slope=0.2) 29 | h3 = F.leaky_relu(self.bn3(self.conv3(h2)), negative_slope=0.2) 30 | h4 = F.leaky_relu(self.bn4(self.conv4(h3)), negative_slope=0.2) 31 | h5 = self.conv5(h4) 32 | output = F.sigmoid(h5.view(h5.size()[0],-1)) 33 | return output 34 | 35 | class Encoder(nn.Module): 36 | def __init__(self, channel=512,out_class=1): 37 | super(Encoder, self).__init__() 38 | 39 | self.channel = channel 40 | n_class = out_class 41 | 42 | self.conv1 = nn.Conv3d(1, channel//8, kernel_size=4, stride=2, padding=1) 43 | self.conv2 = nn.Conv3d(channel//8, channel//4, kernel_size=4, stride=2, padding=1) 44 | self.bn2 = nn.BatchNorm3d(channel//4) 45 | self.conv3 = nn.Conv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) 46 | self.bn3 = nn.BatchNorm3d(channel//2) 47 | self.conv4 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) 48 | self.bn4 = nn.BatchNorm3d(channel) 49 | 50 | self.mean = nn.Sequential( 51 | nn.Linear(32768, 2048), 52 | nn.BatchNorm1d(2048), 53 | nn.ReLU(), 54 | nn.Linear(2048, 1000)) 55 | self.logvar = nn.Sequential( 56 | nn.Linear(32768, 2048), 57 | nn.BatchNorm1d(2048), 58 | nn.ReLU(), 59 | nn.Linear(2048, 1000)) 60 | 61 | def forward(self, x, _return_activations=False): 62 | batch_size = x.size()[0] 63 | h1 = F.leaky_relu(self.conv1(x), negative_slope=0.2) 64 | h2 = F.leaky_relu(self.bn2(self.conv2(h1)), negative_slope=0.2) 65 | h3 = F.leaky_relu(self.bn3(self.conv3(h2)), negative_slope=0.2) 66 | h4 = F.leaky_relu(self.bn4(self.conv4(h3)), negative_slope=0.2) 67 | 68 | mean = self.mean(h4.view(batch_size,-1)) 69 | logvar = self.logvar(h4.view(batch_size,-1)) 70 | 71 | std = logvar.mul(0.5).exp_() 72 | reparametrized_noise = Variable(torch.randn((batch_size, 1000))).cuda() 73 | reparametrized_noise = mean + std * reparametrized_noise 74 | return mean,logvar ,reparametrized_noise 75 | 76 | class Generator(nn.Module): 77 | def __init__(self, noise:int=100, channel:int=64): 78 | super(Generator, self).__init__() 79 | _c = channel 80 | 81 | self.noise = noise 82 | self.fc = nn.Linear(1000,512*4*4*4) 83 | self.bn1 = nn.BatchNorm3d(_c*8) 84 | 85 | self.tp_conv2 = nn.Conv3d(_c*8, _c*4, kernel_size=3, stride=1, padding=1, bias=False) 86 | self.bn2 = nn.BatchNorm3d(_c*4) 87 | 88 | self.tp_conv3 = nn.Conv3d(_c*4, _c*2, kernel_size=3, stride=1, padding=1, bias=False) 89 | self.bn3 = nn.BatchNorm3d(_c*2) 90 | 91 | self.tp_conv4 = nn.Conv3d(_c*2, _c, kernel_size=3, stride=1, padding=1, bias=False) 92 | self.bn4 = nn.BatchNorm3d(_c) 93 | 94 | self.tp_conv5 = nn.Conv3d(_c, 1, kernel_size=3, stride=1, padding=1, bias=False) 95 | 96 | def forward(self, noise): 97 | noise = noise.view(-1, 1000) 98 | h = self.fc(noise) 99 | h = h.view(-1,512,4,4,4) 100 | h = F.relu(self.bn1(h)) 101 | 102 | h = F.upsample(h,scale_factor = 2) 103 | h = self.tp_conv2(h) 104 | h = F.relu(self.bn2(h)) 105 | 106 | h = F.upsample(h,scale_factor = 2) 107 | h = self.tp_conv3(h) 108 | h = F.relu(self.bn3(h)) 109 | 110 | h = F.upsample(h,scale_factor = 2) 111 | h = self.tp_conv4(h) 112 | h = F.relu(self.bn4(h)) 113 | 114 | h = F.upsample(h,scale_factor = 2) 115 | h = self.tp_conv5(h) 116 | 117 | h = F.tanh(h) 118 | 119 | return h 120 | -------------------------------------------------------------------------------- /WGAN_ADNI_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%pylab inline\n", 10 | "import numpy as np\n", 11 | "import torch\n", 12 | "import os\n", 13 | "\n", 14 | "from torch import nn\n", 15 | "from torch import optim\n", 16 | "from torch.nn import functional as F\n", 17 | "from torch import autograd\n", 18 | "from torch.autograd import Variable\n", 19 | "import nibabel as nib\n", 20 | "from torch.utils.data.dataset import Dataset\n", 21 | "from torch.utils.data import dataloader\n", 22 | "from nilearn import plotting\n", 23 | "from ADNI_dataset import *\n", 24 | "from BRATS_dataset import *\n", 25 | "from ATLAS_dataset import *\n", 26 | "from Model_WGAN import *" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "# Configuration" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "BATCH_SIZE=4\n", 43 | "max_epoch = 100\n", 44 | "lr = 0.0001\n", 45 | "gpu = True\n", 46 | "workers = 4\n", 47 | "\n", 48 | "LAMBDA= 10\n", 49 | "#setting latent variable sizes\n", 50 | "latent_dim = 1000" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "trainset = ADNIdataset(augmentation=True)\n", 60 | "train_loader = torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,\n", 61 | " shuffle=True,num_workers=workers)\n", 62 | "if Use_BRATS:\n", 63 | " #'flair' or 't2' or 't1ce'\n", 64 | " trainset = BRATSdataset(imgtype='flair')\n", 65 | " train_loader = torch.utils.data.DataLoader(trainset,batch_size = BATCH_SIZE, shuffle=True,\n", 66 | " num_workers=workers)\n", 67 | "if Use_ATLAS:\n", 68 | " trainset = ATLASdataset(augmentation=True)\n", 69 | " train_loader = torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,\n", 70 | " shuffle=True,num_workers=workers)\n", 71 | "\n", 72 | "def inf_train_gen(data_loader):\n", 73 | " while True:\n", 74 | " for _,images in enumerate(data_loader):\n", 75 | " yield images\n" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "D = Discriminator()\n", 85 | "G = Generator(noise = latent_dim)\n", 86 | "\n", 87 | "g_optimizer = optim.Adam(G.parameters(), lr=0.0002)\n", 88 | "d_optimizer = optim.Adam(D.parameters(), lr=0.0002)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "def calc_gradient_penalty(netD, real_data, fake_data): \n", 98 | " alpha = torch.rand(real_data.size(0),1,1,1,1)\n", 99 | " alpha = alpha.expand(real_data.size())\n", 100 | " \n", 101 | " alpha = alpha.cuda()\n", 102 | "\n", 103 | " interpolates = alpha * real_data + ((1 - alpha) * fake_data)\n", 104 | "\n", 105 | " interpolates = interpolates.cuda()\n", 106 | " interpolates = Variable(interpolates, requires_grad=True)\n", 107 | "\n", 108 | " disc_interpolates = netD(interpolates)\n", 109 | "\n", 110 | " gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,\n", 111 | " grad_outputs=torch.ones(disc_interpolates.size()).cuda(),\n", 112 | " create_graph=True, retain_graph=True, only_inputs=True)[0]\n", 113 | "\n", 114 | " gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA\n", 115 | " return gradient_penalty" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "# Training" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "real_y = Variable(torch.ones((BATCH_SIZE, 1)).cuda())\n", 132 | "fake_y = Variable(torch.zeros((BATCH_SIZE, 1)).cuda())\n", 133 | "loss_f = nn.BCELoss()\n", 134 | "\n", 135 | "d_real_losses = list()\n", 136 | "d_fake_losses = list()\n", 137 | "d_losses = list()\n", 138 | "g_losses = list()\n", 139 | "divergences = list()" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": { 146 | "scrolled": true 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "TOTAL_ITER = 200000\n", 151 | "gen_load = inf_train_gen(train_loader)\n", 152 | "for iteration in range(TOTAL_ITER):\n", 153 | " ###############################################\n", 154 | " # Train D \n", 155 | " ###############################################\n", 156 | " for p in D.parameters(): \n", 157 | " p.requires_grad = True \n", 158 | "\n", 159 | " real_images = gen_load.__next__()\n", 160 | " D.zero_grad()\n", 161 | " real_images = Variable(real_images).cuda()\n", 162 | "\n", 163 | " _batch_size = real_images.size(0)\n", 164 | "\n", 165 | "\n", 166 | " y_real_pred = D(real_images)\n", 167 | "\n", 168 | " d_real_loss = y_real_pred.mean()\n", 169 | " \n", 170 | " noise = Variable(torch.randn((_batch_size, latent_dim, 1, 1, 1)),volatile=True).cuda()\n", 171 | " fake_images = G(noise)\n", 172 | " y_fake_pred = D(fake_images.detach())\n", 173 | "\n", 174 | " d_fake_loss = y_fake_pred.mean()\n", 175 | "\n", 176 | " gradient_penalty = calc_gradient_penalty(D,real_images.data, fake_images.data)\n", 177 | " \n", 178 | " d_loss = - d_real_loss + d_fake_loss +gradient_penalty\n", 179 | " d_loss.backward()\n", 180 | " Wasserstein_D = d_real_loss - d_fake_loss\n", 181 | "\n", 182 | " d_optimizer.step()\n", 183 | "\n", 184 | " ###############################################\n", 185 | " # Train G \n", 186 | " ###############################################\n", 187 | " for p in D.parameters():\n", 188 | " p.requires_grad = False\n", 189 | " \n", 190 | " for iters in range(5):\n", 191 | " G.zero_grad()\n", 192 | " noise = Variable(torch.randn((_batch_size, latent_dim, 1, 1 ,1)).cuda())\n", 193 | " fake_image =G(noise)\n", 194 | " y_fake_g = D(fake_image)\n", 195 | "\n", 196 | " g_loss = -y_fake_g.mean()\n", 197 | "\n", 198 | " g_loss.backward()\n", 199 | " g_optimizer.step()\n", 200 | "\n", 201 | " ###############################################\n", 202 | " # Visualization\n", 203 | " ###############################################\n", 204 | " if iteration%10 == 0:\n", 205 | " d_real_losses.append(d_real_loss.data[0])\n", 206 | " d_fake_losses.append(d_fake_loss.data[0])\n", 207 | " d_losses.append(d_loss.data[0])\n", 208 | " g_losses.append(g_loss.data.cpu().numpy())\n", 209 | "\n", 210 | " print('[{}/{}]'.format(iteration,TOTAL_ITER),\n", 211 | " 'D: {:<8.3}'.format(d_loss.data[0].cpu().numpy()), \n", 212 | " 'D_real: {:<8.3}'.format(d_real_loss.data[0].cpu().numpy()),\n", 213 | " 'D_fake: {:<8.3}'.format(d_fake_loss.data[0].cpu().numpy()), \n", 214 | " 'G: {:<8.3}'.format(g_loss.data[0].cpu().numpy()))\n", 215 | "\n", 216 | " featmask = np.squeeze((0.5*fake_image+0.5)[0].data.cpu().numpy())\n", 217 | "\n", 218 | " featmask = nib.Nifti1Image(featmask,affine = np.eye(4))\n", 219 | " plotting.plot_img(featmask,title=\"FAKE\")\n", 220 | " plotting.show()\n", 221 | " \n", 222 | " if (iteration+1)%500 ==0:\n", 223 | " torch.save(G.state_dict(),'./checkpoint/G_W_iter'+str(iteration+1)+'.pth')\n", 224 | " torch.save(D.state_dict(),'./checkpoint/D_W_iter'+str(iteration+1)+'.pth')" 225 | ] 226 | } 227 | ], 228 | "metadata": { 229 | "kernelspec": { 230 | "display_name": "Python 3", 231 | "language": "python", 232 | "name": "python3" 233 | }, 234 | "language_info": { 235 | "codemirror_mode": { 236 | "name": "ipython", 237 | "version": 3 238 | }, 239 | "file_extension": ".py", 240 | "mimetype": "text/x-python", 241 | "name": "python", 242 | "nbconvert_exporter": "python", 243 | "pygments_lexer": "ipython3", 244 | "version": "3.5.6" 245 | } 246 | }, 247 | "nbformat": 4, 248 | "nbformat_minor": 2 249 | } 250 | -------------------------------------------------------------------------------- /VAEGAN_ADNI_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%pylab inline\n", 10 | "import numpy as np\n", 11 | "import torch\n", 12 | "import os\n", 13 | "\n", 14 | "from torch import nn\n", 15 | "from torch import optim\n", 16 | "from torch.nn import functional as F\n", 17 | "from torch import autograd\n", 18 | "from torch.autograd import Variable\n", 19 | "import nibabel as nib\n", 20 | "from torch.utils.data.dataset import Dataset\n", 21 | "from torch.utils.data import dataloader\n", 22 | "from skimage.transform import resize\n", 23 | "from nilearn import plotting\n", 24 | "from ADNI_dataset import *\n", 25 | "from BRATS_dataset import *\n", 26 | "from ATLAS_dataset import *\n", 27 | "from Model_VAEGAN import *" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "# Configuration" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "BATCH_SIZE=4\n", 44 | "max_epoch = 100\n", 45 | "gpu = True\n", 46 | "workers = 4\n", 47 | "\n", 48 | "reg = 5e-10\n", 49 | "\n", 50 | "gamma = 20\n", 51 | "beta = 10\n", 52 | "\n", 53 | "Use_BRATS=False\n", 54 | "Use_ATLAS = False\n", 55 | "\n", 56 | "#setting latent variable sizes\n", 57 | "latent_dim = 1000" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "trainset = ADNIdataset(augmentation=True)\n", 67 | "train_loader = torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,\n", 68 | " shuffle=True,num_workers=workers)\n", 69 | "if Use_BRATS:\n", 70 | " #'flair' or 't2' or 't1ce'\n", 71 | " trainset = BRATSdataset(imgtype='flair')\n", 72 | " train_loader = torch.utils.data.DataLoader(trainset,batch_size = BATCH_SIZE, shuffle=True,\n", 73 | " num_workers=workers)\n", 74 | "if Use_ATLAS:\n", 75 | " trainset = ATLASdataset(augmentation=True)\n", 76 | " train_loader = torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,\n", 77 | " shuffle=True,num_workers=workers)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "G = Generator(noise = latent_dim)\n", 87 | "D = Discriminator()\n", 88 | "E = Encoder()\n", 89 | "\n", 90 | "G.cuda()\n", 91 | "D.cuda()\n", 92 | "E.cuda()" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "g_optimizer = optim.Adam(G.parameters(), lr=0.0001)\n", 102 | "d_optimizer = optim.Adam(D.parameters(), lr=0.0001)\n", 103 | "e_optimizer = optim.Adam(E.parameters(), lr = 0.0001)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "# Training" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "N_EPOCH = 100\n", 120 | "\n", 121 | "real_y = Variable(torch.ones((BATCH_SIZE, 1)).cuda())\n", 122 | "fake_y = Variable(torch.zeros((BATCH_SIZE, 1)).cuda())\n", 123 | "criterion_bce = nn.BCELoss()\n", 124 | "criterion_l1 = nn.L1Loss()\n" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": { 131 | "scrolled": true 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "for epoch in range(N_EPOCH):\n", 136 | " for step, real_images in enumerate(train_loader):\n", 137 | " _batch_size = real_images.size(0)\n", 138 | " real_images = Variable(real_images,requires_grad=False).cuda()\n", 139 | " z_rand = Variable(torch.randn((_batch_size, latent_dim)),requires_grad=False).cuda()\n", 140 | " mean,logvar,code = E(real_images)\n", 141 | " x_rec = G(code)\n", 142 | " x_rand = G(z_rand)\n", 143 | " ###############################################\n", 144 | " # Train D \n", 145 | " ###############################################\n", 146 | " d_optimizer.zero_grad()\n", 147 | " \n", 148 | " d_real_loss = criterion_bce(D(real_images),real_y[:_batch_size])\n", 149 | " d_recon_loss = criterion_bce(D(x_rec), fake_y[:_batch_size])\n", 150 | " d_fake_loss = criterion_bce(D(x_rand), fake_y[:_batch_size])\n", 151 | " \n", 152 | " dis_loss = d_recon_loss+d_real_loss + d_fake_loss\n", 153 | " dis_loss.backward(retain_graph=True)\n", 154 | " \n", 155 | " d_optimizer.step()\n", 156 | " \n", 157 | " ###############################################\n", 158 | " # Train G\n", 159 | " ###############################################\n", 160 | " g_optimizer.zero_grad()\n", 161 | " output = D(real_images)\n", 162 | " d_real_loss = criterion_bce(output,real_y[:_batch_size])\n", 163 | " output = D(x_rec)\n", 164 | " d_recon_loss = criterion_bce(output,fake_y[:_batch_size])\n", 165 | " output = D(x_rand)\n", 166 | " d_fake_loss = criterion_bce(output,fake_y[:_batch_size])\n", 167 | " \n", 168 | " d_img_loss = d_real_loss + d_recon_loss+ d_fake_loss\n", 169 | " gen_img_loss = -d_img_loss\n", 170 | " \n", 171 | " rec_loss = ((x_rec - real_images)**2).mean()\n", 172 | " \n", 173 | " err_dec = gamma* rec_loss + gen_img_loss\n", 174 | " \n", 175 | " err_dec.backward(retain_graph=True)\n", 176 | " g_optimizer.step()\n", 177 | " ###############################################\n", 178 | " # Train E\n", 179 | " ###############################################\n", 180 | " prior_loss = 1+logvar-mean.pow(2) - logvar.exp()\n", 181 | " prior_loss = (-0.5*torch.sum(prior_loss))/torch.numel(mean.data)\n", 182 | " err_enc = prior_loss + beta*rec_loss\n", 183 | " \n", 184 | " e_optimizer.zero_grad()\n", 185 | " err_enc.backward()\n", 186 | " e_optimizer.step()\n", 187 | " ###############################################\n", 188 | " # Visualization\n", 189 | " ###############################################\n", 190 | " \n", 191 | " if step % 10 == 0:\n", 192 | " print('[{}/{}]'.format(epoch,N_EPOCH),\n", 193 | " 'D: {:<8.3}'.format(dis_loss.data[0].cpu().numpy()), \n", 194 | " 'En: {:<8.3}'.format(err_enc.data[0].cpu().numpy()),\n", 195 | " 'De: {:<8.3}'.format(err_dec.data[0].cpu().numpy()) \n", 196 | " )\n", 197 | " \n", 198 | " featmask = np.squeeze((0.5*real_images[0]+0.5).data.cpu().numpy())\n", 199 | " featmask = nib.Nifti1Image(featmask,affine = np.eye(4))\n", 200 | " plotting.plot_img(featmask,title=\"X_Real\")\n", 201 | " plotting.show()\n", 202 | " \n", 203 | " featmask = np.squeeze((0.5*x_rec[0]+0.5).data.cpu().numpy())\n", 204 | " featmask = nib.Nifti1Image(featmask,affine = np.eye(4))\n", 205 | " plotting.plot_img(featmask,title=\"X_DEC\")\n", 206 | " plotting.show()\n", 207 | " \n", 208 | " featmask = np.squeeze((0.5*x_rand[0]+0.5).data.cpu().numpy())\n", 209 | " featmask = nib.Nifti1Image(featmask,affine = np.eye(4))\n", 210 | " plotting.plot_img(featmask,title=\"X_rand\")\n", 211 | " plotting.show()\n", 212 | "\n", 213 | " torch.save(G.state_dict(),'./chechpoint/G_VG_ep_'+str(epoch+1)+'.pth')\n", 214 | " torch.save(D.state_dict(),'./chechpoint/D_VG_ep_'+str(epoch+1)+'.pth')\n", 215 | " torch.save(E.state_dict(),'./chechpoint/E_VG_ep_'+str(epoch+1)+'.pth')" 216 | ] 217 | } 218 | ], 219 | "metadata": { 220 | "kernelspec": { 221 | "display_name": "Python 3", 222 | "language": "python", 223 | "name": "python3" 224 | }, 225 | "language_info": { 226 | "codemirror_mode": { 227 | "name": "ipython", 228 | "version": 3 229 | }, 230 | "file_extension": ".py", 231 | "mimetype": "text/x-python", 232 | "name": "python", 233 | "nbconvert_exporter": "python", 234 | "pygments_lexer": "ipython3", 235 | "version": "3.5.6" 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 2 240 | } 241 | -------------------------------------------------------------------------------- /pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from math import exp 4 | import numpy as np 5 | import scipy.ndimage as ndimage 6 | 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | 13 | def create_window(window_size, channel=1): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 16 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 17 | _3D_window = filters.gaussian_filter 18 | return window 19 | def ssim_exact(img1, img2, sd=1.5, C1=0.01**2, C2=0.03**2): 20 | 21 | mu1 = ndimage.gaussian_filter(img1, sd) 22 | mu2 = ndimage.gaussian_filter(img2, sd) 23 | mu1_sq = mu1 * mu1 24 | mu2_sq = mu2 * mu2 25 | mu1_mu2 = mu1 * mu2 26 | sigma1_sq = ndimage.gaussian_filter(img1 * img1, sd) - mu1_sq 27 | sigma2_sq = ndimage.gaussian_filter(img2 * img2, sd) - mu2_sq 28 | sigma12 = ndimage.gaussian_filter(img1 * img2, sd) - mu1_mu2 29 | 30 | ssim_num = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) 31 | 32 | ssim_den = ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 33 | 34 | ssim_map = ssim_num / ssim_den 35 | 36 | v1 = 2.0 * sigma12 + C2 37 | v2 = sigma1_sq + sigma2_sq + C2 38 | cs = np.mean(v1 / v2) # contrast sensitivity 39 | 40 | # ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 41 | 42 | return np.mean(ssim_map),cs 43 | 44 | def ssim_3d(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 45 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 46 | if val_range is None: 47 | if torch.max(img1) > 128: 48 | max_val = 255 49 | else: 50 | max_val = 1 51 | 52 | if torch.min(img1) < -0.5: 53 | min_val = -1 54 | else: 55 | min_val = 0 56 | L = max_val - min_val 57 | else: 58 | L = val_range 59 | 60 | padd = 0 61 | (_, channel, height, width,width2) = img1.size() 62 | if window is None: 63 | real_size = min(window_size, height, width,width2) 64 | window = create_window(real_size, channel=channel).to(img1.device) 65 | 66 | mu1 = F.conv3d(img1, window, padding=padd, groups=channel) 67 | mu2 = F.conv3d(img2, window, padding=padd, groups=channel) 68 | 69 | mu1_sq = mu1.pow(2) 70 | mu2_sq = mu2.pow(2) 71 | mu1_mu2 = mu1 * mu2 72 | 73 | sigma1_sq = F.conv3d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 74 | sigma2_sq = F.conv3d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 75 | sigma12 = F.conv3d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 76 | 77 | C1 = (0.01 * L) ** 2 78 | C2 = (0.03 * L) ** 2 79 | 80 | v1 = 2.0 * sigma12 + C2 81 | v2 = sigma1_sq + sigma2_sq + C2 82 | cs = torch.mean(v1 / v2) # contrast sensitivity 83 | 84 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 85 | 86 | if size_average: 87 | ret = ssim_map.mean() 88 | else: 89 | ret = ssim_map.mean(1).mean(1).mean(1) 90 | 91 | if full: 92 | return ret, cs 93 | return ret 94 | 95 | 96 | def msssim_3d(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 97 | device = img1.device 98 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 99 | levels = weights.size()[0] 100 | mssim = [] 101 | mcs = [] 102 | for _ in range(levels): 103 | sim, cs = ssim_exact(img1.data.cpu().numpy(), img2.data.cpu().numpy()) 104 | mssim.append(sim) 105 | mcs.append(cs) 106 | 107 | img1 = F.avg_pool3d(img1, (2, 2,2)) 108 | img2 = F.avg_pool3d(img2, (2, 2,2)) 109 | 110 | # mssim = torch.stack(torch.from_numpy(mssim)) 111 | # mcs = torch.stack(torch.from_numpy(mcs)) 112 | mssim = np.asarray(mssim) 113 | mcs = np.asarray(mcs) 114 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) 115 | if normalize: 116 | mssim = (mssim + 1) / 2 117 | mcs = (mcs + 1) / 2 118 | 119 | pow1 = mcs ** weights 120 | pow2 = mssim ** weights 121 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 122 | output = torch.prod(pow1[:-1] * pow2[-1]) 123 | return output 124 | 125 | 126 | # Classes to re-use window 127 | class SSIM(torch.nn.Module): 128 | def __init__(self, window_size=11, size_average=True, val_range=None): 129 | super(SSIM, self).__init__() 130 | self.window_size = window_size 131 | self.size_average = size_average 132 | self.val_range = val_range 133 | 134 | # Assume 1 channel for SSIM 135 | self.channel = 1 136 | self.window = create_window(window_size) 137 | 138 | def forward(self, img1, img2): 139 | (_, channel, _, _) = img1.size() 140 | 141 | if channel == self.channel and self.window.dtype == img1.dtype: 142 | window = self.window 143 | else: 144 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 145 | self.window = window 146 | self.channel = channel 147 | 148 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 149 | 150 | class MSSSIM_3d(torch.nn.Module): 151 | def __init__(self, window_size=11, size_average=True, channel=3): 152 | super(MSSSIM_3d, self).__init__() 153 | self.window_size = window_size 154 | self.size_average = size_average 155 | self.channel = channel 156 | 157 | def forward(self, img1, img2): 158 | # TODO: store window between calls if possible 159 | return msssim_3d(img1, img2, window_size=self.window_size, size_average=self.size_average) 160 | 161 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 162 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 163 | if val_range is None: 164 | if torch.max(img1) > 128: 165 | max_val = 255 166 | else: 167 | max_val = 1 168 | 169 | if torch.min(img1) < -0.5: 170 | min_val = -1 171 | else: 172 | min_val = 0 173 | L = max_val - min_val 174 | else: 175 | L = val_range 176 | 177 | padd = 0 178 | (_, channel, height, width) = img1.size() 179 | if window is None: 180 | real_size = min(window_size, height, width) 181 | window = create_window(real_size, channel=channel).to(img1.device) 182 | 183 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 184 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 185 | 186 | mu1_sq = mu1.pow(2) 187 | mu2_sq = mu2.pow(2) 188 | mu1_mu2 = mu1 * mu2 189 | 190 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 191 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 192 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 193 | 194 | C1 = (0.01 * L) ** 2 195 | C2 = (0.03 * L) ** 2 196 | 197 | v1 = 2.0 * sigma12 + C2 198 | v2 = sigma1_sq + sigma2_sq + C2 199 | cs = torch.mean(v1 / v2) # contrast sensitivity 200 | 201 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 202 | 203 | if size_average: 204 | ret = ssim_map.mean() 205 | else: 206 | ret = ssim_map.mean(1).mean(1).mean(1) 207 | 208 | if full: 209 | return ret, cs 210 | return ret 211 | 212 | 213 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 214 | device = img1.device 215 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 216 | levels = weights.size()[0] 217 | mssim = [] 218 | mcs = [] 219 | for _ in range(levels): 220 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 221 | mssim.append(sim) 222 | mcs.append(cs) 223 | 224 | img1 = F.avg_pool2d(img1, (2, 2)) 225 | img2 = F.avg_pool2d(img2, (2, 2)) 226 | 227 | mssim = torch.stack(mssim) 228 | mcs = torch.stack(mcs) 229 | 230 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) 231 | if normalize: 232 | mssim = (mssim + 1) / 2 233 | mcs = (mcs + 1) / 2 234 | 235 | pow1 = mcs ** weights 236 | pow2 = mssim ** weights 237 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 238 | output = torch.prod(pow1[:-1] * pow2[-1]) 239 | return output 240 | 241 | 242 | # Classes to re-use window 243 | class SSIM(torch.nn.Module): 244 | def __init__(self, window_size=11, size_average=True, val_range=None): 245 | super(SSIM, self).__init__() 246 | self.window_size = window_size 247 | self.size_average = size_average 248 | self.val_range = val_range 249 | 250 | # Assume 1 channel for SSIM 251 | self.channel = 1 252 | self.window = create_window(window_size) 253 | 254 | def forward(self, img1, img2): 255 | (_, channel, _, _) = img1.size() 256 | 257 | if channel == self.channel and self.window.dtype == img1.dtype: 258 | window = self.window 259 | else: 260 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 261 | self.window = window 262 | self.channel = channel 263 | 264 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 265 | 266 | class MSSSIM(torch.nn.Module): 267 | def __init__(self, window_size=11, size_average=True, channel=3): 268 | super(MSSSIM, self).__init__() 269 | self.window_size = window_size 270 | self.size_average = size_average 271 | self.channel = channel 272 | 273 | def forward(self, img1, img2): 274 | # TODO: store window between calls if possible 275 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) 276 | 277 | -------------------------------------------------------------------------------- /Alpha_GAN_ADNI_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%pylab inline\n", 10 | "import numpy as np\n", 11 | "import torch\n", 12 | "import os\n", 13 | "from torch import nn\n", 14 | "from torch import optim\n", 15 | "from torch.nn import functional as F\n", 16 | "from torch import autograd\n", 17 | "from torch.autograd import Variable\n", 18 | "import nibabel as nib\n", 19 | "from torch.utils.data.dataset import Dataset\n", 20 | "from torch.utils.data import dataloader\n", 21 | "from nilearn import plotting\n", 22 | "from ADNI_dataset import *\n", 23 | "from BRATS_dataset import *\n", 24 | "from ATLAS_dataset import *\n", 25 | "from Model_alphaGAN import *" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "# Configuration" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "BATCH_SIZE=4\n", 42 | "gpu = True\n", 43 | "workers = 4\n", 44 | "LAMBDA= 10\n", 45 | "_eps = 1e-15\n", 46 | "\n", 47 | "Use_BRATS = False\n", 48 | "Use_ATLAS = False\n", 49 | "\n", 50 | "#setting latent variable sizes\n", 51 | "latent_dim = 1000" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": { 58 | "scrolled": true 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "trainset = ADNIdataset(augmentation=False)\n", 63 | "train_loader = torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,\n", 64 | " shuffle=True,num_workers=workers)\n", 65 | "\n", 66 | "if Use_BRATS:\n", 67 | " #imgtype -> 'flair' or 't2' or 't1ce'\n", 68 | " trainset = BRATSdataset(train=True, imgtype = 'flair',augmentation=False)\n", 69 | " train_loader = torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,\n", 70 | " shuffle=True,num_workers=workers)\n", 71 | "if Use_ATLAS:\n", 72 | " trainset = ATLASdataset(augmentation=True)\n", 73 | " train_loader = torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,\n", 74 | " shuffle=True,num_workers=workers)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "def inf_train_gen(data_loader):\n", 84 | " while True:\n", 85 | " for _,images in enumerate(data_loader):\n", 86 | " yield images" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "G = Generator(noise = latent_dim)\n", 96 | "CD = Code_Discriminator(code_size = latent_dim ,num_units = 4096)\n", 97 | "D = Discriminator(is_dis=True)\n", 98 | "E = Discriminator(out_class = latent_dim ,is_dis=False)\n", 99 | "\n", 100 | "G.cuda()\n", 101 | "D.cuda()\n", 102 | "CD.cuda()\n", 103 | "E.cuda()" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "g_optimizer = optim.Adam(G.parameters(), lr=0.0002)\n", 113 | "d_optimizer = optim.Adam(D.parameters(), lr=0.0002)\n", 114 | "e_optimizer = optim.Adam(E.parameters(), lr = 0.0002)\n", 115 | "cd_optimizer = optim.Adam(CD.parameters(), lr = 0.0002)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "def calc_gradient_penalty(model, x, x_gen, w=10):\n", 125 | " \"\"\"WGAN-GP gradient penalty\"\"\"\n", 126 | " assert x.size()==x_gen.size(), \"real and sampled sizes do not match\"\n", 127 | " alpha_size = tuple((len(x), *(1,)*(x.dim()-1)))\n", 128 | " alpha_t = torch.cuda.FloatTensor if x.is_cuda else torch.Tensor\n", 129 | " alpha = alpha_t(*alpha_size).uniform_()\n", 130 | " x_hat = x.data*alpha + x_gen.data*(1-alpha)\n", 131 | " x_hat = Variable(x_hat, requires_grad=True)\n", 132 | "\n", 133 | " def eps_norm(x):\n", 134 | " x = x.view(len(x), -1)\n", 135 | " return (x*x+_eps).sum(-1).sqrt()\n", 136 | " def bi_penalty(x):\n", 137 | " return (x-1)**2\n", 138 | "\n", 139 | " grad_xhat = torch.autograd.grad(model(x_hat).sum(), x_hat, create_graph=True, only_inputs=True)[0]\n", 140 | "\n", 141 | " penalty = w*bi_penalty(eps_norm(grad_xhat)).mean()\n", 142 | " return penalty" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "# Training" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "real_y = Variable(torch.ones((BATCH_SIZE, 1)).cuda(async=True))\n", 159 | "fake_y = Variable(torch.zeros((BATCH_SIZE, 1)).cuda(async=True))\n", 160 | "\n", 161 | "criterion_bce = nn.BCELoss()\n", 162 | "criterion_l1 = nn.L1Loss()\n", 163 | "criterion_mse = nn.MSELoss()" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "gen_load = inf_train_gen(train_loader)\n", 173 | "MAX_ITER = 200000\n", 174 | "for iteration in range(MAX_ITER):\n", 175 | " ###############################################\n", 176 | " # Train Encoder - Generator \n", 177 | " ###############################################\n", 178 | " for p in D.parameters(): # reset requires_grad\n", 179 | " p.requires_grad = False\n", 180 | " for p in CD.parameters(): # reset requires_grad\n", 181 | " p.requires_grad = False\n", 182 | " for p in E.parameters(): # reset requires_grad\n", 183 | " p.requires_grad = True\n", 184 | " for p in G.parameters(): # reset requires_grad\n", 185 | " p.requires_grad = True\n", 186 | "\n", 187 | " g_optimizer.zero_grad()\n", 188 | " e_optimizer.zero_grad()\n", 189 | "\n", 190 | "\n", 191 | " for iters in range(1):\n", 192 | " real_images = gen_load.__next__()\n", 193 | " real_images = Variable(real_images,volatile=True).cuda(async=True)\n", 194 | " _batch_size = real_images.size(0)\n", 195 | " z_hat = E(real_images).view(_batch_size,-1)\n", 196 | " z_rand = Variable(torch.randn((_batch_size,latent_dim)),requires_grad=False).cuda()\n", 197 | "\n", 198 | " x_hat = G(z_hat)\n", 199 | " x_rand = G(z_rand)\n", 200 | "\n", 201 | " l1_loss = 10 * criterion_l1(x_hat, real_images)\n", 202 | " c_loss = criterion_bce(CD(z_hat), real_y[:_batch_size])\n", 203 | " d_real_loss = criterion_bce(D(x_hat), real_y[:_batch_size]) \n", 204 | " d_fake_loss = criterion_bce(D(x_rand), real_y[:_batch_size])\n", 205 | "\n", 206 | " loss1 = l1_loss + c_loss + d_real_loss + d_fake_loss\n", 207 | "\n", 208 | " loss1.backward(retain_graph=True)\n", 209 | " e_optimizer.step()\n", 210 | "\n", 211 | " g_optimizer.step()\n", 212 | " g_optimizer.step()\n", 213 | "\n", 214 | " ###############################################\n", 215 | " # Train D\n", 216 | " ###############################################\n", 217 | " for p in D.parameters(): \n", 218 | " p.requires_grad = True\n", 219 | " for p in CD.parameters(): \n", 220 | " p.requires_grad = False\n", 221 | " for p in E.parameters(): \n", 222 | " p.requires_grad = False\n", 223 | " for p in G.parameters(): \n", 224 | " p.requires_grad = False\n", 225 | "\n", 226 | " for iters in range(1):\n", 227 | " d_optimizer.zero_grad()\n", 228 | "\n", 229 | " z_rand = Variable(torch.randn((_batch_size,latent_dim)),volatile=True).cuda()\n", 230 | " z_hat = E(real_images).view(_batch_size,-1)\n", 231 | " x_hat = G(z_hat)\n", 232 | " x_rand = G(z_rand)\n", 233 | "\n", 234 | " x_loss2 = 2.0 * criterion_bce(D(real_images), real_y[:_batch_size])+criterion_bce(D(x_hat), fake_y[:_batch_size])\n", 235 | " z_loss2 = criterion_bce(D(x_rand), fake_y[:_batch_size])\n", 236 | " loss2 = x_loss2 + z_loss2\n", 237 | "\n", 238 | " if iters<4:\n", 239 | " loss2.backward(retain_graph=True)\n", 240 | " else:\n", 241 | " loss2.backward(retain_graph=True)\n", 242 | " d_optimizer.step()\n", 243 | " ###############################################\n", 244 | " # Train CD\n", 245 | " ###############################################\n", 246 | " for p in D.parameters(): # reset requires_grad\n", 247 | " p.requires_grad = False\n", 248 | " for p in CD.parameters(): # reset requires_grad\n", 249 | " p.requires_grad = True\n", 250 | " for p in E.parameters(): # reset requires_grad\n", 251 | " p.requires_grad = False\n", 252 | " for p in G.parameters(): # reset requires_grad\n", 253 | " p.requires_grad = False\n", 254 | "\n", 255 | " for iters in range(1):\n", 256 | " cd_optimizer.zero_grad()\n", 257 | " z_hat = E(real_images).view(_batch_size,-1)\n", 258 | " x_loss3 = criterion_bce(CD(z_hat), fake_y[:_batch_size])\n", 259 | " z_rand = Variable(torch.randn((_batch_size,latent_dim)),volatile=True).cuda()\n", 260 | " z_loss3 = criterion_bce(CD(z_rand), real_y[:_batch_size])\n", 261 | " loss3 = x_loss3 + z_loss3\n", 262 | " loss3.backward(retain_graph=True)\n", 263 | " cd_optimizer.step()\n", 264 | " \n", 265 | " ###############################################\n", 266 | " # Visualization\n", 267 | " ###############################################\n", 268 | "\n", 269 | " if iteration % 50 == 0:\n", 270 | " print('[{}/{}]'.format(iteration,50000),\n", 271 | " 'D: {:<8.3}'.format(loss2.data[0].cpu().numpy()), \n", 272 | " 'En_Ge: {:<8.3}'.format(loss1.data[0].cpu().numpy()),\n", 273 | " 'Code: {:<8.3}'.format(loss3.data[0].cpu().numpy()),\n", 274 | " )\n", 275 | "\n", 276 | " featmask = np.squeeze((0.5*real_images[0]+0.5).data.cpu().numpy())\n", 277 | " featmask = nib.Nifti1Image(featmask,affine = np.eye(4))\n", 278 | " plotting.plot_img(featmask,title=\"Real\")\n", 279 | " plotting.show()\n", 280 | "\n", 281 | " featmask = np.squeeze((0.5*x_hat[0]+0.5).data.cpu().numpy())\n", 282 | " featmask = nib.Nifti1Image(featmask,affine = np.eye(4))\n", 283 | " plotting.plot_img(featmask,title=\"DEC\")\n", 284 | " plotting.show()\n", 285 | "\n", 286 | " featmask = np.squeeze((0.5*x_rand[0]+0.5).data.cpu().numpy())\n", 287 | " featmask = nib.Nifti1Image(featmask,affine = np.eye(4))\n", 288 | " plotting.plot_img(featmask,title=\"Rand\")\n", 289 | " plotting.show()\n", 290 | "\n", 291 | " if (iteration+1)%500 ==0: \n", 292 | " torch.save(G.state_dict(),'./checkpoint/G_noW_iter'+str(iteration+1)+'.pth')\n", 293 | " torch.save(D.state_dict(),'./checkpoint/D_noW_iter'+str(iteration+1)+'.pth')\n", 294 | " torch.save(E.state_dict(),'./checkpoint/E_noW_iter'+str(iteration+1)+'.pth')\n", 295 | " torch.save(CD.state_dict(),'./checkpoint/CD_noW_iter'+str(iteration+1)+'.pth')" 296 | ] 297 | } 298 | ], 299 | "metadata": { 300 | "kernelspec": { 301 | "display_name": "Python 3", 302 | "language": "python", 303 | "name": "python3" 304 | }, 305 | "language_info": { 306 | "codemirror_mode": { 307 | "name": "ipython", 308 | "version": 3 309 | }, 310 | "file_extension": ".py", 311 | "mimetype": "text/x-python", 312 | "name": "python", 313 | "nbconvert_exporter": "python", 314 | "pygments_lexer": "ipython3", 315 | "version": "3.5.6" 316 | } 317 | }, 318 | "nbformat": 4, 319 | "nbformat_minor": 2 320 | } 321 | -------------------------------------------------------------------------------- /Alpha_WGAN_ADNI_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%pylab inline\n", 10 | "import numpy as np\n", 11 | "import torch\n", 12 | "import os\n", 13 | "from torch import nn\n", 14 | "from torch import optim\n", 15 | "from torch.nn import functional as F\n", 16 | "from torch import autograd\n", 17 | "from torch.autograd import Variable\n", 18 | "import nibabel as nib\n", 19 | "from torch.utils.data.dataset import Dataset\n", 20 | "from torch.utils.data import dataloader\n", 21 | "from nilearn import plotting\n", 22 | "from ADNI_dataset import *\n", 23 | "from BRATS_dataset import *\n", 24 | "from ATLAS_dataset import *\n", 25 | "from Model_alphaWGAN import *" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "# Configuration" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "BATCH_SIZE=4\n", 42 | "gpu = True\n", 43 | "workers = 4\n", 44 | "\n", 45 | "LAMBDA= 10\n", 46 | "_eps = 1e-15\n", 47 | "Use_BRATS=False\n", 48 | "Use_ATLAS = False\n", 49 | "\n", 50 | "#setting latent variable sizes\n", 51 | "latent_dim = 1000" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": { 58 | "scrolled": true 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "trainset = ADNIdataset(augmentation=True)\n", 63 | "train_loader = torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,\n", 64 | " shuffle=True,num_workers=workers)\n", 65 | "if Use_BRATS:\n", 66 | " #'flair' or 't2' or 't1ce'\n", 67 | " trainset = BRATSdataset(imgtype='flair')\n", 68 | " train_loader = torch.utils.data.DataLoader(trainset,batch_size = BATCH_SIZE, shuffle=True,\n", 69 | " num_workers=workers)\n", 70 | "if Use_ATLAS:\n", 71 | " trainset = ATLASdataset(augmentation=True)\n", 72 | " train_loader = torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,\n", 73 | " shuffle=True,num_workers=workers)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def inf_train_gen(data_loader):\n", 83 | " while True:\n", 84 | " for _,images in enumerate(data_loader):\n", 85 | " yield images" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "G = Generator(noise = latent_dim)\n", 95 | "CD = Code_Discriminator(code_size = latent_dim ,num_units = 4096)\n", 96 | "D = Discriminator(is_dis=True)\n", 97 | "E = Discriminator(out_class = latent_dim,is_dis=False)\n", 98 | "\n", 99 | "G.cuda()\n", 100 | "D.cuda()\n", 101 | "CD.cuda()\n", 102 | "E.cuda()" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "g_optimizer = optim.Adam(G.parameters(), lr=0.0002)\n", 112 | "d_optimizer = optim.Adam(D.parameters(), lr=0.0002)\n", 113 | "e_optimizer = optim.Adam(E.parameters(), lr = 0.0002)\n", 114 | "cd_optimizer = optim.Adam(CD.parameters(), lr = 0.0002)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "def calc_gradient_penalty(model, x, x_gen, w=10):\n", 124 | " \"\"\"WGAN-GP gradient penalty\"\"\"\n", 125 | " assert x.size()==x_gen.size(), \"real and sampled sizes do not match\"\n", 126 | " alpha_size = tuple((len(x), *(1,)*(x.dim()-1)))\n", 127 | " alpha_t = torch.cuda.FloatTensor if x.is_cuda else torch.Tensor\n", 128 | " alpha = alpha_t(*alpha_size).uniform_()\n", 129 | " x_hat = x.data*alpha + x_gen.data*(1-alpha)\n", 130 | " x_hat = Variable(x_hat, requires_grad=True)\n", 131 | "\n", 132 | " def eps_norm(x):\n", 133 | " x = x.view(len(x), -1)\n", 134 | " return (x*x+_eps).sum(-1).sqrt()\n", 135 | " def bi_penalty(x):\n", 136 | " return (x-1)**2\n", 137 | "\n", 138 | " grad_xhat = torch.autograd.grad(model(x_hat).sum(), x_hat, create_graph=True, only_inputs=True)[0]\n", 139 | "\n", 140 | " penalty = w*bi_penalty(eps_norm(grad_xhat)).mean()\n", 141 | " return penalty" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "# Training" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "real_y = Variable(torch.ones((BATCH_SIZE, 1)).cuda(async=True))\n", 158 | "fake_y = Variable(torch.zeros((BATCH_SIZE, 1)).cuda(async=True))\n", 159 | "\n", 160 | "criterion_bce = nn.BCELoss()\n", 161 | "criterion_l1 = nn.L1Loss()\n", 162 | "criterion_mse = nn.MSELoss()" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "g_iter = 1\n", 172 | "d_iter = 1\n", 173 | "cd_iter =1\n", 174 | "TOTAL_ITER = 200000\n", 175 | "gen_load = inf_train_gen(train_loader)\n", 176 | "for iteration in range(TOTAL_ITER):\n", 177 | " for p in D.parameters(): \n", 178 | " p.requires_grad = False\n", 179 | " for p in CD.parameters(): \n", 180 | " p.requires_grad = False\n", 181 | " for p in E.parameters(): \n", 182 | " p.requires_grad = True\n", 183 | " for p in G.parameters(): \n", 184 | " p.requires_grad = True\n", 185 | "\n", 186 | " ###############################################\n", 187 | " # Train Encoder - Generator \n", 188 | " ###############################################\n", 189 | " for iters in range(g_iter):\n", 190 | " G.zero_grad()\n", 191 | " E.zero_grad()\n", 192 | " real_images = gen_load.__next__()\n", 193 | " _batch_size = real_images.size(0)\n", 194 | " real_images = Variable(real_images,volatile=True).cuda(async=True)\n", 195 | " z_rand = Variable(torch.randn((_batch_size,latent_dim)),volatile=True).cuda()\n", 196 | " z_hat = E(real_images).view(_batch_size,-1)\n", 197 | " x_hat = G(z_hat)\n", 198 | " x_rand = G(z_rand)\n", 199 | " c_loss = -CD(z_hat).mean()\n", 200 | "\n", 201 | " d_real_loss = D(x_hat).mean()\n", 202 | " d_fake_loss = D(x_rand).mean()\n", 203 | " d_loss = -d_fake_loss-d_real_loss\n", 204 | " l1_loss =10* criterion_l1(x_hat,real_images)\n", 205 | " loss1 = l1_loss + c_loss + d_loss\n", 206 | "\n", 207 | " if iters" 158 | ] 159 | }, 160 | "metadata": {}, 161 | "output_type": "display_data" 162 | }, 163 | { 164 | "data": { 165 | "image/png": "\n", 166 | "text/plain": [ 167 | "
" 168 | ] 169 | }, 170 | "metadata": {}, 171 | "output_type": "display_data" 172 | } 173 | ], 174 | "source": [ 175 | "Show_color = False\n", 176 | "\n", 177 | "noise = Variable(torch.randn((1, 1000)).cuda())\n", 178 | "fake_image = G(noise)\n", 179 | "featmask = np.squeeze(fake_image[0].data.cpu().numpy())\n", 180 | "featmask = nib.Nifti1Image(featmask,affine = np.eye(4))\n", 181 | "\n", 182 | "arr1 = [4,6,8,10,12,14,16,18,20,22,24,26,28,30,32]\n", 183 | "arr2 = [34,36,38,40,42,44,46,48,50,52,54,56,58,60]\n", 184 | "if Show_color:\n", 185 | " disp = plotting.plot_img(featmask,cut_coords=arr1,draw_cross=False,annotate=False,black_bg=True,display_mode='x')\n", 186 | " # disp.annotate(size=25,left_right=False,positions=True)\n", 187 | " plotting.show()\n", 188 | " disp=plotting.plot_img(featmask,cut_coords=arr2,draw_cross=False,annotate=False,black_bg=True,display_mode='x')\n", 189 | " # disp.annotate(size=25,left_right=False)\n", 190 | " plotting.show()\n", 191 | "else:\n", 192 | " disp = plotting.plot_anat(featmask,cut_coords=arr1,draw_cross=False,annotate=False,black_bg=True,display_mode='x')\n", 193 | " plotting.show()\n", 194 | " # disp.annotate(size=25,left_right=False)\n", 195 | " disp=plotting.plot_anat(featmask,cut_coords=arr2,draw_cross=False,annotate=False,black_bg=True,display_mode='x')\n", 196 | " # disp.annotate(size=25,left_right=False)\n", 197 | " plotting.show()" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "Fake Image - Center cut slices Visualization" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "noise = Variable(torch.randn((1, 1000)).cuda())\n", 214 | "fake_image = G(noise)\n", 215 | "featmask = np.squeeze(fake_image[0].data.cpu().numpy())\n", 216 | "featmask = nib.Nifti1Image(featmask,affine = np.eye(4))\n", 217 | "plotting.plot_img(featmask,cut_coords=(32,32,32),draw_cross=False,annotate=False,black_bg=True)\n", 218 | "plotting.plot_anat(featmask,cut_coords=(32,32,32),draw_cross=False,annotate=False,black_bg=True)\n", 219 | "plotting.show()" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 29, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "noise = Variable(torch.randn((1, 1000)).cuda())\n", 229 | "fake_image = G(noise)\n", 230 | "featmask = np.squeeze(fake_image[0].data.cpu().numpy())\n", 231 | "# featmask = nib.Nifti1Image(featmask,affine = np.eye(4))\n", 232 | "\n", 233 | "from PIL import Image\n", 234 | "for i in range(64):\n", 235 | " A = 0.5*featmask[:,:,i]+0.5\n", 236 | " im = Image.fromarray(np.uint8(255*A)).convert(\"L\")\n", 237 | " im.save('./SCANS/T2_num'+str(i)+'.png')\n", 238 | " \n", 239 | "\n", 240 | "# plotting.plot_anat(featmask,cut_coords=(32,32),draw_cross=False,annotate=False,black_bg=True,display_mode ='x')" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": {}, 246 | "source": [ 247 | "Real Image - Slice series visualization" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": { 254 | "scrolled": true 255 | }, 256 | "outputs": [], 257 | "source": [ 258 | "Show_color = False\n", 259 | "\n", 260 | "image = gen_load.__next__()\n", 261 | "featmask = np.squeeze(image[0].data.cpu().numpy())\n", 262 | "featmask = nib.Nifti1Image(featmask,affine = np.eye(4))\n", 263 | "arr1 = [4,6,8,10,12,14,16,18,20,22,24,26,28,30,32]\n", 264 | "arr2 = [34,36,38,40,42,44,46,48,50,52,54,56,58,60]\n", 265 | "\n", 266 | "if Show_color:\n", 267 | " disp = plotting.plot_img(featmask,cut_coords=arr1,draw_cross=False,annotate=False,black_bg=True,display_mode='x')\n", 268 | " # disp.annotate(size=25,left_right=False,positions=True)\n", 269 | " plotting.show()\n", 270 | " disp=plotting.plot_img(featmask,cut_coords=arr2,draw_cross=False,annotate=False,black_bg=True,display_mode='x')\n", 271 | " # disp.annotate(size=25,left_right=False)\n", 272 | " plotting.show()\n", 273 | "else:\n", 274 | " disp = plotting.plot_anat(featmask,cut_coords=arr1,draw_cross=False,annotate=False,black_bg=True,display_mode='x')\n", 275 | " plotting.show()\n", 276 | " # disp.annotate(size=25,left_right=False)\n", 277 | " disp=plotting.plot_anat(featmask,cut_coords=arr2,draw_cross=False,annotate=False,black_bg=True,display_mode='x')\n", 278 | " # disp.annotate(size=25,left_right=False)\n", 279 | " plotting.show()" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 47, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "gen_load = inf_train_gen(train_loader)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 51, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "image = gen_load.__next__()\n", 298 | "featmask = np.squeeze(image[0].data.cpu().numpy())\n", 299 | "# featmask = nib.Nifti1Image(featmask,affine = np.eye(4))\n", 300 | "\n", 301 | "from PIL import Image\n", 302 | "for i in range(64):\n", 303 | " A = 0.5*featmask[:,:,i]+0.5\n", 304 | " im = Image.fromarray(np.uint8(255*A)).convert(\"L\")\n", 305 | " im.save('./SCANS/Real_T2_num'+str(i)+'.png')\n", 306 | " " 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "# MS-SSIM Calculation" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 4, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "sum_ssim = 0\n", 323 | "for k in range(20):\n", 324 | " for i,dat in enumerate(train_loader):\n", 325 | " if len(dat)!=2:\n", 326 | " break\n", 327 | " img1 = dat[0]\n", 328 | " img2 = dat[1]\n", 329 | "\n", 330 | " msssim = pytorch_ssim.msssim_3d(img1,img2)\n", 331 | " sum_ssim = sum_ssim+msssim\n", 332 | " print(sum_ssim/((k+1)*(i+1)))" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": null, 338 | "metadata": {}, 339 | "outputs": [], 340 | "source": [ 341 | "sum_ssim = 0\n", 342 | "for i in range(1000):\n", 343 | " noise = Variable(torch.randn((2, 1000)).cuda())\n", 344 | " fake_image = G(noise)\n", 345 | "\n", 346 | " img1 = fake_image[0]\n", 347 | " img2 = fake_image[1]\n", 348 | "\n", 349 | " msssim = pytorch_ssim.msssim_3d(img1,img2)\n", 350 | " sum_ssim = sum_ssim+msssim\n", 351 | "print(sum_ssim/1000)\n" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "# Maximum-Mean Discrepancy Score" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "for p in G.parameters():\n", 368 | " p.requires_grad = False\n", 369 | "\n", 370 | "meanarr = []\n", 371 | "for s in range(100):\n", 372 | " distmean = 0.0\n", 373 | " for i,(y) in enumerate(train_loader):\n", 374 | " y = Variable(y).cuda()\n", 375 | " noise = Variable(torch.randn((y.size(0), 1000)).cuda())\n", 376 | " x = G(noise)\n", 377 | "\n", 378 | " B = y.size(0)\n", 379 | " x = x.view(x.size(0), x.size(2) * x.size(3)*x.size(4))\n", 380 | " y = y.view(y.size(0), y.size(2) * y.size(3)*y.size(4))\n", 381 | "\n", 382 | " xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t())\n", 383 | "\n", 384 | " beta = (1./(B*B))\n", 385 | " gamma = (2./(B*B)) \n", 386 | "\n", 387 | " Dist = beta * (torch.sum(xx)+torch.sum(yy)) - gamma * torch.sum(zz)\n", 388 | " distmean += Dist\n", 389 | " print('Mean:'+str(distmean/(i+1)))\n", 390 | " meanarr.append(distmean/(i+1))\n", 391 | "meanarr = numpy.array(meanarr)\n", 392 | "print('Total_mean:'+str(np.mean(meanarr))+' STD:'+str(np.std(meanarr)))" 393 | ] 394 | } 395 | ], 396 | "metadata": { 397 | "kernelspec": { 398 | "display_name": "Python 3", 399 | "language": "python", 400 | "name": "python3" 401 | }, 402 | "language_info": { 403 | "codemirror_mode": { 404 | "name": "ipython", 405 | "version": 3 406 | }, 407 | "file_extension": ".py", 408 | "mimetype": "text/x-python", 409 | "name": "python", 410 | "nbconvert_exporter": "python", 411 | "pygments_lexer": "ipython3", 412 | "version": "3.5.6" 413 | } 414 | }, 415 | "nbformat": 4, 416 | "nbformat_minor": 2 417 | } 418 | --------------------------------------------------------------------------------