├── models ├── __init__.py ├── layers.py ├── Model_HA_GAN_128.py └── Model_HA_GAN_256.py ├── figures ├── main_github.png ├── tensorboard.png └── sample_HA_GAN.png ├── LICENSE ├── volume_dataset.py ├── preprocess.py ├── utils.py ├── README.md ├── environment.yml ├── evaluation ├── resnet3D.py └── fid_score.py └── train.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/main_github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/batmanlab/HA-GAN/HEAD/figures/main_github.png -------------------------------------------------------------------------------- /figures/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/batmanlab/HA-GAN/HEAD/figures/tensorboard.png -------------------------------------------------------------------------------- /figures/sample_HA_GAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/batmanlab/HA-GAN/HEAD/figures/sample_HA_GAN.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 batmanlab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /volume_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from sklearn.model_selection import KFold 3 | import numpy as np 4 | import glob 5 | 6 | class Volume_Dataset(Dataset): 7 | 8 | def __init__(self, data_dir, mode='train', fold=0, num_class=0): 9 | self.sid_list = [] 10 | self.data_dir = data_dir 11 | self.num_class = num_class 12 | 13 | for item in glob.glob(self.data_dir+"*.npy"): 14 | self.sid_list.append(item.split('/')[-1]) 15 | 16 | self.sid_list.sort() 17 | self.sid_list = np.asarray(self.sid_list) 18 | 19 | kf = KFold(n_splits=5, shuffle=True, random_state=0) 20 | train_index, valid_index = list(kf.split(self.sid_list))[fold] 21 | print("Fold:", fold) 22 | if mode=="train": 23 | self.sid_list = self.sid_list[train_index] 24 | else: 25 | self.sid_list = self.sid_list[valid_index] 26 | print("Dataset size:", len(self)) 27 | 28 | self.class_label_dict = dict() 29 | if self.num_class > 0: # conditional 30 | FILE = open("class_label.csv", "r") 31 | FILE.readline() # header 32 | for myline in FILE.readlines(): 33 | mylist = myline.strip("\n").split(",") 34 | self.class_label_dict[mylist[0]] = int(mylist[1]) 35 | FILE.close() 36 | 37 | def __len__(self): 38 | return len(self.sid_list) 39 | 40 | def __getitem__(self, idx): 41 | img = np.load(self.data_dir+self.sid_list[idx]) 42 | class_label = self.class_label_dict.get(self.sid_list[idx], -1) # -1 if no class label 43 | return img[None,:,:,:], class_label 44 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # resize and rescale images for preprocessing 2 | 3 | import glob 4 | import SimpleITK as sitk 5 | import numpy as np 6 | from skimage.transform import resize 7 | import os 8 | import multiprocessing as mp 9 | 10 | ### Configs 11 | # 8 cores are used for multi-thread processing 12 | NUM_JOBS = 8 13 | # resized output size, can be 128 or 256 14 | IMG_SIZE = 256 15 | INPUT_DATA_DIR = '/path_to_imgs/' 16 | OUTPUT_DATA_DIR = '/output_folder/' 17 | # the intensity range is clipped with the two thresholds, this default is used for our CT images, please adapt to your own dataset 18 | LOW_THRESHOLD = -1024 19 | HIGH_THRESHOLD = 600 20 | # suffix (ext.) of input images 21 | SUFFIX = '.nii.gz' 22 | # whether or not to trim blank axial slices, recommend to set as True 23 | TRIM_BLANK_SLICES = True 24 | 25 | def resize_img(img): 26 | nan_mask = np.isnan(img) # Remove NaN 27 | img[nan_mask] = LOW_THRESHOLD 28 | img = np.interp(img, [LOW_THRESHOLD, HIGH_THRESHOLD], [-1,1]) 29 | 30 | if TRIM_BLANK_SLICES: 31 | valid_plane_i = np.mean(img, (1,2)) != -1 # Remove blank axial planes 32 | img = img[valid_plane_i,:,:] 33 | 34 | img = resize(img, (IMG_SIZE, IMG_SIZE, IMG_SIZE), mode='constant', cval=-1) 35 | return img 36 | 37 | def main(): 38 | img_list = list(glob.glob(INPUT_DATA_DIR+"*"+SUFFIX)) 39 | 40 | processes = [] 41 | for i in range(NUM_JOBS): 42 | processes.append(mp.Process(target=batch_resize, args=(i, img_list))) 43 | for p in processes: 44 | p.start() 45 | 46 | def batch_resize(batch_idx, img_list): 47 | for idx in range(len(img_list)): 48 | if idx % NUM_JOBS != batch_idx: 49 | continue 50 | imgname = img_list[idx].split('/')[-1] 51 | if os.path.exists(OUTPUT_DATA_DIR+imgname.split('.')[0]+".npy"): 52 | # skip images that already finished pre-processing 53 | continue 54 | try: 55 | img = sitk.ReadImage(INPUT_DATA_DIR + img_list[idx]) 56 | except Exception as e: 57 | # skip corrupted images 58 | print(e) 59 | print("Image loading error:", imgname) 60 | continue 61 | img = sitk.GetArrayFromImage(img) 62 | try: 63 | img = resize_img(img) 64 | except Exception as e: # Some images are corrupted 65 | print(e) 66 | print("Image resize error:", imgname) 67 | continue 68 | # preprocessed images are saved in numpy arrays 69 | np.save(OUTPUT_DATA_DIR+imgname.split('.')[0]+".npy", img) 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | from skimage.transform import resize 6 | 7 | import torch 8 | 9 | def post_process_brain(x_pred): 10 | x_pred = resize(x_pred, (256-90,256-40,256-40), mode='constant', cval=0.) 11 | x_canvas = np.zeros((256,256,256)) 12 | x_canvas[50:-40,20:-20,20:-20] = x_pred 13 | x_canvas = np.flip(x_canvas,0) 14 | return x_canvas 15 | 16 | def _itensity_normalize(volume): 17 | pixels = volume[volume > 0] 18 | mean = pixels.mean() 19 | std = pixels.std() 20 | out = (volume - mean)/std 21 | return out 22 | 23 | class Flatten(torch.nn.Module): 24 | def forward(self, inp): 25 | return inp.view(inp.size(0), -1) 26 | 27 | def calculate_nmse(img1, img2): 28 | img1 = img1.astype(np.float64) 29 | img2 = img2.astype(np.float64) 30 | mse = np.mean((img1 - img2)**2) 31 | mse0 = np.mean(img1**2) 32 | if mse == 0: 33 | return float('inf') 34 | return mse / mse0 * 100. 35 | 36 | def calculate_psnr(img1, img2): 37 | # img1 and img2 have range [0, 1] 38 | img1 = img1.astype(np.float64) 39 | img2 = img2.astype(np.float64) 40 | mse = np.mean((img1 - img2)**2) 41 | if mse == 0: 42 | return float('inf') 43 | return 20 * math.log10(1.0 / math.sqrt(mse)) 44 | 45 | class KLN01Loss(torch.nn.Module): 46 | 47 | def __init__(self, direction, minimize): 48 | super(KLN01Loss, self).__init__() 49 | self.minimize = minimize 50 | assert direction in ['pq', 'qp'], 'direction?' 51 | 52 | self.direction = direction 53 | 54 | def forward(self, samples): 55 | 56 | assert samples.nelement() == samples.size(1) * samples.size(0), 'wtf?' 57 | 58 | samples = samples.view(samples.size(0), -1) 59 | 60 | self.samples_var = var(samples) 61 | self.samples_mean = samples.mean(0) 62 | 63 | samples_mean = self.samples_mean 64 | samples_var = self.samples_var 65 | 66 | if self.direction == 'pq': 67 | # mu_1 = 0; sigma_1 = 1 68 | 69 | t1 = (1 + samples_mean.pow(2)) / (2 * samples_var.pow(2)) 70 | t2 = samples_var.log() 71 | 72 | KL = (t1 + t2 - 0.5).mean() 73 | else: 74 | # mu_2 = 0; sigma_2 = 1 75 | 76 | t1 = (samples_var.pow(2) + samples_mean.pow(2)) / 2 77 | t2 = -samples_var.log() 78 | 79 | KL = (t1 + t2 - 0.5).mean() 80 | 81 | if not self.minimize: 82 | KL *= -1 83 | 84 | return KL 85 | 86 | def trim_state_dict_name(state_dict): 87 | for k in list(state_dict.keys()): 88 | if k.startswith('module.'): 89 | # remove prefix 90 | state_dict[k[len("module."):]] = state_dict[k] 91 | del state_dict[k] 92 | return state_dict 93 | 94 | def inf_train_gen(data_loader): 95 | while True: 96 | for _,batch in enumerate(data_loader): 97 | yield batch 98 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from torch.nn import Parameter as P 8 | 9 | # Projection of x onto y 10 | def proj(x, y): 11 | return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) 12 | 13 | 14 | # Orthogonalize x wrt list of vectors ys 15 | def gram_schmidt(x, ys): 16 | for y in ys: 17 | x = x - proj(x, y) 18 | return x 19 | 20 | # Apply num_itrs steps of the power method to estimate top N singular values. 21 | def power_iteration(W, u_, update=True, eps=1e-12): 22 | # Lists holding singular vectors and values 23 | us, vs, svs = [], [], [] 24 | for i, u in enumerate(u_): 25 | # Run one step of the power iteration 26 | with torch.no_grad(): 27 | v = torch.matmul(u, W) 28 | # Run Gram-Schmidt to subtract components of all other singular vectors 29 | v = F.normalize(gram_schmidt(v, vs), eps=eps) 30 | # Add to the list 31 | vs += [v] 32 | # Update the other singular vector 33 | u = torch.matmul(v, W.t()) 34 | # Run Gram-Schmidt to subtract components of all other singular vectors 35 | u = F.normalize(gram_schmidt(u, us), eps=eps) 36 | # Add to the list 37 | us += [u] 38 | if update: 39 | u_[i][:] = u 40 | # Compute this singular value and add it to the list 41 | svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] 42 | #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] 43 | return svs, us, vs 44 | 45 | # Convenience passthrough function 46 | class identity(nn.Module): 47 | def forward(self, input): 48 | return input 49 | 50 | # Spectral normalization base class 51 | class SN(object): 52 | def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): 53 | # Number of power iterations per step 54 | self.num_itrs = num_itrs 55 | # Number of singular values 56 | self.num_svs = num_svs 57 | # Transposed? 58 | self.transpose = transpose 59 | # Epsilon value for avoiding divide-by-0 60 | self.eps = eps 61 | # Register a singular vector for each sv 62 | for i in range(self.num_svs): 63 | self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) 64 | self.register_buffer('sv%d' % i, torch.ones(1)) 65 | 66 | # Singular vectors (u side) 67 | @property 68 | def u(self): 69 | return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] 70 | 71 | # Singular values; 72 | # note that these buffers are just for logging and are not used in training. 73 | @property 74 | def sv(self): 75 | return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] 76 | 77 | # Compute the spectrally-normalized weight 78 | def W_(self): 79 | W_mat = self.weight.view(self.weight.size(0), -1) 80 | if self.transpose: 81 | W_mat = W_mat.t() 82 | # Apply num_itrs power iterations 83 | for _ in range(self.num_itrs): 84 | svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) 85 | # Update the svs 86 | if self.training: 87 | with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! 88 | for i, sv in enumerate(svs): 89 | self.sv[i][:] = sv 90 | return self.weight / svs[0] 91 | 92 | # 3D Conv layer with spectral norm 93 | class SNConv3d(nn.Conv3d, SN): 94 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 95 | padding=0, dilation=1, groups=1, bias=True, 96 | num_svs=1, num_itrs=1, eps=1e-12): 97 | nn.Conv3d.__init__(self, in_channels, out_channels, kernel_size, stride, 98 | padding, dilation, groups, bias) 99 | SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) 100 | def forward(self, x): 101 | return F.conv3d(x, self.W_(), self.bias, self.stride, 102 | self.padding, self.dilation, self.groups) 103 | 104 | 105 | # Linear layer with spectral norm 106 | class SNLinear(nn.Linear, SN): 107 | def __init__(self, in_features, out_features, bias=True, 108 | num_svs=1, num_itrs=1, eps=1e-12): 109 | nn.Linear.__init__(self, in_features, out_features, bias) 110 | SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) 111 | def forward(self, x): 112 | return F.linear(x, self.W_(), self.bias) 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Amortized GAN (HA-GAN) 2 | 3 | Official PyTorch implementation for paper *Hierarchical Amortized GAN for 3D High Resolution Medical Image Synthesis*, accepted to *IEEE Journal of Biomedical and Health Informatics* 4 | 5 |

