├── README.md
├── eval_3DCT_blendcond.py
├── eval_3D_blend_cond.sh
├── eval_3D_blend_cond_lidc.py
├── guided_diffusion
├── CTDataset.py
├── __pycache__
│ ├── CTDataset.cpython-38.pyc
│ ├── dist_util.cpython-38.pyc
│ ├── fp16_util.cpython-38.pyc
│ ├── fp16_util.cpython-39.pyc
│ ├── gaussian_diffusion.cpython-38.pyc
│ ├── gaussian_diffusion.cpython-39.pyc
│ ├── logger.cpython-38.pyc
│ ├── logger.cpython-39.pyc
│ ├── losses.cpython-38.pyc
│ ├── losses.cpython-39.pyc
│ ├── models.cpython-38.pyc
│ ├── models.cpython-39.pyc
│ ├── nn.cpython-38.pyc
│ ├── nn.cpython-39.pyc
│ ├── resample.cpython-38.pyc
│ ├── respace.cpython-38.pyc
│ ├── respace.cpython-39.pyc
│ ├── script_util.cpython-38.pyc
│ ├── script_util.cpython-39.pyc
│ ├── train_util.cpython-38.pyc
│ ├── unet.cpython-38.pyc
│ ├── unet.cpython-39.pyc
│ ├── utils.cpython-38.pyc
│ └── utils.cpython-39.pyc
├── diffusion.py
├── dist_util.py
├── fp16_util.py
├── gaussian_diffusion.py
├── image_datasets.py
├── logger.py
├── losses.py
├── models.py
├── nn.py
├── resample.py
├── respace.py
├── script_util.py
├── train_util.py
├── training_script.py
├── training_triplane_script.py
├── unet.py
└── utils.py
├── main.py
├── train_SVCT_3D_triplane.sh
├── training_triplane_lidc.py
├── training_triplane_script.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # DiffusionBlend: Learning 3D Image Prior through Position-aware Diffusion Score Blending for 3D Computed Tomography Reconstruction
2 |
3 | This repository contains the official implementation of **"DiffusionBlend: Learning 3D Image Prior through Position-aware Diffusion Score Blending for 3D Computed Tomography Reconstruction"**, published at **NeurIPS 2024**.
4 |
5 | Paper link: https://openreview.net/forum?id=h3Kv6sdTWO&referrer=%5Bthe%20profile%20of%20Bowen%20Song%5D(%2Fprofile%3Fid%3D~Bowen_Song3)
6 |
7 |
8 | ---
9 |
10 | ## Overview
11 |
12 | DiffusionBlend introduces a novel method for 3D computed tomography (CT) reconstruction using position-aware diffusion score blending. By leveraging position-specific priors, the framework achieves enhanced reconstruction accuracy while maintaining computational efficiency.
13 |
14 |
15 |
16 |
17 | ---
18 |
19 | ## Features
20 |
21 | - **Position-aware Diffusion Blending:** Incorporates spatial information to refine 3D reconstruction quality.
22 | - **Triplane-based 3D Representation:** Utilizes a position encoding to model 3D patch priors efficiently.
23 | - **Scalable and Generalizable:** Designed for both synthetic and real-world CT reconstruction tasks.
24 |
25 | ---
26 |
27 | ## Requirements
28 |
29 | The code is implemented in Python and requires the following dependencies:
30 |
31 | - `torch` (>=1.9.0)
32 | - `torchvision`
33 | - `numpy`
34 |
35 | You can install the dependencies via:
36 |
37 | ```bash
38 | pip install torch torchvision numpy
39 | ```
40 |
41 | ## Training
42 | To train the model on synthetic volume CT data, use the following script:
43 |
44 | ```bash
45 | bash train_SVCT_3D_triplane.sh
46 | ```
47 |
48 | ## Inference
49 | To perform inference and evaluate 3D reconstruction using diffusion score blending, use:
50 |
51 | ```bash
52 | bash eval_3D_blend_cond.sh
53 | ```
54 |
55 | ## Results
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 | ## Citation
69 | If you find this work useful in your research, please cite:
70 |
71 | ```
72 | @inproceedings{diffusionblend2024,
73 | title={DiffusionBlend: Learning 3D Image Prior through Position-aware Diffusion Score Blending for 3D Computed Tomography Reconstruction},
74 | author={Song, Bowen and Hu, Jason and Luo, Zhaoxu and Fessler, Jeffrey A and Shen, Liyue},
75 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
76 | year={2024}
77 | }
78 | ```
79 |
80 | ## Acknowledgements
81 | We thank the contributors and the NeurIPS community for their valuable feedback and discussions.
82 |
83 | ## License
84 | This project is licensed under the MIT License. See the LICENSE file for details.
85 |
86 |
87 |
--------------------------------------------------------------------------------
/eval_3DCT_blendcond.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import time
4 | import glob
5 | import json
6 | import sys
7 | import math
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import tqdm
11 | import torch
12 | import torch.utils.data as data
13 | import torchvision.utils as tvu
14 |
15 | from guided_diffusion.models import Model
16 | from guided_diffusion.script_util import create_model, classifier_defaults, args_to_dict, create_gaussian_diffusion
17 | from guided_diffusion.utils import get_alpha_schedule
18 | import random
19 |
20 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity
21 | from scipy.linalg import orth
22 | from pathlib import Path
23 |
24 | from physics.ct import CT
25 | from time import time
26 | from utils import shrink, CG, clear, batchfy, _Dz, _DzT, get_beta_schedule
27 |
28 |
29 | def compute_alpha(beta, t):
30 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
31 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
32 | return a
33 |
34 |
35 | class Diffusion(object):
36 | def __init__(self, args, config, device=None):
37 | self.args = args
38 | self.args.image_folder = Path(self.args.image_folder)
39 | for t in ["input", "recon", "label"]:
40 | if t == "recon":
41 | (self.args.image_folder / t / "progress").mkdir(exist_ok=True, parents=True)
42 | else:
43 | (self.args.image_folder / t).mkdir(exist_ok=True, parents=True)
44 | self.config = config
45 | print(self.config)
46 | if device is None:
47 | device = (
48 | torch.device("cuda")
49 | if torch.cuda.is_available()
50 | else torch.device("cpu")
51 | )
52 | self.device = device
53 |
54 | self.model_var_type = config.model.var_type
55 | betas = get_beta_schedule(
56 | beta_schedule=config.diffusion.beta_schedule,
57 | beta_start=config.diffusion.beta_start,
58 | beta_end=config.diffusion.beta_end,
59 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
60 | )
61 | betas = self.betas = torch.from_numpy(betas).float().to(self.device)
62 | self.num_timesteps = betas.shape[0]
63 |
64 | alphas = 1.0 - betas
65 | alphas_cumprod = alphas.cumprod(dim=0)
66 | alphas_cumprod_prev = torch.cat(
67 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
68 | )
69 | self.alphas_cumprod_prev = alphas_cumprod_prev
70 | posterior_variance = (
71 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
72 | )
73 | if self.model_var_type == "fixedlarge":
74 | self.logvar = betas.log()
75 | elif self.model_var_type == "fixedsmall":
76 | self.logvar = posterior_variance.clamp(min=1e-20).log()
77 |
78 | def sample(self):
79 | config_dict = vars(self.config.model)
80 | config_dict['use_spacecode'] = False
81 | config_dict["class_cond"] = True
82 | model = create_model(**config_dict)
83 | ckpt = "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/triplane3D_finetune_452024_iter65099_cond.ckpt"
84 |
85 | model.load_state_dict(torch.load(ckpt, map_location=self.device)["state_dict"])
86 | print(f"Model ckpt loaded from {ckpt}")
87 | model.to(self.device)
88 | model.eval()
89 | model = torch.nn.DataParallel(model)
90 |
91 | print('Run 3D DDS + DiffusionMBIR.',
92 | f'{self.args.T_sampling} sampling steps.',
93 | f'Task: {self.args.deg}.'
94 | )
95 | self.dds3d(model)
96 |
97 |
98 | def blendscore(self, xt, model,t, start_ind = None, start_head = 0, num_batches = 180):
99 | model_kwargs = {}
100 | y = torch.ones(1) * 1
101 | y=y.to(xt.device).to(torch.long)
102 | model_kwargs["y"] = y
103 |
104 | et = torch.zeros((1, num_batches * 3, 256, 256)).to(xt.device).to(torch.float32)
105 | xt = torch.reshape(xt, (1, num_batches * 3, 256, 256))
106 | et[:,:3,:,:] = model(xt[:,:3,:,:], t, **model_kwargs)[:,:3]
107 | et[:,xt.shape[1]-3:, :,:] = model(xt[:,xt.shape[1]-3:,:,:], t, **model_kwargs)[:,:3]
108 |
109 | if start_ind is None:
110 | start_ind = np.random.randint(start_head,3)
111 | for j in range(start_ind, xt.shape[1]-2, 3):
112 | #####randomly select instead of summing
113 | et_sing = model(xt[:,j:(j+3),:,:], t, **model_kwargs)[:,:3] #####1 x 3 x 256 x 256
114 | et[:,j:(j+3), :,:] = et_sing
115 | return et
116 |
117 | def vps_blend(self, xt, model,t, start_ind = None, start_head = 0, num_batches = 180):
118 | model_kwargs = {}
119 | y = torch.ones(1) * 3
120 | y=y.to(xt.device).to(torch.long)
121 | model_kwargs["y"] = y
122 |
123 | et = torch.zeros((1, num_batches * 3, 256, 256)).to(xt.device).to(torch.float32)
124 |
125 | xt = torch.reshape(xt, (1, num_batches * 3, 256, 256))
126 | # 147258369
127 | for i in range(0,xt.shape[1], 9):
128 | for m in range(3):
129 | et[:,[i+m,i+m+3, i+m+6],:,:] = model(xt[:,[i+m,i+m+3, i+m+6],:,:], t, **model_kwargs)[:,:3]
130 |
131 | return et
132 |
133 |
134 | def dds3d(self, model):
135 | args, config = self.args, self.config
136 | print(f"Dataset path: {self.args.dataset_path}")
137 | root = Path(self.args.dataset_path)
138 |
139 | noise, noise_flag = self.args.sigma_y, False
140 | if noise > 0:
141 | noise_flag = True
142 |
143 | # parameters to be moved to args
144 | Nview = self.args.Nview
145 | rho = self.args.rho
146 | rho = 0.001 ###9:23pm 3/3
147 | lamb = self.args.lamb
148 | lamb = 0.05 * 1e-3
149 | n_ADMM = 1
150 | n_CG = self.args.CG_iter
151 | print(n_CG)
152 |
153 | blend= True ####test again 4/7
154 |
155 |
156 | time_travel = False ### 4:20 3/10
157 | vps_scale = 0.03
158 | vps = False ###debugging
159 |
160 | ddimsteps = 200 ###7/30 7:54pm
161 |
162 |
163 | # Specify save directory for saving generated samples
164 | save_root = Path(self.args.image_folder)
165 | save_root.mkdir(parents=True, exist_ok=True)
166 |
167 | irl_types = ['vol', 'input', 'recon', 'label']
168 | for t in irl_types:
169 | save_root_f = save_root / t
170 | save_root_f.mkdir(parents=True, exist_ok=True)
171 |
172 |
173 | ##################################new data##################################
174 | fname_list = os.listdir("/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/validation")
175 | root = "/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/validation"
176 | # fname_list = sorted(fname_list, key=lambda x: float(x.split(".")[0]))[:60]
177 | fname_list.sort()
178 |
179 | pre_slices = 0
180 | num_batches = 168
181 | fname_list = fname_list[pre_slices:(pre_slices+ 3 * num_batches)]
182 | ######################################################################################################
183 |
184 | print(fname_list)
185 | all_img = []
186 | batch_size = 3
187 | print("Loading all data")
188 | if time_travel:
189 | tot_iters = 2
190 | else:
191 | tot_iters = 1
192 | for fname in fname_list:
193 | just_name = fname.split('.')[0]
194 | img = torch.from_numpy(np.load(os.path.join(root, fname), allow_pickle=True))
195 | h, w = img.shape
196 | img = img.view(1, 1, h, w)
197 | all_img.append(img)
198 | all_img = torch.cat(all_img, dim=0)
199 | x_orig = all_img
200 | print(f"Data loaded shape : {all_img.shape}")
201 | x_orig = x_orig.to(torch.float32)
202 | print("Data type is :", x_orig.dtype)
203 | img_shape = (x_orig.shape[0], config.data.channels, config.data.image_size, config.data.image_size)
204 | if self.args.deg == "SV-CT":
205 | A_funcs = CT(img_width=256, radon_view=self.args.Nview, uniform=True, circle=False, device=config.device)
206 | elif self.args.deg == "LA-CT":
207 | A_funcs = CT(img_width=256, radon_view=self.args.Nview, uniform=False, circle=False, device=config.device)
208 | A = lambda z: A_funcs.A(z)
209 | Ap = lambda z: A_funcs.A_dagger(z)
210 | def Acg_TV(x):
211 | return A_funcs.AT(A_funcs.A(x)) + rho * _DzT(_Dz(x))
212 | def ADMM(x, ATy, n_ADMM=n_ADMM):
213 | nonlocal del_z, udel_z
214 | for _ in range(n_ADMM):
215 | bcg_TV = ATy + rho * (_DzT(del_z) - _DzT(udel_z))
216 | x = CG(Acg_TV, bcg_TV, x, n_inner=n_CG)
217 | del_z = shrink(_Dz(x) + udel_z, lamb / rho)
218 | udel_z = _Dz(x) - del_z + udel_z
219 | return x
220 | del_z = torch.zeros(img_shape, device=self.device)
221 | udel_z = torch.zeros(img_shape, device=self.device)
222 | x_orig = x_orig.to(self.device) ######n x 1 x 256 x 256
223 | print(x_orig.min(), x_orig.max(), "xorig")
224 | y = A(x_orig)
225 | print(y.shape, "projection shape")
226 |
227 | ###########################adding noise to projection#######################################
228 | if noise_flag:
229 | print("adding noise to projections")
230 | I0 = 1.11e6
231 | # y = (-(torch.log(1e4 * torch.exp(-y/256) + torch.randn_like(y) * 5) - math.log(1e4))*256) ###gaussian noise
232 | y = -(torch.log(torch.poisson(I0 * torch.exp(-y/18)) + torch.randn_like(y) * 5) - math.log(I0))*18 ##poisson gaussian noise
233 |
234 | Apy = Ap(y)
235 | print(Apy.shape, "Apy backprojection shape")
236 | ATy = A_funcs.AT(y)
237 | ##########################original####################################################
238 | # x = torch.randn(20, 3, 256, 256, device = self.device) ####initial noise
239 |
240 | ########forward init############################
241 | t = (torch.ones(500)).to(self.device)
242 | at = compute_alpha(self.betas, t.long())
243 | # at_next = compute_alpha(self.betas, next_t.long())
244 | at = at[0,0,0,0]
245 | init_noise = at.sqrt() * x_orig + torch.randn_like(x_orig) * (1 - at).sqrt()
246 | x = torch.reshape(init_noise, (num_batches, 3, 256, 256))
247 |
248 |
249 | diffusion = create_gaussian_diffusion(
250 | steps=1000,
251 | learn_sigma=True,
252 | noise_schedule="linear",
253 | use_kl=False,
254 | predict_xstart=False,
255 | rescale_timesteps=False,
256 | rescale_learned_sigmas=False,
257 | timestep_respacing="",
258 | )
259 | xt = None
260 | with torch.no_grad():
261 | skip = config.diffusion.num_diffusion_timesteps//ddimsteps
262 | n = x.size(0)
263 | x0_preds = []
264 | xt = x ###20 x 3 x 256 x 256
265 |
266 | # generate time schedule
267 | times = range(0, 1000, skip) #########0, 1, 2, ....
268 | times_next = [-1] + list(times[:-1])
269 | times_pair = zip(reversed(times), reversed(times_next))
270 |
271 | ct = 0
272 | ###################################start reverse sampling############################################
273 | for i, j in tqdm.tqdm(times_pair, total=len(times)):
274 | t = (torch.ones(n) * i).to(x.device)
275 | next_t = (torch.ones(n) * j).to(x.device)
276 | ########if time travel do two passes, otherwise do one pass#########
277 | travels = tot_iters
278 | ct += 1
279 |
280 | for zhoumu in range(travels):
281 | print("zhoumu: ", zhoumu)
282 | t = (torch.ones(n) * i).to(x.device)
283 | next_t = (torch.ones(n) * j).to(x.device)
284 | at = compute_alpha(self.betas, t.long())
285 | bt = torch.index_select(self.betas,0,t.long())
286 | at_next = compute_alpha(self.betas, next_t.long())
287 | at = at[0,0,0,0]
288 | at_next = at_next[0,0,0,0]
289 | #################################reverse with consistency########################################
290 | et_agg = list() ###initialize a list of scores
291 |
292 | ###########################################ADJ slices#############################################
293 | if ct % 2 != 1:
294 | if vps:
295 | for M in range(1): ####number of VPS iterations
296 | noise = torch.randn_like(xt)
297 | ####################added by bowen 3/24/2024####################
298 | et = self.vps_blend(xt, model, t, num_batches = num_batches)
299 | ####################################################################
300 | # if blend:
301 | # et = self.blendscore(xt, model, t, start_ind = 0)
302 | # else:
303 | # et = self.blendscore(xt, model,t, start_head = 1) ###1 x 60 x 256 x 256
304 | et = torch.reshape(et, (num_batches, 3, 256, 256))
305 | lam_ = vps_scale
306 | xt = xt - lam_ * (1 - at).sqrt() * et
307 | xt = xt + ((lam_ * (2-lam_))*(1-at)).sqrt() * noise * 1
308 | if blend:
309 | et = self.blendscore(xt, model,t,num_batches = num_batches)
310 | else:
311 | y = torch.ones(1) * 1
312 | y=y.to(xt.device).to(torch.long)
313 | model_kwargs = {}
314 | model_kwargs["y"] = y
315 | for j in range(xt.shape[0]//1):
316 | et_sing = model(xt[j*1:(j+1)*1], t, **model_kwargs) ####4 x 6 x 256 x 256
317 | et_agg.append(et_sing)
318 | et = torch.cat(et_agg, dim=0) ####20 x 6 x 256 x 256
319 | et = et[:, :3] ####20 x 3 x 256 x 256
320 | ###reshape xt and et
321 | et_ = torch.reshape(et, ((num_batches * 3), 1, 256, 256))
322 | #######################################SLICE JUMP#############################################
323 | if ct % 2 == 1: ###147258369 4710
324 | print(ct, "changing et to slice jumping")
325 | et_ = self.vps_blend(xt, model, t, num_batches = num_batches)
326 | et_ = torch.reshape(et_, ((num_batches * 3), 1, 256, 256))
327 | xt_ = torch.reshape(xt, ((num_batches * 3), 1, 256, 256))
328 | x0_t = (xt_ - et_ * (1 - at).sqrt()) / at.sqrt() ###60 x 1 x 256 x 256 scale [-1, 1]
329 |
330 | ###########################if inverse problem solving ######################################################
331 | x0_t = torch.clip(x0_t, -1, 1) ####clip to [-1, 1]
332 | x0_t = (x0_t +1)/2 ###rescale to [0, 1]
333 | x0_t_hat = None
334 | eta = self.args.eta
335 | if zhoumu == 0:
336 | x0_t_hat = ADMM(x0_t, ATy, n_ADMM=n_ADMM) ######[0,1]
337 | # x0_t_hat = torch.clip(x0_t_hat, 0, 1)
338 | x0_t_hat = x0_t_hat * 2 - 1 #######rescale back to [-1, 1]
339 | else:
340 | x0_t_hat = x0_t * 2 - 1 #######rescale back to [-1, 1]
341 | ############################################################################################################
342 |
343 | ###########################else######################################################
344 | x0_t_hat = x0_t
345 | eta = self.args.eta
346 |
347 | c1 = (1 - at_next).sqrt() * eta
348 | c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5)
349 | if j != 0:
350 | xt_ = at_next.sqrt() * x0_t_hat + c1 * torch.randn_like(x0_t) + c2 * et_
351 | else:
352 | xt_ = x0_t_hat
353 | xt = torch.reshape(xt_, (num_batches, 3, 256, 256)) ####reshape back
354 |
355 | ######################################################################################################
356 | if noise_flag:
357 | print("added noise")
358 | np.save(f"ctrecon_jump_200NFE_{self.args.Nview}projs_pgnoise.npy", xt.detach().cpu().numpy())
359 | else:
360 | np.save(f"ctrecon_jump_200NFE_{self.args.Nview}projs.npy", xt.detach().cpu().numpy())
361 |
362 | if self.args.deg == "SV-CT":
363 | np.save("x_sample_ddim" + str(ddimsteps) + "_iter65000_reconstructionL67_blend3_rho" + str(rho) + "ttnew" + str(tot_iters) + "_full_view6_47_jump_ful.npy", xt.detach().cpu().numpy())
364 | if self.args.deg == "LA-CT":
365 | np.save("x_sample_ddim" + str(ddimsteps) + f"_iter65000_lactL67_blend3_half{pre_slices}_full_view90.npy", xt.detach().cpu().numpy())
366 |
367 |
368 | if blend:
369 | if vps:
370 | np.save("/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/blendDDS/apr7/x_sample_ddim" + str(ddimsteps) + "_iter65000_reconstructionL67_blend3_rho"+str(rho)+"ttnew"+str(tot_iters)+"_vps_"+ str(vps_scale)+"_full_skip2.npy", xt.detach().cpu().numpy())
371 | else:
372 | np.save("/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/blendDDS/apr7/x_sample_ddim" + str(ddimsteps) + "_iter65000_reconstructionL67_blend3_rho" + str(rho) + "ttnew" + str(tot_iters) + "_full_jump_skip2_view6.npy", xt.detach().cpu().numpy())
373 | else:
374 | if vps:
375 | np.save("/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/blendDDS/apr7/x_sample_ddim" + str(ddimsteps) + "_iter65000_reconstructionL67_rho" + str(rho) + "ttnew" + str(tot_iters) +"_vps_"+ str(vps_scale)+ "_full_skip2_view6.npy", xt.detach().cpu().numpy())
376 | else:
377 | np.save("/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/blendDDS/apr7/x_sample_ddim" + str(ddimsteps) + "_iter65000_reconstructionL67_rho" + str(rho) + "ttnew" + str(tot_iters) + "_full_jump_skip2_view6.npy", xt.detach().cpu().numpy())
378 |
379 |
380 |
381 |
--------------------------------------------------------------------------------
/eval_3D_blend_cond.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | Nview=8
4 | T_sampling=50
5 | eta=0.85
6 |
7 | python main.py \
8 | --type '3dblendcond' \
9 | --config AAPM_256_lsun.yaml \
10 | --dataset_path "/nfs/turbo/coe-liyues/bowenbw/DDS/indist_samples/CT/L067" \
11 | --ckpt_load_name "/nfs/turbo/coe-liyues/bowenbw/DDS/checkpoints/AAPM256_1M.pth" \
12 | --Nview $Nview \
13 | --eta $eta \
14 | --deg "SV-CT" \
15 | --sigma_y 0.00 \
16 | --T_sampling 100 \
17 | --T_sampling $T_sampling \
18 | -i ./results
19 |
--------------------------------------------------------------------------------
/eval_3D_blend_cond_lidc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import time
4 | import glob
5 | import json
6 | import sys
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import tqdm
11 | import torch
12 | import torch.utils.data as data
13 | import torchvision.utils as tvu
14 |
15 | from guided_diffusion.models import Model
16 | from guided_diffusion.script_util import create_model, classifier_defaults, args_to_dict, create_gaussian_diffusion
17 | from guided_diffusion.utils import get_alpha_schedule
18 | import random
19 |
20 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity
21 | from scipy.linalg import orth
22 | from pathlib import Path
23 |
24 | from physics.ct import CT
25 | from time import time
26 | from utils import shrink, CG, clear, batchfy, _Dz, _DzT, get_beta_schedule
27 |
28 |
29 | def compute_alpha(beta, t):
30 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
31 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
32 | return a
33 |
34 |
35 | class Diffusion(object):
36 | def __init__(self, args, config, device=None):
37 | self.args = args
38 | self.args.image_folder = Path(self.args.image_folder)
39 | for t in ["input", "recon", "label"]:
40 | if t == "recon":
41 | (self.args.image_folder / t / "progress").mkdir(exist_ok=True, parents=True)
42 | else:
43 | (self.args.image_folder / t).mkdir(exist_ok=True, parents=True)
44 | self.config = config
45 | print(self.config)
46 | if device is None:
47 | device = (
48 | torch.device("cuda")
49 | if torch.cuda.is_available()
50 | else torch.device("cpu")
51 | )
52 | self.device = device
53 |
54 | self.model_var_type = config.model.var_type
55 | betas = get_beta_schedule(
56 | beta_schedule=config.diffusion.beta_schedule,
57 | beta_start=config.diffusion.beta_start,
58 | beta_end=config.diffusion.beta_end,
59 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
60 | )
61 | betas = self.betas = torch.from_numpy(betas).float().to(self.device)
62 | self.num_timesteps = betas.shape[0]
63 |
64 | alphas = 1.0 - betas
65 | alphas_cumprod = alphas.cumprod(dim=0)
66 | alphas_cumprod_prev = torch.cat(
67 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
68 | )
69 | self.alphas_cumprod_prev = alphas_cumprod_prev
70 | posterior_variance = (
71 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
72 | )
73 | if self.model_var_type == "fixedlarge":
74 | self.logvar = betas.log()
75 | elif self.model_var_type == "fixedsmall":
76 | self.logvar = posterior_variance.clamp(min=1e-20).log()
77 |
78 | def sample(self):
79 | config_dict = vars(self.config.model)
80 | config_dict['use_spacecode'] = False
81 | config_dict["class_cond"] = True
82 | model = create_model(**config_dict)
83 | ckpt = "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/LIDC_triplane3D_finetunelarge_4232024_iter62099_cond.ckpt"
84 |
85 | model.load_state_dict(torch.load(ckpt, map_location=self.device)["state_dict"])
86 | print(f"Model ckpt loaded from {ckpt}")
87 | model.to(self.device)
88 | model.eval()
89 | model = torch.nn.DataParallel(model)
90 |
91 | print('Run 3D DDS + DiffusionMBIR.',
92 | f'{self.args.T_sampling} sampling steps.',
93 | f'Task: {self.args.deg}.'
94 | )
95 | self.dds3d(model)
96 |
97 |
98 | def blendscore(self, xt, model,t, start_ind = None, start_head = 0, num_batches = 180):
99 | model_kwargs = {}
100 | y = torch.ones(1) * 1
101 | y=y.to(xt.device).to(torch.long)
102 | model_kwargs["y"] = y
103 |
104 | et = torch.zeros((1, num_batches * 3, 256, 256)).to(xt.device).to(torch.float32)
105 | xt = torch.reshape(xt, (1, num_batches * 3, 256, 256))
106 | et[:,:3,:,:] = model(xt[:,:3,:,:], t, **model_kwargs)[:,:3]
107 | et[:,xt.shape[1]-3:, :,:] = model(xt[:,xt.shape[1]-3:,:,:], t, **model_kwargs)[:,:3]
108 |
109 | if start_ind is None:
110 | start_ind = np.random.randint(start_head,3)
111 | for j in range(start_ind, xt.shape[1]-2, 3):
112 | #####randomly select instead of summing
113 | et_sing = model(xt[:,j:(j+3),:,:], t, **model_kwargs)[:,:3] #####1 x 3 x 256 x 256
114 | et[:,j:(j+3), :,:] = et_sing
115 | return et
116 |
117 | def vps_blend(self, xt, model,t, start_ind = None, start_head = 0, num_batches = 180):
118 | model_kwargs = {}
119 | y = torch.ones(1) * 3
120 | y=y.to(xt.device).to(torch.long)
121 | model_kwargs["y"] = y
122 |
123 | et = torch.zeros((1, num_batches * 3, 256, 256)).to(xt.device).to(torch.float32)
124 |
125 | xt = torch.reshape(xt, (1, num_batches * 3, 256, 256))
126 | # 147258369
127 | for i in range(0,xt.shape[1], 9):
128 | for m in range(3):
129 | et[:,[i+m,i+m+3, i+m+6],:,:] = model(xt[:,[i+m,i+m+3, i+m+6],:,:], t, **model_kwargs)[:,:3]
130 |
131 | return et
132 |
133 |
134 | def dds3d(self, model):
135 | args, config = self.args, self.config
136 | print(f"Dataset path: {self.args.dataset_path}")
137 | root = Path(self.args.dataset_path)
138 |
139 | # parameters to be moved to args
140 | Nview = self.args.Nview
141 | rho = 0.001 ###9:23pm 3/3
142 | lamb = 0.05/1000
143 | n_ADMM = 1
144 | n_CG = self.args.CG_iter
145 | print(n_CG)
146 |
147 | blend= True ####test again 4/7
148 |
149 | time_travel = False ### 4:20 3/10
150 |
151 | vps = False ###debugging
152 |
153 | ddimsteps= 200 ####11:17pm 3/20
154 |
155 | ####1:21 try hard consistency for only one time
156 |
157 | # print(rho, lamb, "admm params")
158 | # print("blending", blend)
159 |
160 | # Specify save directory for saving generated samples
161 | save_root = Path(self.args.image_folder)
162 | save_root.mkdir(parents=True, exist_ok=True)
163 |
164 | irl_types = ['vol', 'input', 'recon', 'label']
165 | for t in irl_types:
166 | save_root_f = save_root / t
167 | save_root_f.mkdir(parents=True, exist_ok=True)
168 |
169 | fname_list = os.listdir("/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/validation_LIDC")
170 | root = "/nfs/turbo/coe-liyues/bowenbw/3DCT/benchmark/validation_LIDC"
171 |
172 | fname_list.sort()
173 | num_batches = 87 ######4/22 2:16pm
174 | fname_list = fname_list[:(3 * num_batches)] #####4/22 2:22pm
175 |
176 | ######################################################################################################
177 |
178 | print(fname_list)
179 | all_img = []
180 | batch_size = 3
181 | print("Loading all data")
182 | if time_travel:
183 | tot_iters = 2
184 | else:
185 | tot_iters = 1
186 | for fname in fname_list:
187 | just_name = fname.split('.')[0]
188 | img = torch.from_numpy(np.load(os.path.join(root, fname), allow_pickle=True))
189 | h, w = img.shape
190 | img = img.view(1, 1, h, w)
191 | all_img.append(img)
192 | all_img = torch.cat(all_img, dim=0)
193 | x_orig = all_img
194 | print(f"Data loaded shape : {all_img.shape}")
195 | x_orig = x_orig.to(torch.float32)
196 | print("Data type is :", x_orig.dtype)
197 | img_shape = (x_orig.shape[0], config.data.channels, config.data.image_size, config.data.image_size)
198 | if self.args.deg == "SV-CT":
199 | A_funcs = CT(img_width=256, radon_view=self.args.Nview, uniform=True, circle=False, device=config.device)
200 | elif self.args.deg == "LA-CT":
201 | A_funcs = CT(img_width=256, radon_view=self.args.Nview, uniform=False, circle=False, device=config.device)
202 | A = lambda z: A_funcs.A(z)
203 | Ap = lambda z: A_funcs.A_dagger(z)
204 | def Acg_TV(x):
205 | return A_funcs.AT(A_funcs.A(x)) + rho * _DzT(_Dz(x))
206 | def ADMM(x, ATy, n_ADMM=n_ADMM):
207 | nonlocal del_z, udel_z
208 | for _ in range(n_ADMM):
209 | bcg_TV = ATy + rho * (_DzT(del_z) - _DzT(udel_z))
210 | x = CG(Acg_TV, bcg_TV, x, n_inner=n_CG)
211 | del_z = shrink(_Dz(x) + udel_z, lamb / rho)
212 | udel_z = _Dz(x) - del_z + udel_z
213 | return x
214 | del_z = torch.zeros(img_shape, device=self.device)
215 | udel_z = torch.zeros(img_shape, device=self.device)
216 | x_orig = x_orig.to(self.device) ######n x 1 x 256 x 256
217 | print(x_orig.min(), x_orig.max(), "xorig")
218 | y = A(x_orig)
219 | print(y.shape, "projection shape")
220 | Apy = Ap(y)
221 | print(Apy.shape, "Apy backprojection shape")
222 | ATy = A_funcs.AT(y)
223 | ##########################original####################################################
224 | x = torch.randn(num_batches, 3, 256, 256, device = self.device) ####initial noise
225 | ##################################################################################################
226 |
227 | diffusion = create_gaussian_diffusion(
228 | steps=1000,
229 | learn_sigma=True,
230 | noise_schedule="linear",
231 | use_kl=False,
232 | predict_xstart=False,
233 | rescale_timesteps=False,
234 | rescale_learned_sigmas=False,
235 | timestep_respacing="",
236 | )
237 | xt = None
238 | with torch.no_grad():
239 | skip = config.diffusion.num_diffusion_timesteps//ddimsteps
240 | n = x.size(0)
241 | x0_preds = []
242 | xt = x ###20 x 3 x 256 x 256
243 |
244 | # generate time schedule
245 | times = range(0, 1000, skip) #########0, 1, 2, ....
246 | times_next = [-1] + list(times[:-1])
247 | times_pair = zip(reversed(times), reversed(times_next))
248 |
249 | if blend:
250 | n = 1
251 | else:
252 | n = 1
253 |
254 | ct = 0
255 | ###################################start reverse sampling############################################
256 | for i, j in tqdm.tqdm(times_pair, total=len(times)):
257 | t = (torch.ones(n) * i).to(x.device)
258 | next_t = (torch.ones(n) * j).to(x.device)
259 | ########if time travel do two passes, otherwise do one pass#########
260 | travels = tot_iters
261 | ct += 1
262 |
263 | for zhoumu in range(travels):
264 | print("zhoumu: ", zhoumu)
265 | t = (torch.ones(n) * i).to(x.device)
266 | next_t = (torch.ones(n) * j).to(x.device)
267 | at = compute_alpha(self.betas, t.long())
268 | bt = torch.index_select(self.betas,0,t.long())
269 | at_next = compute_alpha(self.betas, next_t.long())
270 | at = at[0,0,0,0]
271 | at_next = at_next[0,0,0,0]
272 | #################################reverse with consistency########################################
273 | et_agg = list() ###initialize a list of scores
274 | ###########################################ADJ slices#############################################
275 | if ct % 2 == 0:
276 | if vps:
277 | for M in range(1): ####number of VPS iterations
278 | noise = torch.randn_like(xt)
279 | ####################added by bowen 3/24/2024####################
280 | et = self.vps_blend(xt, model, t, num_batches = num_batches)
281 | ####################################################################
282 | et = torch.reshape(et, (num_batches, 3, 256, 256))
283 | lam_ = vps_scale
284 | xt = xt - lam_ * (1 - at).sqrt() * et
285 | xt = xt + ((lam_ * (2-lam_))*(1-at)).sqrt() * noise * 1
286 | if blend:
287 | et = self.blendscore(xt, model,t,num_batches = num_batches)
288 | else:
289 | y = torch.ones(1) * 1
290 | y=y.to(xt.device).to(torch.long)
291 | model_kwargs = {}
292 | model_kwargs["y"] = y
293 | for j in range(xt.shape[0]//1):
294 | et_sing = model(xt[j*1:(j+1)*1], t, **model_kwargs) ####4 x 6 x 256 x 256
295 | et_agg.append(et_sing)
296 | et = torch.cat(et_agg, dim=0) ####20 x 6 x 256 x 256
297 | et = et[:, :3] ####20 x 3 x 256 x 256
298 | ###reshape xt and et
299 | et_ = torch.reshape(et, ((num_batches * 3), 1, 256, 256))
300 | #######################################SLICE JUMP#############################################
301 | if ct % 2 == 1: ###147258369 4710
302 | print(ct, "changing et to slice jumping")
303 | et_ = self.vps_blend(xt, model, t, num_batches = num_batches)
304 | et_ = torch.reshape(et_, ((num_batches * 3), 1, 256, 256))
305 | xt_ = torch.reshape(xt, ((num_batches * 3), 1, 256, 256))
306 | x0_t = (xt_ - et_ * (1 - at).sqrt()) / at.sqrt() ###60 x 1 x 256 x 256 scale [-1, 1]
307 | # x0_t = torch.clip(x0_t, -1, 1) ####clip to [-1, 1]
308 | x0_t = (x0_t +1)/2 ###rescale to [0, 1]
309 | x0_t_hat = None
310 | eta = self.args.eta
311 | if zhoumu == 0:
312 | x0_t_hat = ADMM(x0_t, ATy, n_ADMM=n_ADMM) ######[0,1]
313 | # x0_t_hat = torch.clip(x0_t_hat, 0, 1)
314 | x0_t_hat = x0_t_hat * 2 - 1 #######rescale back to [-1, 1]
315 | else:
316 | x0_t_hat = x0_t * 2 - 1 #######rescale back to [-1, 1]
317 |
318 | c1 = (1 - at_next).sqrt() * eta
319 | c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5)
320 | if j != 0:
321 | xt_ = at_next.sqrt() * x0_t_hat + c1 * torch.randn_like(x0_t) + c2 * et_
322 | else:
323 | xt_ = x0_t_hat
324 | xt = torch.reshape(xt_, (num_batches, 3, 256, 256)) ####reshape back
325 |
326 | if self.args.deg == "SV-CT":
327 | np.save("x_sample_ddim" + str(ddimsteps) + "_iter62000large_reconstructionL67_blend3_rho" + str(rho) + "ttnew" + str(tot_iters) + "_full_view4_424_jump_LIDC_ftldct_half_retest.npy", xt.detach().cpu().numpy()) #############imnet pretrain 40000 iters
328 |
329 |
330 | if self.args.deg == "LA-CT":
331 | np.save("x_sample_ddim" + str(ddimsteps) + "_iter62000large_lactlidc_blend3_full_view90.npy", xt.detach().cpu().numpy())
332 |
333 |
--------------------------------------------------------------------------------
/guided_diffusion/CTDataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from torchvision import datasets
4 | from torchvision.transforms import ToTensor
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 |
8 | from glob import glob
9 |
10 |
11 | class CTDataset(Dataset):
12 | def __init__(self, metadata=None, img_dir=None, transform=None, target_transform=lambda x: x, patient_num = -1):
13 |
14 | self.training_paths = glob('/nfs/turbo/coe-liyues/bowenbw/3DCT/AAPM_fusion_training/*')
15 | self.transform = transform
16 | self.target_transform = target_transform
17 | self.patient_num = patient_num
18 | print("length of training data", len(self.training_paths))
19 |
20 |
21 | def __len__(self):
22 | return len(self.training_paths)
23 |
24 | def __getitem__(self, idx):
25 |
26 | image = np.load(self.training_paths[idx])
27 | image = np.transpose(image, (2,0,1))
28 | image = np.clip(image*2-1, -1, 1)
29 |
30 | return torch.from_numpy(image)
31 |
32 |
33 | class CTCondDataset(Dataset):
34 | def __init__(self, metadata=None, img_dir=None, transform=None, target_transform=lambda x: x, patient_num = -1):
35 |
36 | self.training_paths = glob('/nfs/turbo/coe-liyues/bowenbw/3DCT/AAPM_fusion_training/*')
37 | self.transform = transform
38 | self.target_transform = target_transform
39 | self.patient_num = patient_num
40 | print("length of training data", len(self.training_paths))
41 |
42 | def __len__(self):
43 | return len(self.training_paths)
44 |
45 |
46 | def __getitem__(self, idx):
47 | image = None
48 | return None
49 |
50 |
51 |
52 | if __name__ == "__main__":
53 | ds = CTDataset()
54 | params = {'batch_size': 2}
55 | training_generator = torch.utils.data.DataLoader(ds, **params)
56 | ct = 0
57 | for local_batch in training_generator:
58 | print(local_batch.shape)
59 | ct += 1
60 | if ct > 4:
61 | break
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/CTDataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/CTDataset.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/dist_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/dist_util.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/fp16_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/fp16_util.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/fp16_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/fp16_util.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/logger.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/logger.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/losses.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/losses.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/losses.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/losses.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/models.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/models.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/models.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/nn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/nn.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/nn.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/nn.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/resample.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/resample.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/respace.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/respace.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/respace.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/respace.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/script_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/script_util.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/script_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/script_util.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/train_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/train_util.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/unet.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/unet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/unet.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/efzero/DiffusionBlend/05dc7fc48259888b9ec9dfddd6ebf8907f2d9730/guided_diffusion/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | from mpi4py import MPI
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | # Change this to reflect your cluster layout.
15 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
16 | GPUS_PER_NODE = 8
17 |
18 | SETUP_RETRY_COUNT = 3
19 |
20 |
21 | def setup_dist():
22 | """
23 | Setup a distributed process group.
24 | """
25 | if dist.is_initialized():
26 | return
27 | # os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
28 |
29 | comm = MPI.COMM_WORLD
30 | backend = "gloo" if not th.cuda.is_available() else "nccl"
31 |
32 | if backend == "gloo":
33 | hostname = "localhost"
34 | else:
35 | hostname = socket.gethostbyname(socket.getfqdn())
36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
37 | os.environ["RANK"] = str(comm.rank)
38 | os.environ["WORLD_SIZE"] = str(comm.size)
39 |
40 | port = comm.bcast(_find_free_port(), root=0)
41 | os.environ["MASTER_PORT"] = str(port)
42 | dist.init_process_group(backend=backend, init_method="env://")
43 |
44 |
45 | def dev():
46 | """
47 | Get the device to use for torch.distributed.
48 | """
49 | if th.cuda.is_available():
50 | return th.device(f"cuda")
51 | return th.device("cpu")
52 |
53 |
54 | def load_state_dict(path, **kwargs):
55 | """
56 | Load a PyTorch file without redundant fetches across MPI ranks.
57 | """
58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit
59 | if MPI.COMM_WORLD.Get_rank() == 0:
60 | with bf.BlobFile(path, "rb") as f:
61 | data = f.read()
62 | num_chunks = len(data) // chunk_size
63 | if len(data) % chunk_size:
64 | num_chunks += 1
65 | MPI.COMM_WORLD.bcast(num_chunks)
66 | for i in range(0, len(data), chunk_size):
67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
68 | else:
69 | num_chunks = MPI.COMM_WORLD.bcast(None)
70 | data = bytes()
71 | for _ in range(num_chunks):
72 | data += MPI.COMM_WORLD.bcast(None)
73 |
74 | return th.load(io.BytesIO(data), **kwargs)
75 |
76 |
77 | def sync_params(params):
78 | """
79 | Synchronize a sequence of Tensors across ranks from rank 0.
80 | """
81 | for p in params:
82 | with th.no_grad():
83 | dist.broadcast(p, 0)
84 |
85 |
86 | def _find_free_port():
87 | try:
88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
89 | s.bind(("", 0))
90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
91 | return s.getsockname()[1]
92 | finally:
93 | s.close()
94 |
--------------------------------------------------------------------------------
/guided_diffusion/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9 |
10 | from . import logger
11 |
12 | INITIAL_LOG_LOSS_SCALE = 20.0
13 |
14 |
15 | def convert_module_to_f16(l):
16 | """
17 | Convert primitive modules to float16.
18 | """
19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20 | l.weight.data = l.weight.data.half()
21 | if l.bias is not None:
22 | l.bias.data = l.bias.data.half()
23 |
24 |
25 | def convert_module_to_f32(l):
26 | """
27 | Convert primitive modules to float32, undoing convert_module_to_f16().
28 | """
29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30 | l.weight.data = l.weight.data.float()
31 | if l.bias is not None:
32 | l.bias.data = l.bias.data.float()
33 |
34 |
35 | def make_master_params(param_groups_and_shapes):
36 | """
37 | Copy model parameters into a (differently-shaped) list of full-precision
38 | parameters.
39 | """
40 | master_params = []
41 | for param_group, shape in param_groups_and_shapes:
42 | master_param = nn.Parameter(
43 | _flatten_dense_tensors(
44 | [param.detach().float() for (_, param) in param_group]
45 | ).view(shape)
46 | )
47 | master_param.requires_grad = True
48 | master_params.append(master_param)
49 | return master_params
50 |
51 |
52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53 | """
54 | Copy the gradients from the model parameters into the master parameters
55 | from make_master_params().
56 | """
57 | for master_param, (param_group, shape) in zip(
58 | master_params, param_groups_and_shapes
59 | ):
60 | master_param.grad = _flatten_dense_tensors(
61 | [param_grad_or_zeros(param) for (_, param) in param_group]
62 | ).view(shape)
63 |
64 |
65 | def master_params_to_model_params(param_groups_and_shapes, master_params):
66 | """
67 | Copy the master parameter data back into the model parameters.
68 | """
69 | # Without copying to a list, if a generator is passed, this will
70 | # silently not copy any parameters.
71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72 | for (_, param), unflat_master_param in zip(
73 | param_group, unflatten_master_params(param_group, master_param.view(-1))
74 | ):
75 | param.detach().copy_(unflat_master_param)
76 |
77 |
78 | def unflatten_master_params(param_group, master_param):
79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80 |
81 |
82 | def get_param_groups_and_shapes(named_model_params):
83 | named_model_params = list(named_model_params)
84 | scalar_vector_named_params = (
85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86 | (-1),
87 | )
88 | matrix_named_params = (
89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90 | (1, -1),
91 | )
92 | return [scalar_vector_named_params, matrix_named_params]
93 |
94 |
95 | def master_params_to_state_dict(
96 | model, param_groups_and_shapes, master_params, use_fp16
97 | ):
98 | if use_fp16:
99 | state_dict = model.state_dict()
100 | for master_param, (param_group, _) in zip(
101 | master_params, param_groups_and_shapes
102 | ):
103 | for (name, _), unflat_master_param in zip(
104 | param_group, unflatten_master_params(param_group, master_param.view(-1))
105 | ):
106 | assert name in state_dict
107 | state_dict[name] = unflat_master_param
108 | else:
109 | state_dict = model.state_dict()
110 | for i, (name, _value) in enumerate(model.named_parameters()):
111 | assert name in state_dict
112 | state_dict[name] = master_params[i]
113 | return state_dict
114 |
115 |
116 | def state_dict_to_master_params(model, state_dict, use_fp16):
117 | if use_fp16:
118 | named_model_params = [
119 | (name, state_dict[name]) for name, _ in model.named_parameters()
120 | ]
121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122 | master_params = make_master_params(param_groups_and_shapes)
123 | else:
124 | master_params = [state_dict[name] for name, _ in model.named_parameters()]
125 | return master_params
126 |
127 |
128 | def zero_master_grads(master_params):
129 | for param in master_params:
130 | param.grad = None
131 |
132 |
133 | def zero_grad(model_params):
134 | for param in model_params:
135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136 | if param.grad is not None:
137 | param.grad.detach_()
138 | param.grad.zero_()
139 |
140 |
141 | def param_grad_or_zeros(param):
142 | if param.grad is not None:
143 | return param.grad.data.detach()
144 | else:
145 | return th.zeros_like(param)
146 |
147 |
148 | class MixedPrecisionTrainer:
149 | def __init__(
150 | self,
151 | *,
152 | model,
153 | use_fp16=False,
154 | fp16_scale_growth=1e-3,
155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156 | ):
157 | self.model = model
158 | self.use_fp16 = use_fp16
159 | self.fp16_scale_growth = fp16_scale_growth
160 |
161 | self.model_params = list(self.model.parameters())
162 | self.master_params = self.model_params
163 | self.param_groups_and_shapes = None
164 | self.lg_loss_scale = initial_lg_loss_scale
165 |
166 | if self.use_fp16:
167 | self.param_groups_and_shapes = get_param_groups_and_shapes(
168 | self.model.named_parameters()
169 | )
170 | self.master_params = make_master_params(self.param_groups_and_shapes)
171 | self.model.convert_to_fp16()
172 |
173 | def zero_grad(self):
174 | zero_grad(self.model_params)
175 |
176 | def backward(self, loss: th.Tensor):
177 | if self.use_fp16:
178 | loss_scale = 2 ** self.lg_loss_scale
179 | (loss * loss_scale).backward()
180 | else:
181 | loss.backward()
182 |
183 | def optimize(self, opt: th.optim.Optimizer):
184 | if self.use_fp16:
185 | return self._optimize_fp16(opt)
186 | else:
187 | return self._optimize_normal(opt)
188 |
189 | def _optimize_fp16(self, opt: th.optim.Optimizer):
190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193 | if check_overflow(grad_norm):
194 | self.lg_loss_scale -= 1
195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196 | zero_master_grads(self.master_params)
197 | return False
198 |
199 | logger.logkv_mean("grad_norm", grad_norm)
200 | logger.logkv_mean("param_norm", param_norm)
201 |
202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
203 | opt.step()
204 | zero_master_grads(self.master_params)
205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
206 | self.lg_loss_scale += self.fp16_scale_growth
207 | return True
208 |
209 | def _optimize_normal(self, opt: th.optim.Optimizer):
210 | grad_norm, param_norm = self._compute_norms()
211 | logger.logkv_mean("grad_norm", grad_norm)
212 | logger.logkv_mean("param_norm", param_norm)
213 | opt.step()
214 | return True
215 |
216 | def _compute_norms(self, grad_scale=1.0):
217 | grad_norm = 0.0
218 | param_norm = 0.0
219 | for p in self.master_params:
220 | with th.no_grad():
221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
222 | if p.grad is not None:
223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
225 |
226 | def master_params_to_state_dict(self, master_params):
227 | return master_params_to_state_dict(
228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16
229 | )
230 |
231 | def state_dict_to_master_params(self, state_dict):
232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
233 |
234 |
235 | def check_overflow(value):
236 | return (value == float("inf")) or (value == -float("inf")) or (value != value)
237 |
--------------------------------------------------------------------------------
/guided_diffusion/image_datasets.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | from PIL import Image
5 | import blobfile as bf
6 | from mpi4py import MPI
7 | import numpy as np
8 | from torch.utils.data import DataLoader, Dataset
9 |
10 |
11 | def load_data(
12 | *,
13 | data_dir,
14 | batch_size,
15 | image_size,
16 | class_cond=False,
17 | deterministic=False,
18 | random_crop=False,
19 | random_flip=True,
20 | use_mps=False,
21 | ):
22 | if not data_dir:
23 | raise ValueError("unspecified data directory")
24 | all_files = _list_image_files_recursively(data_dir, use_mps=use_mps)
25 | classes = None
26 | if class_cond:
27 | class_names = [bf.basename(path).split("_")[0] for path in all_files]
28 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
29 | classes = [sorted_classes[x] for x in class_names]
30 | dataset = NpyDataset(
31 | image_size,
32 | all_files,
33 | classes=classes,
34 | shard=MPI.COMM_WORLD.Get_rank(),
35 | num_shards=MPI.COMM_WORLD.Get_size(),
36 | random_crop=random_crop,
37 | random_flip=random_flip,
38 | )
39 | if deterministic:
40 | loader = DataLoader(
41 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
42 | )
43 | else:
44 | loader = DataLoader(
45 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
46 | )
47 | while True:
48 | yield from loader
49 |
50 |
51 | def _list_image_files_recursively(data_dir, use_mps=False):
52 | results = []
53 | for entry in sorted(bf.listdir(data_dir)):
54 | if use_mps:
55 | if entry == "mps":
56 | continue
57 | else:
58 | if "img" in entry:
59 | continue
60 | full_path = bf.join(data_dir, entry)
61 | ext = entry.split(".")[-1]
62 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif", "npy"]:
63 | results.append(full_path)
64 | elif bf.isdir(full_path):
65 | results.extend(_list_image_files_recursively(full_path))
66 | return results
67 |
68 |
69 | class ImageDataset(Dataset):
70 | def __init__(
71 | self,
72 | resolution,
73 | image_paths,
74 | classes=None,
75 | shard=0,
76 | num_shards=1,
77 | random_crop=False,
78 | random_flip=True,
79 | ):
80 | super().__init__()
81 | self.resolution = resolution
82 | self.local_images = image_paths[shard:][::num_shards]
83 | self.local_classes = None if classes is None else classes[shard:][::num_shards]
84 | self.random_crop = random_crop
85 | self.random_flip = random_flip
86 |
87 | def __len__(self):
88 | return len(self.local_images)
89 |
90 | def __getitem__(self, idx):
91 | path = self.local_images[idx]
92 | with bf.BlobFile(path, "rb") as f:
93 | pil_image = Image.open(f)
94 | pil_image.load()
95 | pil_image = pil_image.convert("RGB")
96 |
97 | if self.random_crop:
98 | arr = random_crop_arr(pil_image, self.resolution)
99 | else:
100 | arr = center_crop_arr(pil_image, self.resolution)
101 |
102 | if self.random_flip and random.random() < 0.5:
103 | arr = arr[:, ::-1]
104 |
105 | arr = arr.astype(np.float32) / 127.5 - 1
106 |
107 | out_dict = {}
108 | if self.local_classes is not None:
109 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
110 | return np.transpose(arr, [2, 0, 1]), out_dict
111 |
112 |
113 | class NpyDataset(Dataset):
114 | def __init__(
115 | self,
116 | resolution,
117 | image_paths,
118 | classes=None,
119 | shard=0,
120 | num_shards=1,
121 | random_crop=False,
122 | random_flip=True,
123 | ):
124 | super().__init__()
125 | self.resolution = resolution
126 | self.local_images = image_paths[shard:][::num_shards]
127 | self.local_classes = None if classes is None else classes[shard:][::num_shards]
128 | self.random_crop = random_crop
129 | self.random_flip = random_flip
130 |
131 | def __len__(self):
132 | return len(self.local_images)
133 |
134 | def __getitem__(self, idx):
135 | path = self.local_images[idx]
136 | arr = np.load(path)
137 | arr = arr[np.newaxis, :, :]
138 | out_dict = {}
139 | return arr, out_dict
140 |
141 |
142 | def center_crop_arr(pil_image, image_size):
143 | # We are not on a new enough PIL to support the `reducing_gap`
144 | # argument, which uses BOX downsampling at powers of two first.
145 | # Thus, we do it by hand to improve downsample quality.
146 | while min(*pil_image.size) >= 2 * image_size:
147 | pil_image = pil_image.resize(
148 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
149 | )
150 |
151 | scale = image_size / min(*pil_image.size)
152 | pil_image = pil_image.resize(
153 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
154 | )
155 |
156 | arr = np.array(pil_image)
157 | crop_y = (arr.shape[0] - image_size) // 2
158 | crop_x = (arr.shape[1] - image_size) // 2
159 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
160 |
161 |
162 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
163 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
164 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
165 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
166 |
167 | # We are not on a new enough PIL to support the `reducing_gap`
168 | # argument, which uses BOX downsampling at powers of two first.
169 | # Thus, we do it by hand to improve downsample quality.
170 | while min(*pil_image.size) >= 2 * smaller_dim_size:
171 | pil_image = pil_image.resize(
172 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
173 | )
174 |
175 | scale = smaller_dim_size / min(*pil_image.size)
176 | pil_image = pil_image.resize(
177 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
178 | )
179 |
180 | arr = np.array(pil_image)
181 | crop_y = random.randrange(arr.shape[0] - image_size + 1)
182 | crop_x = random.randrange(arr.shape[1] - image_size + 1)
183 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
184 |
--------------------------------------------------------------------------------
/guided_diffusion/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4 | """
5 |
6 | import os
7 | import sys
8 | import shutil
9 | import os.path as osp
10 | import json
11 | import time
12 | import datetime
13 | import tempfile
14 | import warnings
15 | from collections import defaultdict
16 | from contextlib import contextmanager
17 |
18 | DEBUG = 10
19 | INFO = 20
20 | WARN = 30
21 | ERROR = 40
22 |
23 | DISABLED = 50
24 |
25 |
26 | class KVWriter(object):
27 | def writekvs(self, kvs):
28 | raise NotImplementedError
29 |
30 |
31 | class SeqWriter(object):
32 | def writeseq(self, seq):
33 | raise NotImplementedError
34 |
35 |
36 | class HumanOutputFormat(KVWriter, SeqWriter):
37 | def __init__(self, filename_or_file):
38 | if isinstance(filename_or_file, str):
39 | self.file = open(filename_or_file, "wt")
40 | self.own_file = True
41 | else:
42 | assert hasattr(filename_or_file, "read"), (
43 | "expected file or str, got %s" % filename_or_file
44 | )
45 | self.file = filename_or_file
46 | self.own_file = False
47 |
48 | def writekvs(self, kvs):
49 | # Create strings for printing
50 | key2str = {}
51 | for (key, val) in sorted(kvs.items()):
52 | if hasattr(val, "__float__"):
53 | valstr = "%-8.3g" % val
54 | else:
55 | valstr = str(val)
56 | key2str[self._truncate(key)] = self._truncate(valstr)
57 |
58 | # Find max widths
59 | if len(key2str) == 0:
60 | print("WARNING: tried to write empty key-value dict")
61 | return
62 | else:
63 | keywidth = max(map(len, key2str.keys()))
64 | valwidth = max(map(len, key2str.values()))
65 |
66 | # Write out the data
67 | dashes = "-" * (keywidth + valwidth + 7)
68 | lines = [dashes]
69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70 | lines.append(
71 | "| %s%s | %s%s |"
72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
73 | )
74 | lines.append(dashes)
75 | self.file.write("\n".join(lines) + "\n")
76 |
77 | # Flush the output to the file
78 | self.file.flush()
79 |
80 | def _truncate(self, s):
81 | maxlen = 30
82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s
83 |
84 | def writeseq(self, seq):
85 | seq = list(seq)
86 | for (i, elem) in enumerate(seq):
87 | self.file.write(elem)
88 | if i < len(seq) - 1: # add space unless this is the last one
89 | self.file.write(" ")
90 | self.file.write("\n")
91 | self.file.flush()
92 |
93 | def close(self):
94 | if self.own_file:
95 | self.file.close()
96 |
97 |
98 | class JSONOutputFormat(KVWriter):
99 | def __init__(self, filename):
100 | self.file = open(filename, "wt")
101 |
102 | def writekvs(self, kvs):
103 | for k, v in sorted(kvs.items()):
104 | if hasattr(v, "dtype"):
105 | kvs[k] = float(v)
106 | self.file.write(json.dumps(kvs) + "\n")
107 | self.file.flush()
108 |
109 | def close(self):
110 | self.file.close()
111 |
112 |
113 | class CSVOutputFormat(KVWriter):
114 | def __init__(self, filename):
115 | self.file = open(filename, "w+t")
116 | self.keys = []
117 | self.sep = ","
118 |
119 | def writekvs(self, kvs):
120 | # Add our current row to the history
121 | extra_keys = list(kvs.keys() - self.keys)
122 | extra_keys.sort()
123 | if extra_keys:
124 | self.keys.extend(extra_keys)
125 | self.file.seek(0)
126 | lines = self.file.readlines()
127 | self.file.seek(0)
128 | for (i, k) in enumerate(self.keys):
129 | if i > 0:
130 | self.file.write(",")
131 | self.file.write(k)
132 | self.file.write("\n")
133 | for line in lines[1:]:
134 | self.file.write(line[:-1])
135 | self.file.write(self.sep * len(extra_keys))
136 | self.file.write("\n")
137 | for (i, k) in enumerate(self.keys):
138 | if i > 0:
139 | self.file.write(",")
140 | v = kvs.get(k)
141 | if v is not None:
142 | self.file.write(str(v))
143 | self.file.write("\n")
144 | self.file.flush()
145 |
146 | def close(self):
147 | self.file.close()
148 |
149 |
150 | class TensorBoardOutputFormat(KVWriter):
151 | """
152 | Dumps key/value pairs into TensorBoard's numeric format.
153 | """
154 |
155 | def __init__(self, dir):
156 | os.makedirs(dir, exist_ok=True)
157 | self.dir = dir
158 | self.step = 1
159 | prefix = "events"
160 | path = osp.join(osp.abspath(dir), prefix)
161 | import tensorflow as tf
162 | from tensorflow.python import pywrap_tensorflow
163 | from tensorflow.core.util import event_pb2
164 | from tensorflow.python.util import compat
165 |
166 | self.tf = tf
167 | self.event_pb2 = event_pb2
168 | self.pywrap_tensorflow = pywrap_tensorflow
169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
170 |
171 | def writekvs(self, kvs):
172 | def summary_val(k, v):
173 | kwargs = {"tag": k, "simple_value": float(v)}
174 | return self.tf.Summary.Value(**kwargs)
175 |
176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
178 | event.step = (
179 | self.step
180 | ) # is there any reason why you'd want to specify the step?
181 | self.writer.WriteEvent(event)
182 | self.writer.Flush()
183 | self.step += 1
184 |
185 | def close(self):
186 | if self.writer:
187 | self.writer.Close()
188 | self.writer = None
189 |
190 |
191 | def make_output_format(format, ev_dir, log_suffix=""):
192 | os.makedirs(ev_dir, exist_ok=True)
193 | if format == "stdout":
194 | return HumanOutputFormat(sys.stdout)
195 | elif format == "log":
196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
197 | elif format == "json":
198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
199 | elif format == "csv":
200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
201 | elif format == "tensorboard":
202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
203 | else:
204 | raise ValueError("Unknown format specified: %s" % (format,))
205 |
206 |
207 | # ================================================================
208 | # API
209 | # ================================================================
210 |
211 |
212 | def logkv(key, val):
213 | """
214 | Log a value of some diagnostic
215 | Call this once for each diagnostic quantity, each iteration
216 | If called many times, last value will be used.
217 | """
218 | get_current().logkv(key, val)
219 |
220 |
221 | def logkv_mean(key, val):
222 | """
223 | The same as logkv(), but if called many times, values averaged.
224 | """
225 | get_current().logkv_mean(key, val)
226 |
227 |
228 | def logkvs(d):
229 | """
230 | Log a dictionary of key-value pairs
231 | """
232 | for (k, v) in d.items():
233 | logkv(k, v)
234 |
235 |
236 | def dumpkvs():
237 | """
238 | Write all of the diagnostics from the current iteration
239 | """
240 | return get_current().dumpkvs()
241 |
242 |
243 | def getkvs():
244 | return get_current().name2val
245 |
246 |
247 | def log(*args, level=INFO):
248 | """
249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
250 | """
251 | get_current().log(*args, level=level)
252 |
253 |
254 | def debug(*args):
255 | log(*args, level=DEBUG)
256 |
257 |
258 | def info(*args):
259 | log(*args, level=INFO)
260 |
261 |
262 | def warn(*args):
263 | log(*args, level=WARN)
264 |
265 |
266 | def error(*args):
267 | log(*args, level=ERROR)
268 |
269 |
270 | def set_level(level):
271 | """
272 | Set logging threshold on current logger.
273 | """
274 | get_current().set_level(level)
275 |
276 |
277 | def set_comm(comm):
278 | get_current().set_comm(comm)
279 |
280 |
281 | def get_dir():
282 | """
283 | Get directory that log files are being written to.
284 | will be None if there is no output directory (i.e., if you didn't call start)
285 | """
286 | return get_current().get_dir()
287 |
288 |
289 | record_tabular = logkv
290 | dump_tabular = dumpkvs
291 |
292 |
293 | @contextmanager
294 | def profile_kv(scopename):
295 | logkey = "wait_" + scopename
296 | tstart = time.time()
297 | try:
298 | yield
299 | finally:
300 | get_current().name2val[logkey] += time.time() - tstart
301 |
302 |
303 | def profile(n):
304 | """
305 | Usage:
306 | @profile("my_func")
307 | def my_func(): code
308 | """
309 |
310 | def decorator_with_name(func):
311 | def func_wrapper(*args, **kwargs):
312 | with profile_kv(n):
313 | return func(*args, **kwargs)
314 |
315 | return func_wrapper
316 |
317 | return decorator_with_name
318 |
319 |
320 | # ================================================================
321 | # Backend
322 | # ================================================================
323 |
324 |
325 | def get_current():
326 | if Logger.CURRENT is None:
327 | _configure_default_logger()
328 |
329 | return Logger.CURRENT
330 |
331 |
332 | class Logger(object):
333 | DEFAULT = None # A logger with no output files. (See right below class definition)
334 | # So that you can still log to the terminal without setting up any output files
335 | CURRENT = None # Current logger being used by the free functions above
336 |
337 | def __init__(self, dir, output_formats, comm=None):
338 | self.name2val = defaultdict(float) # values this iteration
339 | self.name2cnt = defaultdict(int)
340 | self.level = INFO
341 | self.dir = dir
342 | self.output_formats = output_formats
343 | self.comm = comm
344 |
345 | # Logging API, forwarded
346 | # ----------------------------------------
347 | def logkv(self, key, val):
348 | self.name2val[key] = val
349 |
350 | def logkv_mean(self, key, val):
351 | oldval, cnt = self.name2val[key], self.name2cnt[key]
352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
353 | self.name2cnt[key] = cnt + 1
354 |
355 | def dumpkvs(self):
356 | if self.comm is None:
357 | d = self.name2val
358 | else:
359 | d = mpi_weighted_mean(
360 | self.comm,
361 | {
362 | name: (val, self.name2cnt.get(name, 1))
363 | for (name, val) in self.name2val.items()
364 | },
365 | )
366 | if self.comm.rank != 0:
367 | d["dummy"] = 1 # so we don't get a warning about empty dict
368 | out = d.copy() # Return the dict for unit testing purposes
369 | for fmt in self.output_formats:
370 | if isinstance(fmt, KVWriter):
371 | fmt.writekvs(d)
372 | self.name2val.clear()
373 | self.name2cnt.clear()
374 | return out
375 |
376 | def log(self, *args, level=INFO):
377 | if self.level <= level:
378 | self._do_log(args)
379 |
380 | # Configuration
381 | # ----------------------------------------
382 | def set_level(self, level):
383 | self.level = level
384 |
385 | def set_comm(self, comm):
386 | self.comm = comm
387 |
388 | def get_dir(self):
389 | return self.dir
390 |
391 | def close(self):
392 | for fmt in self.output_formats:
393 | fmt.close()
394 |
395 | # Misc
396 | # ----------------------------------------
397 | def _do_log(self, args):
398 | for fmt in self.output_formats:
399 | if isinstance(fmt, SeqWriter):
400 | fmt.writeseq(map(str, args))
401 |
402 |
403 | def get_rank_without_mpi_import():
404 | # check environment variables here instead of importing mpi4py
405 | # to avoid calling MPI_Init() when this module is imported
406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
407 | if varname in os.environ:
408 | return int(os.environ[varname])
409 | return 0
410 |
411 |
412 | def mpi_weighted_mean(comm, local_name2valcount):
413 | """
414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
415 | Perform a weighted average over dicts that are each on a different node
416 | Input: local_name2valcount: dict mapping key -> (value, count)
417 | Returns: key -> mean
418 | """
419 | all_name2valcount = comm.gather(local_name2valcount)
420 | if comm.rank == 0:
421 | name2sum = defaultdict(float)
422 | name2count = defaultdict(float)
423 | for n2vc in all_name2valcount:
424 | for (name, (val, count)) in n2vc.items():
425 | try:
426 | val = float(val)
427 | except ValueError:
428 | if comm.rank == 0:
429 | warnings.warn(
430 | "WARNING: tried to compute mean on non-float {}={}".format(
431 | name, val
432 | )
433 | )
434 | else:
435 | name2sum[name] += val * count
436 | name2count[name] += count
437 | return {name: name2sum[name] / name2count[name] for name in name2sum}
438 | else:
439 | return {}
440 |
441 |
442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
443 | """
444 | If comm is provided, average all numerical stats across that comm
445 | """
446 | if dir is None:
447 | dir = os.getenv("OPENAI_LOGDIR")
448 | if dir is None:
449 | dir = osp.join(
450 | tempfile.gettempdir(),
451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
452 | )
453 | assert isinstance(dir, str)
454 | dir = os.path.expanduser(dir)
455 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
456 |
457 | rank = get_rank_without_mpi_import()
458 | if rank > 0:
459 | log_suffix = log_suffix + "-rank%03i" % rank
460 |
461 | if format_strs is None:
462 | if rank == 0:
463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
464 | else:
465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
466 | format_strs = filter(None, format_strs)
467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
468 |
469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
470 | if output_formats:
471 | log("Logging to %s" % dir)
472 |
473 |
474 | def _configure_default_logger():
475 | configure()
476 | Logger.DEFAULT = Logger.CURRENT
477 |
478 |
479 | def reset():
480 | if Logger.CURRENT is not Logger.DEFAULT:
481 | Logger.CURRENT.close()
482 | Logger.CURRENT = Logger.DEFAULT
483 | log("Reset logger")
484 |
485 |
486 | @contextmanager
487 | def scoped_configure(dir=None, format_strs=None, comm=None):
488 | prevlogger = Logger.CURRENT
489 | configure(dir=dir, format_strs=format_strs, comm=comm)
490 | try:
491 | yield
492 | finally:
493 | Logger.CURRENT.close()
494 | Logger.CURRENT = prevlogger
495 |
496 |
--------------------------------------------------------------------------------
/guided_diffusion/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a Gaussian distribution discretizing to a
53 | given image.
54 |
55 | :param x: the target images. It is assumed that this was uint8 values,
56 | rescaled to the range [-1, 1].
57 | :param means: the Gaussian mean Tensor.
58 | :param log_scales: the Gaussian log stddev Tensor.
59 | :return: a tensor like x of log probabilities (in nats).
60 | """
61 | assert x.shape == means.shape == log_scales.shape
62 | centered_x = x - means
63 | inv_stdv = th.exp(-log_scales)
64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65 | cdf_plus = approx_standard_normal_cdf(plus_in)
66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67 | cdf_min = approx_standard_normal_cdf(min_in)
68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70 | cdf_delta = cdf_plus - cdf_min
71 | log_probs = th.where(
72 | x < -0.999,
73 | log_cdf_plus,
74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75 | )
76 | assert log_probs.shape == x.shape
77 | return log_probs
78 |
--------------------------------------------------------------------------------
/guided_diffusion/models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def get_timestep_embedding(timesteps, embedding_dim):
7 | """
8 | This matches the implementation in Denoising Diffusion Probabilistic Models:
9 | From Fairseq.
10 | Build sinusoidal embeddings.
11 | This matches the implementation in tensor2tensor, but differs slightly
12 | from the description in Section 3.5 of "Attention Is All You Need".
13 | """
14 | assert len(timesteps.shape) == 1
15 |
16 | half_dim = embedding_dim // 2
17 | emb = math.log(10000) / (half_dim - 1)
18 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
19 | emb = emb.to(device=timesteps.device)
20 | emb = timesteps.float()[:, None] * emb[None, :]
21 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
22 | if embedding_dim % 2 == 1: # zero pad
23 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
24 | return emb
25 |
26 |
27 | def nonlinearity(x):
28 | # swish
29 | return x*torch.sigmoid(x)
30 |
31 |
32 | def Normalize(in_channels):
33 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
34 |
35 |
36 | class Upsample(nn.Module):
37 | def __init__(self, in_channels, with_conv):
38 | super().__init__()
39 | self.with_conv = with_conv
40 | if self.with_conv:
41 | self.conv = torch.nn.Conv2d(in_channels,
42 | in_channels,
43 | kernel_size=3,
44 | stride=1,
45 | padding=1)
46 |
47 | def forward(self, x):
48 | x = torch.nn.functional.interpolate(
49 | x, scale_factor=2.0, mode="nearest")
50 | if self.with_conv:
51 | x = self.conv(x)
52 | return x
53 |
54 |
55 | class Downsample(nn.Module):
56 | def __init__(self, in_channels, with_conv):
57 | super().__init__()
58 | self.with_conv = with_conv
59 | if self.with_conv:
60 | # no asymmetric padding in torch conv, must do it ourselves
61 | self.conv = torch.nn.Conv2d(in_channels,
62 | in_channels,
63 | kernel_size=3,
64 | stride=2,
65 | padding=0)
66 |
67 | def forward(self, x):
68 | if self.with_conv:
69 | pad = (0, 1, 0, 1)
70 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
71 | x = self.conv(x)
72 | else:
73 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
74 | return x
75 |
76 |
77 | class ResnetBlock(nn.Module):
78 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
79 | dropout, temb_channels=512):
80 | super().__init__()
81 | self.in_channels = in_channels
82 | out_channels = in_channels if out_channels is None else out_channels
83 | self.out_channels = out_channels
84 | self.use_conv_shortcut = conv_shortcut
85 |
86 | self.norm1 = Normalize(in_channels)
87 | self.conv1 = torch.nn.Conv2d(in_channels,
88 | out_channels,
89 | kernel_size=3,
90 | stride=1,
91 | padding=1)
92 | self.temb_proj = torch.nn.Linear(temb_channels,
93 | out_channels)
94 | self.norm2 = Normalize(out_channels)
95 | self.dropout = torch.nn.Dropout(dropout)
96 | self.conv2 = torch.nn.Conv2d(out_channels,
97 | out_channels,
98 | kernel_size=3,
99 | stride=1,
100 | padding=1)
101 | if self.in_channels != self.out_channels:
102 | if self.use_conv_shortcut:
103 | self.conv_shortcut = torch.nn.Conv2d(in_channels,
104 | out_channels,
105 | kernel_size=3,
106 | stride=1,
107 | padding=1)
108 | else:
109 | self.nin_shortcut = torch.nn.Conv2d(in_channels,
110 | out_channels,
111 | kernel_size=1,
112 | stride=1,
113 | padding=0)
114 |
115 | def forward(self, x, temb):
116 | h = x
117 | h = self.norm1(h)
118 | h = nonlinearity(h)
119 | h = self.conv1(h)
120 |
121 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
122 |
123 | h = self.norm2(h)
124 | h = nonlinearity(h)
125 | h = self.dropout(h)
126 | h = self.conv2(h)
127 |
128 | if self.in_channels != self.out_channels:
129 | if self.use_conv_shortcut:
130 | x = self.conv_shortcut(x)
131 | else:
132 | x = self.nin_shortcut(x)
133 |
134 | return x+h
135 |
136 |
137 | class AttnBlock(nn.Module):
138 | def __init__(self, in_channels):
139 | super().__init__()
140 | self.in_channels = in_channels
141 |
142 | self.norm = Normalize(in_channels)
143 | self.q = torch.nn.Conv2d(in_channels,
144 | in_channels,
145 | kernel_size=1,
146 | stride=1,
147 | padding=0)
148 | self.k = torch.nn.Conv2d(in_channels,
149 | in_channels,
150 | kernel_size=1,
151 | stride=1,
152 | padding=0)
153 | self.v = torch.nn.Conv2d(in_channels,
154 | in_channels,
155 | kernel_size=1,
156 | stride=1,
157 | padding=0)
158 | self.proj_out = torch.nn.Conv2d(in_channels,
159 | in_channels,
160 | kernel_size=1,
161 | stride=1,
162 | padding=0)
163 |
164 | def forward(self, x):
165 | h_ = x
166 | h_ = self.norm(h_)
167 | q = self.q(h_)
168 | k = self.k(h_)
169 | v = self.v(h_)
170 |
171 | # compute attention
172 | b, c, h, w = q.shape
173 | q = q.reshape(b, c, h*w)
174 | q = q.permute(0, 2, 1) # b,hw,c
175 | k = k.reshape(b, c, h*w) # b,c,hw
176 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
177 | w_ = w_ * (int(c)**(-0.5))
178 | w_ = torch.nn.functional.softmax(w_, dim=2)
179 |
180 | # attend to values
181 | v = v.reshape(b, c, h*w)
182 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
183 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
184 | h_ = torch.bmm(v, w_)
185 | h_ = h_.reshape(b, c, h, w)
186 |
187 | h_ = self.proj_out(h_)
188 |
189 | return x+h_
190 |
191 |
192 | class Model(nn.Module):
193 | def __init__(self, config):
194 | super().__init__()
195 | self.config = config
196 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
197 | num_res_blocks = config.model.num_res_blocks
198 | attn_resolutions = config.model.attn_resolutions
199 | dropout = config.model.dropout
200 | in_channels = config.model.in_channels
201 | resolution = config.data.image_size
202 | resamp_with_conv = config.model.resamp_with_conv
203 | num_timesteps = config.diffusion.num_diffusion_timesteps
204 |
205 | if config.model.type == 'bayesian':
206 | self.logvar = nn.Parameter(torch.zeros(num_timesteps))
207 |
208 | self.ch = ch
209 | self.temb_ch = self.ch*4
210 | self.num_resolutions = len(ch_mult)
211 | self.num_res_blocks = num_res_blocks
212 | self.resolution = resolution
213 | self.in_channels = in_channels
214 |
215 | # timestep embedding
216 | self.temb = nn.Module()
217 | self.temb.dense = nn.ModuleList([
218 | torch.nn.Linear(self.ch,
219 | self.temb_ch),
220 | torch.nn.Linear(self.temb_ch,
221 | self.temb_ch),
222 | ])
223 |
224 | # downsampling
225 | self.conv_in = torch.nn.Conv2d(in_channels,
226 | self.ch,
227 | kernel_size=3,
228 | stride=1,
229 | padding=1)
230 |
231 | curr_res = resolution
232 | in_ch_mult = (1,)+ch_mult
233 | self.down = nn.ModuleList()
234 | block_in = None
235 | for i_level in range(self.num_resolutions):
236 | block = nn.ModuleList()
237 | attn = nn.ModuleList()
238 | block_in = ch*in_ch_mult[i_level]
239 | block_out = ch*ch_mult[i_level]
240 | for i_block in range(self.num_res_blocks):
241 | block.append(ResnetBlock(in_channels=block_in,
242 | out_channels=block_out,
243 | temb_channels=self.temb_ch,
244 | dropout=dropout))
245 | block_in = block_out
246 | if curr_res in attn_resolutions:
247 | attn.append(AttnBlock(block_in))
248 | down = nn.Module()
249 | down.block = block
250 | down.attn = attn
251 | if i_level != self.num_resolutions-1:
252 | down.downsample = Downsample(block_in, resamp_with_conv)
253 | curr_res = curr_res // 2
254 | self.down.append(down)
255 |
256 | # middle
257 | self.mid = nn.Module()
258 | self.mid.block_1 = ResnetBlock(in_channels=block_in,
259 | out_channels=block_in,
260 | temb_channels=self.temb_ch,
261 | dropout=dropout)
262 | self.mid.attn_1 = AttnBlock(block_in)
263 | self.mid.block_2 = ResnetBlock(in_channels=block_in,
264 | out_channels=block_in,
265 | temb_channels=self.temb_ch,
266 | dropout=dropout)
267 |
268 | # upsampling
269 | self.up = nn.ModuleList()
270 | for i_level in reversed(range(self.num_resolutions)):
271 | block = nn.ModuleList()
272 | attn = nn.ModuleList()
273 | block_out = ch*ch_mult[i_level]
274 | skip_in = ch*ch_mult[i_level]
275 | for i_block in range(self.num_res_blocks+1):
276 | if i_block == self.num_res_blocks:
277 | skip_in = ch*in_ch_mult[i_level]
278 | block.append(ResnetBlock(in_channels=block_in+skip_in,
279 | out_channels=block_out,
280 | temb_channels=self.temb_ch,
281 | dropout=dropout))
282 | block_in = block_out
283 | if curr_res in attn_resolutions:
284 | attn.append(AttnBlock(block_in))
285 | up = nn.Module()
286 | up.block = block
287 | up.attn = attn
288 | if i_level != 0:
289 | up.upsample = Upsample(block_in, resamp_with_conv)
290 | curr_res = curr_res * 2
291 | self.up.insert(0, up) # prepend to get consistent order
292 |
293 | # end
294 | self.norm_out = Normalize(block_in)
295 | self.conv_out = torch.nn.Conv2d(block_in,
296 | out_ch,
297 | kernel_size=3,
298 | stride=1,
299 | padding=1)
300 |
301 | def forward(self, x, t):
302 | assert x.shape[2] == x.shape[3] == self.resolution
303 |
304 | # timestep embedding
305 | temb = get_timestep_embedding(t, self.ch)
306 | temb = self.temb.dense[0](temb)
307 | temb = nonlinearity(temb)
308 | temb = self.temb.dense[1](temb)
309 |
310 | # downsampling
311 | hs = [self.conv_in(x)]
312 | for i_level in range(self.num_resolutions):
313 | for i_block in range(self.num_res_blocks):
314 | h = self.down[i_level].block[i_block](hs[-1], temb)
315 | if len(self.down[i_level].attn) > 0:
316 | h = self.down[i_level].attn[i_block](h)
317 | hs.append(h)
318 | if i_level != self.num_resolutions-1:
319 | hs.append(self.down[i_level].downsample(hs[-1]))
320 |
321 | # middle
322 | h = hs[-1]
323 | h = self.mid.block_1(h, temb)
324 | h = self.mid.attn_1(h)
325 | h = self.mid.block_2(h, temb)
326 |
327 | # upsampling
328 | for i_level in reversed(range(self.num_resolutions)):
329 | for i_block in range(self.num_res_blocks+1):
330 | h = self.up[i_level].block[i_block](
331 | torch.cat([h, hs.pop()], dim=1), temb)
332 | if len(self.up[i_level].attn) > 0:
333 | h = self.up[i_level].attn[i_block](h)
334 | if i_level != 0:
335 | h = self.up[i_level].upsample(h)
336 |
337 | # end
338 | h = self.norm_out(h)
339 | h = nonlinearity(h)
340 | h = self.conv_out(h)
341 | return h
342 |
--------------------------------------------------------------------------------
/guided_diffusion/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/guided_diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "loss-second-moment":
18 | return LossSecondMomentResampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(ABC):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / np.sum(w)
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 | indices = th.from_numpy(indices_np).long().to(device)
56 | weights_np = 1 / (len(p) * p[indices_np])
57 | weights = th.from_numpy(weights_np).float().to(device)
58 | return indices, weights
59 |
60 |
61 | class UniformSampler(ScheduleSampler):
62 | def __init__(self, diffusion):
63 | self.diffusion = diffusion
64 | self._weights = np.ones([diffusion.num_timesteps])
65 |
66 | def weights(self):
67 | return self._weights
68 |
69 |
70 | class LossAwareSampler(ScheduleSampler):
71 | def update_with_local_losses(self, local_ts, local_losses):
72 | """
73 | Update the reweighting using losses from a model.
74 |
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 |
80 | :param local_ts: an integer Tensor of timesteps.
81 | :param local_losses: a 1D Tensor of losses.
82 | """
83 | batch_sizes = [
84 | th.tensor([0], dtype=th.int32, device=local_ts.device)
85 | for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(
88 | batch_sizes,
89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90 | )
91 |
92 | # Pad all_gather batches to be the maximum batch size.
93 | batch_sizes = [x.item() for x in batch_sizes]
94 | max_bs = max(batch_sizes)
95 |
96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 | dist.all_gather(timestep_batches, local_ts)
99 | dist.all_gather(loss_batches, local_losses)
100 | timesteps = [
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 | ]
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 | self.update_with_all_losses(timesteps, losses)
105 |
106 | @abstractmethod
107 | def update_with_all_losses(self, ts, losses):
108 | """
109 | Update the reweighting using losses from a model.
110 |
111 | Sub-classes should override this method to update the reweighting
112 | using losses from the model.
113 |
114 | This method directly updates the reweighting without synchronizing
115 | between workers. It is called by update_with_local_losses from all
116 | ranks with identical arguments. Thus, it should have deterministic
117 | behavior to maintain state across workers.
118 |
119 | :param ts: a list of int timesteps.
120 | :param losses: a list of float losses, one per timestep.
121 | """
122 |
123 |
124 | class LossSecondMomentResampler(LossAwareSampler):
125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 | self.diffusion = diffusion
127 | self.history_per_term = history_per_term
128 | self.uniform_prob = uniform_prob
129 | self._loss_history = np.zeros(
130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 | )
132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133 |
134 | def weights(self):
135 | if not self._warmed_up():
136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 | weights /= np.sum(weights)
139 | weights *= 1 - self.uniform_prob
140 | weights += self.uniform_prob / len(weights)
141 | return weights
142 |
143 | def update_with_all_losses(self, ts, losses):
144 | for t, loss in zip(ts, losses):
145 | if self._loss_counts[t] == self.history_per_term:
146 | # Shift out the oldest loss term.
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 | self._loss_history[t, -1] = loss
149 | else:
150 | self._loss_history[t, self._loss_counts[t]] = loss
151 | self._loss_counts[t] += 1
152 |
153 | def _warmed_up(self):
154 | return (self._loss_counts == self.history_per_term).all()
155 |
--------------------------------------------------------------------------------
/guided_diffusion/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 | last_alpha_cumprod = 1.0
79 | new_betas = []
80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 | if i in self.use_timesteps:
82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 | last_alpha_cumprod = alpha_cumprod
84 | self.timestep_map.append(i)
85 | kwargs["betas"] = np.array(new_betas)
86 | super().__init__(**kwargs)
87 |
88 | def p_mean_variance(
89 | self, model, *args, **kwargs
90 | ): # pylint: disable=signature-differs
91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92 |
93 | def training_losses(
94 | self, model, *args, **kwargs
95 | ): # pylint: disable=signature-differs
96 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
97 |
98 | def condition_mean(self, cond_fn, *args, **kwargs):
99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100 |
101 | def condition_score(self, cond_fn, *args, **kwargs):
102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def _wrap_model(self, model):
105 | if isinstance(model, _WrappedModel):
106 | return model
107 | return _WrappedModel(
108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109 | )
110 |
111 | def _scale_timesteps(self, t):
112 | # Scaling is done by the wrapped model.
113 | return t
114 |
115 |
116 | class _WrappedModel:
117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118 | self.model = model
119 | self.timestep_map = timestep_map
120 | self.rescale_timesteps = rescale_timesteps
121 | self.original_num_steps = original_num_steps
122 |
123 | def __call__(self, x, ts, **kwargs):
124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125 | new_ts = map_tensor[ts]
126 | if self.rescale_timesteps:
127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128 | return self.model(x, new_ts, **kwargs)
129 |
--------------------------------------------------------------------------------
/guided_diffusion/script_util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 |
4 | from . import gaussian_diffusion as gd
5 | from .respace import SpacedDiffusion, space_timesteps
6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel
7 |
8 | NUM_CLASSES = 1000
9 |
10 |
11 | def diffusion_defaults():
12 | """
13 | Defaults for image and classifier training.
14 | """
15 | return dict(
16 | learn_sigma=False,
17 | diffusion_steps=1000,
18 | noise_schedule="linear",
19 | timestep_respacing="",
20 | use_kl=False,
21 | predict_xstart=False,
22 | rescale_timesteps=False,
23 | rescale_learned_sigmas=False,
24 | )
25 |
26 |
27 | def classifier_defaults():
28 | """
29 | Defaults for classifier models.
30 | """
31 | return dict(
32 | image_size=64,
33 | classifier_use_fp16=False,
34 | classifier_width=128,
35 | classifier_depth=2,
36 | classifier_attention_resolutions="32,16,8", # 16
37 | classifier_use_scale_shift_norm=True, # False
38 | classifier_resblock_updown=True, # False
39 | classifier_pool="attention",
40 | )
41 |
42 |
43 | def model_and_diffusion_defaults():
44 | """
45 | Defaults for image training.
46 | """
47 | res = dict(
48 | image_size=64,
49 | num_channels=128,
50 | num_res_blocks=2,
51 | num_heads=4,
52 | num_heads_upsample=-1,
53 | num_head_channels=-1,
54 | attention_resolutions="16,8",
55 | channel_mult="",
56 | dropout=0.0,
57 | class_cond=False,
58 | use_checkpoint=False,
59 | use_scale_shift_norm=True,
60 | resblock_updown=False,
61 | use_fp16=False,
62 | use_new_attention_order=False,
63 | )
64 | res.update(diffusion_defaults())
65 | return res
66 |
67 |
68 | def classifier_and_diffusion_defaults():
69 | res = classifier_defaults()
70 | res.update(diffusion_defaults())
71 | return res
72 |
73 |
74 | def create_model_and_diffusion(
75 | image_size,
76 | class_cond,
77 | learn_sigma,
78 | num_channels,
79 | num_res_blocks,
80 | channel_mult,
81 | num_heads,
82 | num_head_channels,
83 | num_heads_upsample,
84 | attention_resolutions,
85 | dropout,
86 | diffusion_steps,
87 | noise_schedule,
88 | timestep_respacing,
89 | use_kl,
90 | predict_xstart,
91 | rescale_timesteps,
92 | rescale_learned_sigmas,
93 | use_checkpoint,
94 | use_scale_shift_norm,
95 | resblock_updown,
96 | use_fp16,
97 | use_new_attention_order,
98 | ):
99 | model = create_model(
100 | image_size,
101 | num_channels,
102 | num_res_blocks,
103 | channel_mult=channel_mult,
104 | learn_sigma=learn_sigma,
105 | class_cond=class_cond,
106 | use_checkpoint=use_checkpoint,
107 | attention_resolutions=attention_resolutions,
108 | num_heads=num_heads,
109 | num_head_channels=num_head_channels,
110 | num_heads_upsample=num_heads_upsample,
111 | use_scale_shift_norm=use_scale_shift_norm,
112 | dropout=dropout,
113 | resblock_updown=resblock_updown,
114 | use_fp16=use_fp16,
115 | use_new_attention_order=use_new_attention_order,
116 | )
117 | diffusion = create_gaussian_diffusion(
118 | steps=diffusion_steps,
119 | learn_sigma=learn_sigma,
120 | noise_schedule=noise_schedule,
121 | use_kl=use_kl,
122 | predict_xstart=predict_xstart,
123 | rescale_timesteps=rescale_timesteps,
124 | rescale_learned_sigmas=rescale_learned_sigmas,
125 | timestep_respacing=timestep_respacing,
126 | )
127 | return model, diffusion
128 |
129 |
130 | def create_model(
131 | image_size,
132 | num_channels,
133 | num_res_blocks,
134 | channel_mult="",
135 | learn_sigma=False,
136 | class_cond=False,
137 | use_checkpoint=False,
138 | attention_resolutions="16",
139 | num_heads=1,
140 | num_head_channels=-1,
141 | num_heads_upsample=-1,
142 | use_scale_shift_norm=False,
143 | dropout=0,
144 | resblock_updown=False,
145 | use_fp16=False,
146 | use_new_attention_order=False,
147 | **kwargs
148 | ):
149 |
150 |
151 |
152 | if channel_mult == "":
153 | if image_size == 512:
154 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
155 | elif image_size == 256 or image_size == 320:
156 | channel_mult = (1, 1, 2, 2, 4, 4)
157 | elif image_size == 128:
158 | channel_mult = (1, 1, 2, 3, 4)
159 | elif image_size == 64:
160 | channel_mult = (1, 2, 3, 4)
161 | else:
162 | raise ValueError(f"unsupported image size: {image_size}")
163 | else:
164 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
165 |
166 | attention_ds = []
167 | for res in attention_resolutions.split(","):
168 | attention_ds.append(image_size // int(res))
169 |
170 | return UNetModel(
171 | image_size=image_size,
172 | in_channels=kwargs["in_channels"],
173 | model_channels=num_channels,
174 | out_channels=(kwargs["in_channels"] if not learn_sigma else kwargs["in_channels"] * 2),
175 | num_res_blocks=num_res_blocks,
176 | attention_resolutions=tuple(attention_ds),
177 | dropout=dropout,
178 | channel_mult=channel_mult,
179 | num_classes=(NUM_CLASSES if class_cond else None),
180 | use_checkpoint=use_checkpoint,
181 | use_fp16=use_fp16,
182 | num_heads=num_heads,
183 | num_head_channels=num_head_channels,
184 | num_heads_upsample=num_heads_upsample,
185 | use_scale_shift_norm=use_scale_shift_norm,
186 | resblock_updown=resblock_updown,
187 | use_new_attention_order=use_new_attention_order,
188 | use_spacecode = kwargs["use_spacecode"],
189 | )
190 |
191 |
192 | def create_classifier_and_diffusion(
193 | image_size,
194 | classifier_use_fp16,
195 | classifier_width,
196 | classifier_depth,
197 | classifier_attention_resolutions,
198 | classifier_use_scale_shift_norm,
199 | classifier_resblock_updown,
200 | classifier_pool,
201 | learn_sigma,
202 | diffusion_steps,
203 | noise_schedule,
204 | timestep_respacing,
205 | use_kl,
206 | predict_xstart,
207 | rescale_timesteps,
208 | rescale_learned_sigmas,
209 | ):
210 | classifier = create_classifier(
211 | image_size,
212 | classifier_use_fp16,
213 | classifier_width,
214 | classifier_depth,
215 | classifier_attention_resolutions,
216 | classifier_use_scale_shift_norm,
217 | classifier_resblock_updown,
218 | classifier_pool,
219 | )
220 | diffusion = create_gaussian_diffusion(
221 | steps=diffusion_steps,
222 | learn_sigma=learn_sigma,
223 | noise_schedule=noise_schedule,
224 | use_kl=use_kl,
225 | predict_xstart=predict_xstart,
226 | rescale_timesteps=rescale_timesteps,
227 | rescale_learned_sigmas=rescale_learned_sigmas,
228 | timestep_respacing=timestep_respacing,
229 | )
230 | return classifier, diffusion
231 |
232 |
233 | def create_classifier(
234 | image_size,
235 | classifier_use_fp16,
236 | classifier_width,
237 | classifier_depth,
238 | classifier_attention_resolutions,
239 | classifier_use_scale_shift_norm,
240 | classifier_resblock_updown,
241 | classifier_pool,
242 | ):
243 | if image_size == 512:
244 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
245 | elif image_size == 256:
246 | channel_mult = (1, 1, 2, 2, 4, 4)
247 | elif image_size == 128:
248 | channel_mult = (1, 1, 2, 3, 4)
249 | elif image_size == 64:
250 | channel_mult = (1, 2, 3, 4)
251 | else:
252 | raise ValueError(f"unsupported image size: {image_size}")
253 |
254 | attention_ds = []
255 | for res in classifier_attention_resolutions.split(","):
256 | attention_ds.append(image_size // int(res))
257 |
258 | return EncoderUNetModel(
259 | image_size=image_size,
260 | in_channels=3,
261 | model_channels=classifier_width,
262 | out_channels=1000,
263 | num_res_blocks=classifier_depth,
264 | attention_resolutions=tuple(attention_ds),
265 | channel_mult=channel_mult,
266 | use_fp16=classifier_use_fp16,
267 | num_head_channels=64,
268 | use_scale_shift_norm=classifier_use_scale_shift_norm,
269 | resblock_updown=classifier_resblock_updown,
270 | pool=classifier_pool,
271 | )
272 |
273 |
274 | def sr_model_and_diffusion_defaults():
275 | res = model_and_diffusion_defaults()
276 | res["large_size"] = 256
277 | res["small_size"] = 64
278 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
279 | for k in res.copy().keys():
280 | if k not in arg_names:
281 | del res[k]
282 | return res
283 |
284 |
285 | def sr_create_model_and_diffusion(
286 | large_size,
287 | small_size,
288 | class_cond,
289 | learn_sigma,
290 | num_channels,
291 | num_res_blocks,
292 | num_heads,
293 | num_head_channels,
294 | num_heads_upsample,
295 | attention_resolutions,
296 | dropout,
297 | diffusion_steps,
298 | noise_schedule,
299 | timestep_respacing,
300 | use_kl,
301 | predict_xstart,
302 | rescale_timesteps,
303 | rescale_learned_sigmas,
304 | use_checkpoint,
305 | use_scale_shift_norm,
306 | resblock_updown,
307 | use_fp16,
308 | ):
309 | model = sr_create_model(
310 | large_size,
311 | small_size,
312 | num_channels,
313 | num_res_blocks,
314 | learn_sigma=learn_sigma,
315 | class_cond=class_cond,
316 | use_checkpoint=use_checkpoint,
317 | attention_resolutions=attention_resolutions,
318 | num_heads=num_heads,
319 | num_head_channels=num_head_channels,
320 | num_heads_upsample=num_heads_upsample,
321 | use_scale_shift_norm=use_scale_shift_norm,
322 | dropout=dropout,
323 | resblock_updown=resblock_updown,
324 | use_fp16=use_fp16,
325 | )
326 | diffusion = create_gaussian_diffusion(
327 | steps=diffusion_steps,
328 | learn_sigma=learn_sigma,
329 | noise_schedule=noise_schedule,
330 | use_kl=use_kl,
331 | predict_xstart=predict_xstart,
332 | rescale_timesteps=rescale_timesteps,
333 | rescale_learned_sigmas=rescale_learned_sigmas,
334 | timestep_respacing=timestep_respacing,
335 | )
336 | return model, diffusion
337 |
338 |
339 | def sr_create_model(
340 | large_size,
341 | small_size,
342 | num_channels,
343 | num_res_blocks,
344 | learn_sigma,
345 | class_cond,
346 | use_checkpoint,
347 | attention_resolutions,
348 | num_heads,
349 | num_head_channels,
350 | num_heads_upsample,
351 | use_scale_shift_norm,
352 | dropout,
353 | resblock_updown,
354 | use_fp16,
355 | ):
356 | _ = small_size # hack to prevent unused variable
357 |
358 | if large_size == 512:
359 | channel_mult = (1, 1, 2, 2, 4, 4)
360 | elif large_size == 256:
361 | channel_mult = (1, 1, 2, 2, 4, 4)
362 | elif large_size == 64:
363 | channel_mult = (1, 2, 3, 4)
364 | else:
365 | raise ValueError(f"unsupported large size: {large_size}")
366 |
367 | attention_ds = []
368 | for res in attention_resolutions.split(","):
369 | attention_ds.append(large_size // int(res))
370 |
371 | return SuperResModel(
372 | image_size=large_size,
373 | in_channels=3,
374 | model_channels=num_channels,
375 | out_channels=(3 if not learn_sigma else 6),
376 | num_res_blocks=num_res_blocks,
377 | attention_resolutions=tuple(attention_ds),
378 | dropout=dropout,
379 | channel_mult=channel_mult,
380 | num_classes=(NUM_CLASSES if class_cond else None),
381 | use_checkpoint=use_checkpoint,
382 | num_heads=num_heads,
383 | num_head_channels=num_head_channels,
384 | num_heads_upsample=num_heads_upsample,
385 | use_scale_shift_norm=use_scale_shift_norm,
386 | resblock_updown=resblock_updown,
387 | use_fp16=use_fp16,
388 | )
389 |
390 |
391 | def create_gaussian_diffusion(
392 | *,
393 | steps=1000,
394 | learn_sigma=False,
395 | sigma_small=False,
396 | noise_schedule="linear",
397 | use_kl=False,
398 | predict_xstart=False,
399 | rescale_timesteps=False,
400 | rescale_learned_sigmas=False,
401 | timestep_respacing="",
402 | ):
403 | betas = gd.get_named_beta_schedule(noise_schedule, steps)
404 | if use_kl:
405 | loss_type = gd.LossType.RESCALED_KL
406 | elif rescale_learned_sigmas:
407 | loss_type = gd.LossType.RESCALED_MSE
408 | else:
409 | loss_type = gd.LossType.MSE
410 | if not timestep_respacing:
411 | timestep_respacing = [steps]
412 | return SpacedDiffusion(
413 | use_timesteps=space_timesteps(steps, timestep_respacing),
414 | betas=betas,
415 | model_mean_type=(
416 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
417 | ),
418 | model_var_type=(
419 | (
420 | gd.ModelVarType.FIXED_LARGE
421 | if not sigma_small
422 | else gd.ModelVarType.FIXED_SMALL
423 | )
424 | if not learn_sigma
425 | else gd.ModelVarType.LEARNED_RANGE
426 | ),
427 | loss_type=loss_type,
428 | rescale_timesteps=rescale_timesteps,
429 | )
430 |
431 |
432 | def add_dict_to_argparser(parser, default_dict):
433 | for k, v in default_dict.items():
434 | v_type = type(v)
435 | if v is None:
436 | v_type = str
437 | elif isinstance(v, bool):
438 | v_type = str2bool
439 | parser.add_argument(f"--{k}", default=v, type=v_type)
440 |
441 |
442 | def args_to_dict(args, keys):
443 | return {k: getattr(args, k) for k in keys}
444 |
445 |
446 | def str2bool(v):
447 | """
448 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
449 | """
450 | if isinstance(v, bool):
451 | return v
452 | if v.lower() in ("yes", "true", "t", "y", "1"):
453 | return True
454 | elif v.lower() in ("no", "false", "f", "n", "0"):
455 | return False
456 | else:
457 | raise argparse.ArgumentTypeError("boolean value expected")
458 |
--------------------------------------------------------------------------------
/guided_diffusion/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 | import matplotlib.pyplot as plt
5 |
6 | import blobfile as bf
7 | import torch as th
8 | import torch.distributed as dist
9 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
10 | from torch.optim import AdamW
11 | from pathlib import Path
12 |
13 | from . import dist_util, logger
14 | from .fp16_util import MixedPrecisionTrainer
15 | from .nn import update_ema
16 | from .resample import LossAwareSampler, UniformSampler
17 | from utils import clear
18 |
19 | # For ImageNet experiments, this was a good default value.
20 | # We found that the lg_loss_scale quickly climbed to
21 | # 20-21 within the first ~1K steps of training.
22 | INITIAL_LOG_LOSS_SCALE = 20.0
23 |
24 | class TrainLoop2:
25 | def __init__(
26 | self,
27 | *,
28 | model,
29 | diffusion,
30 | data,
31 | batch_size,
32 | microbatch,
33 | lr,
34 | ema_rate,
35 | log_interval,
36 | save_interval,
37 | resume_checkpoint,
38 | use_fp16=False,
39 | fp16_scale_growth=1e-3,
40 | schedule_sampler=None,
41 | weight_decay=0.0,
42 | lr_anneal_steps=0,
43 | ):
44 | self.model = model
45 | self.diffusion = diffusion
46 | self.data = data
47 | self.batch_size = batch_size
48 | self.microbatch = microbatch if microbatch > 0 else batch_size
49 | self.lr = lr
50 | self.ema_rate = (
51 | [ema_rate]
52 | if isinstance(ema_rate, float)
53 | else [float(x) for x in ema_rate.split(",")]
54 | )
55 | self.log_interval = log_interval
56 | self.save_interval = save_interval
57 | self.resume_checkpoint = resume_checkpoint
58 | self.use_fp16 = use_fp16
59 | self.fp16_scale_growth = fp16_scale_growth
60 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
61 | self.weight_decay = weight_decay
62 | self.lr_anneal_steps = lr_anneal_steps
63 |
64 | self.step = 0
65 | self.resume_step = 0
66 | self.global_batch = self.batch_size * dist.get_world_size()
67 |
68 | self.model_params = list(self.model.parameters())
69 | self.master_params = self.model_params
70 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
71 | self.sync_cuda = th.cuda.is_available()
72 |
73 | self._load_and_sync_parameters()
74 | if self.use_fp16:
75 | self._setup_fp16()
76 |
77 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
78 | if self.resume_step:
79 | self._load_optimizer_state()
80 | # Model was resumed, either due to a restart or a checkpoint
81 | # being specified at the command line.
82 | self.ema_params = [
83 | self._load_ema_parameters(rate) for rate in self.ema_rate
84 | ]
85 | else:
86 | self.ema_params = [
87 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
88 | ]
89 |
90 | if th.cuda.is_available() and 1==2:
91 | self.use_ddp = True
92 | self.ddp_model = DDP(
93 | self.model,
94 | device_ids=[dist_util.dev()],
95 | output_device=dist_util.dev(),
96 | broadcast_buffers=False,
97 | bucket_cap_mb=128,
98 | find_unused_parameters=False,
99 | )
100 | else:
101 | if dist.get_world_size() > 1:
102 | logger.warn(
103 | "Distributed training requires CUDA. "
104 | "Gradients will not be synchronized properly!"
105 | )
106 | self.use_ddp = False
107 | self.ddp_model = self.model
108 |
109 | def _load_and_sync_parameters(self):
110 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
111 |
112 | if resume_checkpoint:
113 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
114 | if dist.get_rank() == 0:
115 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
116 | self.model.load_state_dict(
117 | dist_util.load_state_dict(
118 | resume_checkpoint, map_location=dist_util.dev()
119 | )
120 | )
121 |
122 | dist_util.sync_params(self.model.parameters())
123 |
124 | def _load_ema_parameters(self, rate):
125 | ema_params = copy.deepcopy(self.master_params)
126 |
127 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
128 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
129 | if ema_checkpoint:
130 | if dist.get_rank() == 0:
131 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
132 | state_dict = dist_util.load_state_dict(
133 | ema_checkpoint, map_location=dist_util.dev()
134 | )
135 | ema_params = self._state_dict_to_master_params(state_dict)
136 |
137 | dist_util.sync_params(ema_params)
138 | return ema_params
139 |
140 | def _load_optimizer_state(self):
141 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
142 | opt_checkpoint = bf.join(
143 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
144 | )
145 | if bf.exists(opt_checkpoint):
146 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
147 | state_dict = dist_util.load_state_dict(
148 | opt_checkpoint, map_location=dist_util.dev()
149 | )
150 | self.opt.load_state_dict(state_dict)
151 |
152 | def _setup_fp16(self):
153 | self.master_params = make_master_params(self.model_params)
154 | self.model.convert_to_fp16()
155 |
156 | def run_loop(self):
157 | while (
158 | not self.lr_anneal_steps
159 | or self.step + self.resume_step < self.lr_anneal_steps
160 | ):
161 | batch, cond = next(self.data)
162 | self.run_step(batch, cond)
163 | if self.step % self.log_interval == 0:
164 | logger.dumpkvs()
165 | if self.step % self.save_interval == 0:
166 | self.save()
167 | # Run for a finite amount of time in integration tests.
168 | if self.step > 300000:
169 | return
170 | self.step += 1
171 | # Save the last checkpoint if it wasn't already saved.
172 | if (self.step - 1) % self.save_interval != 0:
173 | self.save()
174 |
175 | def run_step(self, batch, cond):
176 | self.forward_backward(batch, cond)
177 | if self.use_fp16:
178 | self.optimize_fp16()
179 | else:
180 | self.optimize_normal()
181 | self.log_step()
182 |
183 | def forward_backward(self, batch, cond):
184 | zero_grad(self.model_params)
185 | for i in range(0, batch.shape[0], self.microbatch):
186 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
187 | micro_cond = {
188 | k: v[i : i + self.microbatch].to(dist_util.dev())
189 | for k, v in cond.items()
190 | }
191 | last_batch = (i + self.microbatch) >= batch.shape[0]
192 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
193 |
194 | compute_losses = functools.partial(
195 | self.diffusion.training_losses,
196 | self.ddp_model,
197 | micro,
198 | t,
199 | model_kwargs=micro_cond,
200 | )
201 |
202 | if last_batch or not self.use_ddp:
203 | losses = compute_losses()
204 | else:
205 | with self.ddp_model.no_sync():
206 | losses = compute_losses()
207 |
208 | if isinstance(self.schedule_sampler, LossAwareSampler):
209 | self.schedule_sampler.update_with_local_losses(
210 | t, losses["loss"].detach()
211 | )
212 |
213 | loss = (losses["loss"] * weights).mean()
214 | log_loss_dict(
215 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
216 | )
217 | if self.use_fp16:
218 | loss_scale = 2 ** self.lg_loss_scale
219 | (loss * loss_scale).backward()
220 | else:
221 | loss.backward()
222 |
223 | def optimize_fp16(self):
224 | if any(not th.isfinite(p.grad).all() for p in self.model_params):
225 | self.lg_loss_scale -= 1
226 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
227 | return
228 |
229 | model_grads_to_master_grads(self.model_params, self.master_params)
230 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
231 | self._log_grad_norm()
232 | self._anneal_lr()
233 | self.opt.step()
234 | for rate, params in zip(self.ema_rate, self.ema_params):
235 | update_ema(params, self.master_params, rate=rate)
236 | master_params_to_model_params(self.model_params, self.master_params)
237 | self.lg_loss_scale += self.fp16_scale_growth
238 |
239 | def optimize_normal(self):
240 | self._log_grad_norm()
241 | self._anneal_lr()
242 | self.opt.step()
243 | for rate, params in zip(self.ema_rate, self.ema_params):
244 | update_ema(params, self.master_params, rate=rate)
245 |
246 | def _log_grad_norm(self):
247 | sqsum = 0.0
248 | for p in self.master_params:
249 | sqsum += (p.grad ** 2).sum().item()
250 | logger.logkv_mean("grad_norm", np.sqrt(sqsum))
251 |
252 | def _anneal_lr(self):
253 | if not self.lr_anneal_steps:
254 | return
255 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
256 | lr = self.lr * (1 - frac_done)
257 | for param_group in self.opt.param_groups:
258 | param_group["lr"] = lr
259 |
260 | def log_step(self):
261 | logger.logkv("step", self.step + self.resume_step)
262 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
263 | if self.use_fp16:
264 | logger.logkv("lg_loss_scale", self.lg_loss_scale)
265 |
266 | def save(self):
267 | def save_checkpoint(rate, params):
268 | state_dict = self._master_params_to_state_dict(params)
269 | if dist.get_rank() == 0:
270 | logger.log(f"saving model {rate}...")
271 | if not rate:
272 | filename = f"model{(self.step+self.resume_step):06d}.pt"
273 | else:
274 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
275 | with bf.BlobFile(bf.join("/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/", filename), "wb") as f:
276 | th.save(state_dict, f)
277 |
278 | save_checkpoint(0, self.master_params)
279 | for rate, params in zip(self.ema_rate, self.ema_params):
280 | save_checkpoint(rate, params)
281 |
282 | if dist.get_rank() == 0:
283 | with bf.BlobFile(
284 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
285 | "wb",
286 | ) as f:
287 | th.save(self.opt.state_dict(), f)
288 |
289 | dist.barrier()
290 |
291 | def _master_params_to_state_dict(self, master_params):
292 | if self.use_fp16:
293 | master_params = unflatten_master_params(
294 | self.model.parameters(), master_params
295 | )
296 | state_dict = self.model.state_dict()
297 | for i, (name, _value) in enumerate(self.model.named_parameters()):
298 | assert name in state_dict
299 | state_dict[name] = master_params[i]
300 | return state_dict
301 |
302 | def _state_dict_to_master_params(self, state_dict):
303 | params = [state_dict[name] for name, _ in self.model.named_parameters()]
304 | if self.use_fp16:
305 | return make_master_params(params)
306 | else:
307 | return params
308 |
309 | class TrainLoop:
310 | def __init__(
311 | self,
312 | *,
313 | model,
314 | diffusion,
315 | data,
316 | batch_size,
317 | microbatch,
318 | lr,
319 | ema_rate,
320 | log_interval,
321 | save_interval,
322 | resume_checkpoint,
323 | use_fp16=False,
324 | fp16_scale_growth=1e-3,
325 | schedule_sampler=None,
326 | weight_decay=0.0,
327 | lr_anneal_steps=0,
328 | ):
329 | self.model = model
330 | self.diffusion = diffusion
331 | self.data = data
332 | self.batch_size = batch_size
333 | self.microbatch = microbatch if microbatch > 0 else batch_size
334 | self.lr = lr
335 | self.ema_rate = (
336 | [ema_rate]
337 | if isinstance(ema_rate, float)
338 | else [float(x) for x in ema_rate.split(",")]
339 | )
340 | self.log_interval = log_interval
341 | self.save_interval = save_interval
342 | self.resume_checkpoint = resume_checkpoint
343 | self.use_fp16 = use_fp16
344 | self.fp16_scale_growth = fp16_scale_growth
345 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
346 | self.weight_decay = weight_decay
347 | self.lr_anneal_steps = lr_anneal_steps
348 |
349 | self.step = 0
350 | self.resume_step = 0
351 | self.global_batch = self.batch_size * dist.get_world_size()
352 |
353 | self.sync_cuda = th.cuda.is_available()
354 |
355 | self._load_and_sync_parameters()
356 | self.mp_trainer = MixedPrecisionTrainer(
357 | model=self.model,
358 | use_fp16=self.use_fp16,
359 | fp16_scale_growth=fp16_scale_growth,
360 | )
361 |
362 | self.opt = AdamW(
363 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
364 | )
365 | if self.resume_step:
366 | self._load_optimizer_state()
367 | # Model was resumed, either due to a restart or a checkpoint
368 | # being specified at the command line.
369 | self.ema_params = [
370 | self._load_ema_parameters(rate) for rate in self.ema_rate
371 | ]
372 | else:
373 | self.ema_params = [
374 | copy.deepcopy(self.mp_trainer.master_params)
375 | for _ in range(len(self.ema_rate))
376 | ]
377 |
378 | if th.cuda.is_available():
379 | self.use_ddp = True
380 | self.ddp_model = DDP(
381 | self.model,
382 | device_ids=[dist_util.dev()],
383 | output_device=dist_util.dev(),
384 | broadcast_buffers=False,
385 | bucket_cap_mb=128,
386 | find_unused_parameters=False,
387 | )
388 | else:
389 | if dist.get_world_size() > 1:
390 | logger.warn(
391 | "Distributed training requires CUDA. "
392 | "Gradients will not be synchronized properly!"
393 | )
394 | self.use_ddp = False
395 | self.ddp_model = self.model
396 |
397 | def _load_and_sync_parameters(self):
398 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
399 |
400 | if resume_checkpoint:
401 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
402 | if dist.get_rank() == 0:
403 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
404 | self.model.load_state_dict(
405 | dist_util.load_state_dict(
406 | resume_checkpoint, map_location=dist_util.dev()
407 | )
408 | )
409 |
410 | dist_util.sync_params(self.model.parameters())
411 |
412 | def _load_ema_parameters(self, rate):
413 | ema_params = copy.deepcopy(self.mp_trainer.master_params)
414 |
415 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
416 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
417 | if ema_checkpoint:
418 | if dist.get_rank() == 0:
419 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
420 | state_dict = dist_util.load_state_dict(
421 | ema_checkpoint, map_location=dist_util.dev()
422 | )
423 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
424 |
425 | dist_util.sync_params(ema_params)
426 | return ema_params
427 |
428 | def _load_optimizer_state(self):
429 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
430 | opt_checkpoint = bf.join(
431 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
432 | )
433 | if bf.exists(opt_checkpoint):
434 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
435 | state_dict = dist_util.load_state_dict(
436 | opt_checkpoint, map_location=dist_util.dev()
437 | )
438 | self.opt.load_state_dict(state_dict)
439 |
440 | def run_loop(self):
441 | while (
442 | not self.lr_anneal_steps
443 | or self.step + self.resume_step < self.lr_anneal_steps
444 | ):
445 | batch, cond = next(self.data)
446 | self.run_step(batch, cond)
447 | if self.step % self.log_interval == 0:
448 | logger.dumpkvs()
449 | if self.step % self.save_interval == 0:
450 | self.save()
451 | # After saving, sample unconditionally
452 | sample = self.diffusion.p_sample_loop(
453 | self.model,
454 | (1, 1, 256, 256),
455 | model_kwargs={},
456 | )
457 | plt.imsave(str(Path(get_blob_logdir()) / f"{self.step:05}.png"), clear(sample))
458 | # Run for a finite amount of time in integration tests.
459 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
460 | return
461 | self.step += 1
462 | # Save the last checkpoint if it wasn't already saved.
463 | if (self.step - 1) % self.save_interval != 0:
464 | self.save()
465 |
466 | def run_step(self, batch, cond):
467 | self.forward_backward(batch, cond)
468 | took_step = self.mp_trainer.optimize(self.opt)
469 | if took_step:
470 | self._update_ema()
471 | self._anneal_lr()
472 | self.log_step()
473 |
474 | def forward_backward(self, batch, cond):
475 | self.mp_trainer.zero_grad()
476 | for i in range(0, batch.shape[0], self.microbatch):
477 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
478 | micro_cond = {
479 | k: v[i : i + self.microbatch].to(dist_util.dev())
480 | for k, v in cond.items()
481 | }
482 | last_batch = (i + self.microbatch) >= batch.shape[0]
483 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
484 |
485 | compute_losses = functools.partial(
486 | self.diffusion.training_losses,
487 | self.ddp_model,
488 | micro,
489 | t,
490 | model_kwargs=micro_cond,
491 | )
492 |
493 | if last_batch or not self.use_ddp:
494 | losses = compute_losses()
495 | else:
496 | with self.ddp_model.no_sync():
497 | losses = compute_losses()
498 |
499 | if isinstance(self.schedule_sampler, LossAwareSampler):
500 | self.schedule_sampler.update_with_local_losses(
501 | t, losses["loss"].detach()
502 | )
503 |
504 | loss = (losses["loss"] * weights).mean()
505 | log_loss_dict(
506 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
507 | )
508 | self.mp_trainer.backward(loss)
509 |
510 | def _update_ema(self):
511 | for rate, params in zip(self.ema_rate, self.ema_params):
512 | update_ema(params, self.mp_trainer.master_params, rate=rate)
513 |
514 | def _anneal_lr(self):
515 | if not self.lr_anneal_steps:
516 | return
517 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
518 | lr = self.lr * (1 - frac_done)
519 | for param_group in self.opt.param_groups:
520 | param_group["lr"] = lr
521 |
522 | def log_step(self):
523 | logger.logkv("step", self.step + self.resume_step)
524 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
525 |
526 | def save(self):
527 | def save_checkpoint(rate, params):
528 | state_dict = self.mp_trainer.master_params_to_state_dict(params)
529 | if dist.get_rank() == 0:
530 | logger.log(f"saving model {rate}...")
531 | if not rate:
532 | filename = f"model{(self.step+self.resume_step):06d}.pt"
533 | else:
534 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
535 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
536 | th.save(state_dict, f)
537 |
538 | save_checkpoint(0, self.mp_trainer.master_params)
539 | for rate, params in zip(self.ema_rate, self.ema_params):
540 | save_checkpoint(rate, params)
541 |
542 | if dist.get_rank() == 0:
543 | with bf.BlobFile(
544 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
545 | "wb",
546 | ) as f:
547 | th.save(self.opt.state_dict(), f)
548 |
549 | dist.barrier()
550 |
551 |
552 |
553 | def parse_resume_step_from_filename(filename):
554 | """
555 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
556 | checkpoint's number of steps.
557 | """
558 | split = filename.split("model")
559 | if len(split) < 2:
560 | return 0
561 | split1 = split[-1].split(".")[0]
562 | try:
563 | return int(split1)
564 | except ValueError:
565 | return 0
566 |
567 |
568 | def get_blob_logdir():
569 | # You can change this to be a separate path to save checkpoints to
570 | # a blobstore or some external drive.
571 | return logger.get_dir()
572 |
573 |
574 | def find_resume_checkpoint():
575 | # On your infrastructure, you may want to override this to automatically
576 | # discover the latest checkpoint on your blob storage, etc.
577 | return None
578 |
579 |
580 | def find_ema_checkpoint(main_checkpoint, step, rate):
581 | if main_checkpoint is None:
582 | return None
583 | filename = f"ema_{rate}_{(step):06d}.pt"
584 | path = bf.join(bf.dirname(main_checkpoint), filename)
585 | if bf.exists(path):
586 | return path
587 | return None
588 |
589 |
590 | def log_loss_dict(diffusion, ts, losses):
591 | for key, values in losses.items():
592 | logger.logkv_mean(key, values.mean().item())
593 | # Log the quantiles (four quartiles, in particular).
594 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
595 | quartile = int(4 * sub_t / diffusion.num_timesteps)
596 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
597 |
--------------------------------------------------------------------------------
/guided_diffusion/training_script.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import tqdm
4 | import torch
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 |
8 | import argparse
9 | import traceback
10 | import logging
11 | import yaml
12 | import sys
13 | import os
14 | import torch
15 | import numpy as np
16 |
17 | from pathlib import Path
18 |
19 | from guided_diffusion.script_util import create_model, create_gaussian_diffusion
20 | from skimage.metrics import peak_signal_noise_ratio
21 | from pathlib import Path
22 | from physics.ct import CT
23 | from physics.mri import SinglecoilMRI_comp, MulticoilMRI
24 | from utils import CG, clear, get_mask, nchw_comp_to_real, real_to_nchw_comp, normalize_np, get_beta_schedule
25 | from functools import partial
26 |
27 | torch.set_printoptions(sci_mode=False)
28 | def compute_alpha(beta, t):
29 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
30 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
31 | return a
32 |
33 |
34 | def dict2namespace(config):
35 | namespace = argparse.Namespace()
36 | for key, value in config.items():
37 | if isinstance(value, dict):
38 | new_value = dict2namespace(value)
39 | else:
40 | new_value = value
41 | setattr(namespace, key, new_value)
42 | return namespace
43 |
44 |
45 | def parse_args_and_config():
46 | parser = argparse.ArgumentParser(description=globals()["__doc__"])
47 |
48 | parser.add_argument(
49 | "--config", type=str, required=True, help="Path to the config file"
50 | )
51 | parser.add_argument(
52 | "--type", type=str, required=True, help="Either [2d, 3d]"
53 | )
54 | parser.add_argument(
55 | "--CG_iter", type=int, default=5, help="Inner number of iterations for CG"
56 | )
57 | parser.add_argument(
58 | "--Nview", type=int, default=16, help="number of projections for CT"
59 | )
60 | parser.add_argument("--seed", type=int, default=1234, help="Set different seeds for diverse results")
61 | parser.add_argument(
62 | "--exp", type=str, default="./exp", help="Path for saving running related data."
63 | )
64 | parser.add_argument(
65 | "--ckpt_load_name", type=str, default="AAPM256_1M.pt", help="Load pre-trained ckpt"
66 | )
67 | parser.add_argument(
68 | "--deg", type=str, required=True, help="Degradation"
69 | )
70 | parser.add_argument(
71 | "--sigma_y", type=float, default=0., help="sigma_y"
72 | )
73 | parser.add_argument(
74 | "--eta", type=float, default=0.85, help="Eta"
75 | )
76 | parser.add_argument(
77 | "--rho", type=float, default=10.0, help="rho"
78 | )
79 | parser.add_argument(
80 | "--lamb", type=float, default=0.04, help="lambda for TV"
81 | )
82 | parser.add_argument(
83 | "--gamma", type=float, default=1.0, help="regularizer for noisy recon"
84 | )
85 | parser.add_argument(
86 | "--T_sampling", type=int, default=50, help="Total number of sampling steps"
87 | )
88 | parser.add_argument(
89 | "-i",
90 | "--image_folder",
91 | type=str,
92 | default="./results",
93 | help="The folder name of samples",
94 | )
95 | parser.add_argument(
96 | "--dataset_path", type=str, default="/media/harry/tomo/AAPM_data_vol/256_sorted/L067", help="The folder of the dataset"
97 | )
98 |
99 | # MRI-exp arguments
100 | parser.add_argument(
101 | "--mask_type", type=str, default="uniform1d", help="Undersampling type"
102 | )
103 | parser.add_argument(
104 | "--acc_factor", type=int, default=4, help="acceleration factor"
105 | )
106 | parser.add_argument(
107 | "--nspokes", type=int, default=30, help="Number of sampled lines in radial trajectory"
108 | )
109 | parser.add_argument(
110 | "--center_fraction", type=float, default=0.08, help="ACS region"
111 | )
112 |
113 |
114 | args = parser.parse_args()
115 |
116 | # parse config file
117 | with open(os.path.join("configs/vp", args.config), "r") as f:
118 | config = yaml.safe_load(f)
119 | new_config = dict2namespace(config)
120 |
121 | if "CT" in args.deg:
122 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"view{args.Nview}"
123 | elif "MRI" in args.deg:
124 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"{args.mask_type}_acc{args.acc_factor}"
125 |
126 | args.image_folder.mkdir(exist_ok=True, parents=True)
127 | if not os.path.exists(args.image_folder):
128 | os.makedirs(args.image_folder)
129 |
130 | # add device
131 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
132 | logging.info("Using device: {}".format(device))
133 | new_config.device = device
134 |
135 | # set random seed
136 | torch.manual_seed(args.seed)
137 | np.random.seed(args.seed)
138 | if torch.cuda.is_available():
139 | torch.cuda.manual_seed_all(args.seed)
140 |
141 | torch.backends.cudnn.benchmark = True
142 |
143 | return args, new_config
144 |
145 | class Diffusion(object):
146 | def __init__(self, args, config, device=None):
147 | self.args = args
148 | self.args.image_folder = Path(self.args.image_folder)
149 | for t in ["input", "recon", "label"]:
150 | if t == "recon":
151 | (self.args.image_folder / t / "progress").mkdir(exist_ok=True, parents=True)
152 | else:
153 | (self.args.image_folder / t).mkdir(exist_ok=True, parents=True)
154 | self.config = config
155 | if device is None:
156 | device = (
157 | torch.device("cuda")
158 | if torch.cuda.is_available()
159 | else torch.device("cpu")
160 | )
161 | self.device = device
162 |
163 | self.model_var_type = config.model.var_type
164 | betas = get_beta_schedule(
165 | beta_schedule=config.diffusion.beta_schedule,
166 | beta_start=config.diffusion.beta_start,
167 | beta_end=config.diffusion.beta_end,
168 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
169 | )
170 | betas = self.betas = torch.from_numpy(betas).float().to(self.device)
171 | self.num_timesteps = betas.shape[0]
172 |
173 | alphas = 1.0 - betas
174 | alphas_cumprod = alphas.cumprod(dim=0)
175 | alphas_cumprod_prev = torch.cat(
176 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
177 | )
178 | self.alphas_cumprod_prev = alphas_cumprod_prev
179 | posterior_variance = (
180 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
181 | )
182 | if self.model_var_type == "fixedlarge":
183 | self.logvar = betas.log()
184 | elif self.model_var_type == "fixedsmall":
185 | self.logvar = posterior_variance.clamp(min=1e-20).log()
186 |
187 | def train(self):
188 | config_dict = vars(self.config.model)
189 | model = create_model(**config_dict)
190 | ckpt = os.path.join(self.args.exp, "vp", self.args.ckpt_load_name)
191 |
192 | model.load_state_dict(torch.load(ckpt, map_location=self.device))
193 | print(f"Model ckpt loaded from {ckpt}")
194 | model.to("cuda")
195 | model.train()
196 |
197 |
198 | # model.eval()
199 |
200 | # print('Run DDS.',
201 | # f'{self.args.T_sampling} sampling steps.',
202 | # f'Task: {self.args.deg}.'
203 | # )
204 | # self.dds(model)
205 |
206 |
207 |
208 | def main():
209 | args, config = parse_args_and_config()
210 | diffusion_model = Diffusion(args, config)
211 | diffusion_model.train()
212 |
213 |
214 | if __name__ == "__main__":
215 | print("running training")
216 | main()
--------------------------------------------------------------------------------
/guided_diffusion/training_triplane_script.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import tqdm
4 | import torch
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 |
8 | import argparse
9 | import traceback
10 | import logging
11 | import yaml
12 | import sys
13 | import os
14 | import torch
15 | import numpy as np
16 |
17 | from pathlib import Path
18 |
19 | from guided_diffusion.script_util import create_model, create_gaussian_diffusion
20 | from skimage.metrics import peak_signal_noise_ratio
21 | from pathlib import Path
22 | from physics.ct import CT
23 | from physics.mri import SinglecoilMRI_comp, MulticoilMRI
24 | from utils import CG, clear, get_mask, nchw_comp_to_real, real_to_nchw_comp, normalize_np, get_beta_schedule
25 | from functools import partial
26 |
27 | torch.set_printoptions(sci_mode=False)
28 | def compute_alpha(beta, t):
29 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
30 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
31 | return a
32 |
33 |
34 | def dict2namespace(config):
35 | namespace = argparse.Namespace()
36 | for key, value in config.items():
37 | if isinstance(value, dict):
38 | new_value = dict2namespace(value)
39 | else:
40 | new_value = value
41 | setattr(namespace, key, new_value)
42 | return namespace
43 |
44 |
45 | def parse_args_and_config():
46 | parser = argparse.ArgumentParser(description=globals()["__doc__"])
47 |
48 | parser.add_argument(
49 | "--config", type=str, required=True, help="Path to the config file"
50 | )
51 | parser.add_argument(
52 | "--type", type=str, required=True, help="Either [2d, 3d]"
53 | )
54 | parser.add_argument(
55 | "--CG_iter", type=int, default=5, help="Inner number of iterations for CG"
56 | )
57 | parser.add_argument(
58 | "--Nview", type=int, default=16, help="number of projections for CT"
59 | )
60 | parser.add_argument("--seed", type=int, default=1234, help="Set different seeds for diverse results")
61 | parser.add_argument(
62 | "--exp", type=str, default="./exp", help="Path for saving running related data."
63 | )
64 | parser.add_argument(
65 | "--ckpt_load_name", type=str, default="AAPM256_1M.pt", help="Load pre-trained ckpt"
66 | )
67 | parser.add_argument(
68 | "--deg", type=str, required=True, help="Degradation"
69 | )
70 | parser.add_argument(
71 | "--sigma_y", type=float, default=0., help="sigma_y"
72 | )
73 | parser.add_argument(
74 | "--eta", type=float, default=0.85, help="Eta"
75 | )
76 | parser.add_argument(
77 | "--rho", type=float, default=10.0, help="rho"
78 | )
79 | parser.add_argument(
80 | "--lamb", type=float, default=0.04, help="lambda for TV"
81 | )
82 | parser.add_argument(
83 | "--gamma", type=float, default=1.0, help="regularizer for noisy recon"
84 | )
85 | parser.add_argument(
86 | "--T_sampling", type=int, default=50, help="Total number of sampling steps"
87 | )
88 | parser.add_argument(
89 | "-i",
90 | "--image_folder",
91 | type=str,
92 | default="./results",
93 | help="The folder name of samples",
94 | )
95 | parser.add_argument(
96 | "--dataset_path", type=str, default="/media/harry/tomo/AAPM_data_vol/256_sorted/L067", help="The folder of the dataset"
97 | )
98 |
99 | # MRI-exp arguments
100 | parser.add_argument(
101 | "--mask_type", type=str, default="uniform1d", help="Undersampling type"
102 | )
103 | parser.add_argument(
104 | "--acc_factor", type=int, default=4, help="acceleration factor"
105 | )
106 | parser.add_argument(
107 | "--nspokes", type=int, default=30, help="Number of sampled lines in radial trajectory"
108 | )
109 | parser.add_argument(
110 | "--center_fraction", type=float, default=0.08, help="ACS region"
111 | )
112 |
113 |
114 | args = parser.parse_args()
115 |
116 | # parse config file
117 | with open(os.path.join("configs/vp", args.config), "r") as f:
118 | config = yaml.safe_load(f)
119 | new_config = dict2namespace(config)
120 |
121 | if "CT" in args.deg:
122 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"view{args.Nview}"
123 | elif "MRI" in args.deg:
124 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"{args.mask_type}_acc{args.acc_factor}"
125 |
126 | args.image_folder.mkdir(exist_ok=True, parents=True)
127 | if not os.path.exists(args.image_folder):
128 | os.makedirs(args.image_folder)
129 |
130 | # add device
131 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
132 | logging.info("Using device: {}".format(device))
133 | new_config.device = device
134 |
135 | # set random seed
136 | torch.manual_seed(args.seed)
137 | np.random.seed(args.seed)
138 | if torch.cuda.is_available():
139 | torch.cuda.manual_seed_all(args.seed)
140 |
141 | torch.backends.cudnn.benchmark = True
142 |
143 | return args, new_config
144 |
145 | class Diffusion(object):
146 | def __init__(self, args, config, device=None):
147 | self.args = args
148 | self.args.image_folder = Path(self.args.image_folder)
149 | for t in ["input", "recon", "label"]:
150 | if t == "recon":
151 | (self.args.image_folder / t / "progress").mkdir(exist_ok=True, parents=True)
152 | else:
153 | (self.args.image_folder / t).mkdir(exist_ok=True, parents=True)
154 | self.config = config
155 | if device is None:
156 | device = (
157 | torch.device("cuda")
158 | if torch.cuda.is_available()
159 | else torch.device("cpu")
160 | )
161 | self.device = device
162 |
163 | self.model_var_type = config.model.var_type
164 | betas = get_beta_schedule(
165 | beta_schedule=config.diffusion.beta_schedule,
166 | beta_start=config.diffusion.beta_start,
167 | beta_end=config.diffusion.beta_end,
168 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
169 | )
170 | betas = self.betas = torch.from_numpy(betas).float().to(self.device)
171 | self.num_timesteps = betas.shape[0]
172 |
173 | alphas = 1.0 - betas
174 | alphas_cumprod = alphas.cumprod(dim=0)
175 | alphas_cumprod_prev = torch.cat(
176 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
177 | )
178 | self.alphas_cumprod_prev = alphas_cumprod_prev
179 | posterior_variance = (
180 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
181 | )
182 | if self.model_var_type == "fixedlarge":
183 | self.logvar = betas.log()
184 | elif self.model_var_type == "fixedsmall":
185 | self.logvar = posterior_variance.clamp(min=1e-20).log()
186 |
187 | def train(self):
188 | config_dict = vars(self.config.model)
189 | config_dict["use_spacecode"] = True
190 | print(config_dict)
191 | print(config_dict)
192 | model = create_model(**config_dict)
193 | ckpt = os.path.join(self.args.exp, "vp", self.args.ckpt_load_name)
194 |
195 | model.load_state_dict(torch.load(ckpt, map_location=self.device))
196 | print(f"Model ckpt loaded from {ckpt}")
197 | # model.to("cuda")
198 | # model.train()
199 |
200 |
201 | # model.eval()
202 |
203 | # print('Run DDS.',
204 | # f'{self.args.T_sampling} sampling steps.',
205 | # f'Task: {self.args.deg}.'
206 | # )
207 | # self.dds(model)
208 |
209 |
210 |
211 | def main():
212 | args, config = parse_args_and_config()
213 | diffusion_model = Diffusion(args, config)
214 | diffusion_model.train()
215 |
216 |
217 | if __name__ == "__main__":
218 | print("running training")
219 | main()
--------------------------------------------------------------------------------
/guided_diffusion/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 | from tqdm import tqdm
5 |
6 |
7 | def get_alpha_schedule(type="linear", start=0.4, end=0.015, total=1000):
8 | if type == "linear":
9 | schedule = np.linspace(start, end, total)
10 | elif type == "const":
11 | assert start == end, f"For const schedule, start and end should match. Got start:{start}, end:{end}"
12 | schedule = np.full(total, start)
13 | return np.flip(schedule)
14 |
15 |
16 | def CG(A_fn, b_cg, x, n_inner=10, eps=1e-8):
17 | r = b_cg - A_fn(x)
18 | p = r.clone()
19 | rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T)
20 | for _ in range(n_inner):
21 | Ap = A_fn(p)
22 | a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T)
23 |
24 | x += a * p
25 | r -= a * Ap
26 |
27 | rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T)
28 |
29 | if torch.sqrt(rs_new) < eps:
30 | break
31 | p = r + (rs_new / rs_old) * p
32 | rs_old = rs_new
33 | return x
34 |
35 | # x0_t_hat = x0_t - A_funcs.A_pinv(
36 | # A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1)
37 | # ).reshape(*x0_t.size())
38 | # returns vectorized A^T(A(x))
39 |
40 |
41 | def Acg(x, A_func):
42 | x_vec = x.reshape(x.size(0), -1)
43 | tmp = A_func.At(A_func.A(x_vec))
44 | return tmp.reshape(*x.size())
45 |
46 |
47 | def clear_color(x):
48 | x = x.detach().cpu().squeeze().numpy()
49 | return normalize_np(np.transpose(x, (1, 2, 0)))
50 |
51 |
52 | def clear(x):
53 | x = x.detach().cpu().squeeze().numpy()
54 | return x
55 |
56 |
57 | def normalize_np(img):
58 | """ Normalize img in arbitrary range to [0, 1] """
59 | img -= np.min(img)
60 | img /= np.max(img)
61 | return img
62 |
63 |
64 | def clip(img):
65 | return torch.clip(img, -1.0, 1.0)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import traceback
3 | import logging
4 | import yaml
5 | import sys
6 | import os
7 | import torch
8 | import numpy as np
9 |
10 | from solver_2d import Diffusion as Diffusion_2d
11 | from solver_3d import Diffusion as Diffusion_3d
12 | from solver_3d_baselines import Diffusion as Diffusion_baseline
13 | from solver_3D_blend import Diffusion as Diffusion_blend
14 | from solver_3d_blend_timetravel import Diffusion as Diffusion_blendtt
15 | from eval_3DCT_blendcond import Diffusion as Diffusion_blendcond
16 | from eval_3D_blend_cond_lidc import Diffusion as Diffusion_blendcond_lidc
17 | from solver_3D_diffusionmbir import Diffusion as Diffusionmbir
18 | from pathlib import Path
19 |
20 | torch.set_printoptions(sci_mode=False)
21 |
22 | def parse_args_and_config():
23 | parser = argparse.ArgumentParser(description=globals()["__doc__"])
24 |
25 | parser.add_argument(
26 | "--config", type=str, required=True, help="Path to the config file"
27 | )
28 | parser.add_argument(
29 | "--type", type=str, required=True, help="Either [2d, 3d]"
30 | )
31 | parser.add_argument(
32 | "--CG_iter", type=int, default=5, help="Inner number of iterations for CG"
33 | )
34 | parser.add_argument(
35 | "--Nview", type=int, default=16, help="number of projections for CT"
36 | )
37 | parser.add_argument("--seed", type=int, default=1234, help="Set different seeds for diverse results")
38 | parser.add_argument(
39 | "--exp", type=str, default="./exp", help="Path for saving running related data."
40 | )
41 | parser.add_argument(
42 | "--ckpt_load_name", type=str, default="AAPM256_1M.pt", help="Load pre-trained ckpt"
43 | )
44 | parser.add_argument(
45 | "--deg", type=str, required=True, help="Degradation"
46 | )
47 | parser.add_argument(
48 | "--sigma_y", type=float, default=0., help="sigma_y"
49 | )
50 | parser.add_argument(
51 | "--eta", type=float, default=0.85, help="Eta"
52 | )
53 | parser.add_argument(
54 | "--rho", type=float, default=10.0, help="rho"
55 | )
56 | parser.add_argument(
57 | "--lamb", type=float, default=0.04, help="lambda for TV"
58 | )
59 | parser.add_argument(
60 | "--gamma", type=float, default=1.0, help="regularizer for noisy recon"
61 | )
62 | parser.add_argument(
63 | "--T_sampling", type=int, default=50, help="Total number of sampling steps"
64 | )
65 | parser.add_argument(
66 | "-i",
67 | "--image_folder",
68 | type=str,
69 | default="./results",
70 | help="The folder name of samples",
71 | )
72 | parser.add_argument(
73 | "--dataset_path", type=str, default="/media/harry/tomo/AAPM_data_vol/256_sorted/L067", help="The folder of the dataset"
74 | )
75 |
76 | # MRI-exp arguments
77 | parser.add_argument(
78 | "--mask_type", type=str, default="uniform1d", help="Undersampling type"
79 | )
80 | parser.add_argument(
81 | "--acc_factor", type=int, default=4, help="acceleration factor"
82 | )
83 | parser.add_argument(
84 | "--nspokes", type=int, default=30, help="Number of sampled lines in radial trajectory"
85 | )
86 | parser.add_argument(
87 | "--center_fraction", type=float, default=0.08, help="ACS region"
88 | )
89 |
90 |
91 | args = parser.parse_args()
92 |
93 | # parse config file
94 | with open(os.path.join("configs/vp", args.config), "r") as f:
95 | config = yaml.safe_load(f)
96 | new_config = dict2namespace(config)
97 |
98 | if "CT" in args.deg:
99 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"view{args.Nview}"
100 | elif "MRI" in args.deg:
101 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"{args.mask_type}_acc{args.acc_factor}"
102 |
103 | args.image_folder.mkdir(exist_ok=True, parents=True)
104 | if not os.path.exists(args.image_folder):
105 | os.makedirs(args.image_folder)
106 |
107 | # add device
108 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
109 | logging.info("Using device: {}".format(device))
110 | new_config.device = device
111 |
112 | # set random seed
113 | torch.manual_seed(args.seed)
114 | np.random.seed(args.seed)
115 | if torch.cuda.is_available():
116 | torch.cuda.manual_seed_all(args.seed)
117 |
118 | torch.backends.cudnn.benchmark = True
119 |
120 | return args, new_config
121 |
122 |
123 | def dict2namespace(config):
124 | namespace = argparse.Namespace()
125 | for key, value in config.items():
126 | if isinstance(value, dict):
127 | new_value = dict2namespace(value)
128 | else:
129 | new_value = value
130 | setattr(namespace, key, new_value)
131 | return namespace
132 |
133 |
134 | def main():
135 | args, config = parse_args_and_config()
136 |
137 | try:
138 | if args.type == "2d":
139 | runner = Diffusion_2d(args, config)
140 | elif args.type == "3d":
141 | runner = Diffusion_3d(args, config)
142 | elif args.type == "3dblend":
143 | runner = Diffusion_blend(args, config)
144 | elif args.type == "3dblendtt":
145 | runner = Diffusion_blendtt(args, config)
146 | elif args.type == "3dblendcond":
147 | runner = Diffusion_blendcond(args, config)
148 | elif args.type == "3dblendcond_lidc":
149 | runner = Diffusion_blendcond_lidc(args, config)
150 | elif args.type == "3dbaseline":
151 | runner = Diffusion_baseline(args, config)
152 | elif args.type == "diffusionmbir":
153 | runner = Diffusionmbir(args, config)
154 |
155 | runner.sample()
156 | except Exception:
157 | logging.error(traceback.format_exc())
158 |
159 | return 0
160 |
161 |
162 | if __name__ == "__main__":
163 | sys.exit(main())
164 |
--------------------------------------------------------------------------------
/train_SVCT_3D_triplane.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | Nview=8
4 | T_sampling=50
5 | eta=0.85
6 |
7 | python training_triplane_script.py \
8 | --type '2d' \
9 | --config AAPM_256_lsun.yaml \
10 | --dataset_path "/nfs/turbo/coe-liyues/bowenbw/DDS/indist_samples/CT/L067" \
11 | --ckpt_load_name "/nfs/turbo/coe-liyues/bowenbw/DDS/checkpoints/AAPM256_1M.pth" \
12 | --Nview $Nview \
13 | --eta $eta \
14 | --deg "SV-CT" \
15 | --sigma_y 0.01 \
16 | --T_sampling 100 \
17 | --T_sampling $T_sampling \
18 | --resume_checkpoint true \
19 | -i ./results
--------------------------------------------------------------------------------
/training_triplane_lidc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import tqdm
4 | import torch
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | from glob import glob
8 | import argparse
9 | import traceback
10 | import logging
11 | import yaml
12 | import sys
13 | import os
14 | import torch
15 | import numpy as np
16 |
17 | from pathlib import Path
18 | from guided_diffusion.CTDataset import *
19 | from guided_diffusion.train_util import *
20 | from guided_diffusion.script_util import create_model, create_gaussian_diffusion
21 | from skimage.metrics import peak_signal_noise_ratio
22 | from pathlib import Path
23 | from physics.ct import CT
24 | from physics.mri import SinglecoilMRI_comp, MulticoilMRI
25 | from utils import CG, clear, get_mask, nchw_comp_to_real, real_to_nchw_comp, normalize_np, get_beta_schedule
26 | from functools import partial
27 |
28 | def compute_alpha(beta, t):
29 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
30 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
31 | return a
32 |
33 | def dict2namespace(config):
34 | namespace = argparse.Namespace()
35 | for key, value in config.items():
36 | if isinstance(value, dict):
37 | new_value = dict2namespace(value)
38 | else:
39 | new_value = value
40 | setattr(namespace, key, new_value)
41 | return namespace
42 |
43 | def parse_args_and_config():
44 | parser = argparse.ArgumentParser(description=globals()["__doc__"])
45 | parser.add_argument(
46 | "--config", type=str, required=True, help="Path to the config file"
47 | )
48 | parser.add_argument(
49 | "--type", type=str, required=True, help="Either [2d, 3d]"
50 | )
51 | parser.add_argument(
52 | "--CG_iter", type=int, default=5, help="Inner number of iterations for CG"
53 | )
54 | parser.add_argument(
55 | "--Nview", type=int, default=16, help="number of projections for CT"
56 | )
57 | parser.add_argument("--seed", type=int, default=1234, help="Set different seeds for diverse results")
58 | parser.add_argument(
59 | "--exp", type=str, default="./exp", help="Path for saving running related data."
60 | )
61 | parser.add_argument(
62 | "--ckpt_load_name", type=str, default="AAPM256_1M.pt", help="Load pre-trained ckpt"
63 | )
64 | parser.add_argument(
65 | "--deg", type=str, required=True, help="Degradation"
66 | )
67 | parser.add_argument(
68 | "--sigma_y", type=float, default=0., help="sigma_y"
69 | )
70 | parser.add_argument(
71 | "--eta", type=float, default=0.85, help="Eta"
72 | )
73 | parser.add_argument(
74 | "--rho", type=float, default=10.0, help="rho"
75 | )
76 | parser.add_argument(
77 | "--lamb", type=float, default=0.04, help="lambda for TV"
78 | )
79 | parser.add_argument(
80 | "--gamma", type=float, default=1.0, help="regularizer for noisy recon"
81 | )
82 | parser.add_argument(
83 | "--T_sampling", type=int, default=50, help="Total number of sampling steps"
84 | )
85 | parser.add_argument(
86 | "--resume_checkpoint", type=bool, default=False, help="resume training from a previous checkpoint"
87 | )
88 | parser.add_argument(
89 | "-i",
90 | "--image_folder",
91 | type=str,
92 | default="./results",
93 | help="The folder name of samples",
94 | )
95 | parser.add_argument(
96 | "--dataset_path", type=str, default="/media/harry/tomo/AAPM_data_vol/256_sorted/L067", help="The folder of the dataset"
97 | )
98 | # MRI-exp arguments
99 | parser.add_argument(
100 | "--mask_type", type=str, default="uniform1d", help="Undersampling type"
101 | )
102 | parser.add_argument(
103 | "--acc_factor", type=int, default=4, help="acceleration factor"
104 | )
105 | parser.add_argument(
106 | "--nspokes", type=int, default=30, help="Number of sampled lines in radial trajectory"
107 | )
108 | parser.add_argument(
109 | "--center_fraction", type=float, default=0.08, help="ACS region"
110 | )
111 | args = parser.parse_args()
112 |
113 | # parse config file
114 | with open(os.path.join("configs/vp", args.config), "r") as f:
115 | config = yaml.safe_load(f)
116 | new_config = dict2namespace(config)
117 |
118 | if "CT" in args.deg:
119 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"view{args.Nview}"
120 | elif "MRI" in args.deg:
121 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"{args.mask_type}_acc{args.acc_factor}"
122 |
123 | args.image_folder.mkdir(exist_ok=True, parents=True)
124 | if not os.path.exists(args.image_folder):
125 | os.makedirs(args.image_folder)
126 | # add device
127 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
128 | logging.info("Using device: {}".format(device))
129 | new_config.device = device
130 | # set random seed
131 | torch.manual_seed(args.seed)
132 | np.random.seed(args.seed)
133 | if torch.cuda.is_available():
134 | torch.cuda.manual_seed_all(args.seed)
135 | torch.backends.cudnn.benchmark = True
136 | return args, new_config
137 |
138 | class Diffusion(object):
139 | def __init__(self, args, config, device=None):
140 | self.args = args
141 | self.args.image_folder = Path(self.args.image_folder)
142 | for t in ["input", "recon", "label"]:
143 | if t == "recon":
144 | (self.args.image_folder / t / "progress").mkdir(exist_ok=True, parents=True)
145 | else:
146 | (self.args.image_folder / t).mkdir(exist_ok=True, parents=True)
147 | self.config = config
148 | if device is None:
149 | device = (
150 | torch.device("cuda")
151 | if torch.cuda.is_available()
152 | else torch.device("cpu")
153 | )
154 | self.device = device
155 |
156 | self.model_var_type = config.model.var_type
157 | betas = get_beta_schedule(
158 | beta_schedule=config.diffusion.beta_schedule,
159 | beta_start=config.diffusion.beta_start,
160 | beta_end=config.diffusion.beta_end,
161 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
162 | )
163 | betas = self.betas = torch.from_numpy(betas).float().to(self.device)
164 | self.num_timesteps = betas.shape[0]
165 |
166 | alphas = 1.0 - betas
167 | alphas_cumprod = alphas.cumprod(dim=0)
168 | alphas_cumprod_prev = torch.cat(
169 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
170 | )
171 | self.alphas_cumprod_prev = alphas_cumprod_prev
172 | posterior_variance = (
173 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
174 | )
175 | if self.model_var_type == "fixedlarge":
176 | self.logvar = betas.log()
177 | elif self.model_var_type == "fixedsmall":
178 | self.logvar = posterior_variance.clamp(min=1e-20).log()
179 |
180 | def train(self):
181 | config_dict = vars(self.config.model)
182 | config_dict["class_cond"] = True
183 | config_dict["use_spacecode"] = False
184 | print(config_dict)
185 | model = create_model(**config_dict)
186 | print(model.use_spacecode, "using spacecode")
187 | ckpt = os.path.join(self.args.exp, "vp", self.args.ckpt_load_name)
188 |
189 | pretrainsteps = 0
190 | if self.args.resume_checkpoint is True:
191 | print("resuming training")
192 | ckpt = "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/triplane3D_finetune_452024_iter50099_cond.ckpt"
193 | pretrainsteps = 50000
194 | else:
195 | ckpt = "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/256x256_diffusion_uncond.pt"
196 |
197 | loaded = torch.load(ckpt, map_location=self.device)
198 | model.load_state_dict(torch.load(ckpt, map_location=self.device)['state_dict']) ####if using last checkpoint
199 |
200 | # model.load_state_dict(torch.load(ckpt, map_location=self.device), strict = False) ######if using imagenet checkpoint
201 |
202 | print(f"Model ckpt loaded from {ckpt}")
203 | model.to("cuda")
204 | model.train()
205 |
206 | diffusion = create_gaussian_diffusion(
207 | steps=1000,
208 | learn_sigma=True,
209 | noise_schedule="linear",
210 | use_kl=False,
211 | predict_xstart=False,
212 | rescale_timesteps=False,
213 | rescale_learned_sigmas=False,
214 | timestep_respacing="",
215 | )
216 | print(diffusion.training_losses, "training loss")
217 |
218 | lr = 1e-5
219 | params = list(model.parameters())
220 |
221 | opt = torch.optim.AdamW(params, lr=lr)
222 |
223 | #############################testing feed numpy matrix into the training script########################
224 |
225 | #########################use the given trainer in improved_diffusion#########################
226 |
227 | """use the given trainer in improved_diffusion
228 | ds = CTDataset()
229 | params = {'batch_size': 4}
230 | training_generator = torch.utils.data.DataLoader(ds, **params)
231 | def load_data(loader):
232 | while True:
233 | yield from loader
234 |
235 | data = load_data(training_generator)
236 |
237 | TrainLoop2(
238 | model=model,
239 | diffusion=diffusion,
240 | data=data,
241 | batch_size=4,
242 | microbatch=-1,
243 | lr=3e-4,
244 | ema_rate="0.9999",
245 | log_interval=10,
246 | save_interval=2500,
247 | resume_checkpoint="",
248 | use_fp16=False,
249 | fp16_scale_growth=1e-3,
250 | schedule_sampler="uniform",
251 | weight_decay=0.0,
252 | lr_anneal_steps=0,
253 | ).run_loop()
254 | """
255 | ################################train manually####################################
256 | # files = glob('/nfs/turbo/coe-liyues/bowenbw/3DCT/AAPM_fusion_training/*')
257 | # files = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/AAPM_fusion_training_cond/*")
258 | # files2 = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/slice_fusion_training/*")
259 | # files = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/LIDC_fusion_training_cond/*")
260 | # files = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/LIDC_fusion_training_cond_small/*")
261 | files = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/LIDC_fusion_training_cond_large/*")
262 | for m in range(100100):
263 | x_train = np.zeros((4, 3, 256, 256))
264 | y = torch.randint(0, 3, (4,))
265 | for l in range(4):
266 | filename = np.random.choice(files)
267 | x_raw = np.transpose(np.load(filename), (2,0,1))[0:3]
268 | x_raw = np.clip(x_raw*2-1, -1, 1)
269 | x_train[l] = x_raw.copy()
270 | y_val = int((filename.split(".")[0]).split("_")[-1])
271 | y[l] = y_val
272 | print(y_val, filename)
273 | x_orig = torch.from_numpy(x_train).to("cuda").to(torch.float)
274 | i = torch.randint(0, 1000, (4,))
275 | t = i.to("cuda").long()
276 | y = y.to("cuda").long()
277 |
278 | model_kwargs = {}
279 | model_kwargs["y"] = y
280 |
281 | if m % 1000 == 0:
282 | x_sample = diffusion.ddim_sample_loop_progressive(model, (4,3,256,256), task = "None", progress= True, model_kwargs = model_kwargs)
283 | np.save("/nfs/turbo/coe-liyues/bowenbw/3DCT/x_sample_ddim_iter" + str(m + pretrainsteps) + "_finetune_4232024_cond_lidc.npy", x_sample.detach().cpu().numpy())
284 |
285 | loss = diffusion.training_losses(model, x_orig, t, model_kwargs=model_kwargs)["loss"]
286 | loss = loss.mean()
287 | loss.backward()
288 | opt.step()
289 | opt.zero_grad()
290 | # print(x_orig.dtype)
291 | # loss = diffusion.training_losses(model, x_orig, t)["loss"]
292 | # loss= loss.mean()
293 | # loss.backward()
294 | # opt.step()
295 | # opt.zero_grad()
296 | if m % 10 == 0:
297 | print(loss.item(), "loss", " at ", m, "th iteration")
298 | # ####################################################################################################################
299 | if m % 2000 == 99:
300 | torch.save({'iterations':m,'state_dict': model.state_dict()}, "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/LIDC_triplane3D_finetunelarge_4232024_iter" + str(m + pretrainsteps) + "_cond.ckpt")
301 |
302 | ####################################################################################
303 | # model.eval()
304 | # print('Run DDS.',
305 | # f'{self.args.T_sampling} sampling steps.',
306 | # f'Task: {self.args.deg}.'
307 | # )
308 | # self.dds(model)
309 | def main():
310 | args, config = parse_args_and_config()
311 | diffusion_model = Diffusion(args, config)
312 | diffusion_model.train()
313 |
314 |
315 | if __name__ == "__main__":
316 | print("running training")
317 | main()
--------------------------------------------------------------------------------
/training_triplane_script.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import tqdm
4 | import torch
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | from glob import glob
8 | import argparse
9 | import traceback
10 | import logging
11 | import yaml
12 | import sys
13 | import os
14 | import torch
15 | import numpy as np
16 | from datetime import datetime
17 |
18 | from pathlib import Path
19 | from guided_diffusion.CTDataset import *
20 | from guided_diffusion.train_util import *
21 | from guided_diffusion.script_util import create_model, create_gaussian_diffusion
22 | from skimage.metrics import peak_signal_noise_ratio
23 | from pathlib import Path
24 | from physics.ct import CT
25 | from physics.mri import SinglecoilMRI_comp, MulticoilMRI
26 | from utils import CG, clear, get_mask, nchw_comp_to_real, real_to_nchw_comp, normalize_np, get_beta_schedule
27 | from functools import partial
28 |
29 | def compute_alpha(beta, t):
30 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
31 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
32 | return a
33 |
34 | def dict2namespace(config):
35 | namespace = argparse.Namespace()
36 | for key, value in config.items():
37 | if isinstance(value, dict):
38 | new_value = dict2namespace(value)
39 | else:
40 | new_value = value
41 | setattr(namespace, key, new_value)
42 | return namespace
43 |
44 | def parse_args_and_config():
45 | parser = argparse.ArgumentParser(description=globals()["__doc__"])
46 | parser.add_argument(
47 | "--config", type=str, required=True, help="Path to the config file"
48 | )
49 | parser.add_argument(
50 | "--type", type=str, required=True, help="Either [2d, 3d]"
51 | )
52 | parser.add_argument(
53 | "--CG_iter", type=int, default=5, help="Inner number of iterations for CG"
54 | )
55 | parser.add_argument(
56 | "--Nview", type=int, default=16, help="number of projections for CT"
57 | )
58 | parser.add_argument("--seed", type=int, default=1234, help="Set different seeds for diverse results")
59 | parser.add_argument(
60 | "--exp", type=str, default="./exp", help="Path for saving running related data."
61 | )
62 | parser.add_argument(
63 | "--ckpt_load_name", type=str, default="AAPM256_1M.pt", help="Load pre-trained ckpt"
64 | )
65 | parser.add_argument(
66 | "--deg", type=str, required=True, help="Degradation"
67 | )
68 | parser.add_argument(
69 | "--sigma_y", type=float, default=0., help="sigma_y"
70 | )
71 | parser.add_argument(
72 | "--eta", type=float, default=0.85, help="Eta"
73 | )
74 | parser.add_argument(
75 | "--rho", type=float, default=10.0, help="rho"
76 | )
77 | parser.add_argument(
78 | "--lamb", type=float, default=0.04, help="lambda for TV"
79 | )
80 | parser.add_argument(
81 | "--gamma", type=float, default=1.0, help="regularizer for noisy recon"
82 | )
83 | parser.add_argument(
84 | "--T_sampling", type=int, default=50, help="Total number of sampling steps"
85 | )
86 | parser.add_argument(
87 | "--resume_checkpoint", type=bool, default=False, help="resume training from a previous checkpoint"
88 | )
89 | parser.add_argument(
90 | "-i",
91 | "--image_folder",
92 | type=str,
93 | default="./results",
94 | help="The folder name of samples",
95 | )
96 | parser.add_argument(
97 | "--dataset_path", type=str, default="/media/harry/tomo/AAPM_data_vol/256_sorted/L067", help="The folder of the dataset"
98 | )
99 | # MRI-exp arguments
100 | parser.add_argument(
101 | "--mask_type", type=str, default="uniform1d", help="Undersampling type"
102 | )
103 | parser.add_argument(
104 | "--acc_factor", type=int, default=4, help="acceleration factor"
105 | )
106 | parser.add_argument(
107 | "--nspokes", type=int, default=30, help="Number of sampled lines in radial trajectory"
108 | )
109 | parser.add_argument(
110 | "--center_fraction", type=float, default=0.08, help="ACS region"
111 | )
112 | args = parser.parse_args()
113 |
114 | # parse config file
115 | with open(os.path.join("configs/vp", args.config), "r") as f:
116 | config = yaml.safe_load(f)
117 | new_config = dict2namespace(config)
118 |
119 | if "CT" in args.deg:
120 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"view{args.Nview}"
121 | elif "MRI" in args.deg:
122 | args.image_folder = Path(args.image_folder) / f"{args.deg}" / f"{args.mask_type}_acc{args.acc_factor}"
123 |
124 | args.image_folder.mkdir(exist_ok=True, parents=True)
125 | if not os.path.exists(args.image_folder):
126 | os.makedirs(args.image_folder)
127 | # add device
128 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
129 | logging.info("Using device: {}".format(device))
130 | new_config.device = device
131 | # set random seed
132 | torch.manual_seed(args.seed)
133 | np.random.seed(args.seed)
134 | if torch.cuda.is_available():
135 | torch.cuda.manual_seed_all(args.seed)
136 | torch.backends.cudnn.benchmark = True
137 | return args, new_config
138 |
139 | class Diffusion(object):
140 | def __init__(self, args, config, device=None):
141 | self.args = args
142 | self.args.image_folder = Path(self.args.image_folder)
143 | for t in ["input", "recon", "label"]:
144 | if t == "recon":
145 | (self.args.image_folder / t / "progress").mkdir(exist_ok=True, parents=True)
146 | else:
147 | (self.args.image_folder / t).mkdir(exist_ok=True, parents=True)
148 | self.config = config
149 | if device is None:
150 | device = (
151 | torch.device("cuda")
152 | if torch.cuda.is_available()
153 | else torch.device("cpu")
154 | )
155 | self.device = device
156 |
157 | self.model_var_type = config.model.var_type
158 | betas = get_beta_schedule(
159 | beta_schedule=config.diffusion.beta_schedule,
160 | beta_start=config.diffusion.beta_start,
161 | beta_end=config.diffusion.beta_end,
162 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
163 | )
164 | betas = self.betas = torch.from_numpy(betas).float().to(self.device)
165 | self.num_timesteps = betas.shape[0]
166 |
167 | alphas = 1.0 - betas
168 | alphas_cumprod = alphas.cumprod(dim=0)
169 | alphas_cumprod_prev = torch.cat(
170 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
171 | )
172 | self.alphas_cumprod_prev = alphas_cumprod_prev
173 | posterior_variance = (
174 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
175 | )
176 | if self.model_var_type == "fixedlarge":
177 | self.logvar = betas.log()
178 | elif self.model_var_type == "fixedsmall":
179 | self.logvar = posterior_variance.clamp(min=1e-20).log()
180 |
181 | def train(self):
182 | config_dict = vars(self.config.model)
183 | config_dict["class_cond"] = True
184 | config_dict["use_spacecode"] = False
185 | print(config_dict)
186 | model = create_model(**config_dict)
187 | print(model.use_spacecode, "using spacecode")
188 | ckpt = os.path.join(self.args.exp, "vp", self.args.ckpt_load_name)
189 |
190 | pretrainsteps = 0
191 | if self.args.resume_checkpoint is True:
192 | print("resuming training")
193 | ckpt = "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/triplane3D_finetune_452024_iter50099_cond.ckpt"
194 | pretrainsteps = 50000
195 | else:
196 | ckpt = "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/256x256_diffusion_uncond.pt"
197 |
198 | loaded = torch.load(ckpt, map_location=self.device)
199 | for key in loaded['state_dict']:
200 | print(key, "printed")
201 | model.load_state_dict(torch.load(ckpt, map_location=self.device)['state_dict'])
202 |
203 | print(f"Model ckpt loaded from {ckpt}")
204 | model.to("cuda")
205 | model.train()
206 |
207 | diffusion = create_gaussian_diffusion(
208 | steps=1000,
209 | learn_sigma=True,
210 | noise_schedule="linear",
211 | use_kl=False,
212 | predict_xstart=False,
213 | rescale_timesteps=False,
214 | rescale_learned_sigmas=False,
215 | timestep_respacing="",
216 | )
217 | print(diffusion.training_losses, "training loss")
218 |
219 | lr = 1e-5
220 | params = list(model.parameters())
221 |
222 | opt = torch.optim.AdamW(params, lr=lr)
223 |
224 | #############################testing feed numpy matrix into the training script########################
225 |
226 | #########################use the given trainer in improved_diffusion#########################
227 |
228 | """use the given trainer in improved_diffusion
229 | ds = CTDataset()
230 | params = {'batch_size': 4}
231 | training_generator = torch.utils.data.DataLoader(ds, **params)
232 | def load_data(loader):
233 | while True:
234 | yield from loader
235 |
236 | data = load_data(training_generator)
237 |
238 | TrainLoop2(
239 | model=model,
240 | diffusion=diffusion,
241 | data=data,
242 | batch_size=4,
243 | microbatch=-1,
244 | lr=3e-4,
245 | ema_rate="0.9999",
246 | log_interval=10,
247 | save_interval=2500,
248 | resume_checkpoint="",
249 | use_fp16=False,
250 | fp16_scale_growth=1e-3,
251 | schedule_sampler="uniform",
252 | weight_decay=0.0,
253 | lr_anneal_steps=0,
254 | ).run_loop()
255 | """
256 | ################################train manually####################################
257 | # files = glob('/nfs/turbo/coe-liyues/bowenbw/3DCT/AAPM_fusion_training/*')
258 | files = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/AAPM_fusion_training_cond/*")
259 | files2 = glob("/nfs/turbo/coe-liyues/bowenbw/3DCT/slice_fusion_training/*")
260 |
261 |
262 | for m in range(25100):
263 |
264 | x_train = np.zeros((4, 3, 256, 256))
265 | y = torch.randint(0, 3, (4,))
266 | for l in range(4):
267 | luck = np.random.randint(0, 2)
268 | if luck == 0:
269 | filename = np.random.choice(files)
270 | x_raw = np.transpose(np.load(filename), (2,0,1))[0:3]
271 | x_raw = np.clip(x_raw*2-1, -1, 1)
272 | x_train[l] = x_raw.copy()
273 | y_val = int((filename.split(".")[0]).split("_")[-1])
274 | y[l] = y_val
275 | print(y_val)
276 | else:
277 | filename = np.random.choice(files2)
278 | x_raw = np.transpose(np.load(filename), (2,0,1))[0:3]
279 | x_raw = np.clip(x_raw*2-1, -1, 1)
280 | x_train[l] = x_raw.copy()
281 | y[l] = 3
282 | x_orig = torch.from_numpy(x_train).to("cuda").to(torch.float)
283 | i = torch.randint(0, 1000, (4,))
284 | t = i.to("cuda").long()
285 | y = y.to("cuda").long()
286 |
287 |
288 | model_kwargs = {}
289 | model_kwargs["y"] = y
290 |
291 | # if m % 1000 == 0:
292 | # x_sample = diffusion.ddim_sample_loop_progressive(model, (4,3,256,256), task = "None", progress= True, model_kwargs = model_kwargs)
293 | # np.save("/nfs/turbo/coe-liyues/bowenbw/3DCT/x_sample_ddim_iter" + str(m + pretrainsteps) + "_finetune_452024_cond.npy", x_sample.detach().cpu().numpy())
294 |
295 | loss = diffusion.training_losses(model, x_orig, t, model_kwargs=model_kwargs)["loss"]
296 | loss = loss.mean()
297 | loss.backward()
298 | opt.step()
299 | opt.zero_grad()
300 | # print(x_orig.dtype)
301 | # loss = diffusion.training_losses(model, x_orig, t)["loss"]
302 | # loss= loss.mean()
303 | # loss.backward()
304 | # opt.step()
305 | # opt.zero_grad()
306 | t0 = datetime.now()
307 | if m % 20 == 0:
308 | print(loss.item(), "loss", " at ", m, "th iteration")
309 | t1 = datetime.now()
310 | print(t1 - t0, "time elapsed")
311 | t0 = t1
312 | # ####################################################################################################################
313 | # if m % 5000 == 99:
314 | # torch.save({'iterations':m,'state_dict': model.state_dict()}, "/nfs/turbo/coe-liyues/bowenbw/3DCT/checkpoints/triplane3D_finetune_452024_iter" + str(m + pretrainsteps) + "_cond.ckpt")
315 |
316 | ####################################################################################
317 | # model.eval()
318 | # print('Run DDS.',
319 | # f'{self.args.T_sampling} sampling steps.',
320 | # f'Task: {self.args.deg}.'
321 | # )
322 | # self.dds(model)
323 | def main():
324 | args, config = parse_args_and_config()
325 | diffusion_model = Diffusion(args, config)
326 | diffusion_model.train()
327 |
328 |
329 | if __name__ == "__main__":
330 | print("running training")
331 | main()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | import os
5 | import logging
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | from statistics import mean, stdev
9 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity
10 | from scipy.ndimage import gaussian_laplace
11 | import functools
12 | from physics.fastmri_utils import fft2c_new, ifft2c_new
13 |
14 |
15 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
16 | def sigmoid(x):
17 | return 1 / (np.exp(-x) + 1)
18 |
19 | if beta_schedule == "quad":
20 | betas = (
21 | np.linspace(
22 | beta_start ** 0.5,
23 | beta_end ** 0.5,
24 | num_diffusion_timesteps,
25 | dtype=np.float64,
26 | )
27 | ** 2
28 | )
29 | elif beta_schedule == "linear":
30 | betas = np.linspace(
31 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
32 | )
33 | elif beta_schedule == "const":
34 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
35 | elif beta_schedule == "jsd":
36 | betas = 1.0 / np.linspace(
37 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
38 | )
39 | elif beta_schedule == "sigmoid":
40 | betas = np.linspace(-6, 6, num_diffusion_timesteps)
41 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
42 | else:
43 | raise NotImplementedError(beta_schedule)
44 | assert betas.shape == (num_diffusion_timesteps,)
45 | return betas
46 |
47 |
48 | def get_sigma(t, sde):
49 | """ VE-SDE """
50 | sigma_t = sde.sigma_min * (sde.sigma_max / sde.sigma_min) ** t
51 | return sigma_t
52 |
53 |
54 | def pred_x0_from_s(xt, s, t, sde):
55 | """ Tweedie's formula for denoising. Assumes VE-SDE """
56 | sigma_t = get_sigma(t, sde)
57 | tmp = sigma_t.view(sigma_t.shape[0], 1, 1, 1)
58 | pred_x0 = xt + (tmp ** 2) * s
59 | return pred_x0
60 |
61 |
62 | def recover_xt_from_x0(x0_t, s, t, sde):
63 | sigma_t = get_sigma(t, sde)
64 | tmp = sigma_t.view(sigma_t.shape[0], 1, 1, 1)
65 | xt = x0_t - (tmp ** 2) * s
66 | return xt
67 |
68 |
69 | def pred_eps_from_s(s, t, sde):
70 | sigma_t = get_sigma(t, sde)
71 | tmp = sigma_t.view(sigma_t.shape[0], 1, 1, 1)
72 | pred_eps = -tmp * s
73 | return pred_eps
74 |
75 |
76 | def _Dz(x): # Batch direction
77 | y = torch.zeros_like(x)
78 | y[:-1] = x[1:]
79 | y[-1] = x[0]
80 | return y - x
81 |
82 | def _Dzx(x):
83 | y = torch.zeros_like(x)
84 | y[:,:,:,:-1] = x[:,:,:,1:]
85 | y[:,:,:,-1] = x[:,:,:,0]
86 | return y - x
87 |
88 | def _DzxT(x):
89 | y = torch.zeros_like(x)
90 | y[:,:,:,:-1] = x[:,:,:,1:]
91 | y[:,:,:,-1] = x[:,:,:,0]
92 |
93 | tempt = -(y-x)
94 | difft = tempt[:,:,:,:-1]
95 | y[:,:,:,1:] = difft
96 | y[:,:,:,0] = x[:,:,:,-1] - x[:,:,:,0]
97 |
98 | return y - x
99 |
100 |
101 | def _Dzy(x):
102 | y = torch.zeros_like(x)
103 | y[:,:,:-1,:] = x[:,:,1:,:]
104 | y[:,:,-1,:] = x[:,:,0,:]
105 | return y - x
106 |
107 | def _DzyT(x):
108 | y = torch.zeros_like(x)
109 | y[:,:,:-1,:] = x[:,:,1:,:]
110 | y[:,:,-1,:] = x[:,:,0,:]
111 |
112 | tempt = -(y-x)
113 | difft = tempt[:,:,:-1,:]
114 | y[:,:,1:,:] = difft
115 | y[:,:,0,:] = x[:,:,-1,:] - x[:,:,0,:]
116 |
117 | return y - x
118 |
119 |
120 | def _DzT(x): # Batch direction
121 | y = torch.zeros_like(x)
122 | y[:-1] = x[1:]
123 | y[-1] = x[0]
124 |
125 | tempt = -(y-x)
126 | difft = tempt[:-1]
127 | y[1:] = difft
128 | y[0] = x[-1] - x[0]
129 |
130 | return y
131 |
132 |
133 |
134 | def _Dx(x): # Batch direction
135 | y = torch.zeros_like(x)
136 | y[:, :, :-1, :] = x[:, :, 1:, :]
137 | y[:, :, -1, :] = x[:, :, 0, :]
138 | return y - x
139 |
140 |
141 | def _DxT(x): # Batch direction
142 | y = torch.zeros_like(x)
143 | y[:, :, :-1, :] = x[:, :, 1:, :]
144 | y[:, :, -1, :] = x[:, :, 0, :]
145 | tempt = -(y - x)
146 | difft = tempt[:, :, :-1, :]
147 | y[:, :, 1:, :] = difft
148 | y[:, :, 0, :] = x[:, :, -1, :] - x[:, :, 0, :]
149 | return y
150 |
151 |
152 | def _Dy(x): # Batch direction
153 | y = torch.zeros_like(x)
154 | y[:, :, :, :-1] = x[:, :, :, 1:]
155 | y[:, :, :, -1] = x[:, :, :, 0]
156 | return y - x
157 |
158 |
159 | def _DyT(x): # Batch direction
160 | y = torch.zeros_like(x)
161 | y[:, :, :, :-1] = x[:, :, :, 1:]
162 | y[:, :, :, -1] = x[:, :, :, 0]
163 | tempt = -(y - x)
164 | difft = tempt[:, :, :, :-1]
165 | y[:, :, :, 1:] = difft
166 | y[:, :, :, 0] = x[:, :, :, -1] - x[:, :, :, 0]
167 | return y
168 |
169 |
170 |
171 | def CG(A, b, x, n_inner=5, eps=1e-5):
172 | r = b - A(x)
173 | p = r.clone()
174 | rsold = torch.matmul(r.view(1, -1), r.view(1, -1).T)
175 |
176 | for i in range(n_inner):
177 | Ap = A(p)
178 | a = rsold / torch.matmul(p.view(1, -1), Ap.view(1, -1).T)
179 |
180 | x = x + a * p
181 | r = r - a * Ap
182 |
183 | rsnew = torch.matmul(r.view(1, -1), r.view(1, -1).T)
184 | if torch.abs(torch.sqrt(rsnew)) < eps:
185 | break
186 | p = r + (rsnew / rsold) * p
187 | rsold = rsnew
188 | return x
189 |
190 |
191 | def shrink(src, lamb):
192 | return torch.sign(src) * torch.max(torch.abs(src)-lamb, torch.zeros_like(src))
193 |
194 |
195 | def clear_color(x):
196 | x = x.detach().cpu().squeeze().numpy()
197 | return np.transpose(x, (1, 2, 0))
198 |
199 |
200 | def clear(x):
201 | x = x.detach().cpu().squeeze().numpy()
202 | return x
203 |
204 |
205 | def restore_checkpoint(ckpt_dir, state, device, skip_sigma=False, skip_optimizer=False):
206 | ckpt_dir = Path(ckpt_dir)
207 | if not ckpt_dir.exists():
208 | ckpt_dir.mkdir(parents=True)
209 | logging.error(f"No checkpoint found at {ckpt_dir}. "
210 | f"Returned the same state as input")
211 | FileNotFoundError(f'No such checkpoint: {ckpt_dir} found!')
212 | return state
213 | else:
214 | loaded_state = torch.load(ckpt_dir, map_location=device)
215 | if not skip_optimizer:
216 | state['optimizer'].load_state_dict(loaded_state['optimizer'])
217 | loaded_model_state = loaded_state['model']
218 | if skip_sigma:
219 | loaded_model_state.pop('module.sigmas')
220 |
221 | state['model'].load_state_dict(loaded_model_state, strict=False)
222 | state['ema'].load_state_dict(loaded_state['ema'])
223 | state['step'] = loaded_state['step']
224 | print(f'loaded checkpoint dir from {ckpt_dir}')
225 | return state
226 |
227 |
228 | def save_checkpoint(ckpt_dir, state):
229 | saved_state = {
230 | 'optimizer': state['optimizer'].state_dict(),
231 | 'model': state['model'].state_dict(),
232 | 'ema': state['ema'].state_dict(),
233 | 'step': state['step']
234 | }
235 | torch.save(saved_state, ckpt_dir)
236 |
237 |
238 | """
239 | Helper functions for new types of inverse problems
240 | """
241 |
242 |
243 |
244 | def fft2(x):
245 | """ FFT with shifting DC to the center of the image"""
246 | return torch.fft.fftshift(torch.fft.fft2(x), dim=[-1, -2])
247 |
248 |
249 | def ifft2(x):
250 | """ IFFT with shifting DC to the corner of the image prior to transform"""
251 | return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-1, -2]))
252 |
253 |
254 | def fft2_m(x):
255 | """ FFT for multi-coil """
256 | return torch.view_as_complex(fft2c_new(torch.view_as_real(x)))
257 |
258 |
259 | def ifft2_m(x):
260 | """ IFFT for multi-coil """
261 | return torch.view_as_complex(ifft2c_new(torch.view_as_real(x)))
262 |
263 |
264 | def crop_center(img, cropx, cropy):
265 | c, y, x = img.shape
266 | startx = x // 2 - (cropx // 2)
267 | starty = y // 2 - (cropy // 2)
268 | return img[:, starty:starty + cropy, startx:startx + cropx]
269 |
270 |
271 | def normalize(img):
272 | """ Normalize img in arbitrary range to [0, 1] """
273 | img -= torch.min(img)
274 | img /= torch.max(img)
275 | return img
276 |
277 |
278 | def normalize_np(img):
279 | """ Normalize img in arbitrary range to [0, 1] """
280 | img -= np.min(img)
281 | img /= np.max(img)
282 | return img
283 |
284 |
285 | def normalize_np_kwarg(img, maxv=1.0, minv=0.0):
286 | """ Normalize img in arbitrary range to [0, 1] """
287 | img -= minv
288 | img /= maxv
289 | return img
290 |
291 |
292 | def normalize_complex(img):
293 | """ normalizes the magnitude of complex-valued image to range [0, 1] """
294 | abs_img = normalize(torch.abs(img))
295 | # ang_img = torch.angle(img)
296 | ang_img = normalize(torch.angle(img))
297 | return abs_img * torch.exp(1j * ang_img)
298 |
299 |
300 | def batchfy(tensor, batch_size):
301 | n = len(tensor)
302 | num_batches = n // batch_size + 1
303 | return tensor.chunk(num_batches, dim=0)
304 |
305 |
306 | def img_wise_min_max(img):
307 | img_flatten = img.view(img.shape[0], -1)
308 | img_min = torch.min(img_flatten, dim=-1)[0].view(-1, 1, 1, 1)
309 | img_max = torch.max(img_flatten, dim=-1)[0].view(-1, 1, 1, 1)
310 |
311 | return (img - img_min) / (img_max - img_min)
312 |
313 |
314 | def patient_wise_min_max(img):
315 | std_upper = 3
316 | img_flatten = img.view(img.shape[0], -1)
317 |
318 | std = torch.std(img)
319 | mean = torch.mean(img)
320 |
321 | img_min = torch.min(img_flatten, dim=-1)[0].view(-1, 1, 1, 1)
322 | img_max = torch.max(img_flatten, dim=-1)[0].view(-1, 1, 1, 1)
323 |
324 | min_max_scaled = (img - img_min) / (img_max - img_min)
325 | min_max_scaled_std = (std - img_min) / (img_max - img_min)
326 | min_max_scaled_mean = (mean - img_min) / (img_max - img_min)
327 |
328 | min_max_scaled[min_max_scaled > min_max_scaled_mean +
329 | std_upper * min_max_scaled_std] = 1
330 |
331 | return min_max_scaled
332 |
333 |
334 | def create_sphere(cx, cy, cz, r, resolution=256):
335 | '''
336 | create sphere with center (cx, cy, cz) and radius r
337 | '''
338 | phi = np.linspace(0, 2 * np.pi, 2 * resolution)
339 | theta = np.linspace(0, np.pi, resolution)
340 |
341 | theta, phi = np.meshgrid(theta, phi)
342 |
343 | r_xy = r * np.sin(theta)
344 | x = cx + np.cos(phi) * r_xy
345 | y = cy + np.sin(phi) * r_xy
346 | z = cz + r * np.cos(theta)
347 |
348 | return np.stack([x, y, z])
349 |
350 |
351 | class lambda_schedule:
352 | def __init__(self, total=2000):
353 | self.total = total
354 |
355 | def get_current_lambda(self, i):
356 | pass
357 |
358 |
359 | class lambda_schedule_linear(lambda_schedule):
360 | def __init__(self, start_lamb=1.0, end_lamb=0.0):
361 | super().__init__()
362 | self.start_lamb = start_lamb
363 | self.end_lamb = end_lamb
364 |
365 | def get_current_lambda(self, i):
366 | return self.start_lamb + (self.end_lamb - self.start_lamb) * (i / self.total)
367 |
368 |
369 | class lambda_schedule_const(lambda_schedule):
370 | def __init__(self, lamb=1.0):
371 | super().__init__()
372 | self.lamb = lamb
373 |
374 | def get_current_lambda(self, i):
375 | return self.lamb
376 |
377 |
378 | def image_grid(x, sz=32):
379 | size = sz
380 | channels = 3
381 | img = x.reshape(-1, size, size, channels)
382 | w = int(np.sqrt(img.shape[0]))
383 | img = img.reshape((w, w, size, size, channels)).transpose(
384 | (0, 2, 1, 3, 4)).reshape((w * size, w * size, channels))
385 | return img
386 |
387 |
388 | def show_samples(x, sz=32):
389 | x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
390 | img = image_grid(x, sz)
391 | plt.figure(figsize=(8, 8))
392 | plt.axis('off')
393 | plt.imshow(img)
394 | plt.show()
395 |
396 |
397 | def image_grid_gray(x, size=32):
398 | img = x.reshape(-1, size, size)
399 | w = int(np.sqrt(img.shape[0]))
400 | img = img.reshape((w, w, size, size)).transpose(
401 | (0, 2, 1, 3)).reshape((w * size, w * size))
402 | return img
403 |
404 |
405 | def show_samples_gray(x, size=32, save=False, save_fname=None):
406 | x = x.detach().cpu().numpy()
407 | img = image_grid_gray(x, size=size)
408 | plt.figure(figsize=(8, 8))
409 | plt.axis('off')
410 | plt.imshow(img, cmap='gray')
411 | plt.show()
412 | if save:
413 | plt.imsave(save_fname, img, cmap='gray')
414 |
415 |
416 | def get_mask(img, size, batch_size, type='gaussian2d', acc_factor=8, center_fraction=0.04, fix=False):
417 | mux_in = size ** 2
418 | if type.endswith('2d'):
419 | Nsamp = mux_in // acc_factor
420 | elif type.endswith('1d'):
421 | Nsamp = size // acc_factor
422 | if type == 'gaussian2d':
423 | mask = torch.zeros_like(img)
424 | cov_factor = size * (1.5 / 128)
425 | mean = [size // 2, size // 2]
426 | cov = [[size * cov_factor, 0], [0, size * cov_factor]]
427 | if fix:
428 | samples = np.random.multivariate_normal(mean, cov, int(Nsamp))
429 | int_samples = samples.astype(int)
430 | int_samples = np.clip(int_samples, 0, size - 1)
431 | mask[..., int_samples[:, 0], int_samples[:, 1]] = 1
432 | else:
433 | for i in range(batch_size):
434 | # sample different masks for batch
435 | samples = np.random.multivariate_normal(mean, cov, int(Nsamp))
436 | int_samples = samples.astype(int)
437 | int_samples = np.clip(int_samples, 0, size - 1)
438 | mask[i, :, int_samples[:, 0], int_samples[:, 1]] = 1
439 | elif type == 'uniformrandom2d':
440 | mask = torch.zeros_like(img)
441 | if fix:
442 | mask_vec = torch.zeros([1, size * size])
443 | samples = np.random.choice(size * size, int(Nsamp))
444 | mask_vec[:, samples] = 1
445 | mask_b = mask_vec.view(size, size)
446 | mask[:, ...] = mask_b
447 | else:
448 | for i in range(batch_size):
449 | # sample different masks for batch
450 | mask_vec = torch.zeros([1, size * size])
451 | samples = np.random.choice(size * size, int(Nsamp))
452 | mask_vec[:, samples] = 1
453 | mask_b = mask_vec.view(size, size)
454 | mask[i, ...] = mask_b
455 | elif type == 'gaussian1d':
456 | mask = torch.zeros_like(img)
457 | mean = size // 2
458 | std = size * (15.0 / 128)
459 | Nsamp_center = int(size * center_fraction)
460 | if fix:
461 | samples = np.random.normal(
462 | loc=mean, scale=std, size=int(Nsamp * 1.2))
463 | int_samples = samples.astype(int)
464 | int_samples = np.clip(int_samples, 0, size - 1)
465 | mask[..., int_samples] = 1
466 | c_from = size // 2 - Nsamp_center // 2
467 | mask[..., c_from:c_from + Nsamp_center] = 1
468 | else:
469 | for i in range(batch_size):
470 | samples = np.random.normal(
471 | loc=mean, scale=std, size=int(Nsamp*1.2))
472 | int_samples = samples.astype(int)
473 | int_samples = np.clip(int_samples, 0, size - 1)
474 | mask[i, :, :, int_samples] = 1
475 | c_from = size // 2 - Nsamp_center // 2
476 | mask[i, :, :, c_from:c_from + Nsamp_center] = 1
477 | elif type == 'uniform1d':
478 | mask = torch.zeros_like(img)
479 | if fix:
480 | Nsamp_center = int(size * center_fraction)
481 | samples = np.random.choice(size, int(Nsamp - Nsamp_center))
482 | mask[..., samples] = 1
483 | # ACS region
484 | c_from = size // 2 - Nsamp_center // 2
485 | mask[..., c_from:c_from + Nsamp_center] = 1
486 | else:
487 | for i in range(batch_size):
488 | Nsamp_center = int(size * center_fraction)
489 | samples = np.random.choice(size, int(Nsamp - Nsamp_center))
490 | mask[i, :, :, samples] = 1
491 | # ACS region
492 | c_from = size // 2 - Nsamp_center // 2
493 | mask[i, :, :, c_from:c_from+Nsamp_center] = 1
494 | else:
495 | NotImplementedError(f'Mask type {type} is currently not supported.')
496 |
497 | return mask
498 |
499 |
500 | def nchw_comp_to_real(x):
501 | """
502 | [1, 1, 320, 320] comp --> [1, 2, 320, 320] real
503 | """
504 | x = torch.view_as_real(x)
505 | x = x.squeeze(dim=1)
506 | x = x.permute(0, 3, 1, 2)
507 | return x
508 |
509 | def real_to_nchw_comp(x):
510 | """
511 | [1, 2, 320, 320] real --> [1, 1, 320, 320] comp
512 | """
513 | if len(x.shape) == 4:
514 | x = x[:, 0:1, :, :] + x[:, 1:2, :, :] * 1j
515 | elif len(x.shape) == 3:
516 | x = x[0:1, :, :] + x[1:2, :, :] * 1j
517 | return x
518 |
519 |
520 | def kspace_to_nchw(tensor):
521 | """
522 | Convert torch tensor in (Slice, Coil, Height, Width, Complex) 5D format to
523 | (N, C, H, W) 4D format for processing by 2D CNNs.
524 |
525 | Complex indicates (real, imag) as 2 channels, the complex data format for Pytorch.
526 |
527 | C is the coils interleaved with real and imaginary values as separate channels.
528 | C is therefore always 2 * Coil.
529 |
530 | Singlecoil data is assumed to be in the 5D format with Coil = 1
531 |
532 | Args:
533 | tensor (torch.Tensor): Input data in 5D kspace tensor format.
534 | Returns:
535 | tensor (torch.Tensor): tensor in 4D NCHW format to be fed into a CNN.
536 | """
537 | assert isinstance(tensor, torch.Tensor)
538 | assert tensor.dim() == 5
539 | s = tensor.shape
540 | assert s[-1] == 2
541 | tensor = tensor.permute(dims=(0, 1, 4, 2, 3)).reshape(
542 | shape=(s[0], 2 * s[1], s[2], s[3]))
543 | return tensor
544 |
545 |
546 | def nchw_to_kspace(tensor):
547 | """
548 | Convert a torch tensor in (N, C, H, W) format to the (Slice, Coil, Height, Width, Complex) format.
549 |
550 | This function assumes that the real and imaginary values of a coil are always adjacent to one another in C.
551 | If the coil dimension is not divisible by 2, the function assumes that the input data is 'real' data,
552 | and thus pads the imaginary dimension as 0.
553 | """
554 | assert isinstance(tensor, torch.Tensor)
555 | assert tensor.dim() == 4
556 | s = tensor.shape
557 | if tensor.shape[1] == 1:
558 | imag_tensor = torch.zeros(s, device=tensor.device)
559 | tensor = torch.cat((tensor, imag_tensor), dim=1)
560 | s = tensor.shape
561 | tensor = tensor.view(
562 | size=(s[0], s[1] // 2, 2, s[2], s[3])).permute(dims=(0, 1, 3, 4, 2))
563 | return tensor
564 |
565 |
566 | def root_sum_of_squares(data, dim=0):
567 | """
568 | Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor.
569 | Args:
570 | data (torch.Tensor): The input tensor
571 | dim (int): The dimensions along which to apply the RSS transform
572 | Returns:
573 | torch.Tensor: The RSS value
574 | """
575 | return torch.sqrt((data ** 2).sum(dim))
576 |
577 |
578 | def save_data(fname, arr):
579 | """ Save data as .npy and .png """
580 | np.save(fname + '.npy', arr)
581 | plt.imsave(fname + '.png', arr, cmap='gray')
582 |
583 |
584 | def mean_std(vals: list):
585 | return mean(vals), stdev(vals)
--------------------------------------------------------------------------------