├── data └── .gitkeep ├── ckpts └── .gitkeep ├── .gitignore ├── pyproject.toml ├── examples ├── ex_2d.png ├── ex_coins.png ├── ex_mnist.png ├── ex_2d_quiver.png ├── ex_mnist_crop.png ├── ex_coins.py ├── ex_2d.py ├── ex_mnist_simple.py ├── ex_mnist.py ├── ex_cifar.py └── ex_moving_mnist.py ├── figures ├── fig_tail.py ├── fig_noisy_img.py ├── fig_lambdas_discrete.py ├── fig_karras_param.py ├── fig_lambdas_continuous.py ├── fig_schedules.py └── fig_shifts.py ├── README.md └── src ├── schedules.py ├── samplers.py ├── blocks_3d.py ├── score_matching.py ├── simple └── diffusion.py ├── experimental └── consistency.py └── blocks.py /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ckpts/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | figures/*.png 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | line-length=120 3 | -------------------------------------------------------------------------------- /examples/ex_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/diffusion/HEAD/examples/ex_2d.png -------------------------------------------------------------------------------- /examples/ex_coins.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/diffusion/HEAD/examples/ex_coins.png -------------------------------------------------------------------------------- /examples/ex_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/diffusion/HEAD/examples/ex_mnist.png -------------------------------------------------------------------------------- /examples/ex_2d_quiver.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/diffusion/HEAD/examples/ex_2d_quiver.png -------------------------------------------------------------------------------- /examples/ex_mnist_crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/diffusion/HEAD/examples/ex_mnist_crop.png -------------------------------------------------------------------------------- /figures/fig_tail.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | import scipy as sp 4 | import scipy.stats 5 | 6 | 7 | if __name__ == "__main__": 8 | 9 | T = 1000 10 | linspace = np.linspace(0, 6, T) 11 | tail = np.linspace(2.5, 6, T) 12 | 13 | plt.figure(figsize=(7, 2.0), dpi=300) 14 | plt.subplot(1, 2, 1) 15 | plt.plot(linspace, sp.stats.expon.pdf(linspace), color="black", label="$f(x)$") 16 | plt.fill_between(tail, sp.stats.expon.pdf(tail), color="grey") 17 | plt.legend(loc="upper right") 18 | plt.yticks([]) 19 | plt.xticks([]) 20 | plt.ylim((0, 1)) 21 | plt.subplot(1, 2, 2) 22 | plt.plot(-linspace[::-1], sp.stats.expon.pdf(linspace)[::-1], color="black", label="$f(y)$") 23 | plt.fill_between(-tail[::-1], sp.stats.expon.pdf(tail)[::-1], color="grey") 24 | plt.legend(loc="upper right") 25 | plt.yticks([]) 26 | plt.xticks([]) 27 | plt.ylim((0, 1)) 28 | plt.tight_layout() 29 | plt.savefig("./figures/fig_tail.png") 30 | plt.show() 31 | -------------------------------------------------------------------------------- /figures/fig_noisy_img.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from skimage import data 4 | 5 | 6 | def add_noise(img, scale): 7 | return np.clip(img + scale * np.random.randn(*img.shape), a_min=-1, a_max=1) 8 | 9 | 10 | if __name__ == "__main__": 11 | 12 | figure = plt.figure(figsize=(7, 2.6), dpi=300) 13 | 14 | img = data.camera() 15 | img = 2 * (img.astype(np.float32) / 255 - 0.5) 16 | assert img.shape == (512, 512) 17 | 18 | img_1x = img 19 | img_2x = img[::2, ::2] 20 | img_4x = img[::4, ::4] 21 | 22 | img_1x = add_noise(img_1x, 0.25) 23 | img_2x = add_noise(img_2x, 0.25) 24 | img_4x = add_noise(img_4x, 0.25) 25 | 26 | plt.subplot(1, 3, 1) 27 | plt.imshow(img_1x, cmap="gray") 28 | plt.title("512 x 512") 29 | plt.axis("off") 30 | plt.subplot(1, 3, 2) 31 | plt.imshow(img_2x, cmap="gray") 32 | plt.title("256 x 256") 33 | plt.axis("off") 34 | plt.subplot(1, 3, 3) 35 | plt.imshow(img_4x, cmap="gray") 36 | plt.title("128 x 128") 37 | plt.axis("off") 38 | plt.savefig("./figures/fig_noisy_img.png") 39 | plt.show() 40 | -------------------------------------------------------------------------------- /figures/fig_lambdas_discrete.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | 4 | from figures.fig_schedules import get_cosine_schedule 5 | 6 | 7 | if __name__ == "__main__": 8 | 9 | T = 15 10 | linspace = np.linspace(-7, 7, 500) 11 | 12 | plt.figure(figsize=(7, 2.6), dpi=300) 13 | 14 | # 15 | # First plot the cosine schedule 16 | # 17 | alpha = get_cosine_schedule(T) 18 | sigma = (1 - alpha ** 2) ** 0.5 19 | 20 | log_lambd = np.log(sigma) - np.log(alpha) 21 | 22 | for idx, one_lambd in enumerate(log_lambd): 23 | label = "Cosine Schedule" if idx == 0 else None 24 | plt.axvline(one_lambd, color="black", linestyle="--", label=label, alpha=0.5) 25 | 26 | # 27 | # Second plot the score matching geometric distribution 28 | # 29 | min_sigma = 0.002 30 | max_sigma = 80.0 31 | log_lambd = np.linspace(np.log(min_sigma), np.log(max_sigma), T) 32 | 33 | for idx, one_lambd in enumerate(log_lambd): 34 | label = "Score Matching" if idx == 0 else None 35 | plt.axvline(one_lambd, color="maroon", linestyle="--", label=label, alpha=0.5) 36 | 37 | plt.xlim((-7, 7)) 38 | plt.yticks([]) 39 | plt.xlabel("$\\log\\lambda$") 40 | plt.legend() 41 | plt.tight_layout() 42 | plt.tick_params(bottom=False, top=False) 43 | plt.savefig("./figures/fig_lambdas_discrete.png") 44 | 45 | -------------------------------------------------------------------------------- /figures/fig_karras_param.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | 4 | 5 | def c_skip(x, sigma_data): 6 | return sigma_data ** 2 / (np.exp(x) ** 2 + sigma_data ** 2) 7 | 8 | def c_out(x, sigma_data): 9 | return np.exp(x) * sigma_data / (np.exp(x) ** 2 + sigma_data ** 2) ** 0.5 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | T = 1000 15 | linspace = np.linspace(-8, 6, T) 16 | 17 | plt.figure(figsize=(7, 2.6), dpi=300) 18 | plt.subplot(1, 2, 1) 19 | plt.plot(linspace, c_skip(linspace, 1.0), color="black", label="$\\sigma_\\text{data}=1.0$") 20 | plt.plot(linspace, c_skip(linspace, 0.5), color="red", label="$\\sigma_\\text{data}=0.5$") 21 | plt.plot(linspace, c_skip(linspace, 0.25), color="blue", label="$\\sigma_\\text{data}=0.25$") 22 | plt.legend(loc="upper left") 23 | plt.xlabel("$\\log\\sigma$") 24 | plt.title("$c_\\text{skip}$") 25 | plt.tick_params(bottom=False, top=False) 26 | plt.subplot(1, 2, 2) 27 | plt.plot(linspace, c_out(linspace, 1.0), color="black", label="$\\sigma_\\text{data}=1.0$") 28 | plt.plot(linspace, c_out(linspace, 0.5), color="red", label="$\\sigma_\\text{data}=0.5$") 29 | plt.plot(linspace, c_out(linspace, 0.25), color="blue", label="$\\sigma_\\text{data}=0.25$") 30 | plt.legend(loc="upper left") 31 | plt.xlabel("$\\log\\sigma$") 32 | plt.title("$c_\\text{out}$") 33 | plt.tick_params(bottom=False, top=False) 34 | plt.tight_layout() 35 | plt.savefig("./figures/fig_karras_param.png") 36 | plt.show() 37 | 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Diffusion Models From Scratch 2 | 3 | March 2023. 4 | 5 | --- 6 | 7 | My notes for this repository ended up longer than expected, too long to be rendered by GitHub. 8 | 9 | So instead of putting notes here, they've been moved to my website. 10 | 11 | [[**This blog post**]](https://www.tonyduan.com/diffusion/index.html) explains the intuition and derivations behind diffusion. 12 | 13 | --- 14 | 15 | This codebase provides a *minimalist* re-production of the MNIST example below. 16 | 17 | It clocks in at well under 500 LOC. 18 | 19 |

20 | 21 |

22 | 23 | (Left: MNIST groundtruth. Right: MNIST sampling starting from random noise). 24 | 25 | --- 26 | 27 | **Example Usage** 28 | 29 | Code below is copied from `examples/ex_mnist_simple.py`, omitting boilerplate training code. 30 | 31 | ```python 32 | # Initialization 33 | nn_module = UNet(in_dim=1, embed_dim=128, dim_scales=(1, 2, 4, 8)) 34 | model = DiffusionModel( 35 | nn_module=nn_module, 36 | input_shape=(1, 32, 32,), 37 | config=DiffusionModelConfig( 38 | num_timesteps=500, 39 | target_type="pred_x_0", 40 | gamma_type="ddim", 41 | noise_schedule_type="cosine", 42 | ), 43 | ) 44 | 45 | # Training Loop 46 | for i in range(args.iterations): 47 | loss = model.loss(x_batch).mean() 48 | loss.backward() 49 | 50 | # Sampling, the number of timesteps can be less than T to accelerate 51 | samples = model.sample(bsz=64, num_sampling_timesteps=None, device="cuda") 52 | ``` 53 | -------------------------------------------------------------------------------- /figures/fig_lambdas_continuous.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | import scipy as sp 4 | import scipy.stats 5 | 6 | 7 | def cosine_pdf(x, kappa): 8 | # 9 | # Equivalent to hyperbolic secant distribution (ignore truncation due to sigma_min, sigma_max). 10 | # 11 | return 1 / np.pi / np.cosh(x + np.log(kappa)) 12 | 13 | def normal_pdf(x, loc, scale, sigma_min, sigma_max): 14 | # 15 | # Simple truncated normal. 16 | # 17 | return sp.stats.truncnorm.pdf(x, a=np.log(sigma_min), b=np.log(sigma_max), loc=loc, scale=scale) 18 | 19 | def exp_pdf(x, sigma_min, sigma_max, rho): 20 | # 21 | # Adjust for the truncation due to [sigma_min, sigma_max] using partition constant. 22 | # 23 | partition_constant = 1 - np.exp(-(np.log(sigma_max) - np.log(sigma_min)) / rho) 24 | return 1 / rho * np.exp((x - np.log(sigma_max)) / rho) / partition_constant 25 | 26 | 27 | if __name__ == "__main__": 28 | 29 | T = 1000 30 | linspace = np.linspace(-8, 6, T) 31 | 32 | plt.figure(figsize=(7, 2.6), dpi=300) 33 | plt.plot(linspace, cosine_pdf(linspace, 1.0), color="black", label="Cosine($\\kappa=1.0$)") 34 | plt.plot(linspace, cosine_pdf(linspace, 0.5), color="red", label="Cosine($\\kappa=0.5$)") 35 | plt.plot(linspace, normal_pdf(linspace, -1.2, 1.2, 0.002, 80), color="green", label="N($\\mu=-1.2, \\gamma=1.2$)") 36 | plt.plot(linspace, exp_pdf(linspace, 0.002, 80, 7), color="blue", label="Exp($\\rho=7$)") 37 | plt.legend() 38 | plt.yticks([]) 39 | plt.xlabel("$\\log\\sigma$") 40 | plt.tight_layout() 41 | plt.tick_params(bottom=False, top=False) 42 | plt.savefig("./figures/fig_lambdas_continuous.png") 43 | plt.show() 44 | 45 | -------------------------------------------------------------------------------- /figures/fig_schedules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | 5 | 6 | def get_linear_schedule(num_timesteps): 7 | beta_t = np.linspace(1e-4, 2e-2, num_timesteps + 1) 8 | alpha_t = np.cumprod(1 - beta_t, axis=0) ** 0.5 9 | return alpha_t 10 | 11 | def get_cosine_schedule(num_timesteps): 12 | linspace = np.linspace(0, 1, num_timesteps + 1) 13 | f_t = np.cos((linspace + 0.008) / (1 + 0.008) * math.pi / 2) ** 2 14 | bar_alpha_t = f_t / f_t[0] 15 | beta_t = np.zeros_like(bar_alpha_t) 16 | beta_t[1:] = np.clip(1 - (bar_alpha_t[1:] / bar_alpha_t[:-1]), a_min=0, a_max=0.999) 17 | alpha_t = np.cumprod(1 - beta_t, axis=0) ** 0.5 18 | return alpha_t 19 | 20 | 21 | if __name__ == "__main__": 22 | 23 | T = 500 24 | T_rng = np.arange(T + 1) 25 | 26 | alpha_linear = get_linear_schedule(T) 27 | alpha_cosine = get_cosine_schedule(T) 28 | 29 | sigma_linear = (1 - alpha_linear ** 2) ** 0.5 30 | sigma_cosine = (1 - alpha_cosine ** 2) ** 0.5 31 | 32 | logsnr_linear = 2 * (np.log(alpha_linear) - np.log(sigma_linear)) 33 | logsnr_cosine = 2 * (np.log(alpha_cosine) - np.log(sigma_cosine)) 34 | 35 | plt.figure(figsize=(7, 2.6), dpi=200) 36 | 37 | plt.subplot(1, 3, 1) 38 | plt.plot(T_rng, alpha_linear, label="Linear", color="#00204E") 39 | plt.plot(T_rng, alpha_cosine, label="Cosine", color="#800000") 40 | plt.ylabel("$\\alpha_t$") 41 | plt.xlabel("Timestep $t$") 42 | plt.xticks([]) 43 | plt.yticks([0,1]) 44 | plt.legend(loc="upper right") 45 | 46 | plt.subplot(1, 3, 2) 47 | plt.plot(T_rng, sigma_linear, label="Linear", color="#00204E") 48 | plt.plot(T_rng, sigma_cosine, label="Cosine", color="#800000") 49 | plt.ylabel("$\\sigma_t$") 50 | plt.xlabel("Timestep $t$") 51 | plt.xticks([]) 52 | plt.yticks([0, 1]) 53 | plt.legend(loc="upper right") 54 | 55 | plt.subplot(1, 3, 3) 56 | plt.plot(T_rng, logsnr_linear, label="Linear", color="#00204E") 57 | plt.plot(T_rng, logsnr_cosine, label="Cosine", color="#800000") 58 | plt.ylabel("$\\log SNR_t$") 59 | plt.xlabel("Timestep $t$") 60 | plt.xticks([]) 61 | plt.yticks([-10, 10]) 62 | plt.ylim([-11, 11]) 63 | plt.legend(loc="upper right") 64 | 65 | plt.tight_layout() 66 | plt.savefig("./figures/fig_schedules.png") 67 | 68 | -------------------------------------------------------------------------------- /src/schedules.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from abc import ABC, abstractmethod 3 | from functools import cached_property 4 | import math 5 | 6 | import torch 7 | 8 | 9 | @dataclass(frozen=True) 10 | class NoiseSchedule(ABC): 11 | 12 | @abstractmethod 13 | def get_sigma_ppf(self, p: torch.Tensor, sigma_min: float, sigma_max: float): 14 | """ 15 | Transform p in [0, 1] to sigma between [sigma_min, sigma_max] via icdf. 16 | """ 17 | pass 18 | 19 | 20 | @dataclass(frozen=True) 21 | class CosineSchedule(NoiseSchedule): 22 | 23 | sigma_min: float = 0.002 24 | sigma_max: float = 80.0 25 | kappa: float = 1.0 26 | 27 | @cached_property 28 | def theta_min(self): 29 | logsnr_min = 2 * (math.log(self.kappa) - math.log(self.sigma_min)) 30 | return math.atan(math.exp(-0.5 * logsnr_min)) 31 | 32 | @cached_property 33 | def theta_max(self): 34 | logsnr_max = 2 * (math.log(self.kappa) - math.log(self.sigma_max)) 35 | return math.atan(math.exp(-0.5 * logsnr_max)) 36 | 37 | def get_sigma_ppf(self, p: torch.Tensor): 38 | return torch.tan(self.theta_min + p * (self.theta_max - self.theta_min)) / self.kappa 39 | 40 | 41 | @dataclass(frozen=True) 42 | class ExponentialSchedule(NoiseSchedule): 43 | 44 | sigma_min: float = 0.002 45 | sigma_max: float = 80.0 46 | rho: float = 7.0 47 | 48 | @cached_property 49 | def inv_sigma_min(self): 50 | return self.sigma_min ** (1 / self.rho) 51 | 52 | @cached_property 53 | def inv_sigma_max(self): 54 | return self.sigma_max ** (1 / self.rho) 55 | 56 | def get_sigma_ppf(self, p: torch.Tensor): 57 | return (self.inv_sigma_min + p * (self.inv_sigma_max - self.inv_sigma_min)) ** self.rho 58 | 59 | 60 | @dataclass(frozen=True) 61 | class NormalSchedule(NoiseSchedule): 62 | 63 | sigma_min: float = 0.002 64 | sigma_max: float = 80.0 65 | mean: float = -1.2 66 | std: float = 1.2 67 | 68 | @cached_property 69 | def erf_sigma_min(self): 70 | return math.erf(math.log(self.sigma_min)) 71 | 72 | @cached_property 73 | def erf_sigma_max(self): 74 | return math.erf(math.log(self.sigma_max)) 75 | 76 | def get_sigma_ppf(self, p: torch.Tensor): 77 | z = torch.special.erfinv(self.erf_sigma_min + p * (self.erf_sigma_max - self.erf_sigma_min)) 78 | return torch.exp(z * self.std + self.mean) 79 | 80 | -------------------------------------------------------------------------------- /examples/ex_coins.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import logging 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.optim as optim 8 | 9 | from src.blocks import FFN 10 | from src.simple.diffusion import DiffusionModel, DiffusionModelConfig 11 | 12 | 13 | def gen_data(n=512, d=48): 14 | x = np.vstack([ 15 | np.random.permutation(np.r_[np.ones(3 * d // 4), np.zeros(d // 4)]) 16 | for _ in range(n) 17 | ]) 18 | assert np.all(np.sum(x, axis=1) == 3 * d // 4) 19 | x = x * 2 - 1 20 | return x.astype(np.float32) 21 | 22 | 23 | if __name__ == "__main__": 24 | 25 | argparser = ArgumentParser() 26 | argparser.add_argument("--n", default=512, type=int) 27 | argparser.add_argument("--iterations", default=2000, type=int) 28 | args = argparser.parse_args() 29 | 30 | logging.basicConfig(level=logging.INFO) 31 | logger = logging.getLogger(__name__) 32 | 33 | nn_module = FFN(in_dim=48, embed_dim=96) 34 | model = DiffusionModel( 35 | nn_module=nn_module, 36 | input_shape=(48,), 37 | config=DiffusionModelConfig( 38 | num_timesteps=100, 39 | target_type="pred_x_0", 40 | gamma_type="ddim", 41 | noise_schedule_type="cosine", 42 | ), 43 | ) 44 | 45 | optimizer = optim.Adam(model.parameters(), lr=0.001) 46 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.iterations) 47 | 48 | for i in range(args.iterations): 49 | x = torch.from_numpy(gen_data(args.n)) 50 | optimizer.zero_grad() 51 | loss = model.loss(x).mean() 52 | loss.backward() 53 | optimizer.step() 54 | scheduler.step() 55 | if i % 100 == 0: 56 | logger.info(f"Iter: {i}\t" + f"Loss: {loss.data:.2f}\t") 57 | 58 | model.eval() 59 | samples = model.sample(bsz=512, device="cpu") 60 | num_heads = (samples[0] > 0).sum(dim=1).cpu().numpy() 61 | 62 | plt.figure(figsize=(8, 3)) 63 | plt.hist(num_heads, range=(10, 50), bins=20, alpha=0.5, color="grey", label="Samples") 64 | logger.info(f"Samples mean {np.mean(num_heads):.2f} sd {np.std(num_heads):.2f}") 65 | 66 | random_sample = np.random.rand(args.n, 48) > 0.5 67 | num_heads = random_sample.sum(axis=1) 68 | 69 | plt.hist(num_heads, range=(10, 50), bins=20, alpha=0.5, color="black", label="Prior") 70 | logger.info(f"Prior mean {np.mean(num_heads):.2f} sd {np.std(num_heads):.2f}") 71 | 72 | plt.tight_layout() 73 | plt.savefig("examples/ex_coins.png") 74 | -------------------------------------------------------------------------------- /figures/fig_shifts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import expit 3 | from matplotlib import pyplot as plt 4 | 5 | from figures.fig_schedules import get_cosine_schedule 6 | 7 | 8 | def draw_lines(logsnr, color="black", stride=10): 9 | alpha = expit(logsnr) ** 0.5 10 | sigma = expit(-logsnr) ** 0.5 11 | for x, y in zip(alpha[::stride], sigma[::stride]): 12 | plt.plot([0, x], [0, y], color=color, alpha=0.5) 13 | 14 | 15 | if __name__ == "__main__": 16 | 17 | T = 201 18 | T_rng = np.linspace(0, 1, T + 1) 19 | 20 | alpha = get_cosine_schedule(T) 21 | sigma = (1 - alpha ** 2) ** 0.5 22 | cosine_logsnr = 2 * (np.log(alpha) - np.log(sigma)) 23 | 24 | shifted_down_logsnr = cosine_logsnr + 2 * np.log(0.5) 25 | shifted_up_logsnr = cosine_logsnr + 2 * np.log(2) 26 | 27 | plt.figure(figsize=(7, 5.2), dpi=200) 28 | 29 | plt.subplot(2, 3, 1) 30 | plt.plot(T_rng, cosine_logsnr, label="Linear", color="black") 31 | plt.plot(T_rng, shifted_down_logsnr, label="Shifted down", color="navy") 32 | plt.plot(T_rng, shifted_up_logsnr, label="Shifted up", color="maroon") 33 | plt.ylabel("$\\log SNR_t$") 34 | plt.xlabel("Timestep $t$") 35 | plt.xticks([]) 36 | plt.yticks([-10, 10]) 37 | plt.ylim([-11, 11]) 38 | plt.legend(loc="lower left") 39 | 40 | plt.subplot(2, 3, 2) 41 | plt.plot(T_rng, expit(cosine_logsnr) ** 0.5, label="Linear", color="black") 42 | plt.plot(T_rng, expit(shifted_down_logsnr) ** 0.5, label="Shifted down", color="navy") 43 | plt.plot(T_rng, expit(shifted_up_logsnr) ** 0.5, label="Shifted up", color="maroon") 44 | plt.ylabel("$\\alpha_t$") 45 | plt.xlabel("Timestep $t$") 46 | plt.xticks([]) 47 | plt.yticks([0, 1]) 48 | plt.legend(loc="lower left") 49 | 50 | plt.subplot(2, 3, 3) 51 | plt.plot(T_rng, expit(-cosine_logsnr) ** 0.5, label="Linear", color="black") 52 | plt.plot(T_rng, expit(-shifted_down_logsnr) ** 0.5, label="Shifted down", color="navy") 53 | plt.plot(T_rng, expit(-shifted_up_logsnr) ** 0.5, label="Shifted up", color="maroon") 54 | plt.ylabel("$\\sigma_t$") 55 | plt.xlabel("Timestep $t$") 56 | plt.xticks([]) 57 | plt.yticks([0, 1]) 58 | plt.legend(loc="lower left") 59 | 60 | plt.subplot(2, 3, 4) 61 | draw_lines(cosine_logsnr) 62 | plt.xticks([]) 63 | plt.yticks([]) 64 | plt.xlabel("$x_t$") 65 | plt.ylabel("$\\epsilon_t$") 66 | plt.title("Cosine Schedule") 67 | 68 | plt.subplot(2, 3, 5) 69 | draw_lines(shifted_down_logsnr, color="navy") 70 | plt.xticks([]) 71 | plt.yticks([]) 72 | plt.xlabel("$x_t$") 73 | plt.ylabel("$\\epsilon_t$") 74 | plt.title("Shifted down $+2\\log\\left(\\frac{1}{2}\\right)$") 75 | 76 | plt.subplot(2, 3, 6) 77 | draw_lines(shifted_up_logsnr, color="maroon") 78 | plt.xticks([]) 79 | plt.yticks([]) 80 | plt.xlabel("$x_t$") 81 | plt.ylabel("$\\epsilon_t$") 82 | plt.title("Shifted up $+2\\log(2)$") 83 | 84 | plt.tight_layout() 85 | plt.savefig("./figures/fig_shifts.png") 86 | 87 | -------------------------------------------------------------------------------- /examples/ex_2d.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import logging 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.optim as optim 8 | 9 | from src.blocks import FFN 10 | from src.simple.diffusion import DiffusionModel, DiffusionModelConfig 11 | 12 | 13 | def gen_mixture_data(n=512): 14 | x = np.r_[np.random.randn(n // 2, 2) + np.array([-5, 0]), 15 | np.random.randn(n // 2, 2) + np.array([5, 0])] 16 | x = (x - np.mean(x, axis=0, keepdims=True)) / np.std(x, axis=0, ddof=1, keepdims=True) 17 | return x.astype(np.float32) 18 | 19 | def plot_data(x): 20 | plt.hist2d(x[:,0].numpy(), x[:,1].numpy(), bins=100, range=np.array([(-3, 3), (-6, 6)])) 21 | plt.axis("off") 22 | 23 | 24 | if __name__ == "__main__": 25 | 26 | argparser = ArgumentParser() 27 | argparser.add_argument("--n", default=512, type=int) 28 | argparser.add_argument("--iterations", default=2000, type=int) 29 | args = argparser.parse_args() 30 | 31 | logging.basicConfig(level=logging.INFO) 32 | logger = logging.getLogger(__name__) 33 | 34 | nn_module = FFN(in_dim=2, embed_dim=16) 35 | model = DiffusionModel( 36 | nn_module=nn_module, 37 | input_shape=(2,), 38 | config=DiffusionModelConfig( 39 | num_timesteps=10, 40 | noise_schedule_type="linear", 41 | target_type="pred_eps", 42 | gamma_type="ddpm", 43 | ), 44 | ) 45 | 46 | optimizer = optim.Adam(model.parameters(), lr=0.001) 47 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.iterations) 48 | 49 | for i in range(args.iterations): 50 | x = torch.from_numpy(gen_mixture_data(args.n)) 51 | optimizer.zero_grad() 52 | loss = model.loss(x).mean() 53 | loss.backward() 54 | optimizer.step() 55 | scheduler.step() 56 | if i % 100 == 0: 57 | logger.info(f"Iter: {i}\t" + f"Loss: {loss.data:.2f}\t") 58 | 59 | model.eval() 60 | samples = model.sample(bsz=512, device="cpu") 61 | 62 | plt.figure(figsize=(12, 12)) 63 | plt.subplot(4, 4, 1) 64 | plot_data(x) 65 | plt.title("Actual data") 66 | for t in range(11): 67 | plt.subplot(4, 4, t + 2) 68 | plot_data(samples[t]) 69 | plt.title(f"Sample t={t}") 70 | 71 | samples = model.sample(bsz=512, device="cpu", num_sampling_timesteps=3) 72 | 73 | for t in range(4): 74 | plt.subplot(4, 4, t + 12 + 1) 75 | plot_data(samples[t]) 76 | plt.title(f"Accelerated Sample t={t}") 77 | 78 | plt.tight_layout() 79 | plt.savefig("./examples/ex_2d.png") 80 | 81 | x_grid, y_grid = torch.meshgrid(torch.linspace(-3, 3, 30), torch.linspace(-6, 6, 20), indexing="ij") 82 | x = torch.from_numpy(np.c_[x_grid.flatten(), y_grid.flatten()]) 83 | bsz, _ = x.shape 84 | 85 | fig, axes = plt.subplots(3, 4, figsize=(12, 9)) 86 | axes[0, 0].axis("off") 87 | 88 | for scalar_t in range(0, model.num_timesteps + 1): 89 | 90 | t = torch.full((bsz,), fill_value=scalar_t, device=x.device) 91 | 92 | with torch.no_grad(): 93 | pred_eps = model.nn_module(x, t) 94 | 95 | pred_eps = pred_eps.numpy().reshape(30, 20, 2) 96 | one_axes = axes[(scalar_t + 1) // 4, (scalar_t + 1) % 4] 97 | one_axes.quiver(x_grid.numpy(), y_grid.numpy(), -pred_eps[..., 0], -pred_eps[..., 1]) 98 | one_axes.axis("off") 99 | one_axes.set_title(f"Score {scalar_t}") 100 | 101 | fig.tight_layout() 102 | fig.savefig("./examples/ex_2d_quiver.png") 103 | 104 | -------------------------------------------------------------------------------- /examples/ex_mnist_simple.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import logging 3 | 4 | from einops import rearrange 5 | from matplotlib import pyplot as plt 6 | import numpy as np 7 | from sklearn.datasets import fetch_openml 8 | import torch 9 | import torch.optim as optim 10 | 11 | from src.blocks import UNet 12 | from src.simple.diffusion import DiffusionModel, DiffusionModelConfig 13 | 14 | 15 | if __name__ == "__main__": 16 | 17 | argparser = ArgumentParser() 18 | argparser.add_argument("--iterations", default=2000, type=int) 19 | argparser.add_argument("--batch-size", default=512, type=int) 20 | argparser.add_argument("--device", default="cuda", type=str, choices=("cuda", "cpu", "mps")) 21 | args = argparser.parse_args() 22 | 23 | logging.basicConfig(level=logging.INFO) 24 | logger = logging.getLogger(__name__) 25 | 26 | # Load data from https://www.openml.org/d/554 27 | # (70000, 784) values between 0-255 28 | x, _ = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False, cache=True) 29 | 30 | # Reshape to 32x32 31 | x = rearrange(x, "b (h w) -> b h w", h=28, w=28) 32 | x = np.pad(x, pad_width=((0, 0), (2, 2), (2, 2))) 33 | x = rearrange(x, "b h w -> b (h w)") 34 | 35 | # Standardize to [-1, 1] 36 | input_mean = np.full((1, 32 ** 2), fill_value=127.5, dtype=np.float32) 37 | input_sd = np.full((1, 32 ** 2), fill_value=127.5, dtype=np.float32) 38 | x = ((x - input_mean) / input_sd).astype(np.float32) 39 | 40 | nn_module = UNet(1, 128, (1, 2, 4, 8)) 41 | model = DiffusionModel( 42 | nn_module=nn_module, 43 | input_shape=(1, 32, 32,), 44 | config=DiffusionModelConfig( 45 | num_timesteps=500, 46 | target_type="pred_x_0", 47 | gamma_type="ddim", 48 | noise_schedule_type="cosine", 49 | ), 50 | ) 51 | model = model.to(args.device) 52 | 53 | optimizer = optim.Adam(model.parameters(), lr=0.001) 54 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.iterations) 55 | 56 | for i in range(args.iterations): 57 | x_batch = x[np.random.choice(len(x), args.batch_size)] 58 | x_batch = torch.from_numpy(x_batch).to(args.device) 59 | x_batch = rearrange(x_batch, "b (h w) -> b () h w", h=32, w=32) 60 | optimizer.zero_grad() 61 | loss = model.loss(x_batch).mean() 62 | loss.backward() 63 | optimizer.step() 64 | scheduler.step() 65 | if i % 100 == 0: 66 | logger.info(f"Iter: {i}\t" + f"Loss: {loss.data:.2f}\t") 67 | 68 | model.eval() 69 | 70 | samples = model.sample(bsz=64, num_sampling_timesteps=None, device=args.device).cpu().numpy() 71 | samples = rearrange(samples, "t b () h w -> t b (h w)") 72 | samples = samples * input_sd + input_mean 73 | x_vis = x[:64] * input_sd + input_mean 74 | 75 | nrows, ncols = 10, 2 76 | percents = (100, 75, 50, 25, 0) 77 | raster = np.zeros((nrows * 32, ncols * 32 * (len(percents) + 1)), dtype=np.float32) 78 | 79 | for i in range(nrows * ncols): 80 | row, col = i // ncols, i % ncols 81 | raster[32 * row : 32 * (row + 1), 32 * col : 32 * (col + 1)] = x_vis[i].reshape(32, 32) 82 | for percent_idx, percent in enumerate(percents): 83 | itr_num = int(round(0.01 * percent * (len(samples) - 1))) 84 | for i in range(nrows * ncols): 85 | row, col = i // ncols, i % ncols 86 | offset = 32 * ncols * (percent_idx + 1) 87 | raster[32 * row : 32 * (row + 1), offset + 32 * col : offset + 32 * (col + 1)] = samples[itr_num][i].reshape(32, 32) 88 | 89 | plt.imsave("./examples/ex_mnist.png", raster, vmin=0, vmax=255) 90 | -------------------------------------------------------------------------------- /src/samplers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections import deque 3 | from dataclasses import dataclass 4 | from functools import cached_property 5 | 6 | import torch 7 | 8 | 9 | @dataclass 10 | class InstantaneousPrediction: 11 | 12 | sigma: float 13 | x_t: torch.Tensor 14 | pred_x_0: torch.Tensor 15 | 16 | @cached_property 17 | def pred_eps(self): 18 | return (self.x_t - self.pred_x_0) / self.sigma 19 | 20 | 21 | class Sampler(ABC): 22 | 23 | @abstractmethod 24 | def reset(self): 25 | pass 26 | 27 | @abstractmethod 28 | def step(self, sigma_t: float, sigma_t_plus_1: float, pred_at_x_t: InstantaneousPrediction): 29 | """ 30 | Take a step from sigma_t to sigma_t_plus_1. 31 | """ 32 | pass 33 | 34 | 35 | class EulerSampler(Sampler): 36 | 37 | def reset(self): 38 | pass 39 | 40 | def step(self, sigma_t: float, sigma_t_plus_1: float, pred_at_x_t: InstantaneousPrediction): 41 | return pred_at_x_t.x_t + pred_at_x_t.pred_eps * (sigma_t_plus_1 - sigma_t) 42 | 43 | 44 | class MultistepDPMSampler(Sampler): 45 | 46 | def reset(self): 47 | self.history: deque[InstantaneousPrediction] = deque(maxlen=2) # Order k=3 48 | 49 | def step(self, sigma_t: float, sigma_t_plus_1: float, pred_at_x_t: InstantaneousPrediction): 50 | if len(self.history) == 0: 51 | x_t_plus_1 = self.step_first_order(sigma_t, sigma_t_plus_1, pred_at_x_t) 52 | elif len(self.history) == 1: 53 | x_t_plus_1 = self.step_second_order(sigma_t, sigma_t_plus_1, pred_at_x_t) 54 | else: 55 | x_t_plus_1 = self.step_third_order(sigma_t, sigma_t_plus_1, pred_at_x_t) 56 | self.history.append(pred_at_x_t) 57 | return x_t_plus_1 58 | 59 | def step_first_order(self, sigma_t: float, sigma_t_plus_1: float, pred_at_x_t: InstantaneousPrediction): 60 | d_sigma = sigma_t_plus_1 - sigma_t 61 | return pred_at_x_t.x_t + pred_at_x_t.pred_eps * d_sigma 62 | 63 | def step_second_order(self, sigma_t: float, sigma_t_plus_1: float, pred_at_x_t: InstantaneousPrediction): 64 | pred_at_x_t_minus_1 = self.history[-1] 65 | d_sigma = sigma_t_plus_1 - sigma_t 66 | pred_first_derivative = ( 67 | (pred_at_x_t.pred_eps - pred_at_x_t_minus_1.pred_eps) / 68 | (pred_at_x_t.sigma - pred_at_x_t_minus_1.sigma) 69 | ) 70 | return ( 71 | pred_at_x_t.x_t + 72 | pred_at_x_t.pred_eps * d_sigma + 73 | (1/2) * pred_first_derivative * d_sigma ** 2 74 | ) 75 | 76 | def step_third_order(self, sigma_t: float, sigma_t_plus_1: float, pred_at_x_t: InstantaneousPrediction): 77 | pred_at_x_t_minus_1 = self.history[-1] 78 | pred_at_x_t_minus_2 = self.history[-2] 79 | d_sigma = sigma_t_plus_1 - sigma_t 80 | pred_first_derivative = ( 81 | (pred_at_x_t.pred_eps - pred_at_x_t_minus_1.pred_eps) / 82 | (pred_at_x_t.sigma - pred_at_x_t_minus_1.sigma) 83 | ) 84 | pred_first_derivative_past = ( 85 | (pred_at_x_t_minus_1.pred_eps - pred_at_x_t_minus_2.pred_eps) / 86 | (pred_at_x_t_minus_1.sigma - pred_at_x_t_minus_2.sigma) 87 | ) 88 | pred_second_derivative = ( 89 | (pred_first_derivative - pred_first_derivative_past) / 90 | (0.5 * (pred_at_x_t.sigma + pred_at_x_t_minus_1.sigma) - 91 | 0.5 * (pred_at_x_t_minus_1.sigma + pred_at_x_t_minus_2.sigma)) 92 | ) 93 | return ( 94 | pred_at_x_t.x_t + 95 | pred_at_x_t.pred_eps * d_sigma + 96 | (1/2) * pred_first_derivative * d_sigma ** 2 + 97 | (1/6) * pred_second_derivative * d_sigma ** 3 98 | ) 99 | -------------------------------------------------------------------------------- /examples/ex_mnist.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import logging 3 | 4 | from einops import rearrange 5 | from matplotlib import pyplot as plt 6 | import numpy as np 7 | from sklearn.datasets import fetch_openml 8 | import torch 9 | import torch.optim as optim 10 | 11 | from src.blocks import UNet 12 | from src.score_matching import ScoreMatchingModel, ScoreMatchingModelConfig 13 | 14 | 15 | if __name__ == "__main__": 16 | 17 | argparser = ArgumentParser() 18 | argparser.add_argument("--iterations", default=2000, type=int) 19 | argparser.add_argument("--batch-size", default=512, type=int) 20 | argparser.add_argument("--device", default="cuda", type=str, choices=("cuda", "cpu", "mps")) 21 | argparser.add_argument("--load-trained", default=0, type=int, choices=(0, 1)) 22 | args = argparser.parse_args() 23 | 24 | logging.basicConfig(level=logging.INFO) 25 | logger = logging.getLogger(__name__) 26 | 27 | # Load data from https://www.openml.org/d/554 28 | # (70000, 784) values between 0-255 29 | x, _ = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False, cache=True) 30 | 31 | # Reshape to 32x32 32 | x = rearrange(x, "b (h w) -> b h w", h=28, w=28) 33 | x = np.pad(x, pad_width=((0, 0), (2, 2), (2, 2))) 34 | x = rearrange(x, "b h w -> b (h w)") 35 | 36 | # Standardize to [-1, 1] 37 | input_mean = np.full((1, 32 ** 2), fill_value=127.5, dtype=np.float32) 38 | input_sd = np.full((1, 32 ** 2), fill_value=127.5, dtype=np.float32) 39 | x = ((x - input_mean) / input_sd).astype(np.float32) 40 | 41 | nn_module = UNet(1, 128, (1, 2, 4, 8)) 42 | model = ScoreMatchingModel( 43 | nn_module=nn_module, 44 | input_shape=(1, 32, 32,), 45 | config=ScoreMatchingModelConfig( 46 | sigma_min=0.002, 47 | sigma_max=80.0, 48 | sigma_data=1.0, 49 | ), 50 | ) 51 | model = model.to(args.device) 52 | 53 | optimizer = optim.Adam(model.parameters(), lr=0.001) 54 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.iterations) 55 | 56 | if args.load_trained: 57 | model.load_state_dict(torch.load("./ckpts/mnist_trained.pt")) 58 | else: 59 | for step_num in range(args.iterations): 60 | x_batch = x[np.random.choice(len(x), args.batch_size)] 61 | x_batch = torch.from_numpy(x_batch).to(args.device) 62 | x_batch = rearrange(x_batch, "b (h w) -> b () h w", h=32, w=32) 63 | optimizer.zero_grad() 64 | loss = model.loss(x_batch).mean() 65 | loss.backward() 66 | optimizer.step() 67 | scheduler.step() 68 | if step_num % 100 == 0: 69 | logger.info(f"Iter: {step_num}\t" + f"Loss: {loss.data:.2f}\t") 70 | torch.save(model.state_dict(), "./ckpts/mnist_trained.pt") 71 | 72 | model.eval() 73 | 74 | samples = model.sample(bsz=64, num_sampling_timesteps=20, device=args.device).cpu().numpy() 75 | samples = rearrange(samples, "t b () h w -> t b (h w)") 76 | samples = samples * input_sd + input_mean 77 | x_vis = x[:64] * input_sd + input_mean 78 | 79 | nrows, ncols = 10, 2 80 | percents = (100, 75, 50, 25, 0) 81 | raster = np.zeros((nrows * 32, ncols * 32 * (len(percents) + 1)), dtype=np.float32) 82 | 83 | for i in range(nrows * ncols): 84 | row, col = i // ncols, i % ncols 85 | raster[32 * row : 32 * (row + 1), 32 * col : 32 * (col + 1)] = x_vis[i].reshape(32, 32) 86 | for percent_idx, percent in enumerate(percents): 87 | itr_num = int(round(0.01 * percent * (len(samples) - 1))) 88 | for i in range(nrows * ncols): 89 | row, col = i // ncols, i % ncols 90 | offset = 32 * ncols * (percent_idx + 1) 91 | raster[32 * row : 32 * (row + 1), offset + 32 * col : offset + 32 * (col + 1)] = samples[itr_num][i].reshape(32, 32) 92 | 93 | plt.imsave("./examples/ex_mnist.png", raster, vmin=0, vmax=255) 94 | -------------------------------------------------------------------------------- /examples/ex_cifar.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import logging 3 | 4 | from einops import rearrange 5 | from matplotlib import pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | from torchvision import datasets, transforms 11 | 12 | from src.blocks import UNet 13 | from src.score_matching import ScoreMatchingModel, ScoreMatchingModelConfig 14 | 15 | 16 | if __name__ == "__main__": 17 | 18 | argparser = ArgumentParser() 19 | argparser.add_argument("--iterations", default=2000, type=int) 20 | argparser.add_argument("--batch-size", default=256, type=int) 21 | argparser.add_argument("--device", default="cuda", type=str, choices=("cuda", "cpu", "mps")) 22 | argparser.add_argument("--load-trained", default=0, type=int, choices=(0, 1)) 23 | args = argparser.parse_args() 24 | 25 | logging.basicConfig(level=logging.INFO) 26 | logger = logging.getLogger(__name__) 27 | 28 | nn_module = UNet(3, 128, (1, 2, 4, 8)) 29 | model = ScoreMatchingModel( 30 | nn_module=nn_module, 31 | input_shape=(3, 32, 32,), 32 | config=ScoreMatchingModelConfig( 33 | sigma_min=0.002, 34 | sigma_max=80.0, 35 | sigma_data=1.0, 36 | ), 37 | ) 38 | model = model.to(args.device) 39 | 40 | # Standardize to [-1, +1] 41 | dataset_mean, dataset_sd = np.asarray([0.5, 0.5, 0.5]), np.asarray([0.5, 0.5, 0.5]) 42 | dataset = datasets.CIFAR10( 43 | "./data/cifar_10", 44 | train=True, 45 | download=True, 46 | transform=transforms.Compose([ 47 | transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | transforms.Normalize(dataset_mean, dataset_sd), 50 | ]) 51 | ) 52 | 53 | optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01) 54 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.iterations) 55 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) 56 | 57 | if args.load_trained: 58 | model.load_state_dict(torch.load("./ckpts/cifar_trained.pt")) 59 | else: 60 | iterator = iter(dataloader) 61 | for step_num in range(args.iterations): 62 | try: 63 | x_batch, _ = next(iterator) 64 | except StopIteration: 65 | iterator = iter(dataloader) 66 | x_batch, _ = next(iterator) 67 | x_batch = x_batch.to(args.device) 68 | optimizer.zero_grad() 69 | loss = model.loss(x_batch).mean() 70 | loss.backward() 71 | optimizer.step() 72 | scheduler.step() 73 | if step_num % 100 == 0: 74 | logger.info(f"Iter: {step_num}\t" + f"Loss: {loss.data:.2f}\t") 75 | torch.save(model.state_dict(), "./ckpts/cifar_trained.pt") 76 | 77 | model.eval() 78 | 79 | samples = model.sample(bsz=64, num_sampling_timesteps=20, device=args.device).cpu().numpy() 80 | samples = samples * rearrange(dataset_sd, "c -> 1 1 c 1 1") + rearrange(dataset_mean, "c -> 1 1 c 1 1") 81 | samples = np.rint(samples * 255).clip(min=0, max=255).astype(np.uint8) 82 | 83 | x_vis = rearrange(dataset.data[:64], "b h w c -> b c h w") 84 | 85 | nrows, ncols = 10, 2 86 | percents = (100, 75, 50, 25, 0) 87 | raster = np.zeros((3, nrows * 32, ncols * 32 * (len(percents) + 1)), dtype=np.uint8) 88 | 89 | for i in range(nrows * ncols): 90 | row, col = i // ncols, i % ncols 91 | raster[:, 32 * row : 32 * (row + 1), 32 * col : 32 * (col + 1)] = x_vis[i] 92 | for percent_idx, percent in enumerate(percents): 93 | itr_num = int(round(0.01 * percent * (len(samples) - 1))) 94 | for i in range(nrows * ncols): 95 | row, col = i // ncols, i % ncols 96 | offset = 32 * ncols * (percent_idx + 1) 97 | raster[:, 32 * row : 32 * (row + 1), offset + 32 * col : offset + 32 * (col + 1)] = samples[itr_num][i] 98 | 99 | raster = rearrange(raster, "c h w -> h w c") 100 | plt.imsave("./examples/ex_cifar.png", raster, vmin=0, vmax=255) 101 | -------------------------------------------------------------------------------- /src/blocks_3d.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from src.blocks import unsqueeze_as, PositionalEmbedding, SelfAttention2d 6 | 7 | 8 | class BasicBlock3d(nn.Module): 9 | """ 10 | BasicBlock: two 3x3 convs followed by a residual connection then ReLU. 11 | [He et al. CVPR 2016] 12 | 13 | BasicBlock(x) = ReLU( x + Conv3x3( ReLU( Conv3x3(x) ) ) ) 14 | 15 | This version supports an additive shift parameterized by time. 16 | """ 17 | def __init__(self, in_c, out_c, time_c): 18 | super().__init__() 19 | self.conv1 = nn.Conv3d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm3d(out_c) 21 | self.conv2 = nn.Conv3d(out_c, out_c, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm3d(out_c) 23 | self.mlp_time = nn.Sequential( 24 | nn.Linear(time_c, time_c), 25 | nn.ReLU(), 26 | nn.Linear(time_c, out_c), 27 | ) 28 | if in_c == out_c: 29 | self.shortcut = nn.Identity() 30 | else: 31 | self.shortcut = nn.Sequential( 32 | nn.Conv3d(in_c, out_c, kernel_size=1, stride=1, bias=False), 33 | nn.BatchNorm3d(out_c) 34 | ) 35 | 36 | def forward(self, x, t): 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = F.relu(out + unsqueeze_as(self.mlp_time(t), x)) 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | out = F.relu(out + self.shortcut(x)) 43 | return out 44 | 45 | 46 | class UNet3d(nn.Module): 47 | """ 48 | Simple implementation that closely mimics the one by Phil Wang (lucidrains). 49 | """ 50 | def __init__(self, in_dim, embed_dim, dim_scales): 51 | super().__init__() 52 | 53 | self.init_embed = nn.Conv3d(in_dim, embed_dim, 1) 54 | self.time_embed = PositionalEmbedding(embed_dim) 55 | 56 | self.down_blocks = nn.ModuleList() 57 | self.up_blocks = nn.ModuleList() 58 | 59 | # Example: 60 | # in_dim=1, embed_dim=32, dim_scales=(1, 2, 4, 8) => all_dims=(32, 32, 64, 128, 256) 61 | all_dims = (embed_dim, *[embed_dim * s for s in dim_scales]) 62 | 63 | for idx, (in_c, out_c) in enumerate(zip( 64 | all_dims[:-1], 65 | all_dims[1:], 66 | )): 67 | is_last = idx == len(all_dims) - 2 68 | self.down_blocks.extend(nn.ModuleList([ 69 | BasicBlock3d(in_c, in_c, embed_dim), 70 | BasicBlock3d(in_c, in_c, embed_dim), 71 | nn.Conv3d(in_c, out_c, (1, 3, 3), (1, 2, 2), (0, 1, 1)) if not is_last else nn.Conv3d(in_c, out_c, 1), 72 | ])) 73 | 74 | for idx, (in_c, out_c, skip_c) in enumerate(zip( 75 | all_dims[::-1][:-1], 76 | all_dims[::-1][1:], 77 | all_dims[:-1][::-1], 78 | )): 79 | is_last = idx == len(all_dims) - 2 80 | self.up_blocks.extend(nn.ModuleList([ 81 | BasicBlock3d(in_c + skip_c, in_c, embed_dim), 82 | BasicBlock3d(in_c + skip_c, in_c, embed_dim), 83 | nn.ConvTranspose3d(in_c, out_c, (1, 2, 2), (1, 2, 2)) if not is_last else nn.Conv3d(in_c, out_c, 1), 84 | ])) 85 | 86 | self.mid_blocks = nn.ModuleList([ 87 | BasicBlock3d(all_dims[-1], all_dims[-1], embed_dim), 88 | SelfAttention2d(all_dims[-1]), 89 | BasicBlock3d(all_dims[-1], all_dims[-1], embed_dim), 90 | ]) 91 | self.out_blocks = nn.ModuleList([ 92 | BasicBlock3d(embed_dim, embed_dim, embed_dim), 93 | nn.Conv3d(embed_dim, in_dim, 1, bias=True), 94 | ]) 95 | 96 | def forward(self, x, t): 97 | _, _, num_frames, *_ = x.shape 98 | x = self.init_embed(x) 99 | t = self.time_embed(t) 100 | skip_conns = [] 101 | residual = x.clone() 102 | 103 | for block in self.down_blocks: 104 | if isinstance(block, BasicBlock3d): 105 | x = block(x, t) 106 | skip_conns.append(x) 107 | else: 108 | x = block(x) 109 | for block in self.mid_blocks: 110 | if isinstance(block, BasicBlock3d): 111 | x = block(x, t) 112 | elif isinstance(block, SelfAttention2d): 113 | x = rearrange(x, "b c t h w -> (b t) c h w") 114 | x = block(x) 115 | x = rearrange(x, "(b t) c h w -> b c t h w", t=num_frames) 116 | else: 117 | x = block(x) 118 | for block in self.up_blocks: 119 | if isinstance(block, BasicBlock3d): 120 | x = torch.cat((x, skip_conns.pop()), dim=1) 121 | x = block(x, t) 122 | else: 123 | x = block(x) 124 | 125 | x = x + residual 126 | for block in self.out_blocks: 127 | if isinstance(block, BasicBlock3d): 128 | x = block(x, t) 129 | else: 130 | x = block(x) 131 | return x 132 | -------------------------------------------------------------------------------- /examples/ex_moving_mnist.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import logging 3 | 4 | from einops import rearrange 5 | from matplotlib import pyplot as plt 6 | import numpy as np 7 | from sklearn.datasets import fetch_openml 8 | import torch 9 | import torch.optim as optim 10 | from torch.utils.data import Dataset, DataLoader 11 | 12 | from src.blocks_3d import UNet3d 13 | from src.schedules import CosineSchedule 14 | from src.score_matching import ScoreMatchingModel, ScoreMatchingModelConfig 15 | 16 | 17 | class MovingMNISTDataset(Dataset): 18 | 19 | # Standardize to [-1, 1] 20 | input_mean = 127.5 21 | input_sd = 127.5 22 | 23 | def __init__(self, num_frames, height, width, velocity=5): 24 | 25 | self.num_frames = num_frames 26 | self.height = height 27 | self.width = width 28 | self.velocity = velocity 29 | self.mnist_L = 28 # 28 x 28 images 30 | 31 | # Load data from https://www.openml.org/d/554 32 | # (70000, 784) values between 0-255 33 | x, _ = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False, cache=True) 34 | 35 | x = ((x - MovingMNISTDataset.input_mean) / MovingMNISTDataset.input_sd).astype(np.float32) 36 | self.x = rearrange(x, "b (h w) -> b h w", h=28, w=28) 37 | 38 | # Enumerate all possible indices in this dataset, we allow 4 directions 39 | self.dir_to_dy_dx = { 40 | 0: (-self.velocity, 0), 41 | 1: (self.velocity, 0), 42 | 2: (0, -self.velocity), 43 | 3: (0, self.velocity), 44 | } 45 | self.index_shape = (len(self.dir_to_dy_dx), self.height - self.mnist_L, self.width - self.mnist_L, len(self.x)) 46 | 47 | def __len__(self): 48 | return np.prod(self.index_shape) 49 | 50 | def __getitem__(self, idx): 51 | 52 | result = np.full((self.num_frames, self.height, self.width), dtype=np.float32, fill_value=-1) 53 | dir_idx, y_idx, x_idx, mnist_idx = np.unravel_index(idx, self.index_shape) 54 | 55 | R = self.mnist_L // 2 56 | 57 | # Denotes center coords 58 | y, x = y_idx + R, x_idx + R 59 | dy, dx = self.dir_to_dy_dx[dir_idx] 60 | 61 | for t in range(self.num_frames): 62 | result[t, y - R : y + R, x - R : x + R] = self.x[mnist_idx] 63 | if y + dy - R <= 0 or y + dy + R >= self.height: 64 | dy *= -1 65 | if x + dx - R <= 0 or x + dx + R >= self.width: 66 | dx *= -1 67 | y += dy 68 | x += dx 69 | 70 | # Draw boundaries along corner of image 71 | result[:, :, 0].fill(1.0) 72 | result[:, :, -1].fill(1.0) 73 | result[:, 0, :].fill(1.0) 74 | result[:, -1, :].fill(1.0) 75 | 76 | return result[np.newaxis] 77 | 78 | @staticmethod 79 | def visualize_one(x): 80 | num_channels, num_frames, height, width = x.shape 81 | assert num_channels == 1 82 | x = x * MovingMNISTDataset.input_sd + MovingMNISTDataset.input_mean 83 | x = np.clip(x, a_min=0, a_max=255) 84 | x = x.astype(np.uint8) 85 | raster = np.zeros((height, width * num_frames), dtype=np.uint8) 86 | for t in range(num_frames): 87 | raster[:, width * t : width * (t + 1)] = x[:, t].squeeze(axis=0) 88 | return raster 89 | 90 | 91 | if __name__ == "__main__": 92 | 93 | argparser = ArgumentParser() 94 | argparser.add_argument("--iterations", default=2000, type=int) 95 | argparser.add_argument("--batch-size", default=16, type=int) 96 | argparser.add_argument("--device", default="cuda", type=str, choices=("cuda", "cpu", "mps")) 97 | argparser.add_argument("--load-trained", default=0, type=int, choices=(0, 1)) 98 | argparser.add_argument("--num-frames", default=4, type=int) 99 | argparser.add_argument("--velocity", default=4, type=int) 100 | args = argparser.parse_args() 101 | 102 | logging.basicConfig(level=logging.INFO) 103 | logger = logging.getLogger(__name__) 104 | 105 | nn_module = UNet3d(1, 128, (1, 2, 4, 8)) 106 | model = ScoreMatchingModel( 107 | nn_module=nn_module, 108 | input_shape=(1, args.num_frames, 64, 64), 109 | config=ScoreMatchingModelConfig( 110 | sigma_min=0.002, 111 | sigma_max=80.0, 112 | sigma_data=1.0, 113 | train_sigma_schedule=CosineSchedule(kappa=0.25), 114 | test_sigma_schedule=CosineSchedule(kappa=0.25), 115 | ), 116 | ) 117 | model = model.to(args.device) 118 | 119 | dataset = MovingMNISTDataset(args.num_frames, 64, 64, velocity=args.velocity) 120 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) 121 | 122 | optimizer = optim.AdamW(model.parameters(), lr=0.001) 123 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.iterations) 124 | 125 | if args.load_trained: 126 | model.load_state_dict(torch.load("./ckpts/moving_mnist_trained.pt")) 127 | else: 128 | iterator = iter(dataloader) 129 | for step_num in range(args.iterations): 130 | try: 131 | x_batch = next(iterator) 132 | except StopIteration: 133 | iterator = iter(dataloader) 134 | x_batch = next(iterator) 135 | x_batch = x_batch.to(args.device) 136 | optimizer.zero_grad() 137 | loss = model.loss(x_batch).mean() 138 | loss.backward() 139 | optimizer.step() 140 | scheduler.step() 141 | if step_num % 100 == 0: 142 | logger.info(f"Iter: {step_num}\t" + f"Loss: {loss.data:.2f}\t") 143 | torch.save(model.state_dict(), "./ckpts/moving_mnist_trained.pt") 144 | 145 | model.eval() 146 | 147 | NUM_SHOWN = 4 148 | 149 | samples = model.sample(bsz=NUM_SHOWN, num_sampling_timesteps=20, device=args.device).cpu().numpy() 150 | samples = samples[0] 151 | 152 | gt_raster = np.concatenate([ 153 | MovingMNISTDataset.visualize_one(dataset[i]) for i in range(NUM_SHOWN) 154 | ], axis=0) 155 | pred_raster = np.concatenate([ 156 | MovingMNISTDataset.visualize_one(samples[i]) for i in range(NUM_SHOWN) 157 | ], axis=0) 158 | 159 | raster = np.concatenate([gt_raster, pred_raster], axis=0) 160 | plt.imsave("./examples/ex_moving_mnist.png", raster, vmin=0, vmax=255) 161 | -------------------------------------------------------------------------------- /src/score_matching.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from src.blocks import unsqueeze_as 8 | from src.samplers import InstantaneousPrediction, Sampler, EulerSampler 9 | from src.schedules import CosineSchedule, NoiseSchedule 10 | 11 | 12 | @dataclass(frozen=True) 13 | class ScoreMatchingModelConfig: 14 | 15 | # Network configuration and loss weighting 16 | sigma_min: float = 0.002 17 | sigma_max: float = 80.0 18 | sigma_data: float = 0.5 19 | 20 | # Training time configuration 21 | loss_type: str = "l2" 22 | loss_weighting_type: str = "karras" 23 | train_sigma_schedule: NoiseSchedule = CosineSchedule() 24 | 25 | # Inference time configuration 26 | sampler: Sampler = EulerSampler() 27 | test_sigma_schedule: NoiseSchedule = CosineSchedule() 28 | 29 | def __post_init__(self): 30 | assert 0 <= self.sigma_min <= self.sigma_max 31 | assert self.sigma_min == self.train_sigma_schedule.sigma_min 32 | assert self.sigma_min == self.test_sigma_schedule.sigma_min 33 | assert self.sigma_max == self.train_sigma_schedule.sigma_max 34 | assert self.sigma_max == self.test_sigma_schedule.sigma_max 35 | assert self.loss_type in ("l1", "l2") 36 | assert self.loss_weighting_type in ("ones", "snr", "karras", "min_snr") 37 | 38 | 39 | class ScoreMatchingModel(nn.Module): 40 | 41 | def __init__( 42 | self, 43 | input_shape: tuple[int, ...], 44 | nn_module: nn.Module, 45 | config: ScoreMatchingModelConfig, 46 | ): 47 | super().__init__() 48 | self.input_shape = input_shape 49 | self.nn_module = nn_module 50 | 51 | # Input shape must be either (c,) or (c, h, w) or (c, t, h, w) 52 | assert len(input_shape) in (1, 3, 4) 53 | 54 | self.sigma_data = config.sigma_data 55 | self.sigma_min = config.sigma_min 56 | self.sigma_max = config.sigma_max 57 | self.loss_type = config.loss_type 58 | self.loss_weighting_type = config.loss_weighting_type 59 | self.train_sigma_schedule = config.train_sigma_schedule 60 | self.test_sigma_schedule = config.test_sigma_schedule 61 | self.sampler = config.sampler 62 | 63 | def nn_module_wrapper(self, x, sigma, num_discrete_chunks=10000): 64 | """ 65 | This function does two things: 66 | 1. Implements Karras et al. 2022 pre-conditioning. 67 | 2. Converts sigma in range [sigma_min, sigma_max] into a discrete input. 68 | 69 | Parameters 70 | ---------- 71 | x: (bsz, *self.input_shape) 72 | sigma: (bsz,) 73 | """ 74 | c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2) 75 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 76 | c_out = sigma * self.sigma_data / (self.sigma_data ** 2 + sigma ** 2) ** 0.5 77 | c_skip, c_in, c_out = unsqueeze_as(c_skip, x), unsqueeze_as(c_in, x), unsqueeze_as(c_out, x) 78 | log_sigmas_percentile = ( 79 | (torch.log(sigma) - math.log(self.sigma_min)) / 80 | (math.log(self.sigma_max) - math.log(self.sigma_min)) 81 | ) 82 | sigmas_discrete = torch.floor(num_discrete_chunks * log_sigmas_percentile) 83 | sigmas_discrete = sigmas_discrete.clamp_(min=0, max=num_discrete_chunks - 1).long() 84 | return c_out * self.nn_module(c_in * x, sigmas_discrete) + c_skip * x 85 | 86 | def loss(self, x): 87 | """ 88 | Returns 89 | ------- 90 | loss: (bsz, *input_shape) 91 | """ 92 | bsz, *_ = x.shape 93 | 94 | rng = torch.rand((bsz,), device=x.device) 95 | sigma = self.train_sigma_schedule.get_sigma_ppf(rng) 96 | 97 | x_t = x + unsqueeze_as(sigma, x) * torch.randn_like(x) 98 | pred = self.nn_module_wrapper(x_t, sigma) 99 | 100 | if self.loss_type == "l2": 101 | loss = (0.5 * (x - pred) ** 2) 102 | elif self.loss_type == "l1": 103 | loss = (x - pred).abs() 104 | else: 105 | raise AssertionError(f"Invalid {self.loss_type=}.") 106 | 107 | if self.loss_weighting_type == "ones": 108 | loss_weights = torch.ones_like(sigma) 109 | elif self.loss_weighting_type == "snr": 110 | loss_weights = self.sigma_data ** 2 / sigma ** 2 111 | elif self.loss_weighting_type == "min_snr": 112 | loss_weights = torch.clamp(self.sigma_data ** 2 / sigma ** 2, max=5.0) 113 | elif self.loss_weighting_type == "karras": 114 | loss_weights = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 115 | else: 116 | raise AssertionError(f"Invalid {self.loss_weighting_type=}.") 117 | 118 | loss *= unsqueeze_as(loss_weights, loss) 119 | return loss 120 | 121 | @torch.no_grad() 122 | def sample(self, bsz, device, num_sampling_timesteps: int): 123 | """ 124 | Parameters 125 | ---------- 126 | num_sampling_timesteps: number of steps to take between sigma_max to sigma_min 127 | 128 | Returns 129 | ------- 130 | samples: (sampling_timesteps + 1, bsz, *self.input_shape) 131 | index 0 corresponds to x_0 132 | index t corresponds to x_t 133 | last index corresponds to random noise 134 | 135 | Notes 136 | ----- 137 | This is deterministic for now, solving the probability flow ODE. 138 | """ 139 | assert num_sampling_timesteps >= 1 140 | 141 | linspace = torch.linspace(1.0, 0.0, num_sampling_timesteps + 1, device=device) 142 | sigmas = self.test_sigma_schedule.get_sigma_ppf(linspace) 143 | 144 | sigma_start = torch.empty((bsz,), device=device) 145 | sigma_end = torch.empty((bsz,), device=device) 146 | 147 | x = torch.randn((bsz, *self.input_shape), device=device) * self.sigma_max 148 | samples = torch.empty((num_sampling_timesteps + 1, bsz, *self.input_shape), device=device) 149 | samples[-1] = x * self.sigma_data / (self.sigma_max ** 2 + self.sigma_data ** 2) ** 0.5 150 | 151 | self.sampler.reset() 152 | 153 | for idx, (scalar_sigma_start, scalar_sigma_end) in enumerate(zip(sigmas[:-1], sigmas[1:])): 154 | 155 | sigma_start.fill_(scalar_sigma_start) 156 | sigma_end.fill_(scalar_sigma_end) 157 | 158 | pred_x_0 = self.nn_module_wrapper(x, sigma_start) 159 | pred_at_x_t = InstantaneousPrediction(scalar_sigma_start, x, pred_x_0) 160 | x = self.sampler.step(scalar_sigma_start, scalar_sigma_end, pred_at_x_t) 161 | 162 | normalization_factor = ( 163 | self.sigma_data / (scalar_sigma_end ** 2 + self.sigma_data ** 2) ** 0.5 164 | ) 165 | samples[-1 - idx - 1] = x * normalization_factor 166 | 167 | return samples 168 | -------------------------------------------------------------------------------- /src/simple/diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file provides a simple, self-contained implementation of DDIM (with DDPM as a special case). 3 | """ 4 | from dataclasses import dataclass 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from src.blocks import unsqueeze_to 11 | 12 | 13 | @dataclass(frozen=True) 14 | class DiffusionModelConfig: 15 | 16 | num_timesteps: int 17 | target_type: str = "pred_eps" 18 | noise_schedule_type: str = "cosine" 19 | loss_type: str = "l2" 20 | gamma_type: float = "ddim" 21 | 22 | def __post_init__(self): 23 | assert self.num_timesteps > 0 24 | assert self.target_type in ("pred_x_0", "pred_eps", "pred_v") 25 | assert self.noise_schedule_type in ("linear", "cosine") 26 | assert self.loss_type in ("l1", "l2") 27 | assert self.gamma_type in ("ddim", "ddpm") 28 | 29 | 30 | class DiffusionModel(nn.Module): 31 | 32 | def __init__( 33 | self, 34 | input_shape: tuple[int, ...], 35 | nn_module: nn.Module, 36 | config: DiffusionModelConfig, 37 | ): 38 | super().__init__() 39 | self.input_shape = input_shape 40 | self.nn_module = nn_module 41 | self.num_timesteps = config.num_timesteps 42 | self.target_type = config.target_type 43 | self.gamma_type = config.gamma_type 44 | self.noise_schedule_type = config.noise_schedule_type 45 | self.loss_type = config.loss_type 46 | 47 | # Input shape must be either (c,) or (c, h, w) or (c, t, h, w) 48 | assert len(input_shape) in (1, 3, 4) 49 | 50 | # Construct the noise schedule 51 | if self.noise_schedule_type == "linear": 52 | beta_t = torch.linspace(1e-4, 2e-2, self.num_timesteps + 1) 53 | alpha_t = torch.cumprod(1 - beta_t, dim=0) ** 0.5 54 | elif self.noise_schedule_type == "cosine": 55 | linspace = torch.linspace(0, 1, self.num_timesteps + 1) 56 | f_t = torch.cos((linspace + 0.008) / (1 + 0.008) * math.pi / 2) ** 2 57 | bar_alpha_t = f_t / f_t[0] 58 | beta_t = torch.zeros_like(bar_alpha_t) 59 | beta_t[1:] = (1 - (bar_alpha_t[1:] / bar_alpha_t[:-1])).clamp(min=0, max=0.999) 60 | alpha_t = torch.cumprod(1 - beta_t, dim=0) ** 0.5 61 | else: 62 | raise AssertionError(f"Invalid {self.noise_schedule_type=}.") 63 | 64 | # These tensors are shape (num_timesteps + 1, *self.input_shape) 65 | # For example, 2D: (num_timesteps + 1, 1, 1, 1) 66 | # 1D: (num_timesteps + 1, 1) 67 | alpha_t = unsqueeze_to(alpha_t, len(self.input_shape) + 1) 68 | sigma_t = (1 - alpha_t ** 2).clamp(min=0) ** 0.5 69 | self.register_buffer("alpha_t", alpha_t) 70 | self.register_buffer("sigma_t", sigma_t) 71 | 72 | def loss(self, x: torch.Tensor): 73 | """ 74 | Returns 75 | ------- 76 | loss: (bsz, *input_shape) 77 | """ 78 | bsz, *_ = x.shape 79 | t_sample = torch.randint(1, self.num_timesteps + 1, size=(bsz,), device=x.device) 80 | eps = torch.randn_like(x) 81 | x_t = self.alpha_t[t_sample] * x + self.sigma_t[t_sample] * eps 82 | pred_target = self.nn_module(x_t, t_sample) 83 | 84 | if self.target_type == "pred_x_0": 85 | gt_target = x 86 | elif self.target_type == "pred_eps": 87 | gt_target = eps 88 | elif self.target_type == "pred_v": 89 | gt_target = self.alpha_t[t_sample] * eps - self.sigma_t[t_sample] * x 90 | else: 91 | raise AssertionError(f"Invalid {self.target_type=}.") 92 | 93 | if self.loss_type == "l2": 94 | loss = 0.5 * (gt_target - pred_target) ** 2 95 | elif self.loss_type == "l1": 96 | loss = torch.abs(gt_target - pred_target) 97 | else: 98 | raise AssertionError(f"Invalid {self.loss_type=}.") 99 | 100 | return loss 101 | 102 | @torch.no_grad() 103 | def sample(self, bsz: int, device: str, num_sampling_timesteps: int | None = None): 104 | """ 105 | Parameters 106 | ---------- 107 | num_sampling_timesteps: int. If unspecified, defaults to self.num_timesteps. 108 | 109 | Returns 110 | ------- 111 | samples: (num_sampling_timesteps + 1, bsz, *self.input_shape) 112 | index 0 corresponds to x_0 113 | index t corresponds to x_t 114 | last index corresponds to random noise 115 | """ 116 | num_sampling_timesteps = num_sampling_timesteps or self.num_timesteps 117 | assert 1 <= num_sampling_timesteps <= self.num_timesteps 118 | 119 | x = torch.randn((bsz, *self.input_shape), device=device) 120 | t_start = torch.empty((bsz,), dtype=torch.int64, device=device) 121 | t_end = torch.empty((bsz,), dtype=torch.int64, device=device) 122 | 123 | subseq = torch.linspace(self.num_timesteps, 0, num_sampling_timesteps + 1).round() 124 | samples = torch.empty((num_sampling_timesteps + 1, bsz, *self.input_shape), device=device) 125 | samples[-1] = x 126 | 127 | # Note that t_start > t_end we're traversing pairwise down subseq. 128 | # For example, subseq here could be [500, 400, 300, 200, 100, 0] 129 | for idx, (scalar_t_start, scalar_t_end) in enumerate(zip(subseq[:-1], subseq[1:])): 130 | 131 | t_start.fill_(scalar_t_start) 132 | t_end.fill_(scalar_t_end) 133 | noise = torch.zeros_like(x) if scalar_t_end == 0 else torch.randn_like(x) 134 | 135 | if self.gamma_type == "ddim": 136 | gamma_t = 0.0 137 | elif self.gamma_type == "ddpm": 138 | gamma_t = ( 139 | self.sigma_t[t_end] / self.sigma_t[t_start] * 140 | (1 - self.alpha_t[t_start] ** 2 / self.alpha_t[t_end] ** 2) ** 0.5 141 | ) 142 | else: 143 | raise AssertionError(f"Invalid {self.gamma_type=}.") 144 | 145 | nn_out = self.nn_module(x, t_start) 146 | if self.target_type == "pred_x_0": 147 | pred_x_0 = nn_out 148 | pred_eps = (x - self.alpha_t[t_start] * nn_out) / self.sigma_t[t_start] 149 | elif self.target_type == "pred_eps": 150 | pred_x_0 = (x - self.sigma_t[t_start] * nn_out) / self.alpha_t[t_start] 151 | pred_eps = nn_out 152 | elif self.target_type == "pred_v": 153 | pred_x_0 = self.alpha_t[t_start] * x - self.sigma_t[t_start] * nn_out 154 | pred_eps = self.sigma_t[t_start] * x + self.alpha_t[t_start] * nn_out 155 | else: 156 | raise AssertionError(f"Invalid {self.target_type=}.") 157 | 158 | x = ( 159 | (self.alpha_t[t_end] * pred_x_0) + 160 | (self.sigma_t[t_end] ** 2 - gamma_t ** 2).clamp(min=0) ** 0.5 * pred_eps + 161 | (gamma_t * noise) 162 | ) 163 | samples[-1 - idx - 1] = x 164 | 165 | return samples 166 | 167 | -------------------------------------------------------------------------------- /src/experimental/consistency.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from src.blocks import unsqueeze_as 8 | 9 | 10 | @dataclass(frozen=True) 11 | class ConsistencyModelConfig: 12 | 13 | train_steps_limit: int 14 | loss_type: str = "l2" 15 | s_0: int = 10 16 | s_1: int = 1280 17 | sigma_min: float = 0.002 18 | sigma_max: float = 80.0 19 | sigma_data: float = 0.5 20 | rho: float = 7.0 21 | p_mean: float = -1.1 22 | p_std: float = 2.0 23 | 24 | def __post_init__(self): 25 | assert self.s_0 <= self.s_1 26 | assert self.sigma_min <= self.sigma_max 27 | assert self.loss_type in ("l1", "l2") 28 | 29 | 30 | class ConsistencyModel(nn.Module): 31 | 32 | def __init__( 33 | self, 34 | input_shape: tuple[int, ...], 35 | nn_module: nn.Module, 36 | config: ConsistencyModelConfig, 37 | ): 38 | super().__init__() 39 | self.input_shape = input_shape 40 | self.nn_module = nn_module 41 | 42 | # Input shape must be either (c,) or (c, h, w) or (c, t, h, w) 43 | assert len(input_shape) in (1, 3, 4) 44 | 45 | # Unpack config and pre-compute a few relevant constants 46 | self.p_mean = config.p_mean 47 | self.p_std = config.p_std 48 | self.s_0 = config.s_0 49 | self.s_1 = config.s_1 50 | self.sigma_data = config.sigma_data 51 | self.sigma_min = config.sigma_min 52 | self.sigma_max = config.sigma_max 53 | self.rho = config.rho 54 | self.loss_type = config.loss_type 55 | self.train_steps_limit = config.train_steps_limit 56 | self.sigma_min_root = (self.sigma_min) ** (1 / self.rho) 57 | self.sigma_max_root = (self.sigma_max) ** (1 / self.rho) 58 | self.k_prime = math.floor(self.train_steps_limit / (math.log2(self.s_1 / self.s_0) + 1)) 59 | self.pseudo_huber_c = 0.00054 * math.prod(input_shape) ** 0.5 60 | 61 | def nn_module_wrapper(self, x, sigma, num_discrete_chunks=10000): 62 | """ 63 | This function does two things: 64 | 1. Implements Karras et al. 2022 pre-conditioning 65 | 2. Converts sigma which have range (sigma_min, sigma_max) into a discrete input 66 | 67 | Parameters 68 | ---------- 69 | x: (bsz, *self.input_shape) 70 | sigma: (bsz,) 71 | """ 72 | c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2) 73 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 74 | c_out = sigma * self.sigma_data / (self.sigma_data ** 2 + sigma ** 2) ** 0.5 75 | c_skip, c_in, c_out = unsqueeze_as(c_skip, x), unsqueeze_as(c_in, x), unsqueeze_as(c_out, x) 76 | sigmas_percentile = ( 77 | ((sigma ** (1 / self.rho)) - self.sigma_min_root) / (self.sigma_max_root - self.sigma_min_root) 78 | ) 79 | sigmas_discrete = torch.floor(num_discrete_chunks * sigmas_percentile).clamp(max=num_discrete_chunks - 1).long() 80 | return c_out * self.nn_module(c_in * x, sigmas_discrete) + c_skip * x 81 | 82 | def loss(self, x, train_step_number: int): 83 | """ 84 | Returns 85 | ------- 86 | loss: (bsz, *input_shape) 87 | """ 88 | bsz, *_ = x.shape 89 | 90 | # First compute the amount of discretization (number of sigmas needed) 91 | train_step_number = min(train_step_number, self.train_steps_limit) 92 | num_sigmas = min(self.s_0 * 2 ** math.floor(train_step_number / self.k_prime), self.s_1) + 1 93 | 94 | # Discretize the sigma space 95 | linspace = torch.linspace(0, 1, num_sigmas, device=x.device) 96 | sigmas = (self.sigma_min_root + linspace * (self.sigma_max_root - self.sigma_min_root)) ** self.rho 97 | 98 | # Draw a sample of sigma wiht importance sampling 99 | sigmas_weights = ( 100 | torch.erf((torch.log(sigmas[1:]) - self.p_mean) / (2 ** 0.5 * self.p_std)) - 101 | torch.erf((torch.log(sigmas[:-1]) - self.p_mean) / (2 ** 0.5 * self.p_std)) 102 | ) 103 | sampled_idxs = torch.multinomial(sigmas_weights, num_samples=bsz, replacement=True) 104 | sigma_t = sigmas[sampled_idxs] 105 | sigma_t_plus_1 = sigmas[sampled_idxs + 1] 106 | 107 | # Forward the student and teacher, ensuring dropout is identical for both 108 | eps = torch.randn_like(x) 109 | with torch.random.fork_rng(): 110 | x_t_plus_1 = x + unsqueeze_as(sigma_t_plus_1, x) * eps 111 | pred_t_plus_1 = self.nn_module_wrapper(x_t_plus_1, sigma_t_plus_1) 112 | with torch.no_grad(), torch.random.fork_rng(): 113 | x_t = x + unsqueeze_as(sigma_t, x) * eps 114 | pred_t = self.nn_module_wrapper(x_t, sigma_t) 115 | 116 | # Compute loss and corresponding weights 117 | if self.loss_type == "l2": 118 | loss = (0.5 * (pred_t_plus_1 - pred_t) ** 2) 119 | elif self.loss_type == "l1": 120 | loss = (pred_t_plus_1 - pred_t).abs() 121 | else: 122 | raise AssertionError(f"Invalid {self.loss_type=}.") 123 | 124 | loss_weights = 1 / (sigma_t_plus_1 - sigma_t) 125 | loss *= unsqueeze_as(loss_weights, loss) 126 | return loss 127 | 128 | @torch.no_grad() 129 | def sample(self, bsz, device, num_sampling_timesteps: int): 130 | """ 131 | Parameters 132 | ---------- 133 | num_sampling_timesteps: number of steps to take between sigma_max to sigma_min 134 | 135 | Returns 136 | ------- 137 | samples: (sampling_timesteps + 1, bsz, *self.input_shape) 138 | index 0 corresponds to x_0 139 | index t corresponds to x_t 140 | last index corresponds to random noise 141 | """ 142 | assert num_sampling_timesteps >= 1 143 | 144 | linspace = torch.linspace(1, 0, num_sampling_timesteps + 1, device=device) 145 | sigmas = (self.sigma_min_root + linspace * (self.sigma_max_root - self.sigma_min_root)) ** self.rho 146 | 147 | sigma_start = torch.empty((bsz,), dtype=torch.int64, device=device) 148 | sigma_end = torch.empty((bsz,), dtype=torch.int64, device=device) 149 | 150 | x = torch.randn((bsz, *self.input_shape), device=device) * self.sigma_max_root 151 | samples = torch.empty((num_sampling_timesteps + 1, bsz, *self.input_shape), device=device) 152 | samples[-1] = x * self.sigma_data / (self.sigma_max ** 2 + self.sigma_data ** 2) ** 0.5 153 | 154 | for idx, (scalar_sigma_start, scalar_sigma_end) in enumerate(zip(sigmas[:-1], sigmas[1:])): 155 | 156 | sigma_start.fill_(scalar_sigma_start) 157 | sigma_end.fill_(scalar_sigma_end) 158 | 159 | x = self.nn_module_wrapper(x, sigma_start) 160 | eps = torch.randn_like(x) 161 | x += unsqueeze_as((sigma_end ** 2 - self.sigma_min ** 2).clamp(min=0) ** 0.5, x) * eps 162 | 163 | normalization_factor = self.sigma_data / (scalar_sigma_end ** 2 + self.sigma_data ** 2) ** 0.5 164 | samples[-1 - idx - 1] = x * normalization_factor 165 | 166 | return samples 167 | 168 | -------------------------------------------------------------------------------- /src/blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from einops import parse_shape, rearrange 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def unsqueeze_to(tensor, target_ndim): 10 | assert tensor.ndim <= target_ndim 11 | while tensor.ndim < target_ndim: 12 | tensor = tensor.unsqueeze(-1) 13 | return tensor 14 | 15 | def unsqueeze_as(tensor, target_tensor): 16 | assert tensor.ndim <= target_tensor.ndim 17 | while tensor.ndim < target_tensor.ndim: 18 | tensor = tensor.unsqueeze(-1) 19 | return tensor 20 | 21 | 22 | class PositionalEmbedding(nn.Module): 23 | def __init__(self, dim, max_length=10000): 24 | super().__init__() 25 | self.register_buffer("embedding", self.make_embedding(dim, max_length)) 26 | 27 | def forward(self, x): 28 | # Parameters 29 | # x: (bsz,) discrete 30 | return self.embedding[x] 31 | 32 | @staticmethod 33 | def make_embedding(dim, max_length=10000): 34 | embedding = torch.zeros(max_length, dim) 35 | position = torch.arange(0, max_length).unsqueeze(1) 36 | div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(max_length / 2 / math.pi) / dim)) 37 | embedding[:, 0::2] = torch.sin(position * div_term) 38 | embedding[:, 1::2] = torch.cos(position * div_term) 39 | return embedding 40 | 41 | 42 | class FourierEmbedding(nn.Module): 43 | def __init__(self, dim, max_period=10000): 44 | super().__init__() 45 | assert dim % 2 == 0 46 | self.register_buffer("freqs", self.make_freqs(dim, max_period)) 47 | 48 | def forward(self, x): 49 | # Parameters 50 | # x: (bsz,) continuous 51 | outer = torch.outer(x, self.freqs) 52 | embedding = torch.cat([torch.cos(outer), torch.sin(outer)], dim=-1) 53 | return embedding 54 | 55 | @staticmethod 56 | def make_freqs(dim, max_period=10000): 57 | return torch.exp(torch.arange(0, dim, 2) * -math.log(max_period) / dim) 58 | 59 | 60 | class FFN(nn.Module): 61 | def __init__(self, in_dim, embed_dim): 62 | super().__init__() 63 | self.init_embed = nn.Linear(in_dim, embed_dim) 64 | self.time_embed = PositionalEmbedding(embed_dim) 65 | self.model = nn.Sequential( 66 | nn.Linear(embed_dim, embed_dim), 67 | nn.ReLU(), 68 | nn.Linear(embed_dim, embed_dim), 69 | nn.ReLU(), 70 | nn.Linear(embed_dim, embed_dim), 71 | nn.ReLU(), 72 | nn.Linear(embed_dim, embed_dim), 73 | nn.ReLU(), 74 | nn.Linear(embed_dim, embed_dim), 75 | nn.ReLU(), 76 | nn.Linear(embed_dim, in_dim), 77 | ) 78 | 79 | def forward(self, x, t): 80 | x = self.init_embed(x) 81 | t = self.time_embed(t) 82 | return self.model(x + t) 83 | 84 | 85 | class BasicBlock(nn.Module): 86 | """ 87 | BasicBlock: two 3x3 convs followed by a residual connection then ReLU. 88 | [He et al. CVPR 2016] 89 | 90 | BasicBlock(x) = ReLU( x + Conv3x3( ReLU( Conv3x3(x) ) ) ) 91 | 92 | This version supports an additive shift parameterized by time. 93 | """ 94 | def __init__(self, in_c, out_c, time_c): 95 | super().__init__() 96 | self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=False) 97 | self.bn1 = nn.BatchNorm2d(out_c) 98 | self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, bias=False) 99 | self.bn2 = nn.BatchNorm2d(out_c) 100 | self.mlp_time = nn.Sequential( 101 | nn.Linear(time_c, time_c), 102 | nn.ReLU(), 103 | nn.Linear(time_c, out_c), 104 | ) 105 | if in_c == out_c: 106 | self.shortcut = nn.Identity() 107 | else: 108 | self.shortcut = nn.Sequential( 109 | nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, bias=False), 110 | nn.BatchNorm2d(out_c) 111 | ) 112 | 113 | def forward(self, x, t): 114 | out = self.conv1(x) 115 | out = self.bn1(out) 116 | out = F.relu(out + unsqueeze_as(self.mlp_time(t), x)) 117 | out = self.conv2(out) 118 | out = self.bn2(out) 119 | out = F.relu(out + self.shortcut(x)) 120 | return out 121 | 122 | 123 | class SelfAttention2d(nn.Module): 124 | """ 125 | Only implements the MultiHeadAttention component, not the PositionwiseFFN component. 126 | """ 127 | def __init__(self, dim, num_heads=8, dropout_prob=0.1): 128 | super().__init__() 129 | self.dim = dim 130 | self.num_heads = num_heads 131 | self.q_conv = nn.Conv2d(dim, dim, 1, bias=True) 132 | self.k_conv = nn.Conv2d(dim, dim, 1, bias=True) 133 | self.v_conv = nn.Conv2d(dim, dim, 1, bias=True) 134 | self.o_conv = nn.Conv2d(dim, dim, 1, bias=True) 135 | self.dropout = nn.Dropout(dropout_prob) 136 | 137 | def forward(self, x): 138 | q = self.q_conv(x) 139 | k = self.k_conv(x) 140 | v = self.v_conv(x) 141 | q = rearrange(q, "b (g c) h w -> (b g) c (h w)", g=self.num_heads) 142 | k = rearrange(k, "b (g c) h w -> (b g) c (h w)", g=self.num_heads) 143 | v = rearrange(v, "b (g c) h w -> (b g) c (h w)", g=self.num_heads) 144 | a = torch.einsum("b c s, b c t -> b s t", q, k) / self.dim ** 0.5 145 | a = self.dropout(torch.softmax(a, dim=-1)) 146 | o = torch.einsum("b s t, b c t -> b c s", a, v) 147 | o = rearrange(o, "(b g) c (h w) -> b (g c) h w", g=self.num_heads, w=x.shape[-1]) 148 | return x + self.o_conv(o) 149 | 150 | 151 | class UNet(nn.Module): 152 | """ 153 | Simple implementation that closely mimics the one by Phil Wang (lucidrains). 154 | """ 155 | def __init__(self, in_dim, embed_dim, dim_scales): 156 | super().__init__() 157 | 158 | self.init_embed = nn.Conv2d(in_dim, embed_dim, 1) 159 | self.time_embed = PositionalEmbedding(embed_dim) 160 | 161 | self.down_blocks = nn.ModuleList() 162 | self.up_blocks = nn.ModuleList() 163 | 164 | # Example: 165 | # in_dim=1, embed_dim=32, dim_scales=(1, 2, 4, 8) => all_dims=(32, 32, 64, 128, 256) 166 | all_dims = (embed_dim, *[embed_dim * s for s in dim_scales]) 167 | 168 | for idx, (in_c, out_c) in enumerate(zip( 169 | all_dims[:-1], 170 | all_dims[1:], 171 | )): 172 | is_last = idx == len(all_dims) - 2 173 | self.down_blocks.extend(nn.ModuleList([ 174 | BasicBlock(in_c, in_c, embed_dim), 175 | BasicBlock(in_c, in_c, embed_dim), 176 | nn.Conv2d(in_c, out_c, 3, 2, 1) if not is_last else nn.Conv2d(in_c, out_c, 1), 177 | ])) 178 | 179 | for idx, (in_c, out_c, skip_c) in enumerate(zip( 180 | all_dims[::-1][:-1], 181 | all_dims[::-1][1:], 182 | all_dims[:-1][::-1], 183 | )): 184 | is_last = idx == len(all_dims) - 2 185 | self.up_blocks.extend(nn.ModuleList([ 186 | BasicBlock(in_c + skip_c, in_c, embed_dim), 187 | BasicBlock(in_c + skip_c, in_c, embed_dim), 188 | nn.ConvTranspose2d(in_c, out_c, 2, 2) if not is_last else nn.Conv2d(in_c, out_c, 1), 189 | ])) 190 | 191 | self.mid_blocks = nn.ModuleList([ 192 | BasicBlock(all_dims[-1], all_dims[-1], embed_dim), 193 | SelfAttention2d(all_dims[-1]), 194 | BasicBlock(all_dims[-1], all_dims[-1], embed_dim), 195 | ]) 196 | self.out_blocks = nn.ModuleList([ 197 | BasicBlock(embed_dim, embed_dim, embed_dim), 198 | nn.Conv2d(embed_dim, in_dim, 1, bias=True), 199 | ]) 200 | 201 | def forward(self, x, t): 202 | x = self.init_embed(x) 203 | t = self.time_embed(t) 204 | skip_conns = [] 205 | residual = x.clone() 206 | 207 | for block in self.down_blocks: 208 | if isinstance(block, BasicBlock): 209 | x = block(x, t) 210 | skip_conns.append(x) 211 | else: 212 | x = block(x) 213 | for block in self.mid_blocks: 214 | if isinstance(block, BasicBlock): 215 | x = block(x, t) 216 | else: 217 | x = block(x) 218 | for block in self.up_blocks: 219 | if isinstance(block, BasicBlock): 220 | x = torch.cat((x, skip_conns.pop()), dim=1) 221 | x = block(x, t) 222 | else: 223 | x = block(x) 224 | 225 | x = x + residual 226 | for block in self.out_blocks: 227 | if isinstance(block, BasicBlock): 228 | x = block(x, t) 229 | else: 230 | x = block(x) 231 | return x 232 | 233 | 234 | class MultiHeadAttention(nn.Module): 235 | """ 236 | Multi-Head Attention [Vaswani et al. NeurIPS 2017]. 237 | Scaled dot-product attention is performed over V, using K as keys and Q as queries. 238 | MultiHeadAttention(Q, V) = FC(SoftMax(1/√d QKᵀ) V) (concatenated over multiple heads), 239 | Notes 240 | ----- 241 | (1) Q, K, V can be of different dimensions. Q and K are projected to dim_a and V to dim_o. 242 | (2) We assume the last and second last dimensions correspond to the feature (i.e. embedding) 243 | and token (i.e. words) dimensions respectively. 244 | """ 245 | def __init__(self, dim_q, dim_k, dim_v, num_heads=8, dropout_prob=0.1, dim_a=None, dim_o=None): 246 | super().__init__() 247 | if dim_a is None: 248 | dim_a = dim_q 249 | if dim_o is None: 250 | dim_o = dim_q 251 | self.dim_a, self.dim_o, self.num_heads = dim_a, dim_o, num_heads 252 | self.fc_q = nn.Linear(dim_q, dim_a, bias=True) 253 | self.fc_k = nn.Linear(dim_k, dim_a, bias=True) 254 | self.fc_v = nn.Linear(dim_v, dim_o, bias=True) 255 | self.fc_o = nn.Linear(dim_o, dim_o, bias=True) 256 | self.dropout = nn.Dropout(dropout_prob) 257 | for module in (self.fc_q, self.fc_k, self.fc_v, self.fc_o): 258 | nn.init.xavier_normal_(module.weight) 259 | nn.init.constant_(module.bias, 0.) 260 | 261 | def forward(self, q, k, v, mask=None): 262 | """ 263 | Perform multi-head attention with given queries and values. 264 | Parameters 265 | ---------- 266 | q: (bsz, tsz, dim_q) 267 | k: (bsz, tsz, dim_k) 268 | v: (bsz, tsz, dim_v) 269 | mask: (bsz, tsz) or (bsz, tsz, tsz), where 1 denotes keep and 0 denotes remove 270 | Returns 271 | ------- 272 | O: (bsz, tsz, dim_o) 273 | """ 274 | bsz, tsz, _ = q.shape 275 | q, k, v = self.fc_q(q), self.fc_k(k), self.fc_v(v) 276 | q = torch.cat(q.split(self.dim_a // self.num_heads, dim=-1), dim=0) 277 | k = torch.cat(k.split(self.dim_a // self.num_heads, dim=-1), dim=0) 278 | v = torch.cat(v.split(self.dim_o // self.num_heads, dim=-1), dim=0) 279 | a = q @ k.transpose(-1, -2) / self.dim_a ** 0.5 280 | if mask is not None: 281 | assert mask.ndim in (2, 3) 282 | if mask.ndim == 3: 283 | mask = mask.repeat(self.num_heads, 1, 1) 284 | if mask.ndim == 2: 285 | mask = mask.unsqueeze(-2).repeat(self.num_heads, tsz, 1) 286 | a.masked_fill_(mask == 0, -65504) 287 | a = self.dropout(torch.softmax(a, dim=-1)) 288 | o = self.fc_o(torch.cat((a @ v).split(bsz, dim=0), dim=-1)) 289 | return o 290 | 291 | 292 | class PositionwiseFFN(nn.Module): 293 | """ 294 | Position-wise FFN [Vaswani et al. NeurIPS 2017]. 295 | """ 296 | def __init__(self, dim, hidden_dim, dropout_prob=0.1): 297 | super().__init__() 298 | self.fc1 = nn.Linear(dim, hidden_dim, bias=True) 299 | self.fc2 = nn.Linear(hidden_dim, dim, bias=True) 300 | self.dropout = nn.Dropout(dropout_prob) 301 | for module in (self.fc1, self.fc2): 302 | nn.init.kaiming_normal_(module.weight) 303 | nn.init.constant_(module.bias, 0.) 304 | 305 | def forward(self, x): 306 | return self.fc2(self.dropout(F.relu(self.fc1(x)))) 307 | 308 | 309 | class EncoderBlock(nn.Module): 310 | """ 311 | Transformer encoder block [Vaswani et al. NeurIPS 2017]. 312 | Note that this is the pre-LN version [Nguyen and Salazar 2019]. 313 | """ 314 | def __init__(self, dim, hidden_dim, num_heads=8, dropout_prob=0.1): 315 | super().__init__() 316 | self.attn = MultiHeadAttention(dim, dim, dim, num_heads, dropout_prob) 317 | self.ffn = PositionwiseFFN(dim, hidden_dim, dropout_prob) 318 | self.dropout = nn.Dropout(dropout_prob) 319 | self.ln1 = nn.LayerNorm(dim) 320 | self.ln2 = nn.LayerNorm(dim) 321 | self.mlp_time_1 = nn.Sequential( 322 | nn.Linear(dim, dim), 323 | nn.ReLU(), 324 | nn.Linear(dim, dim), 325 | ) 326 | self.mlp_time_2 = nn.Sequential( 327 | nn.Linear(dim, dim), 328 | nn.ReLU(), 329 | nn.Linear(dim, dim), 330 | ) 331 | 332 | def forward(self, x, t, mask=None): 333 | x_ = self.ln1(x) 334 | x_ = x_ + self.mlp_time_1(t).unsqueeze(1) 335 | x = x + self.dropout(self.attn(x_, x_, x_, mask)) 336 | x_ = self.ln2(x) 337 | x_ = x_ + self.mlp_time_2(t).unsqueeze(1) 338 | x = x + self.dropout(self.ffn(x_)) 339 | return x 340 | 341 | 342 | class ViT(nn.Module): 343 | """ 344 | Simple version of Vision Transformer [Dosovitsky et al. 2020]. 345 | """ 346 | def __init__(self, in_dim, embed_dim, num_layers, patch_shape): 347 | super().__init__() 348 | 349 | self.patch_h, self.patch_w = patch_shape 350 | self.patch_embed = nn.Linear(in_dim * math.prod(patch_shape), embed_dim, bias=True) 351 | self.time_embed = nn.Parameter(PositionalEmbedding.make_embedding(embed_dim)) 352 | self.pos_embed = nn.Parameter(PositionalEmbedding.make_embedding(embed_dim)) 353 | 354 | self.blocks = nn.ModuleList([ 355 | EncoderBlock(embed_dim, 4 * embed_dim) for _ in range(num_layers) 356 | ]) 357 | self.ln = nn.LayerNorm(embed_dim) 358 | self.out_embed = nn.Linear(embed_dim, in_dim * math.prod(patch_shape), bias=True) 359 | 360 | def forward(self, x, t): 361 | 362 | shape_info = parse_shape(x, "b c h w") 363 | 364 | x = rearrange(x, "b c (h ph) (w pw) -> b (h w) (ph pw c)", ph=self.patch_h, pw=self.patch_w) 365 | x = self.patch_embed(x) + rearrange(self.pos_encoding[:x.shape[1]], "t c -> () t c") 366 | t = self.time_embed[t] 367 | 368 | for block in self.blocks: 369 | x = block(x, t) 370 | 371 | x = self.ln(x) 372 | x = self.out_embed(x) 373 | 374 | x = rearrange(x, "b (h w) (ph pw c) -> b c (h ph) (w pw)", 375 | h=shape_info["h"] // self.patch_h, 376 | w=shape_info["w"] // self.patch_w, 377 | ph=self.patch_h, 378 | pw=self.patch_w, 379 | ) 380 | return x 381 | --------------------------------------------------------------------------------