6 | 7 |

8 | 9 | #### [[Paper & Supplementary Material]](https://ieeexplore.ieee.org/abstract/document/9770375) 10 | 11 | Generative Adversarial Networks (GAN) have many potential medical imaging applications. Due to the limited memory of Graphical Processing Units (GPUs), most current 3D GAN models are trained on low-resolution medical images. In this work, we propose a novel end-to-end GAN architecture that can generate high-resolution 3D images. We achieve this goal by using different configurations between training and inference. During training, we adopt a hierarchical structure that simultaneously generates a low-resolution version of the image and a randomly selected sub-volume of the high-resolution image. The hierarchical design has two advantages: First, the memory demand for training on high-resolution images is amortized among sub-volumes. Furthermore, anchoring the high-resolution sub-volumes to a single low-resolution image ensures anatomical consistency between sub-volumes. During inference, our model can directly generate full high-resolution images. We also incorporate an encoder (hidden in the figure to improve clarity) into the model to extract features from the images. 12 | 13 | ### Requirements 14 | - PyTorch 15 | - scikit-image 16 | - nibabel 17 | - nilearn 18 | - tensorboardX 19 | - SimpleITK 20 | 21 | ```bash 22 | conda env create --name hagan -f environment.yml 23 | conda activate hagan 24 | ``` 25 | 26 | ### Data Preprocessing 27 | The volume data need to be cropped or resized to 1283 or 2563, and intensity value need to be scaled to [-1,1]. In addition, we would like to advise you to trim blank axial slices. More details can be found at 28 | ```bash 29 | python preprocess.py 30 | ``` 31 | 32 | ### Training 33 | #### Unconditional HA-GAN 34 | ```bash 35 | python train.py --workers 8 --img-size 256 --num-class 0 --exp-name 'HA_GAN_run1' --data-dir DATA_DIR 36 | ``` 37 | #### Conditional HA-GAN 38 | ```bash 39 | python train.py --workers 8 --img-size 256 --num-class N --exp-name 'HA_GAN_cond_run1' --data-dir DATA_DIR 40 | ``` 41 | 42 | Track your training with Tensorboard: 43 |

44 | 45 |

