├── dataset ├── __init__.py └── toy.py ├── modules ├── __init__.py ├── common.py └── network.py ├── README.md └── scripts ├── train_toy.py └── train_mnist.py /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Minimal code and simple experiments to play with Denoising Diffusion Probabilistic Models (DDPMs) 2 | 3 | All experiments have tensorboard visualizations for samples / train curves etc. 4 | 5 | 1. To run the toy data experiments: 6 | ``` 7 | python scripts/train_toy.py --dataset swissroll --save_path logs/swissroll 8 | ``` 9 | 10 | 2. To run the discrete mode collapse experiment: 11 | ``` 12 | python scripts/train_mnist.py --save_path logs/mnist_3 --n_stack 3 13 | ``` 14 | 15 | This requires the pretrained mnist classifier: 16 | ``` 17 | python scripts/train/mnist_classifier.py 18 | ``` 19 | 20 | 3. To run the CIFAR image generation experiment: 21 | ``` 22 | python scripts/train_cifar.py --save_path logs/cifar 23 | ``` 24 | 25 | 4. To run the CelebA image generation experiments: 26 | ``` 27 | python scripts/train_celeba.py --save_path logs/celeba 28 | ``` 29 | -------------------------------------------------------------------------------- /dataset/toy.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import make_swiss_roll 2 | import numpy as np 3 | import random 4 | 5 | 6 | def inf_train_gen(dataset, batch_size): 7 | if dataset == "25gaussians": 8 | dataset = [] 9 | for i in range(100000 // 25): 10 | for x in range(-2, 3): 11 | for y in range(-2, 3): 12 | point = np.random.randn(2) * 0.05 13 | point[0] += 2 * x 14 | point[1] += 2 * y 15 | dataset.append(point) 16 | dataset = np.array(dataset, dtype="float32") 17 | np.random.shuffle(dataset) 18 | dataset /= 2.828 # stdev 19 | while True: 20 | for i in range(len(dataset) // batch_size): 21 | yield dataset[i * batch_size : (i + 1) * batch_size] 22 | 23 | elif dataset == "swissroll": 24 | while True: 25 | data = make_swiss_roll(n_samples=batch_size, noise=0.25)[0] 26 | data = data.astype("float32")[:, [0, 2]] 27 | data /= 7.5 # stdev plus a little 28 | yield data / 2 29 | 30 | elif dataset == "8gaussians": 31 | scale = 2.0 32 | centers = [ 33 | (1, 0), 34 | (-1, 0), 35 | (0, 1), 36 | (0, -1), 37 | (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), 38 | (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), 39 | (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), 40 | (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), 41 | ] 42 | centers = [(scale * x, scale * y) for x, y in centers] 43 | while True: 44 | dataset = [] 45 | for i in range(batch_size): 46 | point = np.random.randn(2) * 0.02 47 | center = random.choice(centers) 48 | point[0] += center[0] 49 | point[1] += center[1] 50 | dataset.append(point) 51 | dataset = np.array(dataset, dtype="float32") 52 | dataset /= 1.414 # stdev 53 | yield dataset 54 | -------------------------------------------------------------------------------- /modules/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm import tqdm 4 | import numpy as np 5 | 6 | 7 | def simple_schedule(beta_1, beta_T, T, device): 8 | betas = torch.linspace(beta_1, beta_T, T, device=device) 9 | alpha_bars = (1 - betas).cumprod(-1) 10 | 11 | return {"alphas": 1 - betas, "betas": betas, "alpha_bars": alpha_bars} 12 | 13 | 14 | def cosine_schedule(T, device, s=0.008): 15 | t = torch.arange(0, T + 1, dtype=torch.float32, device=device) 16 | f_t = torch.cos(0.5 * np.pi * (t / T + s) / (1 + s)) ** 2 17 | 18 | alpha_bars = f_t / f_t[0] 19 | betas = (1 - alpha_bars[1:] / alpha_bars[:-1]).clip(min=0, max=0.999) 20 | 21 | return {"alphas": 1 - betas, "betas": betas, "alpha_bars": alpha_bars[1:]} 22 | 23 | 24 | def forward_process(x_0, t, params): 25 | # Sample eps from N(0, 1) 26 | z_t = torch.randn_like(x_0) 27 | 28 | # Index alpha_bar_t 29 | shp = (x_0.size(0),) + (1,) * (x_0.dim() - 1) 30 | alpha_bar_t = params["alpha_bars"][t].view(*shp) 31 | 32 | # Compute x_t as interpolation of x_0 and x_T 33 | x_t = alpha_bar_t.pow(0.5) * x_0 + (1 - alpha_bar_t).pow(0.5) * z_t 34 | 35 | return z_t, x_t 36 | 37 | 38 | def reverse_process(z_t, x_t, t, params): 39 | # Index the diffusion params 40 | shp = (1,) * x_t.dim() 41 | alpha_t, beta_t, alpha_bar_t = [ 42 | params[key][t].view(*shp) for key in ["alphas", "betas", "alpha_bars"] 43 | ] 44 | 45 | # Sample noise from N(0, 1) 46 | eps = torch.randn_like(x_t) if t > 0 else 0 47 | 48 | # Compute the reverse process at step t by removing the predicted noise from x_t 49 | x_tm1 = ( 50 | alpha_t.pow(-0.5) * (x_t - z_t * beta_t * (1 - alpha_bar_t).pow(-0.5)) 51 | + beta_t * eps 52 | ) 53 | 54 | return x_tm1 55 | 56 | 57 | def train_loop(data, model, opt, params, args): 58 | # Sample diffusion time-step 59 | t = torch.randint(0, args.diffusion_steps, (data.size(0),), device=data.device) 60 | 61 | # Run the forward process 62 | z_t, x_t = forward_process(data, t, params) 63 | 64 | # Run reverse process network to predict the noise added 65 | pred_z_t = model(x_t, t) 66 | 67 | # Compute the simple diffusion loss 68 | loss_simple = F.mse_loss(pred_z_t, z_t) 69 | 70 | # Perform backward pass 71 | opt.zero_grad() 72 | loss_simple.backward() 73 | opt.step() 74 | 75 | return {"diffusion_loss": loss_simple.item()} 76 | 77 | 78 | def test_loop(x_shp, model, params, args): 79 | # Start from random noise in N(0, 1) 80 | x_t = torch.randn(*x_shp, device=args.device) 81 | 82 | # Iteratively run the reverse process to convert noise to data 83 | for t in tqdm(reversed(range(args.diffusion_steps))): 84 | tensor_t = torch.tensor([t] * x_t.size(0), device=args.device) 85 | z_t = model(x_t, tensor_t) 86 | x_t = reverse_process(z_t, x_t, t, params) 87 | 88 | return x_t 89 | -------------------------------------------------------------------------------- /modules/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | 6 | class FiLM1d(nn.Module): 7 | def __init__(self, hidden_dim, n_classes): 8 | super().__init__() 9 | self.emb = nn.Embedding(n_classes, hidden_dim * 2) 10 | self.norm = nn.BatchNorm1d(hidden_dim, affine=False) 11 | 12 | def forward(self, x, y): 13 | out = self.norm(x) 14 | alpha, beta = self.emb(y).chunk(2, dim=-1) 15 | return alpha + out * (1 + beta) 16 | 17 | 18 | class FiLM2d(nn.Module): 19 | def __init__(self, hidden_dim, n_classes): 20 | super().__init__() 21 | self.emb = nn.Embedding(n_classes, hidden_dim * 2) 22 | self.norm = nn.BatchNorm2d(hidden_dim, affine=False) 23 | 24 | def forward(self, x, y): 25 | out = self.norm(x) 26 | alpha, beta = self.emb(y)[..., None, None].chunk(2, dim=1) 27 | return alpha + out * (1 + beta) 28 | 29 | 30 | class MySequential(nn.Sequential): 31 | def forward(self, x, y): 32 | for module in self: 33 | if ( 34 | isinstance(module, FiLM1d) 35 | or isinstance(module, FiLM2d) 36 | or isinstance(module, ResBlock) 37 | ): 38 | x = module(x, y) 39 | else: 40 | x = module(x) 41 | return x 42 | 43 | 44 | class ToyNet(nn.Module): 45 | def __init__(self, input_dim, hidden_dim, n_classes): 46 | super().__init__() 47 | self.model = MySequential( 48 | nn.Linear(input_dim, hidden_dim), 49 | FiLM1d(hidden_dim, n_classes), 50 | nn.ReLU(), 51 | nn.Linear(hidden_dim, hidden_dim), 52 | FiLM1d(hidden_dim, n_classes), 53 | nn.ReLU(), 54 | nn.Linear(hidden_dim, input_dim), 55 | ) 56 | 57 | def forward(self, x, y): 58 | return self.model(x, y) 59 | 60 | 61 | class ResBlock(nn.Module): 62 | def __init__(self, hidden_dim, n_classes): 63 | super().__init__() 64 | self.conv1 = MySequential( 65 | FiLM2d(hidden_dim, n_classes), 66 | nn.ReLU(), 67 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), 68 | ) 69 | 70 | def forward(self, x, y): 71 | return x + self.conv1(x, y) 72 | 73 | 74 | class MnistNet(nn.Module): 75 | def __init__(self, input_dim, n_downsample, n_resblocks, ngf, n_classes): 76 | super().__init__() 77 | # Add initial layer 78 | model = [ 79 | nn.Conv2d(input_dim, ngf, kernel_size=7, padding=3), 80 | FiLM2d(ngf, n_classes), 81 | nn.ReLU(), 82 | ] 83 | 84 | # Add downsampling layers 85 | for i in range(n_downsample): 86 | mult = 2 ** i 87 | model += [ 88 | nn.Conv2d( 89 | ngf * mult, ngf * mult * 2, kernel_size=4, stride=2, padding=1 90 | ), 91 | FiLM2d(ngf * mult * 2, n_classes), 92 | nn.ReLU(), 93 | ] 94 | 95 | # Add ResNet layers 96 | mult = 2 ** n_downsample 97 | for i in range(n_resblocks): 98 | model += [ResBlock(ngf * mult, n_classes)] 99 | 100 | # Add upsampling layers 101 | for i in range(n_downsample): 102 | mult = 2 ** (n_downsample - i) 103 | model += [ 104 | nn.ConvTranspose2d( 105 | ngf * mult, ngf * mult // 2, kernel_size=4, stride=2, padding=1 106 | ), 107 | FiLM2d(ngf * mult // 2, n_classes), 108 | nn.ReLU(), 109 | ] 110 | 111 | # Add output layers 112 | model += [nn.Conv2d(ngf, input_dim, kernel_size=7, padding=3), nn.Tanh()] 113 | 114 | # Store as sequential layer 115 | self.model = MySequential(*model) 116 | 117 | def forward(self, x, y): 118 | return self.model(x, y) 119 | -------------------------------------------------------------------------------- /scripts/train_toy.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from pathlib import Path 3 | import argparse 4 | import time 5 | import yaml 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.tensorboard.writer import SummaryWriter 11 | 12 | from dataset.toy import inf_train_gen 13 | from modules.network import ToyNet 14 | from modules.common import simple_schedule, cosine_schedule, train_loop, test_loop 15 | 16 | 17 | def make_toy_figure(x): 18 | fig = plt.Figure() 19 | ax = fig.add_subplot(111) 20 | ax.scatter(x[:, 0], x[:, 1]) 21 | return fig 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--save_path", required=True) 27 | parser.add_argument("--dataset", required=True) 28 | 29 | parser.add_argument("--input_dim", type=int, default=2) 30 | parser.add_argument("--hidden_dim", type=int, default=256) 31 | parser.add_argument("--diffusion_steps", type=int, default=20) 32 | parser.add_argument("--beta_1", type=int, default=1e-4) 33 | parser.add_argument("--beta_T", type=int, default=0.02) 34 | parser.add_argument("--device", type=str, default="cpu") 35 | parser.add_argument("--schedule", type=str, default="simple") 36 | 37 | parser.add_argument("--batch_size", type=int, default=512) 38 | parser.add_argument("--iters", type=int, default=100000) 39 | parser.add_argument("--n_test_points", type=int, default=768) 40 | parser.add_argument("--print_interval", type=int, default=50) 41 | parser.add_argument("--plot_interval", type=int, default=250) 42 | parser.add_argument("--save_interval", type=int, default=1000) 43 | 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def main(): 49 | args = parse_args() 50 | 51 | # set the seed 52 | seed = 111 53 | np.random.seed(seed) 54 | torch.manual_seed(seed) 55 | 56 | root = Path(args.save_path) 57 | root.mkdir(parents=True, exist_ok=True) 58 | (root / "imgs").mkdir(exist_ok=True) 59 | 60 | ####################### 61 | # Create data loaders # 62 | ####################### 63 | data_itr = inf_train_gen(args.dataset, args.batch_size) 64 | 65 | #################################### 66 | # Dump arguments and create logger # 67 | #################################### 68 | writer = SummaryWriter(str(root)) 69 | with open(root / "args.yml", "w") as f: 70 | yaml.dump(args, f) 71 | 72 | ######################### 73 | # Create PyTorch Models # 74 | ######################### 75 | model = ToyNet(args.input_dim, args.hidden_dim, args.diffusion_steps).to( 76 | args.device 77 | ) 78 | opt = torch.optim.Adam(model.parameters(), lr=3e-4) 79 | 80 | ###################### 81 | # Dump Original Data # 82 | ###################### 83 | orig_data = inf_train_gen(args.dataset, args.n_test_points).__next__() 84 | fig = make_toy_figure(orig_data) 85 | writer.add_figure("original", fig, 0) 86 | 87 | ############################### 88 | # Create diffusion parameters # 89 | ############################### 90 | if args.schedule == "simple": 91 | params = simple_schedule( 92 | args.beta_1, args.beta_T, args.diffusion_steps, args.device 93 | ) 94 | elif args.schedule == "cosine": 95 | params = cosine_schedule(args.diffusion_steps, args.device) 96 | 97 | ################################ 98 | 99 | costs = { 100 | "diffusion_loss": [], 101 | } 102 | start = time.time() 103 | for iters in range(args.iters): 104 | model.train() 105 | 106 | # Sampe data 107 | data = torch.from_numpy(data_itr.__next__()).to(args.device) 108 | 109 | # Run train loop 110 | metrics = train_loop(data, model, opt, params, args) 111 | 112 | # Update tensorboard 113 | for key, val in metrics.items(): 114 | writer.add_scalar(key, val, iters) 115 | costs[key].append(val) 116 | 117 | if iters % args.print_interval == 0: 118 | mean_costs = [f"{key}: {np.mean(val):.3f}" for key, val in costs.items()] 119 | log = ( 120 | f"Steps {iters} | " 121 | f"ms/batch {1e3 * (time.time() - start) / args.print_interval:5.2f} | " 122 | f"loss {mean_costs}" 123 | ) 124 | print(log) 125 | 126 | costs = {key: [] for key in costs.keys()} 127 | start = time.time() 128 | 129 | if iters % args.plot_interval == 0: 130 | model.eval() 131 | st = time.time() 132 | print("#" * 30) 133 | print("Generating samples") 134 | 135 | with torch.no_grad(): 136 | samples = test_loop(orig_data.shape, model, params, args) 137 | 138 | fig = make_toy_figure(samples.cpu()) 139 | writer.add_figure("generated", fig, iters) 140 | fig.savefig(root / "imgs" / ("sample_%05d.png" % iters)) 141 | 142 | print(f"Completed in {time.time() - st:.2f}") 143 | print("#" * 30) 144 | 145 | if iters % args.save_interval == 0: 146 | torch.save(model.state_dict(), root / "model.pt") 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /scripts/train_mnist.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from pathlib import Path 3 | import argparse 4 | import time 5 | import yaml 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torchvision.datasets import MNIST 11 | from torchvision.utils import make_grid 12 | import torchvision.transforms as tf 13 | from torch.utils.tensorboard.writer import SummaryWriter 14 | 15 | from modules.network import MnistNet 16 | from modules.common import simple_schedule, cosine_schedule, train_loop, test_loop 17 | 18 | 19 | def make_figure(x): 20 | fig = plt.Figure() 21 | ax = fig.add_subplot(111) 22 | x = make_grid(x, nrow=int(x.size(0) ** 0.5)) 23 | ax.imshow(x.permute(1, 2, 0)) 24 | return fig 25 | 26 | 27 | def make_infinite_iterator(loader): 28 | while True: 29 | for data in loader: 30 | yield data 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--save_path", required=True) 36 | 37 | parser.add_argument("--input_dim", type=int, default=1) 38 | parser.add_argument("--n_downsample", type=int, default=2) 39 | parser.add_argument("--n_resblocks", type=int, default=5) 40 | parser.add_argument("--ngf", type=int, default=16) 41 | 42 | parser.add_argument("--diffusion_steps", type=int, default=20) 43 | parser.add_argument("--beta_1", type=int, default=1e-4) 44 | parser.add_argument("--beta_T", type=int, default=0.02) 45 | parser.add_argument("--device", type=str, default="cpu") 46 | parser.add_argument("--schedule", type=str, default="simple") 47 | 48 | parser.add_argument("--batch_size", type=int, default=8) 49 | parser.add_argument("--iters", type=int, default=100000) 50 | parser.add_argument("--n_test_points", type=int, default=16) 51 | parser.add_argument("--print_interval", type=int, default=50) 52 | parser.add_argument("--plot_interval", type=int, default=250) 53 | parser.add_argument("--save_interval", type=int, default=1000) 54 | 55 | args = parser.parse_args() 56 | return args 57 | 58 | 59 | def main(): 60 | args = parse_args() 61 | 62 | # set the seed 63 | seed = 111 64 | np.random.seed(seed) 65 | torch.manual_seed(seed) 66 | 67 | root = Path(args.save_path) 68 | root.mkdir(parents=True, exist_ok=True) 69 | (root / "imgs").mkdir(exist_ok=True) 70 | 71 | ####################### 72 | # Create data loaders # 73 | ####################### 74 | train_loader = torch.utils.data.DataLoader( 75 | MNIST( 76 | "./files/", 77 | train=True, 78 | download=True, 79 | transform=tf.Compose( 80 | [ 81 | tf.ToTensor(), 82 | tf.Normalize((0.1307,), (0.3081,)), 83 | ] 84 | ), 85 | ), 86 | batch_size=args.batch_size, 87 | shuffle=True, 88 | ) 89 | data_itr = make_infinite_iterator(train_loader) 90 | 91 | test_loader = torch.utils.data.DataLoader( 92 | MNIST( 93 | "./files/", 94 | train=False, 95 | download=True, 96 | transform=tf.Compose( 97 | [ 98 | tf.ToTensor(), 99 | tf.Normalize((0.1307,), (0.3081,)), 100 | ] 101 | ), 102 | ), 103 | batch_size=args.n_test_points, 104 | shuffle=True, 105 | ) 106 | 107 | #################################### 108 | # Dump arguments and create logger # 109 | #################################### 110 | writer = SummaryWriter(str(root)) 111 | with open(root / "args.yml", "w") as f: 112 | yaml.dump(args, f) 113 | 114 | ######################### 115 | # Create PyTorch Models # 116 | ######################### 117 | model = MnistNet( 118 | args.input_dim, 119 | args.n_downsample, 120 | args.n_resblocks, 121 | args.ngf, 122 | args.diffusion_steps, 123 | ).to(args.device) 124 | opt = torch.optim.Adam(model.parameters(), lr=3e-4) 125 | 126 | print(model) 127 | 128 | ###################### 129 | # Dump Original Data # 130 | ###################### 131 | orig_data = next(iter(test_loader))[0] 132 | fig = make_figure(orig_data) 133 | writer.add_figure("original", fig, 0) 134 | 135 | ############################### 136 | # Create diffusion parameters # 137 | ############################### 138 | if args.schedule == "simple": 139 | params = simple_schedule( 140 | args.beta_1, args.beta_T, args.diffusion_steps, args.device 141 | ) 142 | elif args.schedule == "cosine": 143 | params = cosine_schedule(args.diffusion_steps, args.device) 144 | 145 | ################################ 146 | 147 | costs = { 148 | "diffusion_loss": [], 149 | } 150 | start = time.time() 151 | for iters in range(args.iters): 152 | model.train() 153 | 154 | # Sampe data 155 | data = next(data_itr)[0].to(args.device) 156 | 157 | # Run train loop 158 | metrics = train_loop(data, model, opt, params, args) 159 | 160 | # Update tensorboard 161 | for key, val in metrics.items(): 162 | writer.add_scalar(key, val, iters) 163 | costs[key].append(val) 164 | 165 | if iters % args.print_interval == 0: 166 | mean_costs = [f"{key}: {np.mean(val):.3f}" for key, val in costs.items()] 167 | log = ( 168 | f"Steps {iters} | " 169 | f"ms/batch {1e3 * (time.time() - start) / args.print_interval:5.2f} | " 170 | f"loss {mean_costs}" 171 | ) 172 | print(log) 173 | 174 | costs = {key: [] for key in costs.keys()} 175 | start = time.time() 176 | 177 | if iters % args.plot_interval == 0: 178 | model.eval() 179 | st = time.time() 180 | print("#" * 30) 181 | print("Generating samples") 182 | 183 | with torch.no_grad(): 184 | samples = test_loop(orig_data.shape, model, params, args) 185 | 186 | fig = make_figure(samples) 187 | writer.add_figure("generated", fig, iters) 188 | fig.savefig(root / "imgs" / ("sample_%05d.png" % iters)) 189 | 190 | print(f"Completed in {time.time() - st:.2f}") 191 | print("#" * 30) 192 | 193 | if iters % args.save_interval == 0: 194 | torch.save(model.state_dict(), root / "model.pt") 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | --------------------------------------------------------------------------------