├── LICENSE ├── Overview.png ├── README.md ├── configs ├── brats_linear.yml ├── ldfd_linear.yml └── pmub_linear.yml ├── datasets ├── BRATS.py ├── LDFDCT.py ├── __init__.py ├── pmub.py ├── sr_util.py └── utils.py ├── ddpm_main.py ├── fast_ddpm_main.py ├── functions ├── __init__.py ├── ckpt_util.py ├── denoising.py └── losses.py ├── models ├── diffusion.py └── ema.py └── runners ├── __init__.py └── diffusion.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 mirth AI lab at UF 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mirthAI/Fast-DDPM/649a14a6093d14f4286a6b6f9963dd208ce07928/Overview.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast-DDPM 2 | 3 | Official PyTorch implementation of: 4 | 5 | [Fast-DDPM: Fast Denoising Diffusion Probabilistic Models for Medical Image-to-Image Generation](https://ieeexplore.ieee.org/abstract/document/10979336) (JBHI 2025) 6 | 7 | We propose Fast-DDPM, a simple yet effective approach that improves training speed, sampling speed, and generation quality of diffusion models simultaneously. Fast-DDPM trains and samples using only 10 time steps, reducing the training time to 0.2x and the sampling time to 0.01x compared to DDPM. 8 | 9 |

10 | DDPM vs. Fast-DDPM 11 |

12 | 13 | The code is only for research purposes. If you have any questions regarding how to use this code, feel free to contact Hongxu Jiang (hongxu.jiang@medicine.ufl.edu). 14 | 15 | ## Requirements 16 | * Python==3.10.6 17 | * torch==1.12.1 18 | * torchvision==0.15.2 19 | * numpy 20 | * opencv-python 21 | * tqdm 22 | * tensorboard 23 | * tensorboardX 24 | * scikit-image 25 | * medpy 26 | * pillow 27 | * scipy 28 | * `pip install -r requirements.txt` 29 | 30 | ## Publicly available Dataset 31 | - Prostate-MRI-US-Biopsy dataset 32 | - LDCT-and-Projection-data dataset 33 | - BraTS 2018 dataset 34 | - The processed dataset can be accessed here: https://drive.google.com/file/d/1kF0g8fMR5XPQ2FTbutfTQ-hwG_mTqerx/view?usp=drive_link. 35 | 36 | ## Usage 37 | ### 1. Git clone or download the codes. 38 | 39 | ### 2. Pretrained model weights 40 | * We provide pretrained model weights for all three tasks, where you can access them here: https://drive.google.com/file/d/1ndS-eLegqwCOUoLT1B-HQiqRQqZUMKVF/view?usp=sharing. 41 | * As shown in ablation study, the defaulted 10 time steps may not be optimal for every task, you're more welcome to train Fast-DDPM model on your dataset using different settings. 42 | 43 | ### 3. Prepare data 44 | * Please download our processed dataset or download from the official websites. 45 | * After downloading, extract the file and put it into folder "data/". The directory structure should be as follows: 46 | 47 | ```bash 48 | ├── configs 49 | │ 50 | ├── data 51 | │ ├── LD_FD_CT_train 52 | │ ├── LD_FD_CT_test 53 | │ ├── PMUB-train 54 | │ ├── PMUB-test 55 | │ ├── Brats_train 56 | │ └── Brats_test 57 | │ 58 | ├── datasets 59 | │ 60 | ├── functions 61 | │ 62 | ├── models 63 | │ 64 | └── runners 65 | 66 | ``` 67 | 68 | ### 4. Training/Sampling a Fast-DDPM model 69 | * Please make sure that the hyperparameters such as scheduler type and timesteps are consistent between training and sampling. 70 | * The total number of time steps is defaulted as 1000 in the paper, so the number of involved time steps for Fast-DDPM should be less than 1000 as an integer. 71 | ``` 72 | python fast_ddpm_main.py --config {DATASET}.yml --dataset {DATASET_NAME} --exp {PROJECT_PATH} --doc {MODEL_NAME} --scheduler_type {SAMPLING STRATEGY} --timesteps {STEPS} 73 | ``` 74 | ``` 75 | python fast_ddpm_main.py --config {DATASET}.yml --dataset {DATASET_NAME} --exp {PROJECT_PATH} --doc {MODEL_NAME} --sample --fid --scheduler_type {SAMPLING STRATEGY} --timesteps {STEPS} 76 | ``` 77 | 78 | where 79 | - `DATASET_NAME` should be selected among `LDFDCT` for image denoising task, `BRATS` for image-to-image translation task and `PMUB` for multi image super-resolution task. 80 | - `SAMPLING STRATEGY` controls the scheduler sampling strategy proposed in the paper (either uniform or non-uniform). 81 | - `STEPS` controls how many timesteps used in the training and inference process. It should be an integer less than 1000 for Fast-DDPM, which is 10 by default. 82 | 83 | 84 | ### 5. Training/Sampling a DDPM model 85 | * Please make sure that the hyperparameters such as scheduler type and timesteps are consistent between training and sampling. 86 | * The total number of time steps is defaulted as 1000 in the paper, so the number of time steps for DDPM is defaulted as 1000. 87 | ``` 88 | python ddpm_main.py --config {DATASET}.yml --dataset {DATASET_NAME} --exp {PROJECT_PATH} --doc {MODEL_NAME} --timesteps {STEPS} 89 | ``` 90 | ``` 91 | python ddpm_main.py --config {DATASET}.yml --dataset {DATASET_NAME} --exp {PROJECT_PATH} --doc {MODEL_NAME} --sample --fid --timesteps {STEPS} 92 | ``` 93 | 94 | where 95 | - `DATASET_NAME` should be selected among `LDFDCT` for image denoising task, `BRATS` for image-to-image translation task and `PMUB` for multi image super-resolution task. 96 | - `STEPS` controls how many timesteps used in the training and inference process. It should be 1000 in the setting of this paper. 97 | 98 | 99 | ## References 100 | * The code is mainly adapted from [DDIM](https://github.com/ermongroup/ddim). 101 | 102 | 103 | ## Citations 104 | If you use our code or dataset, please cite our paper as below: 105 | ```bibtex 106 | @article{jiang2025fast, 107 | title={Fast-DDPM: Fast denoising diffusion probabilistic models for medical image-to-image generation}, 108 | author={Jiang, Hongxu and Imran, Muhammad and Zhang, Teng and Zhou, Yuyin and Liang, Muxuan and Gong, Kuang and Shao, Wei}, 109 | journal={IEEE Journal of Biomedical and Health Informatics}, 110 | year={2025}, 111 | publisher={IEEE} 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /configs/brats_linear.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "PMUB" 3 | train_dataroot: "data/Brats_train" 4 | sample_dataroot: "data/Brats_test" 5 | image_size: 256 6 | channels: 1 7 | logit_transform: false 8 | uniform_dequantization: false 9 | gaussian_dequantization: false 10 | random_flip: true 11 | rescaled: true 12 | num_workers: 8 13 | 14 | model: 15 | type: "sg" 16 | in_channels: 2 17 | out_ch: 1 18 | ch: 128 19 | ch_mult: [1, 1, 2, 2, 4, 4] 20 | num_res_blocks: 2 21 | attn_resolutions: [16, ] 22 | dropout: 0.0 23 | var_type: fixedsmall 24 | ema_rate: 0.999 25 | ema: True 26 | resamp_with_conv: True 27 | 28 | diffusion: 29 | beta_schedule: linear 30 | beta_start: 0.0001 31 | beta_end: 0.02 32 | num_diffusion_timesteps: 1000 33 | 34 | training: 35 | batch_size: 16 36 | n_epochs: 10000 37 | n_iters: 5000000 38 | snapshot_freq: 100000 39 | validation_freq: 5000000000 40 | 41 | sampling: 42 | batch_size: 8 43 | ckpt_id: [100000, 200000, 300000, 400000, 500000, 600000, 700000, 800000, 900000, 1000000, 1500000, 2000000] 44 | last_only: True 45 | 46 | sampling_inter: 47 | batch_size: 59 48 | last_only: True 49 | 50 | sampling_fid: 51 | batch_size: 128 52 | last_only: True 53 | 54 | optim: 55 | weight_decay: 0.000 56 | optimizer: "Adam" 57 | lr: 0.00002 58 | beta1: 0.9 59 | amsgrad: false 60 | eps: 0.00000001 61 | -------------------------------------------------------------------------------- /configs/ldfd_linear.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "PMUB" 3 | train_dataroot: "data/LD_FD_CT_train" 4 | val_dataroot: "data/PMUB-val" 5 | sample_dataroot: "data/LD_FD_CT_test" 6 | image_size: 256 7 | channels: 1 8 | logit_transform: false 9 | uniform_dequantization: false 10 | gaussian_dequantization: false 11 | random_flip: true 12 | rescaled: true 13 | num_workers: 8 14 | 15 | model: 16 | type: "sg" 17 | in_channels: 2 18 | out_ch: 1 19 | ch: 128 20 | ch_mult: [1, 1, 2, 2, 4, 4] 21 | num_res_blocks: 2 22 | attn_resolutions: [16,] 23 | dropout: 0.0 24 | var_type: fixedsmall 25 | ema_rate: 0.999 26 | ema: True 27 | resamp_with_conv: True 28 | 29 | diffusion: 30 | beta_schedule: linear 31 | beta_start: 0.0001 32 | beta_end: 0.02 33 | num_diffusion_timesteps: 1000 34 | 35 | training: 36 | batch_size: 16 37 | n_epochs: 10000 38 | n_iters: 5000000 39 | snapshot_freq: 100000 40 | validation_freq: 5000000000 41 | 42 | sampling: 43 | batch_size: 8 44 | ckpt_id: [100000, 200000, 300000, 400000, 500000, 600000, 700000, 800000, 900000, 1000000, 1500000, 2000000] 45 | last_only: True 46 | 47 | sampling_inter: 48 | batch_size: 59 49 | last_only: True 50 | 51 | sampling_fid: 52 | batch_size: 128 53 | last_only: True 54 | 55 | optim: 56 | weight_decay: 0.000 57 | optimizer: "Adam" 58 | lr: 0.00002 59 | beta1: 0.9 60 | amsgrad: false 61 | eps: 0.00000001 62 | -------------------------------------------------------------------------------- /configs/pmub_linear.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "PMUB" 3 | train_dataroot: "data/PMUB-train" 4 | val_dataroot: "data/PMUB-val" 5 | sample_dataroot: "data/PMUB-test" 6 | image_size: 256 7 | channels: 1 8 | logit_transform: false 9 | uniform_dequantization: false 10 | gaussian_dequantization: false 11 | random_flip: true 12 | rescaled: true 13 | num_workers: 8 14 | 15 | model: 16 | type: "sr" 17 | in_channels: 3 18 | out_ch: 1 19 | ch: 128 20 | ch_mult: [1, 1, 2, 2, 4, 4] 21 | num_res_blocks: 2 22 | attn_resolutions: [16] 23 | dropout: 0.0 24 | var_type: fixedsmall 25 | ema_rate: 0.999 26 | ema: True 27 | resamp_with_conv: True 28 | 29 | diffusion: 30 | beta_schedule: linear 31 | beta_start: 0.0001 32 | beta_end: 0.02 33 | num_diffusion_timesteps: 1000 34 | 35 | training: 36 | batch_size: 16 37 | n_epochs: 10000 38 | n_iters: 5000000 39 | snapshot_freq: 100000 40 | validation_freq: 5000000000 41 | 42 | sampling: 43 | batch_size: 8 44 | ckpt_id: [100000, 200000, 300000, 400000, 500000, 600000, 700000, 800000, 900000, 1000000, 1500000, 2000000] 45 | last_only: True 46 | 47 | sampling_inter: 48 | batch_size: 59 49 | last_only: True 50 | 51 | sampling_fid: 52 | batch_size: 58 53 | last_only: True 54 | 55 | optim: 56 | weight_decay: 0.000 57 | optimizer: "Adam" 58 | lr: 0.00002 59 | beta1: 0.9 60 | amsgrad: false 61 | eps: 0.00000001 62 | -------------------------------------------------------------------------------- /datasets/BRATS.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import random 5 | import torch 6 | 7 | from .sr_util import get_paths_from_npys, brats_transform_augment 8 | 9 | 10 | class BRATS(Dataset): 11 | def __init__(self, dataroot, img_size, split='train', data_len=-1): 12 | self.img_size = img_size 13 | self.data_len = data_len 14 | self.split = split 15 | img_root = dataroot + '/A/' 16 | gt_root = dataroot + '/B/' 17 | self.img_npy_path, self.gt_npy_path = get_paths_from_npys(img_root, gt_root) 18 | self.data_len = len(self.img_npy_path) 19 | 20 | def __len__(self): 21 | return self.data_len 22 | 23 | def __getitem__(self, index): 24 | img_FD = None 25 | img_LD = None 26 | base_name = None 27 | extension = None 28 | number = None 29 | FW_path = None 30 | BW_path = None 31 | 32 | base_name = self.img_npy_path[index].split('/')[-1] 33 | case_name = base_name.split('.')[0] 34 | 35 | img_npy = np.load(self.img_npy_path[index]) 36 | img = Image.fromarray(img_npy) 37 | gt_npy = np.load(self.gt_npy_path[index]) 38 | gt = Image.fromarray(gt_npy) 39 | img = img.resize((self.img_size, self.img_size)) 40 | gt = gt.resize((self.img_size, self.img_size)) 41 | 42 | [img, gt] = brats_transform_augment( 43 | [img, gt], split=self.split) 44 | 45 | return {'FD': gt, 'LD': img, 'case_name': case_name} 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /datasets/LDFDCT.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import lmdb 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import random 6 | import torch 7 | from .sr_util import get_paths_from_images, get_valid_paths_from_images, get_valid_paths_from_test_images, transform_augment 8 | 9 | 10 | class LDFDCT(Dataset): 11 | def __init__(self, dataroot, img_size, split='train', data_len=-1): 12 | self.img_size = img_size 13 | self.data_len = data_len 14 | self.split = split 15 | self.img_ld_path, self.img_fd_path = get_paths_from_images(dataroot) 16 | self.data_len = len(self.img_ld_path) 17 | 18 | def __len__(self): 19 | return self.data_len 20 | 21 | def __getitem__(self, index): 22 | img_FD = None 23 | img_LD = None 24 | base_name = None 25 | extension = None 26 | number = None 27 | FW_path = None 28 | BW_path = None 29 | 30 | base_name = self.img_ld_path[index].split('/')[-1] 31 | case_name = base_name.split('_')[0] 32 | 33 | img_LD = Image.open(self.img_ld_path[index]).convert("L") 34 | img_FD = Image.open(self.img_fd_path[index]).convert("L") 35 | img_LD = img_LD.resize((self.img_size, self.img_size)) 36 | img_FD = img_FD.resize((self.img_size, self.img_size)) 37 | 38 | [img_LD, img_FD] = transform_augment( 39 | [img_LD, img_FD], split=self.split, min_max=(-1, 1)) 40 | 41 | return {'FD': img_FD, 'LD': img_LD, 'case_name': case_name} 42 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numbers 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as F 6 | from torch.utils.data import Subset 7 | import numpy as np 8 | 9 | 10 | class Crop(object): 11 | def __init__(self, x1, x2, y1, y2): 12 | self.x1 = x1 13 | self.x2 = x2 14 | self.y1 = y1 15 | self.y2 = y2 16 | 17 | def __call__(self, img): 18 | return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1) 19 | 20 | def __repr__(self): 21 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( 22 | self.x1, self.x2, self.y1, self.y2 23 | ) 24 | 25 | 26 | 27 | def logit_transform(image, lam=1e-6): 28 | image = lam + (1 - 2 * lam) * image 29 | return torch.log(image) - torch.log1p(-image) 30 | 31 | 32 | def data_transform(config, X): 33 | if config.data.uniform_dequantization: 34 | X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0 35 | if config.data.gaussian_dequantization: 36 | X = X + torch.randn_like(X) * 0.01 37 | 38 | if config.data.rescaled: 39 | X = 2 * X - 1.0 40 | elif config.data.logit_transform: 41 | X = logit_transform(X) 42 | 43 | if hasattr(config, "image_mean"): 44 | return X - config.image_mean.to(X.device)[None, ...] 45 | 46 | return X 47 | 48 | 49 | def inverse_data_transform(config, X): 50 | if hasattr(config, "image_mean"): 51 | X = X + config.image_mean.to(X.device)[None, ...] 52 | 53 | if config.data.logit_transform: 54 | X = torch.sigmoid(X) 55 | elif config.data.rescaled: 56 | X = (X + 1.0) / 2.0 57 | 58 | return torch.clamp(X, 0.0, 1.0) 59 | -------------------------------------------------------------------------------- /datasets/pmub.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import lmdb 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import torch 6 | import random 7 | import matplotlib.pyplot as plt 8 | from .sr_util import get_valid_paths_from_images, get_valid_paths_from_test_images, transform_augment 9 | 10 | 11 | class PMUB(Dataset): 12 | def __init__(self, dataroot, img_size, split='train', data_len=-1): 13 | self.img_size = img_size 14 | self.data_len = data_len 15 | self.split = split 16 | 17 | self.img_path = get_valid_paths_from_images(dataroot) 18 | self.test_img_path = get_valid_paths_from_test_images(dataroot) 19 | 20 | if self.split == 'test': 21 | self.dataset_len = len(self.test_img_path) 22 | else: 23 | self.dataset_len = len(self.img_path) 24 | 25 | if self.data_len <= 0: 26 | self.data_len = self.dataset_len 27 | else: 28 | self.data_len = min(self.data_len, self.dataset_len) 29 | 30 | def __len__(self): 31 | return self.data_len 32 | 33 | def __getitem__(self, index): 34 | img_FW = None 35 | img_MD = None 36 | img_BW = None 37 | base_name = None 38 | extension = None 39 | number = None 40 | FW_path = None 41 | BW_path = None 42 | 43 | base_name = self.img_path[index].split('_')[0] 44 | case_name = int(base_name.split('/')[-1].split('-')[-1]) 45 | extension = self.img_path[index].split('_')[-1].split('.')[-1] 46 | number = int(self.img_path[index].split('_')[-1].split('.')[0]) 47 | FW_path = base_name + '_' + str(number+1) + '.' + extension 48 | BW_path = base_name + '_' + str(number-1) + '.' + extension 49 | 50 | img_BW = Image.open(BW_path).convert("L") 51 | img_MD = Image.open(self.img_path[index]).convert("L") 52 | img_FW = Image.open(FW_path).convert("L") 53 | 54 | img_BW = img_BW.resize((self.img_size, self.img_size)) 55 | img_MD = img_MD.resize((self.img_size, self.img_size)) 56 | img_FW = img_FW.resize((self.img_size, self.img_size)) 57 | 58 | [img_BW, img_MD, img_FW] = transform_augment( 59 | [img_BW, img_MD, img_FW], split=self.split, min_max=(-1, 1)) 60 | 61 | return {'BW': img_BW, 'MD': img_MD, 'FW': img_FW, 'Index': index, 'case_name': case_name} 62 | -------------------------------------------------------------------------------- /datasets/sr_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import random 5 | import numpy as np 6 | import glob 7 | 8 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 10 | 11 | 12 | def is_image_file(filename): 13 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 14 | 15 | 16 | def extract_number(filename): 17 | number = int(filename.split('_')[1].split('.')[0]) 18 | return number 19 | 20 | # LDFDCT 21 | def get_paths_from_images(path): 22 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 23 | 24 | ld_images = glob.glob(path + "**/**/*ld.png", recursive=True) 25 | fd_images = glob.glob(path + "**/**/*fd.png", recursive=True) 26 | 27 | assert ld_images, '{:s} has no valid ld image file'.format(path) 28 | assert fd_images, '{:s} has no valid fd image file'.format(path) 29 | assert len(ld_images) == len(fd_images), 'Low Dose images nd Full Dose images are not paired!' 30 | return sorted(ld_images), sorted(fd_images) 31 | 32 | # Single SR 33 | def get_paths_from_single_sr_images(path): 34 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 35 | 36 | lr_images = glob.glob(path + "**/**/*lr.png", recursive=True) 37 | hr_images = glob.glob(path + "**/**/*hr.png", recursive=True) 38 | 39 | assert lr_images, '{:s} has no valid lr image file'.format(path) 40 | assert hr_images, '{:s} has no valid hr image file'.format(path) 41 | assert len(lr_images) == len(hr_images), 'Low Dose images nd Full Dose images are not paired!' 42 | return sorted(lr_images), sorted(hr_images) 43 | 44 | 45 | def get_paths_from_npys(path_data, path_gt): 46 | assert os.path.isdir(path_data), '{:s} is not a valid directory'.format(path_data) 47 | assert os.path.isdir(path_gt), '{:s} is not a valid directory'.format(path_gt) 48 | 49 | data_npy = glob.glob(path_data + "*.npy") 50 | gt_npy = glob.glob(path_gt + "*.npy") 51 | 52 | assert data_npy, '{:s} has no valid data npy file'.format(path_data) 53 | assert gt_npy, '{:s} has no valid GT npy file'.format(path_gt) 54 | assert len(data_npy) == len(gt_npy), 'Low Dose images nd Full Dose images are not paired!' 55 | return sorted(data_npy), sorted(gt_npy) 56 | 57 | 58 | # Delete head and tail for train and val 59 | def get_valid_paths_from_images(path): 60 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 61 | images = [] 62 | 63 | for dirpath, folder_path, fnames in sorted(os.walk(path)): 64 | 65 | filtered_fnames = [fname for fname in fnames if fname.endswith('.png') and not fname.startswith('.')] 66 | fnames = filtered_fnames 67 | 68 | fnames = sorted(fnames, key=extract_number) 69 | new_fnames = fnames[1:-1] 70 | 71 | for fname in new_fnames: 72 | if is_image_file(fname): 73 | img_path = os.path.join(dirpath, fname) 74 | images.append(img_path) 75 | 76 | assert images, '{:s} has no valid image file'.format(path) 77 | return images 78 | 79 | 80 | # Delete tail for test 81 | def get_valid_paths_from_test_images(path): 82 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 83 | images = [] 84 | 85 | for dirpath, _, fnames in sorted(os.walk(path)): 86 | filtered_fnames = [fname for fname in fnames if not fname.startswith('.')] 87 | fnames = filtered_fnames 88 | 89 | fnames = sorted(fnames, key=extract_number) 90 | new_fnames = fnames[:-1] 91 | 92 | for fname in new_fnames: 93 | if is_image_file(fname): 94 | img_path = os.path.join(dirpath, fname) 95 | images.append(img_path) 96 | 97 | assert images, '{:s} has no valid image file'.format(path) 98 | return images 99 | 100 | 101 | def augment(img_list, hflip=True, rot=True, split='val'): 102 | # horizontal flip OR rotate 103 | hflip = hflip and (split == 'train' and random.random() < 0.5) 104 | vflip = rot and (split == 'train' and random.random() < 0.5) 105 | rot90 = rot and (split == 'train' and random.random() < 0.5) 106 | 107 | def _augment(img): 108 | if hflip: 109 | img = img[:, ::-1, :] 110 | if vflip: 111 | img = img[::-1, :, :] 112 | if rot90: 113 | img = img.transpose(1, 0, 2) 114 | return img 115 | 116 | return [_augment(img) for img in img_list] 117 | 118 | 119 | def transform2numpy(img): 120 | img = np.array(img) 121 | img = img.astype(np.float32) / 255. 122 | if img.ndim == 2: 123 | img = np.expand_dims(img, axis=2) 124 | # some images have 4 channels 125 | if img.shape[2] > 3: 126 | img = img[:, :, :3] 127 | return img 128 | 129 | 130 | def transform2tensor(img, min_max=(0, 1)): 131 | # HWC to CHW 132 | img = torch.from_numpy(np.ascontiguousarray( 133 | np.transpose(img, (2, 0, 1)))).float() 134 | # to range min_max 135 | img = img*(min_max[1] - min_max[0]) + min_max[0] 136 | return img 137 | 138 | 139 | totensor = torchvision.transforms.ToTensor() 140 | hflip = torchvision.transforms.RandomHorizontalFlip() 141 | Resize = torchvision.transforms.Resize((224, 224), antialias=True) 142 | def transform_augment(img_list, split='val', min_max=(0, 1)): 143 | imgs = [totensor(img) for img in img_list] 144 | if split == 'train': 145 | imgs = torch.stack(imgs, 0) 146 | imgs = hflip(imgs) 147 | imgs = torch.unbind(imgs, dim=0) 148 | 149 | ret_img = [img * (min_max[1] - min_max[0]) + min_max[0] for img in imgs] 150 | return ret_img 151 | 152 | 153 | def brats_transform_augment(img_list, split='val'): 154 | imgs = [totensor(img) for img in img_list] 155 | # imgs = [Resize(img) for img in imgs_tlist] 156 | # if split == 'train': 157 | # imgs = torch.stack(imgs, 0) 158 | # imgs = hflip(imgs) 159 | # imgs = torch.unbind(imgs, dim=0) 160 | ret_img = [img.clamp(-1., 1.) for img in imgs] 161 | 162 | return ret_img 163 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | from torch.utils.model_zoo import tqdm 6 | 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | 20 | def check_integrity(fpath, md5=None): 21 | if md5 is None: 22 | return True 23 | if not os.path.isfile(fpath): 24 | return False 25 | md5o = hashlib.md5() 26 | with open(fpath, 'rb') as f: 27 | # read in 1MB chunks 28 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 29 | md5o.update(chunk) 30 | md5c = md5o.hexdigest() 31 | if md5c != md5: 32 | return False 33 | return True 34 | 35 | 36 | def makedir_exist_ok(dirpath): 37 | """ 38 | Python2 support for os.makedirs(.., exist_ok=True) 39 | """ 40 | try: 41 | os.makedirs(dirpath) 42 | except OSError as e: 43 | if e.errno == errno.EEXIST: 44 | pass 45 | else: 46 | raise 47 | 48 | 49 | def download_url(url, root, filename=None, md5=None): 50 | """Download a file from a url and place it in root. 51 | 52 | Args: 53 | url (str): URL to download file from 54 | root (str): Directory to place downloaded file in 55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 56 | md5 (str, optional): MD5 checksum of the download. If None, do not check 57 | """ 58 | from six.moves import urllib 59 | 60 | root = os.path.expanduser(root) 61 | if not filename: 62 | filename = os.path.basename(url) 63 | fpath = os.path.join(root, filename) 64 | 65 | makedir_exist_ok(root) 66 | 67 | # downloads file 68 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 69 | print('Using downloaded and verified file: ' + fpath) 70 | else: 71 | try: 72 | print('Downloading ' + url + ' to ' + fpath) 73 | urllib.request.urlretrieve( 74 | url, fpath, 75 | reporthook=gen_bar_updater() 76 | ) 77 | except OSError: 78 | if url[:5] == 'https': 79 | url = url.replace('https:', 'http:') 80 | print('Failed download. Trying https -> http instead.' 81 | ' Downloading ' + url + ' to ' + fpath) 82 | urllib.request.urlretrieve( 83 | url, fpath, 84 | reporthook=gen_bar_updater() 85 | ) 86 | 87 | 88 | def list_dir(root, prefix=False): 89 | """List all directories at a given root 90 | 91 | Args: 92 | root (str): Path to directory whose folders need to be listed 93 | prefix (bool, optional): If true, prepends the path to each result, otherwise 94 | only returns the name of the directories found 95 | """ 96 | root = os.path.expanduser(root) 97 | directories = list( 98 | filter( 99 | lambda p: os.path.isdir(os.path.join(root, p)), 100 | os.listdir(root) 101 | ) 102 | ) 103 | 104 | if prefix is True: 105 | directories = [os.path.join(root, d) for d in directories] 106 | 107 | return directories 108 | 109 | 110 | def list_files(root, suffix, prefix=False): 111 | """List all files ending with a suffix at a given root 112 | 113 | Args: 114 | root (str): Path to directory whose folders need to be listed 115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 116 | It uses the Python "str.endswith" method and is passed directly 117 | prefix (bool, optional): If true, prepends the path to each result, otherwise 118 | only returns the name of the files found 119 | """ 120 | root = os.path.expanduser(root) 121 | files = list( 122 | filter( 123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 124 | os.listdir(root) 125 | ) 126 | ) 127 | 128 | if prefix is True: 129 | files = [os.path.join(root, d) for d in files] 130 | 131 | return files 132 | 133 | 134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 135 | """Download a Google Drive file from and place it in root. 136 | 137 | Args: 138 | file_id (str): id of file to be downloaded 139 | root (str): Directory to place downloaded file in 140 | filename (str, optional): Name to save the file under. If None, use the id of the file. 141 | md5 (str, optional): MD5 checksum of the download. If None, do not check 142 | """ 143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 144 | import requests 145 | url = "https://docs.google.com/uc?export=download" 146 | 147 | root = os.path.expanduser(root) 148 | if not filename: 149 | filename = file_id 150 | fpath = os.path.join(root, filename) 151 | 152 | makedir_exist_ok(root) 153 | 154 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 155 | print('Using downloaded and verified file: ' + fpath) 156 | else: 157 | session = requests.Session() 158 | 159 | response = session.get(url, params={'id': file_id}, stream=True) 160 | token = _get_confirm_token(response) 161 | 162 | if token: 163 | params = {'id': file_id, 'confirm': token} 164 | response = session.get(url, params=params, stream=True) 165 | 166 | _save_response_content(response, fpath) 167 | 168 | 169 | def _get_confirm_token(response): 170 | for key, value in response.cookies.items(): 171 | if key.startswith('download_warning'): 172 | return value 173 | 174 | return None 175 | 176 | 177 | def _save_response_content(response, destination, chunk_size=32768): 178 | with open(destination, "wb") as f: 179 | pbar = tqdm(total=None) 180 | progress = 0 181 | for chunk in response.iter_content(chunk_size): 182 | if chunk: # filter out keep-alive new chunks 183 | f.write(chunk) 184 | progress += len(chunk) 185 | pbar.update(progress - pbar.n) 186 | pbar.close() 187 | -------------------------------------------------------------------------------- /ddpm_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import traceback 3 | import shutil 4 | import logging 5 | import yaml 6 | import sys 7 | import os 8 | import torch 9 | import numpy as np 10 | import torch.utils.tensorboard as tb 11 | 12 | from runners.diffusion import Diffusion 13 | 14 | torch.set_printoptions(sci_mode=False) 15 | 16 | 17 | def parse_args_and_config(): 18 | parser = argparse.ArgumentParser(description=globals()["__doc__"]) 19 | 20 | parser.add_argument( 21 | "--config", type=str, default="pmub_linear.yml", help="Path to the config file" 22 | ) 23 | parser.add_argument( 24 | "--dataset", type=str, default="PMUB", help="Name of dataset(LDFDCT, BRATS, PMUB)" 25 | ) 26 | parser.add_argument("--seed", type=int, default=1244, help="Random seed") 27 | parser.add_argument( 28 | "--exp", type=str, default="exp", help="Path for saving running related data." 29 | ) 30 | parser.add_argument( 31 | "--doc", 32 | type=str, 33 | default="DDPM_experiments", 34 | help="A string for documentation purpose. " 35 | "Will be the name of the log folder.", 36 | ) 37 | parser.add_argument( 38 | "--comment", type=str, default="", help="A string for experiment comment" 39 | ) 40 | parser.add_argument( 41 | "--verbose", 42 | type=str, 43 | default="info", 44 | help="Verbose level: info | debug | warning | critical", 45 | ) 46 | parser.add_argument("--test", action="store_true", help="Whether to test the model") 47 | parser.add_argument( 48 | "--sample", 49 | action="store_true", 50 | help="Whether to produce samples from the model", 51 | ) 52 | parser.add_argument("--fid", action="store_true") 53 | parser.add_argument("--interpolation", action="store_true") 54 | parser.add_argument( 55 | "--resume_training", action="store_true", help="Whether to resume training" 56 | ) 57 | parser.add_argument( 58 | "-i", 59 | "--image_folder", 60 | type=str, 61 | default="images", 62 | help="The folder name of samples", 63 | ) 64 | parser.add_argument( 65 | "--ni", 66 | action="store_false", 67 | help="No interaction. Suitable for Slurm Job launcher", 68 | ) 69 | parser.add_argument("--use_pretrained", action="store_true") 70 | parser.add_argument( 71 | "--sample_type", 72 | type=str, 73 | default="ddpm_noisy", 74 | help="sampling approach (generalized or ddpm_noisy)", 75 | ) 76 | parser.add_argument( 77 | "--timesteps", type=int, default=1000, help="number of steps involved" 78 | ) 79 | parser.add_argument( 80 | "--eta", 81 | type=float, 82 | default=0.0, 83 | help="eta used to control the variances of sigma", 84 | ) 85 | parser.add_argument("--sequence", action="store_true") 86 | 87 | args = parser.parse_args() 88 | args.log_path = os.path.join(args.exp, "logs", args.doc) 89 | 90 | # parse config file 91 | with open(os.path.join("configs", args.config), "r") as f: 92 | config = yaml.safe_load(f) 93 | new_config = dict2namespace(config) 94 | 95 | tb_path = os.path.join(args.exp, "tensorboard", args.doc) 96 | 97 | # No test No sampling No resume training 98 | if not args.test and not args.sample: 99 | if not args.resume_training: 100 | if os.path.exists(args.log_path): 101 | overwrite = False 102 | if args.ni: 103 | overwrite = True 104 | else: 105 | response = input("Folder already exists. Overwrite? (Y/N)") 106 | if response.upper() == "Y": 107 | overwrite = True 108 | 109 | if overwrite: 110 | shutil.rmtree(args.log_path) 111 | shutil.rmtree(tb_path) 112 | os.makedirs(args.log_path) 113 | if os.path.exists(tb_path): 114 | shutil.rmtree(tb_path) 115 | else: 116 | print("Folder exists. Program halted.") 117 | sys.exit(0) 118 | else: 119 | os.makedirs(args.log_path) 120 | 121 | with open(os.path.join(args.log_path, "config.yml"), "w") as f: 122 | yaml.dump(new_config, f, default_flow_style=False) 123 | 124 | new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path) 125 | # setup logger 126 | level = getattr(logging, args.verbose.upper(), None) 127 | if not isinstance(level, int): 128 | raise ValueError("level {} not supported".format(args.verbose)) 129 | 130 | handler1 = logging.StreamHandler() 131 | handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt")) 132 | formatter = logging.Formatter( 133 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 134 | ) 135 | handler1.setFormatter(formatter) 136 | handler2.setFormatter(formatter) 137 | logger = logging.getLogger() 138 | logger.addHandler(handler1) 139 | logger.addHandler(handler2) 140 | logger.setLevel(level) 141 | 142 | else: 143 | level = getattr(logging, args.verbose.upper(), None) 144 | if not isinstance(level, int): 145 | raise ValueError("level {} not supported".format(args.verbose)) 146 | 147 | handler1 = logging.StreamHandler() 148 | formatter = logging.Formatter( 149 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 150 | ) 151 | handler1.setFormatter(formatter) 152 | logger = logging.getLogger() 153 | logger.addHandler(handler1) 154 | logger.setLevel(level) 155 | 156 | # Sample from the model 157 | if args.sample: 158 | os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True) 159 | if args.fid: 160 | args.image_folder = os.path.join( 161 | args.exp, "image_samples", args.doc, "images_fid") 162 | if args.interpolation: 163 | args.image_folder = os.path.join( 164 | args.exp, "image_samples", args.doc, "images_interpolation") 165 | 166 | if not os.path.exists(args.image_folder): 167 | os.makedirs(args.image_folder) 168 | else: 169 | if not (args.fid or args.interpolation): 170 | overwrite = False 171 | if args.ni: 172 | overwrite = True 173 | else: 174 | response = input( 175 | f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)" 176 | ) 177 | if response.upper() == "Y": 178 | overwrite = True 179 | 180 | if overwrite: 181 | shutil.rmtree(args.image_folder) 182 | os.makedirs(args.image_folder) 183 | else: 184 | print("Output image folder exists. Program halted.") 185 | sys.exit(0) 186 | 187 | # add device 188 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 189 | logging.info("Using device: {}".format(device)) 190 | new_config.device = device 191 | 192 | # set random seed 193 | torch.manual_seed(args.seed) 194 | np.random.seed(args.seed) 195 | if torch.cuda.is_available(): 196 | torch.cuda.manual_seed_all(args.seed) 197 | 198 | torch.backends.cudnn.benchmark = True 199 | 200 | return args, new_config 201 | 202 | 203 | def dict2namespace(config): 204 | namespace = argparse.Namespace() 205 | for key, value in config.items(): 206 | if isinstance(value, dict): 207 | new_value = dict2namespace(value) 208 | else: 209 | new_value = value 210 | setattr(namespace, key, new_value) 211 | return namespace 212 | 213 | 214 | def main(): 215 | args, config = parse_args_and_config() 216 | logging.info("Writing log file to {}".format(args.log_path)) 217 | logging.info("Exp instance id = {}".format(os.getpid())) 218 | logging.info("Exp comment = {}".format(args.comment)) 219 | 220 | try: 221 | runner = Diffusion(args, config) 222 | if args.sample: 223 | if args.dataset=='PMUB': 224 | runner.sr_sample() 225 | elif args.dataset=='LDFDCT' or args.dataset=='BRATS': 226 | runner.sg_sample() 227 | else: 228 | raise Exception("This script only supports LDFDCT, BRATS and PMUB as sampling dataset. Feel free to add your own.") 229 | elif args.test: 230 | runner.test() 231 | else: 232 | if args.dataset=='PMUB': 233 | runner.sr_ddpm_train() 234 | elif args.dataset=='LDFDCT' or args.dataset=='BRATS': 235 | runner.sg_ddpm_train() 236 | else: 237 | raise Exception("This script only supports LDFDCT, BRATS and PMUB as training dataset. Feel free to add your own.") 238 | except Exception: 239 | logging.error(traceback.format_exc()) 240 | 241 | return 0 242 | 243 | 244 | if __name__ == "__main__": 245 | sys.exit(main()) 246 | -------------------------------------------------------------------------------- /fast_ddpm_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import traceback 3 | import shutil 4 | import logging 5 | import yaml 6 | import sys 7 | import os 8 | import torch 9 | import numpy as np 10 | import torch.utils.tensorboard as tb 11 | 12 | from runners.diffusion import Diffusion 13 | 14 | torch.set_printoptions(sci_mode=False) 15 | 16 | 17 | def parse_args_and_config(): 18 | parser = argparse.ArgumentParser(description=globals()["__doc__"]) 19 | 20 | parser.add_argument( 21 | "--config", type=str, default="pmub_linear.yml", help="Path to the config file" 22 | ) 23 | parser.add_argument( 24 | "--dataset", type=str, default="PMUB", help="Name of dataset(LDFDCT, BRATS, PMUB)" 25 | ) 26 | parser.add_argument("--seed", type=int, default=1244, help="Random seed") 27 | parser.add_argument( 28 | "--exp", type=str, default="exp", help="Path for saving running related data." 29 | ) 30 | parser.add_argument( 31 | "--doc", 32 | type=str, 33 | default="Fast-DDPM_experiments", 34 | help="A string for documentation purpose. " 35 | "Will be the name of the log folder.", 36 | ) 37 | parser.add_argument( 38 | "--comment", type=str, default="", help="A string for experiment comment" 39 | ) 40 | parser.add_argument( 41 | "--verbose", 42 | type=str, 43 | default="info", 44 | help="Verbose level: info | debug | warning | critical", 45 | ) 46 | parser.add_argument("--test", action="store_true", help="Whether to test the model") 47 | parser.add_argument( 48 | "--sample", 49 | action="store_true", 50 | help="Whether to produce samples from the model", 51 | ) 52 | parser.add_argument("--fid", action="store_true") 53 | parser.add_argument("--interpolation", action="store_true") 54 | parser.add_argument( 55 | "--resume_training", action="store_true", help="Whether to resume training" 56 | ) 57 | parser.add_argument( 58 | "-i", 59 | "--image_folder", 60 | type=str, 61 | default="images", 62 | help="The folder name of samples", 63 | ) 64 | parser.add_argument( 65 | "--ni", 66 | action="store_false", 67 | help="No interaction. Suitable for Slurm Job launcher", 68 | ) 69 | parser.add_argument("--use_pretrained", action="store_true") 70 | parser.add_argument( 71 | "--sample_type", 72 | type=str, 73 | default="generalized", 74 | help="sampling approach (generalized or ddpm_noisy)", 75 | ) 76 | parser.add_argument( 77 | "--scheduler_type", 78 | type=str, 79 | default="uniform", 80 | help="sample involved time steps according to (uniform or non-uniform)", 81 | ) 82 | parser.add_argument( 83 | "--timesteps", type=int, default=100, help="number of steps involved" 84 | ) 85 | parser.add_argument( 86 | "--eta", 87 | type=float, 88 | default=0.0, 89 | help="eta used to control the variances of sigma", 90 | ) 91 | parser.add_argument("--sequence", action="store_true") 92 | 93 | args = parser.parse_args() 94 | args.log_path = os.path.join(args.exp, "logs", args.doc) 95 | 96 | # parse config file 97 | with open(os.path.join("configs", args.config), "r") as f: 98 | config = yaml.safe_load(f) 99 | new_config = dict2namespace(config) 100 | 101 | tb_path = os.path.join(args.exp, "tensorboard", args.doc) 102 | 103 | # No test No sampling No resume training 104 | if not args.test and not args.sample: 105 | if not args.resume_training: 106 | if os.path.exists(args.log_path): 107 | overwrite = False 108 | if args.ni: 109 | overwrite = True 110 | else: 111 | response = input("Folder already exists. Overwrite? (Y/N)") 112 | if response.upper() == "Y": 113 | overwrite = True 114 | 115 | if overwrite: 116 | shutil.rmtree(args.log_path) 117 | shutil.rmtree(tb_path) 118 | os.makedirs(args.log_path) 119 | if os.path.exists(tb_path): 120 | shutil.rmtree(tb_path) 121 | else: 122 | print("Folder exists. Program halted.") 123 | sys.exit(0) 124 | else: 125 | os.makedirs(args.log_path) 126 | 127 | with open(os.path.join(args.log_path, "config.yml"), "w") as f: 128 | yaml.dump(new_config, f, default_flow_style=False) 129 | 130 | new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path) 131 | # setup logger 132 | level = getattr(logging, args.verbose.upper(), None) 133 | if not isinstance(level, int): 134 | raise ValueError("level {} not supported".format(args.verbose)) 135 | 136 | handler1 = logging.StreamHandler() 137 | handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt")) 138 | formatter = logging.Formatter( 139 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 140 | ) 141 | handler1.setFormatter(formatter) 142 | handler2.setFormatter(formatter) 143 | logger = logging.getLogger() 144 | logger.addHandler(handler1) 145 | logger.addHandler(handler2) 146 | logger.setLevel(level) 147 | 148 | else: 149 | level = getattr(logging, args.verbose.upper(), None) 150 | if not isinstance(level, int): 151 | raise ValueError("level {} not supported".format(args.verbose)) 152 | 153 | handler1 = logging.StreamHandler() 154 | formatter = logging.Formatter( 155 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 156 | ) 157 | handler1.setFormatter(formatter) 158 | logger = logging.getLogger() 159 | logger.addHandler(handler1) 160 | logger.setLevel(level) 161 | 162 | # Sample from the model 163 | if args.sample: 164 | os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True) 165 | if args.fid: 166 | args.image_folder = os.path.join( 167 | args.exp, "image_samples", args.doc, "images_fid") 168 | if args.interpolation: 169 | args.image_folder = os.path.join( 170 | args.exp, "image_samples", args.doc, "images_interpolation") 171 | 172 | if not os.path.exists(args.image_folder): 173 | os.makedirs(args.image_folder) 174 | else: 175 | if not (args.fid or args.interpolation): 176 | overwrite = False 177 | if args.ni: 178 | overwrite = True 179 | else: 180 | response = input( 181 | f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)" 182 | ) 183 | if response.upper() == "Y": 184 | overwrite = True 185 | 186 | if overwrite: 187 | shutil.rmtree(args.image_folder) 188 | os.makedirs(args.image_folder) 189 | else: 190 | print("Output image folder exists. Program halted.") 191 | sys.exit(0) 192 | 193 | # add device 194 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 195 | logging.info("Using device: {}".format(device)) 196 | new_config.device = device 197 | 198 | # set random seed 199 | torch.manual_seed(args.seed) 200 | np.random.seed(args.seed) 201 | if torch.cuda.is_available(): 202 | torch.cuda.manual_seed_all(args.seed) 203 | 204 | torch.backends.cudnn.benchmark = True 205 | 206 | return args, new_config 207 | 208 | 209 | def dict2namespace(config): 210 | namespace = argparse.Namespace() 211 | for key, value in config.items(): 212 | if isinstance(value, dict): 213 | new_value = dict2namespace(value) 214 | else: 215 | new_value = value 216 | setattr(namespace, key, new_value) 217 | return namespace 218 | 219 | 220 | def main(): 221 | args, config = parse_args_and_config() 222 | logging.info("Writing log file to {}".format(args.log_path)) 223 | logging.info("Exp instance id = {}".format(os.getpid())) 224 | logging.info("Exp comment = {}".format(args.comment)) 225 | 226 | try: 227 | runner = Diffusion(args, config) 228 | if args.sample: 229 | if args.dataset=='PMUB': 230 | runner.sr_sample() 231 | elif args.dataset=='LDFDCT' or args.dataset=='BRATS': 232 | runner.sg_sample() 233 | else: 234 | raise Exception("This script only supports LDFDCT, BRATS and PMUB as sampling dataset. Feel free to add your own.") 235 | elif args.test: 236 | runner.test() 237 | else: 238 | if args.dataset=='PMUB': 239 | runner.sr_train() 240 | elif args.dataset=='LDFDCT' or args.dataset=='BRATS': 241 | runner.sg_train() 242 | else: 243 | raise Exception("This script only supports LDFDCT, BRATS and PMUB as training dataset. Feel free to add your own.") 244 | except Exception: 245 | logging.error(traceback.format_exc()) 246 | 247 | return 0 248 | 249 | 250 | if __name__ == "__main__": 251 | sys.exit(main()) 252 | -------------------------------------------------------------------------------- /functions/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def get_optimizer(config, parameters): 5 | if config.optim.optimizer == 'Adam': 6 | return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay, 7 | betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad, 8 | eps=config.optim.eps) 9 | elif config.optim.optimizer == 'RMSProp': 10 | return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay) 11 | elif config.optim.optimizer == 'SGD': 12 | return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9) 13 | else: 14 | raise NotImplementedError( 15 | 'Optimizer {} not understood.'.format(config.optim.optimizer)) 16 | -------------------------------------------------------------------------------- /functions/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1", 7 | "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1", 8 | "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1", 9 | "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1", 10 | "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1", 11 | "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1", 12 | "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1", 13 | "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1", 14 | } 15 | CKPT_MAP = { 16 | "cifar10": "diffusion_cifar10_model/model-790000.ckpt", 17 | "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt", 18 | "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt", 19 | "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt", 20 | "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt", 21 | "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt", 22 | "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt", 23 | "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt", 24 | } 25 | MD5_MAP = { 26 | "cifar10": "82ed3067fd1002f5cf4c339fb80c4669", 27 | "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3", 28 | "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c", 29 | "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f", 30 | "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b", 31 | "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558", 32 | "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3", 33 | "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f", 34 | } 35 | 36 | 37 | def download(url, local_path, chunk_size=1024): 38 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 39 | with requests.get(url, stream=True) as r: 40 | total_size = int(r.headers.get("content-length", 0)) 41 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 42 | with open(local_path, "wb") as f: 43 | for data in r.iter_content(chunk_size=chunk_size): 44 | if data: 45 | f.write(data) 46 | pbar.update(chunk_size) 47 | 48 | 49 | def md5_hash(path): 50 | with open(path, "rb") as f: 51 | content = f.read() 52 | return hashlib.md5(content).hexdigest() 53 | 54 | 55 | def get_ckpt_path(name, root=None, check=False): 56 | if 'church_outdoor' in name: 57 | name = name.replace('church_outdoor', 'church') 58 | assert name in URL_MAP 59 | # Modify the path when necessary 60 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("/atlas/u/tsong/.cache")) 61 | root = ( 62 | root 63 | if root is not None 64 | else os.path.join(cachedir, "diffusion_models_converted") 65 | ) 66 | path = os.path.join(root, CKPT_MAP[name]) 67 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 68 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 69 | download(URL_MAP[name], path) 70 | md5 = md5_hash(path) 71 | assert md5 == MD5_MAP[name], md5 72 | return path 73 | -------------------------------------------------------------------------------- /functions/denoising.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_alpha(beta, t): 5 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 6 | # [1, alphas_cumprod] 7 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 8 | return a 9 | 10 | 11 | def generalized_steps(x, seq, model, b, **kwargs): 12 | with torch.no_grad(): 13 | n = x.size(0) 14 | seq_next = [-1] + list(seq[:-1]) 15 | x0_preds = [] 16 | xs = [x] 17 | for i, j in zip(reversed(seq), reversed(seq_next)): 18 | t = (torch.ones(n) * i).to(x.device) 19 | next_t = (torch.ones(n) * j).to(x.device) 20 | at = compute_alpha(b, t.long()) 21 | at_next = compute_alpha(b, next_t.long()) 22 | xt = xs[-1].to('cuda') 23 | et = model(xt, t) 24 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 25 | x0_preds.append(x0_t.to('cpu')) 26 | # Equation (12) 27 | c1 = ( 28 | kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() 29 | ) 30 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 31 | 32 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et 33 | xs.append(xt_next.to('cpu')) 34 | 35 | return xs, x0_preds 36 | 37 | 38 | def ddpm_steps(x, seq, model, b, **kwargs): 39 | with torch.no_grad(): 40 | n = x.size(0) 41 | seq_next = [-1] + list(seq[:-1]) 42 | xs = [x] 43 | x0_preds = [] 44 | betas = b 45 | for i, j in zip(reversed(seq), reversed(seq_next)): 46 | t = (torch.ones(n) * i).to(x.device) 47 | next_t = (torch.ones(n) * j).to(x.device) 48 | at = compute_alpha(betas, t.long()) 49 | atm1 = compute_alpha(betas, next_t.long()) 50 | beta_t = 1 - at / atm1 51 | x = xs[-1].to('cuda') 52 | 53 | output = model(x, t.float()) 54 | e = output 55 | 56 | x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e 57 | x0_from_e = torch.clamp(x0_from_e, -1, 1) 58 | x0_preds.append(x0_from_e.to('cpu')) 59 | mean_eps = ( 60 | (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x 61 | ) / (1.0 - at) 62 | 63 | mean = mean_eps 64 | noise = torch.randn_like(x) 65 | mask = 1 - (t == 0).float() 66 | mask = mask.view(-1, 1, 1, 1) 67 | logvar = beta_t.log() 68 | sample = mean + mask * torch.exp(0.5 * logvar) * noise 69 | xs.append(sample.to('cpu')) 70 | return xs, x0_preds 71 | 72 | 73 | def sr_generalized_steps(x, x_bw, x_fw, seq, model, b, **kwargs): 74 | with torch.no_grad(): 75 | n = x.size(0) 76 | seq_next = [-1] + list(seq[:-1]) 77 | x0_preds = [] 78 | xs = [x] 79 | 80 | for i, j in zip(reversed(seq), reversed(seq_next)): 81 | t = (torch.ones(n) * i).to(x.device) 82 | next_t = (torch.ones(n) * j).to(x.device) 83 | at = compute_alpha(b, t.long()) 84 | at_next = compute_alpha(b, next_t.long()) 85 | xt = xs[-1].to('cuda') 86 | et = model(torch.cat([x_bw, x_fw, xt], dim=1), t) 87 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 88 | x0_preds.append(x0_t.to('cpu')) 89 | # Equation (12) 90 | c1 = ( 91 | kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() 92 | ) 93 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 94 | 95 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et 96 | xs.append(xt_next.to('cpu')) 97 | 98 | return xs, x0_preds 99 | 100 | 101 | def sr_ddpm_steps(x, x_bw, x_fw, seq, model, b, **kwargs): 102 | with torch.no_grad(): 103 | n = x.size(0) 104 | seq_next = [-1] + list(seq[:-1]) 105 | xs = [x] 106 | x0_preds = [] 107 | betas = b 108 | for i, j in zip(reversed(seq), reversed(seq_next)): 109 | t = (torch.ones(n) * i).to(x.device) 110 | next_t = (torch.ones(n) * j).to(x.device) 111 | at = compute_alpha(betas, t.long()) 112 | atm1 = compute_alpha(betas, next_t.long()) 113 | beta_t = 1 - at / atm1 114 | x = xs[-1].to('cuda') 115 | 116 | output = model(torch.cat([x_bw, x_fw, x], dim=1), t.float()) 117 | e = output 118 | 119 | x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e 120 | x0_from_e = torch.clamp(x0_from_e, -1, 1) 121 | x0_preds.append(x0_from_e.to('cpu')) 122 | mean_eps = ( 123 | (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x 124 | ) / (1.0 - at) 125 | 126 | mean = mean_eps 127 | noise = torch.randn_like(x) 128 | mask = 1 - (t == 0).float() 129 | mask = mask.view(-1, 1, 1, 1) 130 | logvar = beta_t.log() 131 | sample = mean + mask * torch.exp(0.5 * logvar) * noise 132 | xs.append(sample.to('cpu')) 133 | return xs, x0_preds 134 | 135 | 136 | def sg_generalized_steps(x, x_img, seq, model, b, **kwargs): 137 | with torch.no_grad(): 138 | n = x.size(0) 139 | seq_next = [-1] + list(seq[:-1]) 140 | x0_preds = [] 141 | xs = [x] 142 | 143 | for i, j in zip(reversed(seq), reversed(seq_next)): 144 | t = (torch.ones(n) * i).to(x.device) 145 | next_t = (torch.ones(n) * j).to(x.device) 146 | at = compute_alpha(b, t.long()) 147 | at_next = compute_alpha(b, next_t.long()) 148 | xt = xs[-1].to('cuda') 149 | et = model(torch.cat([x_img, xt], dim=1), t) 150 | 151 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 152 | x0_preds.append(x0_t.to('cpu')) 153 | # Equation (12) 154 | c1 = ( 155 | kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() 156 | ) 157 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 158 | 159 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et 160 | xs.append(xt_next.to('cpu')) 161 | 162 | return xs, x0_preds 163 | 164 | 165 | def sg_ddpm_steps(x, x_img, seq, model, b, **kwargs): 166 | with torch.no_grad(): 167 | n = x.size(0) 168 | seq_next = [-1] + list(seq[:-1]) 169 | xs = [x] 170 | x0_preds = [] 171 | betas = b 172 | for i, j in zip(reversed(seq), reversed(seq_next)): 173 | t = (torch.ones(n) * i).to(x.device) 174 | next_t = (torch.ones(n) * j).to(x.device) 175 | at = compute_alpha(betas, t.long()) 176 | atm1 = compute_alpha(betas, next_t.long()) 177 | beta_t = 1 - at / atm1 178 | x = xs[-1].to('cuda') 179 | 180 | output = model(torch.cat([x_img, x], dim=1), t.float()) 181 | e = output 182 | 183 | x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e 184 | x0_from_e = torch.clamp(x0_from_e, -1, 1) 185 | x0_preds.append(x0_from_e.to('cpu')) 186 | mean_eps = ( 187 | (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x 188 | ) / (1.0 - at) 189 | 190 | mean = mean_eps 191 | noise = torch.randn_like(x) 192 | mask = 1 - (t == 0).float() 193 | mask = mask.view(-1, 1, 1, 1) 194 | logvar = beta_t.log() 195 | sample = mean + mask * torch.exp(0.5 * logvar) * noise 196 | xs.append(sample.to('cpu')) 197 | return xs, x0_preds 198 | -------------------------------------------------------------------------------- /functions/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import time 4 | from medpy import metric 5 | import numpy as np 6 | np.bool = np.bool_ 7 | 8 | 9 | def calculate_psnr(img1, img2): 10 | # img1: img 11 | # img2: gt 12 | # img1 and img2 have range [0, 255] 13 | img1 = img1.astype(np.float64) 14 | img2 = img2.astype(np.float64) 15 | 16 | mse = np.mean((img1 - img2)**2) 17 | psnr = 20 * math.log10(255.0 / math.sqrt(mse)) 18 | 19 | return psnr 20 | 21 | 22 | def noise_estimation_loss(model, 23 | x0: torch.Tensor, 24 | t: torch.LongTensor, 25 | e: torch.Tensor, 26 | b: torch.Tensor, keepdim=False): 27 | # a: a_T in DDIM 28 | # 1-a: 1-a_T in DDIM 29 | a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1) 30 | # X_T 31 | x = x0 * a.sqrt() + e * (1.0 - a).sqrt() 32 | output = model(x, t.float()) 33 | if keepdim: 34 | return (e - output).square().sum(dim=(1, 2, 3)) 35 | else: 36 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0) 37 | 38 | 39 | 40 | def sr_noise_estimation_loss(model, 41 | x_bw: torch.Tensor, 42 | x_md: torch.Tensor, 43 | x_fw: torch.Tensor, 44 | t: torch.LongTensor, 45 | e: torch.Tensor, 46 | b: torch.Tensor, keepdim=False): 47 | # a: a_T in DDIM 48 | # 1-a: 1-a_T in DDIM 49 | a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1) 50 | # X_T 51 | x = x_md * a.sqrt() + e * (1.0 - a).sqrt() 52 | 53 | output = model(torch.cat([x_bw, x_fw, x], dim=1), t.float()) 54 | if keepdim: 55 | return (e - output).square().sum(dim=(1, 2, 3)) 56 | else: 57 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0) 58 | 59 | 60 | 61 | def sg_noise_estimation_loss(model, 62 | x_img: torch.Tensor, 63 | x_gt: torch.Tensor, 64 | t: torch.LongTensor, 65 | e: torch.Tensor, 66 | b: torch.Tensor, keepdim=False): 67 | # a: a_T in DDIM 68 | # 1-a: 1-a_T in DDIM 69 | a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1) 70 | # X_T 71 | x = x_gt * a.sqrt() + e * (1.0 - a).sqrt() 72 | output = model(torch.cat([x_img, x], dim=1), t.float()) 73 | 74 | if keepdim: 75 | return (e - output).square().sum(dim=(1, 2, 3)) 76 | else: 77 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0) 78 | 79 | 80 | loss_registry = { 81 | 'simple': noise_estimation_loss, 82 | 'sr': sr_noise_estimation_loss, 83 | 'sg': sg_noise_estimation_loss 84 | } -------------------------------------------------------------------------------- /models/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def get_timestep_embedding(timesteps, embedding_dim): 7 | """ 8 | This matches the implementation in Denoising Diffusion Probabilistic Models: 9 | From Fairseq. 10 | Build sinusoidal embeddings. 11 | This matches the implementation in tensor2tensor, but differs slightly 12 | from the description in Section 3.5 of "Attention Is All You Need". 13 | """ 14 | assert len(timesteps.shape) == 1 15 | 16 | half_dim = embedding_dim // 2 17 | emb = math.log(10000) / (half_dim - 1) 18 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 19 | emb = emb.to(device=timesteps.device) 20 | emb = timesteps.float()[:, None] * emb[None, :] 21 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 22 | if embedding_dim % 2 == 1: # zero pad 23 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 24 | return emb 25 | 26 | 27 | def nonlinearity(x): 28 | # swish 29 | return x*torch.sigmoid(x) 30 | 31 | 32 | def Normalize(in_channels): 33 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 34 | 35 | 36 | class Upsample(nn.Module): 37 | def __init__(self, in_channels, with_conv): 38 | super().__init__() 39 | self.with_conv = with_conv 40 | if self.with_conv: 41 | self.conv = torch.nn.Conv2d(in_channels, 42 | in_channels, 43 | kernel_size=3, 44 | stride=1, 45 | padding=1) 46 | 47 | def forward(self, x): 48 | x = torch.nn.functional.interpolate( 49 | x, scale_factor=2.0, mode="nearest") 50 | if self.with_conv: 51 | x = self.conv(x) 52 | return x 53 | 54 | 55 | class Downsample(nn.Module): 56 | def __init__(self, in_channels, with_conv): 57 | super().__init__() 58 | self.with_conv = with_conv 59 | if self.with_conv: 60 | # no asymmetric padding in torch conv, must do it ourselves 61 | self.conv = torch.nn.Conv2d(in_channels, 62 | in_channels, 63 | kernel_size=3, 64 | stride=2, 65 | padding=0) 66 | 67 | def forward(self, x): 68 | if self.with_conv: 69 | pad = (0, 1, 0, 1) 70 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 71 | x = self.conv(x) 72 | else: 73 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 74 | return x 75 | 76 | 77 | class ResnetBlock(nn.Module): 78 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 79 | dropout, temb_channels=512): 80 | super().__init__() 81 | self.in_channels = in_channels 82 | out_channels = in_channels if out_channels is None else out_channels 83 | self.out_channels = out_channels 84 | self.use_conv_shortcut = conv_shortcut 85 | 86 | self.norm1 = Normalize(in_channels) 87 | self.conv1 = torch.nn.Conv2d(in_channels, 88 | out_channels, 89 | kernel_size=3, 90 | stride=1, 91 | padding=1) 92 | self.temb_proj = torch.nn.Linear(temb_channels, 93 | out_channels) 94 | self.norm2 = Normalize(out_channels) 95 | self.dropout = torch.nn.Dropout(dropout) 96 | self.conv2 = torch.nn.Conv2d(out_channels, 97 | out_channels, 98 | kernel_size=3, 99 | stride=1, 100 | padding=1) 101 | if self.in_channels != self.out_channels: 102 | if self.use_conv_shortcut: 103 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 104 | out_channels, 105 | kernel_size=3, 106 | stride=1, 107 | padding=1) 108 | else: 109 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 110 | out_channels, 111 | kernel_size=1, 112 | stride=1, 113 | padding=0) 114 | 115 | def forward(self, x, temb): 116 | h = x 117 | h = self.norm1(h) 118 | h = nonlinearity(h) 119 | h = self.conv1(h) 120 | 121 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 122 | 123 | h = self.norm2(h) 124 | h = nonlinearity(h) 125 | h = self.dropout(h) 126 | h = self.conv2(h) 127 | 128 | if self.in_channels != self.out_channels: 129 | if self.use_conv_shortcut: 130 | x = self.conv_shortcut(x) 131 | else: 132 | x = self.nin_shortcut(x) 133 | 134 | return x+h 135 | 136 | 137 | class AttnBlock(nn.Module): 138 | def __init__(self, in_channels): 139 | super().__init__() 140 | self.in_channels = in_channels 141 | 142 | self.norm = Normalize(in_channels) 143 | self.q = torch.nn.Conv2d(in_channels, 144 | in_channels, 145 | kernel_size=1, 146 | stride=1, 147 | padding=0) 148 | self.k = torch.nn.Conv2d(in_channels, 149 | in_channels, 150 | kernel_size=1, 151 | stride=1, 152 | padding=0) 153 | self.v = torch.nn.Conv2d(in_channels, 154 | in_channels, 155 | kernel_size=1, 156 | stride=1, 157 | padding=0) 158 | self.proj_out = torch.nn.Conv2d(in_channels, 159 | in_channels, 160 | kernel_size=1, 161 | stride=1, 162 | padding=0) 163 | 164 | def forward(self, x): 165 | h_ = x 166 | h_ = self.norm(h_) 167 | q = self.q(h_) 168 | k = self.k(h_) 169 | v = self.v(h_) 170 | 171 | # compute attention 172 | b, c, h, w = q.shape 173 | q = q.reshape(b, c, h*w) 174 | q = q.permute(0, 2, 1) # b,hw,c 175 | k = k.reshape(b, c, h*w) # b,c,hw 176 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 177 | w_ = w_ * (int(c)**(-0.5)) 178 | w_ = torch.nn.functional.softmax(w_, dim=2) 179 | 180 | # attend to values 181 | v = v.reshape(b, c, h*w) 182 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 183 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 184 | h_ = torch.bmm(v, w_) 185 | h_ = h_.reshape(b, c, h, w) 186 | 187 | h_ = self.proj_out(h_) 188 | 189 | return x+h_ 190 | 191 | 192 | class Model(nn.Module): 193 | def __init__(self, config): 194 | super().__init__() 195 | self.config = config 196 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult) 197 | num_res_blocks = config.model.num_res_blocks 198 | attn_resolutions = config.model.attn_resolutions 199 | dropout = config.model.dropout 200 | in_channels = config.model.in_channels 201 | resolution = config.data.image_size 202 | resamp_with_conv = config.model.resamp_with_conv 203 | num_timesteps = config.diffusion.num_diffusion_timesteps 204 | 205 | if config.model.type == 'bayesian': 206 | self.logvar = nn.Parameter(torch.zeros(num_timesteps)) 207 | 208 | self.ch = ch 209 | self.temb_ch = self.ch*4 210 | self.num_resolutions = len(ch_mult) 211 | self.num_res_blocks = num_res_blocks 212 | self.resolution = resolution 213 | self.in_channels = in_channels 214 | 215 | # timestep embedding 216 | self.temb = nn.Module() 217 | self.temb.dense = nn.ModuleList([ 218 | torch.nn.Linear(self.ch, 219 | self.temb_ch), 220 | torch.nn.Linear(self.temb_ch, 221 | self.temb_ch), 222 | ]) 223 | 224 | # downsampling 225 | self.conv_in = torch.nn.Conv2d(in_channels, 226 | self.ch, 227 | kernel_size=3, 228 | stride=1, 229 | padding=1) 230 | 231 | curr_res = resolution 232 | in_ch_mult = (1,)+ch_mult 233 | self.down = nn.ModuleList() 234 | block_in = None 235 | for i_level in range(self.num_resolutions): 236 | block = nn.ModuleList() 237 | attn = nn.ModuleList() 238 | block_in = ch*in_ch_mult[i_level] 239 | block_out = ch*ch_mult[i_level] 240 | for i_block in range(self.num_res_blocks): 241 | block.append(ResnetBlock(in_channels=block_in, 242 | out_channels=block_out, 243 | temb_channels=self.temb_ch, 244 | dropout=dropout)) 245 | block_in = block_out 246 | if curr_res in attn_resolutions: 247 | attn.append(AttnBlock(block_in)) 248 | down = nn.Module() 249 | down.block = block 250 | down.attn = attn 251 | if i_level != self.num_resolutions-1: 252 | down.downsample = Downsample(block_in, resamp_with_conv) 253 | curr_res = curr_res // 2 254 | self.down.append(down) 255 | 256 | # middle 257 | self.mid = nn.Module() 258 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 259 | out_channels=block_in, 260 | temb_channels=self.temb_ch, 261 | dropout=dropout) 262 | self.mid.attn_1 = AttnBlock(block_in) 263 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 264 | out_channels=block_in, 265 | temb_channels=self.temb_ch, 266 | dropout=dropout) 267 | 268 | # upsampling 269 | self.up = nn.ModuleList() 270 | for i_level in reversed(range(self.num_resolutions)): 271 | block = nn.ModuleList() 272 | attn = nn.ModuleList() 273 | block_out = ch*ch_mult[i_level] 274 | skip_in = ch*ch_mult[i_level] 275 | for i_block in range(self.num_res_blocks+1): 276 | if i_block == self.num_res_blocks: 277 | skip_in = ch*in_ch_mult[i_level] 278 | block.append(ResnetBlock(in_channels=block_in+skip_in, 279 | out_channels=block_out, 280 | temb_channels=self.temb_ch, 281 | dropout=dropout)) 282 | block_in = block_out 283 | if curr_res in attn_resolutions: 284 | attn.append(AttnBlock(block_in)) 285 | up = nn.Module() 286 | up.block = block 287 | up.attn = attn 288 | if i_level != 0: 289 | up.upsample = Upsample(block_in, resamp_with_conv) 290 | curr_res = curr_res * 2 291 | self.up.insert(0, up) # prepend to get consistent order 292 | 293 | # end 294 | self.norm_out = Normalize(block_in) 295 | self.conv_out = torch.nn.Conv2d(block_in, 296 | out_ch, 297 | kernel_size=3, 298 | stride=1, 299 | padding=1) 300 | 301 | def forward(self, x, t): 302 | assert x.shape[2] == x.shape[3] == self.resolution 303 | 304 | # timestep embedding 305 | temb = get_timestep_embedding(t, self.ch) 306 | temb = self.temb.dense[0](temb) 307 | temb = nonlinearity(temb) 308 | temb = self.temb.dense[1](temb) 309 | 310 | # downsampling 311 | hs = [self.conv_in(x)] 312 | for i_level in range(self.num_resolutions): 313 | for i_block in range(self.num_res_blocks): 314 | h = self.down[i_level].block[i_block](hs[-1], temb) 315 | if len(self.down[i_level].attn) > 0: 316 | h = self.down[i_level].attn[i_block](h) 317 | hs.append(h) 318 | if i_level != self.num_resolutions-1: 319 | hs.append(self.down[i_level].downsample(hs[-1])) 320 | 321 | # middle 322 | h = hs[-1] 323 | 324 | h = self.mid.block_1(h, temb) 325 | h = self.mid.attn_1(h) 326 | h = self.mid.block_2(h, temb) 327 | 328 | # upsampling 329 | for i_level in reversed(range(self.num_resolutions)): 330 | for i_block in range(self.num_res_blocks+1): 331 | h = self.up[i_level].block[i_block]( 332 | torch.cat([h, hs.pop()], dim=1), temb) 333 | if len(self.up[i_level].attn) > 0: 334 | h = self.up[i_level].attn[i_block](h) 335 | if i_level != 0: 336 | h = self.up[i_level].upsample(h) 337 | 338 | # end 339 | h = self.norm_out(h) 340 | h = nonlinearity(h) 341 | h = self.conv_out(h) 342 | return h 343 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | """ 4 | A method that increases the stability of a model’s convergence and helps it reach a better overall solution by preventing convergence to a local minima. 5 | To avoid drastic changes in the model’s weights during training, a copy of the current weights is created before updating the model’s weights. 6 | Then the model’s weights are updated to be the weighted average between the current weights and the post-optimization step weights. 7 | """ 8 | 9 | 10 | class EMAHelper(object): 11 | def __init__(self, mu=0.999): 12 | self.mu = mu 13 | self.shadow = {} 14 | 15 | def register(self, module): 16 | if isinstance(module, nn.DataParallel): 17 | module = module.module 18 | for name, param in module.named_parameters(): 19 | if param.requires_grad: 20 | self.shadow[name] = param.data.clone() 21 | 22 | def update(self, module): 23 | if isinstance(module, nn.DataParallel): 24 | module = module.module 25 | for name, param in module.named_parameters(): 26 | if param.requires_grad: 27 | self.shadow[name].data = ( 28 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data 29 | 30 | def ema(self, module): 31 | if isinstance(module, nn.DataParallel): 32 | module = module.module 33 | for name, param in module.named_parameters(): 34 | if param.requires_grad: 35 | param.data.copy_(self.shadow[name].data) 36 | 37 | def ema_copy(self, module): 38 | if isinstance(module, nn.DataParallel): 39 | inner_module = module.module 40 | module_copy = type(inner_module)( 41 | inner_module.config).to(inner_module.config.device) 42 | module_copy.load_state_dict(inner_module.state_dict()) 43 | module_copy = nn.DataParallel(module_copy) 44 | else: 45 | module_copy = type(module)(module.config).to(module.config.device) 46 | module_copy.load_state_dict(module.state_dict()) 47 | # module_copy = copy.deepcopy(module) 48 | self.ema(module_copy) 49 | return module_copy 50 | 51 | def state_dict(self): 52 | return self.shadow 53 | 54 | def load_state_dict(self, state_dict): 55 | self.shadow = state_dict 56 | -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mirthAI/Fast-DDPM/649a14a6093d14f4286a6b6f9963dd208ce07928/runners/__init__.py -------------------------------------------------------------------------------- /runners/diffusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import time 4 | import glob 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import math 9 | import tqdm 10 | import torch 11 | import torch.utils.data as data 12 | 13 | from models.diffusion import Model 14 | from models.ema import EMAHelper 15 | from functions import get_optimizer 16 | from functions.losses import loss_registry, calculate_psnr 17 | from datasets import data_transform, inverse_data_transform 18 | from datasets.pmub import PMUB 19 | from datasets.LDFDCT import LDFDCT 20 | from datasets.BRATS import BRATS 21 | from functions.ckpt_util import get_ckpt_path 22 | from skimage.metrics import structural_similarity as ssim 23 | import torchvision.utils as tvu 24 | import torchvision 25 | from PIL import Image 26 | 27 | 28 | def torch2hwcuint8(x, clip=False): 29 | if clip: 30 | x = torch.clamp(x, -1, 1) 31 | x = (x + 1.0) / 2.0 32 | return x 33 | 34 | 35 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 36 | def sigmoid(x): 37 | return 1 / (np.exp(x) + 1) 38 | def tanh(x): 39 | return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x)) 40 | 41 | if beta_schedule == "quad": 42 | betas = ( 43 | np.linspace( 44 | beta_start ** 0.5, 45 | beta_end ** 0.5, 46 | num_diffusion_timesteps, 47 | dtype=np.float64, 48 | ) 49 | ** 2 50 | ) 51 | elif beta_schedule == "linear": 52 | betas = np.linspace( 53 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 54 | ) 55 | elif beta_schedule == "sigmoid": 56 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 57 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 58 | elif beta_schedule =='alpha_cosine': 59 | s = 0.008 60 | timesteps = np.arange(0, num_diffusion_timesteps+1, dtype=np.float64)/num_diffusion_timesteps 61 | alphas_cumprod = np.cos((timesteps + s) / (1 + s) * math.pi * 0.5) ** 2 62 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 63 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 64 | betas = np.clip(betas, a_min=None, a_max=0.999) 65 | elif beta_schedule == 'alpha_sigmoid': 66 | x = np.linspace(-6, 6, 1001) 67 | alphas_cumprod = sigmoid(x) 68 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 69 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 70 | betas = np.clip(betas, a_min=None, a_max=0.999) 71 | elif beta_schedule == 'alpha_linear': 72 | timesteps = np.arange(0, num_diffusion_timesteps+1, dtype=np.float64)/num_diffusion_timesteps 73 | alphas_cumprod = -timesteps+1 74 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 75 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 76 | betas = np.clip(betas, a_min=None, a_max=0.999) 77 | 78 | else: 79 | raise NotImplementedError(beta_schedule) 80 | assert betas.shape == (num_diffusion_timesteps,) 81 | return betas 82 | 83 | 84 | class Diffusion(object): 85 | def __init__(self, args, config, device=None): 86 | self.args = args 87 | self.config = config 88 | if device is None: 89 | device = ( 90 | torch.device("cuda") 91 | if torch.cuda.is_available() 92 | else torch.device("cpu") 93 | ) 94 | self.device = device 95 | 96 | self.model_var_type = config.model.var_type 97 | betas = get_beta_schedule( 98 | beta_schedule=config.diffusion.beta_schedule, 99 | beta_start=config.diffusion.beta_start, 100 | beta_end=config.diffusion.beta_end, 101 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 102 | ) 103 | betas = self.betas = torch.from_numpy(betas).float().to(self.device) 104 | self.num_timesteps = betas.shape[0] 105 | 106 | alphas = 1.0 - betas 107 | alphas_cumprod = alphas.cumprod(dim=0) 108 | alphas_cumprod_prev = torch.cat( 109 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0 110 | ) 111 | posterior_variance = ( 112 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 113 | ) 114 | if self.model_var_type == "fixedlarge": 115 | self.logvar = betas.log() 116 | elif self.model_var_type == "fixedsmall": 117 | self.logvar = posterior_variance.clamp(min=1e-20).log() 118 | 119 | 120 | # Training Fast-DDPM for tasks that have only one condition: image translation and CT denoising. 121 | def sg_train(self): 122 | args, config = self.args, self.config 123 | tb_logger = self.config.tb_logger 124 | 125 | if self.args.dataset=='LDFDCT': 126 | # LDFDCT for CT image denoising 127 | dataset = LDFDCT(self.config.data.train_dataroot, self.config.data.image_size, split='train') 128 | print('Start training your Fast-DDPM model on LDFDCT dataset.') 129 | elif self.args.dataset=='BRATS': 130 | # BRATS for brain image translation 131 | dataset = BRATS(self.config.data.train_dataroot, self.config.data.image_size, split='train') 132 | print('Start training your Fast-DDPM model on BRATS dataset.') 133 | print('The scheduler sampling type is {}. The number of involved time steps is {} out of 1000.'.format(self.args.scheduler_type, self.args.timesteps)) 134 | 135 | train_loader = data.DataLoader( 136 | dataset, 137 | batch_size=config.training.batch_size, 138 | shuffle=True, 139 | num_workers=config.data.num_workers, 140 | pin_memory=True) 141 | 142 | model = Model(config) 143 | model = model.to(self.device) 144 | model = torch.nn.DataParallel(model) 145 | 146 | optimizer = get_optimizer(self.config, model.parameters()) 147 | 148 | if self.config.model.ema: 149 | ema_helper = EMAHelper(mu=self.config.model.ema_rate) 150 | ema_helper.register(model) 151 | else: 152 | ema_helper = None 153 | 154 | start_epoch, step = 0, 0 155 | if self.args.resume_training: 156 | states = torch.load(os.path.join(self.args.log_path, "ckpt.pth")) 157 | model.load_state_dict(states[0]) 158 | 159 | states[1]["param_groups"][0]["eps"] = self.config.optim.eps 160 | optimizer.load_state_dict(states[1]) 161 | start_epoch = states[2] 162 | step = states[3] 163 | if self.config.model.ema: 164 | ema_helper.load_state_dict(states[4]) 165 | 166 | for epoch in range(start_epoch, self.config.training.n_epochs): 167 | for i, x in enumerate(train_loader): 168 | n = x['LD'].size(0) 169 | model.train() 170 | step += 1 171 | 172 | x_img = x['LD'].to(self.device) 173 | x_gt = x['FD'].to(self.device) 174 | 175 | e = torch.randn_like(x_gt) 176 | b = self.betas 177 | 178 | if self.args.scheduler_type == 'uniform': 179 | skip = self.num_timesteps // self.args.timesteps 180 | t_intervals = torch.arange(-1, self.num_timesteps, skip) 181 | t_intervals[0] = 0 182 | elif self.args.scheduler_type == 'non-uniform': 183 | t_intervals = torch.tensor([0, 199, 399, 599, 699, 799, 849, 899, 949, 999]) 184 | 185 | if self.args.timesteps != 10: 186 | num_1 = int(self.args.timesteps*0.4) 187 | num_2 = int(self.args.timesteps*0.6) 188 | stage_1 = torch.linspace(0, 699, num_1+1)[:-1] 189 | stage_2 = torch.linspace(699, 999, num_2) 190 | stage_1 = torch.ceil(stage_1).long() 191 | stage_2 = torch.ceil(stage_2).long() 192 | t_intervals = torch.cat((stage_1, stage_2)) 193 | else: 194 | raise Exception("The scheduler type is either uniform or non-uniform.") 195 | 196 | # antithetic sampling 197 | idx_1 = torch.randint(0, len(t_intervals), size=(n // 2 + 1,)) 198 | idx_2 = len(t_intervals)-idx_1-1 199 | idx = torch.cat([idx_1, idx_2], dim=0)[:n] 200 | t = t_intervals[idx].to(self.device) 201 | 202 | loss = loss_registry[config.model.type](model, x_img, x_gt, t, e, b) 203 | 204 | tb_logger.add_scalar("loss", loss, global_step=step) 205 | 206 | logging.info( 207 | f"step: {step}, loss: {loss.item()}" 208 | ) 209 | 210 | optimizer.zero_grad() 211 | loss.backward() 212 | 213 | try: 214 | torch.nn.utils.clip_grad_norm_( 215 | model.parameters(), config.optim.grad_clip 216 | ) 217 | except Exception: 218 | pass 219 | optimizer.step() 220 | 221 | if self.config.model.ema: 222 | ema_helper.update(model) 223 | 224 | if step % self.config.training.snapshot_freq == 0 or step == 1: 225 | states = [ 226 | model.state_dict(), 227 | optimizer.state_dict(), 228 | epoch, 229 | step, 230 | ] 231 | if self.config.model.ema: 232 | states.append(ema_helper.state_dict()) 233 | 234 | torch.save( 235 | states, 236 | os.path.join(self.args.log_path, "ckpt_{}.pth".format(step)), 237 | ) 238 | torch.save(states, os.path.join(self.args.log_path, "ckpt.pth")) 239 | 240 | 241 | # Training Fast-DDPM for tasks that have two conditions: multi image super-resolution. 242 | def sr_train(self): 243 | args, config = self.args, self.config 244 | tb_logger = self.config.tb_logger 245 | 246 | dataset = PMUB(self.config.data.train_dataroot, self.config.data.image_size, split='train') 247 | print('Start training your Fast-DDPM model on PMUB dataset.') 248 | print('The scheduler sampling type is {}. The number of involved time steps is {} out of 1000.'.format(self.args.scheduler_type, self.args.timesteps)) 249 | train_loader = data.DataLoader( 250 | dataset, 251 | batch_size=config.training.batch_size, 252 | shuffle=True, 253 | num_workers=config.data.num_workers, 254 | pin_memory=True) 255 | 256 | model = Model(config) 257 | model = model.to(self.device) 258 | model = torch.nn.DataParallel(model) 259 | 260 | optimizer = get_optimizer(self.config, model.parameters()) 261 | 262 | if self.config.model.ema: 263 | ema_helper = EMAHelper(mu=self.config.model.ema_rate) 264 | ema_helper.register(model) 265 | else: 266 | ema_helper = None 267 | 268 | start_epoch, step = 0, 0 269 | if self.args.resume_training: 270 | states = torch.load(os.path.join(self.args.log_path, "ckpt.pth")) 271 | model.load_state_dict(states[0]) 272 | 273 | states[1]["param_groups"][0]["eps"] = self.config.optim.eps 274 | optimizer.load_state_dict(states[1]) 275 | start_epoch = states[2] 276 | step = states[3] 277 | if self.config.model.ema: 278 | ema_helper.load_state_dict(states[4]) 279 | 280 | for epoch in range(start_epoch, self.config.training.n_epochs): 281 | for i, x in enumerate(train_loader): 282 | n = x['BW'].size(0) 283 | model.train() 284 | step += 1 285 | 286 | x_bw = x['BW'].to(self.device) 287 | x_md = x['MD'].to(self.device) 288 | x_fw = x['FW'].to(self.device) 289 | 290 | e = torch.randn_like(x_md) 291 | b = self.betas 292 | 293 | if self.args.scheduler_type == 'uniform': 294 | skip = self.num_timesteps // self.args.timesteps 295 | t_intervals = torch.arange(-1, self.num_timesteps, skip) 296 | t_intervals[0] = 0 297 | elif self.args.scheduler_type == 'non-uniform': 298 | t_intervals = torch.tensor([0, 199, 399, 599, 699, 799, 849, 899, 949, 999]) 299 | 300 | if self.args.timesteps != 10: 301 | num_1 = int(self.args.timesteps*0.4) 302 | num_2 = int(self.args.timesteps*0.6) 303 | stage_1 = torch.linspace(0, 699, num_1+1)[:-1] 304 | stage_2 = torch.linspace(699, 999, num_2) 305 | stage_1 = torch.ceil(stage_1).long() 306 | stage_2 = torch.ceil(stage_2).long() 307 | t_intervals = torch.cat((stage_1, stage_2)) 308 | else: 309 | raise Exception("The scheduler type is either uniform or non-uniform.") 310 | 311 | # antithetic sampling 312 | idx_1 = torch.randint(0, len(t_intervals), size=(n // 2 + 1,)) 313 | idx_2 = len(t_intervals)-idx_1-1 314 | idx = torch.cat([idx_1, idx_2], dim=0)[:n] 315 | t = t_intervals[idx].to(self.device) 316 | 317 | loss = loss_registry[config.model.type](model, x_bw, x_md, x_fw, t, e, b) 318 | 319 | tb_logger.add_scalar("loss", loss, global_step=step) 320 | 321 | logging.info( 322 | f"step: {step}, loss: {loss.item()}" 323 | ) 324 | 325 | optimizer.zero_grad() 326 | loss.backward() 327 | 328 | try: 329 | torch.nn.utils.clip_grad_norm_( 330 | model.parameters(), config.optim.grad_clip 331 | ) 332 | except Exception: 333 | pass 334 | optimizer.step() 335 | 336 | if self.config.model.ema: 337 | ema_helper.update(model) 338 | 339 | if step % self.config.training.snapshot_freq == 0 or step == 1: 340 | states = [ 341 | model.state_dict(), 342 | optimizer.state_dict(), 343 | epoch, 344 | step, 345 | ] 346 | if self.config.model.ema: 347 | states.append(ema_helper.state_dict()) 348 | 349 | torch.save( 350 | states, 351 | os.path.join(self.args.log_path, "ckpt_{}.pth".format(step)), 352 | ) 353 | torch.save(states, os.path.join(self.args.log_path, "ckpt.pth")) 354 | 355 | 356 | # Training original DDPM for tasks that have only one condition: image translation and CT denoising. 357 | def sg_ddpm_train(self): 358 | args, config = self.args, self.config 359 | tb_logger = self.config.tb_logger 360 | 361 | if self.args.dataset=='LDFDCT': 362 | # LDFDCT for CT image denoising 363 | dataset = LDFDCT(self.config.data.train_dataroot, self.config.data.image_size, split='train') 364 | print('Start training DDPM model on LDFDCT dataset.') 365 | elif self.args.dataset=='BRATS': 366 | # BRATS for brain image translation 367 | dataset = BRATS(self.config.data.train_dataroot, self.config.data.image_size, split='train') 368 | print('Start training DDPM model on BRATS dataset.') 369 | 370 | print('The number of involved time steps is {} out of 1000.'.format(self.args.timesteps)) 371 | train_loader = data.DataLoader( 372 | dataset, 373 | batch_size=config.training.batch_size, 374 | shuffle=True, 375 | num_workers=config.data.num_workers, 376 | pin_memory=True) 377 | 378 | model = Model(config) 379 | model = model.to(self.device) 380 | model = torch.nn.DataParallel(model) 381 | 382 | optimizer = get_optimizer(self.config, model.parameters()) 383 | 384 | if self.config.model.ema: 385 | ema_helper = EMAHelper(mu=self.config.model.ema_rate) 386 | ema_helper.register(model) 387 | else: 388 | ema_helper = None 389 | 390 | start_epoch, step = 0, 0 391 | if self.args.resume_training: 392 | states = torch.load(os.path.join(self.args.log_path, "ckpt.pth")) 393 | model.load_state_dict(states[0]) 394 | 395 | states[1]["param_groups"][0]["eps"] = self.config.optim.eps 396 | optimizer.load_state_dict(states[1]) 397 | start_epoch = states[2] 398 | step = states[3] 399 | if self.config.model.ema: 400 | ema_helper.load_state_dict(states[4]) 401 | 402 | for epoch in range(start_epoch, self.config.training.n_epochs): 403 | for i, x in enumerate(train_loader): 404 | n = x['LD'].size(0) 405 | model.train() 406 | step += 1 407 | 408 | x_img = x['LD'].to(self.device) 409 | x_gt = x['FD'].to(self.device) 410 | 411 | e = torch.randn_like(x_gt) 412 | b = self.betas 413 | 414 | t = torch.randint( 415 | low=0, high=self.num_timesteps, size=(n // 2 + 1,) 416 | ).to(self.device) 417 | t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n] 418 | 419 | loss = loss_registry[config.model.type](model, x_img, x_gt, t, e, b) 420 | 421 | tb_logger.add_scalar("loss", loss, global_step=step) 422 | 423 | logging.info( 424 | f"step: {step}, loss: {loss.item()}" 425 | ) 426 | 427 | optimizer.zero_grad() 428 | loss.backward() 429 | 430 | try: 431 | torch.nn.utils.clip_grad_norm_( 432 | model.parameters(), config.optim.grad_clip 433 | ) 434 | except Exception: 435 | pass 436 | optimizer.step() 437 | 438 | if self.config.model.ema: 439 | ema_helper.update(model) 440 | 441 | if step % self.config.training.snapshot_freq == 0 or step == 1: 442 | states = [ 443 | model.state_dict(), 444 | optimizer.state_dict(), 445 | epoch, 446 | step, 447 | ] 448 | if self.config.model.ema: 449 | states.append(ema_helper.state_dict()) 450 | 451 | torch.save( 452 | states, 453 | os.path.join(self.args.log_path, "ckpt_{}.pth".format(step)), 454 | ) 455 | torch.save(states, os.path.join(self.args.log_path, "ckpt.pth")) 456 | 457 | 458 | # Training original DDPM for tasks that have two conditions: multi image super-resolution. 459 | def sr_ddpm_train(self): 460 | args, config = self.args, self.config 461 | tb_logger = self.config.tb_logger 462 | 463 | dataset = PMUB(self.config.data.train_dataroot, self.config.data.image_size, split='train') 464 | print('Start training DDPM model on PMUB dataset.') 465 | print('The number of involved time steps is {} out of 1000.'.format(self.args.timesteps)) 466 | 467 | train_loader = data.DataLoader( 468 | dataset, 469 | batch_size=config.training.batch_size, 470 | shuffle=True, 471 | num_workers=config.data.num_workers, 472 | pin_memory=True) 473 | 474 | model = Model(config) 475 | model = model.to(self.device) 476 | model = torch.nn.DataParallel(model) 477 | 478 | optimizer = get_optimizer(self.config, model.parameters()) 479 | 480 | if self.config.model.ema: 481 | ema_helper = EMAHelper(mu=self.config.model.ema_rate) 482 | ema_helper.register(model) 483 | else: 484 | ema_helper = None 485 | 486 | start_epoch, step = 0, 0 487 | if self.args.resume_training: 488 | states = torch.load(os.path.join(self.args.log_path, "ckpt.pth")) 489 | model.load_state_dict(states[0]) 490 | 491 | states[1]["param_groups"][0]["eps"] = self.config.optim.eps 492 | optimizer.load_state_dict(states[1]) 493 | start_epoch = states[2] 494 | step = states[3] 495 | if self.config.model.ema: 496 | ema_helper.load_state_dict(states[4]) 497 | 498 | time_start = time.time() 499 | total_time = 0 500 | for epoch in range(start_epoch, self.config.training.n_epochs): 501 | for i, x in enumerate(train_loader): 502 | n = x['BW'].size(0) 503 | model.train() 504 | step += 1 505 | 506 | x_bw = x['BW'].to(self.device) 507 | x_md = x['MD'].to(self.device) 508 | x_fw = x['FW'].to(self.device) 509 | 510 | e = torch.randn_like(x_md) 511 | b = self.betas 512 | 513 | # antithetic sampling 514 | t = torch.randint( 515 | low=0, high=self.num_timesteps, size=(n // 2 + 1,) 516 | ).to(self.device) 517 | t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n] 518 | loss = loss_registry[config.model.type](model, x_bw, x_md, x_fw, t, e, b) 519 | 520 | tb_logger.add_scalar("loss", loss, global_step=step) 521 | 522 | logging.info( 523 | f"step: {step}, loss: {loss.item()}" 524 | ) 525 | 526 | optimizer.zero_grad() 527 | loss.backward() 528 | 529 | try: 530 | torch.nn.utils.clip_grad_norm_( 531 | model.parameters(), config.optim.grad_clip 532 | ) 533 | except Exception: 534 | pass 535 | optimizer.step() 536 | 537 | if self.config.model.ema: 538 | ema_helper.update(model) 539 | 540 | if step % self.config.training.snapshot_freq == 0 or step == 1: 541 | states = [ 542 | model.state_dict(), 543 | optimizer.state_dict(), 544 | epoch, 545 | step, 546 | ] 547 | if self.config.model.ema: 548 | states.append(ema_helper.state_dict()) 549 | 550 | torch.save( 551 | states, 552 | os.path.join(self.args.log_path, "ckpt_{}.pth".format(step)), 553 | ) 554 | torch.save(states, os.path.join(self.args.log_path, "ckpt.pth")) 555 | 556 | 557 | # Sampling for tasks that have two conditions: multi image super-resolution. 558 | def sr_sample(self): 559 | ckpt_list = self.config.sampling.ckpt_id 560 | for ckpt_idx in ckpt_list: 561 | self.ckpt_idx = ckpt_idx 562 | model = Model(self.config) 563 | print('Start inference on model of {} steps'.format(ckpt_idx)) 564 | 565 | if not self.args.use_pretrained: 566 | states = torch.load( 567 | os.path.join( 568 | self.args.log_path, f"ckpt_{ckpt_idx}.pth" 569 | ), 570 | map_location=self.config.device, 571 | ) 572 | model = model.to(self.device) 573 | model = torch.nn.DataParallel(model) 574 | model.load_state_dict(states[0], strict=True) 575 | 576 | if self.config.model.ema: 577 | ema_helper = EMAHelper(mu=self.config.model.ema_rate) 578 | ema_helper.register(model) 579 | ema_helper.load_state_dict(states[-1]) 580 | ema_helper.ema(model) 581 | else: 582 | ema_helper = None 583 | else: 584 | # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion 585 | if self.config.data.dataset == "CIFAR10": 586 | name = "cifar10" 587 | elif self.config.data.dataset == "LSUN": 588 | name = f"lsun_{self.config.data.category}" 589 | else: 590 | raise ValueError 591 | ckpt = get_ckpt_path(f"ema_{name}") 592 | print("Loading checkpoint {}".format(ckpt)) 593 | model.load_state_dict(torch.load(ckpt, map_location=self.device)) 594 | model.to(self.device) 595 | model = torch.nn.DataParallel(model) 596 | 597 | model.eval() 598 | 599 | if self.args.fid: 600 | self.sr_sample_fid(model) 601 | elif self.args.interpolation: 602 | self.sr_sample_interpolation(model) 603 | elif self.args.sequence: 604 | self.sample_sequence(model) 605 | else: 606 | raise NotImplementedError("Sample procedeure not defined") 607 | 608 | 609 | # Sampling for tasks that have only one condition: image translation and CT denoising. 610 | def sg_sample(self): 611 | ckpt_list = self.config.sampling.ckpt_id 612 | for ckpt_idx in ckpt_list: 613 | self.ckpt_idx = ckpt_idx 614 | model = Model(self.config) 615 | print('Start inference on model of {} steps'.format(ckpt_idx)) 616 | 617 | if not self.args.use_pretrained: 618 | states = torch.load( 619 | os.path.join( 620 | self.args.log_path, f"ckpt_{ckpt_idx}.pth" 621 | ), 622 | map_location=self.config.device, 623 | ) 624 | model = model.to(self.device) 625 | model = torch.nn.DataParallel(model) 626 | model.load_state_dict(states[0], strict=True) 627 | 628 | if self.config.model.ema: 629 | ema_helper = EMAHelper(mu=self.config.model.ema_rate) 630 | ema_helper.register(model) 631 | ema_helper.load_state_dict(states[-1]) 632 | ema_helper.ema(model) 633 | else: 634 | ema_helper = None 635 | else: 636 | # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion 637 | if self.config.data.dataset == "CIFAR10": 638 | name = "cifar10" 639 | elif self.config.data.dataset == "LSUN": 640 | name = f"lsun_{self.config.data.category}" 641 | else: 642 | raise ValueError 643 | ckpt = get_ckpt_path(f"ema_{name}") 644 | print("Loading checkpoint {}".format(ckpt)) 645 | model.load_state_dict(torch.load(ckpt, map_location=self.device)) 646 | model.to(self.device) 647 | model = torch.nn.DataParallel(model) 648 | 649 | model.eval() 650 | 651 | if self.args.fid: 652 | self.sg_sample_fid(model) 653 | elif self.args.interpolation: 654 | self.sr_sample_interpolation(model) 655 | elif self.args.sequence: 656 | self.sample_sequence(model) 657 | else: 658 | raise NotImplementedError("Sample procedeure not defined") 659 | 660 | 661 | def sr_sample_fid(self, model): 662 | config = self.config 663 | img_id = len(glob.glob(f"{self.args.image_folder}/*")) 664 | print(f"starting from image {img_id}") 665 | 666 | sample_dataset = PMUB(self.config.data.sample_dataroot, self.config.data.image_size, split='calculate') 667 | print('Start sampling model on PMUB dataset.') 668 | print('The inference sample type is {}. The scheduler sampling type is {}. The number of involved time steps is {} out of 1000.'.format(self.args.sample_type, self.args.scheduler_type, self.args.timesteps)) 669 | 670 | sample_loader = data.DataLoader( 671 | sample_dataset, 672 | batch_size=config.sampling_fid.batch_size, 673 | shuffle=False, 674 | num_workers=config.data.num_workers) 675 | 676 | with torch.no_grad(): 677 | data_num = len(sample_dataset) 678 | print('The length of test set is:', data_num) 679 | avg_psnr = 0.0 680 | avg_ssim = 0.0 681 | time_list = [] 682 | psnr_list = [] 683 | ssim_list = [] 684 | 685 | for batch_idx, img in tqdm.tqdm(enumerate(sample_loader), desc="Generating image samples for FID evaluation."): 686 | n = img['BW'].shape[0] 687 | 688 | x = torch.randn( 689 | n, 690 | config.data.channels, 691 | config.data.image_size, 692 | config.data.image_size, 693 | device=self.device, 694 | ) 695 | x_bw = img['BW'].to(self.device) 696 | x_md = img['MD'].to(self.device) 697 | x_fw = img['FW'].to(self.device) 698 | case_name = img['case_name'][0] 699 | 700 | time_start = time.time() 701 | x = self.sr_sample_image(x, x_bw, x_fw, model) 702 | time_end = time.time() 703 | 704 | x = inverse_data_transform(config, x) 705 | x_md = inverse_data_transform(config, x_md) 706 | x_tensor = x 707 | x_md_tensor = x_md 708 | x_md = x_md.squeeze().float().cpu().numpy() 709 | x = x.squeeze().float().cpu().numpy() 710 | x_md = (x_md*255.0).round() 711 | x = (x*255.0).round() 712 | 713 | PSNR = 0.0 714 | SSIM = 0.0 715 | for i in range(x.shape[0]): 716 | psnr_temp = calculate_psnr(x[i,:,:], x_md[i,:,:]) 717 | ssim_temp = ssim(x_md[i,:,:], x[i,:,:], data_range=255) 718 | PSNR += psnr_temp 719 | SSIM += ssim_temp 720 | psnr_list.append(psnr_temp) 721 | ssim_list.append(ssim_temp) 722 | 723 | PSNR_print = PSNR/x.shape[0] 724 | SSIM_print = SSIM/x.shape[0] 725 | 726 | case_time = time_end-time_start 727 | time_list.append(case_time) 728 | 729 | avg_psnr += PSNR 730 | avg_ssim += SSIM 731 | logging.info('Case {}: PSNR {}, SSIM {}, time {}'.format(case_name, PSNR_print, SSIM_print, case_time)) 732 | 733 | for i in range(0, n): 734 | # image:(0-1) 735 | tvu.save_image( 736 | x_tensor[i], os.path.join(self.args.image_folder, "{}_{}_pt.png".format(self.ckpt_idx, img_id)) 737 | ) 738 | tvu.save_image( 739 | x_md_tensor[i], os.path.join(self.args.image_folder, "{}_{}_gt.png".format(self.ckpt_idx, img_id)) 740 | ) 741 | img_id += 1 742 | 743 | avg_psnr = avg_psnr / data_num 744 | avg_ssim = avg_ssim / data_num 745 | # Drop first and last for time calculation. 746 | avg_time = sum(time_list[1:-1])/(len(time_list)-2) 747 | logging.info('Average: PSNR {}, SSIM {}, time {}'.format(avg_psnr, avg_ssim, avg_time)) 748 | 749 | 750 | def sg_sample_fid(self, model): 751 | config = self.config 752 | img_id = len(glob.glob(f"{self.args.image_folder}/*")) 753 | print(f"starting from image {img_id}") 754 | 755 | 756 | if self.args.dataset=='LDFDCT': 757 | # LDFDCT for CT image denoising 758 | sample_dataset = LDFDCT(self.config.data.sample_dataroot, self.config.data.image_size, split='calculate') 759 | print('Start training model on LDFDCT dataset.') 760 | elif self.args.dataset=='BRATS': 761 | # BRATS for brain image translation 762 | sample_dataset = BRATS(self.config.data.sample_dataroot, self.config.data.image_size, split='calculate') 763 | print('Start training model on BRATS dataset.') 764 | print('The inference sample type is {}. The scheduler sampling type is {}. The number of involved time steps is {} out of 1000.'.format(self.args.sample_type, self.args.scheduler_type, self.args.timesteps)) 765 | 766 | sample_loader = data.DataLoader( 767 | sample_dataset, 768 | batch_size=config.sampling_fid.batch_size, 769 | shuffle=False, 770 | num_workers=config.data.num_workers) 771 | 772 | with torch.no_grad(): 773 | data_num = len(sample_dataset) 774 | print('The length of test set is:', data_num) 775 | avg_psnr = 0.0 776 | avg_ssim = 0.0 777 | time_list = [] 778 | psnr_list = [] 779 | ssim_list = [] 780 | 781 | for batch_idx, sample in tqdm.tqdm(enumerate(sample_loader), desc="Generating image samples for FID evaluation."): 782 | n = sample['LD'].shape[0] 783 | 784 | x = torch.randn( 785 | n, 786 | config.data.channels, 787 | config.data.image_size, 788 | config.data.image_size, 789 | device=self.device, 790 | ) 791 | x_img = sample['LD'].to(self.device) 792 | x_gt = sample['FD'].to(self.device) 793 | case_name = sample['case_name'] 794 | 795 | time_start = time.time() 796 | x = self.sg_sample_image(x, x_img, model) 797 | time_end = time.time() 798 | 799 | x = inverse_data_transform(config, x) 800 | x_gt = inverse_data_transform(config, x_gt) 801 | x_tensor = x 802 | x_gt_tensor = x_gt 803 | x_gt = x_gt.squeeze().float().cpu().numpy() 804 | x = x.squeeze().float().cpu().numpy() 805 | x_gt = x_gt*255 806 | x = x*255 807 | 808 | PSNR = 0.0 809 | SSIM = 0.0 810 | for i in range(x.shape[0]): 811 | psnr_temp = calculate_psnr(x[i,:,:], x_gt[i,:,:]) 812 | ssim_temp = ssim(x_gt[i,:,:], x[i,:,:], data_range=255) 813 | PSNR += psnr_temp 814 | SSIM += ssim_temp 815 | psnr_list.append(psnr_temp) 816 | ssim_list.append(ssim_temp) 817 | 818 | PSNR_print = PSNR/x.shape[0] 819 | SSIM_print = SSIM/x.shape[0] 820 | 821 | case_time = time_end-time_start 822 | time_list.append(case_time) 823 | 824 | avg_psnr += PSNR 825 | avg_ssim += SSIM 826 | logging.info('Case {}: PSNR {}, SSIM {}, time {}'.format(case_name[0], PSNR_print, SSIM_print, case_time)) 827 | 828 | for i in range(0, n): 829 | # image:(0-1) 830 | tvu.save_image( 831 | x_tensor[i], os.path.join(self.args.image_folder, "{}_{}_pt.png".format(self.ckpt_idx, img_id)) 832 | ) 833 | tvu.save_image( 834 | x_gt_tensor[i], os.path.join(self.args.image_folder, "{}_{}_gt.png".format(self.ckpt_idx, img_id)) 835 | ) 836 | img_id += 1 837 | 838 | avg_psnr = avg_psnr / data_num 839 | avg_ssim = avg_ssim / data_num 840 | # Drop first and last for time calculation. 841 | avg_time = sum(time_list[1:-1])/(len(time_list)-2) 842 | logging.info('Average: PSNR {}, SSIM {}, time {}'.format(avg_psnr, avg_ssim, avg_time)) 843 | 844 | 845 | def sr_sample_image(self, x, x_bw, x_fw, model, last=True): 846 | try: 847 | skip = self.args.skip 848 | except Exception: 849 | skip = 1 850 | 851 | if self.args.sample_type == "generalized": 852 | if self.args.scheduler_type == 'uniform': 853 | skip = self.num_timesteps // self.args.timesteps 854 | seq = range(-1, self.num_timesteps, skip) 855 | seq = list(seq) 856 | seq[0] = 0 857 | elif self.args.scheduler_type == 'non-uniform': 858 | seq = [0, 199, 399, 599, 699, 799, 849, 899, 949, 999] 859 | 860 | if self.args.timesteps != 10: 861 | num_1 = int(self.args.timesteps*0.4) 862 | num_2 = int(self.args.timesteps*0.6) 863 | stage_1 = np.linspace(0, 699, num_1+1)[:-1] 864 | stage_2 = np.linspace(699, 999, num_2) 865 | stage_1 = np.ceil(stage_1).astype(int) 866 | stage_2 = np.ceil(stage_2).astype(int) 867 | seq = np.concatenate((stage_1, stage_2)) 868 | else: 869 | raise Exception("The scheduler type is either uniform or non-uniform.") 870 | 871 | from functions.denoising import generalized_steps, sr_generalized_steps 872 | 873 | xs = sr_generalized_steps(x, x_bw, x_fw, seq, model, self.betas, eta=self.args.eta) 874 | x = xs 875 | 876 | elif self.args.sample_type == "ddpm_noisy": 877 | skip = self.num_timesteps // self.args.timesteps 878 | seq = range(0, self.num_timesteps, skip) 879 | 880 | from functions.denoising import ddpm_steps, sr_ddpm_steps 881 | 882 | x = sr_ddpm_steps(x, x_bw, x_fw, seq, model, self.betas) 883 | else: 884 | raise NotImplementedError 885 | if last: 886 | x = x[0][-1] 887 | return x 888 | 889 | 890 | def sg_sample_image(self, x, x_img, model, last=True): 891 | try: 892 | skip = self.args.skip 893 | except Exception: 894 | skip = 1 895 | 896 | if self.args.sample_type == "generalized": 897 | if self.args.scheduler_type == 'uniform': 898 | skip = self.num_timesteps // self.args.timesteps 899 | seq = range(-1, self.num_timesteps, skip) 900 | seq = list(seq) 901 | seq[0] = 0 902 | elif self.args.scheduler_type == 'non-uniform': 903 | seq = [0, 199, 399, 599, 699, 799, 849, 899, 949, 999] 904 | 905 | if self.args.timesteps != 10: 906 | num_1 = int(self.args.timesteps*0.4) 907 | num_2 = int(self.args.timesteps*0.6) 908 | stage_1 = np.linspace(0, 699, num_1+1)[:-1] 909 | stage_2 = np.linspace(699, 999, num_2) 910 | stage_1 = np.ceil(stage_1).astype(int) 911 | stage_2 = np.ceil(stage_2).astype(int) 912 | seq = np.concatenate((stage_1, stage_2)) 913 | else: 914 | raise Exception("The scheduler type is either uniform or non-uniform.") 915 | 916 | from functions.denoising import generalized_steps, sr_generalized_steps, sg_generalized_steps 917 | 918 | xs = sg_generalized_steps(x, x_img, seq, model, self.betas, eta=self.args.eta) 919 | x = xs 920 | 921 | elif self.args.sample_type == "ddpm_noisy": 922 | skip = self.num_timesteps // self.args.timesteps 923 | seq = range(0, self.num_timesteps, skip) 924 | 925 | from functions.denoising import ddpm_steps, sr_ddpm_steps, sg_ddpm_steps 926 | 927 | x = sg_ddpm_steps(x, x_img, seq, model, self.betas) 928 | else: 929 | raise NotImplementedError 930 | if last: 931 | x = x[0][-1] 932 | return x 933 | 934 | 935 | def test(self): 936 | pass 937 | --------------------------------------------------------------------------------