├── models
├── __init__.py
├── layers.py
├── Model_HA_GAN_128.py
└── Model_HA_GAN_256.py
├── figures
├── main_github.png
├── tensorboard.png
└── sample_HA_GAN.png
├── LICENSE
├── volume_dataset.py
├── preprocess.py
├── utils.py
├── README.md
├── environment.yml
├── evaluation
├── resnet3D.py
└── fid_score.py
└── train.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/figures/main_github.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/batmanlab/HA-GAN/HEAD/figures/main_github.png
--------------------------------------------------------------------------------
/figures/tensorboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/batmanlab/HA-GAN/HEAD/figures/tensorboard.png
--------------------------------------------------------------------------------
/figures/sample_HA_GAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/batmanlab/HA-GAN/HEAD/figures/sample_HA_GAN.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 batmanlab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/volume_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from sklearn.model_selection import KFold
3 | import numpy as np
4 | import glob
5 |
6 | class Volume_Dataset(Dataset):
7 |
8 | def __init__(self, data_dir, mode='train', fold=0, num_class=0):
9 | self.sid_list = []
10 | self.data_dir = data_dir
11 | self.num_class = num_class
12 |
13 | for item in glob.glob(self.data_dir+"*.npy"):
14 | self.sid_list.append(item.split('/')[-1])
15 |
16 | self.sid_list.sort()
17 | self.sid_list = np.asarray(self.sid_list)
18 |
19 | kf = KFold(n_splits=5, shuffle=True, random_state=0)
20 | train_index, valid_index = list(kf.split(self.sid_list))[fold]
21 | print("Fold:", fold)
22 | if mode=="train":
23 | self.sid_list = self.sid_list[train_index]
24 | else:
25 | self.sid_list = self.sid_list[valid_index]
26 | print("Dataset size:", len(self))
27 |
28 | self.class_label_dict = dict()
29 | if self.num_class > 0: # conditional
30 | FILE = open("class_label.csv", "r")
31 | FILE.readline() # header
32 | for myline in FILE.readlines():
33 | mylist = myline.strip("\n").split(",")
34 | self.class_label_dict[mylist[0]] = int(mylist[1])
35 | FILE.close()
36 |
37 | def __len__(self):
38 | return len(self.sid_list)
39 |
40 | def __getitem__(self, idx):
41 | img = np.load(self.data_dir+self.sid_list[idx])
42 | class_label = self.class_label_dict.get(self.sid_list[idx], -1) # -1 if no class label
43 | return img[None,:,:,:], class_label
44 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | # resize and rescale images for preprocessing
2 |
3 | import glob
4 | import SimpleITK as sitk
5 | import numpy as np
6 | from skimage.transform import resize
7 | import os
8 | import multiprocessing as mp
9 |
10 | ### Configs
11 | # 8 cores are used for multi-thread processing
12 | NUM_JOBS = 8
13 | # resized output size, can be 128 or 256
14 | IMG_SIZE = 256
15 | INPUT_DATA_DIR = '/path_to_imgs/'
16 | OUTPUT_DATA_DIR = '/output_folder/'
17 | # the intensity range is clipped with the two thresholds, this default is used for our CT images, please adapt to your own dataset
18 | LOW_THRESHOLD = -1024
19 | HIGH_THRESHOLD = 600
20 | # suffix (ext.) of input images
21 | SUFFIX = '.nii.gz'
22 | # whether or not to trim blank axial slices, recommend to set as True
23 | TRIM_BLANK_SLICES = True
24 |
25 | def resize_img(img):
26 | nan_mask = np.isnan(img) # Remove NaN
27 | img[nan_mask] = LOW_THRESHOLD
28 | img = np.interp(img, [LOW_THRESHOLD, HIGH_THRESHOLD], [-1,1])
29 |
30 | if TRIM_BLANK_SLICES:
31 | valid_plane_i = np.mean(img, (1,2)) != -1 # Remove blank axial planes
32 | img = img[valid_plane_i,:,:]
33 |
34 | img = resize(img, (IMG_SIZE, IMG_SIZE, IMG_SIZE), mode='constant', cval=-1)
35 | return img
36 |
37 | def main():
38 | img_list = list(glob.glob(INPUT_DATA_DIR+"*"+SUFFIX))
39 |
40 | processes = []
41 | for i in range(NUM_JOBS):
42 | processes.append(mp.Process(target=batch_resize, args=(i, img_list)))
43 | for p in processes:
44 | p.start()
45 |
46 | def batch_resize(batch_idx, img_list):
47 | for idx in range(len(img_list)):
48 | if idx % NUM_JOBS != batch_idx:
49 | continue
50 | imgname = img_list[idx].split('/')[-1]
51 | if os.path.exists(OUTPUT_DATA_DIR+imgname.split('.')[0]+".npy"):
52 | # skip images that already finished pre-processing
53 | continue
54 | try:
55 | img = sitk.ReadImage(INPUT_DATA_DIR + img_list[idx])
56 | except Exception as e:
57 | # skip corrupted images
58 | print(e)
59 | print("Image loading error:", imgname)
60 | continue
61 | img = sitk.GetArrayFromImage(img)
62 | try:
63 | img = resize_img(img)
64 | except Exception as e: # Some images are corrupted
65 | print(e)
66 | print("Image resize error:", imgname)
67 | continue
68 | # preprocessed images are saved in numpy arrays
69 | np.save(OUTPUT_DATA_DIR+imgname.split('.')[0]+".npy", img)
70 |
71 | if __name__ == '__main__':
72 | main()
73 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import OrderedDict
3 |
4 | import numpy as np
5 | from skimage.transform import resize
6 |
7 | import torch
8 |
9 | def post_process_brain(x_pred):
10 | x_pred = resize(x_pred, (256-90,256-40,256-40), mode='constant', cval=0.)
11 | x_canvas = np.zeros((256,256,256))
12 | x_canvas[50:-40,20:-20,20:-20] = x_pred
13 | x_canvas = np.flip(x_canvas,0)
14 | return x_canvas
15 |
16 | def _itensity_normalize(volume):
17 | pixels = volume[volume > 0]
18 | mean = pixels.mean()
19 | std = pixels.std()
20 | out = (volume - mean)/std
21 | return out
22 |
23 | class Flatten(torch.nn.Module):
24 | def forward(self, inp):
25 | return inp.view(inp.size(0), -1)
26 |
27 | def calculate_nmse(img1, img2):
28 | img1 = img1.astype(np.float64)
29 | img2 = img2.astype(np.float64)
30 | mse = np.mean((img1 - img2)**2)
31 | mse0 = np.mean(img1**2)
32 | if mse == 0:
33 | return float('inf')
34 | return mse / mse0 * 100.
35 |
36 | def calculate_psnr(img1, img2):
37 | # img1 and img2 have range [0, 1]
38 | img1 = img1.astype(np.float64)
39 | img2 = img2.astype(np.float64)
40 | mse = np.mean((img1 - img2)**2)
41 | if mse == 0:
42 | return float('inf')
43 | return 20 * math.log10(1.0 / math.sqrt(mse))
44 |
45 | class KLN01Loss(torch.nn.Module):
46 |
47 | def __init__(self, direction, minimize):
48 | super(KLN01Loss, self).__init__()
49 | self.minimize = minimize
50 | assert direction in ['pq', 'qp'], 'direction?'
51 |
52 | self.direction = direction
53 |
54 | def forward(self, samples):
55 |
56 | assert samples.nelement() == samples.size(1) * samples.size(0), 'wtf?'
57 |
58 | samples = samples.view(samples.size(0), -1)
59 |
60 | self.samples_var = var(samples)
61 | self.samples_mean = samples.mean(0)
62 |
63 | samples_mean = self.samples_mean
64 | samples_var = self.samples_var
65 |
66 | if self.direction == 'pq':
67 | # mu_1 = 0; sigma_1 = 1
68 |
69 | t1 = (1 + samples_mean.pow(2)) / (2 * samples_var.pow(2))
70 | t2 = samples_var.log()
71 |
72 | KL = (t1 + t2 - 0.5).mean()
73 | else:
74 | # mu_2 = 0; sigma_2 = 1
75 |
76 | t1 = (samples_var.pow(2) + samples_mean.pow(2)) / 2
77 | t2 = -samples_var.log()
78 |
79 | KL = (t1 + t2 - 0.5).mean()
80 |
81 | if not self.minimize:
82 | KL *= -1
83 |
84 | return KL
85 |
86 | def trim_state_dict_name(state_dict):
87 | for k in list(state_dict.keys()):
88 | if k.startswith('module.'):
89 | # remove prefix
90 | state_dict[k[len("module."):]] = state_dict[k]
91 | del state_dict[k]
92 | return state_dict
93 |
94 | def inf_train_gen(data_loader):
95 | while True:
96 | for _,batch in enumerate(data_loader):
97 | yield batch
98 |
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import init
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | from torch.nn import Parameter as P
8 |
9 | # Projection of x onto y
10 | def proj(x, y):
11 | return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
12 |
13 |
14 | # Orthogonalize x wrt list of vectors ys
15 | def gram_schmidt(x, ys):
16 | for y in ys:
17 | x = x - proj(x, y)
18 | return x
19 |
20 | # Apply num_itrs steps of the power method to estimate top N singular values.
21 | def power_iteration(W, u_, update=True, eps=1e-12):
22 | # Lists holding singular vectors and values
23 | us, vs, svs = [], [], []
24 | for i, u in enumerate(u_):
25 | # Run one step of the power iteration
26 | with torch.no_grad():
27 | v = torch.matmul(u, W)
28 | # Run Gram-Schmidt to subtract components of all other singular vectors
29 | v = F.normalize(gram_schmidt(v, vs), eps=eps)
30 | # Add to the list
31 | vs += [v]
32 | # Update the other singular vector
33 | u = torch.matmul(v, W.t())
34 | # Run Gram-Schmidt to subtract components of all other singular vectors
35 | u = F.normalize(gram_schmidt(u, us), eps=eps)
36 | # Add to the list
37 | us += [u]
38 | if update:
39 | u_[i][:] = u
40 | # Compute this singular value and add it to the list
41 | svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
42 | #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
43 | return svs, us, vs
44 |
45 | # Convenience passthrough function
46 | class identity(nn.Module):
47 | def forward(self, input):
48 | return input
49 |
50 | # Spectral normalization base class
51 | class SN(object):
52 | def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
53 | # Number of power iterations per step
54 | self.num_itrs = num_itrs
55 | # Number of singular values
56 | self.num_svs = num_svs
57 | # Transposed?
58 | self.transpose = transpose
59 | # Epsilon value for avoiding divide-by-0
60 | self.eps = eps
61 | # Register a singular vector for each sv
62 | for i in range(self.num_svs):
63 | self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
64 | self.register_buffer('sv%d' % i, torch.ones(1))
65 |
66 | # Singular vectors (u side)
67 | @property
68 | def u(self):
69 | return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
70 |
71 | # Singular values;
72 | # note that these buffers are just for logging and are not used in training.
73 | @property
74 | def sv(self):
75 | return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
76 |
77 | # Compute the spectrally-normalized weight
78 | def W_(self):
79 | W_mat = self.weight.view(self.weight.size(0), -1)
80 | if self.transpose:
81 | W_mat = W_mat.t()
82 | # Apply num_itrs power iterations
83 | for _ in range(self.num_itrs):
84 | svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
85 | # Update the svs
86 | if self.training:
87 | with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
88 | for i, sv in enumerate(svs):
89 | self.sv[i][:] = sv
90 | return self.weight / svs[0]
91 |
92 | # 3D Conv layer with spectral norm
93 | class SNConv3d(nn.Conv3d, SN):
94 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
95 | padding=0, dilation=1, groups=1, bias=True,
96 | num_svs=1, num_itrs=1, eps=1e-12):
97 | nn.Conv3d.__init__(self, in_channels, out_channels, kernel_size, stride,
98 | padding, dilation, groups, bias)
99 | SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
100 | def forward(self, x):
101 | return F.conv3d(x, self.W_(), self.bias, self.stride,
102 | self.padding, self.dilation, self.groups)
103 |
104 |
105 | # Linear layer with spectral norm
106 | class SNLinear(nn.Linear, SN):
107 | def __init__(self, in_features, out_features, bias=True,
108 | num_svs=1, num_itrs=1, eps=1e-12):
109 | nn.Linear.__init__(self, in_features, out_features, bias)
110 | SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
111 | def forward(self, x):
112 | return F.linear(x, self.W_(), self.bias)
113 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Hierarchical Amortized GAN (HA-GAN)
2 |
3 | Official PyTorch implementation for paper *Hierarchical Amortized GAN for 3D High Resolution Medical Image Synthesis*, accepted to *IEEE Journal of Biomedical and Health Informatics*
4 |
5 |
6 |
7 |
8 |
9 | #### [[Paper & Supplementary Material]](https://ieeexplore.ieee.org/abstract/document/9770375)
10 |
11 | Generative Adversarial Networks (GAN) have many potential medical imaging applications. Due to the limited memory of Graphical Processing Units (GPUs), most current 3D GAN models are trained on low-resolution medical images. In this work, we propose a novel end-to-end GAN architecture that can generate high-resolution 3D images. We achieve this goal by using different configurations between training and inference. During training, we adopt a hierarchical structure that simultaneously generates a low-resolution version of the image and a randomly selected sub-volume of the high-resolution image. The hierarchical design has two advantages: First, the memory demand for training on high-resolution images is amortized among sub-volumes. Furthermore, anchoring the high-resolution sub-volumes to a single low-resolution image ensures anatomical consistency between sub-volumes. During inference, our model can directly generate full high-resolution images. We also incorporate an encoder (hidden in the figure to improve clarity) into the model to extract features from the images.
12 |
13 | ### Requirements
14 | - PyTorch
15 | - scikit-image
16 | - nibabel
17 | - nilearn
18 | - tensorboardX
19 | - SimpleITK
20 |
21 | ```bash
22 | conda env create --name hagan -f environment.yml
23 | conda activate hagan
24 | ```
25 |
26 | ### Data Preprocessing
27 | The volume data need to be cropped or resized to 1283 or 2563, and intensity value need to be scaled to [-1,1]. In addition, we would like to advise you to trim blank axial slices. More details can be found at
28 | ```bash
29 | python preprocess.py
30 | ```
31 |
32 | ### Training
33 | #### Unconditional HA-GAN
34 | ```bash
35 | python train.py --workers 8 --img-size 256 --num-class 0 --exp-name 'HA_GAN_run1' --data-dir DATA_DIR
36 | ```
37 | #### Conditional HA-GAN
38 | ```bash
39 | python train.py --workers 8 --img-size 256 --num-class N --exp-name 'HA_GAN_cond_run1' --data-dir DATA_DIR
40 | ```
41 |
42 | Track your training with Tensorboard:
43 |
44 |
45 |
46 |
47 | It will take around 22 hours to train unconditional HA-GAN for 80000 iterations with two NVIDIA Tesla V100 GPU. It is suggested to have at least 3000 images for training to avoid mode collapse, or you may need to consider [data augmentation](https://docs.monai.io/en/stable/transforms.html#intensity-dict).
48 |
49 | ### Testing
50 | ```bash
51 | visualization.ipynb
52 | evaluation/visualize_feature_MDS.ipynb
53 | python evaluation/fid_score.py
54 | ```
55 |
56 | ### Sample images
57 |
58 |
59 |
60 |
61 | ### Pretrained weights
62 |
63 |
64 |
65 | | Dataset |
66 | Anatomy |
67 | Iteration |
68 | Checkpoint |
69 |
70 | | COPDGene |
71 | Lung |
72 | 80000 |
73 | Download |
74 |
75 | | GSP |
76 | Brain |
77 | 80000 |
78 | Download |
79 |
80 |
81 |
82 |
83 | ### Citation
84 | ```
85 | @ARTICLE{hagan2022,
86 | author={Sun, Li and Chen, Junxiang and Xu, Yanwu and Gong, Mingming and Yu, Ke and Batmanghelich, Kayhan},
87 | journal={IEEE Journal of Biomedical and Health Informatics},
88 | title={Hierarchical Amortized GAN for 3D High Resolution Medical Image Synthesis},
89 | year={2022},
90 | volume={26},
91 | number={8},
92 | pages={3966-3975},
93 | doi={10.1109/JBHI.2022.3172976}}
94 | ```
95 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: 3dgan
2 | channels:
3 | - simpleitk
4 | - pytorch-lts
5 | - nvidia
6 | - anaconda
7 | - conda-forge
8 | - defaults
9 | dependencies:
10 | - _libgcc_mutex=0.1=main
11 | - _openmp_mutex=5.1=1_gnu
12 | - blas=1.0=mkl
13 | - brotli=1.0.9=he6710b0_2
14 | - brotlipy=0.7.0=py38h0a891b7_1004
15 | - bzip2=1.0.8=h7b6447c_0
16 | - ca-certificates=2022.4.26=h06a4308_0
17 | - cached-property=1.5.2=hd8ed1ab_1
18 | - cached_property=1.5.2=pyha770c72_1
19 | - certifi=2022.5.18.1=py38h06a4308_0
20 | - cffi=1.15.0=py38hd667e15_1
21 | - charset-normalizer=2.0.12=pyhd8ed1ab_0
22 | - cloudpickle=2.0.0=pyhd3eb1b0_0
23 | - cryptography=37.0.2=py38h2b5fc30_0
24 | - cudatoolkit=11.1.74=h6bb024c_0
25 | - cycler=0.11.0=pyhd3eb1b0_0
26 | - cytoolz=0.11.0=py38h7b6447c_0
27 | - dask-core=2022.2.1=pyhd3eb1b0_0
28 | - dbus=1.13.18=hb2f20db_0
29 | - expat=2.4.4=h295c915_0
30 | - ffmpeg=4.2.2=h20bf706_0
31 | - fontconfig=2.13.1=h6c09931_0
32 | - fonttools=4.25.0=pyhd3eb1b0_0
33 | - freetype=2.11.0=h70c0345_0
34 | - fsspec=2022.2.0=pyhd3eb1b0_0
35 | - giflib=5.2.1=h7b6447c_0
36 | - glib=2.69.1=h4ff587b_1
37 | - gmp=6.2.1=h2531618_2
38 | - gnutls=3.6.15=he1e5248_0
39 | - gst-plugins-base=1.14.0=h8213a91_2
40 | - gstreamer=1.14.0=h28cd5cc_2
41 | - h5py=3.2.1=nompi_py38h9915d05_100
42 | - hdf5=1.10.6=nompi_h3c11f04_101
43 | - icu=58.2=he6710b0_3
44 | - idna=3.3=pyhd8ed1ab_0
45 | - imageio=2.9.0=pyhd3eb1b0_0
46 | - intel-openmp=2021.4.0=h06a4308_3561
47 | - joblib=1.1.0=pyhd8ed1ab_0
48 | - jpeg=9b=h024ee3a_2
49 | - kiwisolver=1.4.2=py38h295c915_0
50 | - lame=3.100=h7b6447c_0
51 | - lcms2=2.12=h3be6417_0
52 | - ld_impl_linux-64=2.38=h1181459_1
53 | - libffi=3.3=he6710b0_2
54 | - libgcc-ng=11.2.0=h1234567_0
55 | - libgfortran-ng=7.5.0=h14aa051_20
56 | - libgfortran4=7.5.0=h14aa051_20
57 | - libgomp=11.2.0=h1234567_0
58 | - libiconv=1.16=h516909a_0
59 | - libidn2=2.3.2=h7f8727e_0
60 | - libopus=1.3.1=h7b6447c_0
61 | - libpng=1.6.37=hbc83047_0
62 | - libstdcxx-ng=11.2.0=h1234567_0
63 | - libtasn1=4.16.0=h27cfd23_0
64 | - libtiff=4.1.0=h2733197_1
65 | - libunistring=0.9.10=h27cfd23_0
66 | - libuuid=1.0.3=h7f8727e_2
67 | - libuv=1.40.0=h7b6447c_0
68 | - libvpx=1.7.0=h439df22_0
69 | - libwebp=1.2.0=h89dd481_0
70 | - libxcb=1.15=h7f8727e_0
71 | - libxml2=2.9.12=h74e7548_2
72 | - libxslt=1.1.34=hc22bd24_0
73 | - locket=0.2.1=py38h06a4308_2
74 | - lxml=4.8.0=py38h1f438cf_0
75 | - lz4-c=1.9.3=h295c915_1
76 | - matplotlib=3.5.1=py38h06a4308_1
77 | - matplotlib-base=3.5.1=py38ha18d171_1
78 | - mkl=2021.4.0=h06a4308_640
79 | - mkl-service=2.4.0=py38h7f8727e_0
80 | - mkl_fft=1.3.1=py38hd3c417c_0
81 | - mkl_random=1.2.2=py38h51133e4_0
82 | - munkres=1.1.4=py_0
83 | - ncurses=6.3=h7f8727e_2
84 | - nettle=3.7.3=hbbd107a_1
85 | - networkx=2.7.1=pyhd3eb1b0_0
86 | - nibabel=3.2.2=pyhd8ed1ab_0
87 | - nilearn=0.9.1=pyhd8ed1ab_0
88 | - ninja=1.10.2=h06a4308_5
89 | - ninja-base=1.10.2=hd09550d_5
90 | - numpy=1.22.3=py38he7a7128_0
91 | - numpy-base=1.22.3=py38hf524024_0
92 | - openh264=2.1.1=h4ff587b_0
93 | - openssl=1.1.1o=h7f8727e_0
94 | - packaging=21.3=pyhd8ed1ab_0
95 | - pandas=1.4.2=py38h47df419_1
96 | - partd=1.2.0=pyhd3eb1b0_1
97 | - pcre=8.45=h295c915_0
98 | - pillow=9.0.1=py38h22f2fdc_0
99 | - pip=21.2.4=py38h06a4308_0
100 | - pycparser=2.21=pyhd8ed1ab_0
101 | - pydicom=2.3.0=pyh6c4a22f_0
102 | - pyopenssl=22.0.0=pyhd8ed1ab_0
103 | - pyparsing=3.0.9=pyhd8ed1ab_0
104 | - pyqt=5.9.2=py38h05f1152_4
105 | - pysocks=1.7.1=py38h578d9bd_5
106 | - python=3.8.13=h12debd9_0
107 | - python-dateutil=2.8.2=pyhd8ed1ab_0
108 | - python_abi=3.8=2_cp38
109 | - pytorch=1.8.2=py3.8_cuda11.1_cudnn8.0.5_0
110 | - pytz=2022.1=pyhd8ed1ab_0
111 | - pywavelets=1.3.0=py38h7f8727e_0
112 | - pyyaml=6.0=py38h7f8727e_1
113 | - qt=5.9.7=h5867ecd_1
114 | - readline=8.1.2=h7f8727e_1
115 | - requests=2.27.1=pyhd8ed1ab_0
116 | - scikit-image=0.19.2=py38h51133e4_0
117 | - scikit-learn=1.0.2=py38h51133e4_1
118 | - scipy=1.7.3=py38hc147768_0
119 | - setuptools=61.2.0=py38h06a4308_0
120 | - simpleitk=2.0.2=py38h3fd9d12_0
121 | - sip=4.19.13=py38h295c915_0
122 | - six=1.16.0=pyhd3eb1b0_1
123 | - sqlite=3.38.3=hc218d9a_0
124 | - threadpoolctl=3.1.0=pyh8a188c0_0
125 | - tifffile=2020.10.1=py38hdd07704_2
126 | - tk=8.6.11=h1ccaba5_1
127 | - toolz=0.11.2=pyhd3eb1b0_0
128 | - torchaudio=0.8.2=py38
129 | - torchvision=0.9.2=py38_cu111
130 | - tornado=6.1=py38h27cfd23_0
131 | - typing_extensions=4.1.1=pyh06a4308_0
132 | - urllib3=1.26.9=pyhd8ed1ab_0
133 | - wheel=0.37.1=pyhd3eb1b0_0
134 | - x264=1!157.20191217=h7b6447c_0
135 | - xz=5.2.5=h7f8727e_1
136 | - yaml=0.2.5=h7b6447c_0
137 | - zlib=1.2.12=h7f8727e_2
138 | - zstd=1.4.9=haebb681_0
139 | - pip:
140 | - protobuf==3.20.1
141 | - tensorboardx==2.5
142 | prefix: /ocean/projects/asc170022p/lisun/miniconda3/envs/3dgan
143 |
--------------------------------------------------------------------------------
/evaluation/resnet3D.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import math
6 | from functools import partial
7 |
8 | __all__ = [
9 | 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152', 'resnet200'
11 | ]
12 |
13 |
14 | def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
15 | # 3x3x3 convolution with padding
16 | return nn.Conv3d(
17 | in_planes,
18 | out_planes,
19 | kernel_size=3,
20 | dilation=dilation,
21 | stride=stride,
22 | padding=dilation,
23 | bias=False)
24 |
25 |
26 | def downsample_basic_block(x, planes, stride, no_cuda=False):
27 | out = F.avg_pool3d(x, kernel_size=1, stride=stride)
28 | zero_pads = torch.Tensor(
29 | out.size(0), planes - out.size(1), out.size(2), out.size(3),
30 | out.size(4)).zero_()
31 | if not no_cuda:
32 | if isinstance(out.data, torch.cuda.FloatTensor):
33 | zero_pads = zero_pads.cuda()
34 |
35 | out = Variable(torch.cat([out.data, zero_pads], dim=1))
36 |
37 | return out
38 |
39 |
40 | class BasicBlock(nn.Module):
41 | expansion = 1
42 |
43 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
44 | super(BasicBlock, self).__init__()
45 | self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation)
46 | self.bn1 = nn.BatchNorm3d(planes)
47 | self.relu = nn.ReLU(inplace=True)
48 | self.conv2 = conv3x3x3(planes, planes, dilation=dilation)
49 | self.bn2 = nn.BatchNorm3d(planes)
50 | self.downsample = downsample
51 | self.stride = stride
52 | self.dilation = dilation
53 |
54 | def forward(self, x):
55 | residual = x
56 |
57 | out = self.conv1(x)
58 | out = self.bn1(out)
59 | out = self.relu(out)
60 | out = self.conv2(out)
61 | out = self.bn2(out)
62 |
63 | if self.downsample is not None:
64 | residual = self.downsample(x)
65 |
66 | out += residual
67 | out = self.relu(out)
68 |
69 | return out
70 |
71 |
72 | class Bottleneck(nn.Module):
73 | expansion = 4
74 |
75 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
76 | super(Bottleneck, self).__init__()
77 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
78 | self.bn1 = nn.BatchNorm3d(planes)
79 | self.conv2 = nn.Conv3d(
80 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
81 | self.bn2 = nn.BatchNorm3d(planes)
82 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
83 | self.bn3 = nn.BatchNorm3d(planes * 4)
84 | self.relu = nn.ReLU(inplace=True)
85 | self.downsample = downsample
86 | self.stride = stride
87 | self.dilation = dilation
88 |
89 | def forward(self, x):
90 | residual = x
91 |
92 | out = self.conv1(x)
93 | out = self.bn1(out)
94 | out = self.relu(out)
95 |
96 | out = self.conv2(out)
97 | out = self.bn2(out)
98 | out = self.relu(out)
99 |
100 | out = self.conv3(out)
101 | out = self.bn3(out)
102 |
103 | if self.downsample is not None:
104 | residual = self.downsample(x)
105 |
106 | out += residual
107 | out = self.relu(out)
108 |
109 | return out
110 |
111 | class upsample(nn.Module):
112 | def forward(self, inp):
113 | return F.interpolate(inp, scale_factor = 2)
114 |
115 | class ResNet(nn.Module):
116 |
117 | def __init__(self,
118 | block,
119 | layers,
120 | num_seg_classes=2,
121 | shortcut_type='B',
122 | no_cuda = False):
123 | self.inplanes = 64
124 | self.no_cuda = no_cuda
125 | super(ResNet, self).__init__()
126 | self.conv1 = nn.Conv3d(
127 | 1,
128 | 64,
129 | kernel_size=7,
130 | stride=(2, 2, 2),
131 | padding=(3, 3, 3),
132 | bias=False)
133 |
134 | self.bn1 = nn.BatchNorm3d(64)
135 | self.relu = nn.ReLU(inplace=True)
136 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
137 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
138 | self.layer2 = self._make_layer(
139 | block, 128, layers[1], shortcut_type, stride=2)
140 | self.layer3 = self._make_layer(
141 | block, 256, layers[2], shortcut_type, stride=1, dilation=2)
142 | self.layer4 = self._make_layer(
143 | block, 512, layers[3], shortcut_type, stride=1, dilation=4)
144 |
145 | self.conv_seg = nn.Sequential(
146 | upsample(),
147 | nn.ConvTranspose3d(
148 | 512 * block.expansion,
149 | 32,
150 | 2,
151 | stride=2
152 | ),
153 | nn.BatchNorm3d(32),
154 | nn.ReLU(inplace=True),
155 | upsample(),
156 | nn.Conv3d(
157 | 32,
158 | 32,
159 | kernel_size=3,
160 | stride=(1, 1, 1),
161 | padding=(1, 1, 1),
162 | bias=False),
163 | nn.BatchNorm3d(32),
164 | nn.ReLU(inplace=True),
165 | nn.Conv3d(
166 | 32,
167 | num_seg_classes,
168 | kernel_size=1,
169 | stride=(1, 1, 1),
170 | bias=False),
171 | nn.Sigmoid()
172 | )
173 |
174 | for m in self.modules():
175 | if isinstance(m, nn.Conv3d):
176 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
177 | elif isinstance(m, nn.BatchNorm3d):
178 | m.weight.data.fill_(1)
179 | m.bias.data.zero_()
180 |
181 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
182 | downsample = None
183 | if stride != 1 or self.inplanes != planes * block.expansion:
184 | if shortcut_type == 'A':
185 | downsample = partial(
186 | downsample_basic_block,
187 | planes=planes * block.expansion,
188 | stride=stride,
189 | no_cuda=self.no_cuda)
190 | else:
191 | downsample = nn.Sequential(
192 | nn.Conv3d(
193 | self.inplanes,
194 | planes * block.expansion,
195 | kernel_size=1,
196 | stride=stride,
197 | bias=False), nn.BatchNorm3d(planes * block.expansion))
198 |
199 | layers = []
200 | layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
201 | self.inplanes = planes * block.expansion
202 | for i in range(1, blocks):
203 | layers.append(block(self.inplanes, planes, dilation=dilation))
204 |
205 | return nn.Sequential(*layers)
206 |
207 | def forward(self, x):
208 | x = self.conv1(x)
209 | x = self.bn1(x)
210 | x = self.relu(x)
211 | x = self.maxpool(x)
212 | x = self.layer1(x)
213 | x = self.layer2(x)
214 | x = self.layer3(x)
215 | x = self.layer4(x)
216 | #print(x.shape)
217 | x = self.conv_seg(x)
218 |
219 | return x
220 |
221 | def resnet10(**kwargs):
222 | """Constructs a ResNet-18 model.
223 | """
224 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
225 | return model
226 |
227 |
228 | def resnet18(**kwargs):
229 | """Constructs a ResNet-18 model.
230 | """
231 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
232 | return model
233 |
234 |
235 | def resnet34(**kwargs):
236 | """Constructs a ResNet-34 model.
237 | """
238 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
239 | return model
240 |
241 |
242 | def resnet50(**kwargs):
243 | """Constructs a ResNet-50 model.
244 | """
245 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
246 | return model
247 |
248 |
249 | def resnet101(**kwargs):
250 | """Constructs a ResNet-101 model.
251 | """
252 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
253 | return model
254 |
255 |
256 | def resnet152(**kwargs):
257 | """Constructs a ResNet-101 model.
258 | """
259 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
260 | return model
261 |
262 |
263 | def resnet200(**kwargs):
264 | """Constructs a ResNet-101 model.
265 | """
266 | model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)
267 | return model
268 |
--------------------------------------------------------------------------------
/models/Model_HA_GAN_128.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 | from torch import nn
5 | from torch import optim
6 | from torch.nn import functional as F
7 | from models.layers import SNConv3d, SNLinear
8 |
9 | class Code_Discriminator(nn.Module):
10 | def __init__(self, code_size, num_units=256):
11 | super(Code_Discriminator, self).__init__()
12 |
13 | self.l1 = nn.Sequential(SNLinear(code_size, num_units),
14 | nn.LeakyReLU(0.2,inplace=True))
15 | self.l2 = nn.Sequential(SNLinear(num_units, num_units),
16 | nn.LeakyReLU(0.2,inplace=True))
17 | self.l3 = SNLinear(num_units, 1)
18 |
19 | def forward(self, x):
20 | x = self.l1(x)
21 | x = self.l2(x)
22 | x = self.l3(x)
23 |
24 | return x
25 |
26 | class Sub_Encoder(nn.Module):
27 | def __init__(self, channel=256, latent_dim=1024):
28 | super(Sub_Encoder, self).__init__()
29 |
30 | self.relu = nn.ReLU()
31 | self.conv2 = nn.Conv3d(channel//4, channel//4, kernel_size=4, stride=2, padding=1) # out:[16,16,16]
32 | self.bn2 = nn.GroupNorm(8, channel//4)
33 | self.conv3 = nn.Conv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) # out:[8,8,8]
34 | self.bn3 = nn.GroupNorm(8, channel//2)
35 | self.conv4 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) # out:[4,4,4]
36 | self.bn4 = nn.GroupNorm(8, channel)
37 | self.conv5 = nn.Conv3d(channel, latent_dim, kernel_size=4, stride=1, padding=0) # out:[1,1,1,1]
38 |
39 | def forward(self, h):
40 | h = self.conv2(h)
41 | h = self.relu(self.bn2(h))
42 | h = self.conv3(h)
43 | h = self.relu(self.bn3(h))
44 | h = self.conv4(h)
45 | h = self.relu(self.bn4(h))
46 | h = self.conv5(h).squeeze()
47 | return h
48 |
49 | class Encoder(nn.Module):
50 | def __init__(self, channel=64):
51 | super(Encoder, self).__init__()
52 |
53 | self.relu = nn.ReLU()
54 | self.conv1 = nn.Conv3d(1, channel//2, kernel_size=4, stride=2, padding=1) # in:[16,128,128], out:[8,64,64]
55 | self.bn1 = nn.GroupNorm(8, channel//2)
56 | self.conv2 = nn.Conv3d(channel//2, channel//2, kernel_size=3, stride=1, padding=1) # out:[8,64,64]
57 | self.bn2 = nn.GroupNorm(8, channel//2)
58 | self.conv3 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) # out:[4,32,32]
59 | self.bn3 = nn.GroupNorm(8, channel)
60 |
61 | def forward(self, h):
62 | h = self.conv1(h)
63 | h = self.relu(self.bn1(h))
64 |
65 | h = self.conv2(h)
66 | h = self.relu(self.bn2(h))
67 |
68 | h = self.conv3(h)
69 | h = self.relu(self.bn3(h))
70 | return h
71 |
72 | class Sub_Discriminator(nn.Module):
73 | def __init__(self, num_class=0, channel=256):
74 | super(Sub_Discriminator, self).__init__()
75 | self.channel = channel
76 | self.num_class = num_class
77 |
78 | self.conv2 = SNConv3d(1, channel//4, kernel_size=4, stride=2, padding=1) # out:[16,16,16]
79 | self.conv3 = SNConv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) # out:[8,8,8]
80 | self.conv4 = SNConv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) # out:[4,4,4]
81 | self.conv5 = SNConv3d(channel, 1+num_class, kernel_size=4, stride=1, padding=0) # out:[1,1,1,1]
82 |
83 | def forward(self, h):
84 | h = F.leaky_relu(self.conv2(h), negative_slope=0.2)
85 | h = F.leaky_relu(self.conv3(h), negative_slope=0.2)
86 | h = F.leaky_relu(self.conv4(h), negative_slope=0.2)
87 | if self.num_class == 0:
88 | h = self.conv5(h).view((-1,1))
89 | return h
90 | else:
91 | h = self.conv5(h).view((-1,1+self.num_class))
92 | return h[:,:1], h[:,1:]
93 |
94 | class Discriminator(nn.Module):
95 | def __init__(self, num_class=0, channel=512):
96 | super(Discriminator, self).__init__()
97 | self.channel = channel
98 | self.num_class = num_class
99 |
100 | # D^H
101 | self.conv2 = SNConv3d(1, channel//16, kernel_size=4, stride=2, padding=1) # out:[8,64,64,64]
102 | self.conv3 = SNConv3d(channel//16, channel//8, kernel_size=4, stride=2, padding=1) # out:[4,32,32,32]
103 | self.conv4 = SNConv3d(channel//8, channel//4, kernel_size=(2,4,4), stride=(2,2,2), padding=(0,1,1)) # out:[2,16,16,16]
104 | self.conv5 = SNConv3d(channel//4, channel//2, kernel_size=(2,4,4), stride=(2,2,2), padding=(0,1,1)) # out:[1,8,8,8]
105 | self.conv6 = SNConv3d(channel//2, channel, kernel_size=(1,4,4), stride=(1,2,2), padding=(0,1,1)) # out:[1,4,4,4]
106 | self.conv7 = SNConv3d(channel, channel//4, kernel_size=(1,4,4), stride=1, padding=0) # out:[1,1,1,1]
107 | self.fc1 = SNLinear(channel//4+1, channel//8)
108 | self.fc2 = SNLinear(channel//8, 1)
109 | if num_class>0:
110 | self.fc2_class = SNLinear(channel//8, num_class)
111 |
112 | # D^L
113 | self.sub_D = Sub_Discriminator(num_class)
114 |
115 | def forward(self, h, h_small, crop_idx):
116 | h = F.leaky_relu(self.conv2(h), negative_slope=0.2)
117 | h = F.leaky_relu(self.conv3(h), negative_slope=0.2)
118 | h = F.leaky_relu(self.conv4(h), negative_slope=0.2)
119 | h = F.leaky_relu(self.conv5(h), negative_slope=0.2)
120 | h = F.leaky_relu(self.conv6(h), negative_slope=0.2)
121 | h = F.leaky_relu(self.conv7(h), negative_slope=0.2).squeeze()
122 | h = torch.cat([h, (crop_idx / 112. * torch.ones((h.size(0), 1))).cuda()], 1) # 128*7/8
123 | h = F.leaky_relu(self.fc1(h), negative_slope=0.2)
124 | h_logit = self.fc2(h)
125 | if self.num_class>0:
126 | h_class_logit = self.fc2_class(h)
127 |
128 | h_small_logit, h_small_class_logit = self.sub_D(h_small)
129 | return (h_logit+ h_small_logit)/2., (h_class_logit+ h_small_class_logit)/2.
130 | else:
131 | h_small_logit = self.sub_D(h_small)
132 | return (h_logit+ h_small_logit)/2.
133 |
134 |
135 | class Sub_Generator(nn.Module):
136 | def __init__(self, channel:int=16):
137 | super(Sub_Generator, self).__init__()
138 | _c = channel
139 |
140 | self.relu = nn.ReLU()
141 | self.tp_conv1 = nn.Conv3d(_c*4, _c*2, kernel_size=3, stride=1, padding=1, bias=True)
142 | self.bn1 = nn.GroupNorm(8, _c*2)
143 |
144 | self.tp_conv2 = nn.Conv3d(_c*2, _c, kernel_size=3, stride=1, padding=1, bias=True)
145 | self.bn2 = nn.GroupNorm(8, _c)
146 |
147 | self.tp_conv3 = nn.Conv3d(_c, 1, kernel_size=3, stride=1, padding=1, bias=True)
148 |
149 | def forward(self, h):
150 |
151 | h = self.tp_conv1(h)
152 | h = self.relu(self.bn1(h))
153 |
154 | h = self.tp_conv2(h)
155 | h = self.relu(self.bn2(h))
156 |
157 | h = self.tp_conv3(h)
158 | h = torch.tanh(h)
159 | return h
160 |
161 | class Generator(nn.Module):
162 | def __init__(self, mode="train", latent_dim=1024, channel=32, num_class=0):
163 | super(Generator, self).__init__()
164 | _c = channel
165 |
166 | self.mode = mode
167 | self.relu = nn.ReLU()
168 | self.num_class = num_class
169 |
170 | # G^A and G^H
171 | self.fc1 = nn.Linear(latent_dim+num_class, 4*4*4*_c*16)
172 |
173 | self.tp_conv1 = nn.Conv3d(_c*16, _c*16, kernel_size=3, stride=1, padding=1, bias=True)
174 | self.bn1 = nn.GroupNorm(8, _c*16)
175 |
176 | self.tp_conv2 = nn.Conv3d(_c*16, _c*16, kernel_size=3, stride=1, padding=1, bias=True)
177 | self.bn2 = nn.GroupNorm(8, _c*16)
178 |
179 | self.tp_conv3 = nn.Conv3d(_c*16, _c*8, kernel_size=3, stride=1, padding=1, bias=True)
180 | self.bn3 = nn.GroupNorm(8, _c*8)
181 |
182 | self.tp_conv4 = nn.Conv3d(_c*8, _c*4, kernel_size=3, stride=1, padding=1, bias=True)
183 | self.bn4 = nn.GroupNorm(8, _c*4)
184 |
185 | self.tp_conv5 = nn.Conv3d(_c*4, _c*2, kernel_size=3, stride=1, padding=1, bias=True)
186 | self.bn5 = nn.GroupNorm(8, _c*2)
187 |
188 | self.tp_conv6 = nn.Conv3d(_c*2, _c, kernel_size=3, stride=1, padding=1, bias=True)
189 | self.bn6 = nn.GroupNorm(8, _c)
190 |
191 | self.tp_conv7 = nn.Conv3d(_c, 1, kernel_size=3, stride=1, padding=1, bias=True)
192 |
193 | # G^L
194 | self.sub_G = Sub_Generator(channel=_c//2)
195 |
196 | def forward(self, h, crop_idx=None, class_label=None):
197 |
198 | # Generate from random noise
199 | if crop_idx != None or self.mode=='eval':
200 | if self.num_class > 0:
201 | h = torch.cat((h, class_label), dim=1)
202 |
203 | h = self.fc1(h)
204 |
205 | h = h.view(-1,512,4,4,4)
206 | h = self.tp_conv1(h)
207 | h = self.relu(self.bn1(h))
208 |
209 | h = F.interpolate(h,scale_factor = 2)
210 | h = self.tp_conv2(h)
211 | h = self.relu(self.bn2(h))
212 |
213 | h = F.interpolate(h,scale_factor = 2)
214 | h = self.tp_conv3(h)
215 | h = self.relu(self.bn3(h))
216 |
217 | h = F.interpolate(h,scale_factor = 2)
218 | h = self.tp_conv4(h)
219 | h = self.relu(self.bn4(h))
220 |
221 | h = self.tp_conv5(h)
222 | h_latent = self.relu(self.bn5(h)) # (32, 32, 32), channel:128
223 |
224 | if self.mode == "train":
225 | h_small = self.sub_G(h_latent)
226 | h = h_latent[:,:,crop_idx//4:crop_idx//4+4,:,:] # Crop sub-volume, out: (4, 32, 32)
227 | else:
228 | h = h_latent
229 |
230 | # Generate from latent feature
231 | h = F.interpolate(h,scale_factor = 2)
232 | h = self.tp_conv6(h)
233 | h = self.relu(self.bn6(h)) # (64, 64, 64)
234 |
235 | h = F.interpolate(h,scale_factor = 2)
236 | h = self.tp_conv7(h)
237 |
238 | h = torch.tanh(h) # (128, 128, 128)
239 |
240 | if crop_idx != None and self.mode == "train":
241 | return h, h_small
242 | return h
243 |
--------------------------------------------------------------------------------
/models/Model_HA_GAN_256.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from models.layers import SNConv3d, SNLinear
6 |
7 | class Code_Discriminator(nn.Module):
8 | def __init__(self, code_size, num_units=256):
9 | super(Code_Discriminator, self).__init__()
10 |
11 | self.l1 = nn.Sequential(SNLinear(code_size, num_units),
12 | nn.LeakyReLU(0.2,inplace=True))
13 | self.l2 = nn.Sequential(SNLinear(num_units, num_units),
14 | nn.LeakyReLU(0.2,inplace=True))
15 | self.l3 = SNLinear(num_units, 1)
16 |
17 | def forward(self, x):
18 | x = self.l1(x)
19 | x = self.l2(x)
20 | x = self.l3(x)
21 |
22 | return x
23 |
24 | class Sub_Encoder(nn.Module):
25 | def __init__(self, channel=256, latent_dim=1024):
26 | super(Sub_Encoder, self).__init__()
27 |
28 | self.relu = nn.ReLU()
29 | self.conv1 = nn.Conv3d(channel//4, channel//8, kernel_size=4, stride=2, padding=1) # in:[64,64,64], out:[32,32,32]
30 | self.bn1 = nn.GroupNorm(8, channel//8)
31 | self.conv2 = nn.Conv3d(channel//8, channel//4, kernel_size=4, stride=2, padding=1) # out:[16,16,16]
32 | self.bn2 = nn.GroupNorm(8, channel//4)
33 | self.conv3 = nn.Conv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) # out:[8,8,8]
34 | self.bn3 = nn.GroupNorm(8, channel//2)
35 | self.conv4 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) # out:[4,4,4]
36 | self.bn4 = nn.GroupNorm(8, channel)
37 | self.conv5 = nn.Conv3d(channel, latent_dim, kernel_size=4, stride=1, padding=0) # out:[1,1,1,1]
38 |
39 | def forward(self, h):
40 | h = self.conv1(h)
41 | h = self.relu(self.bn1(h))
42 | h = self.conv2(h)
43 | h = self.relu(self.bn2(h))
44 | h = self.conv3(h)
45 | h = self.relu(self.bn3(h))
46 | h = self.conv4(h)
47 | h = self.relu(self.bn4(h))
48 | h = self.conv5(h).squeeze()
49 | return h
50 |
51 | class Encoder(nn.Module):
52 | def __init__(self, channel=64):
53 | super(Encoder, self).__init__()
54 |
55 | self.relu = nn.ReLU()
56 | self.conv1 = nn.Conv3d(1, channel//2, kernel_size=4, stride=2, padding=1) # in:[32,256,256], out:[16,128,128]
57 | self.bn1 = nn.GroupNorm(8, channel//2)
58 | self.conv2 = nn.Conv3d(channel//2, channel//2, kernel_size=3, stride=1, padding=1) # out:[16,128,128]
59 | self.bn2 = nn.GroupNorm(8, channel//2)
60 | self.conv3 = nn.Conv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) # out:[8,64,64]
61 | self.bn3 = nn.GroupNorm(8, channel)
62 |
63 | def forward(self, h):
64 | h = self.conv1(h)
65 | h = self.relu(self.bn1(h))
66 |
67 | h = self.conv2(h)
68 | h = self.relu(self.bn2(h))
69 |
70 | h = self.conv3(h)
71 | h = self.relu(self.bn3(h))
72 | return h
73 |
74 | # D^L
75 | class Sub_Discriminator(nn.Module):
76 | def __init__(self, num_class=0, channel=256):
77 | super(Sub_Discriminator, self).__init__()
78 | self.channel = channel
79 | self.num_class = num_class
80 |
81 | self.conv1 = SNConv3d(1, channel//8, kernel_size=4, stride=2, padding=1) # in:[64,64,64], out:[32,32,32]
82 | self.conv2 = SNConv3d(channel//8, channel//4, kernel_size=4, stride=2, padding=1) # out:[16,16,16]
83 | self.conv3 = SNConv3d(channel//4, channel//2, kernel_size=4, stride=2, padding=1) # out:[8,8,8]
84 | self.conv4 = SNConv3d(channel//2, channel, kernel_size=4, stride=2, padding=1) # out:[4,4,4]
85 | self.conv5 = SNConv3d(channel, 1+num_class, kernel_size=4, stride=1, padding=0) # out:[1,1,1,1]
86 |
87 | def forward(self, h):
88 | h = F.leaky_relu(self.conv1(h), negative_slope=0.2)
89 | h = F.leaky_relu(self.conv2(h), negative_slope=0.2)
90 | h = F.leaky_relu(self.conv3(h), negative_slope=0.2)
91 | h = F.leaky_relu(self.conv4(h), negative_slope=0.2)
92 | if self.num_class == 0:
93 | h = self.conv5(h).view((-1,1))
94 | return h
95 | else:
96 | h = self.conv5(h).view((-1,1+self.num_class))
97 | return h[:,:1], h[:,1:]
98 |
99 | class Discriminator(nn.Module):
100 | def __init__(self, num_class=0, channel=512):
101 | super(Discriminator, self).__init__()
102 | self.channel = channel
103 | self.num_class = num_class
104 |
105 | # D^H
106 | self.conv1 = SNConv3d(1, channel//32, kernel_size=4, stride=2, padding=1) # in:[32,256,256], out:[16,128,128]
107 | self.conv2 = SNConv3d(channel//32, channel//16, kernel_size=4, stride=2, padding=1) # out:[8,64,64,64]
108 | self.conv3 = SNConv3d(channel//16, channel//8, kernel_size=4, stride=2, padding=1) # out:[4,32,32,32]
109 | self.conv4 = SNConv3d(channel//8, channel//4, kernel_size=(2,4,4), stride=(2,2,2), padding=(0,1,1)) # out:[2,16,16,16]
110 | self.conv5 = SNConv3d(channel//4, channel//2, kernel_size=(2,4,4), stride=(2,2,2), padding=(0,1,1)) # out:[1,8,8,8]
111 | self.conv6 = SNConv3d(channel//2, channel, kernel_size=(1,4,4), stride=(1,2,2), padding=(0,1,1)) # out:[1,4,4,4]
112 | self.conv7 = SNConv3d(channel, channel//4, kernel_size=(1,4,4), stride=1, padding=0) # out:[1,1,1,1]
113 | self.fc1 = SNLinear(channel//4+1, channel//8)
114 | self.fc2 = SNLinear(channel//8, 1)
115 | if num_class>0:
116 | self.fc2_class = SNLinear(channel//8, num_class)
117 |
118 | # D^L
119 | self.sub_D = Sub_Discriminator(num_class)
120 |
121 | def forward(self, h, h_small, crop_idx):
122 | h = F.leaky_relu(self.conv1(h), negative_slope=0.2)
123 | h = F.leaky_relu(self.conv2(h), negative_slope=0.2)
124 | h = F.leaky_relu(self.conv3(h), negative_slope=0.2)
125 | h = F.leaky_relu(self.conv4(h), negative_slope=0.2)
126 | h = F.leaky_relu(self.conv5(h), negative_slope=0.2)
127 | h = F.leaky_relu(self.conv6(h), negative_slope=0.2)
128 | h = F.leaky_relu(self.conv7(h), negative_slope=0.2).squeeze()
129 | h = torch.cat([h, (crop_idx / 224. * torch.ones((h.size(0), 1))).cuda()], 1) # 256*7/8
130 | h = F.leaky_relu(self.fc1(h), negative_slope=0.2)
131 | h_logit = self.fc2(h)
132 | if self.num_class>0:
133 | h_class_logit = self.fc2_class(h)
134 |
135 | h_small_logit, h_small_class_logit = self.sub_D(h_small)
136 | return (h_logit+ h_small_logit)/2., (h_class_logit+ h_small_class_logit)/2.
137 | else:
138 | h_small_logit = self.sub_D(h_small)
139 | return (h_logit+ h_small_logit)/2.
140 |
141 |
142 | class Sub_Generator(nn.Module):
143 | def __init__(self, channel:int=16):
144 | super(Sub_Generator, self).__init__()
145 | _c = channel
146 |
147 | self.relu = nn.ReLU()
148 | self.tp_conv1 = nn.Conv3d(_c*4, _c*2, kernel_size=3, stride=1, padding=1, bias=True)
149 | self.bn1 = nn.GroupNorm(8, _c*2)
150 |
151 | self.tp_conv2 = nn.Conv3d(_c*2, _c, kernel_size=3, stride=1, padding=1, bias=True)
152 | self.bn2 = nn.GroupNorm(8, _c)
153 |
154 | self.tp_conv3 = nn.Conv3d(_c, 1, kernel_size=3, stride=1, padding=1, bias=True)
155 |
156 | def forward(self, h):
157 |
158 | h = self.tp_conv1(h)
159 | h = self.relu(self.bn1(h))
160 |
161 | h = self.tp_conv2(h)
162 | h = self.relu(self.bn2(h))
163 |
164 | h = self.tp_conv3(h)
165 | h = torch.tanh(h)
166 | return h
167 |
168 | class Generator(nn.Module):
169 | def __init__(self, mode="train", latent_dim=1024, channel=32, num_class=0):
170 | super(Generator, self).__init__()
171 | _c = channel
172 |
173 | self.mode = mode
174 | self.relu = nn.ReLU()
175 | self.num_class = num_class
176 |
177 | # G^A and G^H
178 | self.fc1 = nn.Linear(latent_dim+num_class, 4*4*4*_c*16)
179 |
180 | self.tp_conv1 = nn.Conv3d(_c*16, _c*16, kernel_size=3, stride=1, padding=1, bias=True)
181 | self.bn1 = nn.GroupNorm(8, _c*16)
182 |
183 | self.tp_conv2 = nn.Conv3d(_c*16, _c*16, kernel_size=3, stride=1, padding=1, bias=True)
184 | self.bn2 = nn.GroupNorm(8, _c*16)
185 |
186 | self.tp_conv3 = nn.Conv3d(_c*16, _c*8, kernel_size=3, stride=1, padding=1, bias=True)
187 | self.bn3 = nn.GroupNorm(8, _c*8)
188 |
189 | self.tp_conv4 = nn.Conv3d(_c*8, _c*4, kernel_size=3, stride=1, padding=1, bias=True)
190 | self.bn4 = nn.GroupNorm(8, _c*4)
191 |
192 | self.tp_conv5 = nn.Conv3d(_c*4, _c*2, kernel_size=3, stride=1, padding=1, bias=True)
193 | self.bn5 = nn.GroupNorm(8, _c*2)
194 |
195 | self.tp_conv6 = nn.Conv3d(_c*2, _c, kernel_size=3, stride=1, padding=1, bias=True)
196 | self.bn6 = nn.GroupNorm(8, _c)
197 |
198 | self.tp_conv7 = nn.Conv3d(_c, 1, kernel_size=3, stride=1, padding=1, bias=True)
199 |
200 | # G^L
201 | self.sub_G = Sub_Generator(channel=_c//2)
202 |
203 | def forward(self, h, crop_idx=None, class_label=None):
204 |
205 | # Generate from random noise
206 | if crop_idx != None or self.mode=='eval':
207 | if self.num_class > 0:
208 | h = torch.cat((h, class_label), dim=1)
209 |
210 | h = self.fc1(h)
211 |
212 | h = h.view(-1,512,4,4,4)
213 | h = self.tp_conv1(h)
214 | h = self.relu(self.bn1(h))
215 |
216 | h = F.interpolate(h,scale_factor = 2)
217 | h = self.tp_conv2(h)
218 | h = self.relu(self.bn2(h))
219 |
220 | h = F.interpolate(h,scale_factor = 2)
221 | h = self.tp_conv3(h)
222 | h = self.relu(self.bn3(h))
223 |
224 | h = F.interpolate(h,scale_factor = 2)
225 | h = self.tp_conv4(h)
226 | h = self.relu(self.bn4(h))
227 |
228 | h = F.interpolate(h,scale_factor = 2)
229 | h = self.tp_conv5(h)
230 | h_latent = self.relu(self.bn5(h)) # (64, 64, 64), channel:128
231 |
232 | if self.mode == "train":
233 | h_small = self.sub_G(h_latent)
234 | h = h_latent[:,:,crop_idx//4:crop_idx//4+8,:,:] # Crop, out: (8, 64, 64)
235 | else:
236 | h = h_latent
237 |
238 | # Generate from latent feature
239 | h = F.interpolate(h,scale_factor = 2)
240 | h = self.tp_conv6(h)
241 | h = self.relu(self.bn6(h)) # (128, 128, 128)
242 |
243 | h = F.interpolate(h,scale_factor = 2)
244 | h = self.tp_conv7(h)
245 |
246 | h = torch.tanh(h) # (256, 256, 256)
247 |
248 | if crop_idx != None and self.mode == "train":
249 | return h, h_small
250 | return h
251 |
--------------------------------------------------------------------------------
/evaluation/fid_score.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import time
5 | from argparse import ArgumentParser
6 |
7 | import numpy as np
8 | from scipy import linalg
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch.nn import functional as F
13 |
14 | from volume_dataset import Volume_Dataset
15 |
16 | from models.Model_HA_GAN_256 import Generator, Encoder, Sub_Encoder
17 | from resnet3D import resnet50
18 |
19 | torch.manual_seed(0)
20 | torch.cuda.manual_seed_all(0)
21 | torch.backends.cudnn.benchmark = True
22 |
23 | parser = ArgumentParser()
24 | parser.add_argument('--path', type=str, default='')
25 | parser.add_argument('--real_suffix', type=str, default='eval_600_size_256_resnet50_fold')
26 | parser.add_argument('--img_size', type=int, default=256)
27 | parser.add_argument('--batch_size', type=int, default=2)
28 | parser.add_argument('--num_workers', type=int, default=8)
29 | parser.add_argument('--num_samples', type=int, default=2048)
30 | parser.add_argument('--dims', type=int, default=2048)
31 | parser.add_argument('--ckpt_step', type=int, default=80000)
32 | parser.add_argument('--latent_dim', type=int, default=1024)
33 | parser.add_argument('--basename', type=str, default="256_1024_Alpha_SN_v4plus_4_l1_GN_threshold_600_fold")
34 | parser.add_argument('--fold', type=int)
35 |
36 | def trim_state_dict_name(ckpt):
37 | new_state_dict = OrderedDict()
38 | for k, v in ckpt.items():
39 | name = k[7:] # remove `module.`
40 | new_state_dict[name] = v
41 | return new_state_dict
42 |
43 | class Flatten(torch.nn.Module):
44 | def forward(self, inp):
45 | return inp.view(inp.size(0), -1)
46 |
47 | def generate_samples(args):
48 | G = Generator(mode='eval', latent_dim=args.latent_dim, num_class=0)
49 | ckpt_path = "./checkpoint/"+args.basename+str(args.fold)+"/G_iter"+str(args.ckpt_step)+".pth"
50 | ckpt = torch.load(ckpt_path)['model']
51 | ckpt = trim_state_dict_name(ckpt)
52 | G.load_state_dict(ckpt)
53 |
54 | E = Encoder()
55 | ckpt_path = "./checkpoint/"+args.basename+str(args.fold)+"/E_iter"+str(args.ckpt_step)+".pth"
56 | ckpt = torch.load(ckpt_path)['model']
57 | ckpt = trim_state_dict_name(ckpt)
58 | E.load_state_dict(ckpt)
59 |
60 | Sub_E = Sub_Encoder(args.latent_dim=args.latent_dim)
61 | ckpt_path = "./checkpoint/"+args.basename+str(args.fold)+"/Sub_E_iter"+str(args.ckpt_step)+".pth"
62 | ckpt = torch.load(ckpt_path)['model']
63 | ckpt = trim_state_dict_name(ckpt)
64 | Sub_E.load_state_dict(ckpt)
65 | print("Weights step", args.ckpt_step, "loaded.")
66 | del ckpt
67 |
68 | G = nn.DataParallel(G).cuda()
69 | E = nn.DataParallel(E).cuda()
70 | Sub_E = nn.DataParallel(Sub_E).cuda()
71 |
72 | G.eval()
73 | E.eval()
74 | Sub_E.eval()
75 |
76 |
77 | model = get_feature_extractor()
78 | pred_arr = np.empty((args.num_samples, args.dims))
79 |
80 | for i in range(args.num_samples//args.batch_size):
81 | if i % 10 == 0:
82 | print('\rPropagating batch %d' % i, end='', flush=True)
83 | with torch.no_grad():
84 |
85 | noise = torch.randn((args.batch_size, args.latent_size)).cuda()
86 | x_rand = G(noise) # dumb index 0, not used
87 | # range: [-1,1]
88 | x_rand = x_rand.detach()
89 | pred = model(x_rand)
90 |
91 | if (i+1)*args.batch_size > pred_arr.shape[0]:
92 | pred_arr[i*args.batch_size:] = pred.cpu().numpy()
93 | else:
94 | pred_arr[i*args.batch_size:(i+1)*args.batch_size] = pred.cpu().numpy()
95 |
96 | print(' done')
97 | return pred_arr
98 |
99 | def get_activations_from_dataloader(model, data_loader, args):
100 |
101 | pred_arr = np.empty((args.num_samples, args.dims))
102 | for i, batch in enumerate(data_loader):
103 | if i % 10 == 0:
104 | print('\rPropagating batch %d' % i, end='', flush=True)
105 | batch = batch.float().cuda()
106 | with torch.no_grad():
107 | pred = model(batch)
108 |
109 | if i*args.batch_size > pred_arr.shape[0]:
110 | pred_arr[i*args.batch_size:] = pred.cpu().numpy()
111 | else:
112 | pred_arr[i*args.batch_size:(i+1)*args.batch_size] = pred.cpu().numpy()
113 | print(' done')
114 | return pred_arr
115 |
116 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
117 | """Numpy implementation of the Frechet Distance.
118 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
119 | and X_2 ~ N(mu_2, C_2) is
120 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
121 |
122 | Stable version by Dougal J. Sutherland.
123 |
124 | Params:
125 | -- mu1 : Numpy array containing the activations of a layer of the
126 | inception net (like returned by the function 'get_predictions')
127 | for generated samples.
128 | -- mu2 : The sample mean over activations, precalculated on an
129 | representative data set.
130 | -- sigma1: The covariance matrix over activations for generated samples.
131 | -- sigma2: The covariance matrix over activations, precalculated on an
132 | representative data set.
133 |
134 | Returns:
135 | -- : The Frechet Distance.
136 | """
137 |
138 | mu1 = np.atleast_1d(mu1)
139 | mu2 = np.atleast_1d(mu2)
140 |
141 | sigma1 = np.atleast_2d(sigma1)
142 | sigma2 = np.atleast_2d(sigma2)
143 |
144 | assert mu1.shape == mu2.shape, \
145 | 'Training and test mean vectors have different lengths'
146 | assert sigma1.shape == sigma2.shape, \
147 | 'Training and test covariances have different dimensions'
148 |
149 | diff = mu1 - mu2
150 |
151 | # Product might be almost singular
152 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
153 | if not np.isfinite(covmean).all():
154 | msg = ('fid calculation produces singular product; '
155 | 'adding %s to diagonal of cov estimates') % eps
156 | print(msg)
157 | offset = np.eye(sigma1.shape[0]) * eps
158 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
159 |
160 | # Numerical error might give slight imaginary component
161 | if np.iscomplexobj(covmean):
162 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
163 | m = np.max(np.abs(covmean.imag))
164 | raise ValueError('Imaginary component {}'.format(m))
165 | covmean = covmean.real
166 |
167 | tr_covmean = np.trace(covmean)
168 |
169 | return (diff.dot(diff) + np.trace(sigma1) +
170 | np.trace(sigma2) - 2 * tr_covmean)
171 |
172 | def post_process(act):
173 | mu = np.mean(act, axis=0)
174 | sigma = np.cov(act, rowvar=False)
175 | return mu, sigma
176 |
177 | def get_feature_extractor():
178 | model = resnet50(shortcut_type='B')
179 | model.conv_seg = nn.Sequential(nn.AdaptiveAvgPool3d((1, 1, 1)),
180 | Flatten()) # (N, 512)
181 | # ckpt from https://drive.google.com/file/d/1399AsrYpQDi1vq6ciKRQkfknLsQQyigM/view?usp=sharing
182 | ckpt = torch.load("../gnn_shared/ckpt/pretrain/resnet_50.pth")
183 | ckpt = trim_state_dict_name(ckpt["state_dict"])
184 | model.load_state_dict(ckpt) # No conv_seg module in ckpt
185 | model = nn.DataParallel(model).cuda()
186 | model.eval()
187 | print("Feature extractor weights loaded")
188 | return model
189 |
190 | def calculate_fid_real(args):
191 | """Calculates the FID of two paths"""
192 | assert os.path.exists("./results/fid/m_real_"+args.real_suffix+str(args.fold)+".npy")
193 |
194 | model = get_feature_extractor()
195 | #dataset = COPD_dataset(img_size=args.img_size, stage="train", fold=args.fold, threshold=600)
196 | dataset = Brain_dataset(img_size=args.img_size, stage="train", fold=args.fold)
197 | args.num_samples = len(dataset)
198 | print("Number of samples:", args.num_samples)
199 | data_loader = torch.utils.data.DataLoader(dataset,batch_size=args.batch_size,drop_last=False,
200 | shuffle=False,num_workers=args.num_workers)
201 | act = get_activations_from_dataloader(model, data_loader, args)
202 | np.save("./results/fid/pred_arr_real_train_size_"+str(args.img_size)+"_resnet50_GSP_fold"+str(args.fold)+".npy", act)
203 | #np.save("./results/fid/pred_arr_real_train_600_size_"+str(args.img_size)+"_resnet50_fold"+str(args.fold)+".npy", act)
204 | #calculate_mmd(args, act)
205 | m, s = post_process(act)
206 |
207 | m1 = np.load("./results/fid/m_real_"+args.real_suffix+str(args.fold)+".npy")
208 | s1 = np.load("./results/fid/s_real_"+args.real_suffix+str(args.fold)+".npy")
209 |
210 | fid_value = calculate_frechet_distance(m1, s1, m, s)
211 | print('FID: ', fid_value)
212 | #np.save("./results/fid/m_real_train_600_size_"+str(args.img_size)+"_resnet50_fold"+str(args.fold)+".npy", m)
213 | #np.save("./results/fid/s_real_train_600_size_"+str(args.img_size)+"_resnet50_fold"+str(args.fold)+".npy", s)
214 | #np.save("./results/fid/m_real_train_size_"+str(args.img_size)+"_resnet50_GSP_fold"+str(args.fold)+".npy", m)
215 | #np.save("./results/fid/s_real_train_size_"+str(args.img_size)+"_resnet50_GSP_fold"+str(args.fold)+".npy", s)
216 |
217 | def calculate_mmd_fake(args):
218 | assert os.path.exists("./results/fid/pred_arr_real_"+args.real_suffix+str(args.fold)+".npy")
219 | act = generate_samples(args)
220 | calculate_mmd(args, act)
221 |
222 | def calculate_fid_fake(args):
223 | #assert os.path.exists("./results/fid/m_real_"+args.real_suffix+str(args.fold)+".npy")
224 | act = generate_samples(args)
225 | m2, s2 = post_process(act)
226 |
227 | m1 = np.load("./results/fid/m_real_"+args.real_suffix+str(args.fold)+".npy")
228 | s1 = np.load("./results/fid/s_real_"+args.real_suffix+str(args.fold)+".npy")
229 |
230 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
231 | print('FID: ', fid_value)
232 |
233 |
234 | def calculate_mmd(args, act):
235 | from torch_two_sample.statistics_diff import MMDStatistic
236 |
237 | act_real = np.load("./results/fid/pred_arr_real_"+args.real_suffix+str(args.fold)+".npz")['arr_0']
238 | mmd = MMDStatistic(act_real.shape[0], act.shape[0])
239 | sample_1 = torch.from_numpy(act_real)
240 | sample_2 = torch.from_numpy(act)
241 |
242 | # Need to install updated MMD package at https://github.com/lisun-ai/torch-two-sample for support of median alphas
243 | test_statistics, ret_matrix = mmd(sample_1, sample_2, alphas='median', ret_matrix=True)
244 | #p = mmd.pval(ret_matrix.float(), n_permutations=1000)
245 |
246 | print("\nMMD test statistics:", test_statistics.item())
247 |
248 | if __name__ == '__main__':
249 | args = parser.parse_args()
250 | start_time = time.time()
251 | calculate_fid_real(args)
252 | calculate_fid_fake(args)
253 | print("Done. Using", (time.time()-start_time)//60, "minutes.")
254 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # train HA-GAN
4 | # Hierarchical Amortized GAN for 3D High Resolution Medical Image Synthesis
5 | # https://ieeexplore.ieee.org/abstract/document/9770375
6 |
7 | import numpy as np
8 | import torch
9 | import os
10 | import json
11 | import argparse
12 |
13 | from torch import nn
14 | from torch import optim
15 | from torch.nn import functional as F
16 | from tensorboardX import SummaryWriter
17 | import nibabel as nib
18 | from nilearn import plotting
19 |
20 | from utils import trim_state_dict_name, inf_train_gen
21 | from volume_dataset import Volume_Dataset
22 |
23 | import matplotlib.pyplot as plt
24 |
25 | torch.manual_seed(0)
26 | torch.cuda.manual_seed_all(0)
27 | torch.backends.cudnn.benchmark = True
28 |
29 | parser = argparse.ArgumentParser(description='PyTorch HA-GAN Training')
30 | parser.add_argument('--batch-size', default=4, type=int,
31 | help='mini-batch size (default: 4), this is the total '
32 | 'batch size of all GPUs')
33 | parser.add_argument('--workers', default=8, type=int,
34 | help='number of data loading workers (default: 8)')
35 | parser.add_argument('--img-size', default=256, type=int,
36 | help='size of training images (default: 256, can be 128 or 256)')
37 | parser.add_argument('--num-iter', default=80000, type=int,
38 | help='number of iteration for training (default: 80000)')
39 | parser.add_argument('--log-iter', default=20, type=int,
40 | help='number of iteration between logging (default: 20)')
41 | parser.add_argument('--continue-iter', default=0, type=int,
42 | help='continue from a ckeckpoint that has run for n iteration (0 if a new run)')
43 | parser.add_argument('--latent-dim', default=1024, type=int,
44 | help='size of the input latent variable')
45 | parser.add_argument('--g-iter', default=1, type=int,
46 | help='number of generator pass per iteration')
47 | parser.add_argument('--lr-g', default=0.0001, type=float,
48 | help='learning rate for the generator')
49 | parser.add_argument('--lr-d', default=0.0004, type=float,
50 | help='learning rate for the discriminator')
51 | parser.add_argument('--lr-e', default=0.0001, type=float,
52 | help='learning rate for the encoder')
53 | parser.add_argument('--data-dir', type=str,
54 | help='path to the preprocessed data folder')
55 | parser.add_argument('--exp-name', default='HA_GAN_run1', type=str,
56 | help='name of the experiment')
57 | parser.add_argument('--fold', default=0, type=int,
58 | help='fold number for cross validation')
59 |
60 | # configs for conditional generation
61 | parser.add_argument('--lambda-class', default=0.1, type=float,
62 | help='weights for the auxiliary classifier loss')
63 | parser.add_argument('--num-class', default=0, type=int,
64 | help='number of class for auxiliary classifier (0 if unconditional)')
65 |
66 | def main():
67 | # Configuration
68 | args = parser.parse_args()
69 |
70 | trainset = Volume_Dataset(data_dir=args.data_dir, fold=args.fold, num_class=args.num_class)
71 | train_loader = torch.utils.data.DataLoader(trainset,batch_size=args.batch_size,drop_last=True,
72 | shuffle=False,num_workers=args.workers)
73 | gen_load = inf_train_gen(train_loader)
74 |
75 | if args.img_size == 256:
76 | from models.Model_HA_GAN_256 import Discriminator, Generator, Encoder, Sub_Encoder
77 | elif args.img_size == 128:
78 | from models.Model_HA_GAN_128 import Discriminator, Generator, Encoder, Sub_Encoder
79 | else:
80 | raise NotImplmentedError
81 |
82 | G = Generator(mode='train', latent_dim=args.latent_dim, num_class=args.num_class).cuda()
83 | D = Discriminator(num_class=args.num_class).cuda()
84 | E = Encoder().cuda()
85 | Sub_E = Sub_Encoder(latent_dim=args.latent_dim).cuda()
86 |
87 | g_optimizer = optim.Adam(G.parameters(), lr=args.lr_g, betas=(0.0,0.999), eps=1e-8)
88 | d_optimizer = optim.Adam(D.parameters(), lr=args.lr_d, betas=(0.0,0.999), eps=1e-8)
89 | e_optimizer = optim.Adam(E.parameters(), lr=args.lr_e, betas=(0.0,0.999), eps=1e-8)
90 | sub_e_optimizer = optim.Adam(Sub_E.parameters(), lr=args.lr_e, betas=(0.0,0.999), eps=1e-8)
91 |
92 | # Resume from a previous checkpoint
93 | if args.continue_iter != 0:
94 | ckpt_path = './checkpoint/'+args.exp_name+'/G_iter'+str(args.continue_iter)+'.pth'
95 | ckpt = torch.load(ckpt_path, map_location='cuda')
96 | ckpt['model'] = trim_state_dict_name(ckpt['model'])
97 | G.load_state_dict(ckpt['model'])
98 | g_optimizer.load_state_dict(ckpt['optimizer'])
99 | ckpt_path = './checkpoint/'+args.exp_name+'/D_iter'+str(args.continue_iter)+'.pth'
100 | ckpt = torch.load(ckpt_path, map_location='cuda')
101 | ckpt['model'] = trim_state_dict_name(ckpt['model'])
102 | D.load_state_dict(ckpt['model'])
103 | d_optimizer.load_state_dict(ckpt['optimizer'])
104 | ckpt_path = './checkpoint/'+args.exp_name+'/E_iter'+str(args.continue_iter)+'.pth'
105 | ckpt = torch.load(ckpt_path, map_location='cuda')
106 | ckpt['model'] = trim_state_dict_name(ckpt['model'])
107 | E.load_state_dict(ckpt['model'])
108 | e_optimizer.load_state_dict(ckpt['optimizer'])
109 | ckpt_path = './checkpoint/'+args.exp_name+'/Sub_E_iter'+str(args.continue_iter)+'.pth'
110 | ckpt = torch.load(ckpt_path, map_location='cuda')
111 | ckpt['model'] = trim_state_dict_name(ckpt['model'])
112 | Sub_E.load_state_dict(ckpt['model'])
113 | sub_e_optimizer.load_state_dict(ckpt['optimizer'])
114 | del ckpt
115 | print("Ckpt", args.exp_name, args.continue_iter, "loaded.")
116 |
117 | G = nn.DataParallel(G)
118 | D = nn.DataParallel(D)
119 | E = nn.DataParallel(E)
120 | Sub_E = nn.DataParallel(Sub_E)
121 |
122 | G.train()
123 | D.train()
124 | E.train()
125 | Sub_E.train()
126 |
127 | real_y = torch.ones((args.batch_size, 1)).cuda()
128 | fake_y = torch.zeros((args.batch_size, 1)).cuda()
129 |
130 | loss_f = nn.BCEWithLogitsLoss()
131 | loss_mse = nn.L1Loss()
132 |
133 | fake_labels = torch.zeros((args.batch_size, 1)).cuda()
134 | real_labels = torch.ones((args.batch_size, 1)).cuda()
135 |
136 | summary_writer = SummaryWriter("./checkpoint/"+args.exp_name)
137 |
138 | # save configurations to a dictionary
139 | with open(os.path.join("./checkpoint/"+args.exp_name, 'configs.json'), 'w') as f:
140 | json.dump(vars(args), f, indent=2)
141 |
142 | for p in D.parameters():
143 | p.requires_grad = False
144 | for p in G.parameters():
145 | p.requires_grad = False
146 | for p in E.parameters():
147 | p.requires_grad = False
148 | for p in Sub_E.parameters():
149 | p.requires_grad = False
150 |
151 | for iteration in range(args.continue_iter, args.num_iter):
152 |
153 | ###############################################
154 | # Train Discriminator (D^H and D^L)
155 | ###############################################
156 | for p in D.parameters():
157 | p.requires_grad = True
158 | for p in Sub_E.parameters():
159 | p.requires_grad = False
160 |
161 | real_images, class_label = gen_load.__next__()
162 | D.zero_grad()
163 | real_images = real_images.float().cuda()
164 | # low-res full volume of real image
165 | real_images_small = F.interpolate(real_images, scale_factor = 0.25)
166 |
167 | # randomly select a high-res sub-volume from real image
168 | crop_idx = np.random.randint(0,args.img_size*7/8+1) # 256 * 7/8 + 1
169 | real_images_crop = real_images[:,:,crop_idx:crop_idx+args.img_size//8,:,:]
170 |
171 | if args.num_class == 0: # unconditional
172 | y_real_pred = D(real_images_crop, real_images_small, crop_idx)
173 | d_real_loss = loss_f(y_real_pred, real_labels)
174 |
175 | # random generation
176 | noise = torch.randn((args.batch_size, args.latent_dim)).cuda()
177 | # fake_images: high-res sub-volume of generated image
178 | # fake_images_small: low-res full volume of generated image
179 | fake_images, fake_images_small = G(noise, crop_idx=crop_idx, class_label=None)
180 | y_fake_pred = D(fake_images, fake_images_small, crop_idx)
181 |
182 | else: # conditional
183 | class_label_onehot = F.one_hot(class_label, num_classes=args.num_class)
184 | class_label = class_label.long().cuda()
185 | class_label_onehot = class_label_onehot.float().cuda()
186 |
187 | y_real_pred, y_real_class = D(real_images_crop, real_images_small, crop_idx)
188 | # GAN loss + auxiliary classifier loss
189 | d_real_loss = loss_f(y_real_pred, real_labels) + \
190 | F.cross_entropy(y_real_class, class_label)
191 |
192 | # random generation
193 | noise = torch.randn((args.batch_size, args.latent_dim)).cuda()
194 | fake_images, fake_images_small = G(noise, crop_idx=crop_idx, class_label=class_label_onehot)
195 | y_fake_pred, y_fake_class= D(fake_images, fake_images_small, crop_idx)
196 |
197 | d_fake_loss = loss_f(y_fake_pred, fake_labels)
198 |
199 | d_loss = d_real_loss + d_fake_loss
200 | d_loss.backward()
201 |
202 | d_optimizer.step()
203 |
204 | ###############################################
205 | # Train Generator (G^A, G^H and G^L)
206 | ###############################################
207 | for p in D.parameters():
208 | p.requires_grad = False
209 | for p in G.parameters():
210 | p.requires_grad = True
211 |
212 | for iters in range(args.g_iter):
213 | G.zero_grad()
214 |
215 | noise = torch.randn((args.batch_size, args.latent_dim)).cuda()
216 | if args.num_class == 0: # unconditional
217 | fake_images, fake_images_small = G(noise, crop_idx=crop_idx, class_label=None)
218 |
219 | y_fake_g = D(fake_images, fake_images_small, crop_idx)
220 | g_loss = loss_f(y_fake_g, real_labels)
221 | else: # conditional
222 | fake_images, fake_images_small = G(noise, crop_idx=crop_idx, class_label=class_label_onehot)
223 |
224 | y_fake_g, y_fake_g_class = D(fake_images, fake_images_small, crop_idx)
225 | g_loss = loss_f(y_fake_g, real_labels) + \
226 | args.lambda_class * F.cross_entropy(y_fake_g_class, class_label)
227 |
228 | g_loss.backward()
229 | g_optimizer.step()
230 |
231 | ###############################################
232 | # Train Encoder (E^H)
233 | ###############################################
234 | for p in E.parameters():
235 | p.requires_grad = True
236 | for p in G.parameters():
237 | p.requires_grad = False
238 | E.zero_grad()
239 |
240 | z_hat = E(real_images_crop)
241 | x_hat = G(z_hat, crop_idx=None)
242 |
243 | e_loss = loss_mse(x_hat, real_images_crop)
244 | e_loss.backward()
245 | e_optimizer.step()
246 |
247 | ###############################################
248 | # Train Sub Encoder (E^G)
249 | ###############################################
250 | for p in Sub_E.parameters():
251 | p.requires_grad = True
252 | for p in E.parameters():
253 | p.requires_grad = False
254 | Sub_E.zero_grad()
255 |
256 | with torch.no_grad():
257 | z_hat_i_list = []
258 | # Process all sub-volume and concatenate
259 | for crop_idx_i in range(0,args.img_size,args.img_size//8):
260 | real_images_crop_i = real_images[:,:,crop_idx_i:crop_idx_i+args.img_size//8,:,:]
261 | z_hat_i = E(real_images_crop_i)
262 | z_hat_i_list.append(z_hat_i)
263 | z_hat = torch.cat(z_hat_i_list, dim=2).detach()
264 | sub_z_hat = Sub_E(z_hat)
265 | # Reconstruction
266 | if args.num_class == 0: # unconditional
267 | sub_x_hat_rec, sub_x_hat_rec_small = G(sub_z_hat, crop_idx=crop_idx)
268 | else: # conditional
269 | sub_x_hat_rec, sub_x_hat_rec_small = G(sub_z_hat, crop_idx=crop_idx, class_label=class_label_onehot)
270 |
271 | sub_e_loss = (loss_mse(sub_x_hat_rec,real_images_crop) + loss_mse(sub_x_hat_rec_small,real_images_small))/2.
272 |
273 | sub_e_loss.backward()
274 | sub_e_optimizer.step()
275 |
276 | # Logging
277 | if iteration%args.log_iter == 0:
278 | summary_writer.add_scalar('D', d_loss.item(), iteration)
279 | summary_writer.add_scalar('D_real', d_real_loss.item(), iteration)
280 | summary_writer.add_scalar('D_fake', d_fake_loss.item(), iteration)
281 | summary_writer.add_scalar('G_fake', g_loss.item(), iteration)
282 | summary_writer.add_scalar('E', e_loss.item(), iteration)
283 | summary_writer.add_scalar('Sub_E', sub_e_loss.item(), iteration)
284 |
285 | ###############################################
286 | # Visualization with Tensorboard
287 | ###############################################
288 | if iteration%200 == 0:
289 | print('[{}/{}]'.format(iteration,args.num_iter),
290 | 'D_real: {:<8.3}'.format(d_real_loss.item()),
291 | 'D_fake: {:<8.3}'.format(d_fake_loss.item()),
292 | 'G_fake: {:<8.3}'.format(g_loss.item()),
293 | 'Sub_E: {:<8.3}'.format(sub_e_loss.item()),
294 | 'E: {:<8.3}'.format(e_loss.item()))
295 |
296 | featmask = np.squeeze((0.5*real_images_crop[0]+0.5).data.cpu().numpy())
297 | featmask = nib.Nifti1Image(featmask.transpose((2,1,0)),affine = np.eye(4))
298 | fig=plt.figure()
299 | plotting.plot_img(featmask,title="REAL",cut_coords=(args.img_size//2,args.img_size//2,args.img_size//16),figure=fig,draw_cross=False,cmap="gray")
300 | summary_writer.add_figure('Real', fig, iteration, close=True)
301 |
302 | featmask = np.squeeze((0.5*sub_x_hat_rec[0]+0.5).data.cpu().numpy())
303 | featmask = nib.Nifti1Image(featmask.transpose((2,1,0)),affine = np.eye(4))
304 | fig=plt.figure()
305 | plotting.plot_img(featmask,title="REC",cut_coords=(args.img_size//2,args.img_size//2,args.img_size//16),figure=fig,draw_cross=False,cmap="gray")
306 | summary_writer.add_figure('Rec', fig, iteration, close=True)
307 |
308 | featmask = np.squeeze((0.5*fake_images[0]+0.5).data.cpu().numpy())
309 | featmask = nib.Nifti1Image(featmask.transpose((2,1,0)),affine = np.eye(4))
310 | fig=plt.figure()
311 | plotting.plot_img(featmask,title="FAKE",cut_coords=(args.img_size//2,args.img_size//2,args.img_size//16),figure=fig,draw_cross=False,cmap="gray")
312 | summary_writer.add_figure('Fake', fig, iteration, close=True)
313 |
314 | if iteration > 30000 and (iteration+1)%500 == 0:
315 | torch.save({'model':G.state_dict(), 'optimizer':g_optimizer.state_dict()},'./checkpoint/'+args.exp_name+'/G_iter'+str(iteration+1)+'.pth')
316 | torch.save({'model':D.state_dict(), 'optimizer':d_optimizer.state_dict()},'./checkpoint/'+args.exp_name+'/D_iter'+str(iteration+1)+'.pth')
317 | torch.save({'model':E.state_dict(), 'optimizer':e_optimizer.state_dict()},'./checkpoint/'+args.exp_name+'/E_iter'+str(iteration+1)+'.pth')
318 | torch.save({'model':Sub_E.state_dict(), 'optimizer':sub_e_optimizer.state_dict()},'./checkpoint/'+args.exp_name+'/Sub_E_iter'+str(iteration+1)+'.pth')
319 |
320 | if __name__ == '__main__':
321 | main()
322 |
--------------------------------------------------------------------------------