46 | 47 | It will take around 22 hours to train unconditional HA-GAN for 80000 iterations with two NVIDIA Tesla V100 GPU. It is suggested to have at least 3000 images for training to avoid mode collapse, or you may need to consider [data augmentation](https://docs.monai.io/en/stable/transforms.html#intensity-dict). 48 | 49 | ### Testing 50 | ```bash 51 | visualization.ipynb 52 | evaluation/visualize_feature_MDS.ipynb 53 | python evaluation/fid_score.py 54 | ``` 55 | 56 | ### Sample images 57 |

58 | 59 |

60 | 61 | ### Pretrained weights 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 |
DatasetAnatomyIterationCheckpoint
COPDGeneLung80000Download
GSPBrain80000Download
82 | 83 | ### Citation 84 | ``` 85 | @ARTICLE{hagan2022, 86 | author={Sun, Li and Chen, Junxiang and Xu, Yanwu and Gong, Mingming and Yu, Ke and Batmanghelich, Kayhan}, 87 | journal={IEEE Journal of Biomedical and Health Informatics}, 88 | title={Hierarchical Amortized GAN for 3D High Resolution Medical Image Synthesis}, 89 | year={2022}, 90 | volume={26}, 91 | number={8}, 92 | pages={3966-3975}, 93 | doi={10.1109/JBHI.2022.3172976}} 94 | ``` 95 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: 3dgan 2 | channels: 3 | - simpleitk 4 | - pytorch-lts 5 | - nvidia 6 | - anaconda 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=main 11 | - _openmp_mutex=5.1=1_gnu 12 | - blas=1.0=mkl 13 | - brotli=1.0.9=he6710b0_2 14 | - brotlipy=0.7.0=py38h0a891b7_1004 15 | - bzip2=1.0.8=h7b6447c_0 16 | - ca-certificates=2022.4.26=h06a4308_0 17 | - cached-property=1.5.2=hd8ed1ab_1 18 | - cached_property=1.5.2=pyha770c72_1 19 | - certifi=2022.5.18.1=py38h06a4308_0 20 | - cffi=1.15.0=py38hd667e15_1 21 | - charset-normalizer=2.0.12=pyhd8ed1ab_0 22 | - cloudpickle=2.0.0=pyhd3eb1b0_0 23 | - cryptography=37.0.2=py38h2b5fc30_0 24 | - cudatoolkit=11.1.74=h6bb024c_0 25 | - cycler=0.11.0=pyhd3eb1b0_0 26 | - cytoolz=0.11.0=py38h7b6447c_0 27 | - dask-core=2022.2.1=pyhd3eb1b0_0 28 | - dbus=1.13.18=hb2f20db_0 29 | - expat=2.4.4=h295c915_0 30 | - ffmpeg=4.2.2=h20bf706_0 31 | - fontconfig=2.13.1=h6c09931_0 32 | - fonttools=4.25.0=pyhd3eb1b0_0 33 | - freetype=2.11.0=h70c0345_0 34 | - fsspec=2022.2.0=pyhd3eb1b0_0 35 | - giflib=5.2.1=h7b6447c_0 36 | - glib=2.69.1=h4ff587b_1 37 | - gmp=6.2.1=h2531618_2 38 | - gnutls=3.6.15=he1e5248_0 39 | - gst-plugins-base=1.14.0=h8213a91_2 40 | - gstreamer=1.14.0=h28cd5cc_2 41 | - h5py=3.2.1=nompi_py38h9915d05_100 42 | - hdf5=1.10.6=nompi_h3c11f04_101 43 | - icu=58.2=he6710b0_3 44 | - idna=3.3=pyhd8ed1ab_0 45 | - imageio=2.9.0=pyhd3eb1b0_0 46 | - intel-openmp=2021.4.0=h06a4308_3561 47 | - joblib=1.1.0=pyhd8ed1ab_0 48 | - jpeg=9b=h024ee3a_2 49 | - kiwisolver=1.4.2=py38h295c915_0 50 | - lame=3.100=h7b6447c_0 51 | - lcms2=2.12=h3be6417_0 52 | - ld_impl_linux-64=2.38=h1181459_1 53 | - libffi=3.3=he6710b0_2 54 | - libgcc-ng=11.2.0=h1234567_0 55 | - libgfortran-ng=7.5.0=h14aa051_20 56 | - libgfortran4=7.5.0=h14aa051_20 57 | - libgomp=11.2.0=h1234567_0 58 | - libiconv=1.16=h516909a_0 59 | - libidn2=2.3.2=h7f8727e_0 60 | - libopus=1.3.1=h7b6447c_0 61 | - libpng=1.6.37=hbc83047_0 62 | - libstdcxx-ng=11.2.0=h1234567_0 63 | - libtasn1=4.16.0=h27cfd23_0 64 | - libtiff=4.1.0=h2733197_1 65 | - libunistring=0.9.10=h27cfd23_0 66 | - libuuid=1.0.3=h7f8727e_2 67 | - libuv=1.40.0=h7b6447c_0 68 | - libvpx=1.7.0=h439df22_0 69 | - libwebp=1.2.0=h89dd481_0 70 | - libxcb=1.15=h7f8727e_0 71 | - libxml2=2.9.12=h74e7548_2 72 | - libxslt=1.1.34=hc22bd24_0 73 | - locket=0.2.1=py38h06a4308_2 74 | - lxml=4.8.0=py38h1f438cf_0 75 | - lz4-c=1.9.3=h295c915_1 76 | - matplotlib=3.5.1=py38h06a4308_1 77 | - matplotlib-base=3.5.1=py38ha18d171_1 78 | - mkl=2021.4.0=h06a4308_640 79 | - mkl-service=2.4.0=py38h7f8727e_0 80 | - mkl_fft=1.3.1=py38hd3c417c_0 81 | - mkl_random=1.2.2=py38h51133e4_0 82 | - munkres=1.1.4=py_0 83 | - ncurses=6.3=h7f8727e_2 84 | - nettle=3.7.3=hbbd107a_1 85 | - networkx=2.7.1=pyhd3eb1b0_0 86 | - nibabel=3.2.2=pyhd8ed1ab_0 87 | - nilearn=0.9.1=pyhd8ed1ab_0 88 | - ninja=1.10.2=h06a4308_5 89 | - ninja-base=1.10.2=hd09550d_5 90 | - numpy=1.22.3=py38he7a7128_0 91 | - numpy-base=1.22.3=py38hf524024_0 92 | - openh264=2.1.1=h4ff587b_0 93 | - openssl=1.1.1o=h7f8727e_0 94 | - packaging=21.3=pyhd8ed1ab_0 95 | - pandas=1.4.2=py38h47df419_1 96 | - partd=1.2.0=pyhd3eb1b0_1 97 | - pcre=8.45=h295c915_0 98 | - pillow=9.0.1=py38h22f2fdc_0 99 | - pip=21.2.4=py38h06a4308_0 100 | - pycparser=2.21=pyhd8ed1ab_0 101 | - pydicom=2.3.0=pyh6c4a22f_0 102 | - pyopenssl=22.0.0=pyhd8ed1ab_0 103 | - pyparsing=3.0.9=pyhd8ed1ab_0 104 | - pyqt=5.9.2=py38h05f1152_4 105 | - pysocks=1.7.1=py38h578d9bd_5 106 | - python=3.8.13=h12debd9_0 107 | - python-dateutil=2.8.2=pyhd8ed1ab_0 108 | - python_abi=3.8=2_cp38 109 | - pytorch=1.8.2=py3.8_cuda11.1_cudnn8.0.5_0 110 | - pytz=2022.1=pyhd8ed1ab_0 111 | - pywavelets=1.3.0=py38h7f8727e_0 112 | - pyyaml=6.0=py38h7f8727e_1 113 | - qt=5.9.7=h5867ecd_1 114 | - readline=8.1.2=h7f8727e_1 115 | - requests=2.27.1=pyhd8ed1ab_0 116 | - scikit-image=0.19.2=py38h51133e4_0 117 | - scikit-learn=1.0.2=py38h51133e4_1 118 | - scipy=1.7.3=py38hc147768_0 119 | - setuptools=61.2.0=py38h06a4308_0 120 | - simpleitk=2.0.2=py38h3fd9d12_0 121 | - sip=4.19.13=py38h295c915_0 122 | - six=1.16.0=pyhd3eb1b0_1 123 | - sqlite=3.38.3=hc218d9a_0 124 | - threadpoolctl=3.1.0=pyh8a188c0_0 125 | - tifffile=2020.10.1=py38hdd07704_2 126 | - tk=8.6.11=h1ccaba5_1 127 | - toolz=0.11.2=pyhd3eb1b0_0 128 | - torchaudio=0.8.2=py38 129 | - torchvision=0.9.2=py38_cu111 130 | - tornado=6.1=py38h27cfd23_0 131 | - typing_extensions=4.1.1=pyh06a4308_0 132 | - urllib3=1.26.9=pyhd8ed1ab_0 133 | - wheel=0.37.1=pyhd3eb1b0_0 134 | - x264=1!157.20191217=h7b6447c_0 135 | - xz=5.2.5=h7f8727e_1 136 | - yaml=0.2.5=h7b6447c_0 137 | - zlib=1.2.12=h7f8727e_2 138 | - zstd=1.4.9=haebb681_0 139 | - pip: 140 | - protobuf==3.20.1 141 | - tensorboardx==2.5 142 | prefix: /ocean/projects/asc170022p/lisun/miniconda3/envs/3dgan 143 | -------------------------------------------------------------------------------- /evaluation/resnet3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = [ 9 | 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnet200' 11 | ] 12 | 13 | 14 | def conv3x3x3(in_planes, out_planes, stride=1, dilation=1): 15 | # 3x3x3 convolution with padding 16 | return nn.Conv3d( 17 | in_planes, 18 | out_planes, 19 | kernel_size=3, 20 | dilation=dilation, 21 | stride=stride, 22 | padding=dilation, 23 | bias=False) 24 | 25 | 26 | def downsample_basic_block(x, planes, stride, no_cuda=False): 27 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 28 | zero_pads = torch.Tensor( 29 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 30 | out.size(4)).zero_() 31 | if not no_cuda: 32 | if isinstance(out.data, torch.cuda.FloatTensor): 33 | zero_pads = zero_pads.cuda() 34 | 35 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 36 | 37 | return out 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | expansion = 1 42 | 43 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 44 | super(BasicBlock, self).__init__() 45 | self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation) 46 | self.bn1 = nn.BatchNorm3d(planes) 47 | self.relu = nn.ReLU(inplace=True) 48 | self.conv2 = conv3x3x3(planes, planes, dilation=dilation) 49 | self.bn2 = nn.BatchNorm3d(planes) 50 | self.downsample = downsample 51 | self.stride = stride 52 | self.dilation = dilation 53 | 54 | def forward(self, x): 55 | residual = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | out = self.conv2(out) 61 | out = self.bn2(out) 62 | 63 | if self.downsample is not None: 64 | residual = self.downsample(x) 65 | 66 | out += residual 67 | out = self.relu(out) 68 | 69 | return out 70 | 71 | 72 | class Bottleneck(nn.Module): 73 | expansion = 4 74 | 75 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 76 | super(Bottleneck, self).__init__() 77 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 78 | self.bn1 = nn.BatchNorm3d(planes) 79 | self.conv2 = nn.Conv3d( 80 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False) 81 | self.bn2 = nn.BatchNorm3d(planes) 82 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 83 | self.bn3 = nn.BatchNorm3d(planes * 4) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.downsample = downsample 86 | self.stride = stride 87 | self.dilation = dilation 88 | 89 | def forward(self, x): 90 | residual = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | residual = self.downsample(x) 105 | 106 | out += residual 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | class upsample(nn.Module): 112 | def forward(self, inp): 113 | return F.interpolate(inp, scale_factor = 2) 114 | 115 | class ResNet(nn.Module): 116 | 117 | def __init__(self, 118 | block, 119 | layers, 120 | num_seg_classes=2, 121 | shortcut_type='B', 122 | no_cuda = False): 123 | self.inplanes = 64 124 | self.no_cuda = no_cuda 125 | super(ResNet, self).__init__() 126 | self.conv1 = nn.Conv3d( 127 | 1, 128 | 64, 129 | kernel_size=7, 130 | stride=(2, 2, 2), 131 | padding=(3, 3, 3), 132 | bias=False) 133 | 134 | self.bn1 = nn.BatchNorm3d(64) 135 | self.relu = nn.ReLU(inplace=True) 136 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 137 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 138 | self.layer2 = self._make_layer( 139 | block, 128, layers[1], shortcut_type, stride=2) 140 | self.layer3 = self._make_layer( 141 | block, 256, layers[2], shortcut_type, stride=1, dilation=2) 142 | self.layer4 = self._make_layer( 143 | block, 512, layers[3], shortcut_type, stride=1, dilation=4) 144 | 145 | self.conv_seg = nn.Sequential( 146 | upsample(), 147 | nn.ConvTranspose3d( 148 | 512 * block.expansion, 149 | 32, 150 | 2, 151 | stride=2 152 | ), 153 | nn.BatchNorm3d(32), 154 | nn.ReLU(inplace=True), 155 | upsample(), 156 | nn.Conv3d( 157 | 32, 158 | 32, 159 | kernel_size=3, 160 | stride=(1, 1, 1), 161 | padding=(1, 1, 1), 162 | bias=False), 163 | nn.BatchNorm3d(32), 164 | nn.ReLU(inplace=True), 165 | nn.Conv3d( 166 | 32, 167 | num_seg_classes, 168 | kernel_size=1, 169 | stride=(1, 1, 1), 170 | bias=False), 171 | nn.Sigmoid() 172 | ) 173 | 174 | for m in self.modules(): 175 | if isinstance(m, nn.Conv3d): 176 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 177 | elif isinstance(m, nn.BatchNorm3d): 178 | m.weight.data.fill_(1) 179 | m.bias.data.zero_() 180 | 181 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1): 182 | downsample = None 183 | if stride != 1 or self.inplanes != planes * block.expansion: 184 | if shortcut_type == 'A': 185 | downsample = partial( 186 | downsample_basic_block, 187 | planes=planes * block.expansion, 188 | stride=stride, 189 | no_cuda=self.no_cuda) 190 | else: 191 | downsample = nn.Sequential( 192 | nn.Conv3d( 193 | self.inplanes, 194 | planes * block.expansion, 195 | kernel_size=1, 196 | stride=stride, 197 | bias=False), nn.BatchNorm3d(planes * block.expansion)) 198 | 199 | layers = [] 200 | layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample)) 201 | self.inplanes = planes * block.expansion 202 | for i in range(1, blocks): 203 | layers.append(block(self.inplanes, planes, dilation=dilation)) 204 | 205 | return nn.Sequential(*layers) 206 | 207 | def forward(self, x): 208 | x = self.conv1(x) 209 | x = self.bn1(x) 210 | x = self.relu(x) 211 | x = self.maxpool(x) 212 | x = self.layer1(x) 213 | x = self.layer2(x) 214 | x = self.layer3(x) 215 | x = self.layer4(x) 216 | #print(x.shape) 217 | x = self.conv_seg(x) 218 | 219 | return x 220 | 221 | def resnet10(**kwargs): 222 | """Constructs a ResNet-18 model. 223 | """ 224 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 225 | return model 226 | 227 | 228 | def resnet18(**kwargs): 229 | """Constructs a ResNet-18 model. 230 | """ 231 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 232 | return model 233 | 234 | 235 | def resnet34(**kwargs): 236 | """Constructs a ResNet-34 model. 237 | """ 238 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 239 | return model 240 | 241 | 242 | def resnet50(**kwargs): 243 | """Constructs a ResNet-50 model. 244 | """ 245 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 246 | return model 247 | 248 | 249 | def resnet101(**kwargs): 250 | """Constructs a ResNet-101 model. 251 | """ 252 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 253 | return model 254 | 255 | 256 | def resnet152(**kwargs): 257 | """Constructs a ResNet-101 model. 258 | """ 259 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 260 | return model 261 | 262 | 263 | def resnet200(**kwargs): 264 | """Constructs a ResNet-101 model. 265 | """ 266 | model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs) 267 | return model 268 | -------------------------------------------------------------------------------- /models/Model_HA_GAN_128.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from torch import nn 5 | from torch import optim 6 | from torch.nn import functional as F 7 | from models.layers import SNConv3d, SNLinear 8 | 9 | class Code_Discriminator(nn.Module): 10 | def __init__(self, code_size, num_units=256): 11 | super(Code_Discriminator, self).__init__() 12 | 13 | self.l1 = nn.Sequential(SNLinear(code_size, num_units), 14 | nn.LeakyReLU(0.2,inplace=True)) 15 | self.l2 = nn.Sequential(SNLinear(num_units, num_units), 16 | nn.LeakyReLU(0.2,inplace=True)) 17 | self.l3 = SNLinear(num_units, 1) 18 | 19 | def forward(self, x): 20 | x = self.l1(x) 21 | x = self.l2(x) 22 | x = self.l3(x) 23 | 24 | return x 25 | 26 | class Sub_Encoder(nn.Module): 27 | def __init__(self, channel=256, latent_dim=1024): 28 | super(Sub_Encoder, self).__init__() 29 | 30 | self.relu = nn.ReLU() 31 | self.conv2 = nn.Conv3d(channel//4, channel//4, kernel_size=4, stride=2, padding=1) # out:[16,16,16] 32 | self.bn2 = nn.GroupNorm(8, channel//4) 33 | self.conv3 = nn.Conv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) # out:[8,8,8] 34 | self.bn3 = nn.GroupNorm(8, channel//2) 35 | self.conv4 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) # out:[4,4,4] 36 | self.bn4 = nn.GroupNorm(8, channel) 37 | self.conv5 = nn.Conv3d(channel, latent_dim, kernel_size=4, stride=1, padding=0) # out:[1,1,1,1] 38 | 39 | def forward(self, h): 40 | h = self.conv2(h) 41 | h = self.relu(self.bn2(h)) 42 | h = self.conv3(h) 43 | h = self.relu(self.bn3(h)) 44 | h = self.conv4(h) 45 | h = self.relu(self.bn4(h)) 46 | h = self.conv5(h).squeeze() 47 | return h 48 | 49 | class Encoder(nn.Module): 50 | def __init__(self, channel=64): 51 | super(Encoder, self).__init__() 52 | 53 | self.relu = nn.ReLU() 54 | self.conv1 = nn.Conv3d(1, channel//2, kernel_size=4, stride=2, padding=1) # in:[16,128,128], out:[8,64,64] 55 | self.bn1 = nn.GroupNorm(8, channel//2) 56 | self.conv2 = nn.Conv3d(channel//2, channel//2, kernel_size=3, stride=1, padding=1) # out:[8,64,64] 57 | self.bn2 = nn.GroupNorm(8, channel//2) 58 | self.conv3 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) # out:[4,32,32] 59 | self.bn3 = nn.GroupNorm(8, channel) 60 | 61 | def forward(self, h): 62 | h = self.conv1(h) 63 | h = self.relu(self.bn1(h)) 64 | 65 | h = self.conv2(h) 66 | h = self.relu(self.bn2(h)) 67 | 68 | h = self.conv3(h) 69 | h = self.relu(self.bn3(h)) 70 | return h 71 | 72 | class Sub_Discriminator(nn.Module): 73 | def __init__(self, num_class=0, channel=256): 74 | super(Sub_Discriminator, self).__init__() 75 | self.channel = channel 76 | self.num_class = num_class 77 | 78 | self.conv2 = SNConv3d(1, channel//4, kernel_size=4, stride=2, padding=1) # out:[16,16,16] 79 | self.conv3 = SNConv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) # out:[8,8,8] 80 | self.conv4 = SNConv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) # out:[4,4,4] 81 | self.conv5 = SNConv3d(channel, 1+num_class, kernel_size=4, stride=1, padding=0) # out:[1,1,1,1] 82 | 83 | def forward(self, h): 84 | h = F.leaky_relu(self.conv2(h), negative_slope=0.2) 85 | h = F.leaky_relu(self.conv3(h), negative_slope=0.2) 86 | h = F.leaky_relu(self.conv4(h), negative_slope=0.2) 87 | if self.num_class == 0: 88 | h = self.conv5(h).view((-1,1)) 89 | return h 90 | else: 91 | h = self.conv5(h).view((-1,1+self.num_class)) 92 | return h[:,:1], h[:,1:] 93 | 94 | class Discriminator(nn.Module): 95 | def __init__(self, num_class=0, channel=512): 96 | super(Discriminator, self).__init__() 97 | self.channel = channel 98 | self.num_class = num_class 99 | 100 | # D^H 101 | self.conv2 = SNConv3d(1, channel//16, kernel_size=4, stride=2, padding=1) # out:[8,64,64,64] 102 | self.conv3 = SNConv3d(channel//16, channel//8, kernel_size=4, stride=2, padding=1) # out:[4,32,32,32] 103 | self.conv4 = SNConv3d(channel//8, channel//4, kernel_size=(2,4,4), stride=(2,2,2), padding=(0,1,1)) # out:[2,16,16,16] 104 | self.conv5 = SNConv3d(channel//4, channel//2, kernel_size=(2,4,4), stride=(2,2,2), padding=(0,1,1)) # out:[1,8,8,8] 105 | self.conv6 = SNConv3d(channel//2, channel, kernel_size=(1,4,4), stride=(1,2,2), padding=(0,1,1)) # out:[1,4,4,4] 106 | self.conv7 = SNConv3d(channel, channel//4, kernel_size=(1,4,4), stride=1, padding=0) # out:[1,1,1,1] 107 | self.fc1 = SNLinear(channel//4+1, channel//8) 108 | self.fc2 = SNLinear(channel//8, 1) 109 | if num_class>0: 110 | self.fc2_class = SNLinear(channel//8, num_class) 111 | 112 | # D^L 113 | self.sub_D = Sub_Discriminator(num_class) 114 | 115 | def forward(self, h, h_small, crop_idx): 116 | h = F.leaky_relu(self.conv2(h), negative_slope=0.2) 117 | h = F.leaky_relu(self.conv3(h), negative_slope=0.2) 118 | h = F.leaky_relu(self.conv4(h), negative_slope=0.2) 119 | h = F.leaky_relu(self.conv5(h), negative_slope=0.2) 120 | h = F.leaky_relu(self.conv6(h), negative_slope=0.2) 121 | h = F.leaky_relu(self.conv7(h), negative_slope=0.2).squeeze() 122 | h = torch.cat([h, (crop_idx / 112. * torch.ones((h.size(0), 1))).cuda()], 1) # 128*7/8 123 | h = F.leaky_relu(self.fc1(h), negative_slope=0.2) 124 | h_logit = self.fc2(h) 125 | if self.num_class>0: 126 | h_class_logit = self.fc2_class(h) 127 | 128 | h_small_logit, h_small_class_logit = self.sub_D(h_small) 129 | return (h_logit+ h_small_logit)/2., (h_class_logit+ h_small_class_logit)/2. 130 | else: 131 | h_small_logit = self.sub_D(h_small) 132 | return (h_logit+ h_small_logit)/2. 133 | 134 | 135 | class Sub_Generator(nn.Module): 136 | def __init__(self, channel:int=16): 137 | super(Sub_Generator, self).__init__() 138 | _c = channel 139 | 140 | self.relu = nn.ReLU() 141 | self.tp_conv1 = nn.Conv3d(_c*4, _c*2, kernel_size=3, stride=1, padding=1, bias=True) 142 | self.bn1 = nn.GroupNorm(8, _c*2) 143 | 144 | self.tp_conv2 = nn.Conv3d(_c*2, _c, kernel_size=3, stride=1, padding=1, bias=True) 145 | self.bn2 = nn.GroupNorm(8, _c) 146 | 147 | self.tp_conv3 = nn.Conv3d(_c, 1, kernel_size=3, stride=1, padding=1, bias=True) 148 | 149 | def forward(self, h): 150 | 151 | h = self.tp_conv1(h) 152 | h = self.relu(self.bn1(h)) 153 | 154 | h = self.tp_conv2(h) 155 | h = self.relu(self.bn2(h)) 156 | 157 | h = self.tp_conv3(h) 158 | h = torch.tanh(h) 159 | return h 160 | 161 | class Generator(nn.Module): 162 | def __init__(self, mode="train", latent_dim=1024, channel=32, num_class=0): 163 | super(Generator, self).__init__() 164 | _c = channel 165 | 166 | self.mode = mode 167 | self.relu = nn.ReLU() 168 | self.num_class = num_class 169 | 170 | # G^A and G^H 171 | self.fc1 = nn.Linear(latent_dim+num_class, 4*4*4*_c*16) 172 | 173 | self.tp_conv1 = nn.Conv3d(_c*16, _c*16, kernel_size=3, stride=1, padding=1, bias=True) 174 | self.bn1 = nn.GroupNorm(8, _c*16) 175 | 176 | self.tp_conv2 = nn.Conv3d(_c*16, _c*16, kernel_size=3, stride=1, padding=1, bias=True) 177 | self.bn2 = nn.GroupNorm(8, _c*16) 178 | 179 | self.tp_conv3 = nn.Conv3d(_c*16, _c*8, kernel_size=3, stride=1, padding=1, bias=True) 180 | self.bn3 = nn.GroupNorm(8, _c*8) 181 | 182 | self.tp_conv4 = nn.Conv3d(_c*8, _c*4, kernel_size=3, stride=1, padding=1, bias=True) 183 | self.bn4 = nn.GroupNorm(8, _c*4) 184 | 185 | self.tp_conv5 = nn.Conv3d(_c*4, _c*2, kernel_size=3, stride=1, padding=1, bias=True) 186 | self.bn5 = nn.GroupNorm(8, _c*2) 187 | 188 | self.tp_conv6 = nn.Conv3d(_c*2, _c, kernel_size=3, stride=1, padding=1, bias=True) 189 | self.bn6 = nn.GroupNorm(8, _c) 190 | 191 | self.tp_conv7 = nn.Conv3d(_c, 1, kernel_size=3, stride=1, padding=1, bias=True) 192 | 193 | # G^L 194 | self.sub_G = Sub_Generator(channel=_c//2) 195 | 196 | def forward(self, h, crop_idx=None, class_label=None): 197 | 198 | # Generate from random noise 199 | if crop_idx != None or self.mode=='eval': 200 | if self.num_class > 0: 201 | h = torch.cat((h, class_label), dim=1) 202 | 203 | h = self.fc1(h) 204 | 205 | h = h.view(-1,512,4,4,4) 206 | h = self.tp_conv1(h) 207 | h = self.relu(self.bn1(h)) 208 | 209 | h = F.interpolate(h,scale_factor = 2) 210 | h = self.tp_conv2(h) 211 | h = self.relu(self.bn2(h)) 212 | 213 | h = F.interpolate(h,scale_factor = 2) 214 | h = self.tp_conv3(h) 215 | h = self.relu(self.bn3(h)) 216 | 217 | h = F.interpolate(h,scale_factor = 2) 218 | h = self.tp_conv4(h) 219 | h = self.relu(self.bn4(h)) 220 | 221 | h = self.tp_conv5(h) 222 | h_latent = self.relu(self.bn5(h)) # (32, 32, 32), channel:128 223 | 224 | if self.mode == "train": 225 | h_small = self.sub_G(h_latent) 226 | h = h_latent[:,:,crop_idx//4:crop_idx//4+4,:,:] # Crop sub-volume, out: (4, 32, 32) 227 | else: 228 | h = h_latent 229 | 230 | # Generate from latent feature 231 | h = F.interpolate(h,scale_factor = 2) 232 | h = self.tp_conv6(h) 233 | h = self.relu(self.bn6(h)) # (64, 64, 64) 234 | 235 | h = F.interpolate(h,scale_factor = 2) 236 | h = self.tp_conv7(h) 237 | 238 | h = torch.tanh(h) # (128, 128, 128) 239 | 240 | if crop_idx != None and self.mode == "train": 241 | return h, h_small 242 | return h 243 | -------------------------------------------------------------------------------- /models/Model_HA_GAN_256.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from models.layers import SNConv3d, SNLinear 6 | 7 | class Code_Discriminator(nn.Module): 8 | def __init__(self, code_size, num_units=256): 9 | super(Code_Discriminator, self).__init__() 10 | 11 | self.l1 = nn.Sequential(SNLinear(code_size, num_units), 12 | nn.LeakyReLU(0.2,inplace=True)) 13 | self.l2 = nn.Sequential(SNLinear(num_units, num_units), 14 | nn.LeakyReLU(0.2,inplace=True)) 15 | self.l3 = SNLinear(num_units, 1) 16 | 17 | def forward(self, x): 18 | x = self.l1(x) 19 | x = self.l2(x) 20 | x = self.l3(x) 21 | 22 | return x 23 | 24 | class Sub_Encoder(nn.Module): 25 | def __init__(self, channel=256, latent_dim=1024): 26 | super(Sub_Encoder, self).__init__() 27 | 28 | self.relu = nn.ReLU() 29 | self.conv1 = nn.Conv3d(channel//4, channel//8, kernel_size=4, stride=2, padding=1) # in:[64,64,64], out:[32,32,32] 30 | self.bn1 = nn.GroupNorm(8, channel//8) 31 | self.conv2 = nn.Conv3d(channel//8, channel//4, kernel_size=4, stride=2, padding=1) # out:[16,16,16] 32 | self.bn2 = nn.GroupNorm(8, channel//4) 33 | self.conv3 = nn.Conv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) # out:[8,8,8] 34 | self.bn3 = nn.GroupNorm(8, channel//2) 35 | self.conv4 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) # out:[4,4,4] 36 | self.bn4 = nn.GroupNorm(8, channel) 37 | self.conv5 = nn.Conv3d(channel, latent_dim, kernel_size=4, stride=1, padding=0) # out:[1,1,1,1] 38 | 39 | def forward(self, h): 40 | h = self.conv1(h) 41 | h = self.relu(self.bn1(h)) 42 | h = self.conv2(h) 43 | h = self.relu(self.bn2(h)) 44 | h = self.conv3(h) 45 | h = self.relu(self.bn3(h)) 46 | h = self.conv4(h) 47 | h = self.relu(self.bn4(h)) 48 | h = self.conv5(h).squeeze() 49 | return h 50 | 51 | class Encoder(nn.Module): 52 | def __init__(self, channel=64): 53 | super(Encoder, self).__init__() 54 | 55 | self.relu = nn.ReLU() 56 | self.conv1 = nn.Conv3d(1, channel//2, kernel_size=4, stride=2, padding=1) # in:[32,256,256], out:[16,128,128] 57 | self.bn1 = nn.GroupNorm(8, channel//2) 58 | self.conv2 = nn.Conv3d(channel//2, channel//2, kernel_size=3, stride=1, padding=1) # out:[16,128,128] 59 | self.bn2 = nn.GroupNorm(8, channel//2) 60 | self.conv3 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) # out:[8,64,64] 61 | self.bn3 = nn.GroupNorm(8, channel) 62 | 63 | def forward(self, h): 64 | h = self.conv1(h) 65 | h = self.relu(self.bn1(h)) 66 | 67 | h = self.conv2(h) 68 | h = self.relu(self.bn2(h)) 69 | 70 | h = self.conv3(h) 71 | h = self.relu(self.bn3(h)) 72 | return h 73 | 74 | # D^L 75 | class Sub_Discriminator(nn.Module): 76 | def __init__(self, num_class=0, channel=256): 77 | super(Sub_Discriminator, self).__init__() 78 | self.channel = channel 79 | self.num_class = num_class 80 | 81 | self.conv1 = SNConv3d(1, channel//8, kernel_size=4, stride=2, padding=1) # in:[64,64,64], out:[32,32,32] 82 | self.conv2 = SNConv3d(channel//8, channel//4, kernel_size=4, stride=2, padding=1) # out:[16,16,16] 83 | self.conv3 = SNConv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) # out:[8,8,8] 84 | self.conv4 = SNConv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) # out:[4,4,4] 85 | self.conv5 = SNConv3d(channel, 1+num_class, kernel_size=4, stride=1, padding=0) # out:[1,1,1,1] 86 | 87 | def forward(self, h): 88 | h = F.leaky_relu(self.conv1(h), negative_slope=0.2) 89 | h = F.leaky_relu(self.conv2(h), negative_slope=0.2) 90 | h = F.leaky_relu(self.conv3(h), negative_slope=0.2) 91 | h = F.leaky_relu(self.conv4(h), negative_slope=0.2) 92 | if self.num_class == 0: 93 | h = self.conv5(h).view((-1,1)) 94 | return h 95 | else: 96 | h = self.conv5(h).view((-1,1+self.num_class)) 97 | return h[:,:1], h[:,1:] 98 | 99 | class Discriminator(nn.Module): 100 | def __init__(self, num_class=0, channel=512): 101 | super(Discriminator, self).__init__() 102 | self.channel = channel 103 | self.num_class = num_class 104 | 105 | # D^H 106 | self.conv1 = SNConv3d(1, channel//32, kernel_size=4, stride=2, padding=1) # in:[32,256,256], out:[16,128,128] 107 | self.conv2 = SNConv3d(channel//32, channel//16, kernel_size=4, stride=2, padding=1) # out:[8,64,64,64] 108 | self.conv3 = SNConv3d(channel//16, channel//8, kernel_size=4, stride=2, padding=1) # out:[4,32,32,32] 109 | self.conv4 = SNConv3d(channel//8, channel//4, kernel_size=(2,4,4), stride=(2,2,2), padding=(0,1,1)) # out:[2,16,16,16] 110 | self.conv5 = SNConv3d(channel//4, channel//2, kernel_size=(2,4,4), stride=(2,2,2), padding=(0,1,1)) # out:[1,8,8,8] 111 | self.conv6 = SNConv3d(channel//2, channel, kernel_size=(1,4,4), stride=(1,2,2), padding=(0,1,1)) # out:[1,4,4,4] 112 | self.conv7 = SNConv3d(channel, channel//4, kernel_size=(1,4,4), stride=1, padding=0) # out:[1,1,1,1] 113 | self.fc1 = SNLinear(channel//4+1, channel//8) 114 | self.fc2 = SNLinear(channel//8, 1) 115 | if num_class>0: 116 | self.fc2_class = SNLinear(channel//8, num_class) 117 | 118 | # D^L 119 | self.sub_D = Sub_Discriminator(num_class) 120 | 121 | def forward(self, h, h_small, crop_idx): 122 | h = F.leaky_relu(self.conv1(h), negative_slope=0.2) 123 | h = F.leaky_relu(self.conv2(h), negative_slope=0.2) 124 | h = F.leaky_relu(self.conv3(h), negative_slope=0.2) 125 | h = F.leaky_relu(self.conv4(h), negative_slope=0.2) 126 | h = F.leaky_relu(self.conv5(h), negative_slope=0.2) 127 | h = F.leaky_relu(self.conv6(h), negative_slope=0.2) 128 | h = F.leaky_relu(self.conv7(h), negative_slope=0.2).squeeze() 129 | h = torch.cat([h, (crop_idx / 224. * torch.ones((h.size(0), 1))).cuda()], 1) # 256*7/8 130 | h = F.leaky_relu(self.fc1(h), negative_slope=0.2) 131 | h_logit = self.fc2(h) 132 | if self.num_class>0: 133 | h_class_logit = self.fc2_class(h) 134 | 135 | h_small_logit, h_small_class_logit = self.sub_D(h_small) 136 | return (h_logit+ h_small_logit)/2., (h_class_logit+ h_small_class_logit)/2. 137 | else: 138 | h_small_logit = self.sub_D(h_small) 139 | return (h_logit+ h_small_logit)/2. 140 | 141 | 142 | class Sub_Generator(nn.Module): 143 | def __init__(self, channel:int=16): 144 | super(Sub_Generator, self).__init__() 145 | _c = channel 146 | 147 | self.relu = nn.ReLU() 148 | self.tp_conv1 = nn.Conv3d(_c*4, _c*2, kernel_size=3, stride=1, padding=1, bias=True) 149 | self.bn1 = nn.GroupNorm(8, _c*2) 150 | 151 | self.tp_conv2 = nn.Conv3d(_c*2, _c, kernel_size=3, stride=1, padding=1, bias=True) 152 | self.bn2 = nn.GroupNorm(8, _c) 153 | 154 | self.tp_conv3 = nn.Conv3d(_c, 1, kernel_size=3, stride=1, padding=1, bias=True) 155 | 156 | def forward(self, h): 157 | 158 | h = self.tp_conv1(h) 159 | h = self.relu(self.bn1(h)) 160 | 161 | h = self.tp_conv2(h) 162 | h = self.relu(self.bn2(h)) 163 | 164 | h = self.tp_conv3(h) 165 | h = torch.tanh(h) 166 | return h 167 | 168 | class Generator(nn.Module): 169 | def __init__(self, mode="train", latent_dim=1024, channel=32, num_class=0): 170 | super(Generator, self).__init__() 171 | _c = channel 172 | 173 | self.mode = mode 174 | self.relu = nn.ReLU() 175 | self.num_class = num_class 176 | 177 | # G^A and G^H 178 | self.fc1 = nn.Linear(latent_dim+num_class, 4*4*4*_c*16) 179 | 180 | self.tp_conv1 = nn.Conv3d(_c*16, _c*16, kernel_size=3, stride=1, padding=1, bias=True) 181 | self.bn1 = nn.GroupNorm(8, _c*16) 182 | 183 | self.tp_conv2 = nn.Conv3d(_c*16, _c*16, kernel_size=3, stride=1, padding=1, bias=True) 184 | self.bn2 = nn.GroupNorm(8, _c*16) 185 | 186 | self.tp_conv3 = nn.Conv3d(_c*16, _c*8, kernel_size=3, stride=1, padding=1, bias=True) 187 | self.bn3 = nn.GroupNorm(8, _c*8) 188 | 189 | self.tp_conv4 = nn.Conv3d(_c*8, _c*4, kernel_size=3, stride=1, padding=1, bias=True) 190 | self.bn4 = nn.GroupNorm(8, _c*4) 191 | 192 | self.tp_conv5 = nn.Conv3d(_c*4, _c*2, kernel_size=3, stride=1, padding=1, bias=True) 193 | self.bn5 = nn.GroupNorm(8, _c*2) 194 | 195 | self.tp_conv6 = nn.Conv3d(_c*2, _c, kernel_size=3, stride=1, padding=1, bias=True) 196 | self.bn6 = nn.GroupNorm(8, _c) 197 | 198 | self.tp_conv7 = nn.Conv3d(_c, 1, kernel_size=3, stride=1, padding=1, bias=True) 199 | 200 | # G^L 201 | self.sub_G = Sub_Generator(channel=_c//2) 202 | 203 | def forward(self, h, crop_idx=None, class_label=None): 204 | 205 | # Generate from random noise 206 | if crop_idx != None or self.mode=='eval': 207 | if self.num_class > 0: 208 | h = torch.cat((h, class_label), dim=1) 209 | 210 | h = self.fc1(h) 211 | 212 | h = h.view(-1,512,4,4,4) 213 | h = self.tp_conv1(h) 214 | h = self.relu(self.bn1(h)) 215 | 216 | h = F.interpolate(h,scale_factor = 2) 217 | h = self.tp_conv2(h) 218 | h = self.relu(self.bn2(h)) 219 | 220 | h = F.interpolate(h,scale_factor = 2) 221 | h = self.tp_conv3(h) 222 | h = self.relu(self.bn3(h)) 223 | 224 | h = F.interpolate(h,scale_factor = 2) 225 | h = self.tp_conv4(h) 226 | h = self.relu(self.bn4(h)) 227 | 228 | h = F.interpolate(h,scale_factor = 2) 229 | h = self.tp_conv5(h) 230 | h_latent = self.relu(self.bn5(h)) # (64, 64, 64), channel:128 231 | 232 | if self.mode == "train": 233 | h_small = self.sub_G(h_latent) 234 | h = h_latent[:,:,crop_idx//4:crop_idx//4+8,:,:] # Crop, out: (8, 64, 64) 235 | else: 236 | h = h_latent 237 | 238 | # Generate from latent feature 239 | h = F.interpolate(h,scale_factor = 2) 240 | h = self.tp_conv6(h) 241 | h = self.relu(self.bn6(h)) # (128, 128, 128) 242 | 243 | h = F.interpolate(h,scale_factor = 2) 244 | h = self.tp_conv7(h) 245 | 246 | h = torch.tanh(h) # (256, 256, 256) 247 | 248 | if crop_idx != None and self.mode == "train": 249 | return h, h_small 250 | return h 251 | -------------------------------------------------------------------------------- /evaluation/fid_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import time 5 | from argparse import ArgumentParser 6 | 7 | import numpy as np 8 | from scipy import linalg 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | 14 | from volume_dataset import Volume_Dataset 15 | 16 | from models.Model_HA_GAN_256 import Generator, Encoder, Sub_Encoder 17 | from resnet3D import resnet50 18 | 19 | torch.manual_seed(0) 20 | torch.cuda.manual_seed_all(0) 21 | torch.backends.cudnn.benchmark = True 22 | 23 | parser = ArgumentParser() 24 | parser.add_argument('--path', type=str, default='') 25 | parser.add_argument('--real_suffix', type=str, default='eval_600_size_256_resnet50_fold') 26 | parser.add_argument('--img_size', type=int, default=256) 27 | parser.add_argument('--batch_size', type=int, default=2) 28 | parser.add_argument('--num_workers', type=int, default=8) 29 | parser.add_argument('--num_samples', type=int, default=2048) 30 | parser.add_argument('--dims', type=int, default=2048) 31 | parser.add_argument('--ckpt_step', type=int, default=80000) 32 | parser.add_argument('--latent_dim', type=int, default=1024) 33 | parser.add_argument('--basename', type=str, default="256_1024_Alpha_SN_v4plus_4_l1_GN_threshold_600_fold") 34 | parser.add_argument('--fold', type=int) 35 | 36 | def trim_state_dict_name(ckpt): 37 | new_state_dict = OrderedDict() 38 | for k, v in ckpt.items(): 39 | name = k[7:] # remove `module.` 40 | new_state_dict[name] = v 41 | return new_state_dict 42 | 43 | class Flatten(torch.nn.Module): 44 | def forward(self, inp): 45 | return inp.view(inp.size(0), -1) 46 | 47 | def generate_samples(args): 48 | G = Generator(mode='eval', latent_dim=args.latent_dim, num_class=0) 49 | ckpt_path = "./checkpoint/"+args.basename+str(args.fold)+"/G_iter"+str(args.ckpt_step)+".pth" 50 | ckpt = torch.load(ckpt_path)['model'] 51 | ckpt = trim_state_dict_name(ckpt) 52 | G.load_state_dict(ckpt) 53 | 54 | E = Encoder() 55 | ckpt_path = "./checkpoint/"+args.basename+str(args.fold)+"/E_iter"+str(args.ckpt_step)+".pth" 56 | ckpt = torch.load(ckpt_path)['model'] 57 | ckpt = trim_state_dict_name(ckpt) 58 | E.load_state_dict(ckpt) 59 | 60 | Sub_E = Sub_Encoder(args.latent_dim=args.latent_dim) 61 | ckpt_path = "./checkpoint/"+args.basename+str(args.fold)+"/Sub_E_iter"+str(args.ckpt_step)+".pth" 62 | ckpt = torch.load(ckpt_path)['model'] 63 | ckpt = trim_state_dict_name(ckpt) 64 | Sub_E.load_state_dict(ckpt) 65 | print("Weights step", args.ckpt_step, "loaded.") 66 | del ckpt 67 | 68 | G = nn.DataParallel(G).cuda() 69 | E = nn.DataParallel(E).cuda() 70 | Sub_E = nn.DataParallel(Sub_E).cuda() 71 | 72 | G.eval() 73 | E.eval() 74 | Sub_E.eval() 75 | 76 | 77 | model = get_feature_extractor() 78 | pred_arr = np.empty((args.num_samples, args.dims)) 79 | 80 | for i in range(args.num_samples//args.batch_size): 81 | if i % 10 == 0: 82 | print('\rPropagating batch %d' % i, end='', flush=True) 83 | with torch.no_grad(): 84 | 85 | noise = torch.randn((args.batch_size, args.latent_size)).cuda() 86 | x_rand = G(noise) # dumb index 0, not used 87 | # range: [-1,1] 88 | x_rand = x_rand.detach() 89 | pred = model(x_rand) 90 | 91 | if (i+1)*args.batch_size > pred_arr.shape[0]: 92 | pred_arr[i*args.batch_size:] = pred.cpu().numpy() 93 | else: 94 | pred_arr[i*args.batch_size:(i+1)*args.batch_size] = pred.cpu().numpy() 95 | 96 | print(' done') 97 | return pred_arr 98 | 99 | def get_activations_from_dataloader(model, data_loader, args): 100 | 101 | pred_arr = np.empty((args.num_samples, args.dims)) 102 | for i, batch in enumerate(data_loader): 103 | if i % 10 == 0: 104 | print('\rPropagating batch %d' % i, end='', flush=True) 105 | batch = batch.float().cuda() 106 | with torch.no_grad(): 107 | pred = model(batch) 108 | 109 | if i*args.batch_size > pred_arr.shape[0]: 110 | pred_arr[i*args.batch_size:] = pred.cpu().numpy() 111 | else: 112 | pred_arr[i*args.batch_size:(i+1)*args.batch_size] = pred.cpu().numpy() 113 | print(' done') 114 | return pred_arr 115 | 116 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 117 | """Numpy implementation of the Frechet Distance. 118 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 119 | and X_2 ~ N(mu_2, C_2) is 120 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 121 | 122 | Stable version by Dougal J. Sutherland. 123 | 124 | Params: 125 | -- mu1 : Numpy array containing the activations of a layer of the 126 | inception net (like returned by the function 'get_predictions') 127 | for generated samples. 128 | -- mu2 : The sample mean over activations, precalculated on an 129 | representative data set. 130 | -- sigma1: The covariance matrix over activations for generated samples. 131 | -- sigma2: The covariance matrix over activations, precalculated on an 132 | representative data set. 133 | 134 | Returns: 135 | -- : The Frechet Distance. 136 | """ 137 | 138 | mu1 = np.atleast_1d(mu1) 139 | mu2 = np.atleast_1d(mu2) 140 | 141 | sigma1 = np.atleast_2d(sigma1) 142 | sigma2 = np.atleast_2d(sigma2) 143 | 144 | assert mu1.shape == mu2.shape, \ 145 | 'Training and test mean vectors have different lengths' 146 | assert sigma1.shape == sigma2.shape, \ 147 | 'Training and test covariances have different dimensions' 148 | 149 | diff = mu1 - mu2 150 | 151 | # Product might be almost singular 152 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 153 | if not np.isfinite(covmean).all(): 154 | msg = ('fid calculation produces singular product; ' 155 | 'adding %s to diagonal of cov estimates') % eps 156 | print(msg) 157 | offset = np.eye(sigma1.shape[0]) * eps 158 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 159 | 160 | # Numerical error might give slight imaginary component 161 | if np.iscomplexobj(covmean): 162 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 163 | m = np.max(np.abs(covmean.imag)) 164 | raise ValueError('Imaginary component {}'.format(m)) 165 | covmean = covmean.real 166 | 167 | tr_covmean = np.trace(covmean) 168 | 169 | return (diff.dot(diff) + np.trace(sigma1) + 170 | np.trace(sigma2) - 2 * tr_covmean) 171 | 172 | def post_process(act): 173 | mu = np.mean(act, axis=0) 174 | sigma = np.cov(act, rowvar=False) 175 | return mu, sigma 176 | 177 | def get_feature_extractor(): 178 | model = resnet50(shortcut_type='B') 179 | model.conv_seg = nn.Sequential(nn.AdaptiveAvgPool3d((1, 1, 1)), 180 | Flatten()) # (N, 512) 181 | # ckpt from https://drive.google.com/file/d/1399AsrYpQDi1vq6ciKRQkfknLsQQyigM/view?usp=sharing 182 | ckpt = torch.load("../gnn_shared/ckpt/pretrain/resnet_50.pth") 183 | ckpt = trim_state_dict_name(ckpt["state_dict"]) 184 | model.load_state_dict(ckpt) # No conv_seg module in ckpt 185 | model = nn.DataParallel(model).cuda() 186 | model.eval() 187 | print("Feature extractor weights loaded") 188 | return model 189 | 190 | def calculate_fid_real(args): 191 | """Calculates the FID of two paths""" 192 | assert os.path.exists("./results/fid/m_real_"+args.real_suffix+str(args.fold)+".npy") 193 | 194 | model = get_feature_extractor() 195 | #dataset = COPD_dataset(img_size=args.img_size, stage="train", fold=args.fold, threshold=600) 196 | dataset = Brain_dataset(img_size=args.img_size, stage="train", fold=args.fold) 197 | args.num_samples = len(dataset) 198 | print("Number of samples:", args.num_samples) 199 | data_loader = torch.utils.data.DataLoader(dataset,batch_size=args.batch_size,drop_last=False, 200 | shuffle=False,num_workers=args.num_workers) 201 | act = get_activations_from_dataloader(model, data_loader, args) 202 | np.save("./results/fid/pred_arr_real_train_size_"+str(args.img_size)+"_resnet50_GSP_fold"+str(args.fold)+".npy", act) 203 | #np.save("./results/fid/pred_arr_real_train_600_size_"+str(args.img_size)+"_resnet50_fold"+str(args.fold)+".npy", act) 204 | #calculate_mmd(args, act) 205 | m, s = post_process(act) 206 | 207 | m1 = np.load("./results/fid/m_real_"+args.real_suffix+str(args.fold)+".npy") 208 | s1 = np.load("./results/fid/s_real_"+args.real_suffix+str(args.fold)+".npy") 209 | 210 | fid_value = calculate_frechet_distance(m1, s1, m, s) 211 | print('FID: ', fid_value) 212 | #np.save("./results/fid/m_real_train_600_size_"+str(args.img_size)+"_resnet50_fold"+str(args.fold)+".npy", m) 213 | #np.save("./results/fid/s_real_train_600_size_"+str(args.img_size)+"_resnet50_fold"+str(args.fold)+".npy", s) 214 | #np.save("./results/fid/m_real_train_size_"+str(args.img_size)+"_resnet50_GSP_fold"+str(args.fold)+".npy", m) 215 | #np.save("./results/fid/s_real_train_size_"+str(args.img_size)+"_resnet50_GSP_fold"+str(args.fold)+".npy", s) 216 | 217 | def calculate_mmd_fake(args): 218 | assert os.path.exists("./results/fid/pred_arr_real_"+args.real_suffix+str(args.fold)+".npy") 219 | act = generate_samples(args) 220 | calculate_mmd(args, act) 221 | 222 | def calculate_fid_fake(args): 223 | #assert os.path.exists("./results/fid/m_real_"+args.real_suffix+str(args.fold)+".npy") 224 | act = generate_samples(args) 225 | m2, s2 = post_process(act) 226 | 227 | m1 = np.load("./results/fid/m_real_"+args.real_suffix+str(args.fold)+".npy") 228 | s1 = np.load("./results/fid/s_real_"+args.real_suffix+str(args.fold)+".npy") 229 | 230 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 231 | print('FID: ', fid_value) 232 | 233 | 234 | def calculate_mmd(args, act): 235 | from torch_two_sample.statistics_diff import MMDStatistic 236 | 237 | act_real = np.load("./results/fid/pred_arr_real_"+args.real_suffix+str(args.fold)+".npz")['arr_0'] 238 | mmd = MMDStatistic(act_real.shape[0], act.shape[0]) 239 | sample_1 = torch.from_numpy(act_real) 240 | sample_2 = torch.from_numpy(act) 241 | 242 | # Need to install updated MMD package at https://github.com/lisun-ai/torch-two-sample for support of median alphas 243 | test_statistics, ret_matrix = mmd(sample_1, sample_2, alphas='median', ret_matrix=True) 244 | #p = mmd.pval(ret_matrix.float(), n_permutations=1000) 245 | 246 | print("\nMMD test statistics:", test_statistics.item()) 247 | 248 | if __name__ == '__main__': 249 | args = parser.parse_args() 250 | start_time = time.time() 251 | calculate_fid_real(args) 252 | calculate_fid_fake(args) 253 | print("Done. Using", (time.time()-start_time)//60, "minutes.") 254 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # train HA-GAN 4 | # Hierarchical Amortized GAN for 3D High Resolution Medical Image Synthesis 5 | # https://ieeexplore.ieee.org/abstract/document/9770375 6 | 7 | import numpy as np 8 | import torch 9 | import os 10 | import json 11 | import argparse 12 | 13 | from torch import nn 14 | from torch import optim 15 | from torch.nn import functional as F 16 | from tensorboardX import SummaryWriter 17 | import nibabel as nib 18 | from nilearn import plotting 19 | 20 | from utils import trim_state_dict_name, inf_train_gen 21 | from volume_dataset import Volume_Dataset 22 | 23 | import matplotlib.pyplot as plt 24 | 25 | torch.manual_seed(0) 26 | torch.cuda.manual_seed_all(0) 27 | torch.backends.cudnn.benchmark = True 28 | 29 | parser = argparse.ArgumentParser(description='PyTorch HA-GAN Training') 30 | parser.add_argument('--batch-size', default=4, type=int, 31 | help='mini-batch size (default: 4), this is the total ' 32 | 'batch size of all GPUs') 33 | parser.add_argument('--workers', default=8, type=int, 34 | help='number of data loading workers (default: 8)') 35 | parser.add_argument('--img-size', default=256, type=int, 36 | help='size of training images (default: 256, can be 128 or 256)') 37 | parser.add_argument('--num-iter', default=80000, type=int, 38 | help='number of iteration for training (default: 80000)') 39 | parser.add_argument('--log-iter', default=20, type=int, 40 | help='number of iteration between logging (default: 20)') 41 | parser.add_argument('--continue-iter', default=0, type=int, 42 | help='continue from a ckeckpoint that has run for n iteration (0 if a new run)') 43 | parser.add_argument('--latent-dim', default=1024, type=int, 44 | help='size of the input latent variable') 45 | parser.add_argument('--g-iter', default=1, type=int, 46 | help='number of generator pass per iteration') 47 | parser.add_argument('--lr-g', default=0.0001, type=float, 48 | help='learning rate for the generator') 49 | parser.add_argument('--lr-d', default=0.0004, type=float, 50 | help='learning rate for the discriminator') 51 | parser.add_argument('--lr-e', default=0.0001, type=float, 52 | help='learning rate for the encoder') 53 | parser.add_argument('--data-dir', type=str, 54 | help='path to the preprocessed data folder') 55 | parser.add_argument('--exp-name', default='HA_GAN_run1', type=str, 56 | help='name of the experiment') 57 | parser.add_argument('--fold', default=0, type=int, 58 | help='fold number for cross validation') 59 | 60 | # configs for conditional generation 61 | parser.add_argument('--lambda-class', default=0.1, type=float, 62 | help='weights for the auxiliary classifier loss') 63 | parser.add_argument('--num-class', default=0, type=int, 64 | help='number of class for auxiliary classifier (0 if unconditional)') 65 | 66 | def main(): 67 | # Configuration 68 | args = parser.parse_args() 69 | 70 | trainset = Volume_Dataset(data_dir=args.data_dir, fold=args.fold, num_class=args.num_class) 71 | train_loader = torch.utils.data.DataLoader(trainset,batch_size=args.batch_size,drop_last=True, 72 | shuffle=False,num_workers=args.workers) 73 | gen_load = inf_train_gen(train_loader) 74 | 75 | if args.img_size == 256: 76 | from models.Model_HA_GAN_256 import Discriminator, Generator, Encoder, Sub_Encoder 77 | elif args.img_size == 128: 78 | from models.Model_HA_GAN_128 import Discriminator, Generator, Encoder, Sub_Encoder 79 | else: 80 | raise NotImplmentedError 81 | 82 | G = Generator(mode='train', latent_dim=args.latent_dim, num_class=args.num_class).cuda() 83 | D = Discriminator(num_class=args.num_class).cuda() 84 | E = Encoder().cuda() 85 | Sub_E = Sub_Encoder(latent_dim=args.latent_dim).cuda() 86 | 87 | g_optimizer = optim.Adam(G.parameters(), lr=args.lr_g, betas=(0.0,0.999), eps=1e-8) 88 | d_optimizer = optim.Adam(D.parameters(), lr=args.lr_d, betas=(0.0,0.999), eps=1e-8) 89 | e_optimizer = optim.Adam(E.parameters(), lr=args.lr_e, betas=(0.0,0.999), eps=1e-8) 90 | sub_e_optimizer = optim.Adam(Sub_E.parameters(), lr=args.lr_e, betas=(0.0,0.999), eps=1e-8) 91 | 92 | # Resume from a previous checkpoint 93 | if args.continue_iter != 0: 94 | ckpt_path = './checkpoint/'+args.exp_name+'/G_iter'+str(args.continue_iter)+'.pth' 95 | ckpt = torch.load(ckpt_path, map_location='cuda') 96 | ckpt['model'] = trim_state_dict_name(ckpt['model']) 97 | G.load_state_dict(ckpt['model']) 98 | g_optimizer.load_state_dict(ckpt['optimizer']) 99 | ckpt_path = './checkpoint/'+args.exp_name+'/D_iter'+str(args.continue_iter)+'.pth' 100 | ckpt = torch.load(ckpt_path, map_location='cuda') 101 | ckpt['model'] = trim_state_dict_name(ckpt['model']) 102 | D.load_state_dict(ckpt['model']) 103 | d_optimizer.load_state_dict(ckpt['optimizer']) 104 | ckpt_path = './checkpoint/'+args.exp_name+'/E_iter'+str(args.continue_iter)+'.pth' 105 | ckpt = torch.load(ckpt_path, map_location='cuda') 106 | ckpt['model'] = trim_state_dict_name(ckpt['model']) 107 | E.load_state_dict(ckpt['model']) 108 | e_optimizer.load_state_dict(ckpt['optimizer']) 109 | ckpt_path = './checkpoint/'+args.exp_name+'/Sub_E_iter'+str(args.continue_iter)+'.pth' 110 | ckpt = torch.load(ckpt_path, map_location='cuda') 111 | ckpt['model'] = trim_state_dict_name(ckpt['model']) 112 | Sub_E.load_state_dict(ckpt['model']) 113 | sub_e_optimizer.load_state_dict(ckpt['optimizer']) 114 | del ckpt 115 | print("Ckpt", args.exp_name, args.continue_iter, "loaded.") 116 | 117 | G = nn.DataParallel(G) 118 | D = nn.DataParallel(D) 119 | E = nn.DataParallel(E) 120 | Sub_E = nn.DataParallel(Sub_E) 121 | 122 | G.train() 123 | D.train() 124 | E.train() 125 | Sub_E.train() 126 | 127 | real_y = torch.ones((args.batch_size, 1)).cuda() 128 | fake_y = torch.zeros((args.batch_size, 1)).cuda() 129 | 130 | loss_f = nn.BCEWithLogitsLoss() 131 | loss_mse = nn.L1Loss() 132 | 133 | fake_labels = torch.zeros((args.batch_size, 1)).cuda() 134 | real_labels = torch.ones((args.batch_size, 1)).cuda() 135 | 136 | summary_writer = SummaryWriter("./checkpoint/"+args.exp_name) 137 | 138 | # save configurations to a dictionary 139 | with open(os.path.join("./checkpoint/"+args.exp_name, 'configs.json'), 'w') as f: 140 | json.dump(vars(args), f, indent=2) 141 | 142 | for p in D.parameters(): 143 | p.requires_grad = False 144 | for p in G.parameters(): 145 | p.requires_grad = False 146 | for p in E.parameters(): 147 | p.requires_grad = False 148 | for p in Sub_E.parameters(): 149 | p.requires_grad = False 150 | 151 | for iteration in range(args.continue_iter, args.num_iter): 152 | 153 | ############################################### 154 | # Train Discriminator (D^H and D^L) 155 | ############################################### 156 | for p in D.parameters(): 157 | p.requires_grad = True 158 | for p in Sub_E.parameters(): 159 | p.requires_grad = False 160 | 161 | real_images, class_label = gen_load.__next__() 162 | D.zero_grad() 163 | real_images = real_images.float().cuda() 164 | # low-res full volume of real image 165 | real_images_small = F.interpolate(real_images, scale_factor = 0.25) 166 | 167 | # randomly select a high-res sub-volume from real image 168 | crop_idx = np.random.randint(0,args.img_size*7/8+1) # 256 * 7/8 + 1 169 | real_images_crop = real_images[:,:,crop_idx:crop_idx+args.img_size//8,:,:] 170 | 171 | if args.num_class == 0: # unconditional 172 | y_real_pred = D(real_images_crop, real_images_small, crop_idx) 173 | d_real_loss = loss_f(y_real_pred, real_labels) 174 | 175 | # random generation 176 | noise = torch.randn((args.batch_size, args.latent_dim)).cuda() 177 | # fake_images: high-res sub-volume of generated image 178 | # fake_images_small: low-res full volume of generated image 179 | fake_images, fake_images_small = G(noise, crop_idx=crop_idx, class_label=None) 180 | y_fake_pred = D(fake_images, fake_images_small, crop_idx) 181 | 182 | else: # conditional 183 | class_label_onehot = F.one_hot(class_label, num_classes=args.num_class) 184 | class_label = class_label.long().cuda() 185 | class_label_onehot = class_label_onehot.float().cuda() 186 | 187 | y_real_pred, y_real_class = D(real_images_crop, real_images_small, crop_idx) 188 | # GAN loss + auxiliary classifier loss 189 | d_real_loss = loss_f(y_real_pred, real_labels) + \ 190 | F.cross_entropy(y_real_class, class_label) 191 | 192 | # random generation 193 | noise = torch.randn((args.batch_size, args.latent_dim)).cuda() 194 | fake_images, fake_images_small = G(noise, crop_idx=crop_idx, class_label=class_label_onehot) 195 | y_fake_pred, y_fake_class= D(fake_images, fake_images_small, crop_idx) 196 | 197 | d_fake_loss = loss_f(y_fake_pred, fake_labels) 198 | 199 | d_loss = d_real_loss + d_fake_loss 200 | d_loss.backward() 201 | 202 | d_optimizer.step() 203 | 204 | ############################################### 205 | # Train Generator (G^A, G^H and G^L) 206 | ############################################### 207 | for p in D.parameters(): 208 | p.requires_grad = False 209 | for p in G.parameters(): 210 | p.requires_grad = True 211 | 212 | for iters in range(args.g_iter): 213 | G.zero_grad() 214 | 215 | noise = torch.randn((args.batch_size, args.latent_dim)).cuda() 216 | if args.num_class == 0: # unconditional 217 | fake_images, fake_images_small = G(noise, crop_idx=crop_idx, class_label=None) 218 | 219 | y_fake_g = D(fake_images, fake_images_small, crop_idx) 220 | g_loss = loss_f(y_fake_g, real_labels) 221 | else: # conditional 222 | fake_images, fake_images_small = G(noise, crop_idx=crop_idx, class_label=class_label_onehot) 223 | 224 | y_fake_g, y_fake_g_class = D(fake_images, fake_images_small, crop_idx) 225 | g_loss = loss_f(y_fake_g, real_labels) + \ 226 | args.lambda_class * F.cross_entropy(y_fake_g_class, class_label) 227 | 228 | g_loss.backward() 229 | g_optimizer.step() 230 | 231 | ############################################### 232 | # Train Encoder (E^H) 233 | ############################################### 234 | for p in E.parameters(): 235 | p.requires_grad = True 236 | for p in G.parameters(): 237 | p.requires_grad = False 238 | E.zero_grad() 239 | 240 | z_hat = E(real_images_crop) 241 | x_hat = G(z_hat, crop_idx=None) 242 | 243 | e_loss = loss_mse(x_hat, real_images_crop) 244 | e_loss.backward() 245 | e_optimizer.step() 246 | 247 | ############################################### 248 | # Train Sub Encoder (E^G) 249 | ############################################### 250 | for p in Sub_E.parameters(): 251 | p.requires_grad = True 252 | for p in E.parameters(): 253 | p.requires_grad = False 254 | Sub_E.zero_grad() 255 | 256 | with torch.no_grad(): 257 | z_hat_i_list = [] 258 | # Process all sub-volume and concatenate 259 | for crop_idx_i in range(0,args.img_size,args.img_size//8): 260 | real_images_crop_i = real_images[:,:,crop_idx_i:crop_idx_i+args.img_size//8,:,:] 261 | z_hat_i = E(real_images_crop_i) 262 | z_hat_i_list.append(z_hat_i) 263 | z_hat = torch.cat(z_hat_i_list, dim=2).detach() 264 | sub_z_hat = Sub_E(z_hat) 265 | # Reconstruction 266 | if args.num_class == 0: # unconditional 267 | sub_x_hat_rec, sub_x_hat_rec_small = G(sub_z_hat, crop_idx=crop_idx) 268 | else: # conditional 269 | sub_x_hat_rec, sub_x_hat_rec_small = G(sub_z_hat, crop_idx=crop_idx, class_label=class_label_onehot) 270 | 271 | sub_e_loss = (loss_mse(sub_x_hat_rec,real_images_crop) + loss_mse(sub_x_hat_rec_small,real_images_small))/2. 272 | 273 | sub_e_loss.backward() 274 | sub_e_optimizer.step() 275 | 276 | # Logging 277 | if iteration%args.log_iter == 0: 278 | summary_writer.add_scalar('D', d_loss.item(), iteration) 279 | summary_writer.add_scalar('D_real', d_real_loss.item(), iteration) 280 | summary_writer.add_scalar('D_fake', d_fake_loss.item(), iteration) 281 | summary_writer.add_scalar('G_fake', g_loss.item(), iteration) 282 | summary_writer.add_scalar('E', e_loss.item(), iteration) 283 | summary_writer.add_scalar('Sub_E', sub_e_loss.item(), iteration) 284 | 285 | ############################################### 286 | # Visualization with Tensorboard 287 | ############################################### 288 | if iteration%200 == 0: 289 | print('[{}/{}]'.format(iteration,args.num_iter), 290 | 'D_real: {:<8.3}'.format(d_real_loss.item()), 291 | 'D_fake: {:<8.3}'.format(d_fake_loss.item()), 292 | 'G_fake: {:<8.3}'.format(g_loss.item()), 293 | 'Sub_E: {:<8.3}'.format(sub_e_loss.item()), 294 | 'E: {:<8.3}'.format(e_loss.item())) 295 | 296 | featmask = np.squeeze((0.5*real_images_crop[0]+0.5).data.cpu().numpy()) 297 | featmask = nib.Nifti1Image(featmask.transpose((2,1,0)),affine = np.eye(4)) 298 | fig=plt.figure() 299 | plotting.plot_img(featmask,title="REAL",cut_coords=(args.img_size//2,args.img_size//2,args.img_size//16),figure=fig,draw_cross=False,cmap="gray") 300 | summary_writer.add_figure('Real', fig, iteration, close=True) 301 | 302 | featmask = np.squeeze((0.5*sub_x_hat_rec[0]+0.5).data.cpu().numpy()) 303 | featmask = nib.Nifti1Image(featmask.transpose((2,1,0)),affine = np.eye(4)) 304 | fig=plt.figure() 305 | plotting.plot_img(featmask,title="REC",cut_coords=(args.img_size//2,args.img_size//2,args.img_size//16),figure=fig,draw_cross=False,cmap="gray") 306 | summary_writer.add_figure('Rec', fig, iteration, close=True) 307 | 308 | featmask = np.squeeze((0.5*fake_images[0]+0.5).data.cpu().numpy()) 309 | featmask = nib.Nifti1Image(featmask.transpose((2,1,0)),affine = np.eye(4)) 310 | fig=plt.figure() 311 | plotting.plot_img(featmask,title="FAKE",cut_coords=(args.img_size//2,args.img_size//2,args.img_size//16),figure=fig,draw_cross=False,cmap="gray") 312 | summary_writer.add_figure('Fake', fig, iteration, close=True) 313 | 314 | if iteration > 30000 and (iteration+1)%500 == 0: 315 | torch.save({'model':G.state_dict(), 'optimizer':g_optimizer.state_dict()},'./checkpoint/'+args.exp_name+'/G_iter'+str(iteration+1)+'.pth') 316 | torch.save({'model':D.state_dict(), 'optimizer':d_optimizer.state_dict()},'./checkpoint/'+args.exp_name+'/D_iter'+str(iteration+1)+'.pth') 317 | torch.save({'model':E.state_dict(), 'optimizer':e_optimizer.state_dict()},'./checkpoint/'+args.exp_name+'/E_iter'+str(iteration+1)+'.pth') 318 | torch.save({'model':Sub_E.state_dict(), 'optimizer':sub_e_optimizer.state_dict()},'./checkpoint/'+args.exp_name+'/Sub_E_iter'+str(iteration+1)+'.pth') 319 | 320 | if __name__ == '__main__': 321 | main() 322 | --------------------------------------------------------------------------------