├── dog.jpg ├── gifs ├── diffusion2.gif ├── diffusion26.gif └── diffusion34.gif ├── models ├── util.py ├── embedding.py ├── attention.py ├── block.py ├── diffusion.py └── unet.py ├── LICENSE ├── .gitignore └── README.md /dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpan0/denoising-diffusion/HEAD/dog.jpg -------------------------------------------------------------------------------- /gifs/diffusion2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpan0/denoising-diffusion/HEAD/gifs/diffusion2.gif -------------------------------------------------------------------------------- /gifs/diffusion26.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpan0/denoising-diffusion/HEAD/gifs/diffusion26.gif -------------------------------------------------------------------------------- /gifs/diffusion34.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpan0/denoising-diffusion/HEAD/gifs/diffusion34.gif -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | 3 | def exists(x): 4 | return x is not None 5 | 6 | def default(val, d): 7 | if exists(val): 8 | return val 9 | return d() if isfunction(d) else d 10 | 11 | def extract(a, t, x_shape): 12 | batch_size = t.shape[0] 13 | out = a.gather(-1, t.cpu()) 14 | return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) -------------------------------------------------------------------------------- /models/embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | # Location encoding 6 | class SinusoidalPositionEmbeddings(nn.Module): 7 | def __init__(self, dim): 8 | super().__init__() 9 | self.dim = dim 10 | 11 | def forward(self, time): 12 | device = time.device 13 | half_dim = self.dim // 2 14 | embeddings = math.log(10000) / (half_dim - 1) 15 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 16 | embeddings = time[:, None] * embeddings[None, :] 17 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 18 | return embeddings 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Justin Pan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | from torch import nn, einsum 2 | from einops import rearrange 3 | import torch 4 | 5 | # Attention 6 | class Attention(nn.Module): 7 | def __init__(self, dim, heads=4, dim_head=32): 8 | super().__init__() 9 | self.scale = dim_head**-0.5 10 | self.heads = heads 11 | hidden_dim = dim_head * heads 12 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 13 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 14 | 15 | def forward(self, x): 16 | b, c, h, w = x.shape 17 | qkv = self.to_qkv(x).chunk(3, dim=1) 18 | q, k, v = map( 19 | lambda t : rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 20 | ) 21 | q = q * self.scale 22 | 23 | sim = einsum("b h d i, b h d j -> b h i j", q, k) 24 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 25 | attn = sim.softmax(dim=-1) 26 | 27 | out = einsum("b h i j, b h d j -> b h i d", attn, v) 28 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) 29 | return self.to_out(out) 30 | 31 | class LinearAttention(nn.Module): 32 | def __init__(self, dim, heads=4, dim_head=32): 33 | super().__init__() 34 | self.scale = dim_head**-0.5 35 | self.heads = heads 36 | hidden_dim = dim_head * heads 37 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 38 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) 39 | 40 | def forward(self, x): 41 | b, c, h, w = x.shape 42 | qkv = self.to_qkv(x).chunk(3, dim=1) 43 | q, k, v = map( 44 | lambda t : rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 45 | ) 46 | 47 | q = q.softmax(dim=-2) 48 | k = k.softmax(dim=-1) 49 | 50 | q = q * self.scale 51 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v) 52 | 53 | out = einsum("b h d e, b h d n -> b h e n", context, q) 54 | out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) 55 | return self.to_out(out) 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *__pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /models/block.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from einops import rearrange 3 | 4 | def exists(x): 5 | return x is not None 6 | 7 | class Block(nn.Module): 8 | def __init__(self, dim, dim_out, groups=8): 9 | super().__init__() 10 | self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) 11 | self.norm = nn.GroupNorm(groups, dim_out) 12 | self.act = nn.SiLU() 13 | 14 | def forward(self, x, scale_shift=None): 15 | x = self.proj(x) 16 | x = self.norm(x) 17 | 18 | if exists(scale_shift): 19 | scale, shift = scale_shift 20 | x = x * (scale + 1) + shift 21 | 22 | x = self.act(x) 23 | return x 24 | 25 | class ResnetBlock(nn.Module): 26 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): 27 | super().__init__() 28 | self.mlp = ( 29 | nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) 30 | if exists(time_emb_dim) 31 | else None 32 | ) 33 | 34 | self.block1 = Block(dim, dim_out, groups) 35 | self.block2 = Block(dim_out, dim_out, groups) 36 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 37 | 38 | def forward(self, x, time_emb=None): 39 | h = self.block1(x) 40 | 41 | if exists(self.mlp) and exists(time_emb): 42 | time_emb = self.mlp(time_emb) 43 | h = rearrange(time_emb, "b c -> b c 1 1") + h 44 | h = self.block2(h) 45 | return h + self.res_conv(x) 46 | 47 | class ConvNextBlock(nn.Module): 48 | def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True): 49 | super().__init__() 50 | self.mlp = ( 51 | nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim)) 52 | if exists(time_emb_dim) 53 | else None 54 | ) 55 | 56 | self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim) 57 | 58 | self.net = nn.Sequential( 59 | nn.GroupNorm(1, dim) if norm else nn.Identity(), 60 | nn.Conv2d(dim, dim_out * mult, 3, padding=1), 61 | nn.GELU(), 62 | nn.GroupNorm(1, dim_out * mult), 63 | nn.Conv2d(dim_out * mult, dim_out, 3, padding=1) 64 | ) 65 | 66 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 67 | 68 | def forward(self, x, time_emb=None): 69 | h = self.ds_conv(x) 70 | 71 | if exists(self.mlp) and exists(time_emb): 72 | assert exists(time_emb), "time embedding must be passed in" 73 | condition = self.mlp(time_emb) 74 | h = h + rearrange(condition, "b c -> b c 1 1") 75 | 76 | h = self.net(h) 77 | return h + self.res_conv(x) -------------------------------------------------------------------------------- /models/diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm.auto import tqdm 4 | 5 | from .util import extract 6 | 7 | # Forward Diffusion Process 8 | def cosine_beta_schedule(timesteps, s=0.008): 9 | steps = timesteps + 1 10 | x = torch.linspace(0, timesteps, steps) 11 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 12 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 13 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 14 | return torch.clip(betas, 0.0001, 0.9999) 15 | 16 | def linear_beta_schedule(timesteps): 17 | beta_start = 0.0001 18 | beta_end = 0.02 19 | return torch.linspace(beta_start, beta_end, timesteps) 20 | 21 | def quadratic_beta_schedule(timesteps): 22 | beta_start = 0.0001 23 | beta_end = 0.02 24 | return torch.linspace(beta_start ** 0.5, beta_end ** 0.5, timesteps) ** 2 25 | 26 | def sigmoid_beta_schedule(timesteps): 27 | beta_start = 0.0001 28 | beta_end = 0.02 29 | betas = torch.linspace(-6, 6, timesteps) 30 | return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 31 | 32 | # Forward Diffusion 33 | def q_sample(x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None): 34 | if noise is None: 35 | noise = torch.randn_like(x_start) 36 | 37 | sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape) 38 | sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) 39 | return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise 40 | 41 | def p_losses(denoise_model, x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None, loss_type="l1"): 42 | if noise is None: 43 | noise = torch.randn_like(x_start) 44 | 45 | x_noisy = q_sample(x_start=x_start, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise) 46 | predicted_noise = denoise_model(x_noisy, t) 47 | 48 | if loss_type == "l1": 49 | return F.l1_loss(noise, predicted_noise) 50 | if loss_type == "l2": 51 | return F.mse_loss(noise, predicted_noise) 52 | if loss_type == "huber": 53 | return F.smooth_l1_loss(noise, predicted_noise) 54 | else: 55 | raise NotImplementedError() 56 | 57 | # Reverse Process 58 | @torch.no_grad() 59 | def p_sample(model, x, t, betas, sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas, posterior_variance, t_index): 60 | betas_t = extract(betas, t, x.shape) 61 | sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape) 62 | sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape) 63 | 64 | # use noise predictor to predict the mean 65 | model_mean = sqrt_recip_alphas_t * ( 66 | x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t 67 | ) 68 | 69 | if t_index == 0: 70 | return model_mean 71 | else: 72 | posterior_variance_t = extract(posterior_variance, t, x.shape) 73 | noise = torch.randn_like(x) 74 | return model_mean + torch.sqrt(posterior_variance_t) * noise #Algorithm 2 line 4 75 | 76 | # Algorithm 2 but save all images 77 | @torch.no_grad() 78 | def p_sample_loop(model, shape, timesteps, betas, sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas, posterior_variance): 79 | device = next(model.parameters()).device 80 | 81 | b = shape[0] 82 | # start from pure noise 83 | img = torch.randn(shape, device=device) 84 | imgs = [] 85 | 86 | for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps): 87 | img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), betas, sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas, posterior_variance, i) 88 | imgs.append(img.cpu().numpy()) 89 | return imgs 90 | 91 | @torch.no_grad() 92 | def sample(model, timesteps, image_size, betas, sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas, posterior_variance, batch_size=16, channels=3): 93 | return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size), timesteps=timesteps, betas=betas, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas=sqrt_recip_alphas, posterior_variance=posterior_variance) -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from functools import partial 4 | 5 | from .embedding import SinusoidalPositionEmbeddings 6 | from .block import ConvNextBlock, ResnetBlock 7 | from .attention import Attention, LinearAttention 8 | from .util import exists, default 9 | 10 | class Residual(nn.Module): 11 | def __init__(self, fn): 12 | super().__init__() 13 | self.fn = fn 14 | 15 | def forward(self, x, *args, **kwargs): 16 | return self.fn(x, *args, **kwargs) + x 17 | 18 | def up_sample(dim): 19 | return nn.ConvTranspose2d(dim, dim, 4, 2, 1) 20 | 21 | def down_sample(dim): 22 | return nn.Conv2d(dim, dim, 4, 2, 1) 23 | 24 | # Group normalization 25 | class PreNorm(nn.Module): 26 | def __init__(self, dim, fn): 27 | super().__init__() 28 | self.fn = fn 29 | self.norm = nn.GroupNorm(1, dim) 30 | 31 | def forward(self, x): 32 | x = self.norm(x) 33 | return self.fn(x) 34 | 35 | # U-Net 36 | class Unet(nn.Module): 37 | def __init__( 38 | self, 39 | dim, 40 | init_dim=None, 41 | out_dim=None, 42 | dim_mults=(1, 2, 4, 8), 43 | channels=3, 44 | with_time_emb=True, 45 | resnet_block_groups=8, 46 | use_convnext=True, 47 | convnext_mult=2 48 | ): 49 | super().__init__() 50 | 51 | # determine dimensions 52 | self.channels = channels 53 | 54 | init_dim = default(init_dim, dim // 3 * 2) 55 | self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) 56 | 57 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 58 | in_out = list(zip(dims[:-1], dims[1:])) 59 | 60 | if use_convnext: 61 | block_klass = partial(ConvNextBlock, mult=convnext_mult) 62 | else: 63 | block_klass = partial(ResnetBlock, groups=resnet_block_groups) 64 | 65 | # time embeddings 66 | if with_time_emb: 67 | time_dim = dim * 4 68 | self.time_mlp = nn.Sequential( 69 | SinusoidalPositionEmbeddings(dim), 70 | nn.Linear(dim, time_dim), 71 | nn.GELU(), 72 | nn.Linear(time_dim, time_dim) 73 | ) 74 | else: 75 | time_dim = None 76 | self.time_mlp = None 77 | 78 | # layers 79 | self.downs = nn.ModuleList([]) 80 | self.ups = nn.ModuleList([]) 81 | num_resolutions = len(in_out) 82 | 83 | for ind, (dim_in, dim_out) in enumerate(in_out): 84 | is_last = ind >= (num_resolutions - 1) 85 | 86 | self.downs.append( 87 | nn.ModuleList( 88 | [ 89 | block_klass(dim_in, dim_out, time_emb_dim=time_dim), 90 | block_klass(dim_out, dim_out, time_emb_dim=time_dim), 91 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 92 | down_sample(dim_out) if not is_last else nn.Identity(), 93 | ] 94 | ) 95 | ) 96 | 97 | mid_dim = dims[-1] 98 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 99 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 100 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 101 | 102 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 103 | is_last = ind >= (num_resolutions - 1) 104 | 105 | self.ups.append( 106 | nn.ModuleList( 107 | [ 108 | block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), 109 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 110 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 111 | up_sample(dim_in) if not is_last else nn.Identity(), 112 | ] 113 | ) 114 | ) 115 | 116 | out_dim = default(out_dim, channels) 117 | self.final_conv = nn.Sequential( 118 | block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1) 119 | ) 120 | 121 | def forward(self, x, time): 122 | x = self.init_conv(x) 123 | 124 | t = self.time_mlp(time) if exists(self.time_mlp) else None 125 | 126 | h = [] 127 | 128 | # down sampling 129 | for block1, block2, attn, downsample in self.downs: 130 | x = block1(x, t) 131 | x = block2(x, t) 132 | x = attn(x) 133 | h.append(x) 134 | x = downsample(x) 135 | 136 | # bottleneck 137 | x = self.mid_block1(x, t) 138 | x = self.mid_attn(x) 139 | x = self.mid_block2(x, t) 140 | 141 | # up sampling 142 | for block1, block2, attn, upsample in self.ups: 143 | x = torch.cat((x, h.pop()), dim=1) 144 | x = block1(x, t) 145 | x = block2(x, t) 146 | x = attn(x) 147 | x = upsample(x) 148 | 149 | return self.final_conv(x) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Structured Denoising Diffusion Models in Discrete State-Spaces 2 | 3 | ## Getting Started 4 | 5 | See [annotated_diffusion.ipynb](./annotated_diffusion.ipynb) for a walkthrough about the forward diffusion process. 6 | 7 | See [annotated_traning.ipynb](./annotated_training.ipynb) for a detailed guide to train a model and generate samples from the model. 8 | 9 | Training with `fashion_mnist` with 20 epochs, here are a few example GIFs of the image generation process: 10 | 11 |

