├── results └── gif.gif ├── requirements.txt ├── README.md ├── Dockerfile ├── LICENSE ├── ddpm.py ├── train.py └── model └── Unet.py /results/gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vrushank264/mini-diffusion/HEAD/results/gif.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.59.0 2 | numpy==1.21.3 3 | matplotlib==3.5.1 4 | torch==1.9.0 5 | torchvision 6 | opencv-python 7 | wandb -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mini-diffusion 2 | A Small lightweight diffusion model in PyTorch. 3 | 4 |

Result:

5 | 6 | ![GIF](https://github.com/Vrushank264/mini-diffusion/blob/main/results/gif.gif) 7 | 8 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PYTORCH="1.9.0" 2 | ARG CUDA="10.2" 3 | ARG CUDNN="7" 4 | 5 | #FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel 6 | FROM nvidia/cuda:11.6.0-cudnn8-runtime-ubuntu18.04 7 | ENV PATH="/root/miniconda3/bin:${PATH}" 8 | ARG PATH="/root/miniconda3/bin:${PATH}" 9 | 10 | RUN apt update && apt install -y htop python3-dev wget nvidia-container-toolkit 11 | 12 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 13 | && mkdir root/.conda \ 14 | && sh Miniconda3-latest-Linux-x86_64.sh -b \ 15 | && rm -f Miniconda3-latest-Linux-x86_64.sh 16 | 17 | RUN conda create -y -n diffusion_env python=3.7 18 | 19 | COPY . src/ 20 | 21 | RUN /bin/bash -c "cd src && source activate diffusion_env && pip install -r requirements.txt" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Vrushank Changawala 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ddpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | 6 | 7 | def betas_for_alpha_bar(T, alpha_bar, max_beta): 8 | 9 | betas = [] 10 | for i in range(T): 11 | t1 = i / T 12 | t2 = (i + 1) / T 13 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 14 | 15 | betas = np.array(betas) 16 | return torch.from_numpy(betas) 17 | 18 | 19 | def ddpm_schedule(beta_start, beta_end, T, scheduler_type = 'cosine'): 20 | 21 | assert beta_start < beta_end < 1.0 22 | 23 | if scheduler_type == 'linear': 24 | 25 | betas = torch.linspace(beta_start, beta_end, T) 26 | 27 | if scheduler_type == 'cosine': 28 | 29 | betas = betas_for_alpha_bar(T, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, beta_end) 30 | 31 | if scheduler_type not in ['linear', 'cosine']: 32 | 33 | raise NotImplementedError(f'Unknown Beta Schedule {scheduler_type}.') 34 | 35 | sqrt_beta_t = torch.sqrt(betas) 36 | alpha_t = 1 - betas 37 | log_alpha_t = torch.log(alpha_t) 38 | sqrt_alpha_t_inv = 1 / torch.sqrt(alpha_t) 39 | 40 | alphabar_t = torch.cumsum(log_alpha_t, dim = 0).exp() 41 | sqrt_abar_t = torch.sqrt(alphabar_t) 42 | sqrt_abar_t1 = torch.sqrt(1 - alphabar_t) 43 | alpha_t_div_sqrt_abar = (1 - alpha_t) / sqrt_abar_t1 44 | 45 | return { 46 | 'sqrt_beta_t': sqrt_beta_t, 47 | 'alpha_t': alpha_t, 48 | 'sqrt_alpha_t_inv': sqrt_alpha_t_inv, 49 | 'alphabar_t': alphabar_t, 50 | 'sqrt_abar_t': sqrt_abar_t, 51 | 'sqrt_abar_t1': sqrt_abar_t1, 52 | 'alpha_t_div_sqrt_abar': alpha_t_div_sqrt_abar 53 | } 54 | 55 | 56 | 57 | class DDPM(nn.Module): 58 | 59 | def __init__(self, model, betas, T = 500, dropout_p = 0.1): 60 | 61 | super().__init__() 62 | self.model = model.cuda() 63 | 64 | for k, v in ddpm_schedule(betas[0], betas[1], T).items(): 65 | self.register_buffer(k, v) 66 | 67 | self.T = T 68 | self.dropout_p = dropout_p 69 | 70 | 71 | def forward(self, x, cls): 72 | 73 | timestep = torch.randint(1, self.T, (x.shape[0], )).cuda() 74 | noise = torch.randn_like(x) 75 | 76 | x_t = (self.sqrt_abar_t[timestep, None, None, None] * x + self.sqrt_abar_t1[timestep, None, None, None] * noise) 77 | 78 | ctx_mask = torch.bernoulli(torch.zeros_like(cls) + self.dropout_p).cuda() 79 | 80 | return noise, x_t, cls, timestep / self.T, ctx_mask 81 | 82 | 83 | def sample(self, num_samples, size, num_cls, guide_w = 0.0): 84 | 85 | x_i = torch.randn(num_samples, *size).cuda() 86 | c_i = torch.arange(0, num_cls).cuda() 87 | c_i = c_i.repeat(int(num_samples / c_i.shape[0])) 88 | 89 | ctx_mask = torch.zeros_like(c_i).cuda() 90 | c_i = c_i.repeat(2) 91 | ctx_mask = ctx_mask.repeat(2) 92 | ctx_mask[num_samples:] = 1.0 93 | 94 | #To Store intermediate results and create GIFs. 95 | x_is = [] 96 | 97 | for i in range(self.T - 1, 0, -1): 98 | 99 | t_is = torch.tensor([i / self.T]).cuda() 100 | t_is = t_is.repeat(num_samples, 1, 1, 1) 101 | 102 | x_i = x_i.repeat(2, 1, 1, 1) 103 | t_is = t_is.repeat(2, 1, 1, 1) 104 | 105 | z = torch.randn(num_samples, *size).cuda() if i > 1 else 0 106 | 107 | eps = self.model(x_i, c_i, t_is, ctx_mask) 108 | eps1 = eps[:num_samples] 109 | eps2 = eps[num_samples:] 110 | eps = (1 + guide_w)*eps1 - guide_w*eps2 111 | 112 | x_i = x_i[:num_samples] 113 | x_i = (self.sqrt_alpha_t_inv[i] * (x_i - eps*self.alpha_t_div_sqrt_abar[i]) + self.sqrt_beta_t[i] * z) 114 | 115 | if i % 25 == 0 or i == self.T - 1: 116 | x_is.append(x_i.detach().cpu().numpy()) 117 | 118 | 119 | return x_i, np.array(x_is) 120 | 121 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision as tv 4 | import torchvision.transforms as T 5 | from torch.utils.data import DataLoader 6 | import cv2 7 | import numpy as np 8 | import os 9 | import matplotlib.pyplot as plt 10 | from matplotlib.animation import FuncAnimation, PillowWriter 11 | from tqdm import tqdm 12 | import wandb 13 | 14 | from model.Unet import UNet 15 | from ddpm import DDPM 16 | 17 | 18 | class AverageMeter: 19 | def __init__(self, name=None): 20 | self.name = name 21 | self.reset() 22 | 23 | def reset(self): 24 | self.sum = self.count = self.avg = 0 25 | 26 | def update(self, val, n=1): 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | 32 | @torch.no_grad() 33 | def plot(ddpm_model, num_cls, ws, save_dir, epoch): 34 | 35 | ddpm_model.eval() 36 | num_samples = 4 * num_cls 37 | 38 | for w_i, w in enumerate(ws): 39 | 40 | pred, pred_arr = ddpm_model.sample(num_samples, (1,28,28), num_cls, w) 41 | real = torch.tensor(pred.shape).cuda() 42 | #print(pred.shape, real.shape) 43 | #combined = torch.cat([real, pred]) 44 | grid = tv.utils.make_grid(pred, nrow = 10) 45 | 46 | grid_arr = grid.squeeze().detach().cpu().numpy() 47 | print(f'Grid Array Shape: {grid_arr.shape}') 48 | cv2.imwrite(f'{save_dir}/pred_epoch_{epoch}_w_{w}.png', grid_arr.transpose(1,2,0)) 49 | 50 | ddpm_model.train() 51 | return grid_arr.transpose(1,2,0) 52 | 53 | 54 | def train(unet, ddpm_model, loader, opt, criterion, scaler, num_cls, save_dir, ws, epoch, num_epochs): 55 | 56 | unet.train() 57 | ddpm_model.train() 58 | 59 | wandb.log({ 60 | 'Epoch': epoch 61 | }) 62 | 63 | loop = tqdm(loader, position = 0, leave = True) 64 | loss_ = AverageMeter() 65 | 66 | for idx, (img, class_lbl) in enumerate(loop): 67 | 68 | img = img.cuda(non_blocking = True) 69 | lbl = class_lbl.cuda(non_blocking = True) 70 | 71 | opt.zero_grad(set_to_none = True) 72 | 73 | with torch.cuda.amp.autocast_mode.autocast(): 74 | 75 | noise, x_t, ctx, timestep, ctx_mask = ddpm_model(img, lbl) 76 | pred = unet(x_t.half(), ctx, timestep.half(), ctx_mask.half()) 77 | loss = criterion(noise, pred) 78 | 79 | scaler.scale(loss).backward() 80 | scaler.step(opt) 81 | scaler.update() 82 | 83 | 84 | loss_.update(loss.detach(), img.size(0)) 85 | 86 | if idx % 200 == 0: 87 | wandb.log({ 88 | 'Loss': loss_.avg 89 | }) 90 | print(loss_.avg) 91 | 92 | 93 | if epoch % 2 == 0: 94 | 95 | ddpm_model.eval() 96 | 97 | with torch.no_grad(): 98 | n_sample = 4*num_cls 99 | for w_i, w in enumerate(ws): 100 | 101 | x1, xis = ddpm_model.sample(n_sample, (1, 28, 28), num_cls,w) 102 | 103 | 104 | fig, ax = plt.subplots(nrows = n_sample // num_cls, ncols = num_cls, sharex = True, sharey = True, figsize = (10, 4)) 105 | 106 | def animate_plot(i, xis): 107 | 108 | plots = [] 109 | 110 | for row in range(n_sample // num_cls): 111 | 112 | for col in range(num_cls): 113 | 114 | ax[row, col].clear() 115 | ax[row, col].set_xticks([]) 116 | ax[row, col].set_yticks([]) 117 | 118 | plots.append(ax[row, col].imshow(-xis[i, (row*num_cls) + col, 0], cmap = 'gray', vmin = (-xis[i]).min(), vmax = (-xis[i]).max())) 119 | 120 | return plots 121 | 122 | ani = FuncAnimation(fig, animate_plot, fargs = [xis], interval = 200, blit = False, repeat = True, frames = xis.shape[0]) 123 | ani.save(f'{save_dir}/epoch_{epoch}.gif', dpi = 100, writer = PillowWriter(fps = 5)) 124 | print('GIF Saved!') 125 | 126 | torch.save(ddpm_model.state_dict(), os.path.join(save_dir, f'ddpm.pth')) 127 | torch.save(unet.state_dict(), os.path.join(save_dir, f'unet.pth')) 128 | 129 | 130 | def main(): 131 | 132 | num_cls = 10 133 | num_epochs = 20 134 | save_dir = '/content/drive/MyDrive/Diffusion' 135 | unet = UNet(1, 128, num_cls).cuda() 136 | ddpm_model = DDPM(unet, (1e-4, 0.02)).cuda() 137 | 138 | #unet.load_state_dict(torch.load('/content/drive/MyDrive/Diffusion/unet.pth')) 139 | #ddpm_model.load_state_dict(torch.load('/content/drive/MyDrive/Diffusion/ddpm.pth')) 140 | 141 | tr = T.Compose([T.ToTensor()]) 142 | dataset = tv.datasets.MNIST('/content/data', True, transform = tr, download = True) 143 | loader = DataLoader(dataset, batch_size = 64, shuffle = True, num_workers = 0) 144 | 145 | opt = torch.optim.Adam(list(ddpm_model.parameters()) + list(unet.parameters()), lr = 1e-4) 146 | criterion = nn.MSELoss() 147 | 148 | scaler = torch.cuda.amp.grad_scaler.GradScaler() 149 | 150 | ws = [0.0, 0.5, 1.0] 151 | 152 | for epoch in range(num_epochs): 153 | 154 | train(unet, ddpm_model, loader, opt, criterion, scaler, num_cls, save_dir, ws, epoch, num_epochs) 155 | 156 | 157 | if __name__ == '__main__': 158 | 159 | #wandb.init(project = 'MinDiffusion') 160 | main() 161 | -------------------------------------------------------------------------------- /model/Unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as fun 4 | import math 5 | 6 | 7 | class ResBlock(nn.Module): 8 | 9 | def __init__(self, in_c, out_c, is_res = False): 10 | 11 | super().__init__() 12 | self.is_res = is_res 13 | self.same_c = in_c == out_c 14 | self.conv1 = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size = 3, padding = 1), 15 | nn.BatchNorm2d(out_c), 16 | nn.GELU() 17 | ) 18 | self.conv2 = nn.Sequential(nn.Conv2d(out_c, out_c, kernel_size = 3, padding = 1), 19 | nn.BatchNorm2d(out_c), 20 | nn.GELU() 21 | ) 22 | 23 | def forward(self, x): 24 | 25 | x1 = self.conv1(x) 26 | x2 = self.conv2(x1) 27 | 28 | if self.is_res: 29 | if self.same_c: 30 | out = x + x2 31 | else: 32 | out = x1 + x2 33 | 34 | return out / math.sqrt(2) 35 | else: 36 | return x2 37 | 38 | 39 | class UnetDown(nn.Module): 40 | 41 | def __init__(self, in_c, out_c): 42 | super().__init__() 43 | 44 | self.block = nn.Sequential(ResBlock(in_c, out_c), 45 | nn.Conv2d(out_c, out_c, kernel_size = 3, stride = 2, padding = 1), 46 | nn.BatchNorm2d(out_c), 47 | nn.GELU() 48 | ) 49 | 50 | def forward(self, x): 51 | 52 | return self.block(x) 53 | 54 | 55 | class UnetUp(nn.Module): 56 | 57 | def __init__(self, in_c, out_c): 58 | super().__init__() 59 | 60 | self.block = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=1), 61 | ResBlock(out_c, out_c), 62 | ResBlock(out_c, out_c) 63 | ) 64 | 65 | def forward(self, x, skip): 66 | 67 | x = torch.cat([x, skip], dim = 1) 68 | x = fun.interpolate(x, scale_factor=2) 69 | res = self.block(x) 70 | return res 71 | 72 | 73 | class EmbeddingFC(nn.Module): 74 | 75 | def __init__(self, ip_dim, emb_dim): 76 | super().__init__() 77 | 78 | self.ip_dim = ip_dim 79 | self.block = nn.Sequential(nn.Linear(ip_dim, emb_dim), 80 | nn.GELU(), 81 | nn.Linear(emb_dim, emb_dim) 82 | ) 83 | 84 | 85 | def forward(self, x): 86 | 87 | x = x.view(-1, self.ip_dim) 88 | res = self.block(x) 89 | return res 90 | 91 | 92 | class UNet(nn.Module): 93 | 94 | def __init__(self, in_c, out_c, num_cls): 95 | super().__init__() 96 | 97 | self.in_c = in_c 98 | self.out_c = out_c 99 | self.num_cls = num_cls 100 | 101 | self.first_layer = ResBlock(in_c, out_c, is_res = True) 102 | self.down1 = UnetDown(out_c, out_c) 103 | self.down2 = UnetDown(out_c, out_c * 2) 104 | 105 | self.avg_pool = nn.Sequential(nn.AvgPool2d(7), nn.GELU()) 106 | 107 | self.time_emb1 = EmbeddingFC(1, out_c * 2) 108 | self.time_emb2 = EmbeddingFC(1, out_c) 109 | self.ctx_emb1 = EmbeddingFC(num_cls, out_c * 2) 110 | self.ctx_emb2 = EmbeddingFC(num_cls, out_c) 111 | 112 | self.up = nn.Sequential(nn.ConvTranspose2d(out_c * 2, out_c * 2, 7, 7), 113 | nn.GroupNorm(16, out_c * 2), 114 | nn.ReLU() 115 | ) 116 | 117 | self.up1 = UnetUp(out_c *4, out_c) 118 | self.up2 = UnetUp(out_c * 2, out_c) 119 | 120 | self.penultimate_layer = nn.Sequential(nn.Conv2d(out_c * 2, out_c, kernel_size = 3, padding = 1), 121 | nn.GroupNorm(16, out_c), 122 | nn.ReLU() 123 | ) 124 | 125 | self.last_layer = nn.Conv2d(out_c, self.in_c, kernel_size = 3, padding = 1) 126 | 127 | 128 | def forward(self, x, ctx, t, ctx_mask): 129 | 130 | ''' 131 | x: Noisy Image 132 | ctx: Context label 133 | t: timestep 134 | ctx_mask: which samples to block context on 135 | 136 | ''' 137 | 138 | x = self.first_layer(x) 139 | d1 = self.down1(x) 140 | d2 = self.down2(d1) 141 | vec = self.avg_pool(d2) 142 | ctx = fun.one_hot(ctx, num_classes = self.num_cls).type(torch.float) 143 | 144 | ctx_mask = ctx_mask[:, None] 145 | ctx_mask = ctx_mask.repeat(1, self.num_cls) 146 | ctx_mask = (-1 * (1 - ctx_mask)) 147 | ctx = ctx * ctx_mask 148 | 149 | c_emb1 = self.ctx_emb1(ctx).view(-1, self.out_c * 2, 1, 1) 150 | c_emb2 = self.ctx_emb2(ctx).view(-1, self.out_c, 1, 1) 151 | t_emb1 = self.time_emb1(t).view(-1, self.out_c * 2, 1, 1) 152 | t_emb2 = self.time_emb2(t).view(-1, self.out_c, 1, 1) 153 | 154 | up = self.up(vec) 155 | up1 = self.up1(up*c_emb1 + t_emb1, d2) 156 | up2 = self.up2(up1*c_emb2 + t_emb2, d1) 157 | res = self.penultimate_layer(torch.cat([up2, x], dim = 1)) 158 | res = self.last_layer(res) 159 | 160 | return res 161 | --------------------------------------------------------------------------------