├── results
└── gif.gif
├── requirements.txt
├── README.md
├── Dockerfile
├── LICENSE
├── ddpm.py
├── train.py
└── model
└── Unet.py
/results/gif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vrushank264/mini-diffusion/HEAD/results/gif.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm==4.59.0
2 | numpy==1.21.3
3 | matplotlib==3.5.1
4 | torch==1.9.0
5 | torchvision
6 | opencv-python
7 | wandb
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # mini-diffusion
2 | A Small lightweight diffusion model in PyTorch.
3 |
4 |
Result:
5 |
6 | 
7 |
8 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG PYTORCH="1.9.0"
2 | ARG CUDA="10.2"
3 | ARG CUDNN="7"
4 |
5 | #FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
6 | FROM nvidia/cuda:11.6.0-cudnn8-runtime-ubuntu18.04
7 | ENV PATH="/root/miniconda3/bin:${PATH}"
8 | ARG PATH="/root/miniconda3/bin:${PATH}"
9 |
10 | RUN apt update && apt install -y htop python3-dev wget nvidia-container-toolkit
11 |
12 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
13 | && mkdir root/.conda \
14 | && sh Miniconda3-latest-Linux-x86_64.sh -b \
15 | && rm -f Miniconda3-latest-Linux-x86_64.sh
16 |
17 | RUN conda create -y -n diffusion_env python=3.7
18 |
19 | COPY . src/
20 |
21 | RUN /bin/bash -c "cd src && source activate diffusion_env && pip install -r requirements.txt"
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Vrushank Changawala
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/ddpm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import math
5 |
6 |
7 | def betas_for_alpha_bar(T, alpha_bar, max_beta):
8 |
9 | betas = []
10 | for i in range(T):
11 | t1 = i / T
12 | t2 = (i + 1) / T
13 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
14 |
15 | betas = np.array(betas)
16 | return torch.from_numpy(betas)
17 |
18 |
19 | def ddpm_schedule(beta_start, beta_end, T, scheduler_type = 'cosine'):
20 |
21 | assert beta_start < beta_end < 1.0
22 |
23 | if scheduler_type == 'linear':
24 |
25 | betas = torch.linspace(beta_start, beta_end, T)
26 |
27 | if scheduler_type == 'cosine':
28 |
29 | betas = betas_for_alpha_bar(T, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, beta_end)
30 |
31 | if scheduler_type not in ['linear', 'cosine']:
32 |
33 | raise NotImplementedError(f'Unknown Beta Schedule {scheduler_type}.')
34 |
35 | sqrt_beta_t = torch.sqrt(betas)
36 | alpha_t = 1 - betas
37 | log_alpha_t = torch.log(alpha_t)
38 | sqrt_alpha_t_inv = 1 / torch.sqrt(alpha_t)
39 |
40 | alphabar_t = torch.cumsum(log_alpha_t, dim = 0).exp()
41 | sqrt_abar_t = torch.sqrt(alphabar_t)
42 | sqrt_abar_t1 = torch.sqrt(1 - alphabar_t)
43 | alpha_t_div_sqrt_abar = (1 - alpha_t) / sqrt_abar_t1
44 |
45 | return {
46 | 'sqrt_beta_t': sqrt_beta_t,
47 | 'alpha_t': alpha_t,
48 | 'sqrt_alpha_t_inv': sqrt_alpha_t_inv,
49 | 'alphabar_t': alphabar_t,
50 | 'sqrt_abar_t': sqrt_abar_t,
51 | 'sqrt_abar_t1': sqrt_abar_t1,
52 | 'alpha_t_div_sqrt_abar': alpha_t_div_sqrt_abar
53 | }
54 |
55 |
56 |
57 | class DDPM(nn.Module):
58 |
59 | def __init__(self, model, betas, T = 500, dropout_p = 0.1):
60 |
61 | super().__init__()
62 | self.model = model.cuda()
63 |
64 | for k, v in ddpm_schedule(betas[0], betas[1], T).items():
65 | self.register_buffer(k, v)
66 |
67 | self.T = T
68 | self.dropout_p = dropout_p
69 |
70 |
71 | def forward(self, x, cls):
72 |
73 | timestep = torch.randint(1, self.T, (x.shape[0], )).cuda()
74 | noise = torch.randn_like(x)
75 |
76 | x_t = (self.sqrt_abar_t[timestep, None, None, None] * x + self.sqrt_abar_t1[timestep, None, None, None] * noise)
77 |
78 | ctx_mask = torch.bernoulli(torch.zeros_like(cls) + self.dropout_p).cuda()
79 |
80 | return noise, x_t, cls, timestep / self.T, ctx_mask
81 |
82 |
83 | def sample(self, num_samples, size, num_cls, guide_w = 0.0):
84 |
85 | x_i = torch.randn(num_samples, *size).cuda()
86 | c_i = torch.arange(0, num_cls).cuda()
87 | c_i = c_i.repeat(int(num_samples / c_i.shape[0]))
88 |
89 | ctx_mask = torch.zeros_like(c_i).cuda()
90 | c_i = c_i.repeat(2)
91 | ctx_mask = ctx_mask.repeat(2)
92 | ctx_mask[num_samples:] = 1.0
93 |
94 | #To Store intermediate results and create GIFs.
95 | x_is = []
96 |
97 | for i in range(self.T - 1, 0, -1):
98 |
99 | t_is = torch.tensor([i / self.T]).cuda()
100 | t_is = t_is.repeat(num_samples, 1, 1, 1)
101 |
102 | x_i = x_i.repeat(2, 1, 1, 1)
103 | t_is = t_is.repeat(2, 1, 1, 1)
104 |
105 | z = torch.randn(num_samples, *size).cuda() if i > 1 else 0
106 |
107 | eps = self.model(x_i, c_i, t_is, ctx_mask)
108 | eps1 = eps[:num_samples]
109 | eps2 = eps[num_samples:]
110 | eps = (1 + guide_w)*eps1 - guide_w*eps2
111 |
112 | x_i = x_i[:num_samples]
113 | x_i = (self.sqrt_alpha_t_inv[i] * (x_i - eps*self.alpha_t_div_sqrt_abar[i]) + self.sqrt_beta_t[i] * z)
114 |
115 | if i % 25 == 0 or i == self.T - 1:
116 | x_is.append(x_i.detach().cpu().numpy())
117 |
118 |
119 | return x_i, np.array(x_is)
120 |
121 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision as tv
4 | import torchvision.transforms as T
5 | from torch.utils.data import DataLoader
6 | import cv2
7 | import numpy as np
8 | import os
9 | import matplotlib.pyplot as plt
10 | from matplotlib.animation import FuncAnimation, PillowWriter
11 | from tqdm import tqdm
12 | import wandb
13 |
14 | from model.Unet import UNet
15 | from ddpm import DDPM
16 |
17 |
18 | class AverageMeter:
19 | def __init__(self, name=None):
20 | self.name = name
21 | self.reset()
22 |
23 | def reset(self):
24 | self.sum = self.count = self.avg = 0
25 |
26 | def update(self, val, n=1):
27 | self.sum += val * n
28 | self.count += n
29 | self.avg = self.sum / self.count
30 |
31 |
32 | @torch.no_grad()
33 | def plot(ddpm_model, num_cls, ws, save_dir, epoch):
34 |
35 | ddpm_model.eval()
36 | num_samples = 4 * num_cls
37 |
38 | for w_i, w in enumerate(ws):
39 |
40 | pred, pred_arr = ddpm_model.sample(num_samples, (1,28,28), num_cls, w)
41 | real = torch.tensor(pred.shape).cuda()
42 | #print(pred.shape, real.shape)
43 | #combined = torch.cat([real, pred])
44 | grid = tv.utils.make_grid(pred, nrow = 10)
45 |
46 | grid_arr = grid.squeeze().detach().cpu().numpy()
47 | print(f'Grid Array Shape: {grid_arr.shape}')
48 | cv2.imwrite(f'{save_dir}/pred_epoch_{epoch}_w_{w}.png', grid_arr.transpose(1,2,0))
49 |
50 | ddpm_model.train()
51 | return grid_arr.transpose(1,2,0)
52 |
53 |
54 | def train(unet, ddpm_model, loader, opt, criterion, scaler, num_cls, save_dir, ws, epoch, num_epochs):
55 |
56 | unet.train()
57 | ddpm_model.train()
58 |
59 | wandb.log({
60 | 'Epoch': epoch
61 | })
62 |
63 | loop = tqdm(loader, position = 0, leave = True)
64 | loss_ = AverageMeter()
65 |
66 | for idx, (img, class_lbl) in enumerate(loop):
67 |
68 | img = img.cuda(non_blocking = True)
69 | lbl = class_lbl.cuda(non_blocking = True)
70 |
71 | opt.zero_grad(set_to_none = True)
72 |
73 | with torch.cuda.amp.autocast_mode.autocast():
74 |
75 | noise, x_t, ctx, timestep, ctx_mask = ddpm_model(img, lbl)
76 | pred = unet(x_t.half(), ctx, timestep.half(), ctx_mask.half())
77 | loss = criterion(noise, pred)
78 |
79 | scaler.scale(loss).backward()
80 | scaler.step(opt)
81 | scaler.update()
82 |
83 |
84 | loss_.update(loss.detach(), img.size(0))
85 |
86 | if idx % 200 == 0:
87 | wandb.log({
88 | 'Loss': loss_.avg
89 | })
90 | print(loss_.avg)
91 |
92 |
93 | if epoch % 2 == 0:
94 |
95 | ddpm_model.eval()
96 |
97 | with torch.no_grad():
98 | n_sample = 4*num_cls
99 | for w_i, w in enumerate(ws):
100 |
101 | x1, xis = ddpm_model.sample(n_sample, (1, 28, 28), num_cls,w)
102 |
103 |
104 | fig, ax = plt.subplots(nrows = n_sample // num_cls, ncols = num_cls, sharex = True, sharey = True, figsize = (10, 4))
105 |
106 | def animate_plot(i, xis):
107 |
108 | plots = []
109 |
110 | for row in range(n_sample // num_cls):
111 |
112 | for col in range(num_cls):
113 |
114 | ax[row, col].clear()
115 | ax[row, col].set_xticks([])
116 | ax[row, col].set_yticks([])
117 |
118 | plots.append(ax[row, col].imshow(-xis[i, (row*num_cls) + col, 0], cmap = 'gray', vmin = (-xis[i]).min(), vmax = (-xis[i]).max()))
119 |
120 | return plots
121 |
122 | ani = FuncAnimation(fig, animate_plot, fargs = [xis], interval = 200, blit = False, repeat = True, frames = xis.shape[0])
123 | ani.save(f'{save_dir}/epoch_{epoch}.gif', dpi = 100, writer = PillowWriter(fps = 5))
124 | print('GIF Saved!')
125 |
126 | torch.save(ddpm_model.state_dict(), os.path.join(save_dir, f'ddpm.pth'))
127 | torch.save(unet.state_dict(), os.path.join(save_dir, f'unet.pth'))
128 |
129 |
130 | def main():
131 |
132 | num_cls = 10
133 | num_epochs = 20
134 | save_dir = '/content/drive/MyDrive/Diffusion'
135 | unet = UNet(1, 128, num_cls).cuda()
136 | ddpm_model = DDPM(unet, (1e-4, 0.02)).cuda()
137 |
138 | #unet.load_state_dict(torch.load('/content/drive/MyDrive/Diffusion/unet.pth'))
139 | #ddpm_model.load_state_dict(torch.load('/content/drive/MyDrive/Diffusion/ddpm.pth'))
140 |
141 | tr = T.Compose([T.ToTensor()])
142 | dataset = tv.datasets.MNIST('/content/data', True, transform = tr, download = True)
143 | loader = DataLoader(dataset, batch_size = 64, shuffle = True, num_workers = 0)
144 |
145 | opt = torch.optim.Adam(list(ddpm_model.parameters()) + list(unet.parameters()), lr = 1e-4)
146 | criterion = nn.MSELoss()
147 |
148 | scaler = torch.cuda.amp.grad_scaler.GradScaler()
149 |
150 | ws = [0.0, 0.5, 1.0]
151 |
152 | for epoch in range(num_epochs):
153 |
154 | train(unet, ddpm_model, loader, opt, criterion, scaler, num_cls, save_dir, ws, epoch, num_epochs)
155 |
156 |
157 | if __name__ == '__main__':
158 |
159 | #wandb.init(project = 'MinDiffusion')
160 | main()
161 |
--------------------------------------------------------------------------------
/model/Unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as fun
4 | import math
5 |
6 |
7 | class ResBlock(nn.Module):
8 |
9 | def __init__(self, in_c, out_c, is_res = False):
10 |
11 | super().__init__()
12 | self.is_res = is_res
13 | self.same_c = in_c == out_c
14 | self.conv1 = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size = 3, padding = 1),
15 | nn.BatchNorm2d(out_c),
16 | nn.GELU()
17 | )
18 | self.conv2 = nn.Sequential(nn.Conv2d(out_c, out_c, kernel_size = 3, padding = 1),
19 | nn.BatchNorm2d(out_c),
20 | nn.GELU()
21 | )
22 |
23 | def forward(self, x):
24 |
25 | x1 = self.conv1(x)
26 | x2 = self.conv2(x1)
27 |
28 | if self.is_res:
29 | if self.same_c:
30 | out = x + x2
31 | else:
32 | out = x1 + x2
33 |
34 | return out / math.sqrt(2)
35 | else:
36 | return x2
37 |
38 |
39 | class UnetDown(nn.Module):
40 |
41 | def __init__(self, in_c, out_c):
42 | super().__init__()
43 |
44 | self.block = nn.Sequential(ResBlock(in_c, out_c),
45 | nn.Conv2d(out_c, out_c, kernel_size = 3, stride = 2, padding = 1),
46 | nn.BatchNorm2d(out_c),
47 | nn.GELU()
48 | )
49 |
50 | def forward(self, x):
51 |
52 | return self.block(x)
53 |
54 |
55 | class UnetUp(nn.Module):
56 |
57 | def __init__(self, in_c, out_c):
58 | super().__init__()
59 |
60 | self.block = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=1),
61 | ResBlock(out_c, out_c),
62 | ResBlock(out_c, out_c)
63 | )
64 |
65 | def forward(self, x, skip):
66 |
67 | x = torch.cat([x, skip], dim = 1)
68 | x = fun.interpolate(x, scale_factor=2)
69 | res = self.block(x)
70 | return res
71 |
72 |
73 | class EmbeddingFC(nn.Module):
74 |
75 | def __init__(self, ip_dim, emb_dim):
76 | super().__init__()
77 |
78 | self.ip_dim = ip_dim
79 | self.block = nn.Sequential(nn.Linear(ip_dim, emb_dim),
80 | nn.GELU(),
81 | nn.Linear(emb_dim, emb_dim)
82 | )
83 |
84 |
85 | def forward(self, x):
86 |
87 | x = x.view(-1, self.ip_dim)
88 | res = self.block(x)
89 | return res
90 |
91 |
92 | class UNet(nn.Module):
93 |
94 | def __init__(self, in_c, out_c, num_cls):
95 | super().__init__()
96 |
97 | self.in_c = in_c
98 | self.out_c = out_c
99 | self.num_cls = num_cls
100 |
101 | self.first_layer = ResBlock(in_c, out_c, is_res = True)
102 | self.down1 = UnetDown(out_c, out_c)
103 | self.down2 = UnetDown(out_c, out_c * 2)
104 |
105 | self.avg_pool = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
106 |
107 | self.time_emb1 = EmbeddingFC(1, out_c * 2)
108 | self.time_emb2 = EmbeddingFC(1, out_c)
109 | self.ctx_emb1 = EmbeddingFC(num_cls, out_c * 2)
110 | self.ctx_emb2 = EmbeddingFC(num_cls, out_c)
111 |
112 | self.up = nn.Sequential(nn.ConvTranspose2d(out_c * 2, out_c * 2, 7, 7),
113 | nn.GroupNorm(16, out_c * 2),
114 | nn.ReLU()
115 | )
116 |
117 | self.up1 = UnetUp(out_c *4, out_c)
118 | self.up2 = UnetUp(out_c * 2, out_c)
119 |
120 | self.penultimate_layer = nn.Sequential(nn.Conv2d(out_c * 2, out_c, kernel_size = 3, padding = 1),
121 | nn.GroupNorm(16, out_c),
122 | nn.ReLU()
123 | )
124 |
125 | self.last_layer = nn.Conv2d(out_c, self.in_c, kernel_size = 3, padding = 1)
126 |
127 |
128 | def forward(self, x, ctx, t, ctx_mask):
129 |
130 | '''
131 | x: Noisy Image
132 | ctx: Context label
133 | t: timestep
134 | ctx_mask: which samples to block context on
135 |
136 | '''
137 |
138 | x = self.first_layer(x)
139 | d1 = self.down1(x)
140 | d2 = self.down2(d1)
141 | vec = self.avg_pool(d2)
142 | ctx = fun.one_hot(ctx, num_classes = self.num_cls).type(torch.float)
143 |
144 | ctx_mask = ctx_mask[:, None]
145 | ctx_mask = ctx_mask.repeat(1, self.num_cls)
146 | ctx_mask = (-1 * (1 - ctx_mask))
147 | ctx = ctx * ctx_mask
148 |
149 | c_emb1 = self.ctx_emb1(ctx).view(-1, self.out_c * 2, 1, 1)
150 | c_emb2 = self.ctx_emb2(ctx).view(-1, self.out_c, 1, 1)
151 | t_emb1 = self.time_emb1(t).view(-1, self.out_c * 2, 1, 1)
152 | t_emb2 = self.time_emb2(t).view(-1, self.out_c, 1, 1)
153 |
154 | up = self.up(vec)
155 | up1 = self.up1(up*c_emb1 + t_emb1, d2)
156 | up2 = self.up2(up1*c_emb2 + t_emb2, d1)
157 | res = self.penultimate_layer(torch.cat([up2, x], dim = 1))
158 | res = self.last_layer(res)
159 |
160 | return res
161 |
--------------------------------------------------------------------------------