├── .gitignore ├── LICENSE.md ├── README.md ├── assets └── demo.gif ├── model.py ├── requirements.txt ├── train_mnist.py ├── unet.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Guocheng Tan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MNIST Diffusion 2 | ![60 epochs training from scratch](assets/demo.gif "60 epochs training from scratch") 3 | 4 | Only simple depthwise convolutions, shorcuts and naive timestep embedding, there you have it! A fully functional denosing diffusion probabilistic model while keeps ultra light weight **4.55MB** (the checkpoint has 9.1MB but with ema model double the size). 5 | 6 | ## Training 7 | Install packages 8 | ```bash 9 | pip install -r requirements.txt 10 | ``` 11 | Start default setting training 12 | ```bash 13 | python train_mnist.py 14 | ``` 15 | Feel free to tuning training parameters, type `python train_mnist.py -h` to get help message of arguments. 16 | 17 | ## Reference 18 | A neat blog explains how diffusion model works(must read!): https://lilianweng.github.io/posts/2021-07-11-diffusion-models/ 19 | 20 | The Denoising Diffusion Probabilistic Models paper: https://arxiv.org/pdf/2006.11239.pdf 21 | 22 | A pytorch version of DDPM: https://github.com/lucidrains/denoising-diffusion-pytorch 23 | 24 | -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bot66/MNISTDiffusion/c7ba8e09174cbb88b9cc314db3bf2e514668681c/assets/demo.gif -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | from unet import Unet 5 | from tqdm import tqdm 6 | 7 | class MNISTDiffusion(nn.Module): 8 | def __init__(self,image_size,in_channels,time_embedding_dim=256,timesteps=1000,base_dim=32,dim_mults= [1, 2, 4, 8]): 9 | super().__init__() 10 | self.timesteps=timesteps 11 | self.in_channels=in_channels 12 | self.image_size=image_size 13 | 14 | betas=self._cosine_variance_schedule(timesteps) 15 | 16 | alphas=1.-betas 17 | alphas_cumprod=torch.cumprod(alphas,dim=-1) 18 | 19 | self.register_buffer("betas",betas) 20 | self.register_buffer("alphas",alphas) 21 | self.register_buffer("alphas_cumprod",alphas_cumprod) 22 | self.register_buffer("sqrt_alphas_cumprod",torch.sqrt(alphas_cumprod)) 23 | self.register_buffer("sqrt_one_minus_alphas_cumprod",torch.sqrt(1.-alphas_cumprod)) 24 | 25 | self.model=Unet(timesteps,time_embedding_dim,in_channels,in_channels,base_dim,dim_mults) 26 | 27 | def forward(self,x,noise): 28 | # x:NCHW 29 | t=torch.randint(0,self.timesteps,(x.shape[0],)).to(x.device) 30 | x_t=self._forward_diffusion(x,t,noise) 31 | pred_noise=self.model(x_t,t) 32 | 33 | return pred_noise 34 | 35 | @torch.no_grad() 36 | def sampling(self,n_samples,clipped_reverse_diffusion=True,device="cuda"): 37 | x_t=torch.randn((n_samples,self.in_channels,self.image_size,self.image_size)).to(device) 38 | for i in tqdm(range(self.timesteps-1,-1,-1),desc="Sampling"): 39 | noise=torch.randn_like(x_t).to(device) 40 | t=torch.tensor([i for _ in range(n_samples)]).to(device) 41 | 42 | if clipped_reverse_diffusion: 43 | x_t=self._reverse_diffusion_with_clip(x_t,t,noise) 44 | else: 45 | x_t=self._reverse_diffusion(x_t,t,noise) 46 | 47 | x_t=(x_t+1.)/2. #[-1,1] to [0,1] 48 | 49 | return x_t 50 | 51 | def _cosine_variance_schedule(self,timesteps,epsilon= 0.008): 52 | steps=torch.linspace(0,timesteps,steps=timesteps+1,dtype=torch.float32) 53 | f_t=torch.cos(((steps/timesteps+epsilon)/(1.0+epsilon))*math.pi*0.5)**2 54 | betas=torch.clip(1.0-f_t[1:]/f_t[:timesteps],0.0,0.999) 55 | 56 | return betas 57 | 58 | def _forward_diffusion(self,x_0,t,noise): 59 | assert x_0.shape==noise.shape 60 | #q(x_{t}|x_{t-1}) 61 | return self.sqrt_alphas_cumprod.gather(-1,t).reshape(x_0.shape[0],1,1,1)*x_0+ \ 62 | self.sqrt_one_minus_alphas_cumprod.gather(-1,t).reshape(x_0.shape[0],1,1,1)*noise 63 | 64 | 65 | @torch.no_grad() 66 | def _reverse_diffusion(self,x_t,t,noise): 67 | ''' 68 | p(x_{t-1}|x_{t})-> mean,std 69 | 70 | pred_noise-> pred_mean and pred_std 71 | ''' 72 | pred=self.model(x_t,t) 73 | 74 | alpha_t=self.alphas.gather(-1,t).reshape(x_t.shape[0],1,1,1) 75 | alpha_t_cumprod=self.alphas_cumprod.gather(-1,t).reshape(x_t.shape[0],1,1,1) 76 | beta_t=self.betas.gather(-1,t).reshape(x_t.shape[0],1,1,1) 77 | sqrt_one_minus_alpha_cumprod_t=self.sqrt_one_minus_alphas_cumprod.gather(-1,t).reshape(x_t.shape[0],1,1,1) 78 | mean=(1./torch.sqrt(alpha_t))*(x_t-((1.0-alpha_t)/sqrt_one_minus_alpha_cumprod_t)*pred) 79 | 80 | if t.min()>0: 81 | alpha_t_cumprod_prev=self.alphas_cumprod.gather(-1,t-1).reshape(x_t.shape[0],1,1,1) 82 | std=torch.sqrt(beta_t*(1.-alpha_t_cumprod_prev)/(1.-alpha_t_cumprod)) 83 | else: 84 | std=0.0 85 | 86 | return mean+std*noise 87 | 88 | 89 | @torch.no_grad() 90 | def _reverse_diffusion_with_clip(self,x_t,t,noise): 91 | ''' 92 | p(x_{0}|x_{t}),q(x_{t-1}|x_{0},x_{t})->mean,std 93 | 94 | pred_noise -> pred_x_0 (clip to [-1.0,1.0]) -> pred_mean and pred_std 95 | ''' 96 | pred=self.model(x_t,t) 97 | alpha_t=self.alphas.gather(-1,t).reshape(x_t.shape[0],1,1,1) 98 | alpha_t_cumprod=self.alphas_cumprod.gather(-1,t).reshape(x_t.shape[0],1,1,1) 99 | beta_t=self.betas.gather(-1,t).reshape(x_t.shape[0],1,1,1) 100 | 101 | x_0_pred=torch.sqrt(1. / alpha_t_cumprod)*x_t-torch.sqrt(1. / alpha_t_cumprod - 1.)*pred 102 | x_0_pred.clamp_(-1., 1.) 103 | 104 | if t.min()>0: 105 | alpha_t_cumprod_prev=self.alphas_cumprod.gather(-1,t-1).reshape(x_t.shape[0],1,1,1) 106 | mean= (beta_t * torch.sqrt(alpha_t_cumprod_prev) / (1. - alpha_t_cumprod))*x_0_pred +\ 107 | ((1. - alpha_t_cumprod_prev) * torch.sqrt(alpha_t) / (1. - alpha_t_cumprod))*x_t 108 | 109 | std=torch.sqrt(beta_t*(1.-alpha_t_cumprod_prev)/(1.-alpha_t_cumprod)) 110 | else: 111 | mean=(beta_t / (1. - alpha_t_cumprod))*x_0_pred #alpha_t_cumprod_prev=1 since 0!=1 112 | std=0.0 113 | 114 | return mean+std*noise 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | tqdm 4 | argparse -------------------------------------------------------------------------------- /train_mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.datasets import MNIST 4 | from torchvision import transforms 5 | from torchvision.utils import save_image 6 | from torch.utils.data import DataLoader 7 | from torch.optim import AdamW 8 | from torch.optim.lr_scheduler import OneCycleLR 9 | from model import MNISTDiffusion 10 | from utils import ExponentialMovingAverage 11 | import os 12 | import math 13 | import argparse 14 | 15 | def create_mnist_dataloaders(batch_size,image_size=28,num_workers=4): 16 | 17 | preprocess=transforms.Compose([transforms.Resize(image_size),\ 18 | transforms.ToTensor(),\ 19 | transforms.Normalize([0.5],[0.5])]) #[0,1] to [-1,1] 20 | 21 | train_dataset=MNIST(root="./mnist_data",\ 22 | train=True,\ 23 | download=True,\ 24 | transform=preprocess 25 | ) 26 | test_dataset=MNIST(root="./mnist_data",\ 27 | train=False,\ 28 | download=True,\ 29 | transform=preprocess 30 | ) 31 | 32 | return DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers),\ 33 | DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers) 34 | 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser(description="Training MNISTDiffusion") 39 | parser.add_argument('--lr',type = float ,default=0.001) 40 | parser.add_argument('--batch_size',type = int ,default=128) 41 | parser.add_argument('--epochs',type = int,default=100) 42 | parser.add_argument('--ckpt',type = str,help = 'define checkpoint path',default='') 43 | parser.add_argument('--n_samples',type = int,help = 'define sampling amounts after every epoch trained',default=36) 44 | parser.add_argument('--model_base_dim',type = int,help = 'base dim of Unet',default=64) 45 | parser.add_argument('--timesteps',type = int,help = 'sampling steps of DDPM',default=1000) 46 | parser.add_argument('--model_ema_steps',type = int,help = 'ema model evaluation interval',default=10) 47 | parser.add_argument('--model_ema_decay',type = float,help = 'ema model decay',default=0.995) 48 | parser.add_argument('--log_freq',type = int,help = 'training log message printing frequence',default=10) 49 | parser.add_argument('--no_clip',action='store_true',help = 'set to normal sampling method without clip x_0 which could yield unstable samples') 50 | parser.add_argument('--cpu',action='store_true',help = 'cpu training') 51 | 52 | args = parser.parse_args() 53 | 54 | return args 55 | 56 | 57 | def main(args): 58 | device="cpu" if args.cpu else "cuda" 59 | train_dataloader,test_dataloader=create_mnist_dataloaders(batch_size=args.batch_size,image_size=28) 60 | model=MNISTDiffusion(timesteps=args.timesteps, 61 | image_size=28, 62 | in_channels=1, 63 | base_dim=args.model_base_dim, 64 | dim_mults=[2,4]).to(device) 65 | 66 | #torchvision ema setting 67 | #https://github.com/pytorch/vision/blob/main/references/classification/train.py#L317 68 | adjust = 1* args.batch_size * args.model_ema_steps / args.epochs 69 | alpha = 1.0 - args.model_ema_decay 70 | alpha = min(1.0, alpha * adjust) 71 | model_ema = ExponentialMovingAverage(model, device=device, decay=1.0 - alpha) 72 | 73 | optimizer=AdamW(model.parameters(),lr=args.lr) 74 | scheduler=OneCycleLR(optimizer,args.lr,total_steps=args.epochs*len(train_dataloader),pct_start=0.25,anneal_strategy='cos') 75 | loss_fn=nn.MSELoss(reduction='mean') 76 | 77 | #load checkpoint 78 | if args.ckpt: 79 | ckpt=torch.load(args.ckpt) 80 | model_ema.load_state_dict(ckpt["model_ema"]) 81 | model.load_state_dict(ckpt["model"]) 82 | 83 | global_steps=0 84 | for i in range(args.epochs): 85 | model.train() 86 | for j,(image,target) in enumerate(train_dataloader): 87 | noise=torch.randn_like(image).to(device) 88 | image=image.to(device) 89 | pred=model(image,noise) 90 | loss=loss_fn(pred,noise) 91 | loss.backward() 92 | optimizer.step() 93 | optimizer.zero_grad() 94 | scheduler.step() 95 | if global_steps%args.model_ema_steps==0: 96 | model_ema.update_parameters(model) 97 | global_steps+=1 98 | if j%args.log_freq==0: 99 | print("Epoch[{}/{}],Step[{}/{}],loss:{:.5f},lr:{:.5f}".format(i+1,args.epochs,j,len(train_dataloader), 100 | loss.detach().cpu().item(),scheduler.get_last_lr()[0])) 101 | ckpt={"model":model.state_dict(), 102 | "model_ema":model_ema.state_dict()} 103 | 104 | os.makedirs("results",exist_ok=True) 105 | torch.save(ckpt,"results/steps_{:0>8}.pt".format(global_steps)) 106 | 107 | model_ema.eval() 108 | samples=model_ema.module.sampling(args.n_samples,clipped_reverse_diffusion=not args.no_clip,device=device) 109 | save_image(samples,"results/steps_{:0>8}.png".format(global_steps),nrow=int(math.sqrt(args.n_samples))) 110 | 111 | if __name__=="__main__": 112 | args=parse_args() 113 | main(args) -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ChannelShuffle(nn.Module): 6 | def __init__(self,groups): 7 | super().__init__() 8 | self.groups=groups 9 | def forward(self,x): 10 | n,c,h,w=x.shape 11 | x=x.view(n,self.groups,c//self.groups,h,w) # group 12 | x=x.transpose(1,2).contiguous().view(n,-1,h,w) #shuffle 13 | 14 | return x 15 | 16 | class ConvBnSiLu(nn.Module): 17 | def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0): 18 | super().__init__() 19 | self.module=nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding), 20 | nn.BatchNorm2d(out_channels), 21 | nn.SiLU(inplace=True)) 22 | def forward(self,x): 23 | return self.module(x) 24 | 25 | class ResidualBottleneck(nn.Module): 26 | ''' 27 | shufflenet_v2 basic unit(https://arxiv.org/pdf/1807.11164.pdf) 28 | ''' 29 | def __init__(self,in_channels,out_channels): 30 | super().__init__() 31 | 32 | self.branch1=nn.Sequential(nn.Conv2d(in_channels//2,in_channels//2,3,1,1,groups=in_channels//2), 33 | nn.BatchNorm2d(in_channels//2), 34 | ConvBnSiLu(in_channels//2,out_channels//2,1,1,0)) 35 | self.branch2=nn.Sequential(ConvBnSiLu(in_channels//2,in_channels//2,1,1,0), 36 | nn.Conv2d(in_channels//2,in_channels//2,3,1,1,groups=in_channels//2), 37 | nn.BatchNorm2d(in_channels//2), 38 | ConvBnSiLu(in_channels//2,out_channels//2,1,1,0)) 39 | self.channel_shuffle=ChannelShuffle(2) 40 | 41 | def forward(self,x): 42 | x1,x2=x.chunk(2,dim=1) 43 | x=torch.cat([self.branch1(x1),self.branch2(x2)],dim=1) 44 | x=self.channel_shuffle(x) #shuffle two branches 45 | 46 | return x 47 | 48 | class ResidualDownsample(nn.Module): 49 | ''' 50 | shufflenet_v2 unit for spatial down sampling(https://arxiv.org/pdf/1807.11164.pdf) 51 | ''' 52 | def __init__(self,in_channels,out_channels): 53 | super().__init__() 54 | self.branch1=nn.Sequential(nn.Conv2d(in_channels,in_channels,3,2,1,groups=in_channels), 55 | nn.BatchNorm2d(in_channels), 56 | ConvBnSiLu(in_channels,out_channels//2,1,1,0)) 57 | self.branch2=nn.Sequential(ConvBnSiLu(in_channels,out_channels//2,1,1,0), 58 | nn.Conv2d(out_channels//2,out_channels//2,3,2,1,groups=out_channels//2), 59 | nn.BatchNorm2d(out_channels//2), 60 | ConvBnSiLu(out_channels//2,out_channels//2,1,1,0)) 61 | self.channel_shuffle=ChannelShuffle(2) 62 | 63 | def forward(self,x): 64 | x=torch.cat([self.branch1(x),self.branch2(x)],dim=1) 65 | x=self.channel_shuffle(x) #shuffle two branches 66 | 67 | return x 68 | 69 | class TimeMLP(nn.Module): 70 | ''' 71 | naive introduce timestep information to feature maps with mlp and add shortcut 72 | ''' 73 | def __init__(self,embedding_dim,hidden_dim,out_dim): 74 | super().__init__() 75 | self.mlp=nn.Sequential(nn.Linear(embedding_dim,hidden_dim), 76 | nn.SiLU(), 77 | nn.Linear(hidden_dim,out_dim)) 78 | self.act=nn.SiLU() 79 | def forward(self,x,t): 80 | t_emb=self.mlp(t).unsqueeze(-1).unsqueeze(-1) 81 | x=x+t_emb 82 | 83 | return self.act(x) 84 | 85 | class EncoderBlock(nn.Module): 86 | def __init__(self,in_channels,out_channels,time_embedding_dim): 87 | super().__init__() 88 | self.conv0=nn.Sequential(*[ResidualBottleneck(in_channels,in_channels) for i in range(3)], 89 | ResidualBottleneck(in_channels,out_channels//2)) 90 | 91 | self.time_mlp=TimeMLP(embedding_dim=time_embedding_dim,hidden_dim=out_channels,out_dim=out_channels//2) 92 | self.conv1=ResidualDownsample(out_channels//2,out_channels) 93 | 94 | def forward(self,x,t=None): 95 | x_shortcut=self.conv0(x) 96 | if t is not None: 97 | x=self.time_mlp(x_shortcut,t) 98 | x=self.conv1(x) 99 | 100 | return [x,x_shortcut] 101 | 102 | class DecoderBlock(nn.Module): 103 | def __init__(self,in_channels,out_channels,time_embedding_dim): 104 | super().__init__() 105 | self.upsample=nn.Upsample(scale_factor=2,mode='bilinear',align_corners=False) 106 | self.conv0=nn.Sequential(*[ResidualBottleneck(in_channels,in_channels) for i in range(3)], 107 | ResidualBottleneck(in_channels,in_channels//2)) 108 | 109 | self.time_mlp=TimeMLP(embedding_dim=time_embedding_dim,hidden_dim=in_channels,out_dim=in_channels//2) 110 | self.conv1=ResidualBottleneck(in_channels//2,out_channels//2) 111 | 112 | def forward(self,x,x_shortcut,t=None): 113 | x=self.upsample(x) 114 | x=torch.cat([x,x_shortcut],dim=1) 115 | x=self.conv0(x) 116 | if t is not None: 117 | x=self.time_mlp(x,t) 118 | x=self.conv1(x) 119 | 120 | return x 121 | 122 | class Unet(nn.Module): 123 | ''' 124 | simple unet design without attention 125 | ''' 126 | def __init__(self,timesteps,time_embedding_dim,in_channels=3,out_channels=2,base_dim=32,dim_mults=[2,4,8,16]): 127 | super().__init__() 128 | assert isinstance(dim_mults,(list,tuple)) 129 | assert base_dim%2==0 130 | 131 | channels=self._cal_channels(base_dim,dim_mults) 132 | 133 | self.init_conv=ConvBnSiLu(in_channels,base_dim,3,1,1) 134 | self.time_embedding=nn.Embedding(timesteps,time_embedding_dim) 135 | 136 | self.encoder_blocks=nn.ModuleList([EncoderBlock(c[0],c[1],time_embedding_dim) for c in channels]) 137 | self.decoder_blocks=nn.ModuleList([DecoderBlock(c[1],c[0],time_embedding_dim) for c in channels[::-1]]) 138 | 139 | self.mid_block=nn.Sequential(*[ResidualBottleneck(channels[-1][1],channels[-1][1]) for i in range(2)], 140 | ResidualBottleneck(channels[-1][1],channels[-1][1]//2)) 141 | 142 | self.final_conv=nn.Conv2d(in_channels=channels[0][0]//2,out_channels=out_channels,kernel_size=1) 143 | 144 | def forward(self,x,t=None): 145 | x=self.init_conv(x) 146 | if t is not None: 147 | t=self.time_embedding(t) 148 | encoder_shortcuts=[] 149 | for encoder_block in self.encoder_blocks: 150 | x,x_shortcut=encoder_block(x,t) 151 | encoder_shortcuts.append(x_shortcut) 152 | x=self.mid_block(x) 153 | encoder_shortcuts.reverse() 154 | for decoder_block,shortcut in zip(self.decoder_blocks,encoder_shortcuts): 155 | x=decoder_block(x,shortcut,t) 156 | x=self.final_conv(x) 157 | 158 | return x 159 | 160 | def _cal_channels(self,base_dim,dim_mults): 161 | dims=[base_dim*x for x in dim_mults] 162 | dims.insert(0,base_dim) 163 | channels=[] 164 | for i in range(len(dims)-1): 165 | channels.append((dims[i],dims[i+1])) # in_channel, out_channel 166 | 167 | return channels 168 | 169 | if __name__=="__main__": 170 | x=torch.randn(3,3,224,224) 171 | t=torch.randint(0,1000,(3,)) 172 | model=Unet(1000,128) 173 | y=model(x,t) 174 | print(y.shape) 175 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | #torchvision ema implementation 4 | #https://github.com/pytorch/vision/blob/main/references/classification/utils.py#L159 5 | class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): 6 | """Maintains moving averages of model parameters using an exponential decay. 7 | ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` 8 | `torch.optim.swa_utils.AveragedModel `_ 9 | is used to compute the EMA. 10 | """ 11 | 12 | def __init__(self, model, decay, device="cpu"): 13 | def ema_avg(avg_model_param, model_param, num_averaged): 14 | return decay * avg_model_param + (1 - decay) * model_param 15 | 16 | super().__init__(model, device, ema_avg, use_buffers=True) --------------------------------------------------------------------------------