├── LICENSE ├── README.md ├── gif_mnist_01.gif ├── guided_mnist.png ├── pretrained_model.zip └── script.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tim Pearce 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional Diffusion MNIST 2 | 3 | [script.py](script.py) is a minimal, self-contained implementation of a conditional diffusion model. It learns to generate MNIST digits, conditioned on a class label. The neural network architecture is a small U-Net (pretrained weights also available in this repo). This code is modified from [this excellent repo](https://github.com/cloneofsimo/minDiffusion) which does unconditional generation. The diffusion model is a [Denoising Diffusion Probabilistic Model (DDPM)](https://arxiv.org/abs/2006.11239). 4 |

5 | 6 |

7 |

8 | Samples generated from the model. 9 |

10 | 11 | The conditioning roughly follows the method described in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598) (also used in [ImageGen](https://arxiv.org/abs/2205.11487)). The model infuses timestep embeddings $t_e$ and context embeddings $c_e$ with the U-Net activations at a certain layer $a_L$, via, 12 |

13 | $a_{L+1} = c_e a_L + t_e.$ 14 |

15 | (Though in our experimentation, we found variants of this also work, e.g. concatenating embeddings together.) 16 | 17 | At training time, $c_e$ is randomly set to zero with probability $0.1$, so the model learns to do unconditional generation (say $\psi(z_t)$ for noise $z_t$ at timestep $t$) and also conditional generation (say $\psi(z_t, c)$ for context $c$). This is important as at generation time, we choose a weight, $w \geq 0$, to guide the model to generate examples with the following equation, 18 |

19 | $\hat{\epsilon}_{t} = (1+w)\psi(z_t, c) - w \psi(z_t).$ 20 |

21 | 22 | Increasing $w$ produces images that are more typical but less diverse. 23 | 24 |

25 | 26 |

27 |

28 | Samples produced with varying guidance strength, $w$. 29 |

30 | 31 | Training for above models took around 20 epochs (~20 minutes). 32 | 33 | `pretrained_model.zip` contains pretrained weights. 34 | 35 | -------------------------------------------------------------------------------- /gif_mnist_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TeaPearce/Conditional_Diffusion_MNIST/657b49854bbf844caa5580177243ae607c79339f/gif_mnist_01.gif -------------------------------------------------------------------------------- /guided_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TeaPearce/Conditional_Diffusion_MNIST/657b49854bbf844caa5580177243ae607c79339f/guided_mnist.png -------------------------------------------------------------------------------- /pretrained_model.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TeaPearce/Conditional_Diffusion_MNIST/657b49854bbf844caa5580177243ae607c79339f/pretrained_model.zip -------------------------------------------------------------------------------- /script.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script does conditional image generation on MNIST, using a diffusion model 3 | 4 | This code is modified from, 5 | https://github.com/cloneofsimo/minDiffusion 6 | 7 | Diffusion model is based on DDPM, 8 | https://arxiv.org/abs/2006.11239 9 | 10 | The conditioning idea is taken from 'Classifier-Free Diffusion Guidance', 11 | https://arxiv.org/abs/2207.12598 12 | 13 | This technique also features in ImageGen 'Photorealistic Text-to-Image Diffusion Modelswith Deep Language Understanding', 14 | https://arxiv.org/abs/2205.11487 15 | 16 | ''' 17 | 18 | from typing import Dict, Tuple 19 | from tqdm import tqdm 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from torch.utils.data import DataLoader 24 | from torchvision import models, transforms 25 | from torchvision.datasets import MNIST 26 | from torchvision.utils import save_image, make_grid 27 | import matplotlib.pyplot as plt 28 | from matplotlib.animation import FuncAnimation, PillowWriter 29 | import numpy as np 30 | 31 | class ResidualConvBlock(nn.Module): 32 | def __init__( 33 | self, in_channels: int, out_channels: int, is_res: bool = False 34 | ) -> None: 35 | super().__init__() 36 | ''' 37 | standard ResNet style convolutional block 38 | ''' 39 | self.same_channels = in_channels==out_channels 40 | self.is_res = is_res 41 | self.conv1 = nn.Sequential( 42 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 43 | nn.BatchNorm2d(out_channels), 44 | nn.GELU(), 45 | ) 46 | self.conv2 = nn.Sequential( 47 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 48 | nn.BatchNorm2d(out_channels), 49 | nn.GELU(), 50 | ) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | if self.is_res: 54 | x1 = self.conv1(x) 55 | x2 = self.conv2(x1) 56 | # this adds on correct residual in case channels have increased 57 | if self.same_channels: 58 | out = x + x2 59 | else: 60 | out = x1 + x2 61 | return out / 1.414 62 | else: 63 | x1 = self.conv1(x) 64 | x2 = self.conv2(x1) 65 | return x2 66 | 67 | 68 | class UnetDown(nn.Module): 69 | def __init__(self, in_channels, out_channels): 70 | super(UnetDown, self).__init__() 71 | ''' 72 | process and downscale the image feature maps 73 | ''' 74 | layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)] 75 | self.model = nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | return self.model(x) 79 | 80 | 81 | class UnetUp(nn.Module): 82 | def __init__(self, in_channels, out_channels): 83 | super(UnetUp, self).__init__() 84 | ''' 85 | process and upscale the image feature maps 86 | ''' 87 | layers = [ 88 | nn.ConvTranspose2d(in_channels, out_channels, 2, 2), 89 | ResidualConvBlock(out_channels, out_channels), 90 | ResidualConvBlock(out_channels, out_channels), 91 | ] 92 | self.model = nn.Sequential(*layers) 93 | 94 | def forward(self, x, skip): 95 | x = torch.cat((x, skip), 1) 96 | x = self.model(x) 97 | return x 98 | 99 | 100 | class EmbedFC(nn.Module): 101 | def __init__(self, input_dim, emb_dim): 102 | super(EmbedFC, self).__init__() 103 | ''' 104 | generic one layer FC NN for embedding things 105 | ''' 106 | self.input_dim = input_dim 107 | layers = [ 108 | nn.Linear(input_dim, emb_dim), 109 | nn.GELU(), 110 | nn.Linear(emb_dim, emb_dim), 111 | ] 112 | self.model = nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | x = x.view(-1, self.input_dim) 116 | return self.model(x) 117 | 118 | 119 | class ContextUnet(nn.Module): 120 | def __init__(self, in_channels, n_feat = 256, n_classes=10): 121 | super(ContextUnet, self).__init__() 122 | 123 | self.in_channels = in_channels 124 | self.n_feat = n_feat 125 | self.n_classes = n_classes 126 | 127 | self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True) 128 | 129 | self.down1 = UnetDown(n_feat, n_feat) 130 | self.down2 = UnetDown(n_feat, 2 * n_feat) 131 | 132 | self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU()) 133 | 134 | self.timeembed1 = EmbedFC(1, 2*n_feat) 135 | self.timeembed2 = EmbedFC(1, 1*n_feat) 136 | self.contextembed1 = EmbedFC(n_classes, 2*n_feat) 137 | self.contextembed2 = EmbedFC(n_classes, 1*n_feat) 138 | 139 | self.up0 = nn.Sequential( 140 | # nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat 141 | nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat 142 | nn.GroupNorm(8, 2 * n_feat), 143 | nn.ReLU(), 144 | ) 145 | 146 | self.up1 = UnetUp(4 * n_feat, n_feat) 147 | self.up2 = UnetUp(2 * n_feat, n_feat) 148 | self.out = nn.Sequential( 149 | nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), 150 | nn.GroupNorm(8, n_feat), 151 | nn.ReLU(), 152 | nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), 153 | ) 154 | 155 | def forward(self, x, c, t, context_mask): 156 | # x is (noisy) image, c is context label, t is timestep, 157 | # context_mask says which samples to block the context on 158 | 159 | x = self.init_conv(x) 160 | down1 = self.down1(x) 161 | down2 = self.down2(down1) 162 | hiddenvec = self.to_vec(down2) 163 | 164 | # convert context to one hot embedding 165 | c = nn.functional.one_hot(c, num_classes=self.n_classes).type(torch.float) 166 | 167 | # mask out context if context_mask == 1 168 | context_mask = context_mask[:, None] 169 | context_mask = context_mask.repeat(1,self.n_classes) 170 | context_mask = (-1*(1-context_mask)) # need to flip 0 <-> 1 171 | c = c * context_mask 172 | 173 | # embed context, time step 174 | cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) 175 | temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1) 176 | cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1) 177 | temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1) 178 | 179 | # could concatenate the context embedding here instead of adaGN 180 | # hiddenvec = torch.cat((hiddenvec, temb1, cemb1), 1) 181 | 182 | up1 = self.up0(hiddenvec) 183 | # up2 = self.up1(up1, down2) # if want to avoid add and multiply embeddings 184 | up2 = self.up1(cemb1*up1+ temb1, down2) # add and multiply embeddings 185 | up3 = self.up2(cemb2*up2+ temb2, down1) 186 | out = self.out(torch.cat((up3, x), 1)) 187 | return out 188 | 189 | 190 | def ddpm_schedules(beta1, beta2, T): 191 | """ 192 | Returns pre-computed schedules for DDPM sampling, training process. 193 | """ 194 | assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)" 195 | 196 | beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1 197 | sqrt_beta_t = torch.sqrt(beta_t) 198 | alpha_t = 1 - beta_t 199 | log_alpha_t = torch.log(alpha_t) 200 | alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() 201 | 202 | sqrtab = torch.sqrt(alphabar_t) 203 | oneover_sqrta = 1 / torch.sqrt(alpha_t) 204 | 205 | sqrtmab = torch.sqrt(1 - alphabar_t) 206 | mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab 207 | 208 | return { 209 | "alpha_t": alpha_t, # \alpha_t 210 | "oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t} 211 | "sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t} 212 | "alphabar_t": alphabar_t, # \bar{\alpha_t} 213 | "sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}} 214 | "sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}} 215 | "mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}} 216 | } 217 | 218 | 219 | class DDPM(nn.Module): 220 | def __init__(self, nn_model, betas, n_T, device, drop_prob=0.1): 221 | super(DDPM, self).__init__() 222 | self.nn_model = nn_model.to(device) 223 | 224 | # register_buffer allows accessing dictionary produced by ddpm_schedules 225 | # e.g. can access self.sqrtab later 226 | for k, v in ddpm_schedules(betas[0], betas[1], n_T).items(): 227 | self.register_buffer(k, v) 228 | 229 | self.n_T = n_T 230 | self.device = device 231 | self.drop_prob = drop_prob 232 | self.loss_mse = nn.MSELoss() 233 | 234 | def forward(self, x, c): 235 | """ 236 | this method is used in training, so samples t and noise randomly 237 | """ 238 | 239 | _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device) # t ~ Uniform(0, n_T) 240 | noise = torch.randn_like(x) # eps ~ N(0, 1) 241 | 242 | x_t = ( 243 | self.sqrtab[_ts, None, None, None] * x 244 | + self.sqrtmab[_ts, None, None, None] * noise 245 | ) # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps 246 | # We should predict the "error term" from this x_t. Loss is what we return. 247 | 248 | # dropout context with some probability 249 | context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device) 250 | 251 | # return MSE between added noise, and our predicted noise 252 | return self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T, context_mask)) 253 | 254 | def sample(self, n_sample, size, device, guide_w = 0.0): 255 | # we follow the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance' 256 | # to make the fwd passes efficient, we concat two versions of the dataset, 257 | # one with context_mask=0 and the other context_mask=1 258 | # we then mix the outputs with the guidance scale, w 259 | # where w>0 means more guidance 260 | 261 | x_i = torch.randn(n_sample, *size).to(device) # x_T ~ N(0, 1), sample initial noise 262 | c_i = torch.arange(0,10).to(device) # context for us just cycles throught the mnist labels 263 | c_i = c_i.repeat(int(n_sample/c_i.shape[0])) 264 | 265 | # don't drop context at test time 266 | context_mask = torch.zeros_like(c_i).to(device) 267 | 268 | # double the batch 269 | c_i = c_i.repeat(2) 270 | context_mask = context_mask.repeat(2) 271 | context_mask[n_sample:] = 1. # makes second half of batch context free 272 | 273 | x_i_store = [] # keep track of generated steps in case want to plot something 274 | print() 275 | for i in range(self.n_T, 0, -1): 276 | print(f'sampling timestep {i}',end='\r') 277 | t_is = torch.tensor([i / self.n_T]).to(device) 278 | t_is = t_is.repeat(n_sample,1,1,1) 279 | 280 | # double batch 281 | x_i = x_i.repeat(2,1,1,1) 282 | t_is = t_is.repeat(2,1,1,1) 283 | 284 | z = torch.randn(n_sample, *size).to(device) if i > 1 else 0 285 | 286 | # split predictions and compute weighting 287 | eps = self.nn_model(x_i, c_i, t_is, context_mask) 288 | eps1 = eps[:n_sample] 289 | eps2 = eps[n_sample:] 290 | eps = (1+guide_w)*eps1 - guide_w*eps2 291 | x_i = x_i[:n_sample] 292 | x_i = ( 293 | self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) 294 | + self.sqrt_beta_t[i] * z 295 | ) 296 | if i%20==0 or i==self.n_T or i<8: 297 | x_i_store.append(x_i.detach().cpu().numpy()) 298 | 299 | x_i_store = np.array(x_i_store) 300 | return x_i, x_i_store 301 | 302 | 303 | def train_mnist(): 304 | 305 | # hardcoding these here 306 | n_epoch = 20 307 | batch_size = 256 308 | n_T = 400 # 500 309 | device = "cuda:0" 310 | n_classes = 10 311 | n_feat = 128 # 128 ok, 256 better (but slower) 312 | lrate = 1e-4 313 | save_model = False 314 | save_dir = './data/diffusion_outputs10/' 315 | ws_test = [0.0, 0.5, 2.0] # strength of generative guidance 316 | 317 | ddpm = DDPM(nn_model=ContextUnet(in_channels=1, n_feat=n_feat, n_classes=n_classes), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1) 318 | ddpm.to(device) 319 | 320 | # optionally load a model 321 | # ddpm.load_state_dict(torch.load("./data/diffusion_outputs/ddpm_unet01_mnist_9.pth")) 322 | 323 | tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1 324 | 325 | dataset = MNIST("./data", train=True, download=True, transform=tf) 326 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=5) 327 | optim = torch.optim.Adam(ddpm.parameters(), lr=lrate) 328 | 329 | for ep in range(n_epoch): 330 | print(f'epoch {ep}') 331 | ddpm.train() 332 | 333 | # linear lrate decay 334 | optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch) 335 | 336 | pbar = tqdm(dataloader) 337 | loss_ema = None 338 | for x, c in pbar: 339 | optim.zero_grad() 340 | x = x.to(device) 341 | c = c.to(device) 342 | loss = ddpm(x, c) 343 | loss.backward() 344 | if loss_ema is None: 345 | loss_ema = loss.item() 346 | else: 347 | loss_ema = 0.95 * loss_ema + 0.05 * loss.item() 348 | pbar.set_description(f"loss: {loss_ema:.4f}") 349 | optim.step() 350 | 351 | # for eval, save an image of currently generated samples (top rows) 352 | # followed by real images (bottom rows) 353 | ddpm.eval() 354 | with torch.no_grad(): 355 | n_sample = 4*n_classes 356 | for w_i, w in enumerate(ws_test): 357 | x_gen, x_gen_store = ddpm.sample(n_sample, (1, 28, 28), device, guide_w=w) 358 | 359 | # append some real images at bottom, order by class also 360 | x_real = torch.Tensor(x_gen.shape).to(device) 361 | for k in range(n_classes): 362 | for j in range(int(n_sample/n_classes)): 363 | try: 364 | idx = torch.squeeze((c == k).nonzero())[j] 365 | except: 366 | idx = 0 367 | x_real[k+(j*n_classes)] = x[idx] 368 | 369 | x_all = torch.cat([x_gen, x_real]) 370 | grid = make_grid(x_all*-1 + 1, nrow=10) 371 | save_image(grid, save_dir + f"image_ep{ep}_w{w}.png") 372 | print('saved image at ' + save_dir + f"image_ep{ep}_w{w}.png") 373 | 374 | if ep%5==0 or ep == int(n_epoch-1): 375 | # create gif of images evolving over time, based on x_gen_store 376 | fig, axs = plt.subplots(nrows=int(n_sample/n_classes), ncols=n_classes,sharex=True,sharey=True,figsize=(8,3)) 377 | def animate_diff(i, x_gen_store): 378 | print(f'gif animating frame {i} of {x_gen_store.shape[0]}', end='\r') 379 | plots = [] 380 | for row in range(int(n_sample/n_classes)): 381 | for col in range(n_classes): 382 | axs[row, col].clear() 383 | axs[row, col].set_xticks([]) 384 | axs[row, col].set_yticks([]) 385 | # plots.append(axs[row, col].imshow(x_gen_store[i,(row*n_classes)+col,0],cmap='gray')) 386 | plots.append(axs[row, col].imshow(-x_gen_store[i,(row*n_classes)+col,0],cmap='gray',vmin=(-x_gen_store[i]).min(), vmax=(-x_gen_store[i]).max())) 387 | return plots 388 | ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store], interval=200, blit=False, repeat=True, frames=x_gen_store.shape[0]) 389 | ani.save(save_dir + f"gif_ep{ep}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5)) 390 | print('saved image at ' + save_dir + f"gif_ep{ep}_w{w}.gif") 391 | # optionally save model 392 | if save_model and ep == int(n_epoch-1): 393 | torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth") 394 | print('saved model at ' + save_dir + f"model_{ep}.pth") 395 | 396 | if __name__ == "__main__": 397 | train_mnist() 398 | 399 | --------------------------------------------------------------------------------