├── LICENSE ├── README.md ├── generate_sample.py ├── images ├── ddim.jpg ├── ddpm.jpg ├── dpm_solver_2.jpg ├── dpmsolverplusplus.jpg ├── exponential_integrator.jpg ├── heun.jpg ├── improved_ddpm.jpg └── pndm.jpg ├── samplers ├── __init__.py ├── ddim.py ├── ddpm.py ├── dpm_solver.py ├── dpm_solver_plus_plus.py ├── exponential_integrator.py ├── heun.py ├── improved_ddpm.py ├── pndm.py └── utils.py └── soft_diffusion ├── README.md ├── graph_wasserstein.py ├── lightning.py ├── soft_diffusion.jpg └── wasserstein.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ozan Ciga 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # diffusion for beginners 2 | 3 | - implementation of _diffusion schedulers_ with minimal code & as faithful to the original work as i could. most recent work reuse or adopt code from previous work and build on it, or transcribe code from another framework - which is great! but i found it hard to follow at times. this is an attempt at simplifying below great papers. the trade-off is made between stability and correctness vs. brevity and simplicity. 4 | 5 | 6 | $$\large{\mathbf{{\color{green}feel\ free\ to\ contribute\ to\ the\ list\ below!}}}$$ 7 | 8 | - [x] [dpm-solver++(2m)](samplers/dpm_solver_plus_plus.py) (lu et al. 2022), dpm-solver++: fast solver for guided sampling of diffusion probabilistic models, https://arxiv.org/abs/2211.01095 9 | - [x] [exponential integrator](samplers/exponential_integrator.py) (zhang et al. 2022), fast sampling of diffusion models with exponential integrator, https://arxiv.org/abs/2204.13902 10 | - [x] [dpm-solver](samplers/dpm_solver.py) (lu et al. 2022), dpm-solver: a fast ode solver for diffusion probabilistic model sampling in around 10 steps, https://arxiv.org/abs/2206.00927 11 | - [x] [heun](samplers/heun.py) (karras et al. 2020), elucidating the design space of diffusion-based generative models, https://arxiv.org/abs/2206.00364 12 | - [x] [pndm](samplers/pndm.py) (ho et al. 2020), pseudo numerical methods for diffusion models, https://arxiv.org/abs/2202.09778 13 | - [x] [ddim](samplers/ddim.py) (song et al. 2020), denoising diffusion implicit models, https://arxiv.org/abs/2010.02502 14 | - [x] [improved ddpm](samplers/improved_ddpm.py) (nichol and dhariwal 2021), improved denoising diffusion probabilistic models,https://arxiv.org/abs/2102.09672 15 | - [x] [ddpm](samplers/ddpm.py) (ho et al. 2020), denoising diffusion probabilistic models, https://arxiv.org/abs/2006.11239 16 | 17 | 18 | 19 | **prompt**: "a man eating an apple sitting on a bench" 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 |
dpm-solver++exponential integrator
32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 |
heundpm-solver
45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 |
ddimpndm
56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 |
ddpmimproved ddpm
68 | 69 | ### * requirements * 70 | while this repository is intended to be educational, if you wish to run and experiment, you'll need to obtain a [token from huggingface](https://huggingface.co/docs/hub/security-tokens) (and paste it to generate_sample.py), and install their excellent [diffusers library](https://github.com/huggingface/diffusers) 71 | 72 | 73 | ### * modification for heun sampler * 74 | heun sampler uses two neural function evaluations per step, and modifies the input as well as the sigma. i wanted to be as faithful to the paper as much as possible, which necessitated changing the sampling code a little. 75 | initiate the sampler as: 76 | ```python 77 | sampler = HeunSampler(num_sample_steps=25, denoiser=pipe.unet, alpha_bar=pipe.scheduler.alphas_cumprod) 78 | init_latents = torch.randn(batch_size, 4, 64, 64).to(device) * sampler.t0 79 | ``` 80 | 81 | and replace the inner loop for generate_sample.py as: 82 | ```python 83 | for t in tqdm(sampler.timesteps): 84 | latents = sampler(latents, t, text_embeddings, guidance_scale) 85 | ``` 86 | 87 | similarly, for dpm-solver-2, 88 | 89 | ```python 90 | sampler = DPMSampler(num_sample_steps=20, denoiser=pipe.unet) 91 | init_latents = torch.randn(batch_size, 4, 64, 64).to(device) * sampler.lmbd(1)[1] 92 | ``` 93 | 94 | and, for fast exponential integrator, 95 | 96 | ```python 97 | sampler = ExponentialSampler(num_sample_steps=50, denoiser=pipe.unet) 98 | init_latents = torch.randn(batch_size, 4, 64, 64).to(device) 99 | ``` 100 | 101 | and, for dpm-solver++ (2m), 102 | 103 | ```python 104 | sampler = DPMPlusPlusSampler(denoiser=pipe.unet, num_sample_steps=20) 105 | init_latents = torch.randn(batch_size, 4, 64, 64).to(device) * sampler.get_coeffs(sampler.t[0])[1] 106 | ``` 107 | 108 | ## soft-diffusion 109 | 110 | a sketch/draft of google's new paper, [soft diffusion: score matching for general corruptions](https://arxiv.org/abs/2209.05442), which achieves state-of-the-art results on celeba-64 dataset. 111 | 112 | details can be found [here](soft_diffusion) -------------------------------------------------------------------------------- /generate_sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from diffusers import StableDiffusionPipeline 4 | from tqdm import tqdm 5 | import matplotlib.pyplot as plt 6 | from samplers import ddpm 7 | from PIL import Image 8 | 9 | 10 | def sample_image(pipe, sampler, prompt, init_latents, batch_size=1, guidance_scale=8.): 11 | # sample image from stable diffusion 12 | with torch.inference_mode(): 13 | latents = init_latents 14 | # prompt conditioning 15 | cond_input = pipe.tokenizer([prompt], padding="max_length", truncation=True, max_length=pipe.tokenizer.model_max_length, return_tensors="pt") 16 | text_embeddings = pipe.text_encoder(cond_input.input_ids.to(pipe.device))[0] 17 | # unconditional (classifier free guidance) 18 | uncond_input = pipe.tokenizer([""] * batch_size, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt") 19 | uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0] 20 | # concat embeddings 21 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 22 | 23 | for t in tqdm(sampler.timesteps): 24 | latent_model_input = torch.cat([latents] * 2) 25 | noise_pred = pipe.unet(latent_model_input, t, text_embeddings).sample 26 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 27 | # @crowsonkb found out subtracting noise obtained from unconditioned input (empty prompt) 28 | # pushes the generated sample toward high density regions given prompt: 29 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 30 | latents = sampler(noise_pred, latents, t) 31 | 32 | with torch.autocast('cuda'): 33 | images = pipe.vae.decode(latents * 1 / 0.18215).sample 34 | 35 | # tensor to image 36 | images = (images / 2 + 0.5).clamp(0, 1) 37 | images = images.cpu().permute(0, 2, 3, 1).numpy() 38 | images = np.uint8(255 * images) 39 | 40 | return images 41 | 42 | 43 | if __name__ == '__main__': 44 | 45 | device = 'cuda' 46 | TOKEN = '' # add your token here 47 | model_name = 'CompVis/stable-diffusion-v1-4' 48 | pipe = StableDiffusionPipeline.from_pretrained(model_name, use_auth_token=TOKEN).to(device) 49 | 50 | sampler = ddpm.DDPMSampler(num_sample_steps=1000) 51 | 52 | prompt = "a man eating an apple sitting on a bench" 53 | batch_size = 1 54 | init_latents = torch.randn(batch_size, 4, 64, 64).to(device) 55 | images = sample_image(pipe, sampler, prompt, init_latents) 56 | # plot the image 57 | plt.figure(); plt.imshow(images[0]); plt.show() 58 | # or save 59 | Image.fromarray(images[0]).save('images/ddpm.jpg', quality=50) 60 | 61 | -------------------------------------------------------------------------------- /images/ddim.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/ddim.jpg -------------------------------------------------------------------------------- /images/ddpm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/ddpm.jpg -------------------------------------------------------------------------------- /images/dpm_solver_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/dpm_solver_2.jpg -------------------------------------------------------------------------------- /images/dpmsolverplusplus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/dpmsolverplusplus.jpg -------------------------------------------------------------------------------- /images/exponential_integrator.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/exponential_integrator.jpg -------------------------------------------------------------------------------- /images/heun.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/heun.jpg -------------------------------------------------------------------------------- /images/improved_ddpm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/improved_ddpm.jpg -------------------------------------------------------------------------------- /images/pndm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/pndm.jpg -------------------------------------------------------------------------------- /samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/samplers/__init__.py -------------------------------------------------------------------------------- /samplers/ddim.py: -------------------------------------------------------------------------------- 1 | from torch import sqrt 2 | import torch 3 | from . import utils 4 | 5 | 6 | class DDIMSampler: 7 | # https://arxiv.org/abs/2010.02502 8 | 9 | def __init__(self, num_sample_steps=50, num_train_timesteps=1000, reverse_sample=True): 10 | beta = utils.get_beta_schedule(num_train_timesteps+1) 11 | # alpha in ddim == alpha^bar in ddpm = \prod_0^t (1-beta) 12 | self.alpha = torch.cumprod(1 - beta, dim=0) 13 | 14 | self.stride = len(self.alpha) // num_sample_steps 15 | self.timesteps = torch.arange(num_train_timesteps+1) # stable diffusion accepts discrete timestep 16 | # make timestep -> alpha mapping explicit to avoid confusion with different sampling steps 17 | self.alpha = {t.item(): alpha for t, alpha in zip(self.timesteps, self.alpha)} 18 | 19 | self.timesteps = self.timesteps[::self.stride] 20 | 21 | if reverse_sample: # generating samples (T) or training the model (F) 22 | self.timesteps = reversed(self.timesteps)[:-1] 23 | 24 | self.reverse_sample = reverse_sample 25 | 26 | def __call__(self, eps_theta, x, t, eta=0): 27 | 28 | t = t.item() 29 | tprev = t - self.stride if self.reverse_sample else t + self.stride 30 | 31 | alpha, alpha_prev = self.alpha[t], self.alpha[tprev] 32 | 33 | eps = torch.randn_like(x) 34 | 35 | # Eqn. 16 36 | sigma_tau = (eta * sqrt((1 - alpha_prev) / (1 - alpha)) * sqrt(1 - alpha / alpha_prev)) if eta > 0 else 0 37 | # sigma_tau interpolates between ddim and ddpm by assigning (0=ddim, 1=ddpm) 38 | # ddim (song et al.) is notationally simpler than ddpm (ho et al.) 39 | # (although ddim uses \alpha to refer to \bar{alpha} from ddpm) 40 | 41 | # Eqn. 12 42 | predicted_x0 = (x - sqrt(1 - alpha) * eps_theta) / sqrt(alpha) 43 | dp_xt = sqrt(1 - alpha_prev - sigma_tau ** 2) # direction pointing to xt 44 | x_prev = sqrt(alpha_prev) * predicted_x0 + dp_xt * eps_theta + sigma_tau * eps 45 | 46 | return x_prev 47 | -------------------------------------------------------------------------------- /samplers/ddpm.py: -------------------------------------------------------------------------------- 1 | from torch import sqrt 2 | import torch 3 | from . import utils 4 | 5 | 6 | class DDPMSampler: 7 | # https://arxiv.org/abs/2006.11239 8 | 9 | def __init__(self, num_sample_steps=500, num_train_timesteps=1000, reverse_sample=True): 10 | 11 | beta = utils.get_beta_schedule(num_train_timesteps+1) 12 | alpha = 1 - beta 13 | alpha_bar = torch.cumprod(alpha, dim=0) 14 | 15 | self.stride = len(alpha) // num_sample_steps 16 | self.timesteps = torch.arange(num_train_timesteps+1) # stable diffusion accepts discrete timestep 17 | 18 | beta = torch.clamp(beta, -torch.inf, 0.99) 19 | alpha_bar = torch.cos((self.timesteps/(num_train_timesteps+1) + 0.008)/(1 + 0.008) * torch.pi/2) 20 | 21 | # make timestep -> alpha/beta mapping explicit 22 | self.beta = {t.item(): beta for t, beta in zip(self.timesteps, beta)} 23 | self.alpha = {t.item(): alpha for t, alpha in zip(self.timesteps, alpha)} 24 | self.alpha_bar = {t.item(): alpha_bar for t, alpha_bar in zip(self.timesteps, alpha_bar)} 25 | 26 | self.timesteps = self.timesteps[::self.stride] 27 | 28 | if reverse_sample: # generating samples (T) or training the model (F) 29 | self.timesteps = reversed(self.timesteps)[:-1] 30 | 31 | self.reverse_sample = reverse_sample 32 | 33 | def __call__(self, eps_theta, x, t): 34 | 35 | t = t.item() 36 | tprev = t - self.stride if self.reverse_sample else t + self.stride 37 | 38 | beta, alpha = self.beta[t], self.alpha[t] 39 | alpha_bar, alpha_bar_prev = self.alpha_bar[t], self.alpha_bar[tprev] 40 | 41 | # two alternatives for \sigma, both equally effective according to ddpm paper 42 | sigma = sqrt(beta) 43 | beta_tilde = (1 - alpha_bar_prev) / (1 - alpha_bar) * beta # eqn (7) 44 | # sigma = sqrt(beta_tilde) 45 | # algorithm 2 46 | z = torch.randn_like(x) if t > 1 else torch.zeros_like(x) 47 | x_prev = 1 / sqrt(alpha) * (x - (1 - alpha) / (1 - alpha_bar) * eps_theta) + sigma * z 48 | 49 | return x_prev 50 | 51 | 52 | -------------------------------------------------------------------------------- /samplers/dpm_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class DPMSampler: 6 | # https://arxiv.org/abs/2206.00927 7 | 8 | def __init__(self, denoiser, num_sample_steps=20, num_train_timesteps=1000, schedule='linear'): 9 | 10 | self.schedule = schedule 11 | T = 1 if self.schedule == 'linear' else 0.9946 12 | 13 | # discrete to continuous- section 3.4, [0,1]->[0,N] 14 | self.denoiser = denoiser # lambda x, t, embd: denoiser(x, int(num_train_timesteps * t), embd).sample 15 | 16 | # A simple way for choosing time steps for lambda is uniformly 17 | # splitting [lambda_t , lambda_eps], which is the setting in our experiments 18 | (_, _, lmbd_max), (_, _, lmbd_min) = self.lmbd(T), self.lmbd(1E-3) 19 | lmbds = torch.linspace(lmbd_max, lmbd_min, num_train_timesteps) 20 | self.timesteps = torch.tensor([self.t_lambda(l, self.schedule) for l in lmbds]) 21 | 22 | self.stride = num_train_timesteps // num_sample_steps 23 | self.timesteps = self.timesteps[::self.stride] 24 | self.num_train_timesteps = num_train_timesteps 25 | 26 | def eps_theta(self, x, t, embd, guidance_scale=None): 27 | t = int(max(0, t - 1 / self.num_train_timesteps) * self.num_train_timesteps) 28 | with torch.inference_mode(): 29 | noise = self.denoiser(x if guidance_scale is None else torch.cat([x] * 2), t, embd).sample 30 | if guidance_scale is not None: 31 | noise_pred_uncond, noise_pred_text = noise.chunk(2) 32 | noise = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 33 | return noise 34 | 35 | @staticmethod 36 | def _alpha(t, s=0.008): 37 | return np.exp(np.log(np.cos(np.pi / 2 * (t + s) / (1 + s))) - 38 | np.log(np.cos(np.pi / 2 * (s) / (1 + s)))) 39 | 40 | @staticmethod 41 | def _lmbd(alpha, sigma): 42 | return np.log(alpha) - np.log(sigma) 43 | 44 | # @staticmethod 45 | def _sigma(self, alpha): 46 | # return 0.5 * np.log(1. - np.exp(2. * self.log_alpha(t))) 47 | return np.sqrt(1 - alpha ** 2) 48 | 49 | @staticmethod 50 | def t_lambda(lmbd, schedule='linear'): 51 | # appendix d.4 52 | if schedule == 'linear': 53 | b0, b1 = 0.1, 20. 54 | nom = 2 * np.logaddexp(-2 * lmbd, 0) # numerically stable log(e**x+e**y) 55 | denom = np.sqrt(b0 ** 2 + 2 * (b1 - b0) * np.logaddexp(-2 * lmbd, 0)) + b0 56 | return nom / denom 57 | elif schedule == 'cosine': 58 | s = 0.008 59 | f_lambda = -1 / 2 * np.logaddexp(-2 * lmbd, 0) 60 | logcos = np.log(np.cos((np.pi * s) / (2 * (1 + s)))) 61 | return 2 * (1 + s) / np.pi * np.arccos(np.exp(f_lambda + logcos)) - s 62 | 63 | def lmbd(self, t): 64 | log_alpha = self.log_alpha(t, self.schedule) 65 | # log_sigma = 0.5 * np.log(1. - np.exp(2. * log_alpha)) 66 | sigma = np.sqrt(1. - np.exp(2 * log_alpha)) 67 | log_sigma = np.log(sigma) 68 | return log_alpha, sigma, (log_alpha-log_sigma) 69 | 70 | def log_alpha(self, t, schedule): 71 | if schedule == 'linear': 72 | b0, b1 = 0.1, 20 73 | return -(b1-b0)/4 * t ** 2 - b0/2*t 74 | elif schedule == 'cosine': 75 | s = 0.008 76 | return np.log(np.cos(np.pi/2*(t+s)/(1+s))) - np.log(np.cos(np.pi/2*(s)/(1+s))) 77 | 78 | def dpm_solver_2(self, x_tilde_prev, t_prev, t, text_embeddings, guidance_scale=None, r1=0.5): 79 | # algorithm 4 80 | # numerically _not_ stable 81 | '''alpha_t, alpha_t_prev = self._alpha(t), self._alpha(t_prev) 82 | sigma_t, sigma_t_prev = self._sigma(alpha_t), self._sigma(alpha_t_prev) 83 | lmbd_t, lmbd_t_prev = self._lmbd(alpha_t, sigma_t), self._lmbd(alpha_t_prev, sigma_t_prev)''' 84 | 85 | # numerically stable 86 | (log_alpha_t, sigma_t, lmbd_t), (log_alpha_t_prev, sigma_t_prev, lmbd_t_prev) = self.lmbd(t), self.lmbd(t_prev) 87 | h = lmbd_t - lmbd_t_prev 88 | s = self.t_lambda(lmbd_t_prev + r1 * h, self.schedule) 89 | log_alpha_s, sigma_s, lmbd_s = self.lmbd(s) 90 | 91 | eps_prev = self.eps_theta(x_tilde_prev, t_prev, text_embeddings, guidance_scale) 92 | u = np.exp(log_alpha_s - log_alpha_t_prev) * x_tilde_prev - sigma_s * np.expm1(r1 * h) * eps_prev 93 | 94 | x_tilde = ( 95 | + np.exp(log_alpha_t - log_alpha_t_prev) * x_tilde_prev 96 | - sigma_t * np.expm1(h) * eps_prev 97 | - sigma_t / (2 * r1) * np.expm1(h) * (self.eps_theta(u, s, text_embeddings, guidance_scale) - eps_prev) 98 | ) 99 | return x_tilde 100 | 101 | def __call__(self, x, t, text_embeddings, guidance_scale=8.): 102 | 103 | if t == self.timesteps[-1]: 104 | return x 105 | 106 | t_prev = t.item() 107 | prev_index = torch.where(self.timesteps < t)[0][0] 108 | t = self.timesteps[prev_index].item() 109 | 110 | return self.dpm_solver_2(x, t_prev, t, text_embeddings, guidance_scale) 111 | -------------------------------------------------------------------------------- /samplers/dpm_solver_plus_plus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import utils 4 | 5 | 6 | class DPMPlusPlusSampler: 7 | # https://arxiv.org/abs/2211.01095 8 | 9 | def __init__(self, denoiser, num_sample_steps=20, num_train_timesteps=1000): 10 | 11 | self.timesteps = np.arange(1, num_train_timesteps+1) 12 | self.t = self.timesteps / num_train_timesteps 13 | 14 | beta = utils.get_beta_schedule(num_train_timesteps+1) 15 | self.alpha = {t: v for t, v in zip(self.t, torch.cumprod(1 - beta, dim=0).sqrt())} 16 | 17 | self.q = dict() 18 | 19 | self.denoiser = denoiser 20 | 21 | self.stride = int(num_train_timesteps / num_sample_steps) 22 | self.timesteps = self.timesteps[::self.stride][::-1] 23 | self.t = self.t[::-1] 24 | self.num_train_timesteps = num_train_timesteps 25 | 26 | def x_theta(self, x, t, embd, guidance_scale=None): 27 | # data prediction model 28 | alpha_t, sigma_t, _ = self.get_coeffs(t) 29 | timestep = int(max(0, t) * self.num_train_timesteps) 30 | with torch.inference_mode(): 31 | noise = self.denoiser(x if guidance_scale is None else torch.cat([x] * 2), timestep, embd).sample 32 | if guidance_scale is not None: 33 | noise_pred_uncond, noise_pred_text = noise.chunk(2) 34 | noise = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 35 | x_pred = (x - sigma_t * noise) / alpha_t # Parameterization: noise prediction and data prediction 36 | # print(t, timestep, x_pred.max(), noise.max(), sigma_t, alpha_t) 37 | return x_pred 38 | 39 | def get_coeffs(self, t): 40 | # appendix b.1 for alphas, sigmas, lambdas 41 | '''# below is used for continuous t, uncomment & modify for your needs if you want to 42 | t_next = t + 1 / self.num_train_timesteps 43 | t_n, t_n_next = sampler.t[np.digitize(t, sampler.t)], sampler.t[np.digitize(t+1/self.num_train_timesteps, sampler.t)] 44 | t_n_next = t_n + self.stride 45 | log_alpha_n, log_alpha_n_next = self.alpha[t].log(), self.alpha[t_next].log() 46 | log_alpha_t = log_alpha_n + (log_alpha_n_next-log_alpha_n)/(t_n_next-t_n) * (t-t_n)''' 47 | log_alpha_t = self.alpha[t].log() # comment this out if above is uncommented 48 | alpha_t = log_alpha_t.exp() 49 | sigma_t = (1-alpha_t ** 2).sqrt() 50 | log_sigma_t = sigma_t.log() 51 | lambda_t = log_alpha_t - log_sigma_t 52 | return alpha_t, sigma_t, lambda_t 53 | 54 | def dpm_solver_plus_plus_2m(self, x_tilde_prev, t_prev_prev, t_prev, t, text_embeddings, guidance_scale=None): 55 | # algorithm 2 56 | alpha_t, sigma_t, lambda_t = self.get_coeffs(t) 57 | _, sigma_t_prev, lambda_t_prev = self.get_coeffs(t_prev) 58 | _, _, lambda_t_prev_prev = self.get_coeffs(t_prev_prev) 59 | h = lambda_t - lambda_t_prev 60 | h_prev = lambda_t_prev - lambda_t_prev_prev 61 | 62 | r = h_prev / h 63 | D = (1 + 1 / (2 * r)) * self.q[t_prev] - 1 / (2 * r) * self.q[t_prev_prev] 64 | x_tilde = sigma_t / sigma_t_prev * x_tilde_prev - alpha_t * torch.expm1(-h) * D 65 | self.q[t] = self.x_theta(x_tilde, t, text_embeddings, guidance_scale) 66 | 67 | return x_tilde 68 | 69 | def __call__(self, x, timestep_index, text_embeddings, guidance_scale=8.): 70 | 71 | # the notation used in this paper is really confusing 72 | # gaussian noise input gets an index of M which maps 73 | # to t_0, i.e., reversed. so 74 | # t0->t1->t2 is M->M-1->M-2. in other words t_prev > t 75 | 76 | t = timestep_index.item() / self.num_train_timesteps 77 | t_prev = (timestep_index.item() + self.stride) / self.num_train_timesteps 78 | t_prev_prev = (timestep_index.item() + 2 * self.stride) / self.num_train_timesteps 79 | 80 | if timestep_index == self.timesteps[0]: 81 | # first step (xt0) 82 | self.q[t] = self.x_theta(x, t, text_embeddings, guidance_scale) 83 | return x # typo in text, x0 in algorithm 2 84 | 85 | elif timestep_index == self.timesteps[1]: 86 | # second step (xt1) 87 | alpha_t, sigma_t, lambda_t = self.get_coeffs(t) 88 | _, sigma_t_prev, lambda_t_prev = self.get_coeffs(t_prev) 89 | h = lambda_t - lambda_t_prev 90 | x_tilde = sigma_t / sigma_t_prev * x - alpha_t * torch.expm1(-h) * self.q[t_prev] 91 | self.q[t] = self.x_theta(x_tilde, t, text_embeddings, guidance_scale) 92 | return x_tilde 93 | 94 | # 2 to M loop 95 | return self.dpm_solver_plus_plus_2m(x, t_prev_prev, t_prev, t, text_embeddings, guidance_scale) 96 | -------------------------------------------------------------------------------- /samplers/exponential_integrator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ExponentialSampler: 6 | # https://arxiv.org/abs/2204.13902 7 | 8 | def __init__( 9 | self, 10 | denoiser, 11 | num_sample_steps=50, 12 | num_train_timesteps=1000, 13 | r=3, 14 | ): 15 | self.r = r 16 | self.denoiser = denoiser 17 | 18 | self.num_train_timesteps, self.num_sample_steps = num_train_timesteps, num_sample_steps 19 | 20 | self.stride = num_train_timesteps // num_sample_steps 21 | self.timesteps = reversed(torch.arange(1, num_sample_steps - 1)) # see algo 1, timestep ends at 1 not zero 22 | 23 | self.t = np.linspace(1E-3, 1, self.num_sample_steps+r) 24 | self.c = self.calc_c() # calculate \mathbf{C} based on Eq. 15, also see table 1 25 | self.eps_buffer = dict() 26 | 27 | @staticmethod 28 | def log_alpha(t, schedule='linear'): 29 | if schedule == 'linear': 30 | b0, b1 = 0.1, 20 31 | return -(b1 - b0) / 4 * t ** 2 - b0 / 2 * t 32 | elif schedule == 'cosine': 33 | s = 0.008 34 | return np.log(np.cos(np.pi / 2 * (t + s) / (1 + s))) - np.log(np.cos(np.pi / 2 * (s) / (1 + s))) 35 | 36 | def calc_c(self): 37 | 38 | from scipy.interpolate import interp1d 39 | log_alpha_fn = interp1d( 40 | np.linspace(0, 1+0.01, self.num_train_timesteps+1), 41 | self.log_alpha(np.linspace(0, 1, self.num_train_timesteps+1), schedule='linear')) 42 | 43 | from scipy.misc import derivative 44 | G = lambda tau: np.sqrt(-derivative(log_alpha_fn, tau, dx=1e-5)) 45 | L = lambda tau: np.sqrt(1-np.exp(log_alpha_fn(tau))) 46 | 47 | def prod_fn(tau, i, j, r): 48 | prod = 1. 49 | for k in range(r+1): 50 | prod *= (tau - self.t[i + k]) / (self.t[i + j] - self.t[i + k]) if k != j else 1 51 | return prod 52 | 53 | def get_cij(tau, i, j, r): 54 | # Eq. 15. ignoring inverse and transpose since G, L are 1d 55 | return 1 / 2 * self.psi(self.t[i - 1], tau) * G(tau) ** 2 * 1 / L(tau) * prod_fn(tau, i, j, r) 56 | 57 | cij = dict() 58 | from scipy import integrate 59 | for i in range(self.num_sample_steps): 60 | for j in range(self.r+1): 61 | cij[i, j] = integrate.quad(get_cij, self.t[i], self.t[i - 1], args=(i, j, self.r), epsrel=1e-4)[0] 62 | 63 | return cij 64 | 65 | def psi(self, t, s): 66 | return np.sqrt(np.exp((self.log_alpha(t) - self.log_alpha(s)))) # appendix d, PROOF OF PROP 2 67 | 68 | def eps_theta(self, x, i, embd, guidance_scale=None): 69 | with torch.inference_mode(): 70 | noise = self.denoiser(x if guidance_scale is None else torch.cat([x] * 2), i, embd).sample 71 | if guidance_scale is not None: 72 | noise_pred_uncond, noise_pred_text = noise.chunk(2) 73 | noise = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 74 | return noise 75 | 76 | def __call__(self, x_hat, i, text_embeddings, guidance_scale=8.): 77 | # algorithm 1. 78 | i = int(i) 79 | self.eps_buffer[self.t[i]] = self.eps_theta(x_hat, i*self.stride, text_embeddings, guidance_scale) 80 | x_hat_prev = self.psi(self.t[i-1], self.t[i]) * x_hat + sum( 81 | [self.c[i, j] * self.eps_buffer[self.t[i+j]] for j in range(self.r+1) if self.t[i+j] in self.eps_buffer] 82 | ) # Eq. 14 83 | return x_hat_prev 84 | -------------------------------------------------------------------------------- /samplers/heun.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import sqrt as msqrt 3 | 4 | 5 | class HeunSampler: 6 | # https://arxiv.org/abs/2010.02502 7 | 8 | def __init__( 9 | self, 10 | num_sample_steps=50, 11 | num_train_timesteps=1000, 12 | denoiser=None, 13 | ddpm=True, 14 | alpha_bar=None, 15 | ): 16 | self.denoiser = denoiser 17 | 18 | self.num_train_timesteps, self.num_sample_steps = num_train_timesteps, num_sample_steps 19 | 20 | self.ddpm_sigmas = ((1-alpha_bar)/alpha_bar).sqrt() 21 | self.t0 = self.ddpm_sigmas[-1] # or sigma_max/max noise 22 | 23 | self.stride = len(alpha_bar) // num_sample_steps 24 | self.timesteps = torch.arange(num_train_timesteps+1) 25 | 26 | # table 1 27 | sigma_min, sigma_max, rho = self.ddpm_sigmas[0], self.ddpm_sigmas[-1], 7 # stable diff. 28 | _heun_sigmas = (sigma_max ** (1 / rho) + (self.timesteps / self.num_sample_steps) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 29 | self.heun_sigmas = _heun_sigmas 30 | self.sigma_data = 0.5 31 | self.timesteps = torch.arange(num_sample_steps - 1) 32 | 33 | self.ddpm = ddpm # model trained with ddpm objective, used to set c_in etc 34 | 35 | def c_skip(self, sigma): 36 | if self.ddpm: 37 | return torch.tensor(1.) 38 | return self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 39 | 40 | def c_out(self, sigma): 41 | if self.ddpm: 42 | return -sigma 43 | return sigma * self.sigma_data * (self.sigma_data**2 + sigma**2).rsqrt() 44 | 45 | def c_in(self, sigma): 46 | if self.ddpm: 47 | return (1 + sigma ** 2).rsqrt() 48 | return (sigma ** 2 + self.sigma_data ** 2).rsqrt() 49 | 50 | def c_noise(self, sigma): 51 | if self.ddpm: 52 | return torch.abs(self.ddpm_sigmas-sigma).argmin() # iddpm practical (c.3.4: 3rd bullet) 53 | return 1 / 4 * sigma.log() 54 | 55 | def d_theta(self, x, sigma, encoder_hidden_states, guidance_scale=None): 56 | 57 | eps_theta = self.predict_noise( 58 | x if guidance_scale is None else torch.cat([x] * 2), sigma, encoder_hidden_states) 59 | # hacky bit: if guidance!=None, perform classifier free guidance 60 | # deviates from original implementation to support minimal code 61 | # duplication, even though it's not good practice 62 | if guidance_scale is not None: 63 | noise_pred_uncond, noise_pred_text = eps_theta.chunk(2) 64 | eps_theta = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 65 | 66 | return self.c_skip(sigma) * x + self.c_out(sigma) * eps_theta 67 | 68 | def f_theta(self, x, sigma, encoder_hidden_states): 69 | with torch.inference_mode(): 70 | return self.denoiser(x, sigma, encoder_hidden_states).sample 71 | 72 | def predict_noise(self, x, sigma, encoder_hidden_states): 73 | # helper function to predict noise for stable diffusion 74 | # since heun modifies both input and the stdev, using 75 | # sd pipeline directly does not yield the correct f_theta output 76 | return self.f_theta(self.c_in(sigma) * x, self.c_noise(sigma), encoder_hidden_states) 77 | 78 | def stochastic_sampler(self, x, i, encoder_hidden_states, s_churn=80, s_tmin=0.05, s_tmax=50, s_noise=1.003): 79 | # Algorithm 2 - arg defaults from table 5 80 | x_next, x_hat, t_next, t_hat, d = self.euler_method( 81 | x, i, encoder_hidden_states, s_churn, s_tmin, s_tmax, s_noise) # first order denoising 82 | x_next = self.heuns_correction(x_next, x_hat, t_next, t_hat, d, encoder_hidden_states) # second order correction 83 | return x_next 84 | 85 | def euler_method(self, x, i, encoder_hidden_states=None, guidance_scale=None, s_churn=80, s_tmin=0.05, s_tmax=50, s_noise=1.003): 86 | # euler method, omitting second order correction @ lines 9-11 (algorithm 2) 87 | t, t_next = self.heun_sigmas[i], self.heun_sigmas[i+1] 88 | 89 | eps = torch.randn_like(x) * s_noise 90 | gamma = min(s_churn / self.num_sample_steps, msqrt(2) - 1) if s_tmin <= t <= s_tmax else 0 91 | t_hat = t + gamma * t 92 | if self.ddpm: 93 | t_hat = self.ddpm_sigmas[torch.abs(self.ddpm_sigmas-t_hat).argmin()] # iddpm practical, step 3 94 | x_hat = x + msqrt(max(t_hat ** 2 - t ** 2, 0)) * eps # gamma < 0 -> negative sqrt 95 | d = (x_hat - self.d_theta(x_hat, t_hat, encoder_hidden_states, guidance_scale)) / t_hat 96 | x_next = x_hat + (t_next - t_hat) * d 97 | 98 | return x_next, x_hat, t_next, t_hat, d 99 | 100 | def heuns_correction(self, x_next, x_hat, t_next, t_hat, d, encoder_hidden_states, guidance_scale=None): 101 | # correction improves differentiation discretization error from o(h^2) -> o(h^3), h = step size 102 | # correction addresses varying step sizes between two different timesteps 103 | # see Discretization and higher-order integrators from the paper 104 | if t_next != 0: 105 | d_prime = (x_next - self.d_theta(x_next, t_next, encoder_hidden_states, guidance_scale)) / t_next 106 | x_next = x_hat + (t_next - t_hat) * (1 / 2 * d + 1 / 2 * d_prime) 107 | return x_next 108 | 109 | def alpha_sampler(self, x, i, encoder_hidden_states, guidance_scale=None, alpha=1): 110 | # Algorithm 3 111 | t, t_next = self.heun_sigmas[i], self.heun_sigmas[i+1] 112 | h = t_next - t 113 | d = (x - self.d_theta(x, t, encoder_hidden_states, guidance_scale)) / t 114 | x_prime, t_prime = x + alpha * h * d, t + alpha * h 115 | if t_prime != 0: 116 | d_prime = (x_prime - self.d_theta(x_prime, t_prime, encoder_hidden_states, guidance_scale)) / t_prime 117 | x_next = x + h * ((1 - 0.5 / alpha) * d + 0.5 / alpha * d_prime) 118 | else: 119 | x_next = x + h*d 120 | 121 | return x_next 122 | 123 | def __call__(self, *args): 124 | return self.stochastic_sampler(*args) 125 | -------------------------------------------------------------------------------- /samplers/improved_ddpm.py: -------------------------------------------------------------------------------- 1 | from torch import sqrt 2 | import torch 3 | from . import utils 4 | 5 | 6 | class ImprovedDDPMSampler: 7 | # https://arxiv.org/abs/2102.09672 8 | 9 | def __init__(self, num_sample_steps=500, num_train_timesteps=1000, reverse_sample=True): 10 | 11 | self.timesteps = torch.arange(num_train_timesteps+1) # stable diffusion accepts discrete timestep 12 | 13 | T, s = num_train_timesteps+1, 0.008 14 | alpha_bar = torch.cos((self.timesteps/T + s)/(1 + s) * torch.pi/2) ** 2 # (17), f(0) ~ 0.99999 so plug f(t)=\bar{\alpha} 15 | alpha_bar_prev = torch.hstack([alpha_bar[0], alpha_bar[:-1]]) 16 | beta = 1 - alpha_bar / alpha_bar_prev 17 | alpha = 1 - beta 18 | beta = torch.clamp(beta, -torch.inf, 0.999) 19 | 20 | self.stride = len(alpha) // num_sample_steps 21 | 22 | # make timestep -> alpha/beta mapping explicit 23 | self.beta = {t.item(): beta for t, beta in zip(self.timesteps, beta)} 24 | self.alpha = {t.item(): alpha for t, alpha in zip(self.timesteps, alpha)} 25 | self.alpha_bar = {t.item(): alpha_bar for t, alpha_bar in zip(self.timesteps, alpha_bar)} 26 | 27 | self.timesteps = self.timesteps[::self.stride] 28 | 29 | if reverse_sample: # generating samples (T) or training the model (F) 30 | self.timesteps = reversed(self.timesteps)[:-1] 31 | 32 | self.reverse_sample = reverse_sample 33 | 34 | def __call__(self, eps_theta, x, t): 35 | 36 | t = t.item() 37 | tprev = t - self.stride if self.reverse_sample else t + self.stride 38 | 39 | beta, alpha = self.beta[t], self.alpha[t] 40 | alpha_bar, alpha_bar_prev = self.alpha_bar[t], self.alpha_bar[tprev] 41 | 42 | # eqn 3 43 | mu = alpha.rsqrt() * (x - beta * (1 - alpha_bar).rsqrt() * eps_theta) # eq (13) reciprocal sqrt should be more numerically stable 44 | # variance, should be learned but i'm just experimenting w/ plugging (15) directly w/ v=1/2 45 | beta_tilde = (1 - alpha_bar_prev) / (1 - alpha_bar) * beta # eqn (7) 46 | v = 0.5 47 | sigma_theta = torch.exp(v*beta.log() + (1-v)*beta_tilde.log()).sqrt() # (15) 48 | 49 | z = torch.randn_like(x) if t > 1 else torch.zeros_like(x) 50 | x_prev = mu + sigma_theta * z 51 | 52 | return x_prev 53 | -------------------------------------------------------------------------------- /samplers/pndm.py: -------------------------------------------------------------------------------- 1 | from torch import sqrt 2 | import torch 3 | from . import utils 4 | 5 | 6 | class PNDMSampler: 7 | # https://arxiv.org/abs/2202.09778 8 | 9 | def __init__(self, num_sample_steps=50, num_train_timesteps=1000, reverse_sample=True, ddim=False): 10 | 11 | beta = utils.get_beta_schedule(num_train_timesteps+1) 12 | alpha = 1 - beta 13 | self.alpha_bar = torch.cumprod(alpha, dim=0) 14 | 15 | self.stride = len(self.alpha_bar) // num_sample_steps 16 | self.timesteps = torch.arange(num_train_timesteps+1) # stable diffusion accepts discrete timestep 17 | self.alpha_bar = {t.item(): alpha for t, alpha in zip(self.timesteps, self.alpha_bar)} 18 | 19 | self.timesteps = self.timesteps[::self.stride] 20 | 21 | if reverse_sample: # generating samples (T) or training the model (F) 22 | self.timesteps = reversed(self.timesteps)[:-1] 23 | 24 | self.reverse_sample = reverse_sample 25 | 26 | self.ddim = ddim # alg. 1 from the paper 27 | 28 | self.et = [] 29 | 30 | def prk(self, xt, et, t, tpd): 31 | ''' 32 | # reuse et 4 times (from the paper): 33 | we find that the linear multi-step method can reuse 34 | the result of \eps four times and only compute \eps 35 | once at every step. 36 | --- 37 | so we can just return e_t for each \eps_{\theta}(.,.) call 38 | ''' 39 | def eps_theta(*args): 40 | return et 41 | delta = tpd - t 42 | tpd2 = t + delta/2 # t plus d/2 43 | # Eqn 13 44 | et1 = eps_theta(xt, t) 45 | xt1 = self.phi(xt, et1, t, tpd2) 46 | et2 = eps_theta(xt1, tpd2) 47 | xt2 = self.phi(xt, et2, t, tpd2) 48 | et3 = eps_theta(xt2, tpd2) 49 | xt3 = self.phi(xt, et3, t, tpd) # paper has a typo here 50 | et4 = eps_theta(xt3, tpd) 51 | et_d = 1/6 * (et1 + 2*et2 + 2*et3 + et4) # e_t' 52 | xtpd = self.phi(xt, et_d, t, tpd) # another typo, this was defined as xtmd 53 | return xtpd, et1 54 | 55 | def plms(self, xt, et, t, tpd): 56 | # Eqn 12 57 | etm3d, etm2d, etmd = self.et[-3:] 58 | et_d = 1 / 24 * (55 * et - 59 * etmd + 37 * etm2d - 9 * etm3d) # e_t' 59 | xtpd = self.phi(xt, et_d, t, tpd) 60 | return xtpd, et 61 | 62 | def add_to_error_queue(self, item): 63 | self.et.append(item) 64 | if len(self.et) > 3: # no need to store > 3 e_t (see eqn 12) 65 | del self.et[0] 66 | 67 | def phi(self, xt, et, t, t_prev): 68 | # i.e., the transfer part (eqn 11), t_prev = t-\delta 69 | alpha_bar, alpha_bar_prev = self.alpha_bar[t], self.alpha_bar[t_prev] 70 | 71 | transfer = sqrt(alpha_bar_prev)/sqrt(alpha_bar) * xt # first term in the eqn. 72 | # second term of the eqn, split into nominator&denom for clarity: 73 | nom = alpha_bar_prev-alpha_bar 74 | denom = sqrt(alpha_bar) * (sqrt((1-alpha_bar_prev)*alpha_bar) + sqrt((1-alpha_bar)*alpha_bar_prev)) 75 | transfer = transfer - nom/denom*et 76 | 77 | return transfer 78 | 79 | def __call__(self, eps_theta, xtp1, t): 80 | t = t.item() 81 | tpd = t - self.stride # \delta = -1 in algorithm 2 (pndms) by inspection 82 | if self.ddim: # algorithm 1 83 | return self.phi(xtp1, eps_theta, t, tpd) 84 | else: # algorithm 2 85 | if len(self.et) < 3: 86 | xtpd, et = self.prk(xtp1, eps_theta, t, tpd) 87 | else: 88 | xtpd, et = self.plms(xtp1, eps_theta, t, tpd) 89 | self.add_to_error_queue(et) 90 | 91 | return xtpd # note that xtpd = x_t - 1 since d = -1 92 | -------------------------------------------------------------------------------- /samplers/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_beta_schedule(num_diffusion_timesteps=1000, beta_start=0.00085, beta_end=0.012): 5 | # beta schedule, start, end are from stable diffusion config 6 | betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps) ** 2 7 | return betas 8 | -------------------------------------------------------------------------------- /soft_diffusion/README.md: -------------------------------------------------------------------------------- 1 | # Soft Diffusion: Score Matching for General Corruptions 2 | 3 | [arxiv link to paper](https://arxiv.org/abs/2209.05442) 4 | 5 | 6 | 7 | sketch of google's new paper, which achieves state-of-the-art results on celeba-64 dataset. 8 | 9 | this implementation is missing some details & involves some guesswork 10 | 11 | because the paper did not define certain variables and implementation details in its current version (v2). -------------------------------------------------------------------------------- /soft_diffusion/graph_wasserstein.py: -------------------------------------------------------------------------------- 1 | ''' 2 | graph based selection of gaussian blur levels 3 | plug-in the resulting sigmas to the training script 4 | ''' 5 | 6 | from scipy.sparse import csr_matrix 7 | from scipy.sparse.csgraph import shortest_path 8 | from soft_diffusion.wasserstein import calculate_2_wasserstein_dist 9 | import torch 10 | import numpy as np 11 | import torchvision.transforms as tt 12 | from torch.utils.data import DataLoader 13 | import torchvision 14 | 15 | 16 | def get_shortest_path(M, start_node, end_node): 17 | # https://stackoverflow.com/a/53078901/1696420 18 | path = [end_node] 19 | k = end_node 20 | while M[start_node, k] != -9999: 21 | path.append(Pr[start_node, k]) 22 | k = M[start_node, k] 23 | return path[::-1] 24 | 25 | 26 | if __name__ == '__main__': 27 | 28 | gbs_min, gbs_max = 0.01, 23 # gaussian blur sigma - appendix D, blur 29 | T = 256 30 | sigma_blurs = np.linspace(gbs_min, gbs_max, T) 31 | 32 | distributions = dict() 33 | for sigma in sigma_blurs: 34 | distributions[sigma] = list() 35 | 36 | data_root = '' 37 | data_transform = tt.Compose([ 38 | tt.Resize((64, 64)), 39 | tt.ToTensor(), 40 | tt.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 41 | ]) 42 | dataset = torchvision.datasets.CelebA(data_root, split='train', target_type='attr', transform=data_transform, download=True) 43 | loader = DataLoader(dataset, batch_size=64, shuffle=True) 44 | for batch in loader: 45 | distributions[sigma].append(batch[0]) 46 | 47 | 48 | # compute wasserstein distances between each edge 49 | # ``we start with 256 different blur levels and we tune \eps such that 50 | # the shortest path contains 32 distributions. 51 | # we then use linear interpolation to extend to the continuous case." 52 | eps_dist_max = 1E-3 # dataset specific, eq 19, must set! this is arbitrary 53 | 54 | wasserstein_dist = torch.empty((T, T)) 55 | for index, sigma in enumerate(sigma_blurs): 56 | for next_index, next_sigma in enumerate(sigma_blurs): 57 | if index == next_index: 58 | wasserstein_dist[index, next_index] = torch.inf 59 | else: 60 | wdist = calculate_2_wasserstein_dist(distributions[sigma], distributions[next_sigma]) 61 | if wdist > eps_dist_max: 62 | wdist = torch.inf 63 | wasserstein_dist[index, next_index] = wdist 64 | 65 | wasserstein_dist = csr_matrix(wasserstein_dist.cpu().tolist()) 66 | D, Pr = shortest_path(wasserstein_dist, directed=False, method='FW', return_predecessors=True) 67 | # get the list of sigmas giving the shortest path 68 | sigma_index = get_shortest_path(wasserstein_dist, start_node=0, end_node=T-1) 69 | sigma_blurs = [sigma for j, sigma in enumerate(sigma_blurs) if j in sigma_index] 70 | assert len(sigma_blurs) == 32, 'adjust eps_dist_max in eqn 19 until you get 32 nodes/sigmas' 71 | 72 | print(sigma_blurs) -------------------------------------------------------------------------------- /soft_diffusion/lightning.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from inspect import getfullargspec 3 | import pytorch_lightning as pl 4 | from torch import optim 5 | import torch 6 | import numpy as np 7 | from functools import partial 8 | import torchvision.transforms.functional as ttf 9 | from scipy.interpolate import interp1d 10 | import k_diffusion as K # using crowsonkb's k-diffusion (github.com/crowsonkb/k-diffusion) 11 | 12 | 13 | def blur(discrete_sigmas, half_kernel=80, image_size=64): 14 | # gaussian blur, in place of C matrix 15 | pad_val = int(np.ceil(((half_kernel * 2 + 1)/2 - image_size) / 2 + 1)) 16 | padder = partial(torch.nn.functional.pad, pad=(pad_val, pad_val, pad_val, pad_val), mode='reflect') 17 | 18 | discretize_fn = interp1d(np.linspace(0, 1, len(discrete_sigmas)), discrete_sigmas) 19 | 20 | def fn(image, sigma): 21 | sigma = discretize_fn(sigma.item()) 22 | image = ttf.gaussian_blur(padder(image), kernel_size=[half_kernel * 2 + 1, half_kernel * 2 + 1], sigma=sigma) 23 | image = ttf.center_crop(image, [image_size, image_size]) 24 | return image 25 | return fn 26 | 27 | 28 | def get_sigmas(sigma_min=0.01, sigma_max=0.10, t_clip=0.25): 29 | # geometric scheduling 30 | # roughly at 25%, max noise level, see fig 6 31 | r = np.exp(np.log(sigma_max / sigma_min) / t_clip) - 1 # y0*(1+r)**t = yt 32 | 33 | def fn(t): 34 | return sigma_min * (1 + r) ** t 35 | return fn 36 | 37 | 38 | class LightningDiffusion(pl.LightningModule): 39 | def __init__(self, config_path, discrete_sigmas): 40 | # config path is a json file, e.g.: https://github.com/crowsonkb/k-diffusion/blob/master/configs/config_32x32_small.json 41 | # may need to update arguments depending on resolution, e.g., 42 | # "depths": [2, 2, 4, 4], 43 | # "channels": [128, 256, 256, 512], 44 | # "self_attn_depths": [false, false, true, true], 45 | 46 | # see graph_wasserstein.py to get discrete sigmas for your dataset 47 | super().__init__() 48 | 49 | # Diffusion model 50 | config = K.config.load_config(open(config_path)) 51 | self.phi_theta = K.config.make_model(config) 52 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) 53 | 54 | self.sigma = get_sigmas() 55 | self.dt = 0.001 # \delta_t 56 | self.C = blur(discrete_sigmas) 57 | # gaussian blur C is a linear operator, using the fixed blur in paper - ideally this would be a matrix (C_t * x) 58 | 59 | def training_step(self, batch, batch_idx): 60 | x0 = batch[0] 61 | t = self.rng.draw(x0.shape[0])[:, 0].to(x0.device) # Sample timesteps 62 | residual = self.sigma(t) * torch.randn_like(x0) 63 | x = self.C(x0, t) + residual # eq 4 64 | noise_pred = self.phi_theta(x, t) # (r_hat|x_t) 65 | loss = (self.sigma(t) ** -2) * F.mse_loss(self.C(residual, t), self.C(noise_pred, t), reduction='none') # Eq 12. 66 | return loss.mean() 67 | 68 | def configure_optimizers(self): 69 | opt = optim.Adam(self.phi_theta.parameters(), lr=2E-4, betas=(0.9, 0.999), eps=1E-8) 70 | 71 | def lr_sched(step): 72 | return min(1, step / 5_000) # depending on batch size, this should change 73 | sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_sched) 74 | 75 | opt_config = {"optimizer": opt} 76 | if sched is not None: 77 | opt_config["lr_scheduler"] = { 78 | "scheduler": sched, 79 | "interval": "step", 80 | } 81 | return opt_config 82 | 83 | def naive_sampler(self): 84 | x = self.sample_gauss() 85 | for t in torch.linspace(1, 0, int(1/self.dt)): 86 | x0_hat = self.phi_theta(x, t) + x 87 | eta = torch.randn_like(x) 88 | x = self.C(x0_hat, t - self.dt) + self.sigma(t - self.dt) * eta 89 | return x 90 | 91 | def momentum_sampler(self): 92 | x = self.sample_gauss() 93 | for t in torch.linspace(1, 0, int(1/self.dt)): 94 | x0_hat = self.phi_theta(x, t) + x 95 | y_hat = self.C(x0_hat, t) 96 | eta = torch.randn_like(x) 97 | eps_hat = y_hat - x 98 | z = x - ((self.sigma(t - self.dt) / self.sigma(t)) ** 2 - 1) * eps_hat + (self.sigma(t) ** 2 - self.sigma(t - self.dt) ** 2).sqrt() * eta 99 | y_hat_prev = self.C(x0_hat, t - self.dt) 100 | x = z + y_hat_prev - y_hat 101 | return x 102 | 103 | def sample_gauss(self): 104 | # distribution p_1 not defined in the paper, this is my best guess 105 | x = torch.rand((1, 3, 64, 64)) 106 | x = self.C(x, 23.) 107 | return x 108 | 109 | 110 | -------------------------------------------------------------------------------- /soft_diffusion/soft_diffusion.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/soft_diffusion/soft_diffusion.jpg -------------------------------------------------------------------------------- /soft_diffusion/wasserstein.py: -------------------------------------------------------------------------------- 1 | # https://gist.github.com/Flunzmas/6e359b118b0730ab403753dcc2a447df#file-calc_2_wasserstein_dist-py 2 | import math 3 | import torch 4 | import torch.linalg as linalg 5 | 6 | 7 | def calculate_2_wasserstein_dist(X, Y): 8 | ''' 9 | Calulates the two components of the 2-Wasserstein metric: 10 | The general formula is given by: d(P_X, P_Y) = min_{X, Y} E[|X-Y|^2] 11 | 12 | For multivariate gaussian distributed inputs z_X ~ MN(mu_X, cov_X) and z_Y ~ MN(mu_Y, cov_Y), 13 | this reduces to: d = |mu_X - mu_Y|^2 - Tr(cov_X + cov_Y - 2(cov_X * cov_Y)^(1/2)) 14 | 15 | Fast method implemented according to following paper: https://arxiv.org/pdf/2009.14075.pdf 16 | 17 | Input shape: [b, n] (e.g. batch_size x num_features) 18 | Output shape: scalar 19 | ''' 20 | 21 | if X.shape != Y.shape: 22 | raise ValueError("Expecting equal shapes for X and Y!") 23 | 24 | # the linear algebra ops will need some extra precision -> convert to double 25 | X, Y = X.transpose(0, 1).double(), Y.transpose(0, 1).double() # [n, b] 26 | mu_X, mu_Y = torch.mean(X, dim=1, keepdim=True), torch.mean(Y, dim=1, keepdim=True) # [n, 1] 27 | n, b = X.shape 28 | fact = 1.0 if b < 2 else 1.0 / (b - 1) 29 | 30 | # Cov. Matrix 31 | E_X = X - mu_X 32 | E_Y = Y - mu_Y 33 | cov_X = torch.matmul(E_X, E_X.t()) * fact # [n, n] 34 | cov_Y = torch.matmul(E_Y, E_Y.t()) * fact 35 | 36 | # calculate Tr((cov_X * cov_Y)^(1/2)). with the method proposed in https://arxiv.org/pdf/2009.14075.pdf 37 | # The eigenvalues for M are real-valued. 38 | C_X = E_X * math.sqrt(fact) # [n, n], "root" of covariance 39 | C_Y = E_Y * math.sqrt(fact) 40 | M_l = torch.matmul(C_X.t(), C_Y) 41 | M_r = torch.matmul(C_Y.t(), C_X) 42 | M = torch.matmul(M_l, M_r) 43 | S = linalg.eigvals(M) + 1e-15 # add small constant to avoid infinite gradients from sqrt(0) 44 | sq_tr_cov = S.sqrt().abs().sum() 45 | 46 | # plug the sqrt_trace_component into Tr(cov_X + cov_Y - 2(cov_X * cov_Y)^(1/2)) 47 | trace_term = torch.trace(cov_X + cov_Y) - 2.0 * sq_tr_cov # scalar 48 | 49 | # |mu_X - mu_Y|^2 50 | diff = mu_X - mu_Y # [n, 1] 51 | mean_term = torch.sum(torch.mul(diff, diff)) # scalar 52 | 53 | # put it together 54 | return (trace_term + mean_term).float() --------------------------------------------------------------------------------