├── dog.jpg ├── gifs ├── diffusion2.gif ├── diffusion26.gif └── diffusion34.gif ├── models ├── util.py ├── embedding.py ├── attention.py ├── block.py ├── diffusion.py └── unet.py ├── LICENSE ├── .gitignore └── README.md /dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpan0/denoising-diffusion/HEAD/dog.jpg -------------------------------------------------------------------------------- /gifs/diffusion2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpan0/denoising-diffusion/HEAD/gifs/diffusion2.gif -------------------------------------------------------------------------------- /gifs/diffusion26.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpan0/denoising-diffusion/HEAD/gifs/diffusion26.gif -------------------------------------------------------------------------------- /gifs/diffusion34.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpan0/denoising-diffusion/HEAD/gifs/diffusion34.gif -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | 3 | def exists(x): 4 | return x is not None 5 | 6 | def default(val, d): 7 | if exists(val): 8 | return val 9 | return d() if isfunction(d) else d 10 | 11 | def extract(a, t, x_shape): 12 | batch_size = t.shape[0] 13 | out = a.gather(-1, t.cpu()) 14 | return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) -------------------------------------------------------------------------------- /models/embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | # Location encoding 6 | class SinusoidalPositionEmbeddings(nn.Module): 7 | def __init__(self, dim): 8 | super().__init__() 9 | self.dim = dim 10 | 11 | def forward(self, time): 12 | device = time.device 13 | half_dim = self.dim // 2 14 | embeddings = math.log(10000) / (half_dim - 1) 15 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 16 | embeddings = time[:, None] * embeddings[None, :] 17 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 18 | return embeddings 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Justin Pan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | from torch import nn, einsum 2 | from einops import rearrange 3 | import torch 4 | 5 | # Attention 6 | class Attention(nn.Module): 7 | def __init__(self, dim, heads=4, dim_head=32): 8 | super().__init__() 9 | self.scale = dim_head**-0.5 10 | self.heads = heads 11 | hidden_dim = dim_head * heads 12 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 13 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 14 | 15 | def forward(self, x): 16 | b, c, h, w = x.shape 17 | qkv = self.to_qkv(x).chunk(3, dim=1) 18 | q, k, v = map( 19 | lambda t : rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 20 | ) 21 | q = q * self.scale 22 | 23 | sim = einsum("b h d i, b h d j -> b h i j", q, k) 24 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 25 | attn = sim.softmax(dim=-1) 26 | 27 | out = einsum("b h i j, b h d j -> b h i d", attn, v) 28 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) 29 | return self.to_out(out) 30 | 31 | class LinearAttention(nn.Module): 32 | def __init__(self, dim, heads=4, dim_head=32): 33 | super().__init__() 34 | self.scale = dim_head**-0.5 35 | self.heads = heads 36 | hidden_dim = dim_head * heads 37 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 38 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) 39 | 40 | def forward(self, x): 41 | b, c, h, w = x.shape 42 | qkv = self.to_qkv(x).chunk(3, dim=1) 43 | q, k, v = map( 44 | lambda t : rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 45 | ) 46 | 47 | q = q.softmax(dim=-2) 48 | k = k.softmax(dim=-1) 49 | 50 | q = q * self.scale 51 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v) 52 | 53 | out = einsum("b h d e, b h d n -> b h e n", context, q) 54 | out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) 55 | return self.to_out(out) 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *__pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /models/block.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from einops import rearrange 3 | 4 | def exists(x): 5 | return x is not None 6 | 7 | class Block(nn.Module): 8 | def __init__(self, dim, dim_out, groups=8): 9 | super().__init__() 10 | self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) 11 | self.norm = nn.GroupNorm(groups, dim_out) 12 | self.act = nn.SiLU() 13 | 14 | def forward(self, x, scale_shift=None): 15 | x = self.proj(x) 16 | x = self.norm(x) 17 | 18 | if exists(scale_shift): 19 | scale, shift = scale_shift 20 | x = x * (scale + 1) + shift 21 | 22 | x = self.act(x) 23 | return x 24 | 25 | class ResnetBlock(nn.Module): 26 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): 27 | super().__init__() 28 | self.mlp = ( 29 | nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) 30 | if exists(time_emb_dim) 31 | else None 32 | ) 33 | 34 | self.block1 = Block(dim, dim_out, groups) 35 | self.block2 = Block(dim_out, dim_out, groups) 36 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 37 | 38 | def forward(self, x, time_emb=None): 39 | h = self.block1(x) 40 | 41 | if exists(self.mlp) and exists(time_emb): 42 | time_emb = self.mlp(time_emb) 43 | h = rearrange(time_emb, "b c -> b c 1 1") + h 44 | h = self.block2(h) 45 | return h + self.res_conv(x) 46 | 47 | class ConvNextBlock(nn.Module): 48 | def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True): 49 | super().__init__() 50 | self.mlp = ( 51 | nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim)) 52 | if exists(time_emb_dim) 53 | else None 54 | ) 55 | 56 | self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim) 57 | 58 | self.net = nn.Sequential( 59 | nn.GroupNorm(1, dim) if norm else nn.Identity(), 60 | nn.Conv2d(dim, dim_out * mult, 3, padding=1), 61 | nn.GELU(), 62 | nn.GroupNorm(1, dim_out * mult), 63 | nn.Conv2d(dim_out * mult, dim_out, 3, padding=1) 64 | ) 65 | 66 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 67 | 68 | def forward(self, x, time_emb=None): 69 | h = self.ds_conv(x) 70 | 71 | if exists(self.mlp) and exists(time_emb): 72 | assert exists(time_emb), "time embedding must be passed in" 73 | condition = self.mlp(time_emb) 74 | h = h + rearrange(condition, "b c -> b c 1 1") 75 | 76 | h = self.net(h) 77 | return h + self.res_conv(x) -------------------------------------------------------------------------------- /models/diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm.auto import tqdm 4 | 5 | from .util import extract 6 | 7 | # Forward Diffusion Process 8 | def cosine_beta_schedule(timesteps, s=0.008): 9 | steps = timesteps + 1 10 | x = torch.linspace(0, timesteps, steps) 11 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 12 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 13 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 14 | return torch.clip(betas, 0.0001, 0.9999) 15 | 16 | def linear_beta_schedule(timesteps): 17 | beta_start = 0.0001 18 | beta_end = 0.02 19 | return torch.linspace(beta_start, beta_end, timesteps) 20 | 21 | def quadratic_beta_schedule(timesteps): 22 | beta_start = 0.0001 23 | beta_end = 0.02 24 | return torch.linspace(beta_start ** 0.5, beta_end ** 0.5, timesteps) ** 2 25 | 26 | def sigmoid_beta_schedule(timesteps): 27 | beta_start = 0.0001 28 | beta_end = 0.02 29 | betas = torch.linspace(-6, 6, timesteps) 30 | return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 31 | 32 | # Forward Diffusion 33 | def q_sample(x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None): 34 | if noise is None: 35 | noise = torch.randn_like(x_start) 36 | 37 | sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape) 38 | sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) 39 | return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise 40 | 41 | def p_losses(denoise_model, x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None, loss_type="l1"): 42 | if noise is None: 43 | noise = torch.randn_like(x_start) 44 | 45 | x_noisy = q_sample(x_start=x_start, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise) 46 | predicted_noise = denoise_model(x_noisy, t) 47 | 48 | if loss_type == "l1": 49 | return F.l1_loss(noise, predicted_noise) 50 | if loss_type == "l2": 51 | return F.mse_loss(noise, predicted_noise) 52 | if loss_type == "huber": 53 | return F.smooth_l1_loss(noise, predicted_noise) 54 | else: 55 | raise NotImplementedError() 56 | 57 | # Reverse Process 58 | @torch.no_grad() 59 | def p_sample(model, x, t, betas, sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas, posterior_variance, t_index): 60 | betas_t = extract(betas, t, x.shape) 61 | sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape) 62 | sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape) 63 | 64 | # use noise predictor to predict the mean 65 | model_mean = sqrt_recip_alphas_t * ( 66 | x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t 67 | ) 68 | 69 | if t_index == 0: 70 | return model_mean 71 | else: 72 | posterior_variance_t = extract(posterior_variance, t, x.shape) 73 | noise = torch.randn_like(x) 74 | return model_mean + torch.sqrt(posterior_variance_t) * noise #Algorithm 2 line 4 75 | 76 | # Algorithm 2 but save all images 77 | @torch.no_grad() 78 | def p_sample_loop(model, shape, timesteps, betas, sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas, posterior_variance): 79 | device = next(model.parameters()).device 80 | 81 | b = shape[0] 82 | # start from pure noise 83 | img = torch.randn(shape, device=device) 84 | imgs = [] 85 | 86 | for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps): 87 | img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), betas, sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas, posterior_variance, i) 88 | imgs.append(img.cpu().numpy()) 89 | return imgs 90 | 91 | @torch.no_grad() 92 | def sample(model, timesteps, image_size, betas, sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas, posterior_variance, batch_size=16, channels=3): 93 | return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size), timesteps=timesteps, betas=betas, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, sqrt_recip_alphas=sqrt_recip_alphas, posterior_variance=posterior_variance) -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from functools import partial 4 | 5 | from .embedding import SinusoidalPositionEmbeddings 6 | from .block import ConvNextBlock, ResnetBlock 7 | from .attention import Attention, LinearAttention 8 | from .util import exists, default 9 | 10 | class Residual(nn.Module): 11 | def __init__(self, fn): 12 | super().__init__() 13 | self.fn = fn 14 | 15 | def forward(self, x, *args, **kwargs): 16 | return self.fn(x, *args, **kwargs) + x 17 | 18 | def up_sample(dim): 19 | return nn.ConvTranspose2d(dim, dim, 4, 2, 1) 20 | 21 | def down_sample(dim): 22 | return nn.Conv2d(dim, dim, 4, 2, 1) 23 | 24 | # Group normalization 25 | class PreNorm(nn.Module): 26 | def __init__(self, dim, fn): 27 | super().__init__() 28 | self.fn = fn 29 | self.norm = nn.GroupNorm(1, dim) 30 | 31 | def forward(self, x): 32 | x = self.norm(x) 33 | return self.fn(x) 34 | 35 | # U-Net 36 | class Unet(nn.Module): 37 | def __init__( 38 | self, 39 | dim, 40 | init_dim=None, 41 | out_dim=None, 42 | dim_mults=(1, 2, 4, 8), 43 | channels=3, 44 | with_time_emb=True, 45 | resnet_block_groups=8, 46 | use_convnext=True, 47 | convnext_mult=2 48 | ): 49 | super().__init__() 50 | 51 | # determine dimensions 52 | self.channels = channels 53 | 54 | init_dim = default(init_dim, dim // 3 * 2) 55 | self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) 56 | 57 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 58 | in_out = list(zip(dims[:-1], dims[1:])) 59 | 60 | if use_convnext: 61 | block_klass = partial(ConvNextBlock, mult=convnext_mult) 62 | else: 63 | block_klass = partial(ResnetBlock, groups=resnet_block_groups) 64 | 65 | # time embeddings 66 | if with_time_emb: 67 | time_dim = dim * 4 68 | self.time_mlp = nn.Sequential( 69 | SinusoidalPositionEmbeddings(dim), 70 | nn.Linear(dim, time_dim), 71 | nn.GELU(), 72 | nn.Linear(time_dim, time_dim) 73 | ) 74 | else: 75 | time_dim = None 76 | self.time_mlp = None 77 | 78 | # layers 79 | self.downs = nn.ModuleList([]) 80 | self.ups = nn.ModuleList([]) 81 | num_resolutions = len(in_out) 82 | 83 | for ind, (dim_in, dim_out) in enumerate(in_out): 84 | is_last = ind >= (num_resolutions - 1) 85 | 86 | self.downs.append( 87 | nn.ModuleList( 88 | [ 89 | block_klass(dim_in, dim_out, time_emb_dim=time_dim), 90 | block_klass(dim_out, dim_out, time_emb_dim=time_dim), 91 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 92 | down_sample(dim_out) if not is_last else nn.Identity(), 93 | ] 94 | ) 95 | ) 96 | 97 | mid_dim = dims[-1] 98 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 99 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 100 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 101 | 102 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 103 | is_last = ind >= (num_resolutions - 1) 104 | 105 | self.ups.append( 106 | nn.ModuleList( 107 | [ 108 | block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), 109 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 110 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 111 | up_sample(dim_in) if not is_last else nn.Identity(), 112 | ] 113 | ) 114 | ) 115 | 116 | out_dim = default(out_dim, channels) 117 | self.final_conv = nn.Sequential( 118 | block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1) 119 | ) 120 | 121 | def forward(self, x, time): 122 | x = self.init_conv(x) 123 | 124 | t = self.time_mlp(time) if exists(self.time_mlp) else None 125 | 126 | h = [] 127 | 128 | # down sampling 129 | for block1, block2, attn, downsample in self.downs: 130 | x = block1(x, t) 131 | x = block2(x, t) 132 | x = attn(x) 133 | h.append(x) 134 | x = downsample(x) 135 | 136 | # bottleneck 137 | x = self.mid_block1(x, t) 138 | x = self.mid_attn(x) 139 | x = self.mid_block2(x, t) 140 | 141 | # up sampling 142 | for block1, block2, attn, upsample in self.ups: 143 | x = torch.cat((x, h.pop()), dim=1) 144 | x = block1(x, t) 145 | x = block2(x, t) 146 | x = attn(x) 147 | x = upsample(x) 148 | 149 | return self.final_conv(x) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Structured Denoising Diffusion Models in Discrete State-Spaces 2 | 3 | ## Getting Started 4 | 5 | See [annotated_diffusion.ipynb](./annotated_diffusion.ipynb) for a walkthrough about the forward diffusion process. 6 | 7 | See [annotated_traning.ipynb](./annotated_training.ipynb) for a detailed guide to train a model and generate samples from the model. 8 | 9 | Training with `fashion_mnist` with 20 epochs, here are a few example GIFs of the image generation process: 10 | 11 |
12 |
13 |
14 |
15 |