12 | 13 | 14 | 15 |

16 | 17 | Read Papers to learn more: 18 | > [Structured Denoising Diffusion Models in Discrete State-Spaces 19 | ](https://arxiv.org/abs/2107.03006) 20 | 21 | > [Understanding Diffusion Models: A Unified Perspective 22 | ](https://arxiv.org/abs/2208.11970) 23 | 24 | ## Variational Autoencoder (VAE) 25 | 26 | ### Parameters 27 | 28 | - Joint distribution: $p(x,z)$ 29 | 30 | - Posterior: $q_\phi(z|x)$ 31 | 32 | ### Evidence Lower Bound 33 | 34 | $$ 35 | \begin{align*} 36 | \log p(x) 37 | &\geq \mathbb{E}_{q_\phi(z|x)}[\log \frac{p(x,z)}{q_\phi(z|x)}] \\ 38 | &= \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_{KL}(q_\phi(z|x) \ || \ p(z)) 39 | \end{align*} 40 | $$ 41 | 42 | - $\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]$ measures the reconstruction likelihood of the decoder from the variational distribution. (Monte Carlo estimate) 43 | 44 | - $D_{KL}(q_\phi(z|x) \ || \ p(z))$ measures how similar the learned variational distribution is to a prior belief held over latent variables. (Analytical calculation) 45 | 46 | ### Objective 47 | 48 | $$ 49 | \begin{align*} 50 | \arg\max_{\phi,\theta} \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_{KL}(q_\phi(z|x) \ || \ p(z)) \\ 51 | \approx \arg\max_{\phi,\theta} \sum^L_{l=1}[\log p_\theta(x|z^{(l)})] - D_{KL}(q_\phi(z|x) \ || \ p(z)) 52 | \end{align*} 53 | $$ 54 | 55 | where latents $\{ z^{(l)} \}^L_{l=1}$ are sampled from $q_\phi(z|x)$. 56 | 57 | ## Hierarchical Variational Autoencoder (HVAE) 58 | 59 | Stacking VAEs on top of each other. 60 | 61 | ### Parameters 62 | 63 | - Joint distribution: $p(x,z_{1:T}) = p(z_T)p_\theta(x|z_1)\prod^T_{t=2}p_\theta(z_{t-1}|z_t)$ 64 | 65 | - Posterior: $q_\phi(z|x) = q_\theta(z_1|x)\prod^T_{t=2}q_\theta(z_t|z_{t-1})$ 66 | 67 | ### Evidence Lower Bound 68 | 69 | $$ 70 | \begin{align*} 71 | \log p(x) 72 | &\geq \mathbb{E}_{q_\phi(z_{1:T}|x)}[\log \frac{p(x,z_{1:T})}{q_\phi(z_{1:T}|x)}] \\ 73 | &= \mathbb{E}_{q_\phi(z_{1:T}|x)}[\log \frac{p(z_T)p_\theta(x|z_1)\prod^T_{t=2}p_\theta(z_{t-1}|z_t)}{q_\theta(z_1|x)\prod^T_{t=2}q_\theta(z_t|z_{t-1})}] 74 | \end{align*} 75 | $$ 76 | 77 | ### Objective 78 | 79 | Similar to VAE. 80 | 81 | ## Variational Diffusion Models (VDM) 82 | 83 | HVAE but with three key restrictions: 84 | 85 | - The latent dimension is exactly equal to the data dimension 86 | $\implies q_\phi(z_{1:T}|x)= q(z_{1:T}|x_0) = \prod^T_{t=1}q(x_t|x_{t-1})$ 87 | 88 | - The structure of the latent encoder at each timestep is not learned; it is pre-defined as a linear Gaussian model 89 | $\implies$ The latent encoder is a Gaussian distribution centered around the output of the previous timestep $\implies q(x_t|x_{t-1}) = \mathcal{N}(x_t;\sqrt{\alpha_t}x_{t-1}, (1-\alpha_t)\pmb{I})$ 90 | 91 | - The Gaussian parameters of the latent encoders vary over time in such a way that the distribution of the latent at final timestep T is a standard Gaussian $\implies p(x_T)=\mathcal{N}(x_T;0,\pmb{I})$, which is pure noise 92 | 93 | ### Parameters 94 | 95 | - Joint distribution: $p(x_{0:T})$ 96 | 97 | - Posterior: $q(z_{1:T}|x_0) = \prod^T_{t=1}q(x_t|x_{t-1})$ 98 | 99 | ### Evidence Lower Bound 100 | 101 | $$ 102 | \begin{align*} 103 | \log p(x) 104 | \geq \mathbb{E}_{q(x_1|x_0)}[\log p_\theta(x_0|x_1)] - D_{KL}(q(x_T|x_0) \ || \ p(x_T)) \\ 105 | - \sum^T_{t=2} \mathbb{E}_{q(x_t|x_0)}[D_{KL}(q(x_{t-1}|x_t,x_0) \ || \ p_\theta(x_{t-1}|x_t))] 106 | \end{align*} 107 | $$ 108 | 109 | - The first term measures the reconstruction likelihood of the decoder from the variational distribution. (Monte Carlo estimate) 110 | 111 | - The second term measures how close the distribution of the final nosisified input is to the standard Gaussian prior. 112 | > Note that it has no trainable parameters, and is also equal to zero under the assumptions. 113 | 114 | - The third term is for *denoising matching*. We learn desired denoising transition step $p_\theta(x_{t-1}|x_t)$ as an approximation to tracable, ground-truth denoising transition step $q(x_{t-1}|x_t, x_0)$. 115 | > Note that when $T=1$, VDM's ELBO falls back into VAE's. 116 | 117 | > Note that the *denoising matching term* dominates the overall optimization cost because of the summation term. 118 | 119 | ### Objective 120 | For learning a neural network to predict the original ground truth image from an arbitrarily noisified version of it, minimize the summation term of the derived ELBO objective across all noise levels, which can be approximated by minimizing the expectation over all timesteps: 121 | 122 | $$ 123 | \arg\min_\theta \mathbb{E}_{t \sim U{2,T}}[\mathbb{E}_{q(x_t|x_0)}D_{KL}(q(x_{t-1}|x_t,x_0) \ || \ p_\theta(x_{t-1}|x_t))] 124 | $$ 125 | 126 | which can be optimized using stochastic samples over timesteps. 127 | 128 | For generating a novel $x_0$, sample Gaussian noise from $p(x_T) and iteratively running the denoising transitions $p_\theta(x_{t-1} | x_t)$ for T steps. 129 | --------------------------------------------------------------------------------