├── LICENSE
├── README.md
├── generate_sample.py
├── images
├── ddim.jpg
├── ddpm.jpg
├── dpm_solver_2.jpg
├── dpmsolverplusplus.jpg
├── exponential_integrator.jpg
├── heun.jpg
├── improved_ddpm.jpg
└── pndm.jpg
├── samplers
├── __init__.py
├── ddim.py
├── ddpm.py
├── dpm_solver.py
├── dpm_solver_plus_plus.py
├── exponential_integrator.py
├── heun.py
├── improved_ddpm.py
├── pndm.py
└── utils.py
└── soft_diffusion
├── README.md
├── graph_wasserstein.py
├── lightning.py
├── soft_diffusion.jpg
└── wasserstein.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Ozan Ciga
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # diffusion for beginners
2 |
3 | - implementation of _diffusion schedulers_ with minimal code & as faithful to the original work as i could. most recent work reuse or adopt code from previous work and build on it, or transcribe code from another framework - which is great! but i found it hard to follow at times. this is an attempt at simplifying below great papers. the trade-off is made between stability and correctness vs. brevity and simplicity.
4 |
5 |
6 | $$\large{\mathbf{{\color{green}feel\ free\ to\ contribute\ to\ the\ list\ below!}}}$$
7 |
8 | - [x] [dpm-solver++(2m)](samplers/dpm_solver_plus_plus.py) (lu et al. 2022), dpm-solver++: fast solver for guided sampling of diffusion probabilistic models, https://arxiv.org/abs/2211.01095
9 | - [x] [exponential integrator](samplers/exponential_integrator.py) (zhang et al. 2022), fast sampling of diffusion models with exponential integrator, https://arxiv.org/abs/2204.13902
10 | - [x] [dpm-solver](samplers/dpm_solver.py) (lu et al. 2022), dpm-solver: a fast ode solver for diffusion probabilistic model sampling in around 10 steps, https://arxiv.org/abs/2206.00927
11 | - [x] [heun](samplers/heun.py) (karras et al. 2020), elucidating the design space of diffusion-based generative models, https://arxiv.org/abs/2206.00364
12 | - [x] [pndm](samplers/pndm.py) (ho et al. 2020), pseudo numerical methods for diffusion models, https://arxiv.org/abs/2202.09778
13 | - [x] [ddim](samplers/ddim.py) (song et al. 2020), denoising diffusion implicit models, https://arxiv.org/abs/2010.02502
14 | - [x] [improved ddpm](samplers/improved_ddpm.py) (nichol and dhariwal 2021), improved denoising diffusion probabilistic models,https://arxiv.org/abs/2102.09672
15 | - [x] [ddpm](samplers/ddpm.py) (ho et al. 2020), denoising diffusion probabilistic models, https://arxiv.org/abs/2006.11239
16 |
17 |
18 |
19 | **prompt**: "a man eating an apple sitting on a bench"
20 |
21 |
22 |
23 |
24 |  |
25 |  |
26 |
27 |
28 | dpm-solver++ |
29 | exponential integrator |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |  |
38 |  |
39 |
40 |
41 | heun |
42 | dpm-solver |
43 |
44 |
45 |
46 |
47 |
48 |  |
49 |  |
50 |
51 |
52 | ddim |
53 | pndm |
54 |
55 |
56 |
57 |
58 |
59 |
60 |  |
61 |  |
62 |
63 |
64 | ddpm |
65 | improved ddpm |
66 |
67 |
68 |
69 | ### * requirements *
70 | while this repository is intended to be educational, if you wish to run and experiment, you'll need to obtain a [token from huggingface](https://huggingface.co/docs/hub/security-tokens) (and paste it to generate_sample.py), and install their excellent [diffusers library](https://github.com/huggingface/diffusers)
71 |
72 |
73 | ### * modification for heun sampler *
74 | heun sampler uses two neural function evaluations per step, and modifies the input as well as the sigma. i wanted to be as faithful to the paper as much as possible, which necessitated changing the sampling code a little.
75 | initiate the sampler as:
76 | ```python
77 | sampler = HeunSampler(num_sample_steps=25, denoiser=pipe.unet, alpha_bar=pipe.scheduler.alphas_cumprod)
78 | init_latents = torch.randn(batch_size, 4, 64, 64).to(device) * sampler.t0
79 | ```
80 |
81 | and replace the inner loop for generate_sample.py as:
82 | ```python
83 | for t in tqdm(sampler.timesteps):
84 | latents = sampler(latents, t, text_embeddings, guidance_scale)
85 | ```
86 |
87 | similarly, for dpm-solver-2,
88 |
89 | ```python
90 | sampler = DPMSampler(num_sample_steps=20, denoiser=pipe.unet)
91 | init_latents = torch.randn(batch_size, 4, 64, 64).to(device) * sampler.lmbd(1)[1]
92 | ```
93 |
94 | and, for fast exponential integrator,
95 |
96 | ```python
97 | sampler = ExponentialSampler(num_sample_steps=50, denoiser=pipe.unet)
98 | init_latents = torch.randn(batch_size, 4, 64, 64).to(device)
99 | ```
100 |
101 | and, for dpm-solver++ (2m),
102 |
103 | ```python
104 | sampler = DPMPlusPlusSampler(denoiser=pipe.unet, num_sample_steps=20)
105 | init_latents = torch.randn(batch_size, 4, 64, 64).to(device) * sampler.get_coeffs(sampler.t[0])[1]
106 | ```
107 |
108 | ## soft-diffusion
109 |
110 | a sketch/draft of google's new paper, [soft diffusion: score matching for general corruptions](https://arxiv.org/abs/2209.05442), which achieves state-of-the-art results on celeba-64 dataset.
111 |
112 | details can be found [here](soft_diffusion)
--------------------------------------------------------------------------------
/generate_sample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from diffusers import StableDiffusionPipeline
4 | from tqdm import tqdm
5 | import matplotlib.pyplot as plt
6 | from samplers import ddpm
7 | from PIL import Image
8 |
9 |
10 | def sample_image(pipe, sampler, prompt, init_latents, batch_size=1, guidance_scale=8.):
11 | # sample image from stable diffusion
12 | with torch.inference_mode():
13 | latents = init_latents
14 | # prompt conditioning
15 | cond_input = pipe.tokenizer([prompt], padding="max_length", truncation=True, max_length=pipe.tokenizer.model_max_length, return_tensors="pt")
16 | text_embeddings = pipe.text_encoder(cond_input.input_ids.to(pipe.device))[0]
17 | # unconditional (classifier free guidance)
18 | uncond_input = pipe.tokenizer([""] * batch_size, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt")
19 | uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0]
20 | # concat embeddings
21 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
22 |
23 | for t in tqdm(sampler.timesteps):
24 | latent_model_input = torch.cat([latents] * 2)
25 | noise_pred = pipe.unet(latent_model_input, t, text_embeddings).sample
26 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
27 | # @crowsonkb found out subtracting noise obtained from unconditioned input (empty prompt)
28 | # pushes the generated sample toward high density regions given prompt:
29 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
30 | latents = sampler(noise_pred, latents, t)
31 |
32 | with torch.autocast('cuda'):
33 | images = pipe.vae.decode(latents * 1 / 0.18215).sample
34 |
35 | # tensor to image
36 | images = (images / 2 + 0.5).clamp(0, 1)
37 | images = images.cpu().permute(0, 2, 3, 1).numpy()
38 | images = np.uint8(255 * images)
39 |
40 | return images
41 |
42 |
43 | if __name__ == '__main__':
44 |
45 | device = 'cuda'
46 | TOKEN = '' # add your token here
47 | model_name = 'CompVis/stable-diffusion-v1-4'
48 | pipe = StableDiffusionPipeline.from_pretrained(model_name, use_auth_token=TOKEN).to(device)
49 |
50 | sampler = ddpm.DDPMSampler(num_sample_steps=1000)
51 |
52 | prompt = "a man eating an apple sitting on a bench"
53 | batch_size = 1
54 | init_latents = torch.randn(batch_size, 4, 64, 64).to(device)
55 | images = sample_image(pipe, sampler, prompt, init_latents)
56 | # plot the image
57 | plt.figure(); plt.imshow(images[0]); plt.show()
58 | # or save
59 | Image.fromarray(images[0]).save('images/ddpm.jpg', quality=50)
60 |
61 |
--------------------------------------------------------------------------------
/images/ddim.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/ddim.jpg
--------------------------------------------------------------------------------
/images/ddpm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/ddpm.jpg
--------------------------------------------------------------------------------
/images/dpm_solver_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/dpm_solver_2.jpg
--------------------------------------------------------------------------------
/images/dpmsolverplusplus.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/dpmsolverplusplus.jpg
--------------------------------------------------------------------------------
/images/exponential_integrator.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/exponential_integrator.jpg
--------------------------------------------------------------------------------
/images/heun.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/heun.jpg
--------------------------------------------------------------------------------
/images/improved_ddpm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/improved_ddpm.jpg
--------------------------------------------------------------------------------
/images/pndm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/images/pndm.jpg
--------------------------------------------------------------------------------
/samplers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/samplers/__init__.py
--------------------------------------------------------------------------------
/samplers/ddim.py:
--------------------------------------------------------------------------------
1 | from torch import sqrt
2 | import torch
3 | from . import utils
4 |
5 |
6 | class DDIMSampler:
7 | # https://arxiv.org/abs/2010.02502
8 |
9 | def __init__(self, num_sample_steps=50, num_train_timesteps=1000, reverse_sample=True):
10 | beta = utils.get_beta_schedule(num_train_timesteps+1)
11 | # alpha in ddim == alpha^bar in ddpm = \prod_0^t (1-beta)
12 | self.alpha = torch.cumprod(1 - beta, dim=0)
13 |
14 | self.stride = len(self.alpha) // num_sample_steps
15 | self.timesteps = torch.arange(num_train_timesteps+1) # stable diffusion accepts discrete timestep
16 | # make timestep -> alpha mapping explicit to avoid confusion with different sampling steps
17 | self.alpha = {t.item(): alpha for t, alpha in zip(self.timesteps, self.alpha)}
18 |
19 | self.timesteps = self.timesteps[::self.stride]
20 |
21 | if reverse_sample: # generating samples (T) or training the model (F)
22 | self.timesteps = reversed(self.timesteps)[:-1]
23 |
24 | self.reverse_sample = reverse_sample
25 |
26 | def __call__(self, eps_theta, x, t, eta=0):
27 |
28 | t = t.item()
29 | tprev = t - self.stride if self.reverse_sample else t + self.stride
30 |
31 | alpha, alpha_prev = self.alpha[t], self.alpha[tprev]
32 |
33 | eps = torch.randn_like(x)
34 |
35 | # Eqn. 16
36 | sigma_tau = (eta * sqrt((1 - alpha_prev) / (1 - alpha)) * sqrt(1 - alpha / alpha_prev)) if eta > 0 else 0
37 | # sigma_tau interpolates between ddim and ddpm by assigning (0=ddim, 1=ddpm)
38 | # ddim (song et al.) is notationally simpler than ddpm (ho et al.)
39 | # (although ddim uses \alpha to refer to \bar{alpha} from ddpm)
40 |
41 | # Eqn. 12
42 | predicted_x0 = (x - sqrt(1 - alpha) * eps_theta) / sqrt(alpha)
43 | dp_xt = sqrt(1 - alpha_prev - sigma_tau ** 2) # direction pointing to xt
44 | x_prev = sqrt(alpha_prev) * predicted_x0 + dp_xt * eps_theta + sigma_tau * eps
45 |
46 | return x_prev
47 |
--------------------------------------------------------------------------------
/samplers/ddpm.py:
--------------------------------------------------------------------------------
1 | from torch import sqrt
2 | import torch
3 | from . import utils
4 |
5 |
6 | class DDPMSampler:
7 | # https://arxiv.org/abs/2006.11239
8 |
9 | def __init__(self, num_sample_steps=500, num_train_timesteps=1000, reverse_sample=True):
10 |
11 | beta = utils.get_beta_schedule(num_train_timesteps+1)
12 | alpha = 1 - beta
13 | alpha_bar = torch.cumprod(alpha, dim=0)
14 |
15 | self.stride = len(alpha) // num_sample_steps
16 | self.timesteps = torch.arange(num_train_timesteps+1) # stable diffusion accepts discrete timestep
17 |
18 | beta = torch.clamp(beta, -torch.inf, 0.99)
19 | alpha_bar = torch.cos((self.timesteps/(num_train_timesteps+1) + 0.008)/(1 + 0.008) * torch.pi/2)
20 |
21 | # make timestep -> alpha/beta mapping explicit
22 | self.beta = {t.item(): beta for t, beta in zip(self.timesteps, beta)}
23 | self.alpha = {t.item(): alpha for t, alpha in zip(self.timesteps, alpha)}
24 | self.alpha_bar = {t.item(): alpha_bar for t, alpha_bar in zip(self.timesteps, alpha_bar)}
25 |
26 | self.timesteps = self.timesteps[::self.stride]
27 |
28 | if reverse_sample: # generating samples (T) or training the model (F)
29 | self.timesteps = reversed(self.timesteps)[:-1]
30 |
31 | self.reverse_sample = reverse_sample
32 |
33 | def __call__(self, eps_theta, x, t):
34 |
35 | t = t.item()
36 | tprev = t - self.stride if self.reverse_sample else t + self.stride
37 |
38 | beta, alpha = self.beta[t], self.alpha[t]
39 | alpha_bar, alpha_bar_prev = self.alpha_bar[t], self.alpha_bar[tprev]
40 |
41 | # two alternatives for \sigma, both equally effective according to ddpm paper
42 | sigma = sqrt(beta)
43 | beta_tilde = (1 - alpha_bar_prev) / (1 - alpha_bar) * beta # eqn (7)
44 | # sigma = sqrt(beta_tilde)
45 | # algorithm 2
46 | z = torch.randn_like(x) if t > 1 else torch.zeros_like(x)
47 | x_prev = 1 / sqrt(alpha) * (x - (1 - alpha) / (1 - alpha_bar) * eps_theta) + sigma * z
48 |
49 | return x_prev
50 |
51 |
52 |
--------------------------------------------------------------------------------
/samplers/dpm_solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class DPMSampler:
6 | # https://arxiv.org/abs/2206.00927
7 |
8 | def __init__(self, denoiser, num_sample_steps=20, num_train_timesteps=1000, schedule='linear'):
9 |
10 | self.schedule = schedule
11 | T = 1 if self.schedule == 'linear' else 0.9946
12 |
13 | # discrete to continuous- section 3.4, [0,1]->[0,N]
14 | self.denoiser = denoiser # lambda x, t, embd: denoiser(x, int(num_train_timesteps * t), embd).sample
15 |
16 | # A simple way for choosing time steps for lambda is uniformly
17 | # splitting [lambda_t , lambda_eps], which is the setting in our experiments
18 | (_, _, lmbd_max), (_, _, lmbd_min) = self.lmbd(T), self.lmbd(1E-3)
19 | lmbds = torch.linspace(lmbd_max, lmbd_min, num_train_timesteps)
20 | self.timesteps = torch.tensor([self.t_lambda(l, self.schedule) for l in lmbds])
21 |
22 | self.stride = num_train_timesteps // num_sample_steps
23 | self.timesteps = self.timesteps[::self.stride]
24 | self.num_train_timesteps = num_train_timesteps
25 |
26 | def eps_theta(self, x, t, embd, guidance_scale=None):
27 | t = int(max(0, t - 1 / self.num_train_timesteps) * self.num_train_timesteps)
28 | with torch.inference_mode():
29 | noise = self.denoiser(x if guidance_scale is None else torch.cat([x] * 2), t, embd).sample
30 | if guidance_scale is not None:
31 | noise_pred_uncond, noise_pred_text = noise.chunk(2)
32 | noise = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
33 | return noise
34 |
35 | @staticmethod
36 | def _alpha(t, s=0.008):
37 | return np.exp(np.log(np.cos(np.pi / 2 * (t + s) / (1 + s))) -
38 | np.log(np.cos(np.pi / 2 * (s) / (1 + s))))
39 |
40 | @staticmethod
41 | def _lmbd(alpha, sigma):
42 | return np.log(alpha) - np.log(sigma)
43 |
44 | # @staticmethod
45 | def _sigma(self, alpha):
46 | # return 0.5 * np.log(1. - np.exp(2. * self.log_alpha(t)))
47 | return np.sqrt(1 - alpha ** 2)
48 |
49 | @staticmethod
50 | def t_lambda(lmbd, schedule='linear'):
51 | # appendix d.4
52 | if schedule == 'linear':
53 | b0, b1 = 0.1, 20.
54 | nom = 2 * np.logaddexp(-2 * lmbd, 0) # numerically stable log(e**x+e**y)
55 | denom = np.sqrt(b0 ** 2 + 2 * (b1 - b0) * np.logaddexp(-2 * lmbd, 0)) + b0
56 | return nom / denom
57 | elif schedule == 'cosine':
58 | s = 0.008
59 | f_lambda = -1 / 2 * np.logaddexp(-2 * lmbd, 0)
60 | logcos = np.log(np.cos((np.pi * s) / (2 * (1 + s))))
61 | return 2 * (1 + s) / np.pi * np.arccos(np.exp(f_lambda + logcos)) - s
62 |
63 | def lmbd(self, t):
64 | log_alpha = self.log_alpha(t, self.schedule)
65 | # log_sigma = 0.5 * np.log(1. - np.exp(2. * log_alpha))
66 | sigma = np.sqrt(1. - np.exp(2 * log_alpha))
67 | log_sigma = np.log(sigma)
68 | return log_alpha, sigma, (log_alpha-log_sigma)
69 |
70 | def log_alpha(self, t, schedule):
71 | if schedule == 'linear':
72 | b0, b1 = 0.1, 20
73 | return -(b1-b0)/4 * t ** 2 - b0/2*t
74 | elif schedule == 'cosine':
75 | s = 0.008
76 | return np.log(np.cos(np.pi/2*(t+s)/(1+s))) - np.log(np.cos(np.pi/2*(s)/(1+s)))
77 |
78 | def dpm_solver_2(self, x_tilde_prev, t_prev, t, text_embeddings, guidance_scale=None, r1=0.5):
79 | # algorithm 4
80 | # numerically _not_ stable
81 | '''alpha_t, alpha_t_prev = self._alpha(t), self._alpha(t_prev)
82 | sigma_t, sigma_t_prev = self._sigma(alpha_t), self._sigma(alpha_t_prev)
83 | lmbd_t, lmbd_t_prev = self._lmbd(alpha_t, sigma_t), self._lmbd(alpha_t_prev, sigma_t_prev)'''
84 |
85 | # numerically stable
86 | (log_alpha_t, sigma_t, lmbd_t), (log_alpha_t_prev, sigma_t_prev, lmbd_t_prev) = self.lmbd(t), self.lmbd(t_prev)
87 | h = lmbd_t - lmbd_t_prev
88 | s = self.t_lambda(lmbd_t_prev + r1 * h, self.schedule)
89 | log_alpha_s, sigma_s, lmbd_s = self.lmbd(s)
90 |
91 | eps_prev = self.eps_theta(x_tilde_prev, t_prev, text_embeddings, guidance_scale)
92 | u = np.exp(log_alpha_s - log_alpha_t_prev) * x_tilde_prev - sigma_s * np.expm1(r1 * h) * eps_prev
93 |
94 | x_tilde = (
95 | + np.exp(log_alpha_t - log_alpha_t_prev) * x_tilde_prev
96 | - sigma_t * np.expm1(h) * eps_prev
97 | - sigma_t / (2 * r1) * np.expm1(h) * (self.eps_theta(u, s, text_embeddings, guidance_scale) - eps_prev)
98 | )
99 | return x_tilde
100 |
101 | def __call__(self, x, t, text_embeddings, guidance_scale=8.):
102 |
103 | if t == self.timesteps[-1]:
104 | return x
105 |
106 | t_prev = t.item()
107 | prev_index = torch.where(self.timesteps < t)[0][0]
108 | t = self.timesteps[prev_index].item()
109 |
110 | return self.dpm_solver_2(x, t_prev, t, text_embeddings, guidance_scale)
111 |
--------------------------------------------------------------------------------
/samplers/dpm_solver_plus_plus.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import utils
4 |
5 |
6 | class DPMPlusPlusSampler:
7 | # https://arxiv.org/abs/2211.01095
8 |
9 | def __init__(self, denoiser, num_sample_steps=20, num_train_timesteps=1000):
10 |
11 | self.timesteps = np.arange(1, num_train_timesteps+1)
12 | self.t = self.timesteps / num_train_timesteps
13 |
14 | beta = utils.get_beta_schedule(num_train_timesteps+1)
15 | self.alpha = {t: v for t, v in zip(self.t, torch.cumprod(1 - beta, dim=0).sqrt())}
16 |
17 | self.q = dict()
18 |
19 | self.denoiser = denoiser
20 |
21 | self.stride = int(num_train_timesteps / num_sample_steps)
22 | self.timesteps = self.timesteps[::self.stride][::-1]
23 | self.t = self.t[::-1]
24 | self.num_train_timesteps = num_train_timesteps
25 |
26 | def x_theta(self, x, t, embd, guidance_scale=None):
27 | # data prediction model
28 | alpha_t, sigma_t, _ = self.get_coeffs(t)
29 | timestep = int(max(0, t) * self.num_train_timesteps)
30 | with torch.inference_mode():
31 | noise = self.denoiser(x if guidance_scale is None else torch.cat([x] * 2), timestep, embd).sample
32 | if guidance_scale is not None:
33 | noise_pred_uncond, noise_pred_text = noise.chunk(2)
34 | noise = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
35 | x_pred = (x - sigma_t * noise) / alpha_t # Parameterization: noise prediction and data prediction
36 | # print(t, timestep, x_pred.max(), noise.max(), sigma_t, alpha_t)
37 | return x_pred
38 |
39 | def get_coeffs(self, t):
40 | # appendix b.1 for alphas, sigmas, lambdas
41 | '''# below is used for continuous t, uncomment & modify for your needs if you want to
42 | t_next = t + 1 / self.num_train_timesteps
43 | t_n, t_n_next = sampler.t[np.digitize(t, sampler.t)], sampler.t[np.digitize(t+1/self.num_train_timesteps, sampler.t)]
44 | t_n_next = t_n + self.stride
45 | log_alpha_n, log_alpha_n_next = self.alpha[t].log(), self.alpha[t_next].log()
46 | log_alpha_t = log_alpha_n + (log_alpha_n_next-log_alpha_n)/(t_n_next-t_n) * (t-t_n)'''
47 | log_alpha_t = self.alpha[t].log() # comment this out if above is uncommented
48 | alpha_t = log_alpha_t.exp()
49 | sigma_t = (1-alpha_t ** 2).sqrt()
50 | log_sigma_t = sigma_t.log()
51 | lambda_t = log_alpha_t - log_sigma_t
52 | return alpha_t, sigma_t, lambda_t
53 |
54 | def dpm_solver_plus_plus_2m(self, x_tilde_prev, t_prev_prev, t_prev, t, text_embeddings, guidance_scale=None):
55 | # algorithm 2
56 | alpha_t, sigma_t, lambda_t = self.get_coeffs(t)
57 | _, sigma_t_prev, lambda_t_prev = self.get_coeffs(t_prev)
58 | _, _, lambda_t_prev_prev = self.get_coeffs(t_prev_prev)
59 | h = lambda_t - lambda_t_prev
60 | h_prev = lambda_t_prev - lambda_t_prev_prev
61 |
62 | r = h_prev / h
63 | D = (1 + 1 / (2 * r)) * self.q[t_prev] - 1 / (2 * r) * self.q[t_prev_prev]
64 | x_tilde = sigma_t / sigma_t_prev * x_tilde_prev - alpha_t * torch.expm1(-h) * D
65 | self.q[t] = self.x_theta(x_tilde, t, text_embeddings, guidance_scale)
66 |
67 | return x_tilde
68 |
69 | def __call__(self, x, timestep_index, text_embeddings, guidance_scale=8.):
70 |
71 | # the notation used in this paper is really confusing
72 | # gaussian noise input gets an index of M which maps
73 | # to t_0, i.e., reversed. so
74 | # t0->t1->t2 is M->M-1->M-2. in other words t_prev > t
75 |
76 | t = timestep_index.item() / self.num_train_timesteps
77 | t_prev = (timestep_index.item() + self.stride) / self.num_train_timesteps
78 | t_prev_prev = (timestep_index.item() + 2 * self.stride) / self.num_train_timesteps
79 |
80 | if timestep_index == self.timesteps[0]:
81 | # first step (xt0)
82 | self.q[t] = self.x_theta(x, t, text_embeddings, guidance_scale)
83 | return x # typo in text, x0 in algorithm 2
84 |
85 | elif timestep_index == self.timesteps[1]:
86 | # second step (xt1)
87 | alpha_t, sigma_t, lambda_t = self.get_coeffs(t)
88 | _, sigma_t_prev, lambda_t_prev = self.get_coeffs(t_prev)
89 | h = lambda_t - lambda_t_prev
90 | x_tilde = sigma_t / sigma_t_prev * x - alpha_t * torch.expm1(-h) * self.q[t_prev]
91 | self.q[t] = self.x_theta(x_tilde, t, text_embeddings, guidance_scale)
92 | return x_tilde
93 |
94 | # 2 to M loop
95 | return self.dpm_solver_plus_plus_2m(x, t_prev_prev, t_prev, t, text_embeddings, guidance_scale)
96 |
--------------------------------------------------------------------------------
/samplers/exponential_integrator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class ExponentialSampler:
6 | # https://arxiv.org/abs/2204.13902
7 |
8 | def __init__(
9 | self,
10 | denoiser,
11 | num_sample_steps=50,
12 | num_train_timesteps=1000,
13 | r=3,
14 | ):
15 | self.r = r
16 | self.denoiser = denoiser
17 |
18 | self.num_train_timesteps, self.num_sample_steps = num_train_timesteps, num_sample_steps
19 |
20 | self.stride = num_train_timesteps // num_sample_steps
21 | self.timesteps = reversed(torch.arange(1, num_sample_steps - 1)) # see algo 1, timestep ends at 1 not zero
22 |
23 | self.t = np.linspace(1E-3, 1, self.num_sample_steps+r)
24 | self.c = self.calc_c() # calculate \mathbf{C} based on Eq. 15, also see table 1
25 | self.eps_buffer = dict()
26 |
27 | @staticmethod
28 | def log_alpha(t, schedule='linear'):
29 | if schedule == 'linear':
30 | b0, b1 = 0.1, 20
31 | return -(b1 - b0) / 4 * t ** 2 - b0 / 2 * t
32 | elif schedule == 'cosine':
33 | s = 0.008
34 | return np.log(np.cos(np.pi / 2 * (t + s) / (1 + s))) - np.log(np.cos(np.pi / 2 * (s) / (1 + s)))
35 |
36 | def calc_c(self):
37 |
38 | from scipy.interpolate import interp1d
39 | log_alpha_fn = interp1d(
40 | np.linspace(0, 1+0.01, self.num_train_timesteps+1),
41 | self.log_alpha(np.linspace(0, 1, self.num_train_timesteps+1), schedule='linear'))
42 |
43 | from scipy.misc import derivative
44 | G = lambda tau: np.sqrt(-derivative(log_alpha_fn, tau, dx=1e-5))
45 | L = lambda tau: np.sqrt(1-np.exp(log_alpha_fn(tau)))
46 |
47 | def prod_fn(tau, i, j, r):
48 | prod = 1.
49 | for k in range(r+1):
50 | prod *= (tau - self.t[i + k]) / (self.t[i + j] - self.t[i + k]) if k != j else 1
51 | return prod
52 |
53 | def get_cij(tau, i, j, r):
54 | # Eq. 15. ignoring inverse and transpose since G, L are 1d
55 | return 1 / 2 * self.psi(self.t[i - 1], tau) * G(tau) ** 2 * 1 / L(tau) * prod_fn(tau, i, j, r)
56 |
57 | cij = dict()
58 | from scipy import integrate
59 | for i in range(self.num_sample_steps):
60 | for j in range(self.r+1):
61 | cij[i, j] = integrate.quad(get_cij, self.t[i], self.t[i - 1], args=(i, j, self.r), epsrel=1e-4)[0]
62 |
63 | return cij
64 |
65 | def psi(self, t, s):
66 | return np.sqrt(np.exp((self.log_alpha(t) - self.log_alpha(s)))) # appendix d, PROOF OF PROP 2
67 |
68 | def eps_theta(self, x, i, embd, guidance_scale=None):
69 | with torch.inference_mode():
70 | noise = self.denoiser(x if guidance_scale is None else torch.cat([x] * 2), i, embd).sample
71 | if guidance_scale is not None:
72 | noise_pred_uncond, noise_pred_text = noise.chunk(2)
73 | noise = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
74 | return noise
75 |
76 | def __call__(self, x_hat, i, text_embeddings, guidance_scale=8.):
77 | # algorithm 1.
78 | i = int(i)
79 | self.eps_buffer[self.t[i]] = self.eps_theta(x_hat, i*self.stride, text_embeddings, guidance_scale)
80 | x_hat_prev = self.psi(self.t[i-1], self.t[i]) * x_hat + sum(
81 | [self.c[i, j] * self.eps_buffer[self.t[i+j]] for j in range(self.r+1) if self.t[i+j] in self.eps_buffer]
82 | ) # Eq. 14
83 | return x_hat_prev
84 |
--------------------------------------------------------------------------------
/samplers/heun.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from math import sqrt as msqrt
3 |
4 |
5 | class HeunSampler:
6 | # https://arxiv.org/abs/2010.02502
7 |
8 | def __init__(
9 | self,
10 | num_sample_steps=50,
11 | num_train_timesteps=1000,
12 | denoiser=None,
13 | ddpm=True,
14 | alpha_bar=None,
15 | ):
16 | self.denoiser = denoiser
17 |
18 | self.num_train_timesteps, self.num_sample_steps = num_train_timesteps, num_sample_steps
19 |
20 | self.ddpm_sigmas = ((1-alpha_bar)/alpha_bar).sqrt()
21 | self.t0 = self.ddpm_sigmas[-1] # or sigma_max/max noise
22 |
23 | self.stride = len(alpha_bar) // num_sample_steps
24 | self.timesteps = torch.arange(num_train_timesteps+1)
25 |
26 | # table 1
27 | sigma_min, sigma_max, rho = self.ddpm_sigmas[0], self.ddpm_sigmas[-1], 7 # stable diff.
28 | _heun_sigmas = (sigma_max ** (1 / rho) + (self.timesteps / self.num_sample_steps) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
29 | self.heun_sigmas = _heun_sigmas
30 | self.sigma_data = 0.5
31 | self.timesteps = torch.arange(num_sample_steps - 1)
32 |
33 | self.ddpm = ddpm # model trained with ddpm objective, used to set c_in etc
34 |
35 | def c_skip(self, sigma):
36 | if self.ddpm:
37 | return torch.tensor(1.)
38 | return self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
39 |
40 | def c_out(self, sigma):
41 | if self.ddpm:
42 | return -sigma
43 | return sigma * self.sigma_data * (self.sigma_data**2 + sigma**2).rsqrt()
44 |
45 | def c_in(self, sigma):
46 | if self.ddpm:
47 | return (1 + sigma ** 2).rsqrt()
48 | return (sigma ** 2 + self.sigma_data ** 2).rsqrt()
49 |
50 | def c_noise(self, sigma):
51 | if self.ddpm:
52 | return torch.abs(self.ddpm_sigmas-sigma).argmin() # iddpm practical (c.3.4: 3rd bullet)
53 | return 1 / 4 * sigma.log()
54 |
55 | def d_theta(self, x, sigma, encoder_hidden_states, guidance_scale=None):
56 |
57 | eps_theta = self.predict_noise(
58 | x if guidance_scale is None else torch.cat([x] * 2), sigma, encoder_hidden_states)
59 | # hacky bit: if guidance!=None, perform classifier free guidance
60 | # deviates from original implementation to support minimal code
61 | # duplication, even though it's not good practice
62 | if guidance_scale is not None:
63 | noise_pred_uncond, noise_pred_text = eps_theta.chunk(2)
64 | eps_theta = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
65 |
66 | return self.c_skip(sigma) * x + self.c_out(sigma) * eps_theta
67 |
68 | def f_theta(self, x, sigma, encoder_hidden_states):
69 | with torch.inference_mode():
70 | return self.denoiser(x, sigma, encoder_hidden_states).sample
71 |
72 | def predict_noise(self, x, sigma, encoder_hidden_states):
73 | # helper function to predict noise for stable diffusion
74 | # since heun modifies both input and the stdev, using
75 | # sd pipeline directly does not yield the correct f_theta output
76 | return self.f_theta(self.c_in(sigma) * x, self.c_noise(sigma), encoder_hidden_states)
77 |
78 | def stochastic_sampler(self, x, i, encoder_hidden_states, s_churn=80, s_tmin=0.05, s_tmax=50, s_noise=1.003):
79 | # Algorithm 2 - arg defaults from table 5
80 | x_next, x_hat, t_next, t_hat, d = self.euler_method(
81 | x, i, encoder_hidden_states, s_churn, s_tmin, s_tmax, s_noise) # first order denoising
82 | x_next = self.heuns_correction(x_next, x_hat, t_next, t_hat, d, encoder_hidden_states) # second order correction
83 | return x_next
84 |
85 | def euler_method(self, x, i, encoder_hidden_states=None, guidance_scale=None, s_churn=80, s_tmin=0.05, s_tmax=50, s_noise=1.003):
86 | # euler method, omitting second order correction @ lines 9-11 (algorithm 2)
87 | t, t_next = self.heun_sigmas[i], self.heun_sigmas[i+1]
88 |
89 | eps = torch.randn_like(x) * s_noise
90 | gamma = min(s_churn / self.num_sample_steps, msqrt(2) - 1) if s_tmin <= t <= s_tmax else 0
91 | t_hat = t + gamma * t
92 | if self.ddpm:
93 | t_hat = self.ddpm_sigmas[torch.abs(self.ddpm_sigmas-t_hat).argmin()] # iddpm practical, step 3
94 | x_hat = x + msqrt(max(t_hat ** 2 - t ** 2, 0)) * eps # gamma < 0 -> negative sqrt
95 | d = (x_hat - self.d_theta(x_hat, t_hat, encoder_hidden_states, guidance_scale)) / t_hat
96 | x_next = x_hat + (t_next - t_hat) * d
97 |
98 | return x_next, x_hat, t_next, t_hat, d
99 |
100 | def heuns_correction(self, x_next, x_hat, t_next, t_hat, d, encoder_hidden_states, guidance_scale=None):
101 | # correction improves differentiation discretization error from o(h^2) -> o(h^3), h = step size
102 | # correction addresses varying step sizes between two different timesteps
103 | # see Discretization and higher-order integrators from the paper
104 | if t_next != 0:
105 | d_prime = (x_next - self.d_theta(x_next, t_next, encoder_hidden_states, guidance_scale)) / t_next
106 | x_next = x_hat + (t_next - t_hat) * (1 / 2 * d + 1 / 2 * d_prime)
107 | return x_next
108 |
109 | def alpha_sampler(self, x, i, encoder_hidden_states, guidance_scale=None, alpha=1):
110 | # Algorithm 3
111 | t, t_next = self.heun_sigmas[i], self.heun_sigmas[i+1]
112 | h = t_next - t
113 | d = (x - self.d_theta(x, t, encoder_hidden_states, guidance_scale)) / t
114 | x_prime, t_prime = x + alpha * h * d, t + alpha * h
115 | if t_prime != 0:
116 | d_prime = (x_prime - self.d_theta(x_prime, t_prime, encoder_hidden_states, guidance_scale)) / t_prime
117 | x_next = x + h * ((1 - 0.5 / alpha) * d + 0.5 / alpha * d_prime)
118 | else:
119 | x_next = x + h*d
120 |
121 | return x_next
122 |
123 | def __call__(self, *args):
124 | return self.stochastic_sampler(*args)
125 |
--------------------------------------------------------------------------------
/samplers/improved_ddpm.py:
--------------------------------------------------------------------------------
1 | from torch import sqrt
2 | import torch
3 | from . import utils
4 |
5 |
6 | class ImprovedDDPMSampler:
7 | # https://arxiv.org/abs/2102.09672
8 |
9 | def __init__(self, num_sample_steps=500, num_train_timesteps=1000, reverse_sample=True):
10 |
11 | self.timesteps = torch.arange(num_train_timesteps+1) # stable diffusion accepts discrete timestep
12 |
13 | T, s = num_train_timesteps+1, 0.008
14 | alpha_bar = torch.cos((self.timesteps/T + s)/(1 + s) * torch.pi/2) ** 2 # (17), f(0) ~ 0.99999 so plug f(t)=\bar{\alpha}
15 | alpha_bar_prev = torch.hstack([alpha_bar[0], alpha_bar[:-1]])
16 | beta = 1 - alpha_bar / alpha_bar_prev
17 | alpha = 1 - beta
18 | beta = torch.clamp(beta, -torch.inf, 0.999)
19 |
20 | self.stride = len(alpha) // num_sample_steps
21 |
22 | # make timestep -> alpha/beta mapping explicit
23 | self.beta = {t.item(): beta for t, beta in zip(self.timesteps, beta)}
24 | self.alpha = {t.item(): alpha for t, alpha in zip(self.timesteps, alpha)}
25 | self.alpha_bar = {t.item(): alpha_bar for t, alpha_bar in zip(self.timesteps, alpha_bar)}
26 |
27 | self.timesteps = self.timesteps[::self.stride]
28 |
29 | if reverse_sample: # generating samples (T) or training the model (F)
30 | self.timesteps = reversed(self.timesteps)[:-1]
31 |
32 | self.reverse_sample = reverse_sample
33 |
34 | def __call__(self, eps_theta, x, t):
35 |
36 | t = t.item()
37 | tprev = t - self.stride if self.reverse_sample else t + self.stride
38 |
39 | beta, alpha = self.beta[t], self.alpha[t]
40 | alpha_bar, alpha_bar_prev = self.alpha_bar[t], self.alpha_bar[tprev]
41 |
42 | # eqn 3
43 | mu = alpha.rsqrt() * (x - beta * (1 - alpha_bar).rsqrt() * eps_theta) # eq (13) reciprocal sqrt should be more numerically stable
44 | # variance, should be learned but i'm just experimenting w/ plugging (15) directly w/ v=1/2
45 | beta_tilde = (1 - alpha_bar_prev) / (1 - alpha_bar) * beta # eqn (7)
46 | v = 0.5
47 | sigma_theta = torch.exp(v*beta.log() + (1-v)*beta_tilde.log()).sqrt() # (15)
48 |
49 | z = torch.randn_like(x) if t > 1 else torch.zeros_like(x)
50 | x_prev = mu + sigma_theta * z
51 |
52 | return x_prev
53 |
--------------------------------------------------------------------------------
/samplers/pndm.py:
--------------------------------------------------------------------------------
1 | from torch import sqrt
2 | import torch
3 | from . import utils
4 |
5 |
6 | class PNDMSampler:
7 | # https://arxiv.org/abs/2202.09778
8 |
9 | def __init__(self, num_sample_steps=50, num_train_timesteps=1000, reverse_sample=True, ddim=False):
10 |
11 | beta = utils.get_beta_schedule(num_train_timesteps+1)
12 | alpha = 1 - beta
13 | self.alpha_bar = torch.cumprod(alpha, dim=0)
14 |
15 | self.stride = len(self.alpha_bar) // num_sample_steps
16 | self.timesteps = torch.arange(num_train_timesteps+1) # stable diffusion accepts discrete timestep
17 | self.alpha_bar = {t.item(): alpha for t, alpha in zip(self.timesteps, self.alpha_bar)}
18 |
19 | self.timesteps = self.timesteps[::self.stride]
20 |
21 | if reverse_sample: # generating samples (T) or training the model (F)
22 | self.timesteps = reversed(self.timesteps)[:-1]
23 |
24 | self.reverse_sample = reverse_sample
25 |
26 | self.ddim = ddim # alg. 1 from the paper
27 |
28 | self.et = []
29 |
30 | def prk(self, xt, et, t, tpd):
31 | '''
32 | # reuse et 4 times (from the paper):
33 | we find that the linear multi-step method can reuse
34 | the result of \eps four times and only compute \eps
35 | once at every step.
36 | ---
37 | so we can just return e_t for each \eps_{\theta}(.,.) call
38 | '''
39 | def eps_theta(*args):
40 | return et
41 | delta = tpd - t
42 | tpd2 = t + delta/2 # t plus d/2
43 | # Eqn 13
44 | et1 = eps_theta(xt, t)
45 | xt1 = self.phi(xt, et1, t, tpd2)
46 | et2 = eps_theta(xt1, tpd2)
47 | xt2 = self.phi(xt, et2, t, tpd2)
48 | et3 = eps_theta(xt2, tpd2)
49 | xt3 = self.phi(xt, et3, t, tpd) # paper has a typo here
50 | et4 = eps_theta(xt3, tpd)
51 | et_d = 1/6 * (et1 + 2*et2 + 2*et3 + et4) # e_t'
52 | xtpd = self.phi(xt, et_d, t, tpd) # another typo, this was defined as xtmd
53 | return xtpd, et1
54 |
55 | def plms(self, xt, et, t, tpd):
56 | # Eqn 12
57 | etm3d, etm2d, etmd = self.et[-3:]
58 | et_d = 1 / 24 * (55 * et - 59 * etmd + 37 * etm2d - 9 * etm3d) # e_t'
59 | xtpd = self.phi(xt, et_d, t, tpd)
60 | return xtpd, et
61 |
62 | def add_to_error_queue(self, item):
63 | self.et.append(item)
64 | if len(self.et) > 3: # no need to store > 3 e_t (see eqn 12)
65 | del self.et[0]
66 |
67 | def phi(self, xt, et, t, t_prev):
68 | # i.e., the transfer part (eqn 11), t_prev = t-\delta
69 | alpha_bar, alpha_bar_prev = self.alpha_bar[t], self.alpha_bar[t_prev]
70 |
71 | transfer = sqrt(alpha_bar_prev)/sqrt(alpha_bar) * xt # first term in the eqn.
72 | # second term of the eqn, split into nominator&denom for clarity:
73 | nom = alpha_bar_prev-alpha_bar
74 | denom = sqrt(alpha_bar) * (sqrt((1-alpha_bar_prev)*alpha_bar) + sqrt((1-alpha_bar)*alpha_bar_prev))
75 | transfer = transfer - nom/denom*et
76 |
77 | return transfer
78 |
79 | def __call__(self, eps_theta, xtp1, t):
80 | t = t.item()
81 | tpd = t - self.stride # \delta = -1 in algorithm 2 (pndms) by inspection
82 | if self.ddim: # algorithm 1
83 | return self.phi(xtp1, eps_theta, t, tpd)
84 | else: # algorithm 2
85 | if len(self.et) < 3:
86 | xtpd, et = self.prk(xtp1, eps_theta, t, tpd)
87 | else:
88 | xtpd, et = self.plms(xtp1, eps_theta, t, tpd)
89 | self.add_to_error_queue(et)
90 |
91 | return xtpd # note that xtpd = x_t - 1 since d = -1
92 |
--------------------------------------------------------------------------------
/samplers/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_beta_schedule(num_diffusion_timesteps=1000, beta_start=0.00085, beta_end=0.012):
5 | # beta schedule, start, end are from stable diffusion config
6 | betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps) ** 2
7 | return betas
8 |
--------------------------------------------------------------------------------
/soft_diffusion/README.md:
--------------------------------------------------------------------------------
1 | # Soft Diffusion: Score Matching for General Corruptions
2 |
3 | [arxiv link to paper](https://arxiv.org/abs/2209.05442)
4 |
5 |
6 |
7 | sketch of google's new paper, which achieves state-of-the-art results on celeba-64 dataset.
8 |
9 | this implementation is missing some details & involves some guesswork
10 |
11 | because the paper did not define certain variables and implementation details in its current version (v2).
--------------------------------------------------------------------------------
/soft_diffusion/graph_wasserstein.py:
--------------------------------------------------------------------------------
1 | '''
2 | graph based selection of gaussian blur levels
3 | plug-in the resulting sigmas to the training script
4 | '''
5 |
6 | from scipy.sparse import csr_matrix
7 | from scipy.sparse.csgraph import shortest_path
8 | from soft_diffusion.wasserstein import calculate_2_wasserstein_dist
9 | import torch
10 | import numpy as np
11 | import torchvision.transforms as tt
12 | from torch.utils.data import DataLoader
13 | import torchvision
14 |
15 |
16 | def get_shortest_path(M, start_node, end_node):
17 | # https://stackoverflow.com/a/53078901/1696420
18 | path = [end_node]
19 | k = end_node
20 | while M[start_node, k] != -9999:
21 | path.append(Pr[start_node, k])
22 | k = M[start_node, k]
23 | return path[::-1]
24 |
25 |
26 | if __name__ == '__main__':
27 |
28 | gbs_min, gbs_max = 0.01, 23 # gaussian blur sigma - appendix D, blur
29 | T = 256
30 | sigma_blurs = np.linspace(gbs_min, gbs_max, T)
31 |
32 | distributions = dict()
33 | for sigma in sigma_blurs:
34 | distributions[sigma] = list()
35 |
36 | data_root = ''
37 | data_transform = tt.Compose([
38 | tt.Resize((64, 64)),
39 | tt.ToTensor(),
40 | tt.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
41 | ])
42 | dataset = torchvision.datasets.CelebA(data_root, split='train', target_type='attr', transform=data_transform, download=True)
43 | loader = DataLoader(dataset, batch_size=64, shuffle=True)
44 | for batch in loader:
45 | distributions[sigma].append(batch[0])
46 |
47 |
48 | # compute wasserstein distances between each edge
49 | # ``we start with 256 different blur levels and we tune \eps such that
50 | # the shortest path contains 32 distributions.
51 | # we then use linear interpolation to extend to the continuous case."
52 | eps_dist_max = 1E-3 # dataset specific, eq 19, must set! this is arbitrary
53 |
54 | wasserstein_dist = torch.empty((T, T))
55 | for index, sigma in enumerate(sigma_blurs):
56 | for next_index, next_sigma in enumerate(sigma_blurs):
57 | if index == next_index:
58 | wasserstein_dist[index, next_index] = torch.inf
59 | else:
60 | wdist = calculate_2_wasserstein_dist(distributions[sigma], distributions[next_sigma])
61 | if wdist > eps_dist_max:
62 | wdist = torch.inf
63 | wasserstein_dist[index, next_index] = wdist
64 |
65 | wasserstein_dist = csr_matrix(wasserstein_dist.cpu().tolist())
66 | D, Pr = shortest_path(wasserstein_dist, directed=False, method='FW', return_predecessors=True)
67 | # get the list of sigmas giving the shortest path
68 | sigma_index = get_shortest_path(wasserstein_dist, start_node=0, end_node=T-1)
69 | sigma_blurs = [sigma for j, sigma in enumerate(sigma_blurs) if j in sigma_index]
70 | assert len(sigma_blurs) == 32, 'adjust eps_dist_max in eqn 19 until you get 32 nodes/sigmas'
71 |
72 | print(sigma_blurs)
--------------------------------------------------------------------------------
/soft_diffusion/lightning.py:
--------------------------------------------------------------------------------
1 | from torch.nn import functional as F
2 | from inspect import getfullargspec
3 | import pytorch_lightning as pl
4 | from torch import optim
5 | import torch
6 | import numpy as np
7 | from functools import partial
8 | import torchvision.transforms.functional as ttf
9 | from scipy.interpolate import interp1d
10 | import k_diffusion as K # using crowsonkb's k-diffusion (github.com/crowsonkb/k-diffusion)
11 |
12 |
13 | def blur(discrete_sigmas, half_kernel=80, image_size=64):
14 | # gaussian blur, in place of C matrix
15 | pad_val = int(np.ceil(((half_kernel * 2 + 1)/2 - image_size) / 2 + 1))
16 | padder = partial(torch.nn.functional.pad, pad=(pad_val, pad_val, pad_val, pad_val), mode='reflect')
17 |
18 | discretize_fn = interp1d(np.linspace(0, 1, len(discrete_sigmas)), discrete_sigmas)
19 |
20 | def fn(image, sigma):
21 | sigma = discretize_fn(sigma.item())
22 | image = ttf.gaussian_blur(padder(image), kernel_size=[half_kernel * 2 + 1, half_kernel * 2 + 1], sigma=sigma)
23 | image = ttf.center_crop(image, [image_size, image_size])
24 | return image
25 | return fn
26 |
27 |
28 | def get_sigmas(sigma_min=0.01, sigma_max=0.10, t_clip=0.25):
29 | # geometric scheduling
30 | # roughly at 25%, max noise level, see fig 6
31 | r = np.exp(np.log(sigma_max / sigma_min) / t_clip) - 1 # y0*(1+r)**t = yt
32 |
33 | def fn(t):
34 | return sigma_min * (1 + r) ** t
35 | return fn
36 |
37 |
38 | class LightningDiffusion(pl.LightningModule):
39 | def __init__(self, config_path, discrete_sigmas):
40 | # config path is a json file, e.g.: https://github.com/crowsonkb/k-diffusion/blob/master/configs/config_32x32_small.json
41 | # may need to update arguments depending on resolution, e.g.,
42 | # "depths": [2, 2, 4, 4],
43 | # "channels": [128, 256, 256, 512],
44 | # "self_attn_depths": [false, false, true, true],
45 |
46 | # see graph_wasserstein.py to get discrete sigmas for your dataset
47 | super().__init__()
48 |
49 | # Diffusion model
50 | config = K.config.load_config(open(config_path))
51 | self.phi_theta = K.config.make_model(config)
52 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
53 |
54 | self.sigma = get_sigmas()
55 | self.dt = 0.001 # \delta_t
56 | self.C = blur(discrete_sigmas)
57 | # gaussian blur C is a linear operator, using the fixed blur in paper - ideally this would be a matrix (C_t * x)
58 |
59 | def training_step(self, batch, batch_idx):
60 | x0 = batch[0]
61 | t = self.rng.draw(x0.shape[0])[:, 0].to(x0.device) # Sample timesteps
62 | residual = self.sigma(t) * torch.randn_like(x0)
63 | x = self.C(x0, t) + residual # eq 4
64 | noise_pred = self.phi_theta(x, t) # (r_hat|x_t)
65 | loss = (self.sigma(t) ** -2) * F.mse_loss(self.C(residual, t), self.C(noise_pred, t), reduction='none') # Eq 12.
66 | return loss.mean()
67 |
68 | def configure_optimizers(self):
69 | opt = optim.Adam(self.phi_theta.parameters(), lr=2E-4, betas=(0.9, 0.999), eps=1E-8)
70 |
71 | def lr_sched(step):
72 | return min(1, step / 5_000) # depending on batch size, this should change
73 | sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_sched)
74 |
75 | opt_config = {"optimizer": opt}
76 | if sched is not None:
77 | opt_config["lr_scheduler"] = {
78 | "scheduler": sched,
79 | "interval": "step",
80 | }
81 | return opt_config
82 |
83 | def naive_sampler(self):
84 | x = self.sample_gauss()
85 | for t in torch.linspace(1, 0, int(1/self.dt)):
86 | x0_hat = self.phi_theta(x, t) + x
87 | eta = torch.randn_like(x)
88 | x = self.C(x0_hat, t - self.dt) + self.sigma(t - self.dt) * eta
89 | return x
90 |
91 | def momentum_sampler(self):
92 | x = self.sample_gauss()
93 | for t in torch.linspace(1, 0, int(1/self.dt)):
94 | x0_hat = self.phi_theta(x, t) + x
95 | y_hat = self.C(x0_hat, t)
96 | eta = torch.randn_like(x)
97 | eps_hat = y_hat - x
98 | z = x - ((self.sigma(t - self.dt) / self.sigma(t)) ** 2 - 1) * eps_hat + (self.sigma(t) ** 2 - self.sigma(t - self.dt) ** 2).sqrt() * eta
99 | y_hat_prev = self.C(x0_hat, t - self.dt)
100 | x = z + y_hat_prev - y_hat
101 | return x
102 |
103 | def sample_gauss(self):
104 | # distribution p_1 not defined in the paper, this is my best guess
105 | x = torch.rand((1, 3, 64, 64))
106 | x = self.C(x, 23.)
107 | return x
108 |
109 |
110 |
--------------------------------------------------------------------------------
/soft_diffusion/soft_diffusion.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/diffusion-for-beginners/738029baca86a595f1adf62318b4b99f444ed360/soft_diffusion/soft_diffusion.jpg
--------------------------------------------------------------------------------
/soft_diffusion/wasserstein.py:
--------------------------------------------------------------------------------
1 | # https://gist.github.com/Flunzmas/6e359b118b0730ab403753dcc2a447df#file-calc_2_wasserstein_dist-py
2 | import math
3 | import torch
4 | import torch.linalg as linalg
5 |
6 |
7 | def calculate_2_wasserstein_dist(X, Y):
8 | '''
9 | Calulates the two components of the 2-Wasserstein metric:
10 | The general formula is given by: d(P_X, P_Y) = min_{X, Y} E[|X-Y|^2]
11 |
12 | For multivariate gaussian distributed inputs z_X ~ MN(mu_X, cov_X) and z_Y ~ MN(mu_Y, cov_Y),
13 | this reduces to: d = |mu_X - mu_Y|^2 - Tr(cov_X + cov_Y - 2(cov_X * cov_Y)^(1/2))
14 |
15 | Fast method implemented according to following paper: https://arxiv.org/pdf/2009.14075.pdf
16 |
17 | Input shape: [b, n] (e.g. batch_size x num_features)
18 | Output shape: scalar
19 | '''
20 |
21 | if X.shape != Y.shape:
22 | raise ValueError("Expecting equal shapes for X and Y!")
23 |
24 | # the linear algebra ops will need some extra precision -> convert to double
25 | X, Y = X.transpose(0, 1).double(), Y.transpose(0, 1).double() # [n, b]
26 | mu_X, mu_Y = torch.mean(X, dim=1, keepdim=True), torch.mean(Y, dim=1, keepdim=True) # [n, 1]
27 | n, b = X.shape
28 | fact = 1.0 if b < 2 else 1.0 / (b - 1)
29 |
30 | # Cov. Matrix
31 | E_X = X - mu_X
32 | E_Y = Y - mu_Y
33 | cov_X = torch.matmul(E_X, E_X.t()) * fact # [n, n]
34 | cov_Y = torch.matmul(E_Y, E_Y.t()) * fact
35 |
36 | # calculate Tr((cov_X * cov_Y)^(1/2)). with the method proposed in https://arxiv.org/pdf/2009.14075.pdf
37 | # The eigenvalues for M are real-valued.
38 | C_X = E_X * math.sqrt(fact) # [n, n], "root" of covariance
39 | C_Y = E_Y * math.sqrt(fact)
40 | M_l = torch.matmul(C_X.t(), C_Y)
41 | M_r = torch.matmul(C_Y.t(), C_X)
42 | M = torch.matmul(M_l, M_r)
43 | S = linalg.eigvals(M) + 1e-15 # add small constant to avoid infinite gradients from sqrt(0)
44 | sq_tr_cov = S.sqrt().abs().sum()
45 |
46 | # plug the sqrt_trace_component into Tr(cov_X + cov_Y - 2(cov_X * cov_Y)^(1/2))
47 | trace_term = torch.trace(cov_X + cov_Y) - 2.0 * sq_tr_cov # scalar
48 |
49 | # |mu_X - mu_Y|^2
50 | diff = mu_X - mu_Y # [n, 1]
51 | mean_term = torch.sum(torch.mul(diff, diff)) # scalar
52 |
53 | # put it together
54 | return (trace_term + mean_term).float()
--------------------------------------------------------------------------------