├── README.md ├── eval_3DCT_blendcond.py ├── eval_3D_blend_cond.sh ├── eval_3D_blend_cond_lidc.py ├── guided_diffusion ├── CTDataset.py ├── __pycache__ │ ├── CTDataset.cpython-38.pyc │ ├── dist_util.cpython-38.pyc │ ├── fp16_util.cpython-38.pyc │ ├── fp16_util.cpython-39.pyc │ ├── gaussian_diffusion.cpython-38.pyc │ ├── gaussian_diffusion.cpython-39.pyc │ ├── logger.cpython-38.pyc │ ├── logger.cpython-39.pyc │ ├── losses.cpython-38.pyc │ ├── losses.cpython-39.pyc │ ├── models.cpython-38.pyc │ ├── models.cpython-39.pyc │ ├── nn.cpython-38.pyc │ ├── nn.cpython-39.pyc │ ├── resample.cpython-38.pyc │ ├── respace.cpython-38.pyc │ ├── respace.cpython-39.pyc │ ├── script_util.cpython-38.pyc │ ├── script_util.cpython-39.pyc │ ├── train_util.cpython-38.pyc │ ├── unet.cpython-38.pyc │ ├── unet.cpython-39.pyc │ ├── utils.cpython-38.pyc │ └── utils.cpython-39.pyc ├── diffusion.py ├── dist_util.py ├── fp16_util.py ├── gaussian_diffusion.py ├── image_datasets.py ├── logger.py ├── losses.py ├── models.py ├── nn.py ├── resample.py ├── respace.py ├── script_util.py ├── train_util.py ├── training_script.py ├── training_triplane_script.py ├── unet.py └── utils.py ├── main.py ├── train_SVCT_3D_triplane.sh ├── training_triplane_lidc.py ├── training_triplane_script.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # DiffusionBlend: Learning 3D Image Prior through Position-aware Diffusion Score Blending for 3D Computed Tomography Reconstruction 2 | 3 | This repository contains the official implementation of **"DiffusionBlend: Learning 3D Image Prior through Position-aware Diffusion Score Blending for 3D Computed Tomography Reconstruction"**, published at **NeurIPS 2024**. 4 | 5 | Paper link: https://openreview.net/forum?id=h3Kv6sdTWO&referrer=%5Bthe%20profile%20of%20Bowen%20Song%5D(%2Fprofile%3Fid%3D~Bowen_Song3) 6 | 7 | 8 | --- 9 | 10 | ## Overview 11 | 12 | DiffusionBlend introduces a novel method for 3D computed tomography (CT) reconstruction using position-aware diffusion score blending. By leveraging position-specific priors, the framework achieves enhanced reconstruction accuracy while maintaining computational efficiency. 13 | image 14 | 15 | 16 | 17 | --- 18 | 19 | ## Features 20 | 21 | - **Position-aware Diffusion Blending:** Incorporates spatial information to refine 3D reconstruction quality. 22 | - **Triplane-based 3D Representation:** Utilizes a position encoding to model 3D patch priors efficiently. 23 | - **Scalable and Generalizable:** Designed for both synthetic and real-world CT reconstruction tasks. 24 | 25 | --- 26 | 27 | ## Requirements 28 | 29 | The code is implemented in Python and requires the following dependencies: 30 | 31 | - `torch` (>=1.9.0) 32 | - `torchvision` 33 | - `numpy` 34 | 35 | You can install the dependencies via: 36 | 37 | ```bash 38 | pip install torch torchvision numpy 39 | ``` 40 | 41 | ## Training 42 | To train the model on synthetic volume CT data, use the following script: 43 | 44 | ```bash 45 | bash train_SVCT_3D_triplane.sh 46 | ``` 47 | 48 | ## Inference 49 | To perform inference and evaluate 3D reconstruction using diffusion score blending, use: 50 | 51 | ```bash 52 | bash eval_3D_blend_cond.sh 53 | ``` 54 | 55 | ## Results 56 | image 57 | 58 | image 59 | 60 | image 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | ## Citation 69 | If you find this work useful in your research, please cite: 70 | 71 | ``` 72 | @inproceedings{diffusionblend2024, 73 | title={DiffusionBlend: Learning 3D Image Prior through Position-aware Diffusion Score Blending for 3D Computed Tomography Reconstruction}, 74 | author={Song, Bowen and Hu, Jason and Luo, Zhaoxu and Fessler, Jeffrey A and Shen, Liyue}, 75 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, 76 | year={2024} 77 | } 78 | ``` 79 | 80 | ## Acknowledgements 81 | We thank the contributors and the NeurIPS community for their valuable feedback and discussions. 82 | 83 | ## License 84 | This project is licensed under the MIT License. See the LICENSE file for details. 85 | 86 | 87 | -------------------------------------------------------------------------------- /eval_3DCT_blendcond.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import time 4 | import glob 5 | import json 6 | import sys 7 | import math 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import tqdm 11 | import torch 12 | import torch.utils.data as data 13 | import torchvision.utils as tvu 14 | 15 | from guided_diffusion.models import Model 16 | from guided_diffusion.script_util import create_model, classifier_defaults, args_to_dict, create_gaussian_diffusion 17 | from guided_diffusion.utils import get_alpha_schedule 18 | import random 19 | 20 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 21 | from scipy.linalg import orth 22 | from pathlib import Path 23 | 24 | from physics.ct import CT 25 | from time import time 26 | from utils import shrink, CG, clear, batchfy, _Dz, _DzT, get_beta_schedule 27 | 28 | 29 | def compute_alpha(beta, t): 30 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 31 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 32 | return a 33 | 34 | 35 | class Diffusion(object): 36 | def __init__(self, args, config, device=None): 37 | self.args = args 38 | self.args.image_folder = Path(self.args.image_folder) 39 | for t in ["input", "recon", "label"]: 40 | if t == "recon": 41 | (self.args.image_folder / t / "progress").mkdir(exist_ok=True, parents=True) 42 | else: 43 | (self.args.image_folder / t).mkdir(exist_ok=True, parents=True) 44 | self.config = config 45 | print(self.config) 46 | if device is None: 47 | device = ( 48 | torch.device("cuda") 49 | if torch.cuda.is_available() 50 | else torch.device("cpu") 51 | ) 52 | self.device = device 53 | 54 | self.model_var_type = config.model.var_type 55 | betas = get_beta_schedule( 56 | beta_schedule=config.diffusion.beta_schedule, 57 | beta_start=config.diffusion.beta_start, 58 | beta_end=config.diffusion.beta_end, 59 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 60 | ) 61 | betas = self.betas = torch.from_numpy(betas).float().to(self.device) 62 | self.num_timesteps = betas.shape[0] 63 | 64 | alphas = 1.0 - betas 65 | alphas_cumprod = alphas.cumprod(dim=0) 66 | alphas_cumprod_prev = torch.cat( 67 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0 68 | ) 69 | self.alphas_cumprod_prev = alphas_cumprod_prev 70 | posterior_variance = ( 71 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 72 | ) 73 | if self.model_var_type == "fixedlarge": 74 | self.logvar = betas.log() 75 | elif self.model_var_type == "fixedsmall": 76 | self.logvar = posterior_variance.clamp(min=1e-20).log() 77 | 78 | def sample(self): 79 | config_dict = vars(self.config.model) 80 | config_dict['use_spacecode'] = False 81 | config_dict["class_cond"] = True 82 | model = create_model(**config_dict) 83 | ckpt = "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/triplane3D_finetune_452024_iter65099_cond.ckpt" 84 | 85 | model.load_state_dict(torch.load(ckpt, map_location=self.device)["state_dict"]) 86 | print(f"Model ckpt loaded from {ckpt}") 87 | model.to(self.device) 88 | model.eval() 89 | model = torch.nn.DataParallel(model) 90 | 91 | print('Run 3D DDS + DiffusionMBIR.', 92 | f'{self.args.T_sampling} sampling steps.', 93 | f'Task: {self.args.deg}.' 94 | ) 95 | self.dds3d(model) 96 | 97 | 98 | def blendscore(self, xt, model,t, start_ind = None, start_head = 0, num_batches = 180): 99 | model_kwargs = {} 100 | y = torch.ones(1) * 1 101 | y=y.to(xt.device).to(torch.long) 102 | model_kwargs["y"] = y 103 | 104 | et = torch.zeros((1, num_batches * 3, 256, 256)).to(xt.device).to(torch.float32) 105 | xt = torch.reshape(xt, (1, num_batches * 3, 256, 256)) 106 | et[:,:3,:,:] = model(xt[:,:3,:,:], t, **model_kwargs)[:,:3] 107 | et[:,xt.shape[1]-3:, :,:] = model(xt[:,xt.shape[1]-3:,:,:], t, **model_kwargs)[:,:3] 108 | 109 | if start_ind is None: 110 | start_ind = np.random.randint(start_head,3) 111 | for j in range(start_ind, xt.shape[1]-2, 3): 112 | #####randomly select instead of summing 113 | et_sing = model(xt[:,j:(j+3),:,:], t, **model_kwargs)[:,:3] #####1 x 3 x 256 x 256 114 | et[:,j:(j+3), :,:] = et_sing 115 | return et 116 | 117 | def vps_blend(self, xt, model,t, start_ind = None, start_head = 0, num_batches = 180): 118 | model_kwargs = {} 119 | y = torch.ones(1) * 3 120 | y=y.to(xt.device).to(torch.long) 121 | model_kwargs["y"] = y 122 | 123 | et = torch.zeros((1, num_batches * 3, 256, 256)).to(xt.device).to(torch.float32) 124 | 125 | xt = torch.reshape(xt, (1, num_batches * 3, 256, 256)) 126 | # 147258369 127 | for i in range(0,xt.shape[1], 9): 128 | for m in range(3): 129 | et[:,[i+m,i+m+3, i+m+6],:,:] = model(xt[:,[i+m,i+m+3, i+m+6],:,:], t, **model_kwargs)[:,:3] 130 | 131 | return et 132 | 133 | 134 | def dds3d(self, model): 135 | args, config = self.args, self.config 136 | print(f"Dataset path: {self.args.dataset_path}") 137 | root = Path(self.args.dataset_path) 138 | 139 | noise, noise_flag = self.args.sigma_y, False 140 | if noise > 0: 141 | noise_flag = True 142 | 143 | # parameters to be moved to args 144 | Nview = self.args.Nview 145 | rho = self.args.rho 146 | rho = 0.001 ###9:23pm 3/3 147 | lamb = self.args.lamb 148 | lamb = 0.05 * 1e-3 149 | n_ADMM = 1 150 | n_CG = self.args.CG_iter 151 | print(n_CG) 152 | 153 | blend= True ####test again 4/7 154 | 155 | 156 | time_travel = False ### 4:20 3/10 157 | vps_scale = 0.03 158 | vps = False ###debugging 159 | 160 | ddimsteps = 200 ###7/30 7:54pm 161 | 162 | 163 | # Specify save directory for saving generated samples 164 | save_root = Path(self.args.image_folder) 165 | save_root.mkdir(parents=True, exist_ok=True) 166 | 167 | irl_types = ['vol', 'input', 'recon', 'label'] 168 | for t in irl_types: 169 | save_root_f = save_root / t 170 | save_root_f.mkdir(parents=True, exist_ok=True) 171 | 172 | 173 | ##################################new data################################## 174 | fname_list = os.listdir("/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/validation") 175 | root = "/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/validation" 176 | # fname_list = sorted(fname_list, key=lambda x: float(x.split(".")[0]))[:60] 177 | fname_list.sort() 178 | 179 | pre_slices = 0 180 | num_batches = 168 181 | fname_list = fname_list[pre_slices:(pre_slices+ 3 * num_batches)] 182 | ###################################################################################################### 183 | 184 | print(fname_list) 185 | all_img = [] 186 | batch_size = 3 187 | print("Loading all data") 188 | if time_travel: 189 | tot_iters = 2 190 | else: 191 | tot_iters = 1 192 | for fname in fname_list: 193 | just_name = fname.split('.')[0] 194 | img = torch.from_numpy(np.load(os.path.join(root, fname), allow_pickle=True)) 195 | h, w = img.shape 196 | img = img.view(1, 1, h, w) 197 | all_img.append(img) 198 | all_img = torch.cat(all_img, dim=0) 199 | x_orig = all_img 200 | print(f"Data loaded shape : {all_img.shape}") 201 | x_orig = x_orig.to(torch.float32) 202 | print("Data type is :", x_orig.dtype) 203 | img_shape = (x_orig.shape[0], config.data.channels, config.data.image_size, config.data.image_size) 204 | if self.args.deg == "SV-CT": 205 | A_funcs = CT(img_width=256, radon_view=self.args.Nview, uniform=True, circle=False, device=config.device) 206 | elif self.args.deg == "LA-CT": 207 | A_funcs = CT(img_width=256, radon_view=self.args.Nview, uniform=False, circle=False, device=config.device) 208 | A = lambda z: A_funcs.A(z) 209 | Ap = lambda z: A_funcs.A_dagger(z) 210 | def Acg_TV(x): 211 | return A_funcs.AT(A_funcs.A(x)) + rho * _DzT(_Dz(x)) 212 | def ADMM(x, ATy, n_ADMM=n_ADMM): 213 | nonlocal del_z, udel_z 214 | for _ in range(n_ADMM): 215 | bcg_TV = ATy + rho * (_DzT(del_z) - _DzT(udel_z)) 216 | x = CG(Acg_TV, bcg_TV, x, n_inner=n_CG) 217 | del_z = shrink(_Dz(x) + udel_z, lamb / rho) 218 | udel_z = _Dz(x) - del_z + udel_z 219 | return x 220 | del_z = torch.zeros(img_shape, device=self.device) 221 | udel_z = torch.zeros(img_shape, device=self.device) 222 | x_orig = x_orig.to(self.device) ######n x 1 x 256 x 256 223 | print(x_orig.min(), x_orig.max(), "xorig") 224 | y = A(x_orig) 225 | print(y.shape, "projection shape") 226 | 227 | ###########################adding noise to projection####################################### 228 | if noise_flag: 229 | print("adding noise to projections") 230 | I0 = 1.11e6 231 | # y = (-(torch.log(1e4 * torch.exp(-y/256) + torch.randn_like(y) * 5) - math.log(1e4))*256) ###gaussian noise 232 | y = -(torch.log(torch.poisson(I0 * torch.exp(-y/18)) + torch.randn_like(y) * 5) - math.log(I0))*18 ##poisson gaussian noise 233 | 234 | Apy = Ap(y) 235 | print(Apy.shape, "Apy backprojection shape") 236 | ATy = A_funcs.AT(y) 237 | ##########################original#################################################### 238 | # x = torch.randn(20, 3, 256, 256, device = self.device) ####initial noise 239 | 240 | ########forward init############################ 241 | t = (torch.ones(500)).to(self.device) 242 | at = compute_alpha(self.betas, t.long()) 243 | # at_next = compute_alpha(self.betas, next_t.long()) 244 | at = at[0,0,0,0] 245 | init_noise = at.sqrt() * x_orig + torch.randn_like(x_orig) * (1 - at).sqrt() 246 | x = torch.reshape(init_noise, (num_batches, 3, 256, 256)) 247 | 248 | 249 | diffusion = create_gaussian_diffusion( 250 | steps=1000, 251 | learn_sigma=True, 252 | noise_schedule="linear", 253 | use_kl=False, 254 | predict_xstart=False, 255 | rescale_timesteps=False, 256 | rescale_learned_sigmas=False, 257 | timestep_respacing="", 258 | ) 259 | xt = None 260 | with torch.no_grad(): 261 | skip = config.diffusion.num_diffusion_timesteps//ddimsteps 262 | n = x.size(0) 263 | x0_preds = [] 264 | xt = x ###20 x 3 x 256 x 256 265 | 266 | # generate time schedule 267 | times = range(0, 1000, skip) #########0, 1, 2, .... 268 | times_next = [-1] + list(times[:-1]) 269 | times_pair = zip(reversed(times), reversed(times_next)) 270 | 271 | ct = 0 272 | ###################################start reverse sampling############################################ 273 | for i, j in tqdm.tqdm(times_pair, total=len(times)): 274 | t = (torch.ones(n) * i).to(x.device) 275 | next_t = (torch.ones(n) * j).to(x.device) 276 | ########if time travel do two passes, otherwise do one pass######### 277 | travels = tot_iters 278 | ct += 1 279 | 280 | for zhoumu in range(travels): 281 | print("zhoumu: ", zhoumu) 282 | t = (torch.ones(n) * i).to(x.device) 283 | next_t = (torch.ones(n) * j).to(x.device) 284 | at = compute_alpha(self.betas, t.long()) 285 | bt = torch.index_select(self.betas,0,t.long()) 286 | at_next = compute_alpha(self.betas, next_t.long()) 287 | at = at[0,0,0,0] 288 | at_next = at_next[0,0,0,0] 289 | #################################reverse with consistency######################################## 290 | et_agg = list() ###initialize a list of scores 291 | 292 | ###########################################ADJ slices############################################# 293 | if ct % 2 != 1: 294 | if vps: 295 | for M in range(1): ####number of VPS iterations 296 | noise = torch.randn_like(xt) 297 | ####################added by bowen 3/24/2024#################### 298 | et = self.vps_blend(xt, model, t, num_batches = num_batches) 299 | #################################################################### 300 | # if blend: 301 | # et = self.blendscore(xt, model, t, start_ind = 0) 302 | # else: 303 | # et = self.blendscore(xt, model,t, start_head = 1) ###1 x 60 x 256 x 256 304 | et = torch.reshape(et, (num_batches, 3, 256, 256)) 305 | lam_ = vps_scale 306 | xt = xt - lam_ * (1 - at).sqrt() * et 307 | xt = xt + ((lam_ * (2-lam_))*(1-at)).sqrt() * noise * 1 308 | if blend: 309 | et = self.blendscore(xt, model,t,num_batches = num_batches) 310 | else: 311 | y = torch.ones(1) * 1 312 | y=y.to(xt.device).to(torch.long) 313 | model_kwargs = {} 314 | model_kwargs["y"] = y 315 | for j in range(xt.shape[0]//1): 316 | et_sing = model(xt[j*1:(j+1)*1], t, **model_kwargs) ####4 x 6 x 256 x 256 317 | et_agg.append(et_sing) 318 | et = torch.cat(et_agg, dim=0) ####20 x 6 x 256 x 256 319 | et = et[:, :3] ####20 x 3 x 256 x 256 320 | ###reshape xt and et 321 | et_ = torch.reshape(et, ((num_batches * 3), 1, 256, 256)) 322 | #######################################SLICE JUMP############################################# 323 | if ct % 2 == 1: ###147258369 4710 324 | print(ct, "changing et to slice jumping") 325 | et_ = self.vps_blend(xt, model, t, num_batches = num_batches) 326 | et_ = torch.reshape(et_, ((num_batches * 3), 1, 256, 256)) 327 | xt_ = torch.reshape(xt, ((num_batches * 3), 1, 256, 256)) 328 | x0_t = (xt_ - et_ * (1 - at).sqrt()) / at.sqrt() ###60 x 1 x 256 x 256 scale [-1, 1] 329 | 330 | ###########################if inverse problem solving ###################################################### 331 | x0_t = torch.clip(x0_t, -1, 1) ####clip to [-1, 1] 332 | x0_t = (x0_t +1)/2 ###rescale to [0, 1] 333 | x0_t_hat = None 334 | eta = self.args.eta 335 | if zhoumu == 0: 336 | x0_t_hat = ADMM(x0_t, ATy, n_ADMM=n_ADMM) ######[0,1] 337 | # x0_t_hat = torch.clip(x0_t_hat, 0, 1) 338 | x0_t_hat = x0_t_hat * 2 - 1 #######rescale back to [-1, 1] 339 | else: 340 | x0_t_hat = x0_t * 2 - 1 #######rescale back to [-1, 1] 341 | ############################################################################################################ 342 | 343 | ###########################else###################################################### 344 | x0_t_hat = x0_t 345 | eta = self.args.eta 346 | 347 | c1 = (1 - at_next).sqrt() * eta 348 | c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5) 349 | if j != 0: 350 | xt_ = at_next.sqrt() * x0_t_hat + c1 * torch.randn_like(x0_t) + c2 * et_ 351 | else: 352 | xt_ = x0_t_hat 353 | xt = torch.reshape(xt_, (num_batches, 3, 256, 256)) ####reshape back 354 | 355 | ###################################################################################################### 356 | if noise_flag: 357 | print("added noise") 358 | np.save(f"ctrecon_jump_200NFE_{self.args.Nview}projs_pgnoise.npy", xt.detach().cpu().numpy()) 359 | else: 360 | np.save(f"ctrecon_jump_200NFE_{self.args.Nview}projs.npy", xt.detach().cpu().numpy()) 361 | 362 | if self.args.deg == "SV-CT": 363 | np.save("x_sample_ddim" + str(ddimsteps) + "_iter65000_reconstructionL67_blend3_rho" + str(rho) + "ttnew" + str(tot_iters) + "_full_view6_47_jump_ful.npy", xt.detach().cpu().numpy()) 364 | if self.args.deg == "LA-CT": 365 | np.save("x_sample_ddim" + str(ddimsteps) + f"_iter65000_lactL67_blend3_half{pre_slices}_full_view90.npy", xt.detach().cpu().numpy()) 366 | 367 | 368 | if blend: 369 | if vps: 370 | np.save("/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/blendDDS/apr7/x_sample_ddim" + str(ddimsteps) + "_iter65000_reconstructionL67_blend3_rho"+str(rho)+"ttnew"+str(tot_iters)+"_vps_"+ str(vps_scale)+"_full_skip2.npy", xt.detach().cpu().numpy()) 371 | else: 372 | np.save("/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/blendDDS/apr7/x_sample_ddim" + str(ddimsteps) + "_iter65000_reconstructionL67_blend3_rho" + str(rho) + "ttnew" + str(tot_iters) + "_full_jump_skip2_view6.npy", xt.detach().cpu().numpy()) 373 | else: 374 | if vps: 375 | np.save("/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/blendDDS/apr7/x_sample_ddim" + str(ddimsteps) + "_iter65000_reconstructionL67_rho" + str(rho) + "ttnew" + str(tot_iters) +"_vps_"+ str(vps_scale)+ "_full_skip2_view6.npy", xt.detach().cpu().numpy()) 376 | else: 377 | np.save("/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/blendDDS/apr7/x_sample_ddim" + str(ddimsteps) + "_iter65000_reconstructionL67_rho" + str(rho) + "ttnew" + str(tot_iters) + "_full_jump_skip2_view6.npy", xt.detach().cpu().numpy()) 378 | 379 | 380 | 381 | -------------------------------------------------------------------------------- /eval_3D_blend_cond.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | Nview=8 4 | T_sampling=50 5 | eta=0.85 6 | 7 | python main.py \ 8 | --type '3dblendcond' \ 9 | --config AAPM_256_lsun.yaml \ 10 | --dataset_path "/nfs/turbo/coe-liyues/bowenbw/DDS/indist_samples/CT/L067" \ 11 | --ckpt_load_name "/nfs/turbo/coe-liyues/bowenbw/DDS/checkpoints/AAPM256_1M.pth" \ 12 | --Nview $Nview \ 13 | --eta $eta \ 14 | --deg "SV-CT" \ 15 | --sigma_y 0.00 \ 16 | --T_sampling 100 \ 17 | --T_sampling $T_sampling \ 18 | -i ./results 19 | -------------------------------------------------------------------------------- /eval_3D_blend_cond_lidc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import time 4 | import glob 5 | import json 6 | import sys 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import tqdm 11 | import torch 12 | import torch.utils.data as data 13 | import torchvision.utils as tvu 14 | 15 | from guided_diffusion.models import Model 16 | from guided_diffusion.script_util import create_model, classifier_defaults, args_to_dict, create_gaussian_diffusion 17 | from guided_diffusion.utils import get_alpha_schedule 18 | import random 19 | 20 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 21 | from scipy.linalg import orth 22 | from pathlib import Path 23 | 24 | from physics.ct import CT 25 | from time import time 26 | from utils import shrink, CG, clear, batchfy, _Dz, _DzT, get_beta_schedule 27 | 28 | 29 | def compute_alpha(beta, t): 30 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 31 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 32 | return a 33 | 34 | 35 | class Diffusion(object): 36 | def __init__(self, args, config, device=None): 37 | self.args = args 38 | self.args.image_folder = Path(self.args.image_folder) 39 | for t in ["input", "recon", "label"]: 40 | if t == "recon": 41 | (self.args.image_folder / t / "progress").mkdir(exist_ok=True, parents=True) 42 | else: 43 | (self.args.image_folder / t).mkdir(exist_ok=True, parents=True) 44 | self.config = config 45 | print(self.config) 46 | if device is None: 47 | device = ( 48 | torch.device("cuda") 49 | if torch.cuda.is_available() 50 | else torch.device("cpu") 51 | ) 52 | self.device = device 53 | 54 | self.model_var_type = config.model.var_type 55 | betas = get_beta_schedule( 56 | beta_schedule=config.diffusion.beta_schedule, 57 | beta_start=config.diffusion.beta_start, 58 | beta_end=config.diffusion.beta_end, 59 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 60 | ) 61 | betas = self.betas = torch.from_numpy(betas).float().to(self.device) 62 | self.num_timesteps = betas.shape[0] 63 | 64 | alphas = 1.0 - betas 65 | alphas_cumprod = alphas.cumprod(dim=0) 66 | alphas_cumprod_prev = torch.cat( 67 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0 68 | ) 69 | self.alphas_cumprod_prev = alphas_cumprod_prev 70 | posterior_variance = ( 71 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 72 | ) 73 | if self.model_var_type == "fixedlarge": 74 | self.logvar = betas.log() 75 | elif self.model_var_type == "fixedsmall": 76 | self.logvar = posterior_variance.clamp(min=1e-20).log() 77 | 78 | def sample(self): 79 | config_dict = vars(self.config.model) 80 | config_dict['use_spacecode'] = False 81 | config_dict["class_cond"] = True 82 | model = create_model(**config_dict) 83 | ckpt = "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/LIDC_triplane3D_finetunelarge_4232024_iter62099_cond.ckpt" 84 | 85 | model.load_state_dict(torch.load(ckpt, map_location=self.device)["state_dict"]) 86 | print(f"Model ckpt loaded from {ckpt}") 87 | model.to(self.device) 88 | model.eval() 89 | model = torch.nn.DataParallel(model) 90 | 91 | print('Run 3D DDS + DiffusionMBIR.', 92 | f'{self.args.T_sampling} sampling steps.', 93 | f'Task: {self.args.deg}.' 94 | ) 95 | self.dds3d(model) 96 | 97 | 98 | def blendscore(self, xt, model,t, start_ind = None, start_head = 0, num_batches = 180): 99 | model_kwargs = {} 100 | y = torch.ones(1) * 1 101 | y=y.to(xt.device).to(torch.long) 102 | model_kwargs["y"] = y 103 | 104 | et = torch.zeros((1, num_batches * 3, 256, 256)).to(xt.device).to(torch.float32) 105 | xt = torch.reshape(xt, (1, num_batches * 3, 256, 256)) 106 | et[:,:3,:,:] = model(xt[:,:3,:,:], t, **model_kwargs)[:,:3] 107 | et[:,xt.shape[1]-3:, :,:] = model(xt[:,xt.shape[1]-3:,:,:], t, **model_kwargs)[:,:3] 108 | 109 | if start_ind is None: 110 | start_ind = np.random.randint(start_head,3) 111 | for j in range(start_ind, xt.shape[1]-2, 3): 112 | #####randomly select instead of summing 113 | et_sing = model(xt[:,j:(j+3),:,:], t, **model_kwargs)[:,:3] #####1 x 3 x 256 x 256 114 | et[:,j:(j+3), :,:] = et_sing 115 | return et 116 | 117 | def vps_blend(self, xt, model,t, start_ind = None, start_head = 0, num_batches = 180): 118 | model_kwargs = {} 119 | y = torch.ones(1) * 3 120 | y=y.to(xt.device).to(torch.long) 121 | model_kwargs["y"] = y 122 | 123 | et = torch.zeros((1, num_batches * 3, 256, 256)).to(xt.device).to(torch.float32) 124 | 125 | xt = torch.reshape(xt, (1, num_batches * 3, 256, 256)) 126 | # 147258369 127 | for i in range(0,xt.shape[1], 9): 128 | for m in range(3): 129 | et[:,[i+m,i+m+3, i+m+6],:,:] = model(xt[:,[i+m,i+m+3, i+m+6],:,:], t, **model_kwargs)[:,:3] 130 | 131 | return et 132 | 133 | 134 | def dds3d(self, model): 135 | args, config = self.args, self.config 136 | print(f"Dataset path: {self.args.dataset_path}") 137 | root = Path(self.args.dataset_path) 138 | 139 | # parameters to be moved to args 140 | Nview = self.args.Nview 141 | rho = 0.001 ###9:23pm 3/3 142 | lamb = 0.05/1000 143 | n_ADMM = 1 144 | n_CG = self.args.CG_iter 145 | print(n_CG) 146 | 147 | blend= True ####test again 4/7 148 | 149 | time_travel = False ### 4:20 3/10 150 | 151 | vps = False ###debugging 152 | 153 | ddimsteps= 200 ####11:17pm 3/20 154 | 155 | ####1:21 try hard consistency for only one time 156 | 157 | # print(rho, lamb, "admm params") 158 | # print("blending", blend) 159 | 160 | # Specify save directory for saving generated samples 161 | save_root = Path(self.args.image_folder) 162 | save_root.mkdir(parents=True, exist_ok=True) 163 | 164 | irl_types = ['vol', 'input', 'recon', 'label'] 165 | for t in irl_types: 166 | save_root_f = save_root / t 167 | save_root_f.mkdir(parents=True, exist_ok=True) 168 | 169 | fname_list = os.listdir("/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/validation_LIDC") 170 | root = "/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/validation_LIDC" 171 | 172 | fname_list.sort() 173 | num_batches = 87 ######4/22 2:16pm 174 | fname_list = fname_list[:(3 * num_batches)] #####4/22 2:22pm 175 | 176 | ###################################################################################################### 177 | 178 | print(fname_list) 179 | all_img = [] 180 | batch_size = 3 181 | print("Loading all data") 182 | if time_travel: 183 | tot_iters = 2 184 | else: 185 | tot_iters = 1 186 | for fname in fname_list: 187 | just_name = fname.split('.')[0] 188 | img = torch.from_numpy(np.load(os.path.join(root, fname), allow_pickle=True)) 189 | h, w = img.shape 190 | img = img.view(1, 1, h, w) 191 | all_img.append(img) 192 | all_img = torch.cat(all_img, dim=0) 193 | x_orig = all_img 194 | print(f"Data loaded shape : {all_img.shape}") 195 | x_orig = x_orig.to(torch.float32) 196 | print("Data type is :", x_orig.dtype) 197 | img_shape = (x_orig.shape[0], config.data.channels, config.data.image_size, config.data.image_size) 198 | if self.args.deg == "SV-CT": 199 | A_funcs = CT(img_width=256, radon_view=self.args.Nview, uniform=True, circle=False, device=config.device) 200 | elif self.args.deg == "LA-CT": 201 | A_funcs = CT(img_width=256, radon_view=self.args.Nview, uniform=False, circle=False, device=config.device) 202 | A = lambda z: A_funcs.A(z) 203 | Ap = lambda z: A_funcs.A_dagger(z) 204 | def Acg_TV(x): 205 | return A_funcs.AT(A_funcs.A(x)) + rho * _DzT(_Dz(x)) 206 | def ADMM(x, ATy, n_ADMM=n_ADMM): 207 | nonlocal del_z, udel_z 208 | for _ in range(n_ADMM): 209 | bcg_TV = ATy + rho * (_DzT(del_z) - _DzT(udel_z)) 210 | x = CG(Acg_TV, bcg_TV, x, n_inner=n_CG) 211 | del_z = shrink(_Dz(x) + udel_z, lamb / rho) 212 | udel_z = _Dz(x) - del_z + udel_z 213 | return x 214 | del_z = torch.zeros(img_shape, device=self.device) 215 | udel_z = torch.zeros(img_shape, device=self.device) 216 | x_orig = x_orig.to(self.device) ######n x 1 x 256 x 256 217 | print(x_orig.min(), x_orig.max(), "xorig") 218 | y = A(x_orig) 219 | print(y.shape, "projection shape") 220 | Apy = Ap(y) 221 | print(Apy.shape, "Apy backprojection shape") 222 | ATy = A_funcs.AT(y) 223 | ##########################original#################################################### 224 | x = torch.randn(num_batches, 3, 256, 256, device = self.device) ####initial noise 225 | ################################################################################################## 226 | 227 | diffusion = create_gaussian_diffusion( 228 | steps=1000, 229 | learn_sigma=True, 230 | noise_schedule="linear", 231 | use_kl=False, 232 | predict_xstart=False, 233 | rescale_timesteps=False, 234 | rescale_learned_sigmas=False, 235 | timestep_respacing="", 236 | ) 237 | xt = None 238 | with torch.no_grad(): 239 | skip = config.diffusion.num_diffusion_timesteps//ddimsteps 240 | n = x.size(0) 241 | x0_preds = [] 242 | xt = x ###20 x 3 x 256 x 256 243 | 244 | # generate time schedule 245 | times = range(0, 1000, skip) #########0, 1, 2, .... 246 | times_next = [-1] + list(times[:-1]) 247 | times_pair = zip(reversed(times), reversed(times_next)) 248 | 249 | if blend: 250 | n = 1 251 | else: 252 | n = 1 253 | 254 | ct = 0 255 | ###################################start reverse sampling############################################ 256 | for i, j in tqdm.tqdm(times_pair, total=len(times)): 257 | t = (torch.ones(n) * i).to(x.device) 258 | next_t = (torch.ones(n) * j).to(x.device) 259 | ########if time travel do two passes, otherwise do one pass######### 260 | travels = tot_iters 261 | ct += 1 262 | 263 | for zhoumu in range(travels): 264 | print("zhoumu: ", zhoumu) 265 | t = (torch.ones(n) * i).to(x.device) 266 | next_t = (torch.ones(n) * j).to(x.device) 267 | at = compute_alpha(self.betas, t.long()) 268 | bt = torch.index_select(self.betas,0,t.long()) 269 | at_next = compute_alpha(self.betas, next_t.long()) 270 | at = at[0,0,0,0] 271 | at_next = at_next[0,0,0,0] 272 | #################################reverse with consistency######################################## 273 | et_agg = list() ###initialize a list of scores 274 | ###########################################ADJ slices############################################# 275 | if ct % 2 == 0: 276 | if vps: 277 | for M in range(1): ####number of VPS iterations 278 | noise = torch.randn_like(xt) 279 | ####################added by bowen 3/24/2024#################### 280 | et = self.vps_blend(xt, model, t, num_batches = num_batches) 281 | #################################################################### 282 | et = torch.reshape(et, (num_batches, 3, 256, 256)) 283 | lam_ = vps_scale 284 | xt = xt - lam_ * (1 - at).sqrt() * et 285 | xt = xt + ((lam_ * (2-lam_))*(1-at)).sqrt() * noise * 1 286 | if blend: 287 | et = self.blendscore(xt, model,t,num_batches = num_batches) 288 | else: 289 | y = torch.ones(1) * 1 290 | y=y.to(xt.device).to(torch.long) 291 | model_kwargs = {} 292 | model_kwargs["y"] = y 293 | for j in range(xt.shape[0]//1): 294 | et_sing = model(xt[j*1:(j+1)*1], t, **model_kwargs) ####4 x 6 x 256 x 256 295 | et_agg.append(et_sing) 296 | et = torch.cat(et_agg, dim=0) ####20 x 6 x 256 x 256 297 | et = et[:, :3] ####20 x 3 x 256 x 256 298 | ###reshape xt and et 299 | et_ = torch.reshape(et, ((num_batches * 3), 1, 256, 256)) 300 | #######################################SLICE JUMP############################################# 301 | if ct % 2 == 1: ###147258369 4710 302 | print(ct, "changing et to slice jumping") 303 | et_ = self.vps_blend(xt, model, t, num_batches = num_batches) 304 | et_ = torch.reshape(et_, ((num_batches * 3), 1, 256, 256)) 305 | xt_ = torch.reshape(xt, ((num_batches * 3), 1, 256, 256)) 306 | x0_t = (xt_ - et_ * (1 - at).sqrt()) / at.sqrt() ###60 x 1 x 256 x 256 scale [-1, 1] 307 | # x0_t = torch.clip(x0_t, -1, 1) ####clip to [-1, 1] 308 | x0_t = (x0_t +1)/2 ###rescale to [0, 1] 309 | x0_t_hat = None 310 | eta = self.args.eta 311 | if zhoumu == 0: 312 | x0_t_hat = ADMM(x0_t, ATy, n_ADMM=n_ADMM) ######[0,1] 313 | # x0_t_hat = torch.clip(x0_t_hat, 0, 1) 314 | x0_t_hat = x0_t_hat * 2 - 1 #######rescale back to [-1, 1] 315 | else: 316 | x0_t_hat = x0_t * 2 - 1 #######rescale back to [-1, 1] 317 | 318 | c1 = (1 - at_next).sqrt() * eta 319 | c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5) 320 | if j != 0: 321 | xt_ = at_next.sqrt() * x0_t_hat + c1 * torch.randn_like(x0_t) + c2 * et_ 322 | else: 323 | xt_ = x0_t_hat 324 | xt = torch.reshape(xt_, (num_batches, 3, 256, 256)) ####reshape back 325 | 326 | if self.args.deg == "SV-CT": 327 | np.save("x_sample_ddim" + str(ddimsteps) + "_iter62000large_reconstructionL67_blend3_rho" + str(rho) + "ttnew" + str(tot_iters) + "_full_view4_424_jump_LIDC_ftldct_half_retest.npy", xt.detach().cpu().numpy()) #############imnet pretrain 40000 iters 328 | 329 | 330 | if self.args.deg == "LA-CT": 331 | np.save("x_sample_ddim" + str(ddimsteps) + "_iter62000large_lactlidc_blend3_full_view90.npy", xt.detach().cpu().numpy()) 332 | 333 | -------------------------------------------------------------------------------- /guided_diffusion/CTDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torchvision import datasets 4 | from torchvision.transforms import ToTensor 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | from glob import glob 9 | 10 | 11 | class CTDataset(Dataset): 12 | def __init__(self, metadata=None, img_dir=None, transform=None, target_transform=lambda x: x, patient_num = -1): 13 | 14 | self.training_paths = glob('/nfs/turbo/coe-liyues/bowenbw/3DCT/AAPM_fusion_training/*') 15 | self.transform = transform 16 | self.target_transform = target_transform 17 | self.patient_num = patient_num 18 | print("length of training data", len(self.training_paths)) 19 | 20 | 21 | def __len__(self): 22 | return len(self.training_paths) 23 | 24 | def __getitem__(self, idx): 25 | 26 | image = np.load(self.training_paths[idx]) 27 | image = np.transpose(image, (2,0,1)) 28 | image = np.clip(image*2-1, -1, 1) 29 | 30 | return torch.from_numpy(image) 31 | 32 | 33 | class CTCondDataset(Dataset): 34 | def __init__(self, metadata=None, img_dir=None, transform=None, target_transform=lambda x: x, patient_num = -1): 35 | 36 | self.training_paths = glob('/nfs/turbo/coe-liyues/bowenbw/3DCT/AAPM_fusion_training/*') 37 | self.transform = transform 38 | self.target_transform = target_transform 39 | self.patient_num = patient_num 40 | print("length of training data", len(self.training_paths)) 41 | 42 | def __len__(self): 43 | return len(self.training_paths) 44 | 45 | 46 | def __getitem__(self, idx): 47 | image = None 48 | return None 49 | 50 | 51 | 52 | if __name__ == "__main__": 53 | ds = CTDataset() 54 | params = {'batch_size': 2} 55 | training_generator = torch.utils.data.DataLoader(ds, **params) 56 | ct = 0 57 | for local_batch in training_generator: 58 | print(local_batch.shape) 59 | ct += 1 60 | if ct > 4: 61 | break -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/CTDataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/CTDataset.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/dist_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/dist_util.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/fp16_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/fp16_util.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/fp16_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/fp16_util.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/losses.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/losses.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/nn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/nn.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/nn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/nn.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/resample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/resample.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/respace.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/respace.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/respace.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/respace.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/script_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/script_util.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/script_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/script_util.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/train_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/train_util.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/unet.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | # os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 28 | 29 | comm = MPI.COMM_WORLD 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = str(comm.rank) 38 | os.environ["WORLD_SIZE"] = str(comm.size) 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 59 | if MPI.COMM_WORLD.Get_rank() == 0: 60 | with bf.BlobFile(path, "rb") as f: 61 | data = f.read() 62 | num_chunks = len(data) // chunk_size 63 | if len(data) % chunk_size: 64 | num_chunks += 1 65 | MPI.COMM_WORLD.bcast(num_chunks) 66 | for i in range(0, len(data), chunk_size): 67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 68 | else: 69 | num_chunks = MPI.COMM_WORLD.bcast(None) 70 | data = bytes() 71 | for _ in range(num_chunks): 72 | data += MPI.COMM_WORLD.bcast(None) 73 | 74 | return th.load(io.BytesIO(data), **kwargs) 75 | 76 | 77 | def sync_params(params): 78 | """ 79 | Synchronize a sequence of Tensors across ranks from rank 0. 80 | """ 81 | for p in params: 82 | with th.no_grad(): 83 | dist.broadcast(p, 0) 84 | 85 | 86 | def _find_free_port(): 87 | try: 88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 89 | s.bind(("", 0)) 90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 91 | return s.getsockname()[1] 92 | finally: 93 | s.close() 94 | -------------------------------------------------------------------------------- /guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /guided_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | from PIL import Image 5 | import blobfile as bf 6 | from mpi4py import MPI 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | 11 | def load_data( 12 | *, 13 | data_dir, 14 | batch_size, 15 | image_size, 16 | class_cond=False, 17 | deterministic=False, 18 | random_crop=False, 19 | random_flip=True, 20 | use_mps=False, 21 | ): 22 | if not data_dir: 23 | raise ValueError("unspecified data directory") 24 | all_files = _list_image_files_recursively(data_dir, use_mps=use_mps) 25 | classes = None 26 | if class_cond: 27 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 28 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 29 | classes = [sorted_classes[x] for x in class_names] 30 | dataset = NpyDataset( 31 | image_size, 32 | all_files, 33 | classes=classes, 34 | shard=MPI.COMM_WORLD.Get_rank(), 35 | num_shards=MPI.COMM_WORLD.Get_size(), 36 | random_crop=random_crop, 37 | random_flip=random_flip, 38 | ) 39 | if deterministic: 40 | loader = DataLoader( 41 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 42 | ) 43 | else: 44 | loader = DataLoader( 45 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 46 | ) 47 | while True: 48 | yield from loader 49 | 50 | 51 | def _list_image_files_recursively(data_dir, use_mps=False): 52 | results = [] 53 | for entry in sorted(bf.listdir(data_dir)): 54 | if use_mps: 55 | if entry == "mps": 56 | continue 57 | else: 58 | if "img" in entry: 59 | continue 60 | full_path = bf.join(data_dir, entry) 61 | ext = entry.split(".")[-1] 62 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif", "npy"]: 63 | results.append(full_path) 64 | elif bf.isdir(full_path): 65 | results.extend(_list_image_files_recursively(full_path)) 66 | return results 67 | 68 | 69 | class ImageDataset(Dataset): 70 | def __init__( 71 | self, 72 | resolution, 73 | image_paths, 74 | classes=None, 75 | shard=0, 76 | num_shards=1, 77 | random_crop=False, 78 | random_flip=True, 79 | ): 80 | super().__init__() 81 | self.resolution = resolution 82 | self.local_images = image_paths[shard:][::num_shards] 83 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 84 | self.random_crop = random_crop 85 | self.random_flip = random_flip 86 | 87 | def __len__(self): 88 | return len(self.local_images) 89 | 90 | def __getitem__(self, idx): 91 | path = self.local_images[idx] 92 | with bf.BlobFile(path, "rb") as f: 93 | pil_image = Image.open(f) 94 | pil_image.load() 95 | pil_image = pil_image.convert("RGB") 96 | 97 | if self.random_crop: 98 | arr = random_crop_arr(pil_image, self.resolution) 99 | else: 100 | arr = center_crop_arr(pil_image, self.resolution) 101 | 102 | if self.random_flip and random.random() < 0.5: 103 | arr = arr[:, ::-1] 104 | 105 | arr = arr.astype(np.float32) / 127.5 - 1 106 | 107 | out_dict = {} 108 | if self.local_classes is not None: 109 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 110 | return np.transpose(arr, [2, 0, 1]), out_dict 111 | 112 | 113 | class NpyDataset(Dataset): 114 | def __init__( 115 | self, 116 | resolution, 117 | image_paths, 118 | classes=None, 119 | shard=0, 120 | num_shards=1, 121 | random_crop=False, 122 | random_flip=True, 123 | ): 124 | super().__init__() 125 | self.resolution = resolution 126 | self.local_images = image_paths[shard:][::num_shards] 127 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 128 | self.random_crop = random_crop 129 | self.random_flip = random_flip 130 | 131 | def __len__(self): 132 | return len(self.local_images) 133 | 134 | def __getitem__(self, idx): 135 | path = self.local_images[idx] 136 | arr = np.load(path) 137 | arr = arr[np.newaxis, :, :] 138 | out_dict = {} 139 | return arr, out_dict 140 | 141 | 142 | def center_crop_arr(pil_image, image_size): 143 | # We are not on a new enough PIL to support the `reducing_gap` 144 | # argument, which uses BOX downsampling at powers of two first. 145 | # Thus, we do it by hand to improve downsample quality. 146 | while min(*pil_image.size) >= 2 * image_size: 147 | pil_image = pil_image.resize( 148 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 149 | ) 150 | 151 | scale = image_size / min(*pil_image.size) 152 | pil_image = pil_image.resize( 153 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 154 | ) 155 | 156 | arr = np.array(pil_image) 157 | crop_y = (arr.shape[0] - image_size) // 2 158 | crop_x = (arr.shape[1] - image_size) // 2 159 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 160 | 161 | 162 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 163 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 164 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 165 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 166 | 167 | # We are not on a new enough PIL to support the `reducing_gap` 168 | # argument, which uses BOX downsampling at powers of two first. 169 | # Thus, we do it by hand to improve downsample quality. 170 | while min(*pil_image.size) >= 2 * smaller_dim_size: 171 | pil_image = pil_image.resize( 172 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 173 | ) 174 | 175 | scale = smaller_dim_size / min(*pil_image.size) 176 | pil_image = pil_image.resize( 177 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 178 | ) 179 | 180 | arr = np.array(pil_image) 181 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 182 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 183 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 184 | -------------------------------------------------------------------------------- /guided_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /guided_diffusion/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def get_timestep_embedding(timesteps, embedding_dim): 7 | """ 8 | This matches the implementation in Denoising Diffusion Probabilistic Models: 9 | From Fairseq. 10 | Build sinusoidal embeddings. 11 | This matches the implementation in tensor2tensor, but differs slightly 12 | from the description in Section 3.5 of "Attention Is All You Need". 13 | """ 14 | assert len(timesteps.shape) == 1 15 | 16 | half_dim = embedding_dim // 2 17 | emb = math.log(10000) / (half_dim - 1) 18 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 19 | emb = emb.to(device=timesteps.device) 20 | emb = timesteps.float()[:, None] * emb[None, :] 21 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 22 | if embedding_dim % 2 == 1: # zero pad 23 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 24 | return emb 25 | 26 | 27 | def nonlinearity(x): 28 | # swish 29 | return x*torch.sigmoid(x) 30 | 31 | 32 | def Normalize(in_channels): 33 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 34 | 35 | 36 | class Upsample(nn.Module): 37 | def __init__(self, in_channels, with_conv): 38 | super().__init__() 39 | self.with_conv = with_conv 40 | if self.with_conv: 41 | self.conv = torch.nn.Conv2d(in_channels, 42 | in_channels, 43 | kernel_size=3, 44 | stride=1, 45 | padding=1) 46 | 47 | def forward(self, x): 48 | x = torch.nn.functional.interpolate( 49 | x, scale_factor=2.0, mode="nearest") 50 | if self.with_conv: 51 | x = self.conv(x) 52 | return x 53 | 54 | 55 | class Downsample(nn.Module): 56 | def __init__(self, in_channels, with_conv): 57 | super().__init__() 58 | self.with_conv = with_conv 59 | if self.with_conv: 60 | # no asymmetric padding in torch conv, must do it ourselves 61 | self.conv = torch.nn.Conv2d(in_channels, 62 | in_channels, 63 | kernel_size=3, 64 | stride=2, 65 | padding=0) 66 | 67 | def forward(self, x): 68 | if self.with_conv: 69 | pad = (0, 1, 0, 1) 70 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 71 | x = self.conv(x) 72 | else: 73 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 74 | return x 75 | 76 | 77 | class ResnetBlock(nn.Module): 78 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 79 | dropout, temb_channels=512): 80 | super().__init__() 81 | self.in_channels = in_channels 82 | out_channels = in_channels if out_channels is None else out_channels 83 | self.out_channels = out_channels 84 | self.use_conv_shortcut = conv_shortcut 85 | 86 | self.norm1 = Normalize(in_channels) 87 | self.conv1 = torch.nn.Conv2d(in_channels, 88 | out_channels, 89 | kernel_size=3, 90 | stride=1, 91 | padding=1) 92 | self.temb_proj = torch.nn.Linear(temb_channels, 93 | out_channels) 94 | self.norm2 = Normalize(out_channels) 95 | self.dropout = torch.nn.Dropout(dropout) 96 | self.conv2 = torch.nn.Conv2d(out_channels, 97 | out_channels, 98 | kernel_size=3, 99 | stride=1, 100 | padding=1) 101 | if self.in_channels != self.out_channels: 102 | if self.use_conv_shortcut: 103 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 104 | out_channels, 105 | kernel_size=3, 106 | stride=1, 107 | padding=1) 108 | else: 109 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 110 | out_channels, 111 | kernel_size=1, 112 | stride=1, 113 | padding=0) 114 | 115 | def forward(self, x, temb): 116 | h = x 117 | h = self.norm1(h) 118 | h = nonlinearity(h) 119 | h = self.conv1(h) 120 | 121 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 122 | 123 | h = self.norm2(h) 124 | h = nonlinearity(h) 125 | h = self.dropout(h) 126 | h = self.conv2(h) 127 | 128 | if self.in_channels != self.out_channels: 129 | if self.use_conv_shortcut: 130 | x = self.conv_shortcut(x) 131 | else: 132 | x = self.nin_shortcut(x) 133 | 134 | return x+h 135 | 136 | 137 | class AttnBlock(nn.Module): 138 | def __init__(self, in_channels): 139 | super().__init__() 140 | self.in_channels = in_channels 141 | 142 | self.norm = Normalize(in_channels) 143 | self.q = torch.nn.Conv2d(in_channels, 144 | in_channels, 145 | kernel_size=1, 146 | stride=1, 147 | padding=0) 148 | self.k = torch.nn.Conv2d(in_channels, 149 | in_channels, 150 | kernel_size=1, 151 | stride=1, 152 | padding=0) 153 | self.v = torch.nn.Conv2d(in_channels, 154 | in_channels, 155 | kernel_size=1, 156 | stride=1, 157 | padding=0) 158 | self.proj_out = torch.nn.Conv2d(in_channels, 159 | in_channels, 160 | kernel_size=1, 161 | stride=1, 162 | padding=0) 163 | 164 | def forward(self, x): 165 | h_ = x 166 | h_ = self.norm(h_) 167 | q = self.q(h_) 168 | k = self.k(h_) 169 | v = self.v(h_) 170 | 171 | # compute attention 172 | b, c, h, w = q.shape 173 | q = q.reshape(b, c, h*w) 174 | q = q.permute(0, 2, 1) # b,hw,c 175 | k = k.reshape(b, c, h*w) # b,c,hw 176 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 177 | w_ = w_ * (int(c)**(-0.5)) 178 | w_ = torch.nn.functional.softmax(w_, dim=2) 179 | 180 | # attend to values 181 | v = v.reshape(b, c, h*w) 182 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 183 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 184 | h_ = torch.bmm(v, w_) 185 | h_ = h_.reshape(b, c, h, w) 186 | 187 | h_ = self.proj_out(h_) 188 | 189 | return x+h_ 190 | 191 | 192 | class Model(nn.Module): 193 | def __init__(self, config): 194 | super().__init__() 195 | self.config = config 196 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult) 197 | num_res_blocks = config.model.num_res_blocks 198 | attn_resolutions = config.model.attn_resolutions 199 | dropout = config.model.dropout 200 | in_channels = config.model.in_channels 201 | resolution = config.data.image_size 202 | resamp_with_conv = config.model.resamp_with_conv 203 | num_timesteps = config.diffusion.num_diffusion_timesteps 204 | 205 | if config.model.type == 'bayesian': 206 | self.logvar = nn.Parameter(torch.zeros(num_timesteps)) 207 | 208 | self.ch = ch 209 | self.temb_ch = self.ch*4 210 | self.num_resolutions = len(ch_mult) 211 | self.num_res_blocks = num_res_blocks 212 | self.resolution = resolution 213 | self.in_channels = in_channels 214 | 215 | # timestep embedding 216 | self.temb = nn.Module() 217 | self.temb.dense = nn.ModuleList([ 218 | torch.nn.Linear(self.ch, 219 | self.temb_ch), 220 | torch.nn.Linear(self.temb_ch, 221 | self.temb_ch), 222 | ]) 223 | 224 | # downsampling 225 | self.conv_in = torch.nn.Conv2d(in_channels, 226 | self.ch, 227 | kernel_size=3, 228 | stride=1, 229 | padding=1) 230 | 231 | curr_res = resolution 232 | in_ch_mult = (1,)+ch_mult 233 | self.down = nn.ModuleList() 234 | block_in = None 235 | for i_level in range(self.num_resolutions): 236 | block = nn.ModuleList() 237 | attn = nn.ModuleList() 238 | block_in = ch*in_ch_mult[i_level] 239 | block_out = ch*ch_mult[i_level] 240 | for i_block in range(self.num_res_blocks): 241 | block.append(ResnetBlock(in_channels=block_in, 242 | out_channels=block_out, 243 | temb_channels=self.temb_ch, 244 | dropout=dropout)) 245 | block_in = block_out 246 | if curr_res in attn_resolutions: 247 | attn.append(AttnBlock(block_in)) 248 | down = nn.Module() 249 | down.block = block 250 | down.attn = attn 251 | if i_level != self.num_resolutions-1: 252 | down.downsample = Downsample(block_in, resamp_with_conv) 253 | curr_res = curr_res // 2 254 | self.down.append(down) 255 | 256 | # middle 257 | self.mid = nn.Module() 258 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 259 | out_channels=block_in, 260 | temb_channels=self.temb_ch, 261 | dropout=dropout) 262 | self.mid.attn_1 = AttnBlock(block_in) 263 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 264 | out_channels=block_in, 265 | temb_channels=self.temb_ch, 266 | dropout=dropout) 267 | 268 | # upsampling 269 | self.up = nn.ModuleList() 270 | for i_level in reversed(range(self.num_resolutions)): 271 | block = nn.ModuleList() 272 | attn = nn.ModuleList() 273 | block_out = ch*ch_mult[i_level] 274 | skip_in = ch*ch_mult[i_level] 275 | for i_block in range(self.num_res_blocks+1): 276 | if i_block == self.num_res_blocks: 277 | skip_in = ch*in_ch_mult[i_level] 278 | block.append(ResnetBlock(in_channels=block_in+skip_in, 279 | out_channels=block_out, 280 | temb_channels=self.temb_ch, 281 | dropout=dropout)) 282 | block_in = block_out 283 | if curr_res in attn_resolutions: 284 | attn.append(AttnBlock(block_in)) 285 | up = nn.Module() 286 | up.block = block 287 | up.attn = attn 288 | if i_level != 0: 289 | up.upsample = Upsample(block_in, resamp_with_conv) 290 | curr_res = curr_res * 2 291 | self.up.insert(0, up) # prepend to get consistent order 292 | 293 | # end 294 | self.norm_out = Normalize(block_in) 295 | self.conv_out = torch.nn.Conv2d(block_in, 296 | out_ch, 297 | kernel_size=3, 298 | stride=1, 299 | padding=1) 300 | 301 | def forward(self, x, t): 302 | assert x.shape[2] == x.shape[3] == self.resolution 303 | 304 | # timestep embedding 305 | temb = get_timestep_embedding(t, self.ch) 306 | temb = self.temb.dense[0](temb) 307 | temb = nonlinearity(temb) 308 | temb = self.temb.dense[1](temb) 309 | 310 | # downsampling 311 | hs = [self.conv_in(x)] 312 | for i_level in range(self.num_resolutions): 313 | for i_block in range(self.num_res_blocks): 314 | h = self.down[i_level].block[i_block](hs[-1], temb) 315 | if len(self.down[i_level].attn) > 0: 316 | h = self.down[i_level].attn[i_block](h) 317 | hs.append(h) 318 | if i_level != self.num_resolutions-1: 319 | hs.append(self.down[i_level].downsample(hs[-1])) 320 | 321 | # middle 322 | h = hs[-1] 323 | h = self.mid.block_1(h, temb) 324 | h = self.mid.attn_1(h) 325 | h = self.mid.block_2(h, temb) 326 | 327 | # upsampling 328 | for i_level in reversed(range(self.num_resolutions)): 329 | for i_block in range(self.num_res_blocks+1): 330 | h = self.up[i_level].block[i_block]( 331 | torch.cat([h, hs.pop()], dim=1), temb) 332 | if len(self.up[i_level].attn) > 0: 333 | h = self.up[i_level].attn[i_block](h) 334 | if i_level != 0: 335 | h = self.up[i_level].upsample(h) 336 | 337 | # end 338 | h = self.norm_out(h) 339 | h = nonlinearity(h) 340 | h = self.conv_out(h) 341 | return h 342 | -------------------------------------------------------------------------------- /guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /guided_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel 7 | 8 | NUM_CLASSES = 1000 9 | 10 | 11 | def diffusion_defaults(): 12 | """ 13 | Defaults for image and classifier training. 14 | """ 15 | return dict( 16 | learn_sigma=False, 17 | diffusion_steps=1000, 18 | noise_schedule="linear", 19 | timestep_respacing="", 20 | use_kl=False, 21 | predict_xstart=False, 22 | rescale_timesteps=False, 23 | rescale_learned_sigmas=False, 24 | ) 25 | 26 | 27 | def classifier_defaults(): 28 | """ 29 | Defaults for classifier models. 30 | """ 31 | return dict( 32 | image_size=64, 33 | classifier_use_fp16=False, 34 | classifier_width=128, 35 | classifier_depth=2, 36 | classifier_attention_resolutions="32,16,8", # 16 37 | classifier_use_scale_shift_norm=True, # False 38 | classifier_resblock_updown=True, # False 39 | classifier_pool="attention", 40 | ) 41 | 42 | 43 | def model_and_diffusion_defaults(): 44 | """ 45 | Defaults for image training. 46 | """ 47 | res = dict( 48 | image_size=64, 49 | num_channels=128, 50 | num_res_blocks=2, 51 | num_heads=4, 52 | num_heads_upsample=-1, 53 | num_head_channels=-1, 54 | attention_resolutions="16,8", 55 | channel_mult="", 56 | dropout=0.0, 57 | class_cond=False, 58 | use_checkpoint=False, 59 | use_scale_shift_norm=True, 60 | resblock_updown=False, 61 | use_fp16=False, 62 | use_new_attention_order=False, 63 | ) 64 | res.update(diffusion_defaults()) 65 | return res 66 | 67 | 68 | def classifier_and_diffusion_defaults(): 69 | res = classifier_defaults() 70 | res.update(diffusion_defaults()) 71 | return res 72 | 73 | 74 | def create_model_and_diffusion( 75 | image_size, 76 | class_cond, 77 | learn_sigma, 78 | num_channels, 79 | num_res_blocks, 80 | channel_mult, 81 | num_heads, 82 | num_head_channels, 83 | num_heads_upsample, 84 | attention_resolutions, 85 | dropout, 86 | diffusion_steps, 87 | noise_schedule, 88 | timestep_respacing, 89 | use_kl, 90 | predict_xstart, 91 | rescale_timesteps, 92 | rescale_learned_sigmas, 93 | use_checkpoint, 94 | use_scale_shift_norm, 95 | resblock_updown, 96 | use_fp16, 97 | use_new_attention_order, 98 | ): 99 | model = create_model( 100 | image_size, 101 | num_channels, 102 | num_res_blocks, 103 | channel_mult=channel_mult, 104 | learn_sigma=learn_sigma, 105 | class_cond=class_cond, 106 | use_checkpoint=use_checkpoint, 107 | attention_resolutions=attention_resolutions, 108 | num_heads=num_heads, 109 | num_head_channels=num_head_channels, 110 | num_heads_upsample=num_heads_upsample, 111 | use_scale_shift_norm=use_scale_shift_norm, 112 | dropout=dropout, 113 | resblock_updown=resblock_updown, 114 | use_fp16=use_fp16, 115 | use_new_attention_order=use_new_attention_order, 116 | ) 117 | diffusion = create_gaussian_diffusion( 118 | steps=diffusion_steps, 119 | learn_sigma=learn_sigma, 120 | noise_schedule=noise_schedule, 121 | use_kl=use_kl, 122 | predict_xstart=predict_xstart, 123 | rescale_timesteps=rescale_timesteps, 124 | rescale_learned_sigmas=rescale_learned_sigmas, 125 | timestep_respacing=timestep_respacing, 126 | ) 127 | return model, diffusion 128 | 129 | 130 | def create_model( 131 | image_size, 132 | num_channels, 133 | num_res_blocks, 134 | channel_mult="", 135 | learn_sigma=False, 136 | class_cond=False, 137 | use_checkpoint=False, 138 | attention_resolutions="16", 139 | num_heads=1, 140 | num_head_channels=-1, 141 | num_heads_upsample=-1, 142 | use_scale_shift_norm=False, 143 | dropout=0, 144 | resblock_updown=False, 145 | use_fp16=False, 146 | use_new_attention_order=False, 147 | **kwargs 148 | ): 149 | 150 | 151 | 152 | if channel_mult == "": 153 | if image_size == 512: 154 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 155 | elif image_size == 256 or image_size == 320: 156 | channel_mult = (1, 1, 2, 2, 4, 4) 157 | elif image_size == 128: 158 | channel_mult = (1, 1, 2, 3, 4) 159 | elif image_size == 64: 160 | channel_mult = (1, 2, 3, 4) 161 | else: 162 | raise ValueError(f"unsupported image size: {image_size}") 163 | else: 164 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 165 | 166 | attention_ds = [] 167 | for res in attention_resolutions.split(","): 168 | attention_ds.append(image_size // int(res)) 169 | 170 | return UNetModel( 171 | image_size=image_size, 172 | in_channels=kwargs["in_channels"], 173 | model_channels=num_channels, 174 | out_channels=(kwargs["in_channels"] if not learn_sigma else kwargs["in_channels"] * 2), 175 | num_res_blocks=num_res_blocks, 176 | attention_resolutions=tuple(attention_ds), 177 | dropout=dropout, 178 | channel_mult=channel_mult, 179 | num_classes=(NUM_CLASSES if class_cond else None), 180 | use_checkpoint=use_checkpoint, 181 | use_fp16=use_fp16, 182 | num_heads=num_heads, 183 | num_head_channels=num_head_channels, 184 | num_heads_upsample=num_heads_upsample, 185 | use_scale_shift_norm=use_scale_shift_norm, 186 | resblock_updown=resblock_updown, 187 | use_new_attention_order=use_new_attention_order, 188 | use_spacecode = kwargs["use_spacecode"], 189 | ) 190 | 191 | 192 | def create_classifier_and_diffusion( 193 | image_size, 194 | classifier_use_fp16, 195 | classifier_width, 196 | classifier_depth, 197 | classifier_attention_resolutions, 198 | classifier_use_scale_shift_norm, 199 | classifier_resblock_updown, 200 | classifier_pool, 201 | learn_sigma, 202 | diffusion_steps, 203 | noise_schedule, 204 | timestep_respacing, 205 | use_kl, 206 | predict_xstart, 207 | rescale_timesteps, 208 | rescale_learned_sigmas, 209 | ): 210 | classifier = create_classifier( 211 | image_size, 212 | classifier_use_fp16, 213 | classifier_width, 214 | classifier_depth, 215 | classifier_attention_resolutions, 216 | classifier_use_scale_shift_norm, 217 | classifier_resblock_updown, 218 | classifier_pool, 219 | ) 220 | diffusion = create_gaussian_diffusion( 221 | steps=diffusion_steps, 222 | learn_sigma=learn_sigma, 223 | noise_schedule=noise_schedule, 224 | use_kl=use_kl, 225 | predict_xstart=predict_xstart, 226 | rescale_timesteps=rescale_timesteps, 227 | rescale_learned_sigmas=rescale_learned_sigmas, 228 | timestep_respacing=timestep_respacing, 229 | ) 230 | return classifier, diffusion 231 | 232 | 233 | def create_classifier( 234 | image_size, 235 | classifier_use_fp16, 236 | classifier_width, 237 | classifier_depth, 238 | classifier_attention_resolutions, 239 | classifier_use_scale_shift_norm, 240 | classifier_resblock_updown, 241 | classifier_pool, 242 | ): 243 | if image_size == 512: 244 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 245 | elif image_size == 256: 246 | channel_mult = (1, 1, 2, 2, 4, 4) 247 | elif image_size == 128: 248 | channel_mult = (1, 1, 2, 3, 4) 249 | elif image_size == 64: 250 | channel_mult = (1, 2, 3, 4) 251 | else: 252 | raise ValueError(f"unsupported image size: {image_size}") 253 | 254 | attention_ds = [] 255 | for res in classifier_attention_resolutions.split(","): 256 | attention_ds.append(image_size // int(res)) 257 | 258 | return EncoderUNetModel( 259 | image_size=image_size, 260 | in_channels=3, 261 | model_channels=classifier_width, 262 | out_channels=1000, 263 | num_res_blocks=classifier_depth, 264 | attention_resolutions=tuple(attention_ds), 265 | channel_mult=channel_mult, 266 | use_fp16=classifier_use_fp16, 267 | num_head_channels=64, 268 | use_scale_shift_norm=classifier_use_scale_shift_norm, 269 | resblock_updown=classifier_resblock_updown, 270 | pool=classifier_pool, 271 | ) 272 | 273 | 274 | def sr_model_and_diffusion_defaults(): 275 | res = model_and_diffusion_defaults() 276 | res["large_size"] = 256 277 | res["small_size"] = 64 278 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 279 | for k in res.copy().keys(): 280 | if k not in arg_names: 281 | del res[k] 282 | return res 283 | 284 | 285 | def sr_create_model_and_diffusion( 286 | large_size, 287 | small_size, 288 | class_cond, 289 | learn_sigma, 290 | num_channels, 291 | num_res_blocks, 292 | num_heads, 293 | num_head_channels, 294 | num_heads_upsample, 295 | attention_resolutions, 296 | dropout, 297 | diffusion_steps, 298 | noise_schedule, 299 | timestep_respacing, 300 | use_kl, 301 | predict_xstart, 302 | rescale_timesteps, 303 | rescale_learned_sigmas, 304 | use_checkpoint, 305 | use_scale_shift_norm, 306 | resblock_updown, 307 | use_fp16, 308 | ): 309 | model = sr_create_model( 310 | large_size, 311 | small_size, 312 | num_channels, 313 | num_res_blocks, 314 | learn_sigma=learn_sigma, 315 | class_cond=class_cond, 316 | use_checkpoint=use_checkpoint, 317 | attention_resolutions=attention_resolutions, 318 | num_heads=num_heads, 319 | num_head_channels=num_head_channels, 320 | num_heads_upsample=num_heads_upsample, 321 | use_scale_shift_norm=use_scale_shift_norm, 322 | dropout=dropout, 323 | resblock_updown=resblock_updown, 324 | use_fp16=use_fp16, 325 | ) 326 | diffusion = create_gaussian_diffusion( 327 | steps=diffusion_steps, 328 | learn_sigma=learn_sigma, 329 | noise_schedule=noise_schedule, 330 | use_kl=use_kl, 331 | predict_xstart=predict_xstart, 332 | rescale_timesteps=rescale_timesteps, 333 | rescale_learned_sigmas=rescale_learned_sigmas, 334 | timestep_respacing=timestep_respacing, 335 | ) 336 | return model, diffusion 337 | 338 | 339 | def sr_create_model( 340 | large_size, 341 | small_size, 342 | num_channels, 343 | num_res_blocks, 344 | learn_sigma, 345 | class_cond, 346 | use_checkpoint, 347 | attention_resolutions, 348 | num_heads, 349 | num_head_channels, 350 | num_heads_upsample, 351 | use_scale_shift_norm, 352 | dropout, 353 | resblock_updown, 354 | use_fp16, 355 | ): 356 | _ = small_size # hack to prevent unused variable 357 | 358 | if large_size == 512: 359 | channel_mult = (1, 1, 2, 2, 4, 4) 360 | elif large_size == 256: 361 | channel_mult = (1, 1, 2, 2, 4, 4) 362 | elif large_size == 64: 363 | channel_mult = (1, 2, 3, 4) 364 | else: 365 | raise ValueError(f"unsupported large size: {large_size}") 366 | 367 | attention_ds = [] 368 | for res in attention_resolutions.split(","): 369 | attention_ds.append(large_size // int(res)) 370 | 371 | return SuperResModel( 372 | image_size=large_size, 373 | in_channels=3, 374 | model_channels=num_channels, 375 | out_channels=(3 if not learn_sigma else 6), 376 | num_res_blocks=num_res_blocks, 377 | attention_resolutions=tuple(attention_ds), 378 | dropout=dropout, 379 | channel_mult=channel_mult, 380 | num_classes=(NUM_CLASSES if class_cond else None), 381 | use_checkpoint=use_checkpoint, 382 | num_heads=num_heads, 383 | num_head_channels=num_head_channels, 384 | num_heads_upsample=num_heads_upsample, 385 | use_scale_shift_norm=use_scale_shift_norm, 386 | resblock_updown=resblock_updown, 387 | use_fp16=use_fp16, 388 | ) 389 | 390 | 391 | def create_gaussian_diffusion( 392 | *, 393 | steps=1000, 394 | learn_sigma=False, 395 | sigma_small=False, 396 | noise_schedule="linear", 397 | use_kl=False, 398 | predict_xstart=False, 399 | rescale_timesteps=False, 400 | rescale_learned_sigmas=False, 401 | timestep_respacing="", 402 | ): 403 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 404 | if use_kl: 405 | loss_type = gd.LossType.RESCALED_KL 406 | elif rescale_learned_sigmas: 407 | loss_type = gd.LossType.RESCALED_MSE 408 | else: 409 | loss_type = gd.LossType.MSE 410 | if not timestep_respacing: 411 | timestep_respacing = [steps] 412 | return SpacedDiffusion( 413 | use_timesteps=space_timesteps(steps, timestep_respacing), 414 | betas=betas, 415 | model_mean_type=( 416 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 417 | ), 418 | model_var_type=( 419 | ( 420 | gd.ModelVarType.FIXED_LARGE 421 | if not sigma_small 422 | else gd.ModelVarType.FIXED_SMALL 423 | ) 424 | if not learn_sigma 425 | else gd.ModelVarType.LEARNED_RANGE 426 | ), 427 | loss_type=loss_type, 428 | rescale_timesteps=rescale_timesteps, 429 | ) 430 | 431 | 432 | def add_dict_to_argparser(parser, default_dict): 433 | for k, v in default_dict.items(): 434 | v_type = type(v) 435 | if v is None: 436 | v_type = str 437 | elif isinstance(v, bool): 438 | v_type = str2bool 439 | parser.add_argument(f"--{k}", default=v, type=v_type) 440 | 441 | 442 | def args_to_dict(args, keys): 443 | return {k: getattr(args, k) for k in keys} 444 | 445 | 446 | def str2bool(v): 447 | """ 448 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 449 | """ 450 | if isinstance(v, bool): 451 | return v 452 | if v.lower() in ("yes", "true", "t", "y", "1"): 453 | return True 454 | elif v.lower() in ("no", "false", "f", "n", "0"): 455 | return False 456 | else: 457 | raise argparse.ArgumentTypeError("boolean value expected") 458 | -------------------------------------------------------------------------------- /guided_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | import matplotlib.pyplot as plt 5 | 6 | import blobfile as bf 7 | import torch as th 8 | import torch.distributed as dist 9 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 10 | from torch.optim import AdamW 11 | from pathlib import Path 12 | 13 | from . import dist_util, logger 14 | from .fp16_util import MixedPrecisionTrainer 15 | from .nn import update_ema 16 | from .resample import LossAwareSampler, UniformSampler 17 | from utils import clear 18 | 19 | # For ImageNet experiments, this was a good default value. 20 | # We found that the lg_loss_scale quickly climbed to 21 | # 20-21 within the first ~1K steps of training. 22 | INITIAL_LOG_LOSS_SCALE = 20.0 23 | 24 | class TrainLoop2: 25 | def __init__( 26 | self, 27 | *, 28 | model, 29 | diffusion, 30 | data, 31 | batch_size, 32 | microbatch, 33 | lr, 34 | ema_rate, 35 | log_interval, 36 | save_interval, 37 | resume_checkpoint, 38 | use_fp16=False, 39 | fp16_scale_growth=1e-3, 40 | schedule_sampler=None, 41 | weight_decay=0.0, 42 | lr_anneal_steps=0, 43 | ): 44 | self.model = model 45 | self.diffusion = diffusion 46 | self.data = data 47 | self.batch_size = batch_size 48 | self.microbatch = microbatch if microbatch > 0 else batch_size 49 | self.lr = lr 50 | self.ema_rate = ( 51 | [ema_rate] 52 | if isinstance(ema_rate, float) 53 | else [float(x) for x in ema_rate.split(",")] 54 | ) 55 | self.log_interval = log_interval 56 | self.save_interval = save_interval 57 | self.resume_checkpoint = resume_checkpoint 58 | self.use_fp16 = use_fp16 59 | self.fp16_scale_growth = fp16_scale_growth 60 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 61 | self.weight_decay = weight_decay 62 | self.lr_anneal_steps = lr_anneal_steps 63 | 64 | self.step = 0 65 | self.resume_step = 0 66 | self.global_batch = self.batch_size * dist.get_world_size() 67 | 68 | self.model_params = list(self.model.parameters()) 69 | self.master_params = self.model_params 70 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE 71 | self.sync_cuda = th.cuda.is_available() 72 | 73 | self._load_and_sync_parameters() 74 | if self.use_fp16: 75 | self._setup_fp16() 76 | 77 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) 78 | if self.resume_step: 79 | self._load_optimizer_state() 80 | # Model was resumed, either due to a restart or a checkpoint 81 | # being specified at the command line. 82 | self.ema_params = [ 83 | self._load_ema_parameters(rate) for rate in self.ema_rate 84 | ] 85 | else: 86 | self.ema_params = [ 87 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) 88 | ] 89 | 90 | if th.cuda.is_available() and 1==2: 91 | self.use_ddp = True 92 | self.ddp_model = DDP( 93 | self.model, 94 | device_ids=[dist_util.dev()], 95 | output_device=dist_util.dev(), 96 | broadcast_buffers=False, 97 | bucket_cap_mb=128, 98 | find_unused_parameters=False, 99 | ) 100 | else: 101 | if dist.get_world_size() > 1: 102 | logger.warn( 103 | "Distributed training requires CUDA. " 104 | "Gradients will not be synchronized properly!" 105 | ) 106 | self.use_ddp = False 107 | self.ddp_model = self.model 108 | 109 | def _load_and_sync_parameters(self): 110 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 111 | 112 | if resume_checkpoint: 113 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 114 | if dist.get_rank() == 0: 115 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 116 | self.model.load_state_dict( 117 | dist_util.load_state_dict( 118 | resume_checkpoint, map_location=dist_util.dev() 119 | ) 120 | ) 121 | 122 | dist_util.sync_params(self.model.parameters()) 123 | 124 | def _load_ema_parameters(self, rate): 125 | ema_params = copy.deepcopy(self.master_params) 126 | 127 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 128 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 129 | if ema_checkpoint: 130 | if dist.get_rank() == 0: 131 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 132 | state_dict = dist_util.load_state_dict( 133 | ema_checkpoint, map_location=dist_util.dev() 134 | ) 135 | ema_params = self._state_dict_to_master_params(state_dict) 136 | 137 | dist_util.sync_params(ema_params) 138 | return ema_params 139 | 140 | def _load_optimizer_state(self): 141 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 142 | opt_checkpoint = bf.join( 143 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 144 | ) 145 | if bf.exists(opt_checkpoint): 146 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 147 | state_dict = dist_util.load_state_dict( 148 | opt_checkpoint, map_location=dist_util.dev() 149 | ) 150 | self.opt.load_state_dict(state_dict) 151 | 152 | def _setup_fp16(self): 153 | self.master_params = make_master_params(self.model_params) 154 | self.model.convert_to_fp16() 155 | 156 | def run_loop(self): 157 | while ( 158 | not self.lr_anneal_steps 159 | or self.step + self.resume_step < self.lr_anneal_steps 160 | ): 161 | batch, cond = next(self.data) 162 | self.run_step(batch, cond) 163 | if self.step % self.log_interval == 0: 164 | logger.dumpkvs() 165 | if self.step % self.save_interval == 0: 166 | self.save() 167 | # Run for a finite amount of time in integration tests. 168 | if self.step > 300000: 169 | return 170 | self.step += 1 171 | # Save the last checkpoint if it wasn't already saved. 172 | if (self.step - 1) % self.save_interval != 0: 173 | self.save() 174 | 175 | def run_step(self, batch, cond): 176 | self.forward_backward(batch, cond) 177 | if self.use_fp16: 178 | self.optimize_fp16() 179 | else: 180 | self.optimize_normal() 181 | self.log_step() 182 | 183 | def forward_backward(self, batch, cond): 184 | zero_grad(self.model_params) 185 | for i in range(0, batch.shape[0], self.microbatch): 186 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 187 | micro_cond = { 188 | k: v[i : i + self.microbatch].to(dist_util.dev()) 189 | for k, v in cond.items() 190 | } 191 | last_batch = (i + self.microbatch) >= batch.shape[0] 192 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 193 | 194 | compute_losses = functools.partial( 195 | self.diffusion.training_losses, 196 | self.ddp_model, 197 | micro, 198 | t, 199 | model_kwargs=micro_cond, 200 | ) 201 | 202 | if last_batch or not self.use_ddp: 203 | losses = compute_losses() 204 | else: 205 | with self.ddp_model.no_sync(): 206 | losses = compute_losses() 207 | 208 | if isinstance(self.schedule_sampler, LossAwareSampler): 209 | self.schedule_sampler.update_with_local_losses( 210 | t, losses["loss"].detach() 211 | ) 212 | 213 | loss = (losses["loss"] * weights).mean() 214 | log_loss_dict( 215 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 216 | ) 217 | if self.use_fp16: 218 | loss_scale = 2 ** self.lg_loss_scale 219 | (loss * loss_scale).backward() 220 | else: 221 | loss.backward() 222 | 223 | def optimize_fp16(self): 224 | if any(not th.isfinite(p.grad).all() for p in self.model_params): 225 | self.lg_loss_scale -= 1 226 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 227 | return 228 | 229 | model_grads_to_master_grads(self.model_params, self.master_params) 230 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 231 | self._log_grad_norm() 232 | self._anneal_lr() 233 | self.opt.step() 234 | for rate, params in zip(self.ema_rate, self.ema_params): 235 | update_ema(params, self.master_params, rate=rate) 236 | master_params_to_model_params(self.model_params, self.master_params) 237 | self.lg_loss_scale += self.fp16_scale_growth 238 | 239 | def optimize_normal(self): 240 | self._log_grad_norm() 241 | self._anneal_lr() 242 | self.opt.step() 243 | for rate, params in zip(self.ema_rate, self.ema_params): 244 | update_ema(params, self.master_params, rate=rate) 245 | 246 | def _log_grad_norm(self): 247 | sqsum = 0.0 248 | for p in self.master_params: 249 | sqsum += (p.grad ** 2).sum().item() 250 | logger.logkv_mean("grad_norm", np.sqrt(sqsum)) 251 | 252 | def _anneal_lr(self): 253 | if not self.lr_anneal_steps: 254 | return 255 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 256 | lr = self.lr * (1 - frac_done) 257 | for param_group in self.opt.param_groups: 258 | param_group["lr"] = lr 259 | 260 | def log_step(self): 261 | logger.logkv("step", self.step + self.resume_step) 262 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 263 | if self.use_fp16: 264 | logger.logkv("lg_loss_scale", self.lg_loss_scale) 265 | 266 | def save(self): 267 | def save_checkpoint(rate, params): 268 | state_dict = self._master_params_to_state_dict(params) 269 | if dist.get_rank() == 0: 270 | logger.log(f"saving model {rate}...") 271 | if not rate: 272 | filename = f"model{(self.step+self.resume_step):06d}.pt" 273 | else: 274 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 275 | with bf.BlobFile(bf.join("/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/", filename), "wb") as f: 276 | th.save(state_dict, f) 277 | 278 | save_checkpoint(0, self.master_params) 279 | for rate, params in zip(self.ema_rate, self.ema_params): 280 | save_checkpoint(rate, params) 281 | 282 | if dist.get_rank() == 0: 283 | with bf.BlobFile( 284 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 285 | "wb", 286 | ) as f: 287 | th.save(self.opt.state_dict(), f) 288 | 289 | dist.barrier() 290 | 291 | def _master_params_to_state_dict(self, master_params): 292 | if self.use_fp16: 293 | master_params = unflatten_master_params( 294 | self.model.parameters(), master_params 295 | ) 296 | state_dict = self.model.state_dict() 297 | for i, (name, _value) in enumerate(self.model.named_parameters()): 298 | assert name in state_dict 299 | state_dict[name] = master_params[i] 300 | return state_dict 301 | 302 | def _state_dict_to_master_params(self, state_dict): 303 | params = [state_dict[name] for name, _ in self.model.named_parameters()] 304 | if self.use_fp16: 305 | return make_master_params(params) 306 | else: 307 | return params 308 | 309 | class TrainLoop: 310 | def __init__( 311 | self, 312 | *, 313 | model, 314 | diffusion, 315 | data, 316 | batch_size, 317 | microbatch, 318 | lr, 319 | ema_rate, 320 | log_interval, 321 | save_interval, 322 | resume_checkpoint, 323 | use_fp16=False, 324 | fp16_scale_growth=1e-3, 325 | schedule_sampler=None, 326 | weight_decay=0.0, 327 | lr_anneal_steps=0, 328 | ): 329 | self.model = model 330 | self.diffusion = diffusion 331 | self.data = data 332 | self.batch_size = batch_size 333 | self.microbatch = microbatch if microbatch > 0 else batch_size 334 | self.lr = lr 335 | self.ema_rate = ( 336 | [ema_rate] 337 | if isinstance(ema_rate, float) 338 | else [float(x) for x in ema_rate.split(",")] 339 | ) 340 | self.log_interval = log_interval 341 | self.save_interval = save_interval 342 | self.resume_checkpoint = resume_checkpoint 343 | self.use_fp16 = use_fp16 344 | self.fp16_scale_growth = fp16_scale_growth 345 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 346 | self.weight_decay = weight_decay 347 | self.lr_anneal_steps = lr_anneal_steps 348 | 349 | self.step = 0 350 | self.resume_step = 0 351 | self.global_batch = self.batch_size * dist.get_world_size() 352 | 353 | self.sync_cuda = th.cuda.is_available() 354 | 355 | self._load_and_sync_parameters() 356 | self.mp_trainer = MixedPrecisionTrainer( 357 | model=self.model, 358 | use_fp16=self.use_fp16, 359 | fp16_scale_growth=fp16_scale_growth, 360 | ) 361 | 362 | self.opt = AdamW( 363 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 364 | ) 365 | if self.resume_step: 366 | self._load_optimizer_state() 367 | # Model was resumed, either due to a restart or a checkpoint 368 | # being specified at the command line. 369 | self.ema_params = [ 370 | self._load_ema_parameters(rate) for rate in self.ema_rate 371 | ] 372 | else: 373 | self.ema_params = [ 374 | copy.deepcopy(self.mp_trainer.master_params) 375 | for _ in range(len(self.ema_rate)) 376 | ] 377 | 378 | if th.cuda.is_available(): 379 | self.use_ddp = True 380 | self.ddp_model = DDP( 381 | self.model, 382 | device_ids=[dist_util.dev()], 383 | output_device=dist_util.dev(), 384 | broadcast_buffers=False, 385 | bucket_cap_mb=128, 386 | find_unused_parameters=False, 387 | ) 388 | else: 389 | if dist.get_world_size() > 1: 390 | logger.warn( 391 | "Distributed training requires CUDA. " 392 | "Gradients will not be synchronized properly!" 393 | ) 394 | self.use_ddp = False 395 | self.ddp_model = self.model 396 | 397 | def _load_and_sync_parameters(self): 398 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 399 | 400 | if resume_checkpoint: 401 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 402 | if dist.get_rank() == 0: 403 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 404 | self.model.load_state_dict( 405 | dist_util.load_state_dict( 406 | resume_checkpoint, map_location=dist_util.dev() 407 | ) 408 | ) 409 | 410 | dist_util.sync_params(self.model.parameters()) 411 | 412 | def _load_ema_parameters(self, rate): 413 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 414 | 415 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 416 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 417 | if ema_checkpoint: 418 | if dist.get_rank() == 0: 419 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 420 | state_dict = dist_util.load_state_dict( 421 | ema_checkpoint, map_location=dist_util.dev() 422 | ) 423 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 424 | 425 | dist_util.sync_params(ema_params) 426 | return ema_params 427 | 428 | def _load_optimizer_state(self): 429 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 430 | opt_checkpoint = bf.join( 431 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 432 | ) 433 | if bf.exists(opt_checkpoint): 434 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 435 | state_dict = dist_util.load_state_dict( 436 | opt_checkpoint, map_location=dist_util.dev() 437 | ) 438 | self.opt.load_state_dict(state_dict) 439 | 440 | def run_loop(self): 441 | while ( 442 | not self.lr_anneal_steps 443 | or self.step + self.resume_step < self.lr_anneal_steps 444 | ): 445 | batch, cond = next(self.data) 446 | self.run_step(batch, cond) 447 | if self.step % self.log_interval == 0: 448 | logger.dumpkvs() 449 | if self.step % self.save_interval == 0: 450 | self.save() 451 | # After saving, sample unconditionally 452 | sample = self.diffusion.p_sample_loop( 453 | self.model, 454 | (1, 1, 256, 256), 455 | model_kwargs={}, 456 | ) 457 | plt.imsave(str(Path(get_blob_logdir()) / f"{self.step:05}.png"), clear(sample)) 458 | # Run for a finite amount of time in integration tests. 459 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 460 | return 461 | self.step += 1 462 | # Save the last checkpoint if it wasn't already saved. 463 | if (self.step - 1) % self.save_interval != 0: 464 | self.save() 465 | 466 | def run_step(self, batch, cond): 467 | self.forward_backward(batch, cond) 468 | took_step = self.mp_trainer.optimize(self.opt) 469 | if took_step: 470 | self._update_ema() 471 | self._anneal_lr() 472 | self.log_step() 473 | 474 | def forward_backward(self, batch, cond): 475 | self.mp_trainer.zero_grad() 476 | for i in range(0, batch.shape[0], self.microbatch): 477 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 478 | micro_cond = { 479 | k: v[i : i + self.microbatch].to(dist_util.dev()) 480 | for k, v in cond.items() 481 | } 482 | last_batch = (i + self.microbatch) >= batch.shape[0] 483 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 484 | 485 | compute_losses = functools.partial( 486 | self.diffusion.training_losses, 487 | self.ddp_model, 488 | micro, 489 | t, 490 | model_kwargs=micro_cond, 491 | ) 492 | 493 | if last_batch or not self.use_ddp: 494 | losses = compute_losses() 495 | else: 496 | with self.ddp_model.no_sync(): 497 | losses = compute_losses() 498 | 499 | if isinstance(self.schedule_sampler, LossAwareSampler): 500 | self.schedule_sampler.update_with_local_losses( 501 | t, losses["loss"].detach() 502 | ) 503 | 504 | loss = (losses["loss"] * weights).mean() 505 | log_loss_dict( 506 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 507 | ) 508 | self.mp_trainer.backward(loss) 509 | 510 | def _update_ema(self): 511 | for rate, params in zip(self.ema_rate, self.ema_params): 512 | update_ema(params, self.mp_trainer.master_params, rate=rate) 513 | 514 | def _anneal_lr(self): 515 | if not self.lr_anneal_steps: 516 | return 517 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 518 | lr = self.lr * (1 - frac_done) 519 | for param_group in self.opt.param_groups: 520 | param_group["lr"] = lr 521 | 522 | def log_step(self): 523 | logger.logkv("step", self.step + self.resume_step) 524 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 525 | 526 | def save(self): 527 | def save_checkpoint(rate, params): 528 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 529 | if dist.get_rank() == 0: 530 | logger.log(f"saving model {rate}...") 531 | if not rate: 532 | filename = f"model{(self.step+self.resume_step):06d}.pt" 533 | else: 534 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 535 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 536 | th.save(state_dict, f) 537 | 538 | save_checkpoint(0, self.mp_trainer.master_params) 539 | for rate, params in zip(self.ema_rate, self.ema_params): 540 | save_checkpoint(rate, params) 541 | 542 | if dist.get_rank() == 0: 543 | with bf.BlobFile( 544 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 545 | "wb", 546 | ) as f: 547 | th.save(self.opt.state_dict(), f) 548 | 549 | dist.barrier() 550 | 551 | 552 | 553 | def parse_resume_step_from_filename(filename): 554 | """ 555 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 556 | checkpoint's number of steps. 557 | """ 558 | split = filename.split("model") 559 | if len(split) < 2: 560 | return 0 561 | split1 = split[-1].split(".")[0] 562 | try: 563 | return int(split1) 564 | except ValueError: 565 | return 0 566 | 567 | 568 | def get_blob_logdir(): 569 | # You can change this to be a separate path to save checkpoints to 570 | # a blobstore or some external drive. 571 | return logger.get_dir() 572 | 573 | 574 | def find_resume_checkpoint(): 575 | # On your infrastructure, you may want to override this to automatically 576 | # discover the latest checkpoint on your blob storage, etc. 577 | return None 578 | 579 | 580 | def find_ema_checkpoint(main_checkpoint, step, rate): 581 | if main_checkpoint is None: 582 | return None 583 | filename = f"ema_{rate}_{(step):06d}.pt" 584 | path = bf.join(bf.dirname(main_checkpoint), filename) 585 | if bf.exists(path): 586 | return path 587 | return None 588 | 589 | 590 | def log_loss_dict(diffusion, ts, losses): 591 | for key, values in losses.items(): 592 | logger.logkv_mean(key, values.mean().item()) 593 | # Log the quantiles (four quartiles, in particular). 594 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 595 | quartile = int(4 * sub_t / diffusion.num_timesteps) 596 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 597 | -------------------------------------------------------------------------------- /guided_diffusion/training_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import tqdm 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | import argparse 9 | import traceback 10 | import logging 11 | import yaml 12 | import sys 13 | import os 14 | import torch 15 | import numpy as np 16 | 17 | from pathlib import Path 18 | 19 | from guided_diffusion.script_util import create_model, create_gaussian_diffusion 20 | from skimage.metrics import peak_signal_noise_ratio 21 | from pathlib import Path 22 | from physics.ct import CT 23 | from physics.mri import SinglecoilMRI_comp, MulticoilMRI 24 | from utils import CG, clear, get_mask, nchw_comp_to_real, real_to_nchw_comp, normalize_np, get_beta_schedule 25 | from functools import partial 26 | 27 | torch.set_printoptions(sci_mode=False) 28 | def compute_alpha(beta, t): 29 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 30 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 31 | return a 32 | 33 | 34 | def dict2namespace(config): 35 | namespace = argparse.Namespace() 36 | for key, value in config.items(): 37 | if isinstance(value, dict): 38 | new_value = dict2namespace(value) 39 | else: 40 | new_value = value 41 | setattr(namespace, key, new_value) 42 | return namespace 43 | 44 | 45 | def parse_args_and_config(): 46 | parser = argparse.ArgumentParser(description=globals()["__doc__"]) 47 | 48 | parser.add_argument( 49 | "--config", type=str, required=True, help="Path to the config file" 50 | ) 51 | parser.add_argument( 52 | "--type", type=str, required=True, help="Either [2d, 3d]" 53 | ) 54 | parser.add_argument( 55 | "--CG_iter", type=int, default=5, help="Inner number of iterations for CG" 56 | ) 57 | parser.add_argument( 58 | "--Nview", type=int, default=16, help="number of projections for CT" 59 | ) 60 | parser.add_argument("--seed", type=int, default=1234, help="Set different seeds for diverse results") 61 | parser.add_argument( 62 | "--exp", type=str, default="./exp", help="Path for saving running related data." 63 | ) 64 | parser.add_argument( 65 | "--ckpt_load_name", type=str, default="AAPM256_1M.pt", help="Load pre-trained ckpt" 66 | ) 67 | parser.add_argument( 68 | "--deg", type=str, required=True, help="Degradation" 69 | ) 70 | parser.add_argument( 71 | "--sigma_y", type=float, default=0., help="sigma_y" 72 | ) 73 | parser.add_argument( 74 | "--eta", type=float, default=0.85, help="Eta" 75 | ) 76 | parser.add_argument( 77 | "--rho", type=float, default=10.0, help="rho" 78 | ) 79 | parser.add_argument( 80 | "--lamb", type=float, default=0.04, help="lambda for TV" 81 | ) 82 | parser.add_argument( 83 | "--gamma", type=float, default=1.0, help="regularizer for noisy recon" 84 | ) 85 | parser.add_argument( 86 | "--T_sampling", type=int, default=50, help="Total number of sampling steps" 87 | ) 88 | parser.add_argument( 89 | "-i", 90 | "--image_folder", 91 | type=str, 92 | default="./results", 93 | help="The folder name of samples", 94 | ) 95 | parser.add_argument( 96 | "--dataset_path", type=str, default="/media/harry/tomo/AAPM_data_vol/256_sorted/L067", help="The folder of the dataset" 97 | ) 98 | 99 | # MRI-exp arguments 100 | parser.add_argument( 101 | "--mask_type", type=str, default="uniform1d", help="Undersampling type" 102 | ) 103 | parser.add_argument( 104 | "--acc_factor", type=int, default=4, help="acceleration factor" 105 | ) 106 | parser.add_argument( 107 | "--nspokes", type=int, default=30, help="Number of sampled lines in radial trajectory" 108 | ) 109 | parser.add_argument( 110 | "--center_fraction", type=float, default=0.08, help="ACS region" 111 | ) 112 | 113 | 114 | args = parser.parse_args() 115 | 116 | # parse config file 117 | with open(os.path.join("configs/vp", args.config), "r") as f: 118 | config = yaml.safe_load(f) 119 | new_config = dict2namespace(config) 120 | 121 | if "CT" in args.deg: 122 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"view{args.Nview}" 123 | elif "MRI" in args.deg: 124 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"{args.mask_type}_acc{args.acc_factor}" 125 | 126 | args.image_folder.mkdir(exist_ok=True, parents=True) 127 | if not os.path.exists(args.image_folder): 128 | os.makedirs(args.image_folder) 129 | 130 | # add device 131 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 132 | logging.info("Using device: {}".format(device)) 133 | new_config.device = device 134 | 135 | # set random seed 136 | torch.manual_seed(args.seed) 137 | np.random.seed(args.seed) 138 | if torch.cuda.is_available(): 139 | torch.cuda.manual_seed_all(args.seed) 140 | 141 | torch.backends.cudnn.benchmark = True 142 | 143 | return args, new_config 144 | 145 | class Diffusion(object): 146 | def __init__(self, args, config, device=None): 147 | self.args = args 148 | self.args.image_folder = Path(self.args.image_folder) 149 | for t in ["input", "recon", "label"]: 150 | if t == "recon": 151 | (self.args.image_folder / t / "progress").mkdir(exist_ok=True, parents=True) 152 | else: 153 | (self.args.image_folder / t).mkdir(exist_ok=True, parents=True) 154 | self.config = config 155 | if device is None: 156 | device = ( 157 | torch.device("cuda") 158 | if torch.cuda.is_available() 159 | else torch.device("cpu") 160 | ) 161 | self.device = device 162 | 163 | self.model_var_type = config.model.var_type 164 | betas = get_beta_schedule( 165 | beta_schedule=config.diffusion.beta_schedule, 166 | beta_start=config.diffusion.beta_start, 167 | beta_end=config.diffusion.beta_end, 168 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 169 | ) 170 | betas = self.betas = torch.from_numpy(betas).float().to(self.device) 171 | self.num_timesteps = betas.shape[0] 172 | 173 | alphas = 1.0 - betas 174 | alphas_cumprod = alphas.cumprod(dim=0) 175 | alphas_cumprod_prev = torch.cat( 176 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0 177 | ) 178 | self.alphas_cumprod_prev = alphas_cumprod_prev 179 | posterior_variance = ( 180 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 181 | ) 182 | if self.model_var_type == "fixedlarge": 183 | self.logvar = betas.log() 184 | elif self.model_var_type == "fixedsmall": 185 | self.logvar = posterior_variance.clamp(min=1e-20).log() 186 | 187 | def train(self): 188 | config_dict = vars(self.config.model) 189 | model = create_model(**config_dict) 190 | ckpt = os.path.join(self.args.exp, "vp", self.args.ckpt_load_name) 191 | 192 | model.load_state_dict(torch.load(ckpt, map_location=self.device)) 193 | print(f"Model ckpt loaded from {ckpt}") 194 | model.to("cuda") 195 | model.train() 196 | 197 | 198 | # model.eval() 199 | 200 | # print('Run DDS.', 201 | # f'{self.args.T_sampling} sampling steps.', 202 | # f'Task: {self.args.deg}.' 203 | # ) 204 | # self.dds(model) 205 | 206 | 207 | 208 | def main(): 209 | args, config = parse_args_and_config() 210 | diffusion_model = Diffusion(args, config) 211 | diffusion_model.train() 212 | 213 | 214 | if __name__ == "__main__": 215 | print("running training") 216 | main() -------------------------------------------------------------------------------- /guided_diffusion/training_triplane_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import tqdm 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | import argparse 9 | import traceback 10 | import logging 11 | import yaml 12 | import sys 13 | import os 14 | import torch 15 | import numpy as np 16 | 17 | from pathlib import Path 18 | 19 | from guided_diffusion.script_util import create_model, create_gaussian_diffusion 20 | from skimage.metrics import peak_signal_noise_ratio 21 | from pathlib import Path 22 | from physics.ct import CT 23 | from physics.mri import SinglecoilMRI_comp, MulticoilMRI 24 | from utils import CG, clear, get_mask, nchw_comp_to_real, real_to_nchw_comp, normalize_np, get_beta_schedule 25 | from functools import partial 26 | 27 | torch.set_printoptions(sci_mode=False) 28 | def compute_alpha(beta, t): 29 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 30 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 31 | return a 32 | 33 | 34 | def dict2namespace(config): 35 | namespace = argparse.Namespace() 36 | for key, value in config.items(): 37 | if isinstance(value, dict): 38 | new_value = dict2namespace(value) 39 | else: 40 | new_value = value 41 | setattr(namespace, key, new_value) 42 | return namespace 43 | 44 | 45 | def parse_args_and_config(): 46 | parser = argparse.ArgumentParser(description=globals()["__doc__"]) 47 | 48 | parser.add_argument( 49 | "--config", type=str, required=True, help="Path to the config file" 50 | ) 51 | parser.add_argument( 52 | "--type", type=str, required=True, help="Either [2d, 3d]" 53 | ) 54 | parser.add_argument( 55 | "--CG_iter", type=int, default=5, help="Inner number of iterations for CG" 56 | ) 57 | parser.add_argument( 58 | "--Nview", type=int, default=16, help="number of projections for CT" 59 | ) 60 | parser.add_argument("--seed", type=int, default=1234, help="Set different seeds for diverse results") 61 | parser.add_argument( 62 | "--exp", type=str, default="./exp", help="Path for saving running related data." 63 | ) 64 | parser.add_argument( 65 | "--ckpt_load_name", type=str, default="AAPM256_1M.pt", help="Load pre-trained ckpt" 66 | ) 67 | parser.add_argument( 68 | "--deg", type=str, required=True, help="Degradation" 69 | ) 70 | parser.add_argument( 71 | "--sigma_y", type=float, default=0., help="sigma_y" 72 | ) 73 | parser.add_argument( 74 | "--eta", type=float, default=0.85, help="Eta" 75 | ) 76 | parser.add_argument( 77 | "--rho", type=float, default=10.0, help="rho" 78 | ) 79 | parser.add_argument( 80 | "--lamb", type=float, default=0.04, help="lambda for TV" 81 | ) 82 | parser.add_argument( 83 | "--gamma", type=float, default=1.0, help="regularizer for noisy recon" 84 | ) 85 | parser.add_argument( 86 | "--T_sampling", type=int, default=50, help="Total number of sampling steps" 87 | ) 88 | parser.add_argument( 89 | "-i", 90 | "--image_folder", 91 | type=str, 92 | default="./results", 93 | help="The folder name of samples", 94 | ) 95 | parser.add_argument( 96 | "--dataset_path", type=str, default="/media/harry/tomo/AAPM_data_vol/256_sorted/L067", help="The folder of the dataset" 97 | ) 98 | 99 | # MRI-exp arguments 100 | parser.add_argument( 101 | "--mask_type", type=str, default="uniform1d", help="Undersampling type" 102 | ) 103 | parser.add_argument( 104 | "--acc_factor", type=int, default=4, help="acceleration factor" 105 | ) 106 | parser.add_argument( 107 | "--nspokes", type=int, default=30, help="Number of sampled lines in radial trajectory" 108 | ) 109 | parser.add_argument( 110 | "--center_fraction", type=float, default=0.08, help="ACS region" 111 | ) 112 | 113 | 114 | args = parser.parse_args() 115 | 116 | # parse config file 117 | with open(os.path.join("configs/vp", args.config), "r") as f: 118 | config = yaml.safe_load(f) 119 | new_config = dict2namespace(config) 120 | 121 | if "CT" in args.deg: 122 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"view{args.Nview}" 123 | elif "MRI" in args.deg: 124 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"{args.mask_type}_acc{args.acc_factor}" 125 | 126 | args.image_folder.mkdir(exist_ok=True, parents=True) 127 | if not os.path.exists(args.image_folder): 128 | os.makedirs(args.image_folder) 129 | 130 | # add device 131 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 132 | logging.info("Using device: {}".format(device)) 133 | new_config.device = device 134 | 135 | # set random seed 136 | torch.manual_seed(args.seed) 137 | np.random.seed(args.seed) 138 | if torch.cuda.is_available(): 139 | torch.cuda.manual_seed_all(args.seed) 140 | 141 | torch.backends.cudnn.benchmark = True 142 | 143 | return args, new_config 144 | 145 | class Diffusion(object): 146 | def __init__(self, args, config, device=None): 147 | self.args = args 148 | self.args.image_folder = Path(self.args.image_folder) 149 | for t in ["input", "recon", "label"]: 150 | if t == "recon": 151 | (self.args.image_folder / t / "progress").mkdir(exist_ok=True, parents=True) 152 | else: 153 | (self.args.image_folder / t).mkdir(exist_ok=True, parents=True) 154 | self.config = config 155 | if device is None: 156 | device = ( 157 | torch.device("cuda") 158 | if torch.cuda.is_available() 159 | else torch.device("cpu") 160 | ) 161 | self.device = device 162 | 163 | self.model_var_type = config.model.var_type 164 | betas = get_beta_schedule( 165 | beta_schedule=config.diffusion.beta_schedule, 166 | beta_start=config.diffusion.beta_start, 167 | beta_end=config.diffusion.beta_end, 168 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 169 | ) 170 | betas = self.betas = torch.from_numpy(betas).float().to(self.device) 171 | self.num_timesteps = betas.shape[0] 172 | 173 | alphas = 1.0 - betas 174 | alphas_cumprod = alphas.cumprod(dim=0) 175 | alphas_cumprod_prev = torch.cat( 176 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0 177 | ) 178 | self.alphas_cumprod_prev = alphas_cumprod_prev 179 | posterior_variance = ( 180 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 181 | ) 182 | if self.model_var_type == "fixedlarge": 183 | self.logvar = betas.log() 184 | elif self.model_var_type == "fixedsmall": 185 | self.logvar = posterior_variance.clamp(min=1e-20).log() 186 | 187 | def train(self): 188 | config_dict = vars(self.config.model) 189 | config_dict["use_spacecode"] = True 190 | print(config_dict) 191 | print(config_dict) 192 | model = create_model(**config_dict) 193 | ckpt = os.path.join(self.args.exp, "vp", self.args.ckpt_load_name) 194 | 195 | model.load_state_dict(torch.load(ckpt, map_location=self.device)) 196 | print(f"Model ckpt loaded from {ckpt}") 197 | # model.to("cuda") 198 | # model.train() 199 | 200 | 201 | # model.eval() 202 | 203 | # print('Run DDS.', 204 | # f'{self.args.T_sampling} sampling steps.', 205 | # f'Task: {self.args.deg}.' 206 | # ) 207 | # self.dds(model) 208 | 209 | 210 | 211 | def main(): 212 | args, config = parse_args_and_config() 213 | diffusion_model = Diffusion(args, config) 214 | diffusion_model.train() 215 | 216 | 217 | if __name__ == "__main__": 218 | print("running training") 219 | main() -------------------------------------------------------------------------------- /guided_diffusion/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from tqdm import tqdm 5 | 6 | 7 | def get_alpha_schedule(type="linear", start=0.4, end=0.015, total=1000): 8 | if type == "linear": 9 | schedule = np.linspace(start, end, total) 10 | elif type == "const": 11 | assert start == end, f"For const schedule, start and end should match. Got start:{start}, end:{end}" 12 | schedule = np.full(total, start) 13 | return np.flip(schedule) 14 | 15 | 16 | def CG(A_fn, b_cg, x, n_inner=10, eps=1e-8): 17 | r = b_cg - A_fn(x) 18 | p = r.clone() 19 | rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T) 20 | for _ in range(n_inner): 21 | Ap = A_fn(p) 22 | a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T) 23 | 24 | x += a * p 25 | r -= a * Ap 26 | 27 | rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T) 28 | 29 | if torch.sqrt(rs_new) < eps: 30 | break 31 | p = r + (rs_new / rs_old) * p 32 | rs_old = rs_new 33 | return x 34 | 35 | # x0_t_hat = x0_t - A_funcs.A_pinv( 36 | # A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1) 37 | # ).reshape(*x0_t.size()) 38 | # returns vectorized A^T(A(x)) 39 | 40 | 41 | def Acg(x, A_func): 42 | x_vec = x.reshape(x.size(0), -1) 43 | tmp = A_func.At(A_func.A(x_vec)) 44 | return tmp.reshape(*x.size()) 45 | 46 | 47 | def clear_color(x): 48 | x = x.detach().cpu().squeeze().numpy() 49 | return normalize_np(np.transpose(x, (1, 2, 0))) 50 | 51 | 52 | def clear(x): 53 | x = x.detach().cpu().squeeze().numpy() 54 | return x 55 | 56 | 57 | def normalize_np(img): 58 | """ Normalize img in arbitrary range to [0, 1] """ 59 | img -= np.min(img) 60 | img /= np.max(img) 61 | return img 62 | 63 | 64 | def clip(img): 65 | return torch.clip(img, -1.0, 1.0) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import traceback 3 | import logging 4 | import yaml 5 | import sys 6 | import os 7 | import torch 8 | import numpy as np 9 | 10 | from solver_2d import Diffusion as Diffusion_2d 11 | from solver_3d import Diffusion as Diffusion_3d 12 | from solver_3d_baselines import Diffusion as Diffusion_baseline 13 | from solver_3D_blend import Diffusion as Diffusion_blend 14 | from solver_3d_blend_timetravel import Diffusion as Diffusion_blendtt 15 | from eval_3DCT_blendcond import Diffusion as Diffusion_blendcond 16 | from eval_3D_blend_cond_lidc import Diffusion as Diffusion_blendcond_lidc 17 | from solver_3D_diffusionmbir import Diffusion as Diffusionmbir 18 | from pathlib import Path 19 | 20 | torch.set_printoptions(sci_mode=False) 21 | 22 | def parse_args_and_config(): 23 | parser = argparse.ArgumentParser(description=globals()["__doc__"]) 24 | 25 | parser.add_argument( 26 | "--config", type=str, required=True, help="Path to the config file" 27 | ) 28 | parser.add_argument( 29 | "--type", type=str, required=True, help="Either [2d, 3d]" 30 | ) 31 | parser.add_argument( 32 | "--CG_iter", type=int, default=5, help="Inner number of iterations for CG" 33 | ) 34 | parser.add_argument( 35 | "--Nview", type=int, default=16, help="number of projections for CT" 36 | ) 37 | parser.add_argument("--seed", type=int, default=1234, help="Set different seeds for diverse results") 38 | parser.add_argument( 39 | "--exp", type=str, default="./exp", help="Path for saving running related data." 40 | ) 41 | parser.add_argument( 42 | "--ckpt_load_name", type=str, default="AAPM256_1M.pt", help="Load pre-trained ckpt" 43 | ) 44 | parser.add_argument( 45 | "--deg", type=str, required=True, help="Degradation" 46 | ) 47 | parser.add_argument( 48 | "--sigma_y", type=float, default=0., help="sigma_y" 49 | ) 50 | parser.add_argument( 51 | "--eta", type=float, default=0.85, help="Eta" 52 | ) 53 | parser.add_argument( 54 | "--rho", type=float, default=10.0, help="rho" 55 | ) 56 | parser.add_argument( 57 | "--lamb", type=float, default=0.04, help="lambda for TV" 58 | ) 59 | parser.add_argument( 60 | "--gamma", type=float, default=1.0, help="regularizer for noisy recon" 61 | ) 62 | parser.add_argument( 63 | "--T_sampling", type=int, default=50, help="Total number of sampling steps" 64 | ) 65 | parser.add_argument( 66 | "-i", 67 | "--image_folder", 68 | type=str, 69 | default="./results", 70 | help="The folder name of samples", 71 | ) 72 | parser.add_argument( 73 | "--dataset_path", type=str, default="/media/harry/tomo/AAPM_data_vol/256_sorted/L067", help="The folder of the dataset" 74 | ) 75 | 76 | # MRI-exp arguments 77 | parser.add_argument( 78 | "--mask_type", type=str, default="uniform1d", help="Undersampling type" 79 | ) 80 | parser.add_argument( 81 | "--acc_factor", type=int, default=4, help="acceleration factor" 82 | ) 83 | parser.add_argument( 84 | "--nspokes", type=int, default=30, help="Number of sampled lines in radial trajectory" 85 | ) 86 | parser.add_argument( 87 | "--center_fraction", type=float, default=0.08, help="ACS region" 88 | ) 89 | 90 | 91 | args = parser.parse_args() 92 | 93 | # parse config file 94 | with open(os.path.join("configs/vp", args.config), "r") as f: 95 | config = yaml.safe_load(f) 96 | new_config = dict2namespace(config) 97 | 98 | if "CT" in args.deg: 99 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"view{args.Nview}" 100 | elif "MRI" in args.deg: 101 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"{args.mask_type}_acc{args.acc_factor}" 102 | 103 | args.image_folder.mkdir(exist_ok=True, parents=True) 104 | if not os.path.exists(args.image_folder): 105 | os.makedirs(args.image_folder) 106 | 107 | # add device 108 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 109 | logging.info("Using device: {}".format(device)) 110 | new_config.device = device 111 | 112 | # set random seed 113 | torch.manual_seed(args.seed) 114 | np.random.seed(args.seed) 115 | if torch.cuda.is_available(): 116 | torch.cuda.manual_seed_all(args.seed) 117 | 118 | torch.backends.cudnn.benchmark = True 119 | 120 | return args, new_config 121 | 122 | 123 | def dict2namespace(config): 124 | namespace = argparse.Namespace() 125 | for key, value in config.items(): 126 | if isinstance(value, dict): 127 | new_value = dict2namespace(value) 128 | else: 129 | new_value = value 130 | setattr(namespace, key, new_value) 131 | return namespace 132 | 133 | 134 | def main(): 135 | args, config = parse_args_and_config() 136 | 137 | try: 138 | if args.type == "2d": 139 | runner = Diffusion_2d(args, config) 140 | elif args.type == "3d": 141 | runner = Diffusion_3d(args, config) 142 | elif args.type == "3dblend": 143 | runner = Diffusion_blend(args, config) 144 | elif args.type == "3dblendtt": 145 | runner = Diffusion_blendtt(args, config) 146 | elif args.type == "3dblendcond": 147 | runner = Diffusion_blendcond(args, config) 148 | elif args.type == "3dblendcond_lidc": 149 | runner = Diffusion_blendcond_lidc(args, config) 150 | elif args.type == "3dbaseline": 151 | runner = Diffusion_baseline(args, config) 152 | elif args.type == "diffusionmbir": 153 | runner = Diffusionmbir(args, config) 154 | 155 | runner.sample() 156 | except Exception: 157 | logging.error(traceback.format_exc()) 158 | 159 | return 0 160 | 161 | 162 | if __name__ == "__main__": 163 | sys.exit(main()) 164 | -------------------------------------------------------------------------------- /train_SVCT_3D_triplane.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | Nview=8 4 | T_sampling=50 5 | eta=0.85 6 | 7 | python training_triplane_script.py \ 8 | --type '2d' \ 9 | --config AAPM_256_lsun.yaml \ 10 | --dataset_path "/nfs/turbo/coe-liyues/bowenbw/DDS/indist_samples/CT/L067" \ 11 | --ckpt_load_name "/nfs/turbo/coe-liyues/bowenbw/DDS/checkpoints/AAPM256_1M.pth" \ 12 | --Nview $Nview \ 13 | --eta $eta \ 14 | --deg "SV-CT" \ 15 | --sigma_y 0.01 \ 16 | --T_sampling 100 \ 17 | --T_sampling $T_sampling \ 18 | --resume_checkpoint true \ 19 | -i ./results -------------------------------------------------------------------------------- /training_triplane_lidc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import tqdm 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from glob import glob 8 | import argparse 9 | import traceback 10 | import logging 11 | import yaml 12 | import sys 13 | import os 14 | import torch 15 | import numpy as np 16 | 17 | from pathlib import Path 18 | from guided_diffusion.CTDataset import * 19 | from guided_diffusion.train_util import * 20 | from guided_diffusion.script_util import create_model, create_gaussian_diffusion 21 | from skimage.metrics import peak_signal_noise_ratio 22 | from pathlib import Path 23 | from physics.ct import CT 24 | from physics.mri import SinglecoilMRI_comp, MulticoilMRI 25 | from utils import CG, clear, get_mask, nchw_comp_to_real, real_to_nchw_comp, normalize_np, get_beta_schedule 26 | from functools import partial 27 | 28 | def compute_alpha(beta, t): 29 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 30 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 31 | return a 32 | 33 | def dict2namespace(config): 34 | namespace = argparse.Namespace() 35 | for key, value in config.items(): 36 | if isinstance(value, dict): 37 | new_value = dict2namespace(value) 38 | else: 39 | new_value = value 40 | setattr(namespace, key, new_value) 41 | return namespace 42 | 43 | def parse_args_and_config(): 44 | parser = argparse.ArgumentParser(description=globals()["__doc__"]) 45 | parser.add_argument( 46 | "--config", type=str, required=True, help="Path to the config file" 47 | ) 48 | parser.add_argument( 49 | "--type", type=str, required=True, help="Either [2d, 3d]" 50 | ) 51 | parser.add_argument( 52 | "--CG_iter", type=int, default=5, help="Inner number of iterations for CG" 53 | ) 54 | parser.add_argument( 55 | "--Nview", type=int, default=16, help="number of projections for CT" 56 | ) 57 | parser.add_argument("--seed", type=int, default=1234, help="Set different seeds for diverse results") 58 | parser.add_argument( 59 | "--exp", type=str, default="./exp", help="Path for saving running related data." 60 | ) 61 | parser.add_argument( 62 | "--ckpt_load_name", type=str, default="AAPM256_1M.pt", help="Load pre-trained ckpt" 63 | ) 64 | parser.add_argument( 65 | "--deg", type=str, required=True, help="Degradation" 66 | ) 67 | parser.add_argument( 68 | "--sigma_y", type=float, default=0., help="sigma_y" 69 | ) 70 | parser.add_argument( 71 | "--eta", type=float, default=0.85, help="Eta" 72 | ) 73 | parser.add_argument( 74 | "--rho", type=float, default=10.0, help="rho" 75 | ) 76 | parser.add_argument( 77 | "--lamb", type=float, default=0.04, help="lambda for TV" 78 | ) 79 | parser.add_argument( 80 | "--gamma", type=float, default=1.0, help="regularizer for noisy recon" 81 | ) 82 | parser.add_argument( 83 | "--T_sampling", type=int, default=50, help="Total number of sampling steps" 84 | ) 85 | parser.add_argument( 86 | "--resume_checkpoint", type=bool, default=False, help="resume training from a previous checkpoint" 87 | ) 88 | parser.add_argument( 89 | "-i", 90 | "--image_folder", 91 | type=str, 92 | default="./results", 93 | help="The folder name of samples", 94 | ) 95 | parser.add_argument( 96 | "--dataset_path", type=str, default="/media/harry/tomo/AAPM_data_vol/256_sorted/L067", help="The folder of the dataset" 97 | ) 98 | # MRI-exp arguments 99 | parser.add_argument( 100 | "--mask_type", type=str, default="uniform1d", help="Undersampling type" 101 | ) 102 | parser.add_argument( 103 | "--acc_factor", type=int, default=4, help="acceleration factor" 104 | ) 105 | parser.add_argument( 106 | "--nspokes", type=int, default=30, help="Number of sampled lines in radial trajectory" 107 | ) 108 | parser.add_argument( 109 | "--center_fraction", type=float, default=0.08, help="ACS region" 110 | ) 111 | args = parser.parse_args() 112 | 113 | # parse config file 114 | with open(os.path.join("configs/vp", args.config), "r") as f: 115 | config = yaml.safe_load(f) 116 | new_config = dict2namespace(config) 117 | 118 | if "CT" in args.deg: 119 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"view{args.Nview}" 120 | elif "MRI" in args.deg: 121 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"{args.mask_type}_acc{args.acc_factor}" 122 | 123 | args.image_folder.mkdir(exist_ok=True, parents=True) 124 | if not os.path.exists(args.image_folder): 125 | os.makedirs(args.image_folder) 126 | # add device 127 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 128 | logging.info("Using device: {}".format(device)) 129 | new_config.device = device 130 | # set random seed 131 | torch.manual_seed(args.seed) 132 | np.random.seed(args.seed) 133 | if torch.cuda.is_available(): 134 | torch.cuda.manual_seed_all(args.seed) 135 | torch.backends.cudnn.benchmark = True 136 | return args, new_config 137 | 138 | class Diffusion(object): 139 | def __init__(self, args, config, device=None): 140 | self.args = args 141 | self.args.image_folder = Path(self.args.image_folder) 142 | for t in ["input", "recon", "label"]: 143 | if t == "recon": 144 | (self.args.image_folder / t / "progress").mkdir(exist_ok=True, parents=True) 145 | else: 146 | (self.args.image_folder / t).mkdir(exist_ok=True, parents=True) 147 | self.config = config 148 | if device is None: 149 | device = ( 150 | torch.device("cuda") 151 | if torch.cuda.is_available() 152 | else torch.device("cpu") 153 | ) 154 | self.device = device 155 | 156 | self.model_var_type = config.model.var_type 157 | betas = get_beta_schedule( 158 | beta_schedule=config.diffusion.beta_schedule, 159 | beta_start=config.diffusion.beta_start, 160 | beta_end=config.diffusion.beta_end, 161 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 162 | ) 163 | betas = self.betas = torch.from_numpy(betas).float().to(self.device) 164 | self.num_timesteps = betas.shape[0] 165 | 166 | alphas = 1.0 - betas 167 | alphas_cumprod = alphas.cumprod(dim=0) 168 | alphas_cumprod_prev = torch.cat( 169 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0 170 | ) 171 | self.alphas_cumprod_prev = alphas_cumprod_prev 172 | posterior_variance = ( 173 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 174 | ) 175 | if self.model_var_type == "fixedlarge": 176 | self.logvar = betas.log() 177 | elif self.model_var_type == "fixedsmall": 178 | self.logvar = posterior_variance.clamp(min=1e-20).log() 179 | 180 | def train(self): 181 | config_dict = vars(self.config.model) 182 | config_dict["class_cond"] = True 183 | config_dict["use_spacecode"] = False 184 | print(config_dict) 185 | model = create_model(**config_dict) 186 | print(model.use_spacecode, "using spacecode") 187 | ckpt = os.path.join(self.args.exp, "vp", self.args.ckpt_load_name) 188 | 189 | pretrainsteps = 0 190 | if self.args.resume_checkpoint is True: 191 | print("resuming training") 192 | ckpt = "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/triplane3D_finetune_452024_iter50099_cond.ckpt" 193 | pretrainsteps = 50000 194 | else: 195 | ckpt = "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/256x256_diffusion_uncond.pt" 196 | 197 | loaded = torch.load(ckpt, map_location=self.device) 198 | model.load_state_dict(torch.load(ckpt, map_location=self.device)['state_dict']) ####if using last checkpoint 199 | 200 | # model.load_state_dict(torch.load(ckpt, map_location=self.device), strict = False) ######if using imagenet checkpoint 201 | 202 | print(f"Model ckpt loaded from {ckpt}") 203 | model.to("cuda") 204 | model.train() 205 | 206 | diffusion = create_gaussian_diffusion( 207 | steps=1000, 208 | learn_sigma=True, 209 | noise_schedule="linear", 210 | use_kl=False, 211 | predict_xstart=False, 212 | rescale_timesteps=False, 213 | rescale_learned_sigmas=False, 214 | timestep_respacing="", 215 | ) 216 | print(diffusion.training_losses, "training loss") 217 | 218 | lr = 1e-5 219 | params = list(model.parameters()) 220 | 221 | opt = torch.optim.AdamW(params, lr=lr) 222 | 223 | #############################testing feed numpy matrix into the training script######################## 224 | 225 | #########################use the given trainer in improved_diffusion######################### 226 | 227 | """use the given trainer in improved_diffusion 228 | ds = CTDataset() 229 | params = {'batch_size': 4} 230 | training_generator = torch.utils.data.DataLoader(ds, **params) 231 | def load_data(loader): 232 | while True: 233 | yield from loader 234 | 235 | data = load_data(training_generator) 236 | 237 | TrainLoop2( 238 | model=model, 239 | diffusion=diffusion, 240 | data=data, 241 | batch_size=4, 242 | microbatch=-1, 243 | lr=3e-4, 244 | ema_rate="0.9999", 245 | log_interval=10, 246 | save_interval=2500, 247 | resume_checkpoint="", 248 | use_fp16=False, 249 | fp16_scale_growth=1e-3, 250 | schedule_sampler="uniform", 251 | weight_decay=0.0, 252 | lr_anneal_steps=0, 253 | ).run_loop() 254 | """ 255 | ################################train manually#################################### 256 | # files = glob('/nfs/turbo/coe-liyues/bowenbw/3DCT/AAPM_fusion_training/*') 257 | # files = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/AAPM_fusion_training_cond/*") 258 | # files2 = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/slice_fusion_training/*") 259 | # files = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/LIDC_fusion_training_cond/*") 260 | # files = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/LIDC_fusion_training_cond_small/*") 261 | files = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/LIDC_fusion_training_cond_large/*") 262 | for m in range(100100): 263 | x_train = np.zeros((4, 3, 256, 256)) 264 | y = torch.randint(0, 3, (4,)) 265 | for l in range(4): 266 | filename = np.random.choice(files) 267 | x_raw = np.transpose(np.load(filename), (2,0,1))[0:3] 268 | x_raw = np.clip(x_raw*2-1, -1, 1) 269 | x_train[l] = x_raw.copy() 270 | y_val = int((filename.split(".")[0]).split("_")[-1]) 271 | y[l] = y_val 272 | print(y_val, filename) 273 | x_orig = torch.from_numpy(x_train).to("cuda").to(torch.float) 274 | i = torch.randint(0, 1000, (4,)) 275 | t = i.to("cuda").long() 276 | y = y.to("cuda").long() 277 | 278 | model_kwargs = {} 279 | model_kwargs["y"] = y 280 | 281 | if m % 1000 == 0: 282 | x_sample = diffusion.ddim_sample_loop_progressive(model, (4,3,256,256), task = "None", progress= True, model_kwargs = model_kwargs) 283 | np.save("/nfs/turbo/coe-liyues/bowenbw/3DCT/x_sample_ddim_iter" + str(m + pretrainsteps) + "_finetune_4232024_cond_lidc.npy", x_sample.detach().cpu().numpy()) 284 | 285 | loss = diffusion.training_losses(model, x_orig, t, model_kwargs=model_kwargs)["loss"] 286 | loss = loss.mean() 287 | loss.backward() 288 | opt.step() 289 | opt.zero_grad() 290 | # print(x_orig.dtype) 291 | # loss = diffusion.training_losses(model, x_orig, t)["loss"] 292 | # loss= loss.mean() 293 | # loss.backward() 294 | # opt.step() 295 | # opt.zero_grad() 296 | if m % 10 == 0: 297 | print(loss.item(), "loss", " at ", m, "th iteration") 298 | # #################################################################################################################### 299 | if m % 2000 == 99: 300 | torch.save({'iterations':m,'state_dict': model.state_dict()}, "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/LIDC_triplane3D_finetunelarge_4232024_iter" + str(m + pretrainsteps) + "_cond.ckpt") 301 | 302 | #################################################################################### 303 | # model.eval() 304 | # print('Run DDS.', 305 | # f'{self.args.T_sampling} sampling steps.', 306 | # f'Task: {self.args.deg}.' 307 | # ) 308 | # self.dds(model) 309 | def main(): 310 | args, config = parse_args_and_config() 311 | diffusion_model = Diffusion(args, config) 312 | diffusion_model.train() 313 | 314 | 315 | if __name__ == "__main__": 316 | print("running training") 317 | main() -------------------------------------------------------------------------------- /training_triplane_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import tqdm 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from glob import glob 8 | import argparse 9 | import traceback 10 | import logging 11 | import yaml 12 | import sys 13 | import os 14 | import torch 15 | import numpy as np 16 | from datetime import datetime 17 | 18 | from pathlib import Path 19 | from guided_diffusion.CTDataset import * 20 | from guided_diffusion.train_util import * 21 | from guided_diffusion.script_util import create_model, create_gaussian_diffusion 22 | from skimage.metrics import peak_signal_noise_ratio 23 | from pathlib import Path 24 | from physics.ct import CT 25 | from physics.mri import SinglecoilMRI_comp, MulticoilMRI 26 | from utils import CG, clear, get_mask, nchw_comp_to_real, real_to_nchw_comp, normalize_np, get_beta_schedule 27 | from functools import partial 28 | 29 | def compute_alpha(beta, t): 30 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 31 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 32 | return a 33 | 34 | def dict2namespace(config): 35 | namespace = argparse.Namespace() 36 | for key, value in config.items(): 37 | if isinstance(value, dict): 38 | new_value = dict2namespace(value) 39 | else: 40 | new_value = value 41 | setattr(namespace, key, new_value) 42 | return namespace 43 | 44 | def parse_args_and_config(): 45 | parser = argparse.ArgumentParser(description=globals()["__doc__"]) 46 | parser.add_argument( 47 | "--config", type=str, required=True, help="Path to the config file" 48 | ) 49 | parser.add_argument( 50 | "--type", type=str, required=True, help="Either [2d, 3d]" 51 | ) 52 | parser.add_argument( 53 | "--CG_iter", type=int, default=5, help="Inner number of iterations for CG" 54 | ) 55 | parser.add_argument( 56 | "--Nview", type=int, default=16, help="number of projections for CT" 57 | ) 58 | parser.add_argument("--seed", type=int, default=1234, help="Set different seeds for diverse results") 59 | parser.add_argument( 60 | "--exp", type=str, default="./exp", help="Path for saving running related data." 61 | ) 62 | parser.add_argument( 63 | "--ckpt_load_name", type=str, default="AAPM256_1M.pt", help="Load pre-trained ckpt" 64 | ) 65 | parser.add_argument( 66 | "--deg", type=str, required=True, help="Degradation" 67 | ) 68 | parser.add_argument( 69 | "--sigma_y", type=float, default=0., help="sigma_y" 70 | ) 71 | parser.add_argument( 72 | "--eta", type=float, default=0.85, help="Eta" 73 | ) 74 | parser.add_argument( 75 | "--rho", type=float, default=10.0, help="rho" 76 | ) 77 | parser.add_argument( 78 | "--lamb", type=float, default=0.04, help="lambda for TV" 79 | ) 80 | parser.add_argument( 81 | "--gamma", type=float, default=1.0, help="regularizer for noisy recon" 82 | ) 83 | parser.add_argument( 84 | "--T_sampling", type=int, default=50, help="Total number of sampling steps" 85 | ) 86 | parser.add_argument( 87 | "--resume_checkpoint", type=bool, default=False, help="resume training from a previous checkpoint" 88 | ) 89 | parser.add_argument( 90 | "-i", 91 | "--image_folder", 92 | type=str, 93 | default="./results", 94 | help="The folder name of samples", 95 | ) 96 | parser.add_argument( 97 | "--dataset_path", type=str, default="/media/harry/tomo/AAPM_data_vol/256_sorted/L067", help="The folder of the dataset" 98 | ) 99 | # MRI-exp arguments 100 | parser.add_argument( 101 | "--mask_type", type=str, default="uniform1d", help="Undersampling type" 102 | ) 103 | parser.add_argument( 104 | "--acc_factor", type=int, default=4, help="acceleration factor" 105 | ) 106 | parser.add_argument( 107 | "--nspokes", type=int, default=30, help="Number of sampled lines in radial trajectory" 108 | ) 109 | parser.add_argument( 110 | "--center_fraction", type=float, default=0.08, help="ACS region" 111 | ) 112 | args = parser.parse_args() 113 | 114 | # parse config file 115 | with open(os.path.join("configs/vp", args.config), "r") as f: 116 | config = yaml.safe_load(f) 117 | new_config = dict2namespace(config) 118 | 119 | if "CT" in args.deg: 120 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"view{args.Nview}" 121 | elif "MRI" in args.deg: 122 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"{args.mask_type}_acc{args.acc_factor}" 123 | 124 | args.image_folder.mkdir(exist_ok=True, parents=True) 125 | if not os.path.exists(args.image_folder): 126 | os.makedirs(args.image_folder) 127 | # add device 128 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 129 | logging.info("Using device: {}".format(device)) 130 | new_config.device = device 131 | # set random seed 132 | torch.manual_seed(args.seed) 133 | np.random.seed(args.seed) 134 | if torch.cuda.is_available(): 135 | torch.cuda.manual_seed_all(args.seed) 136 | torch.backends.cudnn.benchmark = True 137 | return args, new_config 138 | 139 | class Diffusion(object): 140 | def __init__(self, args, config, device=None): 141 | self.args = args 142 | self.args.image_folder = Path(self.args.image_folder) 143 | for t in ["input", "recon", "label"]: 144 | if t == "recon": 145 | (self.args.image_folder / t / "progress").mkdir(exist_ok=True, parents=True) 146 | else: 147 | (self.args.image_folder / t).mkdir(exist_ok=True, parents=True) 148 | self.config = config 149 | if device is None: 150 | device = ( 151 | torch.device("cuda") 152 | if torch.cuda.is_available() 153 | else torch.device("cpu") 154 | ) 155 | self.device = device 156 | 157 | self.model_var_type = config.model.var_type 158 | betas = get_beta_schedule( 159 | beta_schedule=config.diffusion.beta_schedule, 160 | beta_start=config.diffusion.beta_start, 161 | beta_end=config.diffusion.beta_end, 162 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 163 | ) 164 | betas = self.betas = torch.from_numpy(betas).float().to(self.device) 165 | self.num_timesteps = betas.shape[0] 166 | 167 | alphas = 1.0 - betas 168 | alphas_cumprod = alphas.cumprod(dim=0) 169 | alphas_cumprod_prev = torch.cat( 170 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0 171 | ) 172 | self.alphas_cumprod_prev = alphas_cumprod_prev 173 | posterior_variance = ( 174 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 175 | ) 176 | if self.model_var_type == "fixedlarge": 177 | self.logvar = betas.log() 178 | elif self.model_var_type == "fixedsmall": 179 | self.logvar = posterior_variance.clamp(min=1e-20).log() 180 | 181 | def train(self): 182 | config_dict = vars(self.config.model) 183 | config_dict["class_cond"] = True 184 | config_dict["use_spacecode"] = False 185 | print(config_dict) 186 | model = create_model(**config_dict) 187 | print(model.use_spacecode, "using spacecode") 188 | ckpt = os.path.join(self.args.exp, "vp", self.args.ckpt_load_name) 189 | 190 | pretrainsteps = 0 191 | if self.args.resume_checkpoint is True: 192 | print("resuming training") 193 | ckpt = "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/triplane3D_finetune_452024_iter50099_cond.ckpt" 194 | pretrainsteps = 50000 195 | else: 196 | ckpt = "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/256x256_diffusion_uncond.pt" 197 | 198 | loaded = torch.load(ckpt, map_location=self.device) 199 | for key in loaded['state_dict']: 200 | print(key, "printed") 201 | model.load_state_dict(torch.load(ckpt, map_location=self.device)['state_dict']) 202 | 203 | print(f"Model ckpt loaded from {ckpt}") 204 | model.to("cuda") 205 | model.train() 206 | 207 | diffusion = create_gaussian_diffusion( 208 | steps=1000, 209 | learn_sigma=True, 210 | noise_schedule="linear", 211 | use_kl=False, 212 | predict_xstart=False, 213 | rescale_timesteps=False, 214 | rescale_learned_sigmas=False, 215 | timestep_respacing="", 216 | ) 217 | print(diffusion.training_losses, "training loss") 218 | 219 | lr = 1e-5 220 | params = list(model.parameters()) 221 | 222 | opt = torch.optim.AdamW(params, lr=lr) 223 | 224 | #############################testing feed numpy matrix into the training script######################## 225 | 226 | #########################use the given trainer in improved_diffusion######################### 227 | 228 | """use the given trainer in improved_diffusion 229 | ds = CTDataset() 230 | params = {'batch_size': 4} 231 | training_generator = torch.utils.data.DataLoader(ds, **params) 232 | def load_data(loader): 233 | while True: 234 | yield from loader 235 | 236 | data = load_data(training_generator) 237 | 238 | TrainLoop2( 239 | model=model, 240 | diffusion=diffusion, 241 | data=data, 242 | batch_size=4, 243 | microbatch=-1, 244 | lr=3e-4, 245 | ema_rate="0.9999", 246 | log_interval=10, 247 | save_interval=2500, 248 | resume_checkpoint="", 249 | use_fp16=False, 250 | fp16_scale_growth=1e-3, 251 | schedule_sampler="uniform", 252 | weight_decay=0.0, 253 | lr_anneal_steps=0, 254 | ).run_loop() 255 | """ 256 | ################################train manually#################################### 257 | # files = glob('/nfs/turbo/coe-liyues/bowenbw/3DCT/AAPM_fusion_training/*') 258 | files = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/AAPM_fusion_training_cond/*") 259 | files2 = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/slice_fusion_training/*") 260 | 261 | 262 | for m in range(25100): 263 | 264 | x_train = np.zeros((4, 3, 256, 256)) 265 | y = torch.randint(0, 3, (4,)) 266 | for l in range(4): 267 | luck = np.random.randint(0, 2) 268 | if luck == 0: 269 | filename = np.random.choice(files) 270 | x_raw = np.transpose(np.load(filename), (2,0,1))[0:3] 271 | x_raw = np.clip(x_raw*2-1, -1, 1) 272 | x_train[l] = x_raw.copy() 273 | y_val = int((filename.split(".")[0]).split("_")[-1]) 274 | y[l] = y_val 275 | print(y_val) 276 | else: 277 | filename = np.random.choice(files2) 278 | x_raw = np.transpose(np.load(filename), (2,0,1))[0:3] 279 | x_raw = np.clip(x_raw*2-1, -1, 1) 280 | x_train[l] = x_raw.copy() 281 | y[l] = 3 282 | x_orig = torch.from_numpy(x_train).to("cuda").to(torch.float) 283 | i = torch.randint(0, 1000, (4,)) 284 | t = i.to("cuda").long() 285 | y = y.to("cuda").long() 286 | 287 | 288 | model_kwargs = {} 289 | model_kwargs["y"] = y 290 | 291 | # if m % 1000 == 0: 292 | # x_sample = diffusion.ddim_sample_loop_progressive(model, (4,3,256,256), task = "None", progress= True, model_kwargs = model_kwargs) 293 | # np.save("/nfs/turbo/coe-liyues/bowenbw/3DCT/x_sample_ddim_iter" + str(m + pretrainsteps) + "_finetune_452024_cond.npy", x_sample.detach().cpu().numpy()) 294 | 295 | loss = diffusion.training_losses(model, x_orig, t, model_kwargs=model_kwargs)["loss"] 296 | loss = loss.mean() 297 | loss.backward() 298 | opt.step() 299 | opt.zero_grad() 300 | # print(x_orig.dtype) 301 | # loss = diffusion.training_losses(model, x_orig, t)["loss"] 302 | # loss= loss.mean() 303 | # loss.backward() 304 | # opt.step() 305 | # opt.zero_grad() 306 | t0 = datetime.now() 307 | if m % 20 == 0: 308 | print(loss.item(), "loss", " at ", m, "th iteration") 309 | t1 = datetime.now() 310 | print(t1 - t0, "time elapsed") 311 | t0 = t1 312 | # #################################################################################################################### 313 | # if m % 5000 == 99: 314 | # torch.save({'iterations':m,'state_dict': model.state_dict()}, "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/triplane3D_finetune_452024_iter" + str(m + pretrainsteps) + "_cond.ckpt") 315 | 316 | #################################################################################### 317 | # model.eval() 318 | # print('Run DDS.', 319 | # f'{self.args.T_sampling} sampling steps.', 320 | # f'Task: {self.args.deg}.' 321 | # ) 322 | # self.dds(model) 323 | def main(): 324 | args, config = parse_args_and_config() 325 | diffusion_model = Diffusion(args, config) 326 | diffusion_model.train() 327 | 328 | 329 | if __name__ == "__main__": 330 | print("running training") 331 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import os 5 | import logging 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from statistics import mean, stdev 9 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 10 | from scipy.ndimage import gaussian_laplace 11 | import functools 12 | from physics.fastmri_utils import fft2c_new, ifft2c_new 13 | 14 | 15 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 16 | def sigmoid(x): 17 | return 1 / (np.exp(-x) + 1) 18 | 19 | if beta_schedule == "quad": 20 | betas = ( 21 | np.linspace( 22 | beta_start ** 0.5, 23 | beta_end ** 0.5, 24 | num_diffusion_timesteps, 25 | dtype=np.float64, 26 | ) 27 | ** 2 28 | ) 29 | elif beta_schedule == "linear": 30 | betas = np.linspace( 31 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 32 | ) 33 | elif beta_schedule == "const": 34 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 35 | elif beta_schedule == "jsd": 36 | betas = 1.0 / np.linspace( 37 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 38 | ) 39 | elif beta_schedule == "sigmoid": 40 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 41 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 42 | else: 43 | raise NotImplementedError(beta_schedule) 44 | assert betas.shape == (num_diffusion_timesteps,) 45 | return betas 46 | 47 | 48 | def get_sigma(t, sde): 49 | """ VE-SDE """ 50 | sigma_t = sde.sigma_min * (sde.sigma_max / sde.sigma_min) ** t 51 | return sigma_t 52 | 53 | 54 | def pred_x0_from_s(xt, s, t, sde): 55 | """ Tweedie's formula for denoising. Assumes VE-SDE """ 56 | sigma_t = get_sigma(t, sde) 57 | tmp = sigma_t.view(sigma_t.shape[0], 1, 1, 1) 58 | pred_x0 = xt + (tmp ** 2) * s 59 | return pred_x0 60 | 61 | 62 | def recover_xt_from_x0(x0_t, s, t, sde): 63 | sigma_t = get_sigma(t, sde) 64 | tmp = sigma_t.view(sigma_t.shape[0], 1, 1, 1) 65 | xt = x0_t - (tmp ** 2) * s 66 | return xt 67 | 68 | 69 | def pred_eps_from_s(s, t, sde): 70 | sigma_t = get_sigma(t, sde) 71 | tmp = sigma_t.view(sigma_t.shape[0], 1, 1, 1) 72 | pred_eps = -tmp * s 73 | return pred_eps 74 | 75 | 76 | def _Dz(x): # Batch direction 77 | y = torch.zeros_like(x) 78 | y[:-1] = x[1:] 79 | y[-1] = x[0] 80 | return y - x 81 | 82 | def _Dzx(x): 83 | y = torch.zeros_like(x) 84 | y[:,:,:,:-1] = x[:,:,:,1:] 85 | y[:,:,:,-1] = x[:,:,:,0] 86 | return y - x 87 | 88 | def _DzxT(x): 89 | y = torch.zeros_like(x) 90 | y[:,:,:,:-1] = x[:,:,:,1:] 91 | y[:,:,:,-1] = x[:,:,:,0] 92 | 93 | tempt = -(y-x) 94 | difft = tempt[:,:,:,:-1] 95 | y[:,:,:,1:] = difft 96 | y[:,:,:,0] = x[:,:,:,-1] - x[:,:,:,0] 97 | 98 | return y - x 99 | 100 | 101 | def _Dzy(x): 102 | y = torch.zeros_like(x) 103 | y[:,:,:-1,:] = x[:,:,1:,:] 104 | y[:,:,-1,:] = x[:,:,0,:] 105 | return y - x 106 | 107 | def _DzyT(x): 108 | y = torch.zeros_like(x) 109 | y[:,:,:-1,:] = x[:,:,1:,:] 110 | y[:,:,-1,:] = x[:,:,0,:] 111 | 112 | tempt = -(y-x) 113 | difft = tempt[:,:,:-1,:] 114 | y[:,:,1:,:] = difft 115 | y[:,:,0,:] = x[:,:,-1,:] - x[:,:,0,:] 116 | 117 | return y - x 118 | 119 | 120 | def _DzT(x): # Batch direction 121 | y = torch.zeros_like(x) 122 | y[:-1] = x[1:] 123 | y[-1] = x[0] 124 | 125 | tempt = -(y-x) 126 | difft = tempt[:-1] 127 | y[1:] = difft 128 | y[0] = x[-1] - x[0] 129 | 130 | return y 131 | 132 | 133 | 134 | def _Dx(x): # Batch direction 135 | y = torch.zeros_like(x) 136 | y[:, :, :-1, :] = x[:, :, 1:, :] 137 | y[:, :, -1, :] = x[:, :, 0, :] 138 | return y - x 139 | 140 | 141 | def _DxT(x): # Batch direction 142 | y = torch.zeros_like(x) 143 | y[:, :, :-1, :] = x[:, :, 1:, :] 144 | y[:, :, -1, :] = x[:, :, 0, :] 145 | tempt = -(y - x) 146 | difft = tempt[:, :, :-1, :] 147 | y[:, :, 1:, :] = difft 148 | y[:, :, 0, :] = x[:, :, -1, :] - x[:, :, 0, :] 149 | return y 150 | 151 | 152 | def _Dy(x): # Batch direction 153 | y = torch.zeros_like(x) 154 | y[:, :, :, :-1] = x[:, :, :, 1:] 155 | y[:, :, :, -1] = x[:, :, :, 0] 156 | return y - x 157 | 158 | 159 | def _DyT(x): # Batch direction 160 | y = torch.zeros_like(x) 161 | y[:, :, :, :-1] = x[:, :, :, 1:] 162 | y[:, :, :, -1] = x[:, :, :, 0] 163 | tempt = -(y - x) 164 | difft = tempt[:, :, :, :-1] 165 | y[:, :, :, 1:] = difft 166 | y[:, :, :, 0] = x[:, :, :, -1] - x[:, :, :, 0] 167 | return y 168 | 169 | 170 | 171 | def CG(A, b, x, n_inner=5, eps=1e-5): 172 | r = b - A(x) 173 | p = r.clone() 174 | rsold = torch.matmul(r.view(1, -1), r.view(1, -1).T) 175 | 176 | for i in range(n_inner): 177 | Ap = A(p) 178 | a = rsold / torch.matmul(p.view(1, -1), Ap.view(1, -1).T) 179 | 180 | x = x + a * p 181 | r = r - a * Ap 182 | 183 | rsnew = torch.matmul(r.view(1, -1), r.view(1, -1).T) 184 | if torch.abs(torch.sqrt(rsnew)) < eps: 185 | break 186 | p = r + (rsnew / rsold) * p 187 | rsold = rsnew 188 | return x 189 | 190 | 191 | def shrink(src, lamb): 192 | return torch.sign(src) * torch.max(torch.abs(src)-lamb, torch.zeros_like(src)) 193 | 194 | 195 | def clear_color(x): 196 | x = x.detach().cpu().squeeze().numpy() 197 | return np.transpose(x, (1, 2, 0)) 198 | 199 | 200 | def clear(x): 201 | x = x.detach().cpu().squeeze().numpy() 202 | return x 203 | 204 | 205 | def restore_checkpoint(ckpt_dir, state, device, skip_sigma=False, skip_optimizer=False): 206 | ckpt_dir = Path(ckpt_dir) 207 | if not ckpt_dir.exists(): 208 | ckpt_dir.mkdir(parents=True) 209 | logging.error(f"No checkpoint found at {ckpt_dir}. " 210 | f"Returned the same state as input") 211 | FileNotFoundError(f'No such checkpoint: {ckpt_dir} found!') 212 | return state 213 | else: 214 | loaded_state = torch.load(ckpt_dir, map_location=device) 215 | if not skip_optimizer: 216 | state['optimizer'].load_state_dict(loaded_state['optimizer']) 217 | loaded_model_state = loaded_state['model'] 218 | if skip_sigma: 219 | loaded_model_state.pop('module.sigmas') 220 | 221 | state['model'].load_state_dict(loaded_model_state, strict=False) 222 | state['ema'].load_state_dict(loaded_state['ema']) 223 | state['step'] = loaded_state['step'] 224 | print(f'loaded checkpoint dir from {ckpt_dir}') 225 | return state 226 | 227 | 228 | def save_checkpoint(ckpt_dir, state): 229 | saved_state = { 230 | 'optimizer': state['optimizer'].state_dict(), 231 | 'model': state['model'].state_dict(), 232 | 'ema': state['ema'].state_dict(), 233 | 'step': state['step'] 234 | } 235 | torch.save(saved_state, ckpt_dir) 236 | 237 | 238 | """ 239 | Helper functions for new types of inverse problems 240 | """ 241 | 242 | 243 | 244 | def fft2(x): 245 | """ FFT with shifting DC to the center of the image""" 246 | return torch.fft.fftshift(torch.fft.fft2(x), dim=[-1, -2]) 247 | 248 | 249 | def ifft2(x): 250 | """ IFFT with shifting DC to the corner of the image prior to transform""" 251 | return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-1, -2])) 252 | 253 | 254 | def fft2_m(x): 255 | """ FFT for multi-coil """ 256 | return torch.view_as_complex(fft2c_new(torch.view_as_real(x))) 257 | 258 | 259 | def ifft2_m(x): 260 | """ IFFT for multi-coil """ 261 | return torch.view_as_complex(ifft2c_new(torch.view_as_real(x))) 262 | 263 | 264 | def crop_center(img, cropx, cropy): 265 | c, y, x = img.shape 266 | startx = x // 2 - (cropx // 2) 267 | starty = y // 2 - (cropy // 2) 268 | return img[:, starty:starty + cropy, startx:startx + cropx] 269 | 270 | 271 | def normalize(img): 272 | """ Normalize img in arbitrary range to [0, 1] """ 273 | img -= torch.min(img) 274 | img /= torch.max(img) 275 | return img 276 | 277 | 278 | def normalize_np(img): 279 | """ Normalize img in arbitrary range to [0, 1] """ 280 | img -= np.min(img) 281 | img /= np.max(img) 282 | return img 283 | 284 | 285 | def normalize_np_kwarg(img, maxv=1.0, minv=0.0): 286 | """ Normalize img in arbitrary range to [0, 1] """ 287 | img -= minv 288 | img /= maxv 289 | return img 290 | 291 | 292 | def normalize_complex(img): 293 | """ normalizes the magnitude of complex-valued image to range [0, 1] """ 294 | abs_img = normalize(torch.abs(img)) 295 | # ang_img = torch.angle(img) 296 | ang_img = normalize(torch.angle(img)) 297 | return abs_img * torch.exp(1j * ang_img) 298 | 299 | 300 | def batchfy(tensor, batch_size): 301 | n = len(tensor) 302 | num_batches = n // batch_size + 1 303 | return tensor.chunk(num_batches, dim=0) 304 | 305 | 306 | def img_wise_min_max(img): 307 | img_flatten = img.view(img.shape[0], -1) 308 | img_min = torch.min(img_flatten, dim=-1)[0].view(-1, 1, 1, 1) 309 | img_max = torch.max(img_flatten, dim=-1)[0].view(-1, 1, 1, 1) 310 | 311 | return (img - img_min) / (img_max - img_min) 312 | 313 | 314 | def patient_wise_min_max(img): 315 | std_upper = 3 316 | img_flatten = img.view(img.shape[0], -1) 317 | 318 | std = torch.std(img) 319 | mean = torch.mean(img) 320 | 321 | img_min = torch.min(img_flatten, dim=-1)[0].view(-1, 1, 1, 1) 322 | img_max = torch.max(img_flatten, dim=-1)[0].view(-1, 1, 1, 1) 323 | 324 | min_max_scaled = (img - img_min) / (img_max - img_min) 325 | min_max_scaled_std = (std - img_min) / (img_max - img_min) 326 | min_max_scaled_mean = (mean - img_min) / (img_max - img_min) 327 | 328 | min_max_scaled[min_max_scaled > min_max_scaled_mean + 329 | std_upper * min_max_scaled_std] = 1 330 | 331 | return min_max_scaled 332 | 333 | 334 | def create_sphere(cx, cy, cz, r, resolution=256): 335 | ''' 336 | create sphere with center (cx, cy, cz) and radius r 337 | ''' 338 | phi = np.linspace(0, 2 * np.pi, 2 * resolution) 339 | theta = np.linspace(0, np.pi, resolution) 340 | 341 | theta, phi = np.meshgrid(theta, phi) 342 | 343 | r_xy = r * np.sin(theta) 344 | x = cx + np.cos(phi) * r_xy 345 | y = cy + np.sin(phi) * r_xy 346 | z = cz + r * np.cos(theta) 347 | 348 | return np.stack([x, y, z]) 349 | 350 | 351 | class lambda_schedule: 352 | def __init__(self, total=2000): 353 | self.total = total 354 | 355 | def get_current_lambda(self, i): 356 | pass 357 | 358 | 359 | class lambda_schedule_linear(lambda_schedule): 360 | def __init__(self, start_lamb=1.0, end_lamb=0.0): 361 | super().__init__() 362 | self.start_lamb = start_lamb 363 | self.end_lamb = end_lamb 364 | 365 | def get_current_lambda(self, i): 366 | return self.start_lamb + (self.end_lamb - self.start_lamb) * (i / self.total) 367 | 368 | 369 | class lambda_schedule_const(lambda_schedule): 370 | def __init__(self, lamb=1.0): 371 | super().__init__() 372 | self.lamb = lamb 373 | 374 | def get_current_lambda(self, i): 375 | return self.lamb 376 | 377 | 378 | def image_grid(x, sz=32): 379 | size = sz 380 | channels = 3 381 | img = x.reshape(-1, size, size, channels) 382 | w = int(np.sqrt(img.shape[0])) 383 | img = img.reshape((w, w, size, size, channels)).transpose( 384 | (0, 2, 1, 3, 4)).reshape((w * size, w * size, channels)) 385 | return img 386 | 387 | 388 | def show_samples(x, sz=32): 389 | x = x.permute(0, 2, 3, 1).detach().cpu().numpy() 390 | img = image_grid(x, sz) 391 | plt.figure(figsize=(8, 8)) 392 | plt.axis('off') 393 | plt.imshow(img) 394 | plt.show() 395 | 396 | 397 | def image_grid_gray(x, size=32): 398 | img = x.reshape(-1, size, size) 399 | w = int(np.sqrt(img.shape[0])) 400 | img = img.reshape((w, w, size, size)).transpose( 401 | (0, 2, 1, 3)).reshape((w * size, w * size)) 402 | return img 403 | 404 | 405 | def show_samples_gray(x, size=32, save=False, save_fname=None): 406 | x = x.detach().cpu().numpy() 407 | img = image_grid_gray(x, size=size) 408 | plt.figure(figsize=(8, 8)) 409 | plt.axis('off') 410 | plt.imshow(img, cmap='gray') 411 | plt.show() 412 | if save: 413 | plt.imsave(save_fname, img, cmap='gray') 414 | 415 | 416 | def get_mask(img, size, batch_size, type='gaussian2d', acc_factor=8, center_fraction=0.04, fix=False): 417 | mux_in = size ** 2 418 | if type.endswith('2d'): 419 | Nsamp = mux_in // acc_factor 420 | elif type.endswith('1d'): 421 | Nsamp = size // acc_factor 422 | if type == 'gaussian2d': 423 | mask = torch.zeros_like(img) 424 | cov_factor = size * (1.5 / 128) 425 | mean = [size // 2, size // 2] 426 | cov = [[size * cov_factor, 0], [0, size * cov_factor]] 427 | if fix: 428 | samples = np.random.multivariate_normal(mean, cov, int(Nsamp)) 429 | int_samples = samples.astype(int) 430 | int_samples = np.clip(int_samples, 0, size - 1) 431 | mask[..., int_samples[:, 0], int_samples[:, 1]] = 1 432 | else: 433 | for i in range(batch_size): 434 | # sample different masks for batch 435 | samples = np.random.multivariate_normal(mean, cov, int(Nsamp)) 436 | int_samples = samples.astype(int) 437 | int_samples = np.clip(int_samples, 0, size - 1) 438 | mask[i, :, int_samples[:, 0], int_samples[:, 1]] = 1 439 | elif type == 'uniformrandom2d': 440 | mask = torch.zeros_like(img) 441 | if fix: 442 | mask_vec = torch.zeros([1, size * size]) 443 | samples = np.random.choice(size * size, int(Nsamp)) 444 | mask_vec[:, samples] = 1 445 | mask_b = mask_vec.view(size, size) 446 | mask[:, ...] = mask_b 447 | else: 448 | for i in range(batch_size): 449 | # sample different masks for batch 450 | mask_vec = torch.zeros([1, size * size]) 451 | samples = np.random.choice(size * size, int(Nsamp)) 452 | mask_vec[:, samples] = 1 453 | mask_b = mask_vec.view(size, size) 454 | mask[i, ...] = mask_b 455 | elif type == 'gaussian1d': 456 | mask = torch.zeros_like(img) 457 | mean = size // 2 458 | std = size * (15.0 / 128) 459 | Nsamp_center = int(size * center_fraction) 460 | if fix: 461 | samples = np.random.normal( 462 | loc=mean, scale=std, size=int(Nsamp * 1.2)) 463 | int_samples = samples.astype(int) 464 | int_samples = np.clip(int_samples, 0, size - 1) 465 | mask[..., int_samples] = 1 466 | c_from = size // 2 - Nsamp_center // 2 467 | mask[..., c_from:c_from + Nsamp_center] = 1 468 | else: 469 | for i in range(batch_size): 470 | samples = np.random.normal( 471 | loc=mean, scale=std, size=int(Nsamp*1.2)) 472 | int_samples = samples.astype(int) 473 | int_samples = np.clip(int_samples, 0, size - 1) 474 | mask[i, :, :, int_samples] = 1 475 | c_from = size // 2 - Nsamp_center // 2 476 | mask[i, :, :, c_from:c_from + Nsamp_center] = 1 477 | elif type == 'uniform1d': 478 | mask = torch.zeros_like(img) 479 | if fix: 480 | Nsamp_center = int(size * center_fraction) 481 | samples = np.random.choice(size, int(Nsamp - Nsamp_center)) 482 | mask[..., samples] = 1 483 | # ACS region 484 | c_from = size // 2 - Nsamp_center // 2 485 | mask[..., c_from:c_from + Nsamp_center] = 1 486 | else: 487 | for i in range(batch_size): 488 | Nsamp_center = int(size * center_fraction) 489 | samples = np.random.choice(size, int(Nsamp - Nsamp_center)) 490 | mask[i, :, :, samples] = 1 491 | # ACS region 492 | c_from = size // 2 - Nsamp_center // 2 493 | mask[i, :, :, c_from:c_from+Nsamp_center] = 1 494 | else: 495 | NotImplementedError(f'Mask type {type} is currently not supported.') 496 | 497 | return mask 498 | 499 | 500 | def nchw_comp_to_real(x): 501 | """ 502 | [1, 1, 320, 320] comp --> [1, 2, 320, 320] real 503 | """ 504 | x = torch.view_as_real(x) 505 | x = x.squeeze(dim=1) 506 | x = x.permute(0, 3, 1, 2) 507 | return x 508 | 509 | def real_to_nchw_comp(x): 510 | """ 511 | [1, 2, 320, 320] real --> [1, 1, 320, 320] comp 512 | """ 513 | if len(x.shape) == 4: 514 | x = x[:, 0:1, :, :] + x[:, 1:2, :, :] * 1j 515 | elif len(x.shape) == 3: 516 | x = x[0:1, :, :] + x[1:2, :, :] * 1j 517 | return x 518 | 519 | 520 | def kspace_to_nchw(tensor): 521 | """ 522 | Convert torch tensor in (Slice, Coil, Height, Width, Complex) 5D format to 523 | (N, C, H, W) 4D format for processing by 2D CNNs. 524 | 525 | Complex indicates (real, imag) as 2 channels, the complex data format for Pytorch. 526 | 527 | C is the coils interleaved with real and imaginary values as separate channels. 528 | C is therefore always 2 * Coil. 529 | 530 | Singlecoil data is assumed to be in the 5D format with Coil = 1 531 | 532 | Args: 533 | tensor (torch.Tensor): Input data in 5D kspace tensor format. 534 | Returns: 535 | tensor (torch.Tensor): tensor in 4D NCHW format to be fed into a CNN. 536 | """ 537 | assert isinstance(tensor, torch.Tensor) 538 | assert tensor.dim() == 5 539 | s = tensor.shape 540 | assert s[-1] == 2 541 | tensor = tensor.permute(dims=(0, 1, 4, 2, 3)).reshape( 542 | shape=(s[0], 2 * s[1], s[2], s[3])) 543 | return tensor 544 | 545 | 546 | def nchw_to_kspace(tensor): 547 | """ 548 | Convert a torch tensor in (N, C, H, W) format to the (Slice, Coil, Height, Width, Complex) format. 549 | 550 | This function assumes that the real and imaginary values of a coil are always adjacent to one another in C. 551 | If the coil dimension is not divisible by 2, the function assumes that the input data is 'real' data, 552 | and thus pads the imaginary dimension as 0. 553 | """ 554 | assert isinstance(tensor, torch.Tensor) 555 | assert tensor.dim() == 4 556 | s = tensor.shape 557 | if tensor.shape[1] == 1: 558 | imag_tensor = torch.zeros(s, device=tensor.device) 559 | tensor = torch.cat((tensor, imag_tensor), dim=1) 560 | s = tensor.shape 561 | tensor = tensor.view( 562 | size=(s[0], s[1] // 2, 2, s[2], s[3])).permute(dims=(0, 1, 3, 4, 2)) 563 | return tensor 564 | 565 | 566 | def root_sum_of_squares(data, dim=0): 567 | """ 568 | Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor. 569 | Args: 570 | data (torch.Tensor): The input tensor 571 | dim (int): The dimensions along which to apply the RSS transform 572 | Returns: 573 | torch.Tensor: The RSS value 574 | """ 575 | return torch.sqrt((data ** 2).sum(dim)) 576 | 577 | 578 | def save_data(fname, arr): 579 | """ Save data as .npy and .png """ 580 | np.save(fname + '.npy', arr) 581 | plt.imsave(fname + '.png', arr, cmap='gray') 582 | 583 | 584 | def mean_std(vals: list): 585 | return mean(vals), stdev(vals) --------------------------------------------------------------------------------