├── Hybrids ├── birdcar.png ├── birdplane.png ├── carhoirse90.png ├── carhorse.png ├── catship.png ├── dogcar.png ├── dogplane.png ├── everythingbutship.png ├── froggycar190.png ├── froghorise.png ├── froghorsecar.png ├── froghourse2.png ├── frogmobile.png ├── frogplane.png ├── frogship.png ├── horsetruck.png ├── output_image.png ├── planehoirse3.png ├── planehorse.png ├── shipplane.png └── truckship.png ├── README.md ├── cifar_ddpm.py └── image_generator.py /Hybrids/birdcar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/birdcar.png -------------------------------------------------------------------------------- /Hybrids/birdplane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/birdplane.png -------------------------------------------------------------------------------- /Hybrids/carhoirse90.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/carhoirse90.png -------------------------------------------------------------------------------- /Hybrids/carhorse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/carhorse.png -------------------------------------------------------------------------------- /Hybrids/catship.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/catship.png -------------------------------------------------------------------------------- /Hybrids/dogcar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/dogcar.png -------------------------------------------------------------------------------- /Hybrids/dogplane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/dogplane.png -------------------------------------------------------------------------------- /Hybrids/everythingbutship.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/everythingbutship.png -------------------------------------------------------------------------------- /Hybrids/froggycar190.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/froggycar190.png -------------------------------------------------------------------------------- /Hybrids/froghorise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/froghorise.png -------------------------------------------------------------------------------- /Hybrids/froghorsecar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/froghorsecar.png -------------------------------------------------------------------------------- /Hybrids/froghourse2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/froghourse2.png -------------------------------------------------------------------------------- /Hybrids/frogmobile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/frogmobile.png -------------------------------------------------------------------------------- /Hybrids/frogplane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/frogplane.png -------------------------------------------------------------------------------- /Hybrids/frogship.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/frogship.png -------------------------------------------------------------------------------- /Hybrids/horsetruck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/horsetruck.png -------------------------------------------------------------------------------- /Hybrids/output_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/output_image.png -------------------------------------------------------------------------------- /Hybrids/planehoirse3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/planehoirse3.png -------------------------------------------------------------------------------- /Hybrids/planehorse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/planehorse.png -------------------------------------------------------------------------------- /Hybrids/shipplane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/shipplane.png -------------------------------------------------------------------------------- /Hybrids/truckship.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louisc-s/Conditional-Diffusion-Model/8bd69cfcf1b11f86e45dbb84fb3532e4789ed7a5/Hybrids/truckship.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional-Diffusion-Model 2 | Code for implemeting a conditional DDPM trained on CIFAR10 3 | 4 | ## Overview 5 | 6 | This code modifies a basic denosing diffusion probailistic model (DDPM) to create a conditional DDPM trained on the CIFAR-10 dataset which is able to generate synthetic images from a combination of different input category labels 7 | 8 | ## Project Structure 9 | 10 | 1. cifar_ddpm.py - implements and trains conditional ddpm 11 | 12 | 2. image_generator.py - produces generative images according to input conditions using trained model 13 | 14 | ## Author 15 | 16 | Louis Chapo-Saunders 17 | -------------------------------------------------------------------------------- /cifar_ddpm.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | from tqdm import tqdm 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from torchvision import models, transforms 8 | from torchvision.datasets import MNIST 9 | from torchvision.datasets import CIFAR10 10 | from torchvision.utils import save_image, make_grid 11 | import matplotlib.pyplot as plt 12 | from matplotlib.animation import FuncAnimation, PillowWriter 13 | import numpy as np 14 | import os 15 | import wandb 16 | 17 | device = "cpu" 18 | 19 | #define ResNet style convolutional block for UNET 20 | class ResidualConvBlock(nn.Module): 21 | def __init__( 22 | self, in_channels: int, out_channels: int, is_res: bool = False 23 | ) -> None: 24 | super().__init__() 25 | self.same_channels = in_channels==out_channels 26 | self.is_res = is_res 27 | self.conv1 = nn.Sequential( 28 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 29 | nn.BatchNorm2d(out_channels), 30 | nn.GELU(), 31 | ) 32 | self.conv2 = nn.Sequential( 33 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 34 | nn.BatchNorm2d(out_channels), 35 | nn.GELU(), 36 | ) 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | if self.is_res: 40 | x1 = self.conv1(x) 41 | x2 = self.conv2(x1) 42 | # this adds on correct residual in case channels have increased 43 | if self.same_channels: 44 | out = x + x2 45 | else: 46 | out = x1 + x2 47 | return out / 1.414 48 | else: 49 | x1 = self.conv1(x) 50 | x2 = self.conv2(x1) 51 | return x2 52 | 53 | 54 | # process and downscale the image feature maps 55 | class UnetDown(nn.Module): 56 | def __init__(self, in_channels, out_channels): 57 | super(UnetDown, self).__init__() 58 | layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)] 59 | self.model = nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | return self.model(x) 63 | 64 | # process and upscale the image feature maps 65 | class UnetUp(nn.Module): 66 | def __init__(self, in_channels, out_channels): 67 | super(UnetUp, self).__init__() 68 | layers = [ 69 | nn.ConvTranspose2d(in_channels, out_channels, 2, 2), 70 | ResidualConvBlock(out_channels, out_channels), 71 | ResidualConvBlock(out_channels, out_channels), 72 | ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | def forward(self, x, skip): 76 | x = torch.cat((x, skip), 1) 77 | x = self.model(x) 78 | return x 79 | 80 | #define embedding layer 81 | class EmbedFC(nn.Module): 82 | def __init__(self, input_dim, emb_dim): 83 | super(EmbedFC, self).__init__() 84 | self.input_dim = input_dim 85 | layers = [ 86 | nn.Linear(input_dim, emb_dim), 87 | nn.GELU(), 88 | nn.Linear(emb_dim, emb_dim), 89 | ] 90 | self.model = nn.Sequential(*layers) 91 | 92 | def forward(self, x): 93 | x = x.view(-1, self.input_dim) 94 | return self.model(x) 95 | 96 | 97 | #implement Context UNET 98 | class ContextUnet(nn.Module): 99 | def __init__(self, in_channels, n_feat = 256, n_classes=10): 100 | super(ContextUnet, self).__init__() 101 | 102 | self.in_channels = in_channels 103 | self.n_feat = n_feat 104 | self.n_classes = n_classes 105 | 106 | self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True) 107 | 108 | self.down1 = UnetDown(n_feat, n_feat) 109 | self.down2 = UnetDown(n_feat, 2 * n_feat) 110 | 111 | self.to_vec = nn.Sequential(nn.AvgPool2d(8), nn.GELU()) 112 | 113 | self.timeembed1 = EmbedFC(1, 2*n_feat) 114 | self.timeembed2 = EmbedFC(1, 1*n_feat) 115 | self.contextembed1 = EmbedFC(n_classes, 2*n_feat) 116 | self.contextembed2 = EmbedFC(n_classes, 1*n_feat) 117 | 118 | self.up0 = nn.Sequential( 119 | # nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat 120 | nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 8, 8), # otherwise just have 2*n_feat 121 | nn.GroupNorm(8, 2 * n_feat), 122 | nn.ReLU(), 123 | ) 124 | 125 | self.up1 = UnetUp(4 * n_feat, n_feat) 126 | self.up2 = UnetUp(2 * n_feat, n_feat) 127 | self.out = nn.Sequential( 128 | nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), 129 | nn.GroupNorm(8, n_feat), 130 | nn.ReLU(), 131 | nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), 132 | ) 133 | 134 | #implement multi hot enconding function to produce multi category images 135 | def one_hot(self,c,num_classes): 136 | c_return = torch.zeros(len(c),num_classes) 137 | for i,value in enumerate(c): 138 | c_return[i, value] = 1.0 139 | return c_return 140 | 141 | def forward(self, x, c, t, context_mask): 142 | # x is (noisy) image, c is context label, t is timestep, 143 | # context_mask says which samples to block the context on 144 | 145 | x = self.init_conv(x) 146 | down1 = self.down1(x) 147 | down2 = self.down2(down1) 148 | hiddenvec = self.to_vec(down2) 149 | 150 | # convert context to one hot embedding 151 | c = self.one_hot(c, self.n_classes) 152 | 153 | # mask out context if context_mask == 1 154 | context_mask = context_mask[:, None] 155 | context_mask = context_mask.repeat(1,self.n_classes) 156 | context_mask = (-1*(1-context_mask)) # need to flip 0 <-> 1 157 | c = c.to(device)* context_mask 158 | 159 | # embed context, time step 160 | cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) 161 | temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1) 162 | cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1) 163 | temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1) 164 | 165 | 166 | up1 = self.up0(hiddenvec) 167 | up2 = self.up1(cemb1*up1+ temb1, down2) 168 | up3 = self.up2(cemb2*up2+ temb2, down1) 169 | out = self.out(torch.cat((up3, x), 1)) 170 | return out 171 | 172 | 173 | # returns pre-computed schedules for DDPM sampling, training process. 174 | def ddpm_schedules(beta1, beta2, T): 175 | 176 | assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)" 177 | 178 | beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1 179 | sqrt_beta_t = torch.sqrt(beta_t) 180 | alpha_t = 1 - beta_t 181 | log_alpha_t = torch.log(alpha_t) 182 | alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() 183 | 184 | sqrtab = torch.sqrt(alphabar_t) 185 | oneover_sqrta = 1 / torch.sqrt(alpha_t) 186 | 187 | sqrtmab = torch.sqrt(1 - alphabar_t) 188 | mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab 189 | 190 | return { 191 | "alpha_t": alpha_t, # \alpha_t 192 | "log_alpha_t": log_alpha_t, 193 | "oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t} 194 | "sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t} 195 | "alphabar_t": alphabar_t, # \bar{\alpha_t} 196 | "sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}} 197 | "sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}} 198 | "mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}} 199 | } 200 | 201 | 202 | #implement diffusion model 203 | class DDPM(nn.Module): 204 | def __init__(self, nn_model, betas, n_T, device, drop_prob=0.1): 205 | super(DDPM, self).__init__() 206 | self.nn_model = nn_model.to(device) 207 | 208 | # register_buffer allows accessing dictionary produced by ddpm_schedules 209 | # e.g. can access self.sqrtab later 210 | for k, v in ddpm_schedules(betas[0], betas[1], n_T).items(): 211 | self.register_buffer(k, v) 212 | 213 | self.n_T = n_T 214 | self.device = device 215 | self.drop_prob = drop_prob 216 | self.loss_mse = nn.MSELoss() 217 | 218 | def forward(self, x, c): 219 | 220 | _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device) # t ~ Uniform(0, n_T) 221 | noise = torch.randn_like(x) # eps ~ N(0, 1) 222 | 223 | x_t = ( 224 | self.sqrtab[_ts, None, None, None] * x 225 | + self.sqrtmab[_ts, None, None, None] * noise 226 | ) # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps 227 | # We should predict the "error term" from this x_t. Loss is what we return. 228 | 229 | # dropout context with some probability 230 | context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device) 231 | 232 | # return MSE between added noise, and our predicted noise 233 | return self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T, context_mask)) 234 | 235 | def sample_single(self, n_sample, size, device, c, guide_w = 0.0): 236 | 237 | x_i = torch.randn(n_sample, *size).to(device) # x_T ~ N(0, 1), sample initial noise 238 | 239 | #c = c.repeat(int(n_sample/c.shape[0])) 240 | 241 | # don't drop context at test time 242 | context_mask = torch.zeros(1).to(device) 243 | if c.dim()== 1: 244 | c = c.repeat(2) 245 | else: 246 | c = c.repeat(2,1) 247 | 248 | context_mask = context_mask.repeat(2) 249 | 250 | context_mask[n_sample:] = 1. # makes second half of batch context free 251 | 252 | x_i_store = [] # keep track of generated steps in case want to plot something 253 | print() 254 | for i in range(self.n_T, 0, -1): 255 | print(f'sampling timestep {i}',end='\r') 256 | t_is = torch.tensor([i / self.n_T]).to(device) 257 | t_is = t_is.repeat(n_sample,1,1,1) 258 | 259 | # double batch 260 | x_i = x_i.repeat(2,1,1,1) 261 | t_is = t_is.repeat(2,1,1,1) 262 | 263 | z = torch.randn(n_sample, * size).to(device) if i > 1 else 0 264 | 265 | # split predictions and compute weighting 266 | eps = self.nn_model(x_i, c, t_is, context_mask) 267 | eps1 = eps[:n_sample] 268 | eps2 = eps[n_sample:] 269 | eps = (1+guide_w)*eps1 - guide_w*eps2 270 | x_i = x_i[:n_sample] 271 | x_i = ( 272 | self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) 273 | + self.sqrt_beta_t[i] * z 274 | ) 275 | if i%20==0 or i==self.n_T or i<8: 276 | x_i_store.append(x_i.detach().cpu().numpy()) 277 | 278 | x_i_store = np.array(x_i_store) 279 | return x_i, x_i_store 280 | 281 | 282 | def sample(self, n_sample, size, device, guide_w = 0.0): 283 | # we follow the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance' 284 | # to make the fwd passes efficient, we concat two versions of the dataset, 285 | # one with context_mask=0 and the other context_mask=1 286 | # we then mix the outputs with the guidance scale, w 287 | # where w>0 means more guidance 288 | 289 | x_i = torch.randn(n_sample, *size).to(device) # x_T ~ N(0, 1), sample initial noise 290 | c_i = torch.arange(0,10).to(device) # context for us just cycles throught the mnist labels 291 | c_i = c_i.repeat(int(n_sample/c_i.shape[0])) 292 | 293 | # don't drop context at test time 294 | context_mask = torch.zeros_like(c_i).to(device) 295 | 296 | # double the batch 297 | c_i = c_i.repeat(2) 298 | context_mask = context_mask.repeat(2) 299 | context_mask[n_sample:] = 1. # makes second half of batch context free 300 | 301 | x_i_store = [] # keep track of generated steps in case want to plot something 302 | print() 303 | for i in range(self.n_T, 0, -1): 304 | print(f'sampling timestep {i}',end='\r') 305 | t_is = torch.tensor([i / self.n_T]).to(device) 306 | t_is = t_is.repeat(n_sample,1,1,1) 307 | 308 | # double batch 309 | x_i = x_i.repeat(2,1,1,1) 310 | t_is = t_is.repeat(2,1,1,1) 311 | 312 | z = torch.randn(n_sample, *size).to(device) if i > 1 else 0 313 | 314 | # split predictions and compute weighting 315 | eps = self.nn_model(x_i, c_i, t_is, context_mask) 316 | eps1 = eps[:n_sample] 317 | eps2 = eps[n_sample:] 318 | eps = (1+guide_w)*eps1 - guide_w*eps2 319 | x_i = x_i[:n_sample] 320 | x_i = ( 321 | self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) 322 | + self.sqrt_beta_t[i] * z 323 | ) 324 | if i%20==0 or i==self.n_T or i<8: 325 | x_i_store.append(x_i.detach().cpu().numpy()) 326 | 327 | x_i_store = np.array(x_i_store) 328 | return x_i, x_i_store 329 | 330 | 331 | def train_mnist(): 332 | 333 | #define hyperparameters 334 | n_epoch = 1 335 | batch_size = 256 336 | n_T = 400 # 500 337 | device = "cuda:0" 338 | n_classes = 10 339 | n_feat = 128 340 | lrate = 1e-4 341 | save_model = True 342 | save_dir = './louisdata/diffusion_outputs10/' 343 | ws_test = [0.0, 0.5, 2.0] # strength of generative guidance 344 | 345 | if not os.path.exists(save_dir): 346 | os.makedirs(save_dir) 347 | 348 | #instantiate model 349 | ddpm = DDPM(nn_model=ContextUnet(in_channels=3, n_feat=n_feat, n_classes=n_classes), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1) 350 | ddpm.to(device) 351 | 352 | # optionally load a model 353 | # ddpm.load_state_dict(torch.load("./data/diffusion_outputs/ddpm_unet01_mnist_9.pth")) 354 | 355 | tf = transforms.Compose([transforms.ToTensor()]) 356 | 357 | #load dataset 358 | dataset = CIFAR10("./datapics", train=True, download=True, transform=tf) 359 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=5) 360 | optim = torch.optim.Adam(ddpm.parameters(), lr=lrate) 361 | 362 | for ep in range(n_epoch): 363 | 364 | print(f'epoch {ep}') 365 | ddpm.train() 366 | 367 | # linear lrate decay 368 | optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch) 369 | 370 | pbar = tqdm(dataloader) 371 | loss_ema = None 372 | for x, c in pbar: 373 | optim.zero_grad() 374 | x = x.to(device) 375 | c = c.to(device) 376 | loss = ddpm(x, c) 377 | loss.backward() 378 | if loss_ema is None: 379 | loss_ema = loss.item() 380 | else: 381 | loss_ema = 0.95 * loss_ema + 0.05 * loss.item() 382 | pbar.set_description(f"loss: {loss_ema:.4f}") 383 | optim.step() 384 | 385 | 386 | # for eval, save an image of currently generated samples (top rows) 387 | # followed by real images (bottom rows) 388 | ddpm.eval() 389 | with torch.no_grad(): 390 | n_sample = 4*n_classes 391 | for w_i, w in enumerate(ws_test): 392 | x_gen, x_gen_store = ddpm.sample(n_sample, (3, 32, 32), device, guide_w=w) 393 | 394 | # append some real images at bottom, order by class also 395 | x_real = torch.Tensor(x_gen.shape).to(device) 396 | for k in range(n_classes): 397 | for j in range(int(n_sample/n_classes)): 398 | try: 399 | idx = torch.squeeze((c == k).nonzero())[j] 400 | except: 401 | idx = 0 402 | x_real[k+(j*n_classes)] = x[idx] 403 | 404 | x_all = torch.cat([x_gen, x_real]) 405 | grid = make_grid(x_all*-1 + 1, nrow=10) 406 | save_image(grid, save_dir + f"image_ep{ep}_w{w}.png") 407 | print('saved image at ' + save_dir + f"image_ep{ep}_w{w}.png") 408 | # optionally save model 409 | if save_model and ep%1 == 0: 410 | torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth") 411 | print('saved model at ' + save_dir + f"model_{ep}.pth") 412 | 413 | if __name__ == "__main__": 414 | train_mnist() 415 | 416 | 417 | 418 | 419 | -------------------------------------------------------------------------------- /image_generator.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torchvision.transforms as transforms 3 | from torchvision.utils import make_grid 4 | import random 5 | from louis_cifar import DDPM 6 | import louis_cifar 7 | import numpy as np 8 | import torch 9 | 10 | #define model parameters 11 | n_T = 400 12 | device=torch.device("cpu") 13 | n_classes = 10 14 | n_feat = 128 15 | lrate = 1e-4 16 | save_model = True 17 | 18 | #load pretrained diffusion model 19 | ddpm = DDPM(nn_model=louis_cifar.ContextUnet(in_channels=3, n_feat=n_feat, n_classes=n_classes), betas=(1e-4, 0.02), n_T=n_T, device=torch.device("cpu"), drop_prob=0.1) 20 | ddpm.load_state_dict(torch.load("model_90.pth", map_location='cpu')) 21 | 22 | #define desired generated image category/categories e.g. cat & plane 23 | context_label = [2,1] 24 | c = torch.tensor([context_label], dtype=torch.long) 25 | 26 | #generate image 27 | ddpm.eval() 28 | with torch.no_grad(): 29 | x_gen, x_gen_store = ddpm.sample_single(1, (3, 32, 32), device, c, guide_w=4.0) 30 | numpy_array = x_gen[0].squeeze(0).cpu().detach().numpy() 31 | numpy_array = np.transpose(numpy_array, (1, 2, 0)) # Transpose to (32, 32, 3) 32 | plt.imshow(numpy_array, cmap='gray') 33 | plt.axis('off') 34 | plt.savefig('output_image.png', bbox_inches='tight') 35 | plt.show() --------------------------------------------------------------------------------