├── LICENSE
├── Overview.png
├── README.md
├── configs
├── brats_linear.yml
├── ldfd_linear.yml
└── pmub_linear.yml
├── datasets
├── BRATS.py
├── LDFDCT.py
├── __init__.py
├── pmub.py
├── sr_util.py
└── utils.py
├── ddpm_main.py
├── fast_ddpm_main.py
├── functions
├── __init__.py
├── ckpt_util.py
├── denoising.py
└── losses.py
├── models
├── diffusion.py
└── ema.py
└── runners
├── __init__.py
└── diffusion.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 mirth AI lab at UF
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mirthAI/Fast-DDPM/649a14a6093d14f4286a6b6f9963dd208ce07928/Overview.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fast-DDPM
2 |
3 | Official PyTorch implementation of:
4 |
5 | [Fast-DDPM: Fast Denoising Diffusion Probabilistic Models for Medical Image-to-Image Generation](https://ieeexplore.ieee.org/abstract/document/10979336) (JBHI 2025)
6 |
7 | We propose Fast-DDPM, a simple yet effective approach that improves training speed, sampling speed, and generation quality of diffusion models simultaneously. Fast-DDPM trains and samples using only 10 time steps, reducing the training time to 0.2x and the sampling time to 0.01x compared to DDPM.
8 |
9 |
10 |
11 |
12 |
13 | The code is only for research purposes. If you have any questions regarding how to use this code, feel free to contact Hongxu Jiang (hongxu.jiang@medicine.ufl.edu).
14 |
15 | ## Requirements
16 | * Python==3.10.6
17 | * torch==1.12.1
18 | * torchvision==0.15.2
19 | * numpy
20 | * opencv-python
21 | * tqdm
22 | * tensorboard
23 | * tensorboardX
24 | * scikit-image
25 | * medpy
26 | * pillow
27 | * scipy
28 | * `pip install -r requirements.txt`
29 |
30 | ## Publicly available Dataset
31 | - Prostate-MRI-US-Biopsy dataset
32 | - LDCT-and-Projection-data dataset
33 | - BraTS 2018 dataset
34 | - The processed dataset can be accessed here: https://drive.google.com/file/d/1kF0g8fMR5XPQ2FTbutfTQ-hwG_mTqerx/view?usp=drive_link.
35 |
36 | ## Usage
37 | ### 1. Git clone or download the codes.
38 |
39 | ### 2. Pretrained model weights
40 | * We provide pretrained model weights for all three tasks, where you can access them here: https://drive.google.com/file/d/1ndS-eLegqwCOUoLT1B-HQiqRQqZUMKVF/view?usp=sharing.
41 | * As shown in ablation study, the defaulted 10 time steps may not be optimal for every task, you're more welcome to train Fast-DDPM model on your dataset using different settings.
42 |
43 | ### 3. Prepare data
44 | * Please download our processed dataset or download from the official websites.
45 | * After downloading, extract the file and put it into folder "data/". The directory structure should be as follows:
46 |
47 | ```bash
48 | ├── configs
49 | │
50 | ├── data
51 | │ ├── LD_FD_CT_train
52 | │ ├── LD_FD_CT_test
53 | │ ├── PMUB-train
54 | │ ├── PMUB-test
55 | │ ├── Brats_train
56 | │ └── Brats_test
57 | │
58 | ├── datasets
59 | │
60 | ├── functions
61 | │
62 | ├── models
63 | │
64 | └── runners
65 |
66 | ```
67 |
68 | ### 4. Training/Sampling a Fast-DDPM model
69 | * Please make sure that the hyperparameters such as scheduler type and timesteps are consistent between training and sampling.
70 | * The total number of time steps is defaulted as 1000 in the paper, so the number of involved time steps for Fast-DDPM should be less than 1000 as an integer.
71 | ```
72 | python fast_ddpm_main.py --config {DATASET}.yml --dataset {DATASET_NAME} --exp {PROJECT_PATH} --doc {MODEL_NAME} --scheduler_type {SAMPLING STRATEGY} --timesteps {STEPS}
73 | ```
74 | ```
75 | python fast_ddpm_main.py --config {DATASET}.yml --dataset {DATASET_NAME} --exp {PROJECT_PATH} --doc {MODEL_NAME} --sample --fid --scheduler_type {SAMPLING STRATEGY} --timesteps {STEPS}
76 | ```
77 |
78 | where
79 | - `DATASET_NAME` should be selected among `LDFDCT` for image denoising task, `BRATS` for image-to-image translation task and `PMUB` for multi image super-resolution task.
80 | - `SAMPLING STRATEGY` controls the scheduler sampling strategy proposed in the paper (either uniform or non-uniform).
81 | - `STEPS` controls how many timesteps used in the training and inference process. It should be an integer less than 1000 for Fast-DDPM, which is 10 by default.
82 |
83 |
84 | ### 5. Training/Sampling a DDPM model
85 | * Please make sure that the hyperparameters such as scheduler type and timesteps are consistent between training and sampling.
86 | * The total number of time steps is defaulted as 1000 in the paper, so the number of time steps for DDPM is defaulted as 1000.
87 | ```
88 | python ddpm_main.py --config {DATASET}.yml --dataset {DATASET_NAME} --exp {PROJECT_PATH} --doc {MODEL_NAME} --timesteps {STEPS}
89 | ```
90 | ```
91 | python ddpm_main.py --config {DATASET}.yml --dataset {DATASET_NAME} --exp {PROJECT_PATH} --doc {MODEL_NAME} --sample --fid --timesteps {STEPS}
92 | ```
93 |
94 | where
95 | - `DATASET_NAME` should be selected among `LDFDCT` for image denoising task, `BRATS` for image-to-image translation task and `PMUB` for multi image super-resolution task.
96 | - `STEPS` controls how many timesteps used in the training and inference process. It should be 1000 in the setting of this paper.
97 |
98 |
99 | ## References
100 | * The code is mainly adapted from [DDIM](https://github.com/ermongroup/ddim).
101 |
102 |
103 | ## Citations
104 | If you use our code or dataset, please cite our paper as below:
105 | ```bibtex
106 | @article{jiang2025fast,
107 | title={Fast-DDPM: Fast denoising diffusion probabilistic models for medical image-to-image generation},
108 | author={Jiang, Hongxu and Imran, Muhammad and Zhang, Teng and Zhou, Yuyin and Liang, Muxuan and Gong, Kuang and Shao, Wei},
109 | journal={IEEE Journal of Biomedical and Health Informatics},
110 | year={2025},
111 | publisher={IEEE}
112 | }
113 | ```
114 |
--------------------------------------------------------------------------------
/configs/brats_linear.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "PMUB"
3 | train_dataroot: "data/Brats_train"
4 | sample_dataroot: "data/Brats_test"
5 | image_size: 256
6 | channels: 1
7 | logit_transform: false
8 | uniform_dequantization: false
9 | gaussian_dequantization: false
10 | random_flip: true
11 | rescaled: true
12 | num_workers: 8
13 |
14 | model:
15 | type: "sg"
16 | in_channels: 2
17 | out_ch: 1
18 | ch: 128
19 | ch_mult: [1, 1, 2, 2, 4, 4]
20 | num_res_blocks: 2
21 | attn_resolutions: [16, ]
22 | dropout: 0.0
23 | var_type: fixedsmall
24 | ema_rate: 0.999
25 | ema: True
26 | resamp_with_conv: True
27 |
28 | diffusion:
29 | beta_schedule: linear
30 | beta_start: 0.0001
31 | beta_end: 0.02
32 | num_diffusion_timesteps: 1000
33 |
34 | training:
35 | batch_size: 16
36 | n_epochs: 10000
37 | n_iters: 5000000
38 | snapshot_freq: 100000
39 | validation_freq: 5000000000
40 |
41 | sampling:
42 | batch_size: 8
43 | ckpt_id: [100000, 200000, 300000, 400000, 500000, 600000, 700000, 800000, 900000, 1000000, 1500000, 2000000]
44 | last_only: True
45 |
46 | sampling_inter:
47 | batch_size: 59
48 | last_only: True
49 |
50 | sampling_fid:
51 | batch_size: 128
52 | last_only: True
53 |
54 | optim:
55 | weight_decay: 0.000
56 | optimizer: "Adam"
57 | lr: 0.00002
58 | beta1: 0.9
59 | amsgrad: false
60 | eps: 0.00000001
61 |
--------------------------------------------------------------------------------
/configs/ldfd_linear.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "PMUB"
3 | train_dataroot: "data/LD_FD_CT_train"
4 | val_dataroot: "data/PMUB-val"
5 | sample_dataroot: "data/LD_FD_CT_test"
6 | image_size: 256
7 | channels: 1
8 | logit_transform: false
9 | uniform_dequantization: false
10 | gaussian_dequantization: false
11 | random_flip: true
12 | rescaled: true
13 | num_workers: 8
14 |
15 | model:
16 | type: "sg"
17 | in_channels: 2
18 | out_ch: 1
19 | ch: 128
20 | ch_mult: [1, 1, 2, 2, 4, 4]
21 | num_res_blocks: 2
22 | attn_resolutions: [16,]
23 | dropout: 0.0
24 | var_type: fixedsmall
25 | ema_rate: 0.999
26 | ema: True
27 | resamp_with_conv: True
28 |
29 | diffusion:
30 | beta_schedule: linear
31 | beta_start: 0.0001
32 | beta_end: 0.02
33 | num_diffusion_timesteps: 1000
34 |
35 | training:
36 | batch_size: 16
37 | n_epochs: 10000
38 | n_iters: 5000000
39 | snapshot_freq: 100000
40 | validation_freq: 5000000000
41 |
42 | sampling:
43 | batch_size: 8
44 | ckpt_id: [100000, 200000, 300000, 400000, 500000, 600000, 700000, 800000, 900000, 1000000, 1500000, 2000000]
45 | last_only: True
46 |
47 | sampling_inter:
48 | batch_size: 59
49 | last_only: True
50 |
51 | sampling_fid:
52 | batch_size: 128
53 | last_only: True
54 |
55 | optim:
56 | weight_decay: 0.000
57 | optimizer: "Adam"
58 | lr: 0.00002
59 | beta1: 0.9
60 | amsgrad: false
61 | eps: 0.00000001
62 |
--------------------------------------------------------------------------------
/configs/pmub_linear.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "PMUB"
3 | train_dataroot: "data/PMUB-train"
4 | val_dataroot: "data/PMUB-val"
5 | sample_dataroot: "data/PMUB-test"
6 | image_size: 256
7 | channels: 1
8 | logit_transform: false
9 | uniform_dequantization: false
10 | gaussian_dequantization: false
11 | random_flip: true
12 | rescaled: true
13 | num_workers: 8
14 |
15 | model:
16 | type: "sr"
17 | in_channels: 3
18 | out_ch: 1
19 | ch: 128
20 | ch_mult: [1, 1, 2, 2, 4, 4]
21 | num_res_blocks: 2
22 | attn_resolutions: [16]
23 | dropout: 0.0
24 | var_type: fixedsmall
25 | ema_rate: 0.999
26 | ema: True
27 | resamp_with_conv: True
28 |
29 | diffusion:
30 | beta_schedule: linear
31 | beta_start: 0.0001
32 | beta_end: 0.02
33 | num_diffusion_timesteps: 1000
34 |
35 | training:
36 | batch_size: 16
37 | n_epochs: 10000
38 | n_iters: 5000000
39 | snapshot_freq: 100000
40 | validation_freq: 5000000000
41 |
42 | sampling:
43 | batch_size: 8
44 | ckpt_id: [100000, 200000, 300000, 400000, 500000, 600000, 700000, 800000, 900000, 1000000, 1500000, 2000000]
45 | last_only: True
46 |
47 | sampling_inter:
48 | batch_size: 59
49 | last_only: True
50 |
51 | sampling_fid:
52 | batch_size: 58
53 | last_only: True
54 |
55 | optim:
56 | weight_decay: 0.000
57 | optimizer: "Adam"
58 | lr: 0.00002
59 | beta1: 0.9
60 | amsgrad: false
61 | eps: 0.00000001
62 |
--------------------------------------------------------------------------------
/datasets/BRATS.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from torch.utils.data import Dataset
4 | import random
5 | import torch
6 |
7 | from .sr_util import get_paths_from_npys, brats_transform_augment
8 |
9 |
10 | class BRATS(Dataset):
11 | def __init__(self, dataroot, img_size, split='train', data_len=-1):
12 | self.img_size = img_size
13 | self.data_len = data_len
14 | self.split = split
15 | img_root = dataroot + '/A/'
16 | gt_root = dataroot + '/B/'
17 | self.img_npy_path, self.gt_npy_path = get_paths_from_npys(img_root, gt_root)
18 | self.data_len = len(self.img_npy_path)
19 |
20 | def __len__(self):
21 | return self.data_len
22 |
23 | def __getitem__(self, index):
24 | img_FD = None
25 | img_LD = None
26 | base_name = None
27 | extension = None
28 | number = None
29 | FW_path = None
30 | BW_path = None
31 |
32 | base_name = self.img_npy_path[index].split('/')[-1]
33 | case_name = base_name.split('.')[0]
34 |
35 | img_npy = np.load(self.img_npy_path[index])
36 | img = Image.fromarray(img_npy)
37 | gt_npy = np.load(self.gt_npy_path[index])
38 | gt = Image.fromarray(gt_npy)
39 | img = img.resize((self.img_size, self.img_size))
40 | gt = gt.resize((self.img_size, self.img_size))
41 |
42 | [img, gt] = brats_transform_augment(
43 | [img, gt], split=self.split)
44 |
45 | return {'FD': gt, 'LD': img, 'case_name': case_name}
46 |
47 |
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/datasets/LDFDCT.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | import lmdb
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 | import random
6 | import torch
7 | from .sr_util import get_paths_from_images, get_valid_paths_from_images, get_valid_paths_from_test_images, transform_augment
8 |
9 |
10 | class LDFDCT(Dataset):
11 | def __init__(self, dataroot, img_size, split='train', data_len=-1):
12 | self.img_size = img_size
13 | self.data_len = data_len
14 | self.split = split
15 | self.img_ld_path, self.img_fd_path = get_paths_from_images(dataroot)
16 | self.data_len = len(self.img_ld_path)
17 |
18 | def __len__(self):
19 | return self.data_len
20 |
21 | def __getitem__(self, index):
22 | img_FD = None
23 | img_LD = None
24 | base_name = None
25 | extension = None
26 | number = None
27 | FW_path = None
28 | BW_path = None
29 |
30 | base_name = self.img_ld_path[index].split('/')[-1]
31 | case_name = base_name.split('_')[0]
32 |
33 | img_LD = Image.open(self.img_ld_path[index]).convert("L")
34 | img_FD = Image.open(self.img_fd_path[index]).convert("L")
35 | img_LD = img_LD.resize((self.img_size, self.img_size))
36 | img_FD = img_FD.resize((self.img_size, self.img_size))
37 |
38 | [img_LD, img_FD] = transform_augment(
39 | [img_LD, img_FD], split=self.split, min_max=(-1, 1))
40 |
41 | return {'FD': img_FD, 'LD': img_LD, 'case_name': case_name}
42 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numbers
4 | import torchvision.transforms as transforms
5 | import torchvision.transforms.functional as F
6 | from torch.utils.data import Subset
7 | import numpy as np
8 |
9 |
10 | class Crop(object):
11 | def __init__(self, x1, x2, y1, y2):
12 | self.x1 = x1
13 | self.x2 = x2
14 | self.y1 = y1
15 | self.y2 = y2
16 |
17 | def __call__(self, img):
18 | return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1)
19 |
20 | def __repr__(self):
21 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
22 | self.x1, self.x2, self.y1, self.y2
23 | )
24 |
25 |
26 |
27 | def logit_transform(image, lam=1e-6):
28 | image = lam + (1 - 2 * lam) * image
29 | return torch.log(image) - torch.log1p(-image)
30 |
31 |
32 | def data_transform(config, X):
33 | if config.data.uniform_dequantization:
34 | X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0
35 | if config.data.gaussian_dequantization:
36 | X = X + torch.randn_like(X) * 0.01
37 |
38 | if config.data.rescaled:
39 | X = 2 * X - 1.0
40 | elif config.data.logit_transform:
41 | X = logit_transform(X)
42 |
43 | if hasattr(config, "image_mean"):
44 | return X - config.image_mean.to(X.device)[None, ...]
45 |
46 | return X
47 |
48 |
49 | def inverse_data_transform(config, X):
50 | if hasattr(config, "image_mean"):
51 | X = X + config.image_mean.to(X.device)[None, ...]
52 |
53 | if config.data.logit_transform:
54 | X = torch.sigmoid(X)
55 | elif config.data.rescaled:
56 | X = (X + 1.0) / 2.0
57 |
58 | return torch.clamp(X, 0.0, 1.0)
59 |
--------------------------------------------------------------------------------
/datasets/pmub.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | import lmdb
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 | import torch
6 | import random
7 | import matplotlib.pyplot as plt
8 | from .sr_util import get_valid_paths_from_images, get_valid_paths_from_test_images, transform_augment
9 |
10 |
11 | class PMUB(Dataset):
12 | def __init__(self, dataroot, img_size, split='train', data_len=-1):
13 | self.img_size = img_size
14 | self.data_len = data_len
15 | self.split = split
16 |
17 | self.img_path = get_valid_paths_from_images(dataroot)
18 | self.test_img_path = get_valid_paths_from_test_images(dataroot)
19 |
20 | if self.split == 'test':
21 | self.dataset_len = len(self.test_img_path)
22 | else:
23 | self.dataset_len = len(self.img_path)
24 |
25 | if self.data_len <= 0:
26 | self.data_len = self.dataset_len
27 | else:
28 | self.data_len = min(self.data_len, self.dataset_len)
29 |
30 | def __len__(self):
31 | return self.data_len
32 |
33 | def __getitem__(self, index):
34 | img_FW = None
35 | img_MD = None
36 | img_BW = None
37 | base_name = None
38 | extension = None
39 | number = None
40 | FW_path = None
41 | BW_path = None
42 |
43 | base_name = self.img_path[index].split('_')[0]
44 | case_name = int(base_name.split('/')[-1].split('-')[-1])
45 | extension = self.img_path[index].split('_')[-1].split('.')[-1]
46 | number = int(self.img_path[index].split('_')[-1].split('.')[0])
47 | FW_path = base_name + '_' + str(number+1) + '.' + extension
48 | BW_path = base_name + '_' + str(number-1) + '.' + extension
49 |
50 | img_BW = Image.open(BW_path).convert("L")
51 | img_MD = Image.open(self.img_path[index]).convert("L")
52 | img_FW = Image.open(FW_path).convert("L")
53 |
54 | img_BW = img_BW.resize((self.img_size, self.img_size))
55 | img_MD = img_MD.resize((self.img_size, self.img_size))
56 | img_FW = img_FW.resize((self.img_size, self.img_size))
57 |
58 | [img_BW, img_MD, img_FW] = transform_augment(
59 | [img_BW, img_MD, img_FW], split=self.split, min_max=(-1, 1))
60 |
61 | return {'BW': img_BW, 'MD': img_MD, 'FW': img_FW, 'Index': index, 'case_name': case_name}
62 |
--------------------------------------------------------------------------------
/datasets/sr_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision
4 | import random
5 | import numpy as np
6 | import glob
7 |
8 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG',
9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
10 |
11 |
12 | def is_image_file(filename):
13 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
14 |
15 |
16 | def extract_number(filename):
17 | number = int(filename.split('_')[1].split('.')[0])
18 | return number
19 |
20 | # LDFDCT
21 | def get_paths_from_images(path):
22 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
23 |
24 | ld_images = glob.glob(path + "**/**/*ld.png", recursive=True)
25 | fd_images = glob.glob(path + "**/**/*fd.png", recursive=True)
26 |
27 | assert ld_images, '{:s} has no valid ld image file'.format(path)
28 | assert fd_images, '{:s} has no valid fd image file'.format(path)
29 | assert len(ld_images) == len(fd_images), 'Low Dose images nd Full Dose images are not paired!'
30 | return sorted(ld_images), sorted(fd_images)
31 |
32 | # Single SR
33 | def get_paths_from_single_sr_images(path):
34 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
35 |
36 | lr_images = glob.glob(path + "**/**/*lr.png", recursive=True)
37 | hr_images = glob.glob(path + "**/**/*hr.png", recursive=True)
38 |
39 | assert lr_images, '{:s} has no valid lr image file'.format(path)
40 | assert hr_images, '{:s} has no valid hr image file'.format(path)
41 | assert len(lr_images) == len(hr_images), 'Low Dose images nd Full Dose images are not paired!'
42 | return sorted(lr_images), sorted(hr_images)
43 |
44 |
45 | def get_paths_from_npys(path_data, path_gt):
46 | assert os.path.isdir(path_data), '{:s} is not a valid directory'.format(path_data)
47 | assert os.path.isdir(path_gt), '{:s} is not a valid directory'.format(path_gt)
48 |
49 | data_npy = glob.glob(path_data + "*.npy")
50 | gt_npy = glob.glob(path_gt + "*.npy")
51 |
52 | assert data_npy, '{:s} has no valid data npy file'.format(path_data)
53 | assert gt_npy, '{:s} has no valid GT npy file'.format(path_gt)
54 | assert len(data_npy) == len(gt_npy), 'Low Dose images nd Full Dose images are not paired!'
55 | return sorted(data_npy), sorted(gt_npy)
56 |
57 |
58 | # Delete head and tail for train and val
59 | def get_valid_paths_from_images(path):
60 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
61 | images = []
62 |
63 | for dirpath, folder_path, fnames in sorted(os.walk(path)):
64 |
65 | filtered_fnames = [fname for fname in fnames if fname.endswith('.png') and not fname.startswith('.')]
66 | fnames = filtered_fnames
67 |
68 | fnames = sorted(fnames, key=extract_number)
69 | new_fnames = fnames[1:-1]
70 |
71 | for fname in new_fnames:
72 | if is_image_file(fname):
73 | img_path = os.path.join(dirpath, fname)
74 | images.append(img_path)
75 |
76 | assert images, '{:s} has no valid image file'.format(path)
77 | return images
78 |
79 |
80 | # Delete tail for test
81 | def get_valid_paths_from_test_images(path):
82 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
83 | images = []
84 |
85 | for dirpath, _, fnames in sorted(os.walk(path)):
86 | filtered_fnames = [fname for fname in fnames if not fname.startswith('.')]
87 | fnames = filtered_fnames
88 |
89 | fnames = sorted(fnames, key=extract_number)
90 | new_fnames = fnames[:-1]
91 |
92 | for fname in new_fnames:
93 | if is_image_file(fname):
94 | img_path = os.path.join(dirpath, fname)
95 | images.append(img_path)
96 |
97 | assert images, '{:s} has no valid image file'.format(path)
98 | return images
99 |
100 |
101 | def augment(img_list, hflip=True, rot=True, split='val'):
102 | # horizontal flip OR rotate
103 | hflip = hflip and (split == 'train' and random.random() < 0.5)
104 | vflip = rot and (split == 'train' and random.random() < 0.5)
105 | rot90 = rot and (split == 'train' and random.random() < 0.5)
106 |
107 | def _augment(img):
108 | if hflip:
109 | img = img[:, ::-1, :]
110 | if vflip:
111 | img = img[::-1, :, :]
112 | if rot90:
113 | img = img.transpose(1, 0, 2)
114 | return img
115 |
116 | return [_augment(img) for img in img_list]
117 |
118 |
119 | def transform2numpy(img):
120 | img = np.array(img)
121 | img = img.astype(np.float32) / 255.
122 | if img.ndim == 2:
123 | img = np.expand_dims(img, axis=2)
124 | # some images have 4 channels
125 | if img.shape[2] > 3:
126 | img = img[:, :, :3]
127 | return img
128 |
129 |
130 | def transform2tensor(img, min_max=(0, 1)):
131 | # HWC to CHW
132 | img = torch.from_numpy(np.ascontiguousarray(
133 | np.transpose(img, (2, 0, 1)))).float()
134 | # to range min_max
135 | img = img*(min_max[1] - min_max[0]) + min_max[0]
136 | return img
137 |
138 |
139 | totensor = torchvision.transforms.ToTensor()
140 | hflip = torchvision.transforms.RandomHorizontalFlip()
141 | Resize = torchvision.transforms.Resize((224, 224), antialias=True)
142 | def transform_augment(img_list, split='val', min_max=(0, 1)):
143 | imgs = [totensor(img) for img in img_list]
144 | if split == 'train':
145 | imgs = torch.stack(imgs, 0)
146 | imgs = hflip(imgs)
147 | imgs = torch.unbind(imgs, dim=0)
148 |
149 | ret_img = [img * (min_max[1] - min_max[0]) + min_max[0] for img in imgs]
150 | return ret_img
151 |
152 |
153 | def brats_transform_augment(img_list, split='val'):
154 | imgs = [totensor(img) for img in img_list]
155 | # imgs = [Resize(img) for img in imgs_tlist]
156 | # if split == 'train':
157 | # imgs = torch.stack(imgs, 0)
158 | # imgs = hflip(imgs)
159 | # imgs = torch.unbind(imgs, dim=0)
160 | ret_img = [img.clamp(-1., 1.) for img in imgs]
161 |
162 | return ret_img
163 |
--------------------------------------------------------------------------------
/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import hashlib
4 | import errno
5 | from torch.utils.model_zoo import tqdm
6 |
7 |
8 | def gen_bar_updater():
9 | pbar = tqdm(total=None)
10 |
11 | def bar_update(count, block_size, total_size):
12 | if pbar.total is None and total_size:
13 | pbar.total = total_size
14 | progress_bytes = count * block_size
15 | pbar.update(progress_bytes - pbar.n)
16 |
17 | return bar_update
18 |
19 |
20 | def check_integrity(fpath, md5=None):
21 | if md5 is None:
22 | return True
23 | if not os.path.isfile(fpath):
24 | return False
25 | md5o = hashlib.md5()
26 | with open(fpath, 'rb') as f:
27 | # read in 1MB chunks
28 | for chunk in iter(lambda: f.read(1024 * 1024), b''):
29 | md5o.update(chunk)
30 | md5c = md5o.hexdigest()
31 | if md5c != md5:
32 | return False
33 | return True
34 |
35 |
36 | def makedir_exist_ok(dirpath):
37 | """
38 | Python2 support for os.makedirs(.., exist_ok=True)
39 | """
40 | try:
41 | os.makedirs(dirpath)
42 | except OSError as e:
43 | if e.errno == errno.EEXIST:
44 | pass
45 | else:
46 | raise
47 |
48 |
49 | def download_url(url, root, filename=None, md5=None):
50 | """Download a file from a url and place it in root.
51 |
52 | Args:
53 | url (str): URL to download file from
54 | root (str): Directory to place downloaded file in
55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL
56 | md5 (str, optional): MD5 checksum of the download. If None, do not check
57 | """
58 | from six.moves import urllib
59 |
60 | root = os.path.expanduser(root)
61 | if not filename:
62 | filename = os.path.basename(url)
63 | fpath = os.path.join(root, filename)
64 |
65 | makedir_exist_ok(root)
66 |
67 | # downloads file
68 | if os.path.isfile(fpath) and check_integrity(fpath, md5):
69 | print('Using downloaded and verified file: ' + fpath)
70 | else:
71 | try:
72 | print('Downloading ' + url + ' to ' + fpath)
73 | urllib.request.urlretrieve(
74 | url, fpath,
75 | reporthook=gen_bar_updater()
76 | )
77 | except OSError:
78 | if url[:5] == 'https':
79 | url = url.replace('https:', 'http:')
80 | print('Failed download. Trying https -> http instead.'
81 | ' Downloading ' + url + ' to ' + fpath)
82 | urllib.request.urlretrieve(
83 | url, fpath,
84 | reporthook=gen_bar_updater()
85 | )
86 |
87 |
88 | def list_dir(root, prefix=False):
89 | """List all directories at a given root
90 |
91 | Args:
92 | root (str): Path to directory whose folders need to be listed
93 | prefix (bool, optional): If true, prepends the path to each result, otherwise
94 | only returns the name of the directories found
95 | """
96 | root = os.path.expanduser(root)
97 | directories = list(
98 | filter(
99 | lambda p: os.path.isdir(os.path.join(root, p)),
100 | os.listdir(root)
101 | )
102 | )
103 |
104 | if prefix is True:
105 | directories = [os.path.join(root, d) for d in directories]
106 |
107 | return directories
108 |
109 |
110 | def list_files(root, suffix, prefix=False):
111 | """List all files ending with a suffix at a given root
112 |
113 | Args:
114 | root (str): Path to directory whose folders need to be listed
115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
116 | It uses the Python "str.endswith" method and is passed directly
117 | prefix (bool, optional): If true, prepends the path to each result, otherwise
118 | only returns the name of the files found
119 | """
120 | root = os.path.expanduser(root)
121 | files = list(
122 | filter(
123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
124 | os.listdir(root)
125 | )
126 | )
127 |
128 | if prefix is True:
129 | files = [os.path.join(root, d) for d in files]
130 |
131 | return files
132 |
133 |
134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None):
135 | """Download a Google Drive file from and place it in root.
136 |
137 | Args:
138 | file_id (str): id of file to be downloaded
139 | root (str): Directory to place downloaded file in
140 | filename (str, optional): Name to save the file under. If None, use the id of the file.
141 | md5 (str, optional): MD5 checksum of the download. If None, do not check
142 | """
143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
144 | import requests
145 | url = "https://docs.google.com/uc?export=download"
146 |
147 | root = os.path.expanduser(root)
148 | if not filename:
149 | filename = file_id
150 | fpath = os.path.join(root, filename)
151 |
152 | makedir_exist_ok(root)
153 |
154 | if os.path.isfile(fpath) and check_integrity(fpath, md5):
155 | print('Using downloaded and verified file: ' + fpath)
156 | else:
157 | session = requests.Session()
158 |
159 | response = session.get(url, params={'id': file_id}, stream=True)
160 | token = _get_confirm_token(response)
161 |
162 | if token:
163 | params = {'id': file_id, 'confirm': token}
164 | response = session.get(url, params=params, stream=True)
165 |
166 | _save_response_content(response, fpath)
167 |
168 |
169 | def _get_confirm_token(response):
170 | for key, value in response.cookies.items():
171 | if key.startswith('download_warning'):
172 | return value
173 |
174 | return None
175 |
176 |
177 | def _save_response_content(response, destination, chunk_size=32768):
178 | with open(destination, "wb") as f:
179 | pbar = tqdm(total=None)
180 | progress = 0
181 | for chunk in response.iter_content(chunk_size):
182 | if chunk: # filter out keep-alive new chunks
183 | f.write(chunk)
184 | progress += len(chunk)
185 | pbar.update(progress - pbar.n)
186 | pbar.close()
187 |
--------------------------------------------------------------------------------
/ddpm_main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import traceback
3 | import shutil
4 | import logging
5 | import yaml
6 | import sys
7 | import os
8 | import torch
9 | import numpy as np
10 | import torch.utils.tensorboard as tb
11 |
12 | from runners.diffusion import Diffusion
13 |
14 | torch.set_printoptions(sci_mode=False)
15 |
16 |
17 | def parse_args_and_config():
18 | parser = argparse.ArgumentParser(description=globals()["__doc__"])
19 |
20 | parser.add_argument(
21 | "--config", type=str, default="pmub_linear.yml", help="Path to the config file"
22 | )
23 | parser.add_argument(
24 | "--dataset", type=str, default="PMUB", help="Name of dataset(LDFDCT, BRATS, PMUB)"
25 | )
26 | parser.add_argument("--seed", type=int, default=1244, help="Random seed")
27 | parser.add_argument(
28 | "--exp", type=str, default="exp", help="Path for saving running related data."
29 | )
30 | parser.add_argument(
31 | "--doc",
32 | type=str,
33 | default="DDPM_experiments",
34 | help="A string for documentation purpose. "
35 | "Will be the name of the log folder.",
36 | )
37 | parser.add_argument(
38 | "--comment", type=str, default="", help="A string for experiment comment"
39 | )
40 | parser.add_argument(
41 | "--verbose",
42 | type=str,
43 | default="info",
44 | help="Verbose level: info | debug | warning | critical",
45 | )
46 | parser.add_argument("--test", action="store_true", help="Whether to test the model")
47 | parser.add_argument(
48 | "--sample",
49 | action="store_true",
50 | help="Whether to produce samples from the model",
51 | )
52 | parser.add_argument("--fid", action="store_true")
53 | parser.add_argument("--interpolation", action="store_true")
54 | parser.add_argument(
55 | "--resume_training", action="store_true", help="Whether to resume training"
56 | )
57 | parser.add_argument(
58 | "-i",
59 | "--image_folder",
60 | type=str,
61 | default="images",
62 | help="The folder name of samples",
63 | )
64 | parser.add_argument(
65 | "--ni",
66 | action="store_false",
67 | help="No interaction. Suitable for Slurm Job launcher",
68 | )
69 | parser.add_argument("--use_pretrained", action="store_true")
70 | parser.add_argument(
71 | "--sample_type",
72 | type=str,
73 | default="ddpm_noisy",
74 | help="sampling approach (generalized or ddpm_noisy)",
75 | )
76 | parser.add_argument(
77 | "--timesteps", type=int, default=1000, help="number of steps involved"
78 | )
79 | parser.add_argument(
80 | "--eta",
81 | type=float,
82 | default=0.0,
83 | help="eta used to control the variances of sigma",
84 | )
85 | parser.add_argument("--sequence", action="store_true")
86 |
87 | args = parser.parse_args()
88 | args.log_path = os.path.join(args.exp, "logs", args.doc)
89 |
90 | # parse config file
91 | with open(os.path.join("configs", args.config), "r") as f:
92 | config = yaml.safe_load(f)
93 | new_config = dict2namespace(config)
94 |
95 | tb_path = os.path.join(args.exp, "tensorboard", args.doc)
96 |
97 | # No test No sampling No resume training
98 | if not args.test and not args.sample:
99 | if not args.resume_training:
100 | if os.path.exists(args.log_path):
101 | overwrite = False
102 | if args.ni:
103 | overwrite = True
104 | else:
105 | response = input("Folder already exists. Overwrite? (Y/N)")
106 | if response.upper() == "Y":
107 | overwrite = True
108 |
109 | if overwrite:
110 | shutil.rmtree(args.log_path)
111 | shutil.rmtree(tb_path)
112 | os.makedirs(args.log_path)
113 | if os.path.exists(tb_path):
114 | shutil.rmtree(tb_path)
115 | else:
116 | print("Folder exists. Program halted.")
117 | sys.exit(0)
118 | else:
119 | os.makedirs(args.log_path)
120 |
121 | with open(os.path.join(args.log_path, "config.yml"), "w") as f:
122 | yaml.dump(new_config, f, default_flow_style=False)
123 |
124 | new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path)
125 | # setup logger
126 | level = getattr(logging, args.verbose.upper(), None)
127 | if not isinstance(level, int):
128 | raise ValueError("level {} not supported".format(args.verbose))
129 |
130 | handler1 = logging.StreamHandler()
131 | handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt"))
132 | formatter = logging.Formatter(
133 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
134 | )
135 | handler1.setFormatter(formatter)
136 | handler2.setFormatter(formatter)
137 | logger = logging.getLogger()
138 | logger.addHandler(handler1)
139 | logger.addHandler(handler2)
140 | logger.setLevel(level)
141 |
142 | else:
143 | level = getattr(logging, args.verbose.upper(), None)
144 | if not isinstance(level, int):
145 | raise ValueError("level {} not supported".format(args.verbose))
146 |
147 | handler1 = logging.StreamHandler()
148 | formatter = logging.Formatter(
149 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
150 | )
151 | handler1.setFormatter(formatter)
152 | logger = logging.getLogger()
153 | logger.addHandler(handler1)
154 | logger.setLevel(level)
155 |
156 | # Sample from the model
157 | if args.sample:
158 | os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True)
159 | if args.fid:
160 | args.image_folder = os.path.join(
161 | args.exp, "image_samples", args.doc, "images_fid")
162 | if args.interpolation:
163 | args.image_folder = os.path.join(
164 | args.exp, "image_samples", args.doc, "images_interpolation")
165 |
166 | if not os.path.exists(args.image_folder):
167 | os.makedirs(args.image_folder)
168 | else:
169 | if not (args.fid or args.interpolation):
170 | overwrite = False
171 | if args.ni:
172 | overwrite = True
173 | else:
174 | response = input(
175 | f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)"
176 | )
177 | if response.upper() == "Y":
178 | overwrite = True
179 |
180 | if overwrite:
181 | shutil.rmtree(args.image_folder)
182 | os.makedirs(args.image_folder)
183 | else:
184 | print("Output image folder exists. Program halted.")
185 | sys.exit(0)
186 |
187 | # add device
188 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
189 | logging.info("Using device: {}".format(device))
190 | new_config.device = device
191 |
192 | # set random seed
193 | torch.manual_seed(args.seed)
194 | np.random.seed(args.seed)
195 | if torch.cuda.is_available():
196 | torch.cuda.manual_seed_all(args.seed)
197 |
198 | torch.backends.cudnn.benchmark = True
199 |
200 | return args, new_config
201 |
202 |
203 | def dict2namespace(config):
204 | namespace = argparse.Namespace()
205 | for key, value in config.items():
206 | if isinstance(value, dict):
207 | new_value = dict2namespace(value)
208 | else:
209 | new_value = value
210 | setattr(namespace, key, new_value)
211 | return namespace
212 |
213 |
214 | def main():
215 | args, config = parse_args_and_config()
216 | logging.info("Writing log file to {}".format(args.log_path))
217 | logging.info("Exp instance id = {}".format(os.getpid()))
218 | logging.info("Exp comment = {}".format(args.comment))
219 |
220 | try:
221 | runner = Diffusion(args, config)
222 | if args.sample:
223 | if args.dataset=='PMUB':
224 | runner.sr_sample()
225 | elif args.dataset=='LDFDCT' or args.dataset=='BRATS':
226 | runner.sg_sample()
227 | else:
228 | raise Exception("This script only supports LDFDCT, BRATS and PMUB as sampling dataset. Feel free to add your own.")
229 | elif args.test:
230 | runner.test()
231 | else:
232 | if args.dataset=='PMUB':
233 | runner.sr_ddpm_train()
234 | elif args.dataset=='LDFDCT' or args.dataset=='BRATS':
235 | runner.sg_ddpm_train()
236 | else:
237 | raise Exception("This script only supports LDFDCT, BRATS and PMUB as training dataset. Feel free to add your own.")
238 | except Exception:
239 | logging.error(traceback.format_exc())
240 |
241 | return 0
242 |
243 |
244 | if __name__ == "__main__":
245 | sys.exit(main())
246 |
--------------------------------------------------------------------------------
/fast_ddpm_main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import traceback
3 | import shutil
4 | import logging
5 | import yaml
6 | import sys
7 | import os
8 | import torch
9 | import numpy as np
10 | import torch.utils.tensorboard as tb
11 |
12 | from runners.diffusion import Diffusion
13 |
14 | torch.set_printoptions(sci_mode=False)
15 |
16 |
17 | def parse_args_and_config():
18 | parser = argparse.ArgumentParser(description=globals()["__doc__"])
19 |
20 | parser.add_argument(
21 | "--config", type=str, default="pmub_linear.yml", help="Path to the config file"
22 | )
23 | parser.add_argument(
24 | "--dataset", type=str, default="PMUB", help="Name of dataset(LDFDCT, BRATS, PMUB)"
25 | )
26 | parser.add_argument("--seed", type=int, default=1244, help="Random seed")
27 | parser.add_argument(
28 | "--exp", type=str, default="exp", help="Path for saving running related data."
29 | )
30 | parser.add_argument(
31 | "--doc",
32 | type=str,
33 | default="Fast-DDPM_experiments",
34 | help="A string for documentation purpose. "
35 | "Will be the name of the log folder.",
36 | )
37 | parser.add_argument(
38 | "--comment", type=str, default="", help="A string for experiment comment"
39 | )
40 | parser.add_argument(
41 | "--verbose",
42 | type=str,
43 | default="info",
44 | help="Verbose level: info | debug | warning | critical",
45 | )
46 | parser.add_argument("--test", action="store_true", help="Whether to test the model")
47 | parser.add_argument(
48 | "--sample",
49 | action="store_true",
50 | help="Whether to produce samples from the model",
51 | )
52 | parser.add_argument("--fid", action="store_true")
53 | parser.add_argument("--interpolation", action="store_true")
54 | parser.add_argument(
55 | "--resume_training", action="store_true", help="Whether to resume training"
56 | )
57 | parser.add_argument(
58 | "-i",
59 | "--image_folder",
60 | type=str,
61 | default="images",
62 | help="The folder name of samples",
63 | )
64 | parser.add_argument(
65 | "--ni",
66 | action="store_false",
67 | help="No interaction. Suitable for Slurm Job launcher",
68 | )
69 | parser.add_argument("--use_pretrained", action="store_true")
70 | parser.add_argument(
71 | "--sample_type",
72 | type=str,
73 | default="generalized",
74 | help="sampling approach (generalized or ddpm_noisy)",
75 | )
76 | parser.add_argument(
77 | "--scheduler_type",
78 | type=str,
79 | default="uniform",
80 | help="sample involved time steps according to (uniform or non-uniform)",
81 | )
82 | parser.add_argument(
83 | "--timesteps", type=int, default=100, help="number of steps involved"
84 | )
85 | parser.add_argument(
86 | "--eta",
87 | type=float,
88 | default=0.0,
89 | help="eta used to control the variances of sigma",
90 | )
91 | parser.add_argument("--sequence", action="store_true")
92 |
93 | args = parser.parse_args()
94 | args.log_path = os.path.join(args.exp, "logs", args.doc)
95 |
96 | # parse config file
97 | with open(os.path.join("configs", args.config), "r") as f:
98 | config = yaml.safe_load(f)
99 | new_config = dict2namespace(config)
100 |
101 | tb_path = os.path.join(args.exp, "tensorboard", args.doc)
102 |
103 | # No test No sampling No resume training
104 | if not args.test and not args.sample:
105 | if not args.resume_training:
106 | if os.path.exists(args.log_path):
107 | overwrite = False
108 | if args.ni:
109 | overwrite = True
110 | else:
111 | response = input("Folder already exists. Overwrite? (Y/N)")
112 | if response.upper() == "Y":
113 | overwrite = True
114 |
115 | if overwrite:
116 | shutil.rmtree(args.log_path)
117 | shutil.rmtree(tb_path)
118 | os.makedirs(args.log_path)
119 | if os.path.exists(tb_path):
120 | shutil.rmtree(tb_path)
121 | else:
122 | print("Folder exists. Program halted.")
123 | sys.exit(0)
124 | else:
125 | os.makedirs(args.log_path)
126 |
127 | with open(os.path.join(args.log_path, "config.yml"), "w") as f:
128 | yaml.dump(new_config, f, default_flow_style=False)
129 |
130 | new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path)
131 | # setup logger
132 | level = getattr(logging, args.verbose.upper(), None)
133 | if not isinstance(level, int):
134 | raise ValueError("level {} not supported".format(args.verbose))
135 |
136 | handler1 = logging.StreamHandler()
137 | handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt"))
138 | formatter = logging.Formatter(
139 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
140 | )
141 | handler1.setFormatter(formatter)
142 | handler2.setFormatter(formatter)
143 | logger = logging.getLogger()
144 | logger.addHandler(handler1)
145 | logger.addHandler(handler2)
146 | logger.setLevel(level)
147 |
148 | else:
149 | level = getattr(logging, args.verbose.upper(), None)
150 | if not isinstance(level, int):
151 | raise ValueError("level {} not supported".format(args.verbose))
152 |
153 | handler1 = logging.StreamHandler()
154 | formatter = logging.Formatter(
155 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
156 | )
157 | handler1.setFormatter(formatter)
158 | logger = logging.getLogger()
159 | logger.addHandler(handler1)
160 | logger.setLevel(level)
161 |
162 | # Sample from the model
163 | if args.sample:
164 | os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True)
165 | if args.fid:
166 | args.image_folder = os.path.join(
167 | args.exp, "image_samples", args.doc, "images_fid")
168 | if args.interpolation:
169 | args.image_folder = os.path.join(
170 | args.exp, "image_samples", args.doc, "images_interpolation")
171 |
172 | if not os.path.exists(args.image_folder):
173 | os.makedirs(args.image_folder)
174 | else:
175 | if not (args.fid or args.interpolation):
176 | overwrite = False
177 | if args.ni:
178 | overwrite = True
179 | else:
180 | response = input(
181 | f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)"
182 | )
183 | if response.upper() == "Y":
184 | overwrite = True
185 |
186 | if overwrite:
187 | shutil.rmtree(args.image_folder)
188 | os.makedirs(args.image_folder)
189 | else:
190 | print("Output image folder exists. Program halted.")
191 | sys.exit(0)
192 |
193 | # add device
194 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
195 | logging.info("Using device: {}".format(device))
196 | new_config.device = device
197 |
198 | # set random seed
199 | torch.manual_seed(args.seed)
200 | np.random.seed(args.seed)
201 | if torch.cuda.is_available():
202 | torch.cuda.manual_seed_all(args.seed)
203 |
204 | torch.backends.cudnn.benchmark = True
205 |
206 | return args, new_config
207 |
208 |
209 | def dict2namespace(config):
210 | namespace = argparse.Namespace()
211 | for key, value in config.items():
212 | if isinstance(value, dict):
213 | new_value = dict2namespace(value)
214 | else:
215 | new_value = value
216 | setattr(namespace, key, new_value)
217 | return namespace
218 |
219 |
220 | def main():
221 | args, config = parse_args_and_config()
222 | logging.info("Writing log file to {}".format(args.log_path))
223 | logging.info("Exp instance id = {}".format(os.getpid()))
224 | logging.info("Exp comment = {}".format(args.comment))
225 |
226 | try:
227 | runner = Diffusion(args, config)
228 | if args.sample:
229 | if args.dataset=='PMUB':
230 | runner.sr_sample()
231 | elif args.dataset=='LDFDCT' or args.dataset=='BRATS':
232 | runner.sg_sample()
233 | else:
234 | raise Exception("This script only supports LDFDCT, BRATS and PMUB as sampling dataset. Feel free to add your own.")
235 | elif args.test:
236 | runner.test()
237 | else:
238 | if args.dataset=='PMUB':
239 | runner.sr_train()
240 | elif args.dataset=='LDFDCT' or args.dataset=='BRATS':
241 | runner.sg_train()
242 | else:
243 | raise Exception("This script only supports LDFDCT, BRATS and PMUB as training dataset. Feel free to add your own.")
244 | except Exception:
245 | logging.error(traceback.format_exc())
246 |
247 | return 0
248 |
249 |
250 | if __name__ == "__main__":
251 | sys.exit(main())
252 |
--------------------------------------------------------------------------------
/functions/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 |
3 |
4 | def get_optimizer(config, parameters):
5 | if config.optim.optimizer == 'Adam':
6 | return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay,
7 | betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad,
8 | eps=config.optim.eps)
9 | elif config.optim.optimizer == 'RMSProp':
10 | return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay)
11 | elif config.optim.optimizer == 'SGD':
12 | return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9)
13 | else:
14 | raise NotImplementedError(
15 | 'Optimizer {} not understood.'.format(config.optim.optimizer))
16 |
--------------------------------------------------------------------------------
/functions/ckpt_util.py:
--------------------------------------------------------------------------------
1 | import os, hashlib
2 | import requests
3 | from tqdm import tqdm
4 |
5 | URL_MAP = {
6 | "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1",
7 | "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1",
8 | "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1",
9 | "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1",
10 | "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1",
11 | "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1",
12 | "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1",
13 | "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1",
14 | }
15 | CKPT_MAP = {
16 | "cifar10": "diffusion_cifar10_model/model-790000.ckpt",
17 | "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt",
18 | "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt",
19 | "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt",
20 | "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt",
21 | "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt",
22 | "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt",
23 | "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt",
24 | }
25 | MD5_MAP = {
26 | "cifar10": "82ed3067fd1002f5cf4c339fb80c4669",
27 | "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3",
28 | "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c",
29 | "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f",
30 | "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b",
31 | "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558",
32 | "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3",
33 | "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f",
34 | }
35 |
36 |
37 | def download(url, local_path, chunk_size=1024):
38 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
39 | with requests.get(url, stream=True) as r:
40 | total_size = int(r.headers.get("content-length", 0))
41 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
42 | with open(local_path, "wb") as f:
43 | for data in r.iter_content(chunk_size=chunk_size):
44 | if data:
45 | f.write(data)
46 | pbar.update(chunk_size)
47 |
48 |
49 | def md5_hash(path):
50 | with open(path, "rb") as f:
51 | content = f.read()
52 | return hashlib.md5(content).hexdigest()
53 |
54 |
55 | def get_ckpt_path(name, root=None, check=False):
56 | if 'church_outdoor' in name:
57 | name = name.replace('church_outdoor', 'church')
58 | assert name in URL_MAP
59 | # Modify the path when necessary
60 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("/atlas/u/tsong/.cache"))
61 | root = (
62 | root
63 | if root is not None
64 | else os.path.join(cachedir, "diffusion_models_converted")
65 | )
66 | path = os.path.join(root, CKPT_MAP[name])
67 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
68 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
69 | download(URL_MAP[name], path)
70 | md5 = md5_hash(path)
71 | assert md5 == MD5_MAP[name], md5
72 | return path
73 |
--------------------------------------------------------------------------------
/functions/denoising.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def compute_alpha(beta, t):
5 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
6 | # [1, alphas_cumprod]
7 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
8 | return a
9 |
10 |
11 | def generalized_steps(x, seq, model, b, **kwargs):
12 | with torch.no_grad():
13 | n = x.size(0)
14 | seq_next = [-1] + list(seq[:-1])
15 | x0_preds = []
16 | xs = [x]
17 | for i, j in zip(reversed(seq), reversed(seq_next)):
18 | t = (torch.ones(n) * i).to(x.device)
19 | next_t = (torch.ones(n) * j).to(x.device)
20 | at = compute_alpha(b, t.long())
21 | at_next = compute_alpha(b, next_t.long())
22 | xt = xs[-1].to('cuda')
23 | et = model(xt, t)
24 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
25 | x0_preds.append(x0_t.to('cpu'))
26 | # Equation (12)
27 | c1 = (
28 | kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
29 | )
30 | c2 = ((1 - at_next) - c1 ** 2).sqrt()
31 |
32 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
33 | xs.append(xt_next.to('cpu'))
34 |
35 | return xs, x0_preds
36 |
37 |
38 | def ddpm_steps(x, seq, model, b, **kwargs):
39 | with torch.no_grad():
40 | n = x.size(0)
41 | seq_next = [-1] + list(seq[:-1])
42 | xs = [x]
43 | x0_preds = []
44 | betas = b
45 | for i, j in zip(reversed(seq), reversed(seq_next)):
46 | t = (torch.ones(n) * i).to(x.device)
47 | next_t = (torch.ones(n) * j).to(x.device)
48 | at = compute_alpha(betas, t.long())
49 | atm1 = compute_alpha(betas, next_t.long())
50 | beta_t = 1 - at / atm1
51 | x = xs[-1].to('cuda')
52 |
53 | output = model(x, t.float())
54 | e = output
55 |
56 | x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e
57 | x0_from_e = torch.clamp(x0_from_e, -1, 1)
58 | x0_preds.append(x0_from_e.to('cpu'))
59 | mean_eps = (
60 | (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x
61 | ) / (1.0 - at)
62 |
63 | mean = mean_eps
64 | noise = torch.randn_like(x)
65 | mask = 1 - (t == 0).float()
66 | mask = mask.view(-1, 1, 1, 1)
67 | logvar = beta_t.log()
68 | sample = mean + mask * torch.exp(0.5 * logvar) * noise
69 | xs.append(sample.to('cpu'))
70 | return xs, x0_preds
71 |
72 |
73 | def sr_generalized_steps(x, x_bw, x_fw, seq, model, b, **kwargs):
74 | with torch.no_grad():
75 | n = x.size(0)
76 | seq_next = [-1] + list(seq[:-1])
77 | x0_preds = []
78 | xs = [x]
79 |
80 | for i, j in zip(reversed(seq), reversed(seq_next)):
81 | t = (torch.ones(n) * i).to(x.device)
82 | next_t = (torch.ones(n) * j).to(x.device)
83 | at = compute_alpha(b, t.long())
84 | at_next = compute_alpha(b, next_t.long())
85 | xt = xs[-1].to('cuda')
86 | et = model(torch.cat([x_bw, x_fw, xt], dim=1), t)
87 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
88 | x0_preds.append(x0_t.to('cpu'))
89 | # Equation (12)
90 | c1 = (
91 | kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
92 | )
93 | c2 = ((1 - at_next) - c1 ** 2).sqrt()
94 |
95 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
96 | xs.append(xt_next.to('cpu'))
97 |
98 | return xs, x0_preds
99 |
100 |
101 | def sr_ddpm_steps(x, x_bw, x_fw, seq, model, b, **kwargs):
102 | with torch.no_grad():
103 | n = x.size(0)
104 | seq_next = [-1] + list(seq[:-1])
105 | xs = [x]
106 | x0_preds = []
107 | betas = b
108 | for i, j in zip(reversed(seq), reversed(seq_next)):
109 | t = (torch.ones(n) * i).to(x.device)
110 | next_t = (torch.ones(n) * j).to(x.device)
111 | at = compute_alpha(betas, t.long())
112 | atm1 = compute_alpha(betas, next_t.long())
113 | beta_t = 1 - at / atm1
114 | x = xs[-1].to('cuda')
115 |
116 | output = model(torch.cat([x_bw, x_fw, x], dim=1), t.float())
117 | e = output
118 |
119 | x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e
120 | x0_from_e = torch.clamp(x0_from_e, -1, 1)
121 | x0_preds.append(x0_from_e.to('cpu'))
122 | mean_eps = (
123 | (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x
124 | ) / (1.0 - at)
125 |
126 | mean = mean_eps
127 | noise = torch.randn_like(x)
128 | mask = 1 - (t == 0).float()
129 | mask = mask.view(-1, 1, 1, 1)
130 | logvar = beta_t.log()
131 | sample = mean + mask * torch.exp(0.5 * logvar) * noise
132 | xs.append(sample.to('cpu'))
133 | return xs, x0_preds
134 |
135 |
136 | def sg_generalized_steps(x, x_img, seq, model, b, **kwargs):
137 | with torch.no_grad():
138 | n = x.size(0)
139 | seq_next = [-1] + list(seq[:-1])
140 | x0_preds = []
141 | xs = [x]
142 |
143 | for i, j in zip(reversed(seq), reversed(seq_next)):
144 | t = (torch.ones(n) * i).to(x.device)
145 | next_t = (torch.ones(n) * j).to(x.device)
146 | at = compute_alpha(b, t.long())
147 | at_next = compute_alpha(b, next_t.long())
148 | xt = xs[-1].to('cuda')
149 | et = model(torch.cat([x_img, xt], dim=1), t)
150 |
151 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
152 | x0_preds.append(x0_t.to('cpu'))
153 | # Equation (12)
154 | c1 = (
155 | kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
156 | )
157 | c2 = ((1 - at_next) - c1 ** 2).sqrt()
158 |
159 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
160 | xs.append(xt_next.to('cpu'))
161 |
162 | return xs, x0_preds
163 |
164 |
165 | def sg_ddpm_steps(x, x_img, seq, model, b, **kwargs):
166 | with torch.no_grad():
167 | n = x.size(0)
168 | seq_next = [-1] + list(seq[:-1])
169 | xs = [x]
170 | x0_preds = []
171 | betas = b
172 | for i, j in zip(reversed(seq), reversed(seq_next)):
173 | t = (torch.ones(n) * i).to(x.device)
174 | next_t = (torch.ones(n) * j).to(x.device)
175 | at = compute_alpha(betas, t.long())
176 | atm1 = compute_alpha(betas, next_t.long())
177 | beta_t = 1 - at / atm1
178 | x = xs[-1].to('cuda')
179 |
180 | output = model(torch.cat([x_img, x], dim=1), t.float())
181 | e = output
182 |
183 | x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e
184 | x0_from_e = torch.clamp(x0_from_e, -1, 1)
185 | x0_preds.append(x0_from_e.to('cpu'))
186 | mean_eps = (
187 | (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x
188 | ) / (1.0 - at)
189 |
190 | mean = mean_eps
191 | noise = torch.randn_like(x)
192 | mask = 1 - (t == 0).float()
193 | mask = mask.view(-1, 1, 1, 1)
194 | logvar = beta_t.log()
195 | sample = mean + mask * torch.exp(0.5 * logvar) * noise
196 | xs.append(sample.to('cpu'))
197 | return xs, x0_preds
198 |
--------------------------------------------------------------------------------
/functions/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import time
4 | from medpy import metric
5 | import numpy as np
6 | np.bool = np.bool_
7 |
8 |
9 | def calculate_psnr(img1, img2):
10 | # img1: img
11 | # img2: gt
12 | # img1 and img2 have range [0, 255]
13 | img1 = img1.astype(np.float64)
14 | img2 = img2.astype(np.float64)
15 |
16 | mse = np.mean((img1 - img2)**2)
17 | psnr = 20 * math.log10(255.0 / math.sqrt(mse))
18 |
19 | return psnr
20 |
21 |
22 | def noise_estimation_loss(model,
23 | x0: torch.Tensor,
24 | t: torch.LongTensor,
25 | e: torch.Tensor,
26 | b: torch.Tensor, keepdim=False):
27 | # a: a_T in DDIM
28 | # 1-a: 1-a_T in DDIM
29 | a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
30 | # X_T
31 | x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
32 | output = model(x, t.float())
33 | if keepdim:
34 | return (e - output).square().sum(dim=(1, 2, 3))
35 | else:
36 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)
37 |
38 |
39 |
40 | def sr_noise_estimation_loss(model,
41 | x_bw: torch.Tensor,
42 | x_md: torch.Tensor,
43 | x_fw: torch.Tensor,
44 | t: torch.LongTensor,
45 | e: torch.Tensor,
46 | b: torch.Tensor, keepdim=False):
47 | # a: a_T in DDIM
48 | # 1-a: 1-a_T in DDIM
49 | a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
50 | # X_T
51 | x = x_md * a.sqrt() + e * (1.0 - a).sqrt()
52 |
53 | output = model(torch.cat([x_bw, x_fw, x], dim=1), t.float())
54 | if keepdim:
55 | return (e - output).square().sum(dim=(1, 2, 3))
56 | else:
57 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)
58 |
59 |
60 |
61 | def sg_noise_estimation_loss(model,
62 | x_img: torch.Tensor,
63 | x_gt: torch.Tensor,
64 | t: torch.LongTensor,
65 | e: torch.Tensor,
66 | b: torch.Tensor, keepdim=False):
67 | # a: a_T in DDIM
68 | # 1-a: 1-a_T in DDIM
69 | a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
70 | # X_T
71 | x = x_gt * a.sqrt() + e * (1.0 - a).sqrt()
72 | output = model(torch.cat([x_img, x], dim=1), t.float())
73 |
74 | if keepdim:
75 | return (e - output).square().sum(dim=(1, 2, 3))
76 | else:
77 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)
78 |
79 |
80 | loss_registry = {
81 | 'simple': noise_estimation_loss,
82 | 'sr': sr_noise_estimation_loss,
83 | 'sg': sg_noise_estimation_loss
84 | }
--------------------------------------------------------------------------------
/models/diffusion.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def get_timestep_embedding(timesteps, embedding_dim):
7 | """
8 | This matches the implementation in Denoising Diffusion Probabilistic Models:
9 | From Fairseq.
10 | Build sinusoidal embeddings.
11 | This matches the implementation in tensor2tensor, but differs slightly
12 | from the description in Section 3.5 of "Attention Is All You Need".
13 | """
14 | assert len(timesteps.shape) == 1
15 |
16 | half_dim = embedding_dim // 2
17 | emb = math.log(10000) / (half_dim - 1)
18 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
19 | emb = emb.to(device=timesteps.device)
20 | emb = timesteps.float()[:, None] * emb[None, :]
21 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
22 | if embedding_dim % 2 == 1: # zero pad
23 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
24 | return emb
25 |
26 |
27 | def nonlinearity(x):
28 | # swish
29 | return x*torch.sigmoid(x)
30 |
31 |
32 | def Normalize(in_channels):
33 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
34 |
35 |
36 | class Upsample(nn.Module):
37 | def __init__(self, in_channels, with_conv):
38 | super().__init__()
39 | self.with_conv = with_conv
40 | if self.with_conv:
41 | self.conv = torch.nn.Conv2d(in_channels,
42 | in_channels,
43 | kernel_size=3,
44 | stride=1,
45 | padding=1)
46 |
47 | def forward(self, x):
48 | x = torch.nn.functional.interpolate(
49 | x, scale_factor=2.0, mode="nearest")
50 | if self.with_conv:
51 | x = self.conv(x)
52 | return x
53 |
54 |
55 | class Downsample(nn.Module):
56 | def __init__(self, in_channels, with_conv):
57 | super().__init__()
58 | self.with_conv = with_conv
59 | if self.with_conv:
60 | # no asymmetric padding in torch conv, must do it ourselves
61 | self.conv = torch.nn.Conv2d(in_channels,
62 | in_channels,
63 | kernel_size=3,
64 | stride=2,
65 | padding=0)
66 |
67 | def forward(self, x):
68 | if self.with_conv:
69 | pad = (0, 1, 0, 1)
70 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
71 | x = self.conv(x)
72 | else:
73 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
74 | return x
75 |
76 |
77 | class ResnetBlock(nn.Module):
78 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
79 | dropout, temb_channels=512):
80 | super().__init__()
81 | self.in_channels = in_channels
82 | out_channels = in_channels if out_channels is None else out_channels
83 | self.out_channels = out_channels
84 | self.use_conv_shortcut = conv_shortcut
85 |
86 | self.norm1 = Normalize(in_channels)
87 | self.conv1 = torch.nn.Conv2d(in_channels,
88 | out_channels,
89 | kernel_size=3,
90 | stride=1,
91 | padding=1)
92 | self.temb_proj = torch.nn.Linear(temb_channels,
93 | out_channels)
94 | self.norm2 = Normalize(out_channels)
95 | self.dropout = torch.nn.Dropout(dropout)
96 | self.conv2 = torch.nn.Conv2d(out_channels,
97 | out_channels,
98 | kernel_size=3,
99 | stride=1,
100 | padding=1)
101 | if self.in_channels != self.out_channels:
102 | if self.use_conv_shortcut:
103 | self.conv_shortcut = torch.nn.Conv2d(in_channels,
104 | out_channels,
105 | kernel_size=3,
106 | stride=1,
107 | padding=1)
108 | else:
109 | self.nin_shortcut = torch.nn.Conv2d(in_channels,
110 | out_channels,
111 | kernel_size=1,
112 | stride=1,
113 | padding=0)
114 |
115 | def forward(self, x, temb):
116 | h = x
117 | h = self.norm1(h)
118 | h = nonlinearity(h)
119 | h = self.conv1(h)
120 |
121 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
122 |
123 | h = self.norm2(h)
124 | h = nonlinearity(h)
125 | h = self.dropout(h)
126 | h = self.conv2(h)
127 |
128 | if self.in_channels != self.out_channels:
129 | if self.use_conv_shortcut:
130 | x = self.conv_shortcut(x)
131 | else:
132 | x = self.nin_shortcut(x)
133 |
134 | return x+h
135 |
136 |
137 | class AttnBlock(nn.Module):
138 | def __init__(self, in_channels):
139 | super().__init__()
140 | self.in_channels = in_channels
141 |
142 | self.norm = Normalize(in_channels)
143 | self.q = torch.nn.Conv2d(in_channels,
144 | in_channels,
145 | kernel_size=1,
146 | stride=1,
147 | padding=0)
148 | self.k = torch.nn.Conv2d(in_channels,
149 | in_channels,
150 | kernel_size=1,
151 | stride=1,
152 | padding=0)
153 | self.v = torch.nn.Conv2d(in_channels,
154 | in_channels,
155 | kernel_size=1,
156 | stride=1,
157 | padding=0)
158 | self.proj_out = torch.nn.Conv2d(in_channels,
159 | in_channels,
160 | kernel_size=1,
161 | stride=1,
162 | padding=0)
163 |
164 | def forward(self, x):
165 | h_ = x
166 | h_ = self.norm(h_)
167 | q = self.q(h_)
168 | k = self.k(h_)
169 | v = self.v(h_)
170 |
171 | # compute attention
172 | b, c, h, w = q.shape
173 | q = q.reshape(b, c, h*w)
174 | q = q.permute(0, 2, 1) # b,hw,c
175 | k = k.reshape(b, c, h*w) # b,c,hw
176 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
177 | w_ = w_ * (int(c)**(-0.5))
178 | w_ = torch.nn.functional.softmax(w_, dim=2)
179 |
180 | # attend to values
181 | v = v.reshape(b, c, h*w)
182 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
183 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
184 | h_ = torch.bmm(v, w_)
185 | h_ = h_.reshape(b, c, h, w)
186 |
187 | h_ = self.proj_out(h_)
188 |
189 | return x+h_
190 |
191 |
192 | class Model(nn.Module):
193 | def __init__(self, config):
194 | super().__init__()
195 | self.config = config
196 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
197 | num_res_blocks = config.model.num_res_blocks
198 | attn_resolutions = config.model.attn_resolutions
199 | dropout = config.model.dropout
200 | in_channels = config.model.in_channels
201 | resolution = config.data.image_size
202 | resamp_with_conv = config.model.resamp_with_conv
203 | num_timesteps = config.diffusion.num_diffusion_timesteps
204 |
205 | if config.model.type == 'bayesian':
206 | self.logvar = nn.Parameter(torch.zeros(num_timesteps))
207 |
208 | self.ch = ch
209 | self.temb_ch = self.ch*4
210 | self.num_resolutions = len(ch_mult)
211 | self.num_res_blocks = num_res_blocks
212 | self.resolution = resolution
213 | self.in_channels = in_channels
214 |
215 | # timestep embedding
216 | self.temb = nn.Module()
217 | self.temb.dense = nn.ModuleList([
218 | torch.nn.Linear(self.ch,
219 | self.temb_ch),
220 | torch.nn.Linear(self.temb_ch,
221 | self.temb_ch),
222 | ])
223 |
224 | # downsampling
225 | self.conv_in = torch.nn.Conv2d(in_channels,
226 | self.ch,
227 | kernel_size=3,
228 | stride=1,
229 | padding=1)
230 |
231 | curr_res = resolution
232 | in_ch_mult = (1,)+ch_mult
233 | self.down = nn.ModuleList()
234 | block_in = None
235 | for i_level in range(self.num_resolutions):
236 | block = nn.ModuleList()
237 | attn = nn.ModuleList()
238 | block_in = ch*in_ch_mult[i_level]
239 | block_out = ch*ch_mult[i_level]
240 | for i_block in range(self.num_res_blocks):
241 | block.append(ResnetBlock(in_channels=block_in,
242 | out_channels=block_out,
243 | temb_channels=self.temb_ch,
244 | dropout=dropout))
245 | block_in = block_out
246 | if curr_res in attn_resolutions:
247 | attn.append(AttnBlock(block_in))
248 | down = nn.Module()
249 | down.block = block
250 | down.attn = attn
251 | if i_level != self.num_resolutions-1:
252 | down.downsample = Downsample(block_in, resamp_with_conv)
253 | curr_res = curr_res // 2
254 | self.down.append(down)
255 |
256 | # middle
257 | self.mid = nn.Module()
258 | self.mid.block_1 = ResnetBlock(in_channels=block_in,
259 | out_channels=block_in,
260 | temb_channels=self.temb_ch,
261 | dropout=dropout)
262 | self.mid.attn_1 = AttnBlock(block_in)
263 | self.mid.block_2 = ResnetBlock(in_channels=block_in,
264 | out_channels=block_in,
265 | temb_channels=self.temb_ch,
266 | dropout=dropout)
267 |
268 | # upsampling
269 | self.up = nn.ModuleList()
270 | for i_level in reversed(range(self.num_resolutions)):
271 | block = nn.ModuleList()
272 | attn = nn.ModuleList()
273 | block_out = ch*ch_mult[i_level]
274 | skip_in = ch*ch_mult[i_level]
275 | for i_block in range(self.num_res_blocks+1):
276 | if i_block == self.num_res_blocks:
277 | skip_in = ch*in_ch_mult[i_level]
278 | block.append(ResnetBlock(in_channels=block_in+skip_in,
279 | out_channels=block_out,
280 | temb_channels=self.temb_ch,
281 | dropout=dropout))
282 | block_in = block_out
283 | if curr_res in attn_resolutions:
284 | attn.append(AttnBlock(block_in))
285 | up = nn.Module()
286 | up.block = block
287 | up.attn = attn
288 | if i_level != 0:
289 | up.upsample = Upsample(block_in, resamp_with_conv)
290 | curr_res = curr_res * 2
291 | self.up.insert(0, up) # prepend to get consistent order
292 |
293 | # end
294 | self.norm_out = Normalize(block_in)
295 | self.conv_out = torch.nn.Conv2d(block_in,
296 | out_ch,
297 | kernel_size=3,
298 | stride=1,
299 | padding=1)
300 |
301 | def forward(self, x, t):
302 | assert x.shape[2] == x.shape[3] == self.resolution
303 |
304 | # timestep embedding
305 | temb = get_timestep_embedding(t, self.ch)
306 | temb = self.temb.dense[0](temb)
307 | temb = nonlinearity(temb)
308 | temb = self.temb.dense[1](temb)
309 |
310 | # downsampling
311 | hs = [self.conv_in(x)]
312 | for i_level in range(self.num_resolutions):
313 | for i_block in range(self.num_res_blocks):
314 | h = self.down[i_level].block[i_block](hs[-1], temb)
315 | if len(self.down[i_level].attn) > 0:
316 | h = self.down[i_level].attn[i_block](h)
317 | hs.append(h)
318 | if i_level != self.num_resolutions-1:
319 | hs.append(self.down[i_level].downsample(hs[-1]))
320 |
321 | # middle
322 | h = hs[-1]
323 |
324 | h = self.mid.block_1(h, temb)
325 | h = self.mid.attn_1(h)
326 | h = self.mid.block_2(h, temb)
327 |
328 | # upsampling
329 | for i_level in reversed(range(self.num_resolutions)):
330 | for i_block in range(self.num_res_blocks+1):
331 | h = self.up[i_level].block[i_block](
332 | torch.cat([h, hs.pop()], dim=1), temb)
333 | if len(self.up[i_level].attn) > 0:
334 | h = self.up[i_level].attn[i_block](h)
335 | if i_level != 0:
336 | h = self.up[i_level].upsample(h)
337 |
338 | # end
339 | h = self.norm_out(h)
340 | h = nonlinearity(h)
341 | h = self.conv_out(h)
342 | return h
343 |
--------------------------------------------------------------------------------
/models/ema.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | """
4 | A method that increases the stability of a model’s convergence and helps it reach a better overall solution by preventing convergence to a local minima.
5 | To avoid drastic changes in the model’s weights during training, a copy of the current weights is created before updating the model’s weights.
6 | Then the model’s weights are updated to be the weighted average between the current weights and the post-optimization step weights.
7 | """
8 |
9 |
10 | class EMAHelper(object):
11 | def __init__(self, mu=0.999):
12 | self.mu = mu
13 | self.shadow = {}
14 |
15 | def register(self, module):
16 | if isinstance(module, nn.DataParallel):
17 | module = module.module
18 | for name, param in module.named_parameters():
19 | if param.requires_grad:
20 | self.shadow[name] = param.data.clone()
21 |
22 | def update(self, module):
23 | if isinstance(module, nn.DataParallel):
24 | module = module.module
25 | for name, param in module.named_parameters():
26 | if param.requires_grad:
27 | self.shadow[name].data = (
28 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data
29 |
30 | def ema(self, module):
31 | if isinstance(module, nn.DataParallel):
32 | module = module.module
33 | for name, param in module.named_parameters():
34 | if param.requires_grad:
35 | param.data.copy_(self.shadow[name].data)
36 |
37 | def ema_copy(self, module):
38 | if isinstance(module, nn.DataParallel):
39 | inner_module = module.module
40 | module_copy = type(inner_module)(
41 | inner_module.config).to(inner_module.config.device)
42 | module_copy.load_state_dict(inner_module.state_dict())
43 | module_copy = nn.DataParallel(module_copy)
44 | else:
45 | module_copy = type(module)(module.config).to(module.config.device)
46 | module_copy.load_state_dict(module.state_dict())
47 | # module_copy = copy.deepcopy(module)
48 | self.ema(module_copy)
49 | return module_copy
50 |
51 | def state_dict(self):
52 | return self.shadow
53 |
54 | def load_state_dict(self, state_dict):
55 | self.shadow = state_dict
56 |
--------------------------------------------------------------------------------
/runners/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mirthAI/Fast-DDPM/649a14a6093d14f4286a6b6f9963dd208ce07928/runners/__init__.py
--------------------------------------------------------------------------------
/runners/diffusion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import time
4 | import glob
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import math
9 | import tqdm
10 | import torch
11 | import torch.utils.data as data
12 |
13 | from models.diffusion import Model
14 | from models.ema import EMAHelper
15 | from functions import get_optimizer
16 | from functions.losses import loss_registry, calculate_psnr
17 | from datasets import data_transform, inverse_data_transform
18 | from datasets.pmub import PMUB
19 | from datasets.LDFDCT import LDFDCT
20 | from datasets.BRATS import BRATS
21 | from functions.ckpt_util import get_ckpt_path
22 | from skimage.metrics import structural_similarity as ssim
23 | import torchvision.utils as tvu
24 | import torchvision
25 | from PIL import Image
26 |
27 |
28 | def torch2hwcuint8(x, clip=False):
29 | if clip:
30 | x = torch.clamp(x, -1, 1)
31 | x = (x + 1.0) / 2.0
32 | return x
33 |
34 |
35 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
36 | def sigmoid(x):
37 | return 1 / (np.exp(x) + 1)
38 | def tanh(x):
39 | return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))
40 |
41 | if beta_schedule == "quad":
42 | betas = (
43 | np.linspace(
44 | beta_start ** 0.5,
45 | beta_end ** 0.5,
46 | num_diffusion_timesteps,
47 | dtype=np.float64,
48 | )
49 | ** 2
50 | )
51 | elif beta_schedule == "linear":
52 | betas = np.linspace(
53 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
54 | )
55 | elif beta_schedule == "sigmoid":
56 | betas = np.linspace(-6, 6, num_diffusion_timesteps)
57 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
58 | elif beta_schedule =='alpha_cosine':
59 | s = 0.008
60 | timesteps = np.arange(0, num_diffusion_timesteps+1, dtype=np.float64)/num_diffusion_timesteps
61 | alphas_cumprod = np.cos((timesteps + s) / (1 + s) * math.pi * 0.5) ** 2
62 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
63 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
64 | betas = np.clip(betas, a_min=None, a_max=0.999)
65 | elif beta_schedule == 'alpha_sigmoid':
66 | x = np.linspace(-6, 6, 1001)
67 | alphas_cumprod = sigmoid(x)
68 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
69 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
70 | betas = np.clip(betas, a_min=None, a_max=0.999)
71 | elif beta_schedule == 'alpha_linear':
72 | timesteps = np.arange(0, num_diffusion_timesteps+1, dtype=np.float64)/num_diffusion_timesteps
73 | alphas_cumprod = -timesteps+1
74 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
75 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
76 | betas = np.clip(betas, a_min=None, a_max=0.999)
77 |
78 | else:
79 | raise NotImplementedError(beta_schedule)
80 | assert betas.shape == (num_diffusion_timesteps,)
81 | return betas
82 |
83 |
84 | class Diffusion(object):
85 | def __init__(self, args, config, device=None):
86 | self.args = args
87 | self.config = config
88 | if device is None:
89 | device = (
90 | torch.device("cuda")
91 | if torch.cuda.is_available()
92 | else torch.device("cpu")
93 | )
94 | self.device = device
95 |
96 | self.model_var_type = config.model.var_type
97 | betas = get_beta_schedule(
98 | beta_schedule=config.diffusion.beta_schedule,
99 | beta_start=config.diffusion.beta_start,
100 | beta_end=config.diffusion.beta_end,
101 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
102 | )
103 | betas = self.betas = torch.from_numpy(betas).float().to(self.device)
104 | self.num_timesteps = betas.shape[0]
105 |
106 | alphas = 1.0 - betas
107 | alphas_cumprod = alphas.cumprod(dim=0)
108 | alphas_cumprod_prev = torch.cat(
109 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
110 | )
111 | posterior_variance = (
112 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
113 | )
114 | if self.model_var_type == "fixedlarge":
115 | self.logvar = betas.log()
116 | elif self.model_var_type == "fixedsmall":
117 | self.logvar = posterior_variance.clamp(min=1e-20).log()
118 |
119 |
120 | # Training Fast-DDPM for tasks that have only one condition: image translation and CT denoising.
121 | def sg_train(self):
122 | args, config = self.args, self.config
123 | tb_logger = self.config.tb_logger
124 |
125 | if self.args.dataset=='LDFDCT':
126 | # LDFDCT for CT image denoising
127 | dataset = LDFDCT(self.config.data.train_dataroot, self.config.data.image_size, split='train')
128 | print('Start training your Fast-DDPM model on LDFDCT dataset.')
129 | elif self.args.dataset=='BRATS':
130 | # BRATS for brain image translation
131 | dataset = BRATS(self.config.data.train_dataroot, self.config.data.image_size, split='train')
132 | print('Start training your Fast-DDPM model on BRATS dataset.')
133 | print('The scheduler sampling type is {}. The number of involved time steps is {} out of 1000.'.format(self.args.scheduler_type, self.args.timesteps))
134 |
135 | train_loader = data.DataLoader(
136 | dataset,
137 | batch_size=config.training.batch_size,
138 | shuffle=True,
139 | num_workers=config.data.num_workers,
140 | pin_memory=True)
141 |
142 | model = Model(config)
143 | model = model.to(self.device)
144 | model = torch.nn.DataParallel(model)
145 |
146 | optimizer = get_optimizer(self.config, model.parameters())
147 |
148 | if self.config.model.ema:
149 | ema_helper = EMAHelper(mu=self.config.model.ema_rate)
150 | ema_helper.register(model)
151 | else:
152 | ema_helper = None
153 |
154 | start_epoch, step = 0, 0
155 | if self.args.resume_training:
156 | states = torch.load(os.path.join(self.args.log_path, "ckpt.pth"))
157 | model.load_state_dict(states[0])
158 |
159 | states[1]["param_groups"][0]["eps"] = self.config.optim.eps
160 | optimizer.load_state_dict(states[1])
161 | start_epoch = states[2]
162 | step = states[3]
163 | if self.config.model.ema:
164 | ema_helper.load_state_dict(states[4])
165 |
166 | for epoch in range(start_epoch, self.config.training.n_epochs):
167 | for i, x in enumerate(train_loader):
168 | n = x['LD'].size(0)
169 | model.train()
170 | step += 1
171 |
172 | x_img = x['LD'].to(self.device)
173 | x_gt = x['FD'].to(self.device)
174 |
175 | e = torch.randn_like(x_gt)
176 | b = self.betas
177 |
178 | if self.args.scheduler_type == 'uniform':
179 | skip = self.num_timesteps // self.args.timesteps
180 | t_intervals = torch.arange(-1, self.num_timesteps, skip)
181 | t_intervals[0] = 0
182 | elif self.args.scheduler_type == 'non-uniform':
183 | t_intervals = torch.tensor([0, 199, 399, 599, 699, 799, 849, 899, 949, 999])
184 |
185 | if self.args.timesteps != 10:
186 | num_1 = int(self.args.timesteps*0.4)
187 | num_2 = int(self.args.timesteps*0.6)
188 | stage_1 = torch.linspace(0, 699, num_1+1)[:-1]
189 | stage_2 = torch.linspace(699, 999, num_2)
190 | stage_1 = torch.ceil(stage_1).long()
191 | stage_2 = torch.ceil(stage_2).long()
192 | t_intervals = torch.cat((stage_1, stage_2))
193 | else:
194 | raise Exception("The scheduler type is either uniform or non-uniform.")
195 |
196 | # antithetic sampling
197 | idx_1 = torch.randint(0, len(t_intervals), size=(n // 2 + 1,))
198 | idx_2 = len(t_intervals)-idx_1-1
199 | idx = torch.cat([idx_1, idx_2], dim=0)[:n]
200 | t = t_intervals[idx].to(self.device)
201 |
202 | loss = loss_registry[config.model.type](model, x_img, x_gt, t, e, b)
203 |
204 | tb_logger.add_scalar("loss", loss, global_step=step)
205 |
206 | logging.info(
207 | f"step: {step}, loss: {loss.item()}"
208 | )
209 |
210 | optimizer.zero_grad()
211 | loss.backward()
212 |
213 | try:
214 | torch.nn.utils.clip_grad_norm_(
215 | model.parameters(), config.optim.grad_clip
216 | )
217 | except Exception:
218 | pass
219 | optimizer.step()
220 |
221 | if self.config.model.ema:
222 | ema_helper.update(model)
223 |
224 | if step % self.config.training.snapshot_freq == 0 or step == 1:
225 | states = [
226 | model.state_dict(),
227 | optimizer.state_dict(),
228 | epoch,
229 | step,
230 | ]
231 | if self.config.model.ema:
232 | states.append(ema_helper.state_dict())
233 |
234 | torch.save(
235 | states,
236 | os.path.join(self.args.log_path, "ckpt_{}.pth".format(step)),
237 | )
238 | torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))
239 |
240 |
241 | # Training Fast-DDPM for tasks that have two conditions: multi image super-resolution.
242 | def sr_train(self):
243 | args, config = self.args, self.config
244 | tb_logger = self.config.tb_logger
245 |
246 | dataset = PMUB(self.config.data.train_dataroot, self.config.data.image_size, split='train')
247 | print('Start training your Fast-DDPM model on PMUB dataset.')
248 | print('The scheduler sampling type is {}. The number of involved time steps is {} out of 1000.'.format(self.args.scheduler_type, self.args.timesteps))
249 | train_loader = data.DataLoader(
250 | dataset,
251 | batch_size=config.training.batch_size,
252 | shuffle=True,
253 | num_workers=config.data.num_workers,
254 | pin_memory=True)
255 |
256 | model = Model(config)
257 | model = model.to(self.device)
258 | model = torch.nn.DataParallel(model)
259 |
260 | optimizer = get_optimizer(self.config, model.parameters())
261 |
262 | if self.config.model.ema:
263 | ema_helper = EMAHelper(mu=self.config.model.ema_rate)
264 | ema_helper.register(model)
265 | else:
266 | ema_helper = None
267 |
268 | start_epoch, step = 0, 0
269 | if self.args.resume_training:
270 | states = torch.load(os.path.join(self.args.log_path, "ckpt.pth"))
271 | model.load_state_dict(states[0])
272 |
273 | states[1]["param_groups"][0]["eps"] = self.config.optim.eps
274 | optimizer.load_state_dict(states[1])
275 | start_epoch = states[2]
276 | step = states[3]
277 | if self.config.model.ema:
278 | ema_helper.load_state_dict(states[4])
279 |
280 | for epoch in range(start_epoch, self.config.training.n_epochs):
281 | for i, x in enumerate(train_loader):
282 | n = x['BW'].size(0)
283 | model.train()
284 | step += 1
285 |
286 | x_bw = x['BW'].to(self.device)
287 | x_md = x['MD'].to(self.device)
288 | x_fw = x['FW'].to(self.device)
289 |
290 | e = torch.randn_like(x_md)
291 | b = self.betas
292 |
293 | if self.args.scheduler_type == 'uniform':
294 | skip = self.num_timesteps // self.args.timesteps
295 | t_intervals = torch.arange(-1, self.num_timesteps, skip)
296 | t_intervals[0] = 0
297 | elif self.args.scheduler_type == 'non-uniform':
298 | t_intervals = torch.tensor([0, 199, 399, 599, 699, 799, 849, 899, 949, 999])
299 |
300 | if self.args.timesteps != 10:
301 | num_1 = int(self.args.timesteps*0.4)
302 | num_2 = int(self.args.timesteps*0.6)
303 | stage_1 = torch.linspace(0, 699, num_1+1)[:-1]
304 | stage_2 = torch.linspace(699, 999, num_2)
305 | stage_1 = torch.ceil(stage_1).long()
306 | stage_2 = torch.ceil(stage_2).long()
307 | t_intervals = torch.cat((stage_1, stage_2))
308 | else:
309 | raise Exception("The scheduler type is either uniform or non-uniform.")
310 |
311 | # antithetic sampling
312 | idx_1 = torch.randint(0, len(t_intervals), size=(n // 2 + 1,))
313 | idx_2 = len(t_intervals)-idx_1-1
314 | idx = torch.cat([idx_1, idx_2], dim=0)[:n]
315 | t = t_intervals[idx].to(self.device)
316 |
317 | loss = loss_registry[config.model.type](model, x_bw, x_md, x_fw, t, e, b)
318 |
319 | tb_logger.add_scalar("loss", loss, global_step=step)
320 |
321 | logging.info(
322 | f"step: {step}, loss: {loss.item()}"
323 | )
324 |
325 | optimizer.zero_grad()
326 | loss.backward()
327 |
328 | try:
329 | torch.nn.utils.clip_grad_norm_(
330 | model.parameters(), config.optim.grad_clip
331 | )
332 | except Exception:
333 | pass
334 | optimizer.step()
335 |
336 | if self.config.model.ema:
337 | ema_helper.update(model)
338 |
339 | if step % self.config.training.snapshot_freq == 0 or step == 1:
340 | states = [
341 | model.state_dict(),
342 | optimizer.state_dict(),
343 | epoch,
344 | step,
345 | ]
346 | if self.config.model.ema:
347 | states.append(ema_helper.state_dict())
348 |
349 | torch.save(
350 | states,
351 | os.path.join(self.args.log_path, "ckpt_{}.pth".format(step)),
352 | )
353 | torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))
354 |
355 |
356 | # Training original DDPM for tasks that have only one condition: image translation and CT denoising.
357 | def sg_ddpm_train(self):
358 | args, config = self.args, self.config
359 | tb_logger = self.config.tb_logger
360 |
361 | if self.args.dataset=='LDFDCT':
362 | # LDFDCT for CT image denoising
363 | dataset = LDFDCT(self.config.data.train_dataroot, self.config.data.image_size, split='train')
364 | print('Start training DDPM model on LDFDCT dataset.')
365 | elif self.args.dataset=='BRATS':
366 | # BRATS for brain image translation
367 | dataset = BRATS(self.config.data.train_dataroot, self.config.data.image_size, split='train')
368 | print('Start training DDPM model on BRATS dataset.')
369 |
370 | print('The number of involved time steps is {} out of 1000.'.format(self.args.timesteps))
371 | train_loader = data.DataLoader(
372 | dataset,
373 | batch_size=config.training.batch_size,
374 | shuffle=True,
375 | num_workers=config.data.num_workers,
376 | pin_memory=True)
377 |
378 | model = Model(config)
379 | model = model.to(self.device)
380 | model = torch.nn.DataParallel(model)
381 |
382 | optimizer = get_optimizer(self.config, model.parameters())
383 |
384 | if self.config.model.ema:
385 | ema_helper = EMAHelper(mu=self.config.model.ema_rate)
386 | ema_helper.register(model)
387 | else:
388 | ema_helper = None
389 |
390 | start_epoch, step = 0, 0
391 | if self.args.resume_training:
392 | states = torch.load(os.path.join(self.args.log_path, "ckpt.pth"))
393 | model.load_state_dict(states[0])
394 |
395 | states[1]["param_groups"][0]["eps"] = self.config.optim.eps
396 | optimizer.load_state_dict(states[1])
397 | start_epoch = states[2]
398 | step = states[3]
399 | if self.config.model.ema:
400 | ema_helper.load_state_dict(states[4])
401 |
402 | for epoch in range(start_epoch, self.config.training.n_epochs):
403 | for i, x in enumerate(train_loader):
404 | n = x['LD'].size(0)
405 | model.train()
406 | step += 1
407 |
408 | x_img = x['LD'].to(self.device)
409 | x_gt = x['FD'].to(self.device)
410 |
411 | e = torch.randn_like(x_gt)
412 | b = self.betas
413 |
414 | t = torch.randint(
415 | low=0, high=self.num_timesteps, size=(n // 2 + 1,)
416 | ).to(self.device)
417 | t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
418 |
419 | loss = loss_registry[config.model.type](model, x_img, x_gt, t, e, b)
420 |
421 | tb_logger.add_scalar("loss", loss, global_step=step)
422 |
423 | logging.info(
424 | f"step: {step}, loss: {loss.item()}"
425 | )
426 |
427 | optimizer.zero_grad()
428 | loss.backward()
429 |
430 | try:
431 | torch.nn.utils.clip_grad_norm_(
432 | model.parameters(), config.optim.grad_clip
433 | )
434 | except Exception:
435 | pass
436 | optimizer.step()
437 |
438 | if self.config.model.ema:
439 | ema_helper.update(model)
440 |
441 | if step % self.config.training.snapshot_freq == 0 or step == 1:
442 | states = [
443 | model.state_dict(),
444 | optimizer.state_dict(),
445 | epoch,
446 | step,
447 | ]
448 | if self.config.model.ema:
449 | states.append(ema_helper.state_dict())
450 |
451 | torch.save(
452 | states,
453 | os.path.join(self.args.log_path, "ckpt_{}.pth".format(step)),
454 | )
455 | torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))
456 |
457 |
458 | # Training original DDPM for tasks that have two conditions: multi image super-resolution.
459 | def sr_ddpm_train(self):
460 | args, config = self.args, self.config
461 | tb_logger = self.config.tb_logger
462 |
463 | dataset = PMUB(self.config.data.train_dataroot, self.config.data.image_size, split='train')
464 | print('Start training DDPM model on PMUB dataset.')
465 | print('The number of involved time steps is {} out of 1000.'.format(self.args.timesteps))
466 |
467 | train_loader = data.DataLoader(
468 | dataset,
469 | batch_size=config.training.batch_size,
470 | shuffle=True,
471 | num_workers=config.data.num_workers,
472 | pin_memory=True)
473 |
474 | model = Model(config)
475 | model = model.to(self.device)
476 | model = torch.nn.DataParallel(model)
477 |
478 | optimizer = get_optimizer(self.config, model.parameters())
479 |
480 | if self.config.model.ema:
481 | ema_helper = EMAHelper(mu=self.config.model.ema_rate)
482 | ema_helper.register(model)
483 | else:
484 | ema_helper = None
485 |
486 | start_epoch, step = 0, 0
487 | if self.args.resume_training:
488 | states = torch.load(os.path.join(self.args.log_path, "ckpt.pth"))
489 | model.load_state_dict(states[0])
490 |
491 | states[1]["param_groups"][0]["eps"] = self.config.optim.eps
492 | optimizer.load_state_dict(states[1])
493 | start_epoch = states[2]
494 | step = states[3]
495 | if self.config.model.ema:
496 | ema_helper.load_state_dict(states[4])
497 |
498 | time_start = time.time()
499 | total_time = 0
500 | for epoch in range(start_epoch, self.config.training.n_epochs):
501 | for i, x in enumerate(train_loader):
502 | n = x['BW'].size(0)
503 | model.train()
504 | step += 1
505 |
506 | x_bw = x['BW'].to(self.device)
507 | x_md = x['MD'].to(self.device)
508 | x_fw = x['FW'].to(self.device)
509 |
510 | e = torch.randn_like(x_md)
511 | b = self.betas
512 |
513 | # antithetic sampling
514 | t = torch.randint(
515 | low=0, high=self.num_timesteps, size=(n // 2 + 1,)
516 | ).to(self.device)
517 | t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
518 | loss = loss_registry[config.model.type](model, x_bw, x_md, x_fw, t, e, b)
519 |
520 | tb_logger.add_scalar("loss", loss, global_step=step)
521 |
522 | logging.info(
523 | f"step: {step}, loss: {loss.item()}"
524 | )
525 |
526 | optimizer.zero_grad()
527 | loss.backward()
528 |
529 | try:
530 | torch.nn.utils.clip_grad_norm_(
531 | model.parameters(), config.optim.grad_clip
532 | )
533 | except Exception:
534 | pass
535 | optimizer.step()
536 |
537 | if self.config.model.ema:
538 | ema_helper.update(model)
539 |
540 | if step % self.config.training.snapshot_freq == 0 or step == 1:
541 | states = [
542 | model.state_dict(),
543 | optimizer.state_dict(),
544 | epoch,
545 | step,
546 | ]
547 | if self.config.model.ema:
548 | states.append(ema_helper.state_dict())
549 |
550 | torch.save(
551 | states,
552 | os.path.join(self.args.log_path, "ckpt_{}.pth".format(step)),
553 | )
554 | torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))
555 |
556 |
557 | # Sampling for tasks that have two conditions: multi image super-resolution.
558 | def sr_sample(self):
559 | ckpt_list = self.config.sampling.ckpt_id
560 | for ckpt_idx in ckpt_list:
561 | self.ckpt_idx = ckpt_idx
562 | model = Model(self.config)
563 | print('Start inference on model of {} steps'.format(ckpt_idx))
564 |
565 | if not self.args.use_pretrained:
566 | states = torch.load(
567 | os.path.join(
568 | self.args.log_path, f"ckpt_{ckpt_idx}.pth"
569 | ),
570 | map_location=self.config.device,
571 | )
572 | model = model.to(self.device)
573 | model = torch.nn.DataParallel(model)
574 | model.load_state_dict(states[0], strict=True)
575 |
576 | if self.config.model.ema:
577 | ema_helper = EMAHelper(mu=self.config.model.ema_rate)
578 | ema_helper.register(model)
579 | ema_helper.load_state_dict(states[-1])
580 | ema_helper.ema(model)
581 | else:
582 | ema_helper = None
583 | else:
584 | # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion
585 | if self.config.data.dataset == "CIFAR10":
586 | name = "cifar10"
587 | elif self.config.data.dataset == "LSUN":
588 | name = f"lsun_{self.config.data.category}"
589 | else:
590 | raise ValueError
591 | ckpt = get_ckpt_path(f"ema_{name}")
592 | print("Loading checkpoint {}".format(ckpt))
593 | model.load_state_dict(torch.load(ckpt, map_location=self.device))
594 | model.to(self.device)
595 | model = torch.nn.DataParallel(model)
596 |
597 | model.eval()
598 |
599 | if self.args.fid:
600 | self.sr_sample_fid(model)
601 | elif self.args.interpolation:
602 | self.sr_sample_interpolation(model)
603 | elif self.args.sequence:
604 | self.sample_sequence(model)
605 | else:
606 | raise NotImplementedError("Sample procedeure not defined")
607 |
608 |
609 | # Sampling for tasks that have only one condition: image translation and CT denoising.
610 | def sg_sample(self):
611 | ckpt_list = self.config.sampling.ckpt_id
612 | for ckpt_idx in ckpt_list:
613 | self.ckpt_idx = ckpt_idx
614 | model = Model(self.config)
615 | print('Start inference on model of {} steps'.format(ckpt_idx))
616 |
617 | if not self.args.use_pretrained:
618 | states = torch.load(
619 | os.path.join(
620 | self.args.log_path, f"ckpt_{ckpt_idx}.pth"
621 | ),
622 | map_location=self.config.device,
623 | )
624 | model = model.to(self.device)
625 | model = torch.nn.DataParallel(model)
626 | model.load_state_dict(states[0], strict=True)
627 |
628 | if self.config.model.ema:
629 | ema_helper = EMAHelper(mu=self.config.model.ema_rate)
630 | ema_helper.register(model)
631 | ema_helper.load_state_dict(states[-1])
632 | ema_helper.ema(model)
633 | else:
634 | ema_helper = None
635 | else:
636 | # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion
637 | if self.config.data.dataset == "CIFAR10":
638 | name = "cifar10"
639 | elif self.config.data.dataset == "LSUN":
640 | name = f"lsun_{self.config.data.category}"
641 | else:
642 | raise ValueError
643 | ckpt = get_ckpt_path(f"ema_{name}")
644 | print("Loading checkpoint {}".format(ckpt))
645 | model.load_state_dict(torch.load(ckpt, map_location=self.device))
646 | model.to(self.device)
647 | model = torch.nn.DataParallel(model)
648 |
649 | model.eval()
650 |
651 | if self.args.fid:
652 | self.sg_sample_fid(model)
653 | elif self.args.interpolation:
654 | self.sr_sample_interpolation(model)
655 | elif self.args.sequence:
656 | self.sample_sequence(model)
657 | else:
658 | raise NotImplementedError("Sample procedeure not defined")
659 |
660 |
661 | def sr_sample_fid(self, model):
662 | config = self.config
663 | img_id = len(glob.glob(f"{self.args.image_folder}/*"))
664 | print(f"starting from image {img_id}")
665 |
666 | sample_dataset = PMUB(self.config.data.sample_dataroot, self.config.data.image_size, split='calculate')
667 | print('Start sampling model on PMUB dataset.')
668 | print('The inference sample type is {}. The scheduler sampling type is {}. The number of involved time steps is {} out of 1000.'.format(self.args.sample_type, self.args.scheduler_type, self.args.timesteps))
669 |
670 | sample_loader = data.DataLoader(
671 | sample_dataset,
672 | batch_size=config.sampling_fid.batch_size,
673 | shuffle=False,
674 | num_workers=config.data.num_workers)
675 |
676 | with torch.no_grad():
677 | data_num = len(sample_dataset)
678 | print('The length of test set is:', data_num)
679 | avg_psnr = 0.0
680 | avg_ssim = 0.0
681 | time_list = []
682 | psnr_list = []
683 | ssim_list = []
684 |
685 | for batch_idx, img in tqdm.tqdm(enumerate(sample_loader), desc="Generating image samples for FID evaluation."):
686 | n = img['BW'].shape[0]
687 |
688 | x = torch.randn(
689 | n,
690 | config.data.channels,
691 | config.data.image_size,
692 | config.data.image_size,
693 | device=self.device,
694 | )
695 | x_bw = img['BW'].to(self.device)
696 | x_md = img['MD'].to(self.device)
697 | x_fw = img['FW'].to(self.device)
698 | case_name = img['case_name'][0]
699 |
700 | time_start = time.time()
701 | x = self.sr_sample_image(x, x_bw, x_fw, model)
702 | time_end = time.time()
703 |
704 | x = inverse_data_transform(config, x)
705 | x_md = inverse_data_transform(config, x_md)
706 | x_tensor = x
707 | x_md_tensor = x_md
708 | x_md = x_md.squeeze().float().cpu().numpy()
709 | x = x.squeeze().float().cpu().numpy()
710 | x_md = (x_md*255.0).round()
711 | x = (x*255.0).round()
712 |
713 | PSNR = 0.0
714 | SSIM = 0.0
715 | for i in range(x.shape[0]):
716 | psnr_temp = calculate_psnr(x[i,:,:], x_md[i,:,:])
717 | ssim_temp = ssim(x_md[i,:,:], x[i,:,:], data_range=255)
718 | PSNR += psnr_temp
719 | SSIM += ssim_temp
720 | psnr_list.append(psnr_temp)
721 | ssim_list.append(ssim_temp)
722 |
723 | PSNR_print = PSNR/x.shape[0]
724 | SSIM_print = SSIM/x.shape[0]
725 |
726 | case_time = time_end-time_start
727 | time_list.append(case_time)
728 |
729 | avg_psnr += PSNR
730 | avg_ssim += SSIM
731 | logging.info('Case {}: PSNR {}, SSIM {}, time {}'.format(case_name, PSNR_print, SSIM_print, case_time))
732 |
733 | for i in range(0, n):
734 | # image:(0-1)
735 | tvu.save_image(
736 | x_tensor[i], os.path.join(self.args.image_folder, "{}_{}_pt.png".format(self.ckpt_idx, img_id))
737 | )
738 | tvu.save_image(
739 | x_md_tensor[i], os.path.join(self.args.image_folder, "{}_{}_gt.png".format(self.ckpt_idx, img_id))
740 | )
741 | img_id += 1
742 |
743 | avg_psnr = avg_psnr / data_num
744 | avg_ssim = avg_ssim / data_num
745 | # Drop first and last for time calculation.
746 | avg_time = sum(time_list[1:-1])/(len(time_list)-2)
747 | logging.info('Average: PSNR {}, SSIM {}, time {}'.format(avg_psnr, avg_ssim, avg_time))
748 |
749 |
750 | def sg_sample_fid(self, model):
751 | config = self.config
752 | img_id = len(glob.glob(f"{self.args.image_folder}/*"))
753 | print(f"starting from image {img_id}")
754 |
755 |
756 | if self.args.dataset=='LDFDCT':
757 | # LDFDCT for CT image denoising
758 | sample_dataset = LDFDCT(self.config.data.sample_dataroot, self.config.data.image_size, split='calculate')
759 | print('Start training model on LDFDCT dataset.')
760 | elif self.args.dataset=='BRATS':
761 | # BRATS for brain image translation
762 | sample_dataset = BRATS(self.config.data.sample_dataroot, self.config.data.image_size, split='calculate')
763 | print('Start training model on BRATS dataset.')
764 | print('The inference sample type is {}. The scheduler sampling type is {}. The number of involved time steps is {} out of 1000.'.format(self.args.sample_type, self.args.scheduler_type, self.args.timesteps))
765 |
766 | sample_loader = data.DataLoader(
767 | sample_dataset,
768 | batch_size=config.sampling_fid.batch_size,
769 | shuffle=False,
770 | num_workers=config.data.num_workers)
771 |
772 | with torch.no_grad():
773 | data_num = len(sample_dataset)
774 | print('The length of test set is:', data_num)
775 | avg_psnr = 0.0
776 | avg_ssim = 0.0
777 | time_list = []
778 | psnr_list = []
779 | ssim_list = []
780 |
781 | for batch_idx, sample in tqdm.tqdm(enumerate(sample_loader), desc="Generating image samples for FID evaluation."):
782 | n = sample['LD'].shape[0]
783 |
784 | x = torch.randn(
785 | n,
786 | config.data.channels,
787 | config.data.image_size,
788 | config.data.image_size,
789 | device=self.device,
790 | )
791 | x_img = sample['LD'].to(self.device)
792 | x_gt = sample['FD'].to(self.device)
793 | case_name = sample['case_name']
794 |
795 | time_start = time.time()
796 | x = self.sg_sample_image(x, x_img, model)
797 | time_end = time.time()
798 |
799 | x = inverse_data_transform(config, x)
800 | x_gt = inverse_data_transform(config, x_gt)
801 | x_tensor = x
802 | x_gt_tensor = x_gt
803 | x_gt = x_gt.squeeze().float().cpu().numpy()
804 | x = x.squeeze().float().cpu().numpy()
805 | x_gt = x_gt*255
806 | x = x*255
807 |
808 | PSNR = 0.0
809 | SSIM = 0.0
810 | for i in range(x.shape[0]):
811 | psnr_temp = calculate_psnr(x[i,:,:], x_gt[i,:,:])
812 | ssim_temp = ssim(x_gt[i,:,:], x[i,:,:], data_range=255)
813 | PSNR += psnr_temp
814 | SSIM += ssim_temp
815 | psnr_list.append(psnr_temp)
816 | ssim_list.append(ssim_temp)
817 |
818 | PSNR_print = PSNR/x.shape[0]
819 | SSIM_print = SSIM/x.shape[0]
820 |
821 | case_time = time_end-time_start
822 | time_list.append(case_time)
823 |
824 | avg_psnr += PSNR
825 | avg_ssim += SSIM
826 | logging.info('Case {}: PSNR {}, SSIM {}, time {}'.format(case_name[0], PSNR_print, SSIM_print, case_time))
827 |
828 | for i in range(0, n):
829 | # image:(0-1)
830 | tvu.save_image(
831 | x_tensor[i], os.path.join(self.args.image_folder, "{}_{}_pt.png".format(self.ckpt_idx, img_id))
832 | )
833 | tvu.save_image(
834 | x_gt_tensor[i], os.path.join(self.args.image_folder, "{}_{}_gt.png".format(self.ckpt_idx, img_id))
835 | )
836 | img_id += 1
837 |
838 | avg_psnr = avg_psnr / data_num
839 | avg_ssim = avg_ssim / data_num
840 | # Drop first and last for time calculation.
841 | avg_time = sum(time_list[1:-1])/(len(time_list)-2)
842 | logging.info('Average: PSNR {}, SSIM {}, time {}'.format(avg_psnr, avg_ssim, avg_time))
843 |
844 |
845 | def sr_sample_image(self, x, x_bw, x_fw, model, last=True):
846 | try:
847 | skip = self.args.skip
848 | except Exception:
849 | skip = 1
850 |
851 | if self.args.sample_type == "generalized":
852 | if self.args.scheduler_type == 'uniform':
853 | skip = self.num_timesteps // self.args.timesteps
854 | seq = range(-1, self.num_timesteps, skip)
855 | seq = list(seq)
856 | seq[0] = 0
857 | elif self.args.scheduler_type == 'non-uniform':
858 | seq = [0, 199, 399, 599, 699, 799, 849, 899, 949, 999]
859 |
860 | if self.args.timesteps != 10:
861 | num_1 = int(self.args.timesteps*0.4)
862 | num_2 = int(self.args.timesteps*0.6)
863 | stage_1 = np.linspace(0, 699, num_1+1)[:-1]
864 | stage_2 = np.linspace(699, 999, num_2)
865 | stage_1 = np.ceil(stage_1).astype(int)
866 | stage_2 = np.ceil(stage_2).astype(int)
867 | seq = np.concatenate((stage_1, stage_2))
868 | else:
869 | raise Exception("The scheduler type is either uniform or non-uniform.")
870 |
871 | from functions.denoising import generalized_steps, sr_generalized_steps
872 |
873 | xs = sr_generalized_steps(x, x_bw, x_fw, seq, model, self.betas, eta=self.args.eta)
874 | x = xs
875 |
876 | elif self.args.sample_type == "ddpm_noisy":
877 | skip = self.num_timesteps // self.args.timesteps
878 | seq = range(0, self.num_timesteps, skip)
879 |
880 | from functions.denoising import ddpm_steps, sr_ddpm_steps
881 |
882 | x = sr_ddpm_steps(x, x_bw, x_fw, seq, model, self.betas)
883 | else:
884 | raise NotImplementedError
885 | if last:
886 | x = x[0][-1]
887 | return x
888 |
889 |
890 | def sg_sample_image(self, x, x_img, model, last=True):
891 | try:
892 | skip = self.args.skip
893 | except Exception:
894 | skip = 1
895 |
896 | if self.args.sample_type == "generalized":
897 | if self.args.scheduler_type == 'uniform':
898 | skip = self.num_timesteps // self.args.timesteps
899 | seq = range(-1, self.num_timesteps, skip)
900 | seq = list(seq)
901 | seq[0] = 0
902 | elif self.args.scheduler_type == 'non-uniform':
903 | seq = [0, 199, 399, 599, 699, 799, 849, 899, 949, 999]
904 |
905 | if self.args.timesteps != 10:
906 | num_1 = int(self.args.timesteps*0.4)
907 | num_2 = int(self.args.timesteps*0.6)
908 | stage_1 = np.linspace(0, 699, num_1+1)[:-1]
909 | stage_2 = np.linspace(699, 999, num_2)
910 | stage_1 = np.ceil(stage_1).astype(int)
911 | stage_2 = np.ceil(stage_2).astype(int)
912 | seq = np.concatenate((stage_1, stage_2))
913 | else:
914 | raise Exception("The scheduler type is either uniform or non-uniform.")
915 |
916 | from functions.denoising import generalized_steps, sr_generalized_steps, sg_generalized_steps
917 |
918 | xs = sg_generalized_steps(x, x_img, seq, model, self.betas, eta=self.args.eta)
919 | x = xs
920 |
921 | elif self.args.sample_type == "ddpm_noisy":
922 | skip = self.num_timesteps // self.args.timesteps
923 | seq = range(0, self.num_timesteps, skip)
924 |
925 | from functions.denoising import ddpm_steps, sr_ddpm_steps, sg_ddpm_steps
926 |
927 | x = sg_ddpm_steps(x, x_img, seq, model, self.betas)
928 | else:
929 | raise NotImplementedError
930 | if last:
931 | x = x[0][-1]
932 | return x
933 |
934 |
935 | def test(self):
936 | pass
937 |
--------------------------------------------------------------------------------