├── data
└── .gitkeep
├── ckpts
└── .gitkeep
├── .gitignore
├── pyproject.toml
├── examples
├── ex_2d.png
├── ex_coins.png
├── ex_mnist.png
├── ex_2d_quiver.png
├── ex_mnist_crop.png
├── ex_coins.py
├── ex_2d.py
├── ex_mnist_simple.py
├── ex_mnist.py
├── ex_cifar.py
└── ex_moving_mnist.py
├── figures
├── fig_tail.py
├── fig_noisy_img.py
├── fig_lambdas_discrete.py
├── fig_karras_param.py
├── fig_lambdas_continuous.py
├── fig_schedules.py
└── fig_shifts.py
├── README.md
└── src
├── schedules.py
├── samplers.py
├── blocks_3d.py
├── score_matching.py
├── simple
└── diffusion.py
├── experimental
└── consistency.py
└── blocks.py
/data/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ckpts/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.pyc
3 | figures/*.png
4 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 | line-length=120
3 |
--------------------------------------------------------------------------------
/examples/ex_2d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/diffusion/HEAD/examples/ex_2d.png
--------------------------------------------------------------------------------
/examples/ex_coins.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/diffusion/HEAD/examples/ex_coins.png
--------------------------------------------------------------------------------
/examples/ex_mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/diffusion/HEAD/examples/ex_mnist.png
--------------------------------------------------------------------------------
/examples/ex_2d_quiver.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/diffusion/HEAD/examples/ex_2d_quiver.png
--------------------------------------------------------------------------------
/examples/ex_mnist_crop.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/diffusion/HEAD/examples/ex_mnist_crop.png
--------------------------------------------------------------------------------
/figures/fig_tail.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pyplot as plt
2 | import numpy as np
3 | import scipy as sp
4 | import scipy.stats
5 |
6 |
7 | if __name__ == "__main__":
8 |
9 | T = 1000
10 | linspace = np.linspace(0, 6, T)
11 | tail = np.linspace(2.5, 6, T)
12 |
13 | plt.figure(figsize=(7, 2.0), dpi=300)
14 | plt.subplot(1, 2, 1)
15 | plt.plot(linspace, sp.stats.expon.pdf(linspace), color="black", label="$f(x)$")
16 | plt.fill_between(tail, sp.stats.expon.pdf(tail), color="grey")
17 | plt.legend(loc="upper right")
18 | plt.yticks([])
19 | plt.xticks([])
20 | plt.ylim((0, 1))
21 | plt.subplot(1, 2, 2)
22 | plt.plot(-linspace[::-1], sp.stats.expon.pdf(linspace)[::-1], color="black", label="$f(y)$")
23 | plt.fill_between(-tail[::-1], sp.stats.expon.pdf(tail)[::-1], color="grey")
24 | plt.legend(loc="upper right")
25 | plt.yticks([])
26 | plt.xticks([])
27 | plt.ylim((0, 1))
28 | plt.tight_layout()
29 | plt.savefig("./figures/fig_tail.png")
30 | plt.show()
31 |
--------------------------------------------------------------------------------
/figures/fig_noisy_img.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 | from skimage import data
4 |
5 |
6 | def add_noise(img, scale):
7 | return np.clip(img + scale * np.random.randn(*img.shape), a_min=-1, a_max=1)
8 |
9 |
10 | if __name__ == "__main__":
11 |
12 | figure = plt.figure(figsize=(7, 2.6), dpi=300)
13 |
14 | img = data.camera()
15 | img = 2 * (img.astype(np.float32) / 255 - 0.5)
16 | assert img.shape == (512, 512)
17 |
18 | img_1x = img
19 | img_2x = img[::2, ::2]
20 | img_4x = img[::4, ::4]
21 |
22 | img_1x = add_noise(img_1x, 0.25)
23 | img_2x = add_noise(img_2x, 0.25)
24 | img_4x = add_noise(img_4x, 0.25)
25 |
26 | plt.subplot(1, 3, 1)
27 | plt.imshow(img_1x, cmap="gray")
28 | plt.title("512 x 512")
29 | plt.axis("off")
30 | plt.subplot(1, 3, 2)
31 | plt.imshow(img_2x, cmap="gray")
32 | plt.title("256 x 256")
33 | plt.axis("off")
34 | plt.subplot(1, 3, 3)
35 | plt.imshow(img_4x, cmap="gray")
36 | plt.title("128 x 128")
37 | plt.axis("off")
38 | plt.savefig("./figures/fig_noisy_img.png")
39 | plt.show()
40 |
--------------------------------------------------------------------------------
/figures/fig_lambdas_discrete.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 |
4 | from figures.fig_schedules import get_cosine_schedule
5 |
6 |
7 | if __name__ == "__main__":
8 |
9 | T = 15
10 | linspace = np.linspace(-7, 7, 500)
11 |
12 | plt.figure(figsize=(7, 2.6), dpi=300)
13 |
14 | #
15 | # First plot the cosine schedule
16 | #
17 | alpha = get_cosine_schedule(T)
18 | sigma = (1 - alpha ** 2) ** 0.5
19 |
20 | log_lambd = np.log(sigma) - np.log(alpha)
21 |
22 | for idx, one_lambd in enumerate(log_lambd):
23 | label = "Cosine Schedule" if idx == 0 else None
24 | plt.axvline(one_lambd, color="black", linestyle="--", label=label, alpha=0.5)
25 |
26 | #
27 | # Second plot the score matching geometric distribution
28 | #
29 | min_sigma = 0.002
30 | max_sigma = 80.0
31 | log_lambd = np.linspace(np.log(min_sigma), np.log(max_sigma), T)
32 |
33 | for idx, one_lambd in enumerate(log_lambd):
34 | label = "Score Matching" if idx == 0 else None
35 | plt.axvline(one_lambd, color="maroon", linestyle="--", label=label, alpha=0.5)
36 |
37 | plt.xlim((-7, 7))
38 | plt.yticks([])
39 | plt.xlabel("$\\log\\lambda$")
40 | plt.legend()
41 | plt.tight_layout()
42 | plt.tick_params(bottom=False, top=False)
43 | plt.savefig("./figures/fig_lambdas_discrete.png")
44 |
45 |
--------------------------------------------------------------------------------
/figures/fig_karras_param.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pyplot as plt
2 | import numpy as np
3 |
4 |
5 | def c_skip(x, sigma_data):
6 | return sigma_data ** 2 / (np.exp(x) ** 2 + sigma_data ** 2)
7 |
8 | def c_out(x, sigma_data):
9 | return np.exp(x) * sigma_data / (np.exp(x) ** 2 + sigma_data ** 2) ** 0.5
10 |
11 |
12 | if __name__ == "__main__":
13 |
14 | T = 1000
15 | linspace = np.linspace(-8, 6, T)
16 |
17 | plt.figure(figsize=(7, 2.6), dpi=300)
18 | plt.subplot(1, 2, 1)
19 | plt.plot(linspace, c_skip(linspace, 1.0), color="black", label="$\\sigma_\\text{data}=1.0$")
20 | plt.plot(linspace, c_skip(linspace, 0.5), color="red", label="$\\sigma_\\text{data}=0.5$")
21 | plt.plot(linspace, c_skip(linspace, 0.25), color="blue", label="$\\sigma_\\text{data}=0.25$")
22 | plt.legend(loc="upper left")
23 | plt.xlabel("$\\log\\sigma$")
24 | plt.title("$c_\\text{skip}$")
25 | plt.tick_params(bottom=False, top=False)
26 | plt.subplot(1, 2, 2)
27 | plt.plot(linspace, c_out(linspace, 1.0), color="black", label="$\\sigma_\\text{data}=1.0$")
28 | plt.plot(linspace, c_out(linspace, 0.5), color="red", label="$\\sigma_\\text{data}=0.5$")
29 | plt.plot(linspace, c_out(linspace, 0.25), color="blue", label="$\\sigma_\\text{data}=0.25$")
30 | plt.legend(loc="upper left")
31 | plt.xlabel("$\\log\\sigma$")
32 | plt.title("$c_\\text{out}$")
33 | plt.tick_params(bottom=False, top=False)
34 | plt.tight_layout()
35 | plt.savefig("./figures/fig_karras_param.png")
36 | plt.show()
37 |
38 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### Diffusion Models From Scratch
2 |
3 | March 2023.
4 |
5 | ---
6 |
7 | My notes for this repository ended up longer than expected, too long to be rendered by GitHub.
8 |
9 | So instead of putting notes here, they've been moved to my website.
10 |
11 | [[**This blog post**]](https://www.tonyduan.com/diffusion/index.html) explains the intuition and derivations behind diffusion.
12 |
13 | ---
14 |
15 | This codebase provides a *minimalist* re-production of the MNIST example below.
16 |
17 | It clocks in at well under 500 LOC.
18 |
19 |
20 |
21 |
22 |
23 | (Left: MNIST groundtruth. Right: MNIST sampling starting from random noise).
24 |
25 | ---
26 |
27 | **Example Usage**
28 |
29 | Code below is copied from `examples/ex_mnist_simple.py`, omitting boilerplate training code.
30 |
31 | ```python
32 | # Initialization
33 | nn_module = UNet(in_dim=1, embed_dim=128, dim_scales=(1, 2, 4, 8))
34 | model = DiffusionModel(
35 | nn_module=nn_module,
36 | input_shape=(1, 32, 32,),
37 | config=DiffusionModelConfig(
38 | num_timesteps=500,
39 | target_type="pred_x_0",
40 | gamma_type="ddim",
41 | noise_schedule_type="cosine",
42 | ),
43 | )
44 |
45 | # Training Loop
46 | for i in range(args.iterations):
47 | loss = model.loss(x_batch).mean()
48 | loss.backward()
49 |
50 | # Sampling, the number of timesteps can be less than T to accelerate
51 | samples = model.sample(bsz=64, num_sampling_timesteps=None, device="cuda")
52 | ```
53 |
--------------------------------------------------------------------------------
/figures/fig_lambdas_continuous.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pyplot as plt
2 | import numpy as np
3 | import scipy as sp
4 | import scipy.stats
5 |
6 |
7 | def cosine_pdf(x, kappa):
8 | #
9 | # Equivalent to hyperbolic secant distribution (ignore truncation due to sigma_min, sigma_max).
10 | #
11 | return 1 / np.pi / np.cosh(x + np.log(kappa))
12 |
13 | def normal_pdf(x, loc, scale, sigma_min, sigma_max):
14 | #
15 | # Simple truncated normal.
16 | #
17 | return sp.stats.truncnorm.pdf(x, a=np.log(sigma_min), b=np.log(sigma_max), loc=loc, scale=scale)
18 |
19 | def exp_pdf(x, sigma_min, sigma_max, rho):
20 | #
21 | # Adjust for the truncation due to [sigma_min, sigma_max] using partition constant.
22 | #
23 | partition_constant = 1 - np.exp(-(np.log(sigma_max) - np.log(sigma_min)) / rho)
24 | return 1 / rho * np.exp((x - np.log(sigma_max)) / rho) / partition_constant
25 |
26 |
27 | if __name__ == "__main__":
28 |
29 | T = 1000
30 | linspace = np.linspace(-8, 6, T)
31 |
32 | plt.figure(figsize=(7, 2.6), dpi=300)
33 | plt.plot(linspace, cosine_pdf(linspace, 1.0), color="black", label="Cosine($\\kappa=1.0$)")
34 | plt.plot(linspace, cosine_pdf(linspace, 0.5), color="red", label="Cosine($\\kappa=0.5$)")
35 | plt.plot(linspace, normal_pdf(linspace, -1.2, 1.2, 0.002, 80), color="green", label="N($\\mu=-1.2, \\gamma=1.2$)")
36 | plt.plot(linspace, exp_pdf(linspace, 0.002, 80, 7), color="blue", label="Exp($\\rho=7$)")
37 | plt.legend()
38 | plt.yticks([])
39 | plt.xlabel("$\\log\\sigma$")
40 | plt.tight_layout()
41 | plt.tick_params(bottom=False, top=False)
42 | plt.savefig("./figures/fig_lambdas_continuous.png")
43 | plt.show()
44 |
45 |
--------------------------------------------------------------------------------
/figures/fig_schedules.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | from matplotlib import pyplot as plt
4 |
5 |
6 | def get_linear_schedule(num_timesteps):
7 | beta_t = np.linspace(1e-4, 2e-2, num_timesteps + 1)
8 | alpha_t = np.cumprod(1 - beta_t, axis=0) ** 0.5
9 | return alpha_t
10 |
11 | def get_cosine_schedule(num_timesteps):
12 | linspace = np.linspace(0, 1, num_timesteps + 1)
13 | f_t = np.cos((linspace + 0.008) / (1 + 0.008) * math.pi / 2) ** 2
14 | bar_alpha_t = f_t / f_t[0]
15 | beta_t = np.zeros_like(bar_alpha_t)
16 | beta_t[1:] = np.clip(1 - (bar_alpha_t[1:] / bar_alpha_t[:-1]), a_min=0, a_max=0.999)
17 | alpha_t = np.cumprod(1 - beta_t, axis=0) ** 0.5
18 | return alpha_t
19 |
20 |
21 | if __name__ == "__main__":
22 |
23 | T = 500
24 | T_rng = np.arange(T + 1)
25 |
26 | alpha_linear = get_linear_schedule(T)
27 | alpha_cosine = get_cosine_schedule(T)
28 |
29 | sigma_linear = (1 - alpha_linear ** 2) ** 0.5
30 | sigma_cosine = (1 - alpha_cosine ** 2) ** 0.5
31 |
32 | logsnr_linear = 2 * (np.log(alpha_linear) - np.log(sigma_linear))
33 | logsnr_cosine = 2 * (np.log(alpha_cosine) - np.log(sigma_cosine))
34 |
35 | plt.figure(figsize=(7, 2.6), dpi=200)
36 |
37 | plt.subplot(1, 3, 1)
38 | plt.plot(T_rng, alpha_linear, label="Linear", color="#00204E")
39 | plt.plot(T_rng, alpha_cosine, label="Cosine", color="#800000")
40 | plt.ylabel("$\\alpha_t$")
41 | plt.xlabel("Timestep $t$")
42 | plt.xticks([])
43 | plt.yticks([0,1])
44 | plt.legend(loc="upper right")
45 |
46 | plt.subplot(1, 3, 2)
47 | plt.plot(T_rng, sigma_linear, label="Linear", color="#00204E")
48 | plt.plot(T_rng, sigma_cosine, label="Cosine", color="#800000")
49 | plt.ylabel("$\\sigma_t$")
50 | plt.xlabel("Timestep $t$")
51 | plt.xticks([])
52 | plt.yticks([0, 1])
53 | plt.legend(loc="upper right")
54 |
55 | plt.subplot(1, 3, 3)
56 | plt.plot(T_rng, logsnr_linear, label="Linear", color="#00204E")
57 | plt.plot(T_rng, logsnr_cosine, label="Cosine", color="#800000")
58 | plt.ylabel("$\\log SNR_t$")
59 | plt.xlabel("Timestep $t$")
60 | plt.xticks([])
61 | plt.yticks([-10, 10])
62 | plt.ylim([-11, 11])
63 | plt.legend(loc="upper right")
64 |
65 | plt.tight_layout()
66 | plt.savefig("./figures/fig_schedules.png")
67 |
68 |
--------------------------------------------------------------------------------
/src/schedules.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from abc import ABC, abstractmethod
3 | from functools import cached_property
4 | import math
5 |
6 | import torch
7 |
8 |
9 | @dataclass(frozen=True)
10 | class NoiseSchedule(ABC):
11 |
12 | @abstractmethod
13 | def get_sigma_ppf(self, p: torch.Tensor, sigma_min: float, sigma_max: float):
14 | """
15 | Transform p in [0, 1] to sigma between [sigma_min, sigma_max] via icdf.
16 | """
17 | pass
18 |
19 |
20 | @dataclass(frozen=True)
21 | class CosineSchedule(NoiseSchedule):
22 |
23 | sigma_min: float = 0.002
24 | sigma_max: float = 80.0
25 | kappa: float = 1.0
26 |
27 | @cached_property
28 | def theta_min(self):
29 | logsnr_min = 2 * (math.log(self.kappa) - math.log(self.sigma_min))
30 | return math.atan(math.exp(-0.5 * logsnr_min))
31 |
32 | @cached_property
33 | def theta_max(self):
34 | logsnr_max = 2 * (math.log(self.kappa) - math.log(self.sigma_max))
35 | return math.atan(math.exp(-0.5 * logsnr_max))
36 |
37 | def get_sigma_ppf(self, p: torch.Tensor):
38 | return torch.tan(self.theta_min + p * (self.theta_max - self.theta_min)) / self.kappa
39 |
40 |
41 | @dataclass(frozen=True)
42 | class ExponentialSchedule(NoiseSchedule):
43 |
44 | sigma_min: float = 0.002
45 | sigma_max: float = 80.0
46 | rho: float = 7.0
47 |
48 | @cached_property
49 | def inv_sigma_min(self):
50 | return self.sigma_min ** (1 / self.rho)
51 |
52 | @cached_property
53 | def inv_sigma_max(self):
54 | return self.sigma_max ** (1 / self.rho)
55 |
56 | def get_sigma_ppf(self, p: torch.Tensor):
57 | return (self.inv_sigma_min + p * (self.inv_sigma_max - self.inv_sigma_min)) ** self.rho
58 |
59 |
60 | @dataclass(frozen=True)
61 | class NormalSchedule(NoiseSchedule):
62 |
63 | sigma_min: float = 0.002
64 | sigma_max: float = 80.0
65 | mean: float = -1.2
66 | std: float = 1.2
67 |
68 | @cached_property
69 | def erf_sigma_min(self):
70 | return math.erf(math.log(self.sigma_min))
71 |
72 | @cached_property
73 | def erf_sigma_max(self):
74 | return math.erf(math.log(self.sigma_max))
75 |
76 | def get_sigma_ppf(self, p: torch.Tensor):
77 | z = torch.special.erfinv(self.erf_sigma_min + p * (self.erf_sigma_max - self.erf_sigma_min))
78 | return torch.exp(z * self.std + self.mean)
79 |
80 |
--------------------------------------------------------------------------------
/examples/ex_coins.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import logging
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import torch
7 | import torch.optim as optim
8 |
9 | from src.blocks import FFN
10 | from src.simple.diffusion import DiffusionModel, DiffusionModelConfig
11 |
12 |
13 | def gen_data(n=512, d=48):
14 | x = np.vstack([
15 | np.random.permutation(np.r_[np.ones(3 * d // 4), np.zeros(d // 4)])
16 | for _ in range(n)
17 | ])
18 | assert np.all(np.sum(x, axis=1) == 3 * d // 4)
19 | x = x * 2 - 1
20 | return x.astype(np.float32)
21 |
22 |
23 | if __name__ == "__main__":
24 |
25 | argparser = ArgumentParser()
26 | argparser.add_argument("--n", default=512, type=int)
27 | argparser.add_argument("--iterations", default=2000, type=int)
28 | args = argparser.parse_args()
29 |
30 | logging.basicConfig(level=logging.INFO)
31 | logger = logging.getLogger(__name__)
32 |
33 | nn_module = FFN(in_dim=48, embed_dim=96)
34 | model = DiffusionModel(
35 | nn_module=nn_module,
36 | input_shape=(48,),
37 | config=DiffusionModelConfig(
38 | num_timesteps=100,
39 | target_type="pred_x_0",
40 | gamma_type="ddim",
41 | noise_schedule_type="cosine",
42 | ),
43 | )
44 |
45 | optimizer = optim.Adam(model.parameters(), lr=0.001)
46 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.iterations)
47 |
48 | for i in range(args.iterations):
49 | x = torch.from_numpy(gen_data(args.n))
50 | optimizer.zero_grad()
51 | loss = model.loss(x).mean()
52 | loss.backward()
53 | optimizer.step()
54 | scheduler.step()
55 | if i % 100 == 0:
56 | logger.info(f"Iter: {i}\t" + f"Loss: {loss.data:.2f}\t")
57 |
58 | model.eval()
59 | samples = model.sample(bsz=512, device="cpu")
60 | num_heads = (samples[0] > 0).sum(dim=1).cpu().numpy()
61 |
62 | plt.figure(figsize=(8, 3))
63 | plt.hist(num_heads, range=(10, 50), bins=20, alpha=0.5, color="grey", label="Samples")
64 | logger.info(f"Samples mean {np.mean(num_heads):.2f} sd {np.std(num_heads):.2f}")
65 |
66 | random_sample = np.random.rand(args.n, 48) > 0.5
67 | num_heads = random_sample.sum(axis=1)
68 |
69 | plt.hist(num_heads, range=(10, 50), bins=20, alpha=0.5, color="black", label="Prior")
70 | logger.info(f"Prior mean {np.mean(num_heads):.2f} sd {np.std(num_heads):.2f}")
71 |
72 | plt.tight_layout()
73 | plt.savefig("examples/ex_coins.png")
74 |
--------------------------------------------------------------------------------
/figures/fig_shifts.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.special import expit
3 | from matplotlib import pyplot as plt
4 |
5 | from figures.fig_schedules import get_cosine_schedule
6 |
7 |
8 | def draw_lines(logsnr, color="black", stride=10):
9 | alpha = expit(logsnr) ** 0.5
10 | sigma = expit(-logsnr) ** 0.5
11 | for x, y in zip(alpha[::stride], sigma[::stride]):
12 | plt.plot([0, x], [0, y], color=color, alpha=0.5)
13 |
14 |
15 | if __name__ == "__main__":
16 |
17 | T = 201
18 | T_rng = np.linspace(0, 1, T + 1)
19 |
20 | alpha = get_cosine_schedule(T)
21 | sigma = (1 - alpha ** 2) ** 0.5
22 | cosine_logsnr = 2 * (np.log(alpha) - np.log(sigma))
23 |
24 | shifted_down_logsnr = cosine_logsnr + 2 * np.log(0.5)
25 | shifted_up_logsnr = cosine_logsnr + 2 * np.log(2)
26 |
27 | plt.figure(figsize=(7, 5.2), dpi=200)
28 |
29 | plt.subplot(2, 3, 1)
30 | plt.plot(T_rng, cosine_logsnr, label="Linear", color="black")
31 | plt.plot(T_rng, shifted_down_logsnr, label="Shifted down", color="navy")
32 | plt.plot(T_rng, shifted_up_logsnr, label="Shifted up", color="maroon")
33 | plt.ylabel("$\\log SNR_t$")
34 | plt.xlabel("Timestep $t$")
35 | plt.xticks([])
36 | plt.yticks([-10, 10])
37 | plt.ylim([-11, 11])
38 | plt.legend(loc="lower left")
39 |
40 | plt.subplot(2, 3, 2)
41 | plt.plot(T_rng, expit(cosine_logsnr) ** 0.5, label="Linear", color="black")
42 | plt.plot(T_rng, expit(shifted_down_logsnr) ** 0.5, label="Shifted down", color="navy")
43 | plt.plot(T_rng, expit(shifted_up_logsnr) ** 0.5, label="Shifted up", color="maroon")
44 | plt.ylabel("$\\alpha_t$")
45 | plt.xlabel("Timestep $t$")
46 | plt.xticks([])
47 | plt.yticks([0, 1])
48 | plt.legend(loc="lower left")
49 |
50 | plt.subplot(2, 3, 3)
51 | plt.plot(T_rng, expit(-cosine_logsnr) ** 0.5, label="Linear", color="black")
52 | plt.plot(T_rng, expit(-shifted_down_logsnr) ** 0.5, label="Shifted down", color="navy")
53 | plt.plot(T_rng, expit(-shifted_up_logsnr) ** 0.5, label="Shifted up", color="maroon")
54 | plt.ylabel("$\\sigma_t$")
55 | plt.xlabel("Timestep $t$")
56 | plt.xticks([])
57 | plt.yticks([0, 1])
58 | plt.legend(loc="lower left")
59 |
60 | plt.subplot(2, 3, 4)
61 | draw_lines(cosine_logsnr)
62 | plt.xticks([])
63 | plt.yticks([])
64 | plt.xlabel("$x_t$")
65 | plt.ylabel("$\\epsilon_t$")
66 | plt.title("Cosine Schedule")
67 |
68 | plt.subplot(2, 3, 5)
69 | draw_lines(shifted_down_logsnr, color="navy")
70 | plt.xticks([])
71 | plt.yticks([])
72 | plt.xlabel("$x_t$")
73 | plt.ylabel("$\\epsilon_t$")
74 | plt.title("Shifted down $+2\\log\\left(\\frac{1}{2}\\right)$")
75 |
76 | plt.subplot(2, 3, 6)
77 | draw_lines(shifted_up_logsnr, color="maroon")
78 | plt.xticks([])
79 | plt.yticks([])
80 | plt.xlabel("$x_t$")
81 | plt.ylabel("$\\epsilon_t$")
82 | plt.title("Shifted up $+2\\log(2)$")
83 |
84 | plt.tight_layout()
85 | plt.savefig("./figures/fig_shifts.png")
86 |
87 |
--------------------------------------------------------------------------------
/examples/ex_2d.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import logging
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import torch
7 | import torch.optim as optim
8 |
9 | from src.blocks import FFN
10 | from src.simple.diffusion import DiffusionModel, DiffusionModelConfig
11 |
12 |
13 | def gen_mixture_data(n=512):
14 | x = np.r_[np.random.randn(n // 2, 2) + np.array([-5, 0]),
15 | np.random.randn(n // 2, 2) + np.array([5, 0])]
16 | x = (x - np.mean(x, axis=0, keepdims=True)) / np.std(x, axis=0, ddof=1, keepdims=True)
17 | return x.astype(np.float32)
18 |
19 | def plot_data(x):
20 | plt.hist2d(x[:,0].numpy(), x[:,1].numpy(), bins=100, range=np.array([(-3, 3), (-6, 6)]))
21 | plt.axis("off")
22 |
23 |
24 | if __name__ == "__main__":
25 |
26 | argparser = ArgumentParser()
27 | argparser.add_argument("--n", default=512, type=int)
28 | argparser.add_argument("--iterations", default=2000, type=int)
29 | args = argparser.parse_args()
30 |
31 | logging.basicConfig(level=logging.INFO)
32 | logger = logging.getLogger(__name__)
33 |
34 | nn_module = FFN(in_dim=2, embed_dim=16)
35 | model = DiffusionModel(
36 | nn_module=nn_module,
37 | input_shape=(2,),
38 | config=DiffusionModelConfig(
39 | num_timesteps=10,
40 | noise_schedule_type="linear",
41 | target_type="pred_eps",
42 | gamma_type="ddpm",
43 | ),
44 | )
45 |
46 | optimizer = optim.Adam(model.parameters(), lr=0.001)
47 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.iterations)
48 |
49 | for i in range(args.iterations):
50 | x = torch.from_numpy(gen_mixture_data(args.n))
51 | optimizer.zero_grad()
52 | loss = model.loss(x).mean()
53 | loss.backward()
54 | optimizer.step()
55 | scheduler.step()
56 | if i % 100 == 0:
57 | logger.info(f"Iter: {i}\t" + f"Loss: {loss.data:.2f}\t")
58 |
59 | model.eval()
60 | samples = model.sample(bsz=512, device="cpu")
61 |
62 | plt.figure(figsize=(12, 12))
63 | plt.subplot(4, 4, 1)
64 | plot_data(x)
65 | plt.title("Actual data")
66 | for t in range(11):
67 | plt.subplot(4, 4, t + 2)
68 | plot_data(samples[t])
69 | plt.title(f"Sample t={t}")
70 |
71 | samples = model.sample(bsz=512, device="cpu", num_sampling_timesteps=3)
72 |
73 | for t in range(4):
74 | plt.subplot(4, 4, t + 12 + 1)
75 | plot_data(samples[t])
76 | plt.title(f"Accelerated Sample t={t}")
77 |
78 | plt.tight_layout()
79 | plt.savefig("./examples/ex_2d.png")
80 |
81 | x_grid, y_grid = torch.meshgrid(torch.linspace(-3, 3, 30), torch.linspace(-6, 6, 20), indexing="ij")
82 | x = torch.from_numpy(np.c_[x_grid.flatten(), y_grid.flatten()])
83 | bsz, _ = x.shape
84 |
85 | fig, axes = plt.subplots(3, 4, figsize=(12, 9))
86 | axes[0, 0].axis("off")
87 |
88 | for scalar_t in range(0, model.num_timesteps + 1):
89 |
90 | t = torch.full((bsz,), fill_value=scalar_t, device=x.device)
91 |
92 | with torch.no_grad():
93 | pred_eps = model.nn_module(x, t)
94 |
95 | pred_eps = pred_eps.numpy().reshape(30, 20, 2)
96 | one_axes = axes[(scalar_t + 1) // 4, (scalar_t + 1) % 4]
97 | one_axes.quiver(x_grid.numpy(), y_grid.numpy(), -pred_eps[..., 0], -pred_eps[..., 1])
98 | one_axes.axis("off")
99 | one_axes.set_title(f"Score {scalar_t}")
100 |
101 | fig.tight_layout()
102 | fig.savefig("./examples/ex_2d_quiver.png")
103 |
104 |
--------------------------------------------------------------------------------
/examples/ex_mnist_simple.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import logging
3 |
4 | from einops import rearrange
5 | from matplotlib import pyplot as plt
6 | import numpy as np
7 | from sklearn.datasets import fetch_openml
8 | import torch
9 | import torch.optim as optim
10 |
11 | from src.blocks import UNet
12 | from src.simple.diffusion import DiffusionModel, DiffusionModelConfig
13 |
14 |
15 | if __name__ == "__main__":
16 |
17 | argparser = ArgumentParser()
18 | argparser.add_argument("--iterations", default=2000, type=int)
19 | argparser.add_argument("--batch-size", default=512, type=int)
20 | argparser.add_argument("--device", default="cuda", type=str, choices=("cuda", "cpu", "mps"))
21 | args = argparser.parse_args()
22 |
23 | logging.basicConfig(level=logging.INFO)
24 | logger = logging.getLogger(__name__)
25 |
26 | # Load data from https://www.openml.org/d/554
27 | # (70000, 784) values between 0-255
28 | x, _ = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False, cache=True)
29 |
30 | # Reshape to 32x32
31 | x = rearrange(x, "b (h w) -> b h w", h=28, w=28)
32 | x = np.pad(x, pad_width=((0, 0), (2, 2), (2, 2)))
33 | x = rearrange(x, "b h w -> b (h w)")
34 |
35 | # Standardize to [-1, 1]
36 | input_mean = np.full((1, 32 ** 2), fill_value=127.5, dtype=np.float32)
37 | input_sd = np.full((1, 32 ** 2), fill_value=127.5, dtype=np.float32)
38 | x = ((x - input_mean) / input_sd).astype(np.float32)
39 |
40 | nn_module = UNet(1, 128, (1, 2, 4, 8))
41 | model = DiffusionModel(
42 | nn_module=nn_module,
43 | input_shape=(1, 32, 32,),
44 | config=DiffusionModelConfig(
45 | num_timesteps=500,
46 | target_type="pred_x_0",
47 | gamma_type="ddim",
48 | noise_schedule_type="cosine",
49 | ),
50 | )
51 | model = model.to(args.device)
52 |
53 | optimizer = optim.Adam(model.parameters(), lr=0.001)
54 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.iterations)
55 |
56 | for i in range(args.iterations):
57 | x_batch = x[np.random.choice(len(x), args.batch_size)]
58 | x_batch = torch.from_numpy(x_batch).to(args.device)
59 | x_batch = rearrange(x_batch, "b (h w) -> b () h w", h=32, w=32)
60 | optimizer.zero_grad()
61 | loss = model.loss(x_batch).mean()
62 | loss.backward()
63 | optimizer.step()
64 | scheduler.step()
65 | if i % 100 == 0:
66 | logger.info(f"Iter: {i}\t" + f"Loss: {loss.data:.2f}\t")
67 |
68 | model.eval()
69 |
70 | samples = model.sample(bsz=64, num_sampling_timesteps=None, device=args.device).cpu().numpy()
71 | samples = rearrange(samples, "t b () h w -> t b (h w)")
72 | samples = samples * input_sd + input_mean
73 | x_vis = x[:64] * input_sd + input_mean
74 |
75 | nrows, ncols = 10, 2
76 | percents = (100, 75, 50, 25, 0)
77 | raster = np.zeros((nrows * 32, ncols * 32 * (len(percents) + 1)), dtype=np.float32)
78 |
79 | for i in range(nrows * ncols):
80 | row, col = i // ncols, i % ncols
81 | raster[32 * row : 32 * (row + 1), 32 * col : 32 * (col + 1)] = x_vis[i].reshape(32, 32)
82 | for percent_idx, percent in enumerate(percents):
83 | itr_num = int(round(0.01 * percent * (len(samples) - 1)))
84 | for i in range(nrows * ncols):
85 | row, col = i // ncols, i % ncols
86 | offset = 32 * ncols * (percent_idx + 1)
87 | raster[32 * row : 32 * (row + 1), offset + 32 * col : offset + 32 * (col + 1)] = samples[itr_num][i].reshape(32, 32)
88 |
89 | plt.imsave("./examples/ex_mnist.png", raster, vmin=0, vmax=255)
90 |
--------------------------------------------------------------------------------
/src/samplers.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from collections import deque
3 | from dataclasses import dataclass
4 | from functools import cached_property
5 |
6 | import torch
7 |
8 |
9 | @dataclass
10 | class InstantaneousPrediction:
11 |
12 | sigma: float
13 | x_t: torch.Tensor
14 | pred_x_0: torch.Tensor
15 |
16 | @cached_property
17 | def pred_eps(self):
18 | return (self.x_t - self.pred_x_0) / self.sigma
19 |
20 |
21 | class Sampler(ABC):
22 |
23 | @abstractmethod
24 | def reset(self):
25 | pass
26 |
27 | @abstractmethod
28 | def step(self, sigma_t: float, sigma_t_plus_1: float, pred_at_x_t: InstantaneousPrediction):
29 | """
30 | Take a step from sigma_t to sigma_t_plus_1.
31 | """
32 | pass
33 |
34 |
35 | class EulerSampler(Sampler):
36 |
37 | def reset(self):
38 | pass
39 |
40 | def step(self, sigma_t: float, sigma_t_plus_1: float, pred_at_x_t: InstantaneousPrediction):
41 | return pred_at_x_t.x_t + pred_at_x_t.pred_eps * (sigma_t_plus_1 - sigma_t)
42 |
43 |
44 | class MultistepDPMSampler(Sampler):
45 |
46 | def reset(self):
47 | self.history: deque[InstantaneousPrediction] = deque(maxlen=2) # Order k=3
48 |
49 | def step(self, sigma_t: float, sigma_t_plus_1: float, pred_at_x_t: InstantaneousPrediction):
50 | if len(self.history) == 0:
51 | x_t_plus_1 = self.step_first_order(sigma_t, sigma_t_plus_1, pred_at_x_t)
52 | elif len(self.history) == 1:
53 | x_t_plus_1 = self.step_second_order(sigma_t, sigma_t_plus_1, pred_at_x_t)
54 | else:
55 | x_t_plus_1 = self.step_third_order(sigma_t, sigma_t_plus_1, pred_at_x_t)
56 | self.history.append(pred_at_x_t)
57 | return x_t_plus_1
58 |
59 | def step_first_order(self, sigma_t: float, sigma_t_plus_1: float, pred_at_x_t: InstantaneousPrediction):
60 | d_sigma = sigma_t_plus_1 - sigma_t
61 | return pred_at_x_t.x_t + pred_at_x_t.pred_eps * d_sigma
62 |
63 | def step_second_order(self, sigma_t: float, sigma_t_plus_1: float, pred_at_x_t: InstantaneousPrediction):
64 | pred_at_x_t_minus_1 = self.history[-1]
65 | d_sigma = sigma_t_plus_1 - sigma_t
66 | pred_first_derivative = (
67 | (pred_at_x_t.pred_eps - pred_at_x_t_minus_1.pred_eps) /
68 | (pred_at_x_t.sigma - pred_at_x_t_minus_1.sigma)
69 | )
70 | return (
71 | pred_at_x_t.x_t +
72 | pred_at_x_t.pred_eps * d_sigma +
73 | (1/2) * pred_first_derivative * d_sigma ** 2
74 | )
75 |
76 | def step_third_order(self, sigma_t: float, sigma_t_plus_1: float, pred_at_x_t: InstantaneousPrediction):
77 | pred_at_x_t_minus_1 = self.history[-1]
78 | pred_at_x_t_minus_2 = self.history[-2]
79 | d_sigma = sigma_t_plus_1 - sigma_t
80 | pred_first_derivative = (
81 | (pred_at_x_t.pred_eps - pred_at_x_t_minus_1.pred_eps) /
82 | (pred_at_x_t.sigma - pred_at_x_t_minus_1.sigma)
83 | )
84 | pred_first_derivative_past = (
85 | (pred_at_x_t_minus_1.pred_eps - pred_at_x_t_minus_2.pred_eps) /
86 | (pred_at_x_t_minus_1.sigma - pred_at_x_t_minus_2.sigma)
87 | )
88 | pred_second_derivative = (
89 | (pred_first_derivative - pred_first_derivative_past) /
90 | (0.5 * (pred_at_x_t.sigma + pred_at_x_t_minus_1.sigma) -
91 | 0.5 * (pred_at_x_t_minus_1.sigma + pred_at_x_t_minus_2.sigma))
92 | )
93 | return (
94 | pred_at_x_t.x_t +
95 | pred_at_x_t.pred_eps * d_sigma +
96 | (1/2) * pred_first_derivative * d_sigma ** 2 +
97 | (1/6) * pred_second_derivative * d_sigma ** 3
98 | )
99 |
--------------------------------------------------------------------------------
/examples/ex_mnist.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import logging
3 |
4 | from einops import rearrange
5 | from matplotlib import pyplot as plt
6 | import numpy as np
7 | from sklearn.datasets import fetch_openml
8 | import torch
9 | import torch.optim as optim
10 |
11 | from src.blocks import UNet
12 | from src.score_matching import ScoreMatchingModel, ScoreMatchingModelConfig
13 |
14 |
15 | if __name__ == "__main__":
16 |
17 | argparser = ArgumentParser()
18 | argparser.add_argument("--iterations", default=2000, type=int)
19 | argparser.add_argument("--batch-size", default=512, type=int)
20 | argparser.add_argument("--device", default="cuda", type=str, choices=("cuda", "cpu", "mps"))
21 | argparser.add_argument("--load-trained", default=0, type=int, choices=(0, 1))
22 | args = argparser.parse_args()
23 |
24 | logging.basicConfig(level=logging.INFO)
25 | logger = logging.getLogger(__name__)
26 |
27 | # Load data from https://www.openml.org/d/554
28 | # (70000, 784) values between 0-255
29 | x, _ = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False, cache=True)
30 |
31 | # Reshape to 32x32
32 | x = rearrange(x, "b (h w) -> b h w", h=28, w=28)
33 | x = np.pad(x, pad_width=((0, 0), (2, 2), (2, 2)))
34 | x = rearrange(x, "b h w -> b (h w)")
35 |
36 | # Standardize to [-1, 1]
37 | input_mean = np.full((1, 32 ** 2), fill_value=127.5, dtype=np.float32)
38 | input_sd = np.full((1, 32 ** 2), fill_value=127.5, dtype=np.float32)
39 | x = ((x - input_mean) / input_sd).astype(np.float32)
40 |
41 | nn_module = UNet(1, 128, (1, 2, 4, 8))
42 | model = ScoreMatchingModel(
43 | nn_module=nn_module,
44 | input_shape=(1, 32, 32,),
45 | config=ScoreMatchingModelConfig(
46 | sigma_min=0.002,
47 | sigma_max=80.0,
48 | sigma_data=1.0,
49 | ),
50 | )
51 | model = model.to(args.device)
52 |
53 | optimizer = optim.Adam(model.parameters(), lr=0.001)
54 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.iterations)
55 |
56 | if args.load_trained:
57 | model.load_state_dict(torch.load("./ckpts/mnist_trained.pt"))
58 | else:
59 | for step_num in range(args.iterations):
60 | x_batch = x[np.random.choice(len(x), args.batch_size)]
61 | x_batch = torch.from_numpy(x_batch).to(args.device)
62 | x_batch = rearrange(x_batch, "b (h w) -> b () h w", h=32, w=32)
63 | optimizer.zero_grad()
64 | loss = model.loss(x_batch).mean()
65 | loss.backward()
66 | optimizer.step()
67 | scheduler.step()
68 | if step_num % 100 == 0:
69 | logger.info(f"Iter: {step_num}\t" + f"Loss: {loss.data:.2f}\t")
70 | torch.save(model.state_dict(), "./ckpts/mnist_trained.pt")
71 |
72 | model.eval()
73 |
74 | samples = model.sample(bsz=64, num_sampling_timesteps=20, device=args.device).cpu().numpy()
75 | samples = rearrange(samples, "t b () h w -> t b (h w)")
76 | samples = samples * input_sd + input_mean
77 | x_vis = x[:64] * input_sd + input_mean
78 |
79 | nrows, ncols = 10, 2
80 | percents = (100, 75, 50, 25, 0)
81 | raster = np.zeros((nrows * 32, ncols * 32 * (len(percents) + 1)), dtype=np.float32)
82 |
83 | for i in range(nrows * ncols):
84 | row, col = i // ncols, i % ncols
85 | raster[32 * row : 32 * (row + 1), 32 * col : 32 * (col + 1)] = x_vis[i].reshape(32, 32)
86 | for percent_idx, percent in enumerate(percents):
87 | itr_num = int(round(0.01 * percent * (len(samples) - 1)))
88 | for i in range(nrows * ncols):
89 | row, col = i // ncols, i % ncols
90 | offset = 32 * ncols * (percent_idx + 1)
91 | raster[32 * row : 32 * (row + 1), offset + 32 * col : offset + 32 * (col + 1)] = samples[itr_num][i].reshape(32, 32)
92 |
93 | plt.imsave("./examples/ex_mnist.png", raster, vmin=0, vmax=255)
94 |
--------------------------------------------------------------------------------
/examples/ex_cifar.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import logging
3 |
4 | from einops import rearrange
5 | from matplotlib import pyplot as plt
6 | import numpy as np
7 | import torch
8 | import torch.optim as optim
9 | from torch.utils.data import DataLoader
10 | from torchvision import datasets, transforms
11 |
12 | from src.blocks import UNet
13 | from src.score_matching import ScoreMatchingModel, ScoreMatchingModelConfig
14 |
15 |
16 | if __name__ == "__main__":
17 |
18 | argparser = ArgumentParser()
19 | argparser.add_argument("--iterations", default=2000, type=int)
20 | argparser.add_argument("--batch-size", default=256, type=int)
21 | argparser.add_argument("--device", default="cuda", type=str, choices=("cuda", "cpu", "mps"))
22 | argparser.add_argument("--load-trained", default=0, type=int, choices=(0, 1))
23 | args = argparser.parse_args()
24 |
25 | logging.basicConfig(level=logging.INFO)
26 | logger = logging.getLogger(__name__)
27 |
28 | nn_module = UNet(3, 128, (1, 2, 4, 8))
29 | model = ScoreMatchingModel(
30 | nn_module=nn_module,
31 | input_shape=(3, 32, 32,),
32 | config=ScoreMatchingModelConfig(
33 | sigma_min=0.002,
34 | sigma_max=80.0,
35 | sigma_data=1.0,
36 | ),
37 | )
38 | model = model.to(args.device)
39 |
40 | # Standardize to [-1, +1]
41 | dataset_mean, dataset_sd = np.asarray([0.5, 0.5, 0.5]), np.asarray([0.5, 0.5, 0.5])
42 | dataset = datasets.CIFAR10(
43 | "./data/cifar_10",
44 | train=True,
45 | download=True,
46 | transform=transforms.Compose([
47 | transforms.RandomHorizontalFlip(),
48 | transforms.ToTensor(),
49 | transforms.Normalize(dataset_mean, dataset_sd),
50 | ])
51 | )
52 |
53 | optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
54 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.iterations)
55 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
56 |
57 | if args.load_trained:
58 | model.load_state_dict(torch.load("./ckpts/cifar_trained.pt"))
59 | else:
60 | iterator = iter(dataloader)
61 | for step_num in range(args.iterations):
62 | try:
63 | x_batch, _ = next(iterator)
64 | except StopIteration:
65 | iterator = iter(dataloader)
66 | x_batch, _ = next(iterator)
67 | x_batch = x_batch.to(args.device)
68 | optimizer.zero_grad()
69 | loss = model.loss(x_batch).mean()
70 | loss.backward()
71 | optimizer.step()
72 | scheduler.step()
73 | if step_num % 100 == 0:
74 | logger.info(f"Iter: {step_num}\t" + f"Loss: {loss.data:.2f}\t")
75 | torch.save(model.state_dict(), "./ckpts/cifar_trained.pt")
76 |
77 | model.eval()
78 |
79 | samples = model.sample(bsz=64, num_sampling_timesteps=20, device=args.device).cpu().numpy()
80 | samples = samples * rearrange(dataset_sd, "c -> 1 1 c 1 1") + rearrange(dataset_mean, "c -> 1 1 c 1 1")
81 | samples = np.rint(samples * 255).clip(min=0, max=255).astype(np.uint8)
82 |
83 | x_vis = rearrange(dataset.data[:64], "b h w c -> b c h w")
84 |
85 | nrows, ncols = 10, 2
86 | percents = (100, 75, 50, 25, 0)
87 | raster = np.zeros((3, nrows * 32, ncols * 32 * (len(percents) + 1)), dtype=np.uint8)
88 |
89 | for i in range(nrows * ncols):
90 | row, col = i // ncols, i % ncols
91 | raster[:, 32 * row : 32 * (row + 1), 32 * col : 32 * (col + 1)] = x_vis[i]
92 | for percent_idx, percent in enumerate(percents):
93 | itr_num = int(round(0.01 * percent * (len(samples) - 1)))
94 | for i in range(nrows * ncols):
95 | row, col = i // ncols, i % ncols
96 | offset = 32 * ncols * (percent_idx + 1)
97 | raster[:, 32 * row : 32 * (row + 1), offset + 32 * col : offset + 32 * (col + 1)] = samples[itr_num][i]
98 |
99 | raster = rearrange(raster, "c h w -> h w c")
100 | plt.imsave("./examples/ex_cifar.png", raster, vmin=0, vmax=255)
101 |
--------------------------------------------------------------------------------
/src/blocks_3d.py:
--------------------------------------------------------------------------------
1 | from einops import rearrange
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from src.blocks import unsqueeze_as, PositionalEmbedding, SelfAttention2d
6 |
7 |
8 | class BasicBlock3d(nn.Module):
9 | """
10 | BasicBlock: two 3x3 convs followed by a residual connection then ReLU.
11 | [He et al. CVPR 2016]
12 |
13 | BasicBlock(x) = ReLU( x + Conv3x3( ReLU( Conv3x3(x) ) ) )
14 |
15 | This version supports an additive shift parameterized by time.
16 | """
17 | def __init__(self, in_c, out_c, time_c):
18 | super().__init__()
19 | self.conv1 = nn.Conv3d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=False)
20 | self.bn1 = nn.BatchNorm3d(out_c)
21 | self.conv2 = nn.Conv3d(out_c, out_c, kernel_size=3, stride=1, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm3d(out_c)
23 | self.mlp_time = nn.Sequential(
24 | nn.Linear(time_c, time_c),
25 | nn.ReLU(),
26 | nn.Linear(time_c, out_c),
27 | )
28 | if in_c == out_c:
29 | self.shortcut = nn.Identity()
30 | else:
31 | self.shortcut = nn.Sequential(
32 | nn.Conv3d(in_c, out_c, kernel_size=1, stride=1, bias=False),
33 | nn.BatchNorm3d(out_c)
34 | )
35 |
36 | def forward(self, x, t):
37 | out = self.conv1(x)
38 | out = self.bn1(out)
39 | out = F.relu(out + unsqueeze_as(self.mlp_time(t), x))
40 | out = self.conv2(out)
41 | out = self.bn2(out)
42 | out = F.relu(out + self.shortcut(x))
43 | return out
44 |
45 |
46 | class UNet3d(nn.Module):
47 | """
48 | Simple implementation that closely mimics the one by Phil Wang (lucidrains).
49 | """
50 | def __init__(self, in_dim, embed_dim, dim_scales):
51 | super().__init__()
52 |
53 | self.init_embed = nn.Conv3d(in_dim, embed_dim, 1)
54 | self.time_embed = PositionalEmbedding(embed_dim)
55 |
56 | self.down_blocks = nn.ModuleList()
57 | self.up_blocks = nn.ModuleList()
58 |
59 | # Example:
60 | # in_dim=1, embed_dim=32, dim_scales=(1, 2, 4, 8) => all_dims=(32, 32, 64, 128, 256)
61 | all_dims = (embed_dim, *[embed_dim * s for s in dim_scales])
62 |
63 | for idx, (in_c, out_c) in enumerate(zip(
64 | all_dims[:-1],
65 | all_dims[1:],
66 | )):
67 | is_last = idx == len(all_dims) - 2
68 | self.down_blocks.extend(nn.ModuleList([
69 | BasicBlock3d(in_c, in_c, embed_dim),
70 | BasicBlock3d(in_c, in_c, embed_dim),
71 | nn.Conv3d(in_c, out_c, (1, 3, 3), (1, 2, 2), (0, 1, 1)) if not is_last else nn.Conv3d(in_c, out_c, 1),
72 | ]))
73 |
74 | for idx, (in_c, out_c, skip_c) in enumerate(zip(
75 | all_dims[::-1][:-1],
76 | all_dims[::-1][1:],
77 | all_dims[:-1][::-1],
78 | )):
79 | is_last = idx == len(all_dims) - 2
80 | self.up_blocks.extend(nn.ModuleList([
81 | BasicBlock3d(in_c + skip_c, in_c, embed_dim),
82 | BasicBlock3d(in_c + skip_c, in_c, embed_dim),
83 | nn.ConvTranspose3d(in_c, out_c, (1, 2, 2), (1, 2, 2)) if not is_last else nn.Conv3d(in_c, out_c, 1),
84 | ]))
85 |
86 | self.mid_blocks = nn.ModuleList([
87 | BasicBlock3d(all_dims[-1], all_dims[-1], embed_dim),
88 | SelfAttention2d(all_dims[-1]),
89 | BasicBlock3d(all_dims[-1], all_dims[-1], embed_dim),
90 | ])
91 | self.out_blocks = nn.ModuleList([
92 | BasicBlock3d(embed_dim, embed_dim, embed_dim),
93 | nn.Conv3d(embed_dim, in_dim, 1, bias=True),
94 | ])
95 |
96 | def forward(self, x, t):
97 | _, _, num_frames, *_ = x.shape
98 | x = self.init_embed(x)
99 | t = self.time_embed(t)
100 | skip_conns = []
101 | residual = x.clone()
102 |
103 | for block in self.down_blocks:
104 | if isinstance(block, BasicBlock3d):
105 | x = block(x, t)
106 | skip_conns.append(x)
107 | else:
108 | x = block(x)
109 | for block in self.mid_blocks:
110 | if isinstance(block, BasicBlock3d):
111 | x = block(x, t)
112 | elif isinstance(block, SelfAttention2d):
113 | x = rearrange(x, "b c t h w -> (b t) c h w")
114 | x = block(x)
115 | x = rearrange(x, "(b t) c h w -> b c t h w", t=num_frames)
116 | else:
117 | x = block(x)
118 | for block in self.up_blocks:
119 | if isinstance(block, BasicBlock3d):
120 | x = torch.cat((x, skip_conns.pop()), dim=1)
121 | x = block(x, t)
122 | else:
123 | x = block(x)
124 |
125 | x = x + residual
126 | for block in self.out_blocks:
127 | if isinstance(block, BasicBlock3d):
128 | x = block(x, t)
129 | else:
130 | x = block(x)
131 | return x
132 |
--------------------------------------------------------------------------------
/examples/ex_moving_mnist.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import logging
3 |
4 | from einops import rearrange
5 | from matplotlib import pyplot as plt
6 | import numpy as np
7 | from sklearn.datasets import fetch_openml
8 | import torch
9 | import torch.optim as optim
10 | from torch.utils.data import Dataset, DataLoader
11 |
12 | from src.blocks_3d import UNet3d
13 | from src.schedules import CosineSchedule
14 | from src.score_matching import ScoreMatchingModel, ScoreMatchingModelConfig
15 |
16 |
17 | class MovingMNISTDataset(Dataset):
18 |
19 | # Standardize to [-1, 1]
20 | input_mean = 127.5
21 | input_sd = 127.5
22 |
23 | def __init__(self, num_frames, height, width, velocity=5):
24 |
25 | self.num_frames = num_frames
26 | self.height = height
27 | self.width = width
28 | self.velocity = velocity
29 | self.mnist_L = 28 # 28 x 28 images
30 |
31 | # Load data from https://www.openml.org/d/554
32 | # (70000, 784) values between 0-255
33 | x, _ = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False, cache=True)
34 |
35 | x = ((x - MovingMNISTDataset.input_mean) / MovingMNISTDataset.input_sd).astype(np.float32)
36 | self.x = rearrange(x, "b (h w) -> b h w", h=28, w=28)
37 |
38 | # Enumerate all possible indices in this dataset, we allow 4 directions
39 | self.dir_to_dy_dx = {
40 | 0: (-self.velocity, 0),
41 | 1: (self.velocity, 0),
42 | 2: (0, -self.velocity),
43 | 3: (0, self.velocity),
44 | }
45 | self.index_shape = (len(self.dir_to_dy_dx), self.height - self.mnist_L, self.width - self.mnist_L, len(self.x))
46 |
47 | def __len__(self):
48 | return np.prod(self.index_shape)
49 |
50 | def __getitem__(self, idx):
51 |
52 | result = np.full((self.num_frames, self.height, self.width), dtype=np.float32, fill_value=-1)
53 | dir_idx, y_idx, x_idx, mnist_idx = np.unravel_index(idx, self.index_shape)
54 |
55 | R = self.mnist_L // 2
56 |
57 | # Denotes center coords
58 | y, x = y_idx + R, x_idx + R
59 | dy, dx = self.dir_to_dy_dx[dir_idx]
60 |
61 | for t in range(self.num_frames):
62 | result[t, y - R : y + R, x - R : x + R] = self.x[mnist_idx]
63 | if y + dy - R <= 0 or y + dy + R >= self.height:
64 | dy *= -1
65 | if x + dx - R <= 0 or x + dx + R >= self.width:
66 | dx *= -1
67 | y += dy
68 | x += dx
69 |
70 | # Draw boundaries along corner of image
71 | result[:, :, 0].fill(1.0)
72 | result[:, :, -1].fill(1.0)
73 | result[:, 0, :].fill(1.0)
74 | result[:, -1, :].fill(1.0)
75 |
76 | return result[np.newaxis]
77 |
78 | @staticmethod
79 | def visualize_one(x):
80 | num_channels, num_frames, height, width = x.shape
81 | assert num_channels == 1
82 | x = x * MovingMNISTDataset.input_sd + MovingMNISTDataset.input_mean
83 | x = np.clip(x, a_min=0, a_max=255)
84 | x = x.astype(np.uint8)
85 | raster = np.zeros((height, width * num_frames), dtype=np.uint8)
86 | for t in range(num_frames):
87 | raster[:, width * t : width * (t + 1)] = x[:, t].squeeze(axis=0)
88 | return raster
89 |
90 |
91 | if __name__ == "__main__":
92 |
93 | argparser = ArgumentParser()
94 | argparser.add_argument("--iterations", default=2000, type=int)
95 | argparser.add_argument("--batch-size", default=16, type=int)
96 | argparser.add_argument("--device", default="cuda", type=str, choices=("cuda", "cpu", "mps"))
97 | argparser.add_argument("--load-trained", default=0, type=int, choices=(0, 1))
98 | argparser.add_argument("--num-frames", default=4, type=int)
99 | argparser.add_argument("--velocity", default=4, type=int)
100 | args = argparser.parse_args()
101 |
102 | logging.basicConfig(level=logging.INFO)
103 | logger = logging.getLogger(__name__)
104 |
105 | nn_module = UNet3d(1, 128, (1, 2, 4, 8))
106 | model = ScoreMatchingModel(
107 | nn_module=nn_module,
108 | input_shape=(1, args.num_frames, 64, 64),
109 | config=ScoreMatchingModelConfig(
110 | sigma_min=0.002,
111 | sigma_max=80.0,
112 | sigma_data=1.0,
113 | train_sigma_schedule=CosineSchedule(kappa=0.25),
114 | test_sigma_schedule=CosineSchedule(kappa=0.25),
115 | ),
116 | )
117 | model = model.to(args.device)
118 |
119 | dataset = MovingMNISTDataset(args.num_frames, 64, 64, velocity=args.velocity)
120 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
121 |
122 | optimizer = optim.AdamW(model.parameters(), lr=0.001)
123 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.iterations)
124 |
125 | if args.load_trained:
126 | model.load_state_dict(torch.load("./ckpts/moving_mnist_trained.pt"))
127 | else:
128 | iterator = iter(dataloader)
129 | for step_num in range(args.iterations):
130 | try:
131 | x_batch = next(iterator)
132 | except StopIteration:
133 | iterator = iter(dataloader)
134 | x_batch = next(iterator)
135 | x_batch = x_batch.to(args.device)
136 | optimizer.zero_grad()
137 | loss = model.loss(x_batch).mean()
138 | loss.backward()
139 | optimizer.step()
140 | scheduler.step()
141 | if step_num % 100 == 0:
142 | logger.info(f"Iter: {step_num}\t" + f"Loss: {loss.data:.2f}\t")
143 | torch.save(model.state_dict(), "./ckpts/moving_mnist_trained.pt")
144 |
145 | model.eval()
146 |
147 | NUM_SHOWN = 4
148 |
149 | samples = model.sample(bsz=NUM_SHOWN, num_sampling_timesteps=20, device=args.device).cpu().numpy()
150 | samples = samples[0]
151 |
152 | gt_raster = np.concatenate([
153 | MovingMNISTDataset.visualize_one(dataset[i]) for i in range(NUM_SHOWN)
154 | ], axis=0)
155 | pred_raster = np.concatenate([
156 | MovingMNISTDataset.visualize_one(samples[i]) for i in range(NUM_SHOWN)
157 | ], axis=0)
158 |
159 | raster = np.concatenate([gt_raster, pred_raster], axis=0)
160 | plt.imsave("./examples/ex_moving_mnist.png", raster, vmin=0, vmax=255)
161 |
--------------------------------------------------------------------------------
/src/score_matching.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from src.blocks import unsqueeze_as
8 | from src.samplers import InstantaneousPrediction, Sampler, EulerSampler
9 | from src.schedules import CosineSchedule, NoiseSchedule
10 |
11 |
12 | @dataclass(frozen=True)
13 | class ScoreMatchingModelConfig:
14 |
15 | # Network configuration and loss weighting
16 | sigma_min: float = 0.002
17 | sigma_max: float = 80.0
18 | sigma_data: float = 0.5
19 |
20 | # Training time configuration
21 | loss_type: str = "l2"
22 | loss_weighting_type: str = "karras"
23 | train_sigma_schedule: NoiseSchedule = CosineSchedule()
24 |
25 | # Inference time configuration
26 | sampler: Sampler = EulerSampler()
27 | test_sigma_schedule: NoiseSchedule = CosineSchedule()
28 |
29 | def __post_init__(self):
30 | assert 0 <= self.sigma_min <= self.sigma_max
31 | assert self.sigma_min == self.train_sigma_schedule.sigma_min
32 | assert self.sigma_min == self.test_sigma_schedule.sigma_min
33 | assert self.sigma_max == self.train_sigma_schedule.sigma_max
34 | assert self.sigma_max == self.test_sigma_schedule.sigma_max
35 | assert self.loss_type in ("l1", "l2")
36 | assert self.loss_weighting_type in ("ones", "snr", "karras", "min_snr")
37 |
38 |
39 | class ScoreMatchingModel(nn.Module):
40 |
41 | def __init__(
42 | self,
43 | input_shape: tuple[int, ...],
44 | nn_module: nn.Module,
45 | config: ScoreMatchingModelConfig,
46 | ):
47 | super().__init__()
48 | self.input_shape = input_shape
49 | self.nn_module = nn_module
50 |
51 | # Input shape must be either (c,) or (c, h, w) or (c, t, h, w)
52 | assert len(input_shape) in (1, 3, 4)
53 |
54 | self.sigma_data = config.sigma_data
55 | self.sigma_min = config.sigma_min
56 | self.sigma_max = config.sigma_max
57 | self.loss_type = config.loss_type
58 | self.loss_weighting_type = config.loss_weighting_type
59 | self.train_sigma_schedule = config.train_sigma_schedule
60 | self.test_sigma_schedule = config.test_sigma_schedule
61 | self.sampler = config.sampler
62 |
63 | def nn_module_wrapper(self, x, sigma, num_discrete_chunks=10000):
64 | """
65 | This function does two things:
66 | 1. Implements Karras et al. 2022 pre-conditioning.
67 | 2. Converts sigma in range [sigma_min, sigma_max] into a discrete input.
68 |
69 | Parameters
70 | ----------
71 | x: (bsz, *self.input_shape)
72 | sigma: (bsz,)
73 | """
74 | c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
75 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
76 | c_out = sigma * self.sigma_data / (self.sigma_data ** 2 + sigma ** 2) ** 0.5
77 | c_skip, c_in, c_out = unsqueeze_as(c_skip, x), unsqueeze_as(c_in, x), unsqueeze_as(c_out, x)
78 | log_sigmas_percentile = (
79 | (torch.log(sigma) - math.log(self.sigma_min)) /
80 | (math.log(self.sigma_max) - math.log(self.sigma_min))
81 | )
82 | sigmas_discrete = torch.floor(num_discrete_chunks * log_sigmas_percentile)
83 | sigmas_discrete = sigmas_discrete.clamp_(min=0, max=num_discrete_chunks - 1).long()
84 | return c_out * self.nn_module(c_in * x, sigmas_discrete) + c_skip * x
85 |
86 | def loss(self, x):
87 | """
88 | Returns
89 | -------
90 | loss: (bsz, *input_shape)
91 | """
92 | bsz, *_ = x.shape
93 |
94 | rng = torch.rand((bsz,), device=x.device)
95 | sigma = self.train_sigma_schedule.get_sigma_ppf(rng)
96 |
97 | x_t = x + unsqueeze_as(sigma, x) * torch.randn_like(x)
98 | pred = self.nn_module_wrapper(x_t, sigma)
99 |
100 | if self.loss_type == "l2":
101 | loss = (0.5 * (x - pred) ** 2)
102 | elif self.loss_type == "l1":
103 | loss = (x - pred).abs()
104 | else:
105 | raise AssertionError(f"Invalid {self.loss_type=}.")
106 |
107 | if self.loss_weighting_type == "ones":
108 | loss_weights = torch.ones_like(sigma)
109 | elif self.loss_weighting_type == "snr":
110 | loss_weights = self.sigma_data ** 2 / sigma ** 2
111 | elif self.loss_weighting_type == "min_snr":
112 | loss_weights = torch.clamp(self.sigma_data ** 2 / sigma ** 2, max=5.0)
113 | elif self.loss_weighting_type == "karras":
114 | loss_weights = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
115 | else:
116 | raise AssertionError(f"Invalid {self.loss_weighting_type=}.")
117 |
118 | loss *= unsqueeze_as(loss_weights, loss)
119 | return loss
120 |
121 | @torch.no_grad()
122 | def sample(self, bsz, device, num_sampling_timesteps: int):
123 | """
124 | Parameters
125 | ----------
126 | num_sampling_timesteps: number of steps to take between sigma_max to sigma_min
127 |
128 | Returns
129 | -------
130 | samples: (sampling_timesteps + 1, bsz, *self.input_shape)
131 | index 0 corresponds to x_0
132 | index t corresponds to x_t
133 | last index corresponds to random noise
134 |
135 | Notes
136 | -----
137 | This is deterministic for now, solving the probability flow ODE.
138 | """
139 | assert num_sampling_timesteps >= 1
140 |
141 | linspace = torch.linspace(1.0, 0.0, num_sampling_timesteps + 1, device=device)
142 | sigmas = self.test_sigma_schedule.get_sigma_ppf(linspace)
143 |
144 | sigma_start = torch.empty((bsz,), device=device)
145 | sigma_end = torch.empty((bsz,), device=device)
146 |
147 | x = torch.randn((bsz, *self.input_shape), device=device) * self.sigma_max
148 | samples = torch.empty((num_sampling_timesteps + 1, bsz, *self.input_shape), device=device)
149 | samples[-1] = x * self.sigma_data / (self.sigma_max ** 2 + self.sigma_data ** 2) ** 0.5
150 |
151 | self.sampler.reset()
152 |
153 | for idx, (scalar_sigma_start, scalar_sigma_end) in enumerate(zip(sigmas[:-1], sigmas[1:])):
154 |
155 | sigma_start.fill_(scalar_sigma_start)
156 | sigma_end.fill_(scalar_sigma_end)
157 |
158 | pred_x_0 = self.nn_module_wrapper(x, sigma_start)
159 | pred_at_x_t = InstantaneousPrediction(scalar_sigma_start, x, pred_x_0)
160 | x = self.sampler.step(scalar_sigma_start, scalar_sigma_end, pred_at_x_t)
161 |
162 | normalization_factor = (
163 | self.sigma_data / (scalar_sigma_end ** 2 + self.sigma_data ** 2) ** 0.5
164 | )
165 | samples[-1 - idx - 1] = x * normalization_factor
166 |
167 | return samples
168 |
--------------------------------------------------------------------------------
/src/simple/diffusion.py:
--------------------------------------------------------------------------------
1 | """
2 | This file provides a simple, self-contained implementation of DDIM (with DDPM as a special case).
3 | """
4 | from dataclasses import dataclass
5 | import math
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from src.blocks import unsqueeze_to
11 |
12 |
13 | @dataclass(frozen=True)
14 | class DiffusionModelConfig:
15 |
16 | num_timesteps: int
17 | target_type: str = "pred_eps"
18 | noise_schedule_type: str = "cosine"
19 | loss_type: str = "l2"
20 | gamma_type: float = "ddim"
21 |
22 | def __post_init__(self):
23 | assert self.num_timesteps > 0
24 | assert self.target_type in ("pred_x_0", "pred_eps", "pred_v")
25 | assert self.noise_schedule_type in ("linear", "cosine")
26 | assert self.loss_type in ("l1", "l2")
27 | assert self.gamma_type in ("ddim", "ddpm")
28 |
29 |
30 | class DiffusionModel(nn.Module):
31 |
32 | def __init__(
33 | self,
34 | input_shape: tuple[int, ...],
35 | nn_module: nn.Module,
36 | config: DiffusionModelConfig,
37 | ):
38 | super().__init__()
39 | self.input_shape = input_shape
40 | self.nn_module = nn_module
41 | self.num_timesteps = config.num_timesteps
42 | self.target_type = config.target_type
43 | self.gamma_type = config.gamma_type
44 | self.noise_schedule_type = config.noise_schedule_type
45 | self.loss_type = config.loss_type
46 |
47 | # Input shape must be either (c,) or (c, h, w) or (c, t, h, w)
48 | assert len(input_shape) in (1, 3, 4)
49 |
50 | # Construct the noise schedule
51 | if self.noise_schedule_type == "linear":
52 | beta_t = torch.linspace(1e-4, 2e-2, self.num_timesteps + 1)
53 | alpha_t = torch.cumprod(1 - beta_t, dim=0) ** 0.5
54 | elif self.noise_schedule_type == "cosine":
55 | linspace = torch.linspace(0, 1, self.num_timesteps + 1)
56 | f_t = torch.cos((linspace + 0.008) / (1 + 0.008) * math.pi / 2) ** 2
57 | bar_alpha_t = f_t / f_t[0]
58 | beta_t = torch.zeros_like(bar_alpha_t)
59 | beta_t[1:] = (1 - (bar_alpha_t[1:] / bar_alpha_t[:-1])).clamp(min=0, max=0.999)
60 | alpha_t = torch.cumprod(1 - beta_t, dim=0) ** 0.5
61 | else:
62 | raise AssertionError(f"Invalid {self.noise_schedule_type=}.")
63 |
64 | # These tensors are shape (num_timesteps + 1, *self.input_shape)
65 | # For example, 2D: (num_timesteps + 1, 1, 1, 1)
66 | # 1D: (num_timesteps + 1, 1)
67 | alpha_t = unsqueeze_to(alpha_t, len(self.input_shape) + 1)
68 | sigma_t = (1 - alpha_t ** 2).clamp(min=0) ** 0.5
69 | self.register_buffer("alpha_t", alpha_t)
70 | self.register_buffer("sigma_t", sigma_t)
71 |
72 | def loss(self, x: torch.Tensor):
73 | """
74 | Returns
75 | -------
76 | loss: (bsz, *input_shape)
77 | """
78 | bsz, *_ = x.shape
79 | t_sample = torch.randint(1, self.num_timesteps + 1, size=(bsz,), device=x.device)
80 | eps = torch.randn_like(x)
81 | x_t = self.alpha_t[t_sample] * x + self.sigma_t[t_sample] * eps
82 | pred_target = self.nn_module(x_t, t_sample)
83 |
84 | if self.target_type == "pred_x_0":
85 | gt_target = x
86 | elif self.target_type == "pred_eps":
87 | gt_target = eps
88 | elif self.target_type == "pred_v":
89 | gt_target = self.alpha_t[t_sample] * eps - self.sigma_t[t_sample] * x
90 | else:
91 | raise AssertionError(f"Invalid {self.target_type=}.")
92 |
93 | if self.loss_type == "l2":
94 | loss = 0.5 * (gt_target - pred_target) ** 2
95 | elif self.loss_type == "l1":
96 | loss = torch.abs(gt_target - pred_target)
97 | else:
98 | raise AssertionError(f"Invalid {self.loss_type=}.")
99 |
100 | return loss
101 |
102 | @torch.no_grad()
103 | def sample(self, bsz: int, device: str, num_sampling_timesteps: int | None = None):
104 | """
105 | Parameters
106 | ----------
107 | num_sampling_timesteps: int. If unspecified, defaults to self.num_timesteps.
108 |
109 | Returns
110 | -------
111 | samples: (num_sampling_timesteps + 1, bsz, *self.input_shape)
112 | index 0 corresponds to x_0
113 | index t corresponds to x_t
114 | last index corresponds to random noise
115 | """
116 | num_sampling_timesteps = num_sampling_timesteps or self.num_timesteps
117 | assert 1 <= num_sampling_timesteps <= self.num_timesteps
118 |
119 | x = torch.randn((bsz, *self.input_shape), device=device)
120 | t_start = torch.empty((bsz,), dtype=torch.int64, device=device)
121 | t_end = torch.empty((bsz,), dtype=torch.int64, device=device)
122 |
123 | subseq = torch.linspace(self.num_timesteps, 0, num_sampling_timesteps + 1).round()
124 | samples = torch.empty((num_sampling_timesteps + 1, bsz, *self.input_shape), device=device)
125 | samples[-1] = x
126 |
127 | # Note that t_start > t_end we're traversing pairwise down subseq.
128 | # For example, subseq here could be [500, 400, 300, 200, 100, 0]
129 | for idx, (scalar_t_start, scalar_t_end) in enumerate(zip(subseq[:-1], subseq[1:])):
130 |
131 | t_start.fill_(scalar_t_start)
132 | t_end.fill_(scalar_t_end)
133 | noise = torch.zeros_like(x) if scalar_t_end == 0 else torch.randn_like(x)
134 |
135 | if self.gamma_type == "ddim":
136 | gamma_t = 0.0
137 | elif self.gamma_type == "ddpm":
138 | gamma_t = (
139 | self.sigma_t[t_end] / self.sigma_t[t_start] *
140 | (1 - self.alpha_t[t_start] ** 2 / self.alpha_t[t_end] ** 2) ** 0.5
141 | )
142 | else:
143 | raise AssertionError(f"Invalid {self.gamma_type=}.")
144 |
145 | nn_out = self.nn_module(x, t_start)
146 | if self.target_type == "pred_x_0":
147 | pred_x_0 = nn_out
148 | pred_eps = (x - self.alpha_t[t_start] * nn_out) / self.sigma_t[t_start]
149 | elif self.target_type == "pred_eps":
150 | pred_x_0 = (x - self.sigma_t[t_start] * nn_out) / self.alpha_t[t_start]
151 | pred_eps = nn_out
152 | elif self.target_type == "pred_v":
153 | pred_x_0 = self.alpha_t[t_start] * x - self.sigma_t[t_start] * nn_out
154 | pred_eps = self.sigma_t[t_start] * x + self.alpha_t[t_start] * nn_out
155 | else:
156 | raise AssertionError(f"Invalid {self.target_type=}.")
157 |
158 | x = (
159 | (self.alpha_t[t_end] * pred_x_0) +
160 | (self.sigma_t[t_end] ** 2 - gamma_t ** 2).clamp(min=0) ** 0.5 * pred_eps +
161 | (gamma_t * noise)
162 | )
163 | samples[-1 - idx - 1] = x
164 |
165 | return samples
166 |
167 |
--------------------------------------------------------------------------------
/src/experimental/consistency.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from src.blocks import unsqueeze_as
8 |
9 |
10 | @dataclass(frozen=True)
11 | class ConsistencyModelConfig:
12 |
13 | train_steps_limit: int
14 | loss_type: str = "l2"
15 | s_0: int = 10
16 | s_1: int = 1280
17 | sigma_min: float = 0.002
18 | sigma_max: float = 80.0
19 | sigma_data: float = 0.5
20 | rho: float = 7.0
21 | p_mean: float = -1.1
22 | p_std: float = 2.0
23 |
24 | def __post_init__(self):
25 | assert self.s_0 <= self.s_1
26 | assert self.sigma_min <= self.sigma_max
27 | assert self.loss_type in ("l1", "l2")
28 |
29 |
30 | class ConsistencyModel(nn.Module):
31 |
32 | def __init__(
33 | self,
34 | input_shape: tuple[int, ...],
35 | nn_module: nn.Module,
36 | config: ConsistencyModelConfig,
37 | ):
38 | super().__init__()
39 | self.input_shape = input_shape
40 | self.nn_module = nn_module
41 |
42 | # Input shape must be either (c,) or (c, h, w) or (c, t, h, w)
43 | assert len(input_shape) in (1, 3, 4)
44 |
45 | # Unpack config and pre-compute a few relevant constants
46 | self.p_mean = config.p_mean
47 | self.p_std = config.p_std
48 | self.s_0 = config.s_0
49 | self.s_1 = config.s_1
50 | self.sigma_data = config.sigma_data
51 | self.sigma_min = config.sigma_min
52 | self.sigma_max = config.sigma_max
53 | self.rho = config.rho
54 | self.loss_type = config.loss_type
55 | self.train_steps_limit = config.train_steps_limit
56 | self.sigma_min_root = (self.sigma_min) ** (1 / self.rho)
57 | self.sigma_max_root = (self.sigma_max) ** (1 / self.rho)
58 | self.k_prime = math.floor(self.train_steps_limit / (math.log2(self.s_1 / self.s_0) + 1))
59 | self.pseudo_huber_c = 0.00054 * math.prod(input_shape) ** 0.5
60 |
61 | def nn_module_wrapper(self, x, sigma, num_discrete_chunks=10000):
62 | """
63 | This function does two things:
64 | 1. Implements Karras et al. 2022 pre-conditioning
65 | 2. Converts sigma which have range (sigma_min, sigma_max) into a discrete input
66 |
67 | Parameters
68 | ----------
69 | x: (bsz, *self.input_shape)
70 | sigma: (bsz,)
71 | """
72 | c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
73 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
74 | c_out = sigma * self.sigma_data / (self.sigma_data ** 2 + sigma ** 2) ** 0.5
75 | c_skip, c_in, c_out = unsqueeze_as(c_skip, x), unsqueeze_as(c_in, x), unsqueeze_as(c_out, x)
76 | sigmas_percentile = (
77 | ((sigma ** (1 / self.rho)) - self.sigma_min_root) / (self.sigma_max_root - self.sigma_min_root)
78 | )
79 | sigmas_discrete = torch.floor(num_discrete_chunks * sigmas_percentile).clamp(max=num_discrete_chunks - 1).long()
80 | return c_out * self.nn_module(c_in * x, sigmas_discrete) + c_skip * x
81 |
82 | def loss(self, x, train_step_number: int):
83 | """
84 | Returns
85 | -------
86 | loss: (bsz, *input_shape)
87 | """
88 | bsz, *_ = x.shape
89 |
90 | # First compute the amount of discretization (number of sigmas needed)
91 | train_step_number = min(train_step_number, self.train_steps_limit)
92 | num_sigmas = min(self.s_0 * 2 ** math.floor(train_step_number / self.k_prime), self.s_1) + 1
93 |
94 | # Discretize the sigma space
95 | linspace = torch.linspace(0, 1, num_sigmas, device=x.device)
96 | sigmas = (self.sigma_min_root + linspace * (self.sigma_max_root - self.sigma_min_root)) ** self.rho
97 |
98 | # Draw a sample of sigma wiht importance sampling
99 | sigmas_weights = (
100 | torch.erf((torch.log(sigmas[1:]) - self.p_mean) / (2 ** 0.5 * self.p_std)) -
101 | torch.erf((torch.log(sigmas[:-1]) - self.p_mean) / (2 ** 0.5 * self.p_std))
102 | )
103 | sampled_idxs = torch.multinomial(sigmas_weights, num_samples=bsz, replacement=True)
104 | sigma_t = sigmas[sampled_idxs]
105 | sigma_t_plus_1 = sigmas[sampled_idxs + 1]
106 |
107 | # Forward the student and teacher, ensuring dropout is identical for both
108 | eps = torch.randn_like(x)
109 | with torch.random.fork_rng():
110 | x_t_plus_1 = x + unsqueeze_as(sigma_t_plus_1, x) * eps
111 | pred_t_plus_1 = self.nn_module_wrapper(x_t_plus_1, sigma_t_plus_1)
112 | with torch.no_grad(), torch.random.fork_rng():
113 | x_t = x + unsqueeze_as(sigma_t, x) * eps
114 | pred_t = self.nn_module_wrapper(x_t, sigma_t)
115 |
116 | # Compute loss and corresponding weights
117 | if self.loss_type == "l2":
118 | loss = (0.5 * (pred_t_plus_1 - pred_t) ** 2)
119 | elif self.loss_type == "l1":
120 | loss = (pred_t_plus_1 - pred_t).abs()
121 | else:
122 | raise AssertionError(f"Invalid {self.loss_type=}.")
123 |
124 | loss_weights = 1 / (sigma_t_plus_1 - sigma_t)
125 | loss *= unsqueeze_as(loss_weights, loss)
126 | return loss
127 |
128 | @torch.no_grad()
129 | def sample(self, bsz, device, num_sampling_timesteps: int):
130 | """
131 | Parameters
132 | ----------
133 | num_sampling_timesteps: number of steps to take between sigma_max to sigma_min
134 |
135 | Returns
136 | -------
137 | samples: (sampling_timesteps + 1, bsz, *self.input_shape)
138 | index 0 corresponds to x_0
139 | index t corresponds to x_t
140 | last index corresponds to random noise
141 | """
142 | assert num_sampling_timesteps >= 1
143 |
144 | linspace = torch.linspace(1, 0, num_sampling_timesteps + 1, device=device)
145 | sigmas = (self.sigma_min_root + linspace * (self.sigma_max_root - self.sigma_min_root)) ** self.rho
146 |
147 | sigma_start = torch.empty((bsz,), dtype=torch.int64, device=device)
148 | sigma_end = torch.empty((bsz,), dtype=torch.int64, device=device)
149 |
150 | x = torch.randn((bsz, *self.input_shape), device=device) * self.sigma_max_root
151 | samples = torch.empty((num_sampling_timesteps + 1, bsz, *self.input_shape), device=device)
152 | samples[-1] = x * self.sigma_data / (self.sigma_max ** 2 + self.sigma_data ** 2) ** 0.5
153 |
154 | for idx, (scalar_sigma_start, scalar_sigma_end) in enumerate(zip(sigmas[:-1], sigmas[1:])):
155 |
156 | sigma_start.fill_(scalar_sigma_start)
157 | sigma_end.fill_(scalar_sigma_end)
158 |
159 | x = self.nn_module_wrapper(x, sigma_start)
160 | eps = torch.randn_like(x)
161 | x += unsqueeze_as((sigma_end ** 2 - self.sigma_min ** 2).clamp(min=0) ** 0.5, x) * eps
162 |
163 | normalization_factor = self.sigma_data / (scalar_sigma_end ** 2 + self.sigma_data ** 2) ** 0.5
164 | samples[-1 - idx - 1] = x * normalization_factor
165 |
166 | return samples
167 |
168 |
--------------------------------------------------------------------------------
/src/blocks.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from einops import parse_shape, rearrange
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | def unsqueeze_to(tensor, target_ndim):
10 | assert tensor.ndim <= target_ndim
11 | while tensor.ndim < target_ndim:
12 | tensor = tensor.unsqueeze(-1)
13 | return tensor
14 |
15 | def unsqueeze_as(tensor, target_tensor):
16 | assert tensor.ndim <= target_tensor.ndim
17 | while tensor.ndim < target_tensor.ndim:
18 | tensor = tensor.unsqueeze(-1)
19 | return tensor
20 |
21 |
22 | class PositionalEmbedding(nn.Module):
23 | def __init__(self, dim, max_length=10000):
24 | super().__init__()
25 | self.register_buffer("embedding", self.make_embedding(dim, max_length))
26 |
27 | def forward(self, x):
28 | # Parameters
29 | # x: (bsz,) discrete
30 | return self.embedding[x]
31 |
32 | @staticmethod
33 | def make_embedding(dim, max_length=10000):
34 | embedding = torch.zeros(max_length, dim)
35 | position = torch.arange(0, max_length).unsqueeze(1)
36 | div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(max_length / 2 / math.pi) / dim))
37 | embedding[:, 0::2] = torch.sin(position * div_term)
38 | embedding[:, 1::2] = torch.cos(position * div_term)
39 | return embedding
40 |
41 |
42 | class FourierEmbedding(nn.Module):
43 | def __init__(self, dim, max_period=10000):
44 | super().__init__()
45 | assert dim % 2 == 0
46 | self.register_buffer("freqs", self.make_freqs(dim, max_period))
47 |
48 | def forward(self, x):
49 | # Parameters
50 | # x: (bsz,) continuous
51 | outer = torch.outer(x, self.freqs)
52 | embedding = torch.cat([torch.cos(outer), torch.sin(outer)], dim=-1)
53 | return embedding
54 |
55 | @staticmethod
56 | def make_freqs(dim, max_period=10000):
57 | return torch.exp(torch.arange(0, dim, 2) * -math.log(max_period) / dim)
58 |
59 |
60 | class FFN(nn.Module):
61 | def __init__(self, in_dim, embed_dim):
62 | super().__init__()
63 | self.init_embed = nn.Linear(in_dim, embed_dim)
64 | self.time_embed = PositionalEmbedding(embed_dim)
65 | self.model = nn.Sequential(
66 | nn.Linear(embed_dim, embed_dim),
67 | nn.ReLU(),
68 | nn.Linear(embed_dim, embed_dim),
69 | nn.ReLU(),
70 | nn.Linear(embed_dim, embed_dim),
71 | nn.ReLU(),
72 | nn.Linear(embed_dim, embed_dim),
73 | nn.ReLU(),
74 | nn.Linear(embed_dim, embed_dim),
75 | nn.ReLU(),
76 | nn.Linear(embed_dim, in_dim),
77 | )
78 |
79 | def forward(self, x, t):
80 | x = self.init_embed(x)
81 | t = self.time_embed(t)
82 | return self.model(x + t)
83 |
84 |
85 | class BasicBlock(nn.Module):
86 | """
87 | BasicBlock: two 3x3 convs followed by a residual connection then ReLU.
88 | [He et al. CVPR 2016]
89 |
90 | BasicBlock(x) = ReLU( x + Conv3x3( ReLU( Conv3x3(x) ) ) )
91 |
92 | This version supports an additive shift parameterized by time.
93 | """
94 | def __init__(self, in_c, out_c, time_c):
95 | super().__init__()
96 | self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=False)
97 | self.bn1 = nn.BatchNorm2d(out_c)
98 | self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, bias=False)
99 | self.bn2 = nn.BatchNorm2d(out_c)
100 | self.mlp_time = nn.Sequential(
101 | nn.Linear(time_c, time_c),
102 | nn.ReLU(),
103 | nn.Linear(time_c, out_c),
104 | )
105 | if in_c == out_c:
106 | self.shortcut = nn.Identity()
107 | else:
108 | self.shortcut = nn.Sequential(
109 | nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, bias=False),
110 | nn.BatchNorm2d(out_c)
111 | )
112 |
113 | def forward(self, x, t):
114 | out = self.conv1(x)
115 | out = self.bn1(out)
116 | out = F.relu(out + unsqueeze_as(self.mlp_time(t), x))
117 | out = self.conv2(out)
118 | out = self.bn2(out)
119 | out = F.relu(out + self.shortcut(x))
120 | return out
121 |
122 |
123 | class SelfAttention2d(nn.Module):
124 | """
125 | Only implements the MultiHeadAttention component, not the PositionwiseFFN component.
126 | """
127 | def __init__(self, dim, num_heads=8, dropout_prob=0.1):
128 | super().__init__()
129 | self.dim = dim
130 | self.num_heads = num_heads
131 | self.q_conv = nn.Conv2d(dim, dim, 1, bias=True)
132 | self.k_conv = nn.Conv2d(dim, dim, 1, bias=True)
133 | self.v_conv = nn.Conv2d(dim, dim, 1, bias=True)
134 | self.o_conv = nn.Conv2d(dim, dim, 1, bias=True)
135 | self.dropout = nn.Dropout(dropout_prob)
136 |
137 | def forward(self, x):
138 | q = self.q_conv(x)
139 | k = self.k_conv(x)
140 | v = self.v_conv(x)
141 | q = rearrange(q, "b (g c) h w -> (b g) c (h w)", g=self.num_heads)
142 | k = rearrange(k, "b (g c) h w -> (b g) c (h w)", g=self.num_heads)
143 | v = rearrange(v, "b (g c) h w -> (b g) c (h w)", g=self.num_heads)
144 | a = torch.einsum("b c s, b c t -> b s t", q, k) / self.dim ** 0.5
145 | a = self.dropout(torch.softmax(a, dim=-1))
146 | o = torch.einsum("b s t, b c t -> b c s", a, v)
147 | o = rearrange(o, "(b g) c (h w) -> b (g c) h w", g=self.num_heads, w=x.shape[-1])
148 | return x + self.o_conv(o)
149 |
150 |
151 | class UNet(nn.Module):
152 | """
153 | Simple implementation that closely mimics the one by Phil Wang (lucidrains).
154 | """
155 | def __init__(self, in_dim, embed_dim, dim_scales):
156 | super().__init__()
157 |
158 | self.init_embed = nn.Conv2d(in_dim, embed_dim, 1)
159 | self.time_embed = PositionalEmbedding(embed_dim)
160 |
161 | self.down_blocks = nn.ModuleList()
162 | self.up_blocks = nn.ModuleList()
163 |
164 | # Example:
165 | # in_dim=1, embed_dim=32, dim_scales=(1, 2, 4, 8) => all_dims=(32, 32, 64, 128, 256)
166 | all_dims = (embed_dim, *[embed_dim * s for s in dim_scales])
167 |
168 | for idx, (in_c, out_c) in enumerate(zip(
169 | all_dims[:-1],
170 | all_dims[1:],
171 | )):
172 | is_last = idx == len(all_dims) - 2
173 | self.down_blocks.extend(nn.ModuleList([
174 | BasicBlock(in_c, in_c, embed_dim),
175 | BasicBlock(in_c, in_c, embed_dim),
176 | nn.Conv2d(in_c, out_c, 3, 2, 1) if not is_last else nn.Conv2d(in_c, out_c, 1),
177 | ]))
178 |
179 | for idx, (in_c, out_c, skip_c) in enumerate(zip(
180 | all_dims[::-1][:-1],
181 | all_dims[::-1][1:],
182 | all_dims[:-1][::-1],
183 | )):
184 | is_last = idx == len(all_dims) - 2
185 | self.up_blocks.extend(nn.ModuleList([
186 | BasicBlock(in_c + skip_c, in_c, embed_dim),
187 | BasicBlock(in_c + skip_c, in_c, embed_dim),
188 | nn.ConvTranspose2d(in_c, out_c, 2, 2) if not is_last else nn.Conv2d(in_c, out_c, 1),
189 | ]))
190 |
191 | self.mid_blocks = nn.ModuleList([
192 | BasicBlock(all_dims[-1], all_dims[-1], embed_dim),
193 | SelfAttention2d(all_dims[-1]),
194 | BasicBlock(all_dims[-1], all_dims[-1], embed_dim),
195 | ])
196 | self.out_blocks = nn.ModuleList([
197 | BasicBlock(embed_dim, embed_dim, embed_dim),
198 | nn.Conv2d(embed_dim, in_dim, 1, bias=True),
199 | ])
200 |
201 | def forward(self, x, t):
202 | x = self.init_embed(x)
203 | t = self.time_embed(t)
204 | skip_conns = []
205 | residual = x.clone()
206 |
207 | for block in self.down_blocks:
208 | if isinstance(block, BasicBlock):
209 | x = block(x, t)
210 | skip_conns.append(x)
211 | else:
212 | x = block(x)
213 | for block in self.mid_blocks:
214 | if isinstance(block, BasicBlock):
215 | x = block(x, t)
216 | else:
217 | x = block(x)
218 | for block in self.up_blocks:
219 | if isinstance(block, BasicBlock):
220 | x = torch.cat((x, skip_conns.pop()), dim=1)
221 | x = block(x, t)
222 | else:
223 | x = block(x)
224 |
225 | x = x + residual
226 | for block in self.out_blocks:
227 | if isinstance(block, BasicBlock):
228 | x = block(x, t)
229 | else:
230 | x = block(x)
231 | return x
232 |
233 |
234 | class MultiHeadAttention(nn.Module):
235 | """
236 | Multi-Head Attention [Vaswani et al. NeurIPS 2017].
237 | Scaled dot-product attention is performed over V, using K as keys and Q as queries.
238 | MultiHeadAttention(Q, V) = FC(SoftMax(1/√d QKᵀ) V) (concatenated over multiple heads),
239 | Notes
240 | -----
241 | (1) Q, K, V can be of different dimensions. Q and K are projected to dim_a and V to dim_o.
242 | (2) We assume the last and second last dimensions correspond to the feature (i.e. embedding)
243 | and token (i.e. words) dimensions respectively.
244 | """
245 | def __init__(self, dim_q, dim_k, dim_v, num_heads=8, dropout_prob=0.1, dim_a=None, dim_o=None):
246 | super().__init__()
247 | if dim_a is None:
248 | dim_a = dim_q
249 | if dim_o is None:
250 | dim_o = dim_q
251 | self.dim_a, self.dim_o, self.num_heads = dim_a, dim_o, num_heads
252 | self.fc_q = nn.Linear(dim_q, dim_a, bias=True)
253 | self.fc_k = nn.Linear(dim_k, dim_a, bias=True)
254 | self.fc_v = nn.Linear(dim_v, dim_o, bias=True)
255 | self.fc_o = nn.Linear(dim_o, dim_o, bias=True)
256 | self.dropout = nn.Dropout(dropout_prob)
257 | for module in (self.fc_q, self.fc_k, self.fc_v, self.fc_o):
258 | nn.init.xavier_normal_(module.weight)
259 | nn.init.constant_(module.bias, 0.)
260 |
261 | def forward(self, q, k, v, mask=None):
262 | """
263 | Perform multi-head attention with given queries and values.
264 | Parameters
265 | ----------
266 | q: (bsz, tsz, dim_q)
267 | k: (bsz, tsz, dim_k)
268 | v: (bsz, tsz, dim_v)
269 | mask: (bsz, tsz) or (bsz, tsz, tsz), where 1 denotes keep and 0 denotes remove
270 | Returns
271 | -------
272 | O: (bsz, tsz, dim_o)
273 | """
274 | bsz, tsz, _ = q.shape
275 | q, k, v = self.fc_q(q), self.fc_k(k), self.fc_v(v)
276 | q = torch.cat(q.split(self.dim_a // self.num_heads, dim=-1), dim=0)
277 | k = torch.cat(k.split(self.dim_a // self.num_heads, dim=-1), dim=0)
278 | v = torch.cat(v.split(self.dim_o // self.num_heads, dim=-1), dim=0)
279 | a = q @ k.transpose(-1, -2) / self.dim_a ** 0.5
280 | if mask is not None:
281 | assert mask.ndim in (2, 3)
282 | if mask.ndim == 3:
283 | mask = mask.repeat(self.num_heads, 1, 1)
284 | if mask.ndim == 2:
285 | mask = mask.unsqueeze(-2).repeat(self.num_heads, tsz, 1)
286 | a.masked_fill_(mask == 0, -65504)
287 | a = self.dropout(torch.softmax(a, dim=-1))
288 | o = self.fc_o(torch.cat((a @ v).split(bsz, dim=0), dim=-1))
289 | return o
290 |
291 |
292 | class PositionwiseFFN(nn.Module):
293 | """
294 | Position-wise FFN [Vaswani et al. NeurIPS 2017].
295 | """
296 | def __init__(self, dim, hidden_dim, dropout_prob=0.1):
297 | super().__init__()
298 | self.fc1 = nn.Linear(dim, hidden_dim, bias=True)
299 | self.fc2 = nn.Linear(hidden_dim, dim, bias=True)
300 | self.dropout = nn.Dropout(dropout_prob)
301 | for module in (self.fc1, self.fc2):
302 | nn.init.kaiming_normal_(module.weight)
303 | nn.init.constant_(module.bias, 0.)
304 |
305 | def forward(self, x):
306 | return self.fc2(self.dropout(F.relu(self.fc1(x))))
307 |
308 |
309 | class EncoderBlock(nn.Module):
310 | """
311 | Transformer encoder block [Vaswani et al. NeurIPS 2017].
312 | Note that this is the pre-LN version [Nguyen and Salazar 2019].
313 | """
314 | def __init__(self, dim, hidden_dim, num_heads=8, dropout_prob=0.1):
315 | super().__init__()
316 | self.attn = MultiHeadAttention(dim, dim, dim, num_heads, dropout_prob)
317 | self.ffn = PositionwiseFFN(dim, hidden_dim, dropout_prob)
318 | self.dropout = nn.Dropout(dropout_prob)
319 | self.ln1 = nn.LayerNorm(dim)
320 | self.ln2 = nn.LayerNorm(dim)
321 | self.mlp_time_1 = nn.Sequential(
322 | nn.Linear(dim, dim),
323 | nn.ReLU(),
324 | nn.Linear(dim, dim),
325 | )
326 | self.mlp_time_2 = nn.Sequential(
327 | nn.Linear(dim, dim),
328 | nn.ReLU(),
329 | nn.Linear(dim, dim),
330 | )
331 |
332 | def forward(self, x, t, mask=None):
333 | x_ = self.ln1(x)
334 | x_ = x_ + self.mlp_time_1(t).unsqueeze(1)
335 | x = x + self.dropout(self.attn(x_, x_, x_, mask))
336 | x_ = self.ln2(x)
337 | x_ = x_ + self.mlp_time_2(t).unsqueeze(1)
338 | x = x + self.dropout(self.ffn(x_))
339 | return x
340 |
341 |
342 | class ViT(nn.Module):
343 | """
344 | Simple version of Vision Transformer [Dosovitsky et al. 2020].
345 | """
346 | def __init__(self, in_dim, embed_dim, num_layers, patch_shape):
347 | super().__init__()
348 |
349 | self.patch_h, self.patch_w = patch_shape
350 | self.patch_embed = nn.Linear(in_dim * math.prod(patch_shape), embed_dim, bias=True)
351 | self.time_embed = nn.Parameter(PositionalEmbedding.make_embedding(embed_dim))
352 | self.pos_embed = nn.Parameter(PositionalEmbedding.make_embedding(embed_dim))
353 |
354 | self.blocks = nn.ModuleList([
355 | EncoderBlock(embed_dim, 4 * embed_dim) for _ in range(num_layers)
356 | ])
357 | self.ln = nn.LayerNorm(embed_dim)
358 | self.out_embed = nn.Linear(embed_dim, in_dim * math.prod(patch_shape), bias=True)
359 |
360 | def forward(self, x, t):
361 |
362 | shape_info = parse_shape(x, "b c h w")
363 |
364 | x = rearrange(x, "b c (h ph) (w pw) -> b (h w) (ph pw c)", ph=self.patch_h, pw=self.patch_w)
365 | x = self.patch_embed(x) + rearrange(self.pos_encoding[:x.shape[1]], "t c -> () t c")
366 | t = self.time_embed[t]
367 |
368 | for block in self.blocks:
369 | x = block(x, t)
370 |
371 | x = self.ln(x)
372 | x = self.out_embed(x)
373 |
374 | x = rearrange(x, "b (h w) (ph pw c) -> b c (h ph) (w pw)",
375 | h=shape_info["h"] // self.patch_h,
376 | w=shape_info["w"] // self.patch_w,
377 | ph=self.patch_h,
378 | pw=self.patch_w,
379 | )
380 | return x
381 |
--------------------------------------------------------------------------------