├── ddpm ├── __init__.py ├── utils.py ├── ema.py ├── script_utils.py ├── diffusion.py └── unet.py ├── .gitignore ├── resources ├── samples_linear_200.png ├── diffusion_models_report.pdf ├── diffusion_sequence_mnist.gif └── diffusion_models_talk_slides.pdf ├── setup.py ├── scripts ├── sample_images.py └── train_cifar.py └── README.md /ddpm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ddpm_env/ 2 | .idea/ 3 | .vscode/ 4 | *.egg-info/ -------------------------------------------------------------------------------- /resources/samples_linear_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abarankab/DDPM/HEAD/resources/samples_linear_200.png -------------------------------------------------------------------------------- /resources/diffusion_models_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abarankab/DDPM/HEAD/resources/diffusion_models_report.pdf -------------------------------------------------------------------------------- /resources/diffusion_sequence_mnist.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abarankab/DDPM/HEAD/resources/diffusion_sequence_mnist.gif -------------------------------------------------------------------------------- /resources/diffusion_models_talk_slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abarankab/DDPM/HEAD/resources/diffusion_models_talk_slides.pdf -------------------------------------------------------------------------------- /ddpm/utils.py: -------------------------------------------------------------------------------- 1 | def extract(a, t, x_shape): 2 | b, *_ = t.shape 3 | out = a.gather(-1, t) 4 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="ddpm", 5 | py_modules=["ddpm"], 6 | install_requires=["torch", "torchvision", "einops", "wandb", "joblib"], 7 | ) -------------------------------------------------------------------------------- /ddpm/ema.py: -------------------------------------------------------------------------------- 1 | class EMA(): 2 | def __init__(self, decay): 3 | self.decay = decay 4 | 5 | def update_average(self, old, new): 6 | if old is None: 7 | return new 8 | return old * self.decay + (1 - self.decay) * new 9 | 10 | def update_model_average(self, ema_model, current_model): 11 | for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()): 12 | old, new = ema_params.data, current_params.data 13 | ema_params.data = self.update_average(old, new) -------------------------------------------------------------------------------- /scripts/sample_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torchvision 4 | 5 | from ddpm import script_utils 6 | 7 | 8 | def main(): 9 | args = create_argparser().parse_args() 10 | device = args.device 11 | 12 | try: 13 | diffusion = script_utils.get_diffusion_from_args(args).to(device) 14 | diffusion.load_state_dict(torch.load(args.model_path)) 15 | 16 | if args.use_labels: 17 | for label in range(10): 18 | y = torch.ones(args.num_images // 10, dtype=torch.long, device=device) * label 19 | samples = diffusion.sample(args.num_images // 10, device, y=y) 20 | 21 | for image_id in range(len(samples)): 22 | image = ((samples[image_id] + 1) / 2).clip(0, 1) 23 | torchvision.utils.save_image(image, f"{args.save_dir}/{label}-{image_id}.png") 24 | else: 25 | samples = diffusion.sample(args.num_images, device) 26 | 27 | for image_id in range(len(samples)): 28 | image = ((samples[image_id] + 1) / 2).clip(0, 1) 29 | torchvision.utils.save_image(image, f"{args.save_dir}/{image_id}.png") 30 | except KeyboardInterrupt: 31 | print("Keyboard interrupt, generation finished early") 32 | 33 | 34 | def create_argparser(): 35 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 36 | defaults = dict(num_images=10000, device=device) 37 | defaults.update(script_utils.diffusion_defaults()) 38 | 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--model_path", type=str) 41 | parser.add_argument("--save_dir", type=str) 42 | script_utils.add_dict_to_argparser(parser, defaults) 43 | return parser 44 | 45 | 46 | if __name__ == "__main__": 47 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Denoising Diffusion Probabilistic Models 2 | 3 | An implementation of Denoising Diffusion Probabilistic Models for image generation written in PyTorch. This roughly follows the original code by Ho et al. Unlike their implementation, however, my model allows for class conditioning through bias in residual blocks. 4 | 5 | ## Experiments 6 | 7 | I have trained the model on MNIST and CIFAR-10 datasets. The model seemed to converge well on the MNIST dataset, producing realistic samples. However, I am yet to report the same CIFAR-10 quality that Ho. et al. provide in their paper. Here are the samples generated with a linear schedule after 2000 epochs: 8 | 9 | ![Samples after 2000 epochs](resources/samples_linear_200.png) 10 | 11 | Here is a sample of a diffusion sequence on MNIST: 12 | 13 |

14 | 15 |

16 | 17 | 18 | ## Resources 19 | 20 | I gave a talk about diffusion models, NCSNs, and their applications in audio generation. The [slides are available here](resources/diffusion_models_talk_slides.pdf). 21 | 22 | I also compiled a report with what are, in my opinion, the most crucial findings on the topic of denoising diffusion models. It is also [available in this repository](resources/diffusion_models_report.pdf). 23 | 24 | 25 | ## Acknowledgements 26 | 27 | I used [Phil Wang's implementation](https://github.com/lucidrains/denoising-diffusion-pytorch) and [the official Tensorflow repo](https://github.com/hojonathanho/diffusion) as a reference for my work. 28 | 29 | ## Citations 30 | 31 | ```bibtex 32 | @misc{ho2020denoising, 33 | title = {Denoising Diffusion Probabilistic Models}, 34 | author = {Jonathan Ho and Ajay Jain and Pieter Abbeel}, 35 | year = {2020}, 36 | eprint = {2006.11239}, 37 | archivePrefix = {arXiv}, 38 | primaryClass = {cs.LG} 39 | } 40 | ``` 41 | 42 | ```bibtex 43 | @inproceedings{anonymous2021improved, 44 | title = {Improved Denoising Diffusion Probabilistic Models}, 45 | author = {Anonymous}, 46 | booktitle = {Submitted to International Conference on Learning Representations}, 47 | year = {2021}, 48 | url = {https://openreview.net/forum?id=-NEXDKk8gZ}, 49 | note = {under review} 50 | } 51 | -------------------------------------------------------------------------------- /ddpm/script_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchvision 3 | import torch.nn.functional as F 4 | 5 | from .unet import UNet 6 | from .diffusion import ( 7 | GaussianDiffusion, 8 | generate_linear_schedule, 9 | generate_cosine_schedule, 10 | ) 11 | 12 | 13 | def cycle(dl): 14 | """ 15 | https://github.com/lucidrains/denoising-diffusion-pytorch/ 16 | """ 17 | while True: 18 | for data in dl: 19 | yield data 20 | 21 | def get_transform(): 22 | class RescaleChannels(object): 23 | def __call__(self, sample): 24 | return 2 * sample - 1 25 | 26 | return torchvision.transforms.Compose([ 27 | torchvision.transforms.ToTensor(), 28 | RescaleChannels(), 29 | ]) 30 | 31 | 32 | def str2bool(v): 33 | """ 34 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 35 | """ 36 | if isinstance(v, bool): 37 | return v 38 | if v.lower() in ("yes", "true", "t", "y", "1"): 39 | return True 40 | elif v.lower() in ("no", "false", "f", "n", "0"): 41 | return False 42 | else: 43 | raise argparse.ArgumentTypeError("boolean value expected") 44 | 45 | 46 | def add_dict_to_argparser(parser, default_dict): 47 | """ 48 | https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py 49 | """ 50 | for k, v in default_dict.items(): 51 | v_type = type(v) 52 | if v is None: 53 | v_type = str 54 | elif isinstance(v, bool): 55 | v_type = str2bool 56 | parser.add_argument(f"--{k}", default=v, type=v_type) 57 | 58 | 59 | def diffusion_defaults(): 60 | defaults = dict( 61 | num_timesteps=1000, 62 | schedule="linear", 63 | loss_type="l2", 64 | use_labels=False, 65 | 66 | base_channels=128, 67 | channel_mults=(1, 2, 2, 2), 68 | num_res_blocks=2, 69 | time_emb_dim=128 * 4, 70 | norm="gn", 71 | dropout=0.1, 72 | activation="silu", 73 | attention_resolutions=(1,), 74 | 75 | ema_decay=0.9999, 76 | ema_update_rate=1, 77 | ) 78 | 79 | return defaults 80 | 81 | 82 | def get_diffusion_from_args(args): 83 | activations = { 84 | "relu": F.relu, 85 | "mish": F.mish, 86 | "silu": F.silu, 87 | } 88 | 89 | model = UNet( 90 | img_channels=3, 91 | 92 | base_channels=args.base_channels, 93 | channel_mults=args.channel_mults, 94 | time_emb_dim=args.time_emb_dim, 95 | norm=args.norm, 96 | dropout=args.dropout, 97 | activation=activations[args.activation], 98 | attention_resolutions=args.attention_resolutions, 99 | 100 | num_classes=None if not args.use_labels else 10, 101 | initial_pad=0, 102 | ) 103 | 104 | if args.schedule == "cosine": 105 | betas = generate_cosine_schedule(args.num_timesteps) 106 | else: 107 | betas = generate_linear_schedule( 108 | args.num_timesteps, 109 | args.schedule_low * 1000 / args.num_timesteps, 110 | args.schedule_high * 1000 / args.num_timesteps, 111 | ) 112 | 113 | diffusion = GaussianDiffusion( 114 | model, (32, 32), 3, 10, 115 | betas, 116 | ema_decay=args.ema_decay, 117 | ema_update_rate=args.ema_update_rate, 118 | ema_start=2000, 119 | loss_type=args.loss_type, 120 | ) 121 | 122 | return diffusion -------------------------------------------------------------------------------- /scripts/train_cifar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import torch 4 | import wandb 5 | 6 | from torch.utils.data import DataLoader 7 | from torchvision import datasets 8 | from ddpm import script_utils 9 | 10 | 11 | def main(): 12 | args = create_argparser().parse_args() 13 | device = args.device 14 | 15 | try: 16 | diffusion = script_utils.get_diffusion_from_args(args).to(device) 17 | optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate) 18 | 19 | if args.model_checkpoint is not None: 20 | diffusion.load_state_dict(torch.load(args.model_checkpoint)) 21 | if args.optim_checkpoint is not None: 22 | optimizer.load_state_dict(torch.load(args.optim_checkpoint)) 23 | 24 | if args.log_to_wandb: 25 | if args.project_name is None: 26 | raise ValueError("args.log_to_wandb set to True but args.project_name is None") 27 | 28 | run = wandb.init( 29 | project=args.project_name, 30 | entity='treaptofun', 31 | config=vars(args), 32 | name=args.run_name, 33 | ) 34 | wandb.watch(diffusion) 35 | 36 | batch_size = args.batch_size 37 | 38 | train_dataset = datasets.CIFAR10( 39 | root='./cifar_train', 40 | train=True, 41 | download=True, 42 | transform=script_utils.get_transform(), 43 | ) 44 | 45 | test_dataset = datasets.CIFAR10( 46 | root='./cifar_test', 47 | train=False, 48 | download=True, 49 | transform=script_utils.get_transform(), 50 | ) 51 | 52 | train_loader = script_utils.cycle(DataLoader( 53 | train_dataset, 54 | batch_size=batch_size, 55 | shuffle=True, 56 | drop_last=True, 57 | num_workers=2, 58 | )) 59 | test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2) 60 | 61 | acc_train_loss = 0 62 | 63 | for iteration in range(1, args.iterations + 1): 64 | diffusion.train() 65 | 66 | x, y = next(train_loader) 67 | x = x.to(device) 68 | y = y.to(device) 69 | 70 | if args.use_labels: 71 | loss = diffusion(x, y) 72 | else: 73 | loss = diffusion(x) 74 | 75 | acc_train_loss += loss.item() 76 | 77 | optimizer.zero_grad() 78 | loss.backward() 79 | optimizer.step() 80 | 81 | diffusion.update_ema() 82 | 83 | if iteration % args.log_rate == 0: 84 | test_loss = 0 85 | with torch.no_grad(): 86 | diffusion.eval() 87 | for x, y in test_loader: 88 | x = x.to(device) 89 | y = y.to(device) 90 | 91 | if args.use_labels: 92 | loss = diffusion(x, y) 93 | else: 94 | loss = diffusion(x) 95 | 96 | test_loss += loss.item() 97 | 98 | if args.use_labels: 99 | samples = diffusion.sample(10, device, y=torch.arange(10, device=device)) 100 | else: 101 | samples = diffusion.sample(10, device) 102 | 103 | samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy() 104 | 105 | test_loss /= len(test_loader) 106 | acc_train_loss /= args.log_rate 107 | 108 | wandb.log({ 109 | "test_loss": test_loss, 110 | "train_loss": acc_train_loss, 111 | "samples": [wandb.Image(sample) for sample in samples], 112 | }) 113 | 114 | acc_train_loss = 0 115 | 116 | if iteration % args.checkpoint_rate == 0: 117 | model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth" 118 | optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth" 119 | 120 | torch.save(diffusion.state_dict(), model_filename) 121 | torch.save(optimizer.state_dict(), optim_filename) 122 | 123 | if args.log_to_wandb: 124 | run.finish() 125 | except KeyboardInterrupt: 126 | if args.log_to_wandb: 127 | run.finish() 128 | print("Keyboard interrupt, run finished early") 129 | 130 | 131 | def create_argparser(): 132 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 133 | run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M") 134 | defaults = dict( 135 | learning_rate=2e-4, 136 | batch_size=128, 137 | iterations=800000, 138 | 139 | log_to_wandb=True, 140 | log_rate=1000, 141 | checkpoint_rate=1000, 142 | log_dir="~/ddpm_logs", 143 | project_name=None, 144 | run_name=run_name, 145 | 146 | model_checkpoint=None, 147 | optim_checkpoint=None, 148 | 149 | schedule_low=1e-4, 150 | schedule_high=0.02, 151 | 152 | device=device, 153 | ) 154 | defaults.update(script_utils.diffusion_defaults()) 155 | 156 | parser = argparse.ArgumentParser() 157 | script_utils.add_dict_to_argparser(parser, defaults) 158 | return parser 159 | 160 | 161 | if __name__ == "__main__": 162 | main() -------------------------------------------------------------------------------- /ddpm/diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from functools import partial 7 | from copy import deepcopy 8 | 9 | from .ema import EMA 10 | from .utils import extract 11 | 12 | class GaussianDiffusion(nn.Module): 13 | __doc__ = r"""Gaussian Diffusion model. Forwarding through the module returns diffusion reversal scalar loss tensor. 14 | 15 | Input: 16 | x: tensor of shape (N, img_channels, *img_size) 17 | y: tensor of shape (N) 18 | Output: 19 | scalar loss tensor 20 | Args: 21 | model (nn.Module): model which estimates diffusion noise 22 | img_size (tuple): image size tuple (H, W) 23 | img_channels (int): number of image channels 24 | betas (np.ndarray): numpy array of diffusion betas 25 | loss_type (string): loss type, "l1" or "l2" 26 | ema_decay (float): model weights exponential moving average decay 27 | ema_start (int): number of steps before EMA 28 | ema_update_rate (int): number of steps before each EMA update 29 | """ 30 | def __init__( 31 | self, 32 | model, 33 | img_size, 34 | img_channels, 35 | num_classes, 36 | betas, 37 | loss_type="l2", 38 | ema_decay=0.9999, 39 | ema_start=5000, 40 | ema_update_rate=1, 41 | ): 42 | super().__init__() 43 | 44 | self.model = model 45 | self.ema_model = deepcopy(model) 46 | 47 | self.ema = EMA(ema_decay) 48 | self.ema_decay = ema_decay 49 | self.ema_start = ema_start 50 | self.ema_update_rate = ema_update_rate 51 | self.step = 0 52 | 53 | self.img_size = img_size 54 | self.img_channels = img_channels 55 | self.num_classes = num_classes 56 | 57 | if loss_type not in ["l1", "l2"]: 58 | raise ValueError("__init__() got unknown loss type") 59 | 60 | self.loss_type = loss_type 61 | self.num_timesteps = len(betas) 62 | 63 | alphas = 1.0 - betas 64 | alphas_cumprod = np.cumprod(alphas) 65 | 66 | to_torch = partial(torch.tensor, dtype=torch.float32) 67 | 68 | self.register_buffer("betas", to_torch(betas)) 69 | self.register_buffer("alphas", to_torch(alphas)) 70 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) 71 | 72 | self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) 73 | self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1 - alphas_cumprod))) 74 | self.register_buffer("reciprocal_sqrt_alphas", to_torch(np.sqrt(1 / alphas))) 75 | 76 | self.register_buffer("remove_noise_coeff", to_torch(betas / np.sqrt(1 - alphas_cumprod))) 77 | self.register_buffer("sigma", to_torch(np.sqrt(betas))) 78 | 79 | def update_ema(self): 80 | self.step += 1 81 | if self.step % self.ema_update_rate == 0: 82 | if self.step < self.ema_start: 83 | self.ema_model.load_state_dict(self.model.state_dict()) 84 | else: 85 | self.ema.update_model_average(self.ema_model, self.model) 86 | 87 | @torch.no_grad() 88 | def remove_noise(self, x, t, y, use_ema=True): 89 | if use_ema: 90 | return ( 91 | (x - extract(self.remove_noise_coeff, t, x.shape) * self.ema_model(x, t, y)) * 92 | extract(self.reciprocal_sqrt_alphas, t, x.shape) 93 | ) 94 | else: 95 | return ( 96 | (x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) * 97 | extract(self.reciprocal_sqrt_alphas, t, x.shape) 98 | ) 99 | 100 | @torch.no_grad() 101 | def sample(self, batch_size, device, y=None, use_ema=True): 102 | if y is not None and batch_size != len(y): 103 | raise ValueError("sample batch size different from length of given y") 104 | 105 | x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device) 106 | 107 | for t in range(self.num_timesteps - 1, -1, -1): 108 | t_batch = torch.tensor([t], device=device).repeat(batch_size) 109 | x = self.remove_noise(x, t_batch, y, use_ema) 110 | 111 | if t > 0: 112 | x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x) 113 | 114 | return x.cpu().detach() 115 | 116 | @torch.no_grad() 117 | def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True): 118 | if y is not None and batch_size != len(y): 119 | raise ValueError("sample batch size different from length of given y") 120 | 121 | x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device) 122 | diffusion_sequence = [x.cpu().detach()] 123 | 124 | for t in range(self.num_timesteps - 1, -1, -1): 125 | t_batch = torch.tensor([t], device=device).repeat(batch_size) 126 | x = self.remove_noise(x, t_batch, y, use_ema) 127 | 128 | if t > 0: 129 | x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x) 130 | 131 | diffusion_sequence.append(x.cpu().detach()) 132 | 133 | return diffusion_sequence 134 | 135 | def perturb_x(self, x, t, noise): 136 | return ( 137 | extract(self.sqrt_alphas_cumprod, t, x.shape) * x + 138 | extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise 139 | ) 140 | 141 | def get_losses(self, x, t, y): 142 | noise = torch.randn_like(x) 143 | 144 | perturbed_x = self.perturb_x(x, t, noise) 145 | estimated_noise = self.model(perturbed_x, t, y) 146 | 147 | if self.loss_type == "l1": 148 | loss = F.l1_loss(estimated_noise, noise) 149 | elif self.loss_type == "l2": 150 | loss = F.mse_loss(estimated_noise, noise) 151 | 152 | return loss 153 | 154 | def forward(self, x, y=None): 155 | b, c, h, w = x.shape 156 | device = x.device 157 | 158 | if h != self.img_size[0]: 159 | raise ValueError("image height does not match diffusion parameters") 160 | if w != self.img_size[0]: 161 | raise ValueError("image width does not match diffusion parameters") 162 | 163 | t = torch.randint(0, self.num_timesteps, (b,), device=device) 164 | return self.get_losses(x, t, y) 165 | 166 | 167 | def generate_cosine_schedule(T, s=0.008): 168 | def f(t, T): 169 | return (np.cos((t / T + s) / (1 + s) * np.pi / 2)) ** 2 170 | 171 | alphas = [] 172 | f0 = f(0, T) 173 | 174 | for t in range(T + 1): 175 | alphas.append(f(t, T) / f0) 176 | 177 | betas = [] 178 | 179 | for t in range(1, T + 1): 180 | betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999)) 181 | 182 | return np.array(betas) 183 | 184 | 185 | def generate_linear_schedule(T, low, high): 186 | return np.linspace(low, high, T) -------------------------------------------------------------------------------- /ddpm/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.nn.modules.normalization import GroupNorm 7 | 8 | 9 | def get_norm(norm, num_channels, num_groups): 10 | if norm == "in": 11 | return nn.InstanceNorm2d(num_channels, affine=True) 12 | elif norm == "bn": 13 | return nn.BatchNorm2d(num_channels) 14 | elif norm == "gn": 15 | return nn.GroupNorm(num_groups, num_channels) 16 | elif norm is None: 17 | return nn.Identity() 18 | else: 19 | raise ValueError("unknown normalization type") 20 | 21 | 22 | class PositionalEmbedding(nn.Module): 23 | __doc__ = r"""Computes a positional embedding of timesteps. 24 | 25 | Input: 26 | x: tensor of shape (N) 27 | Output: 28 | tensor of shape (N, dim) 29 | Args: 30 | dim (int): embedding dimension 31 | scale (float): linear scale to be applied to timesteps. Default: 1.0 32 | """ 33 | 34 | def __init__(self, dim, scale=1.0): 35 | super().__init__() 36 | assert dim % 2 == 0 37 | self.dim = dim 38 | self.scale = scale 39 | 40 | def forward(self, x): 41 | device = x.device 42 | half_dim = self.dim // 2 43 | emb = math.log(10000) / half_dim 44 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 45 | emb = torch.outer(x * self.scale, emb) 46 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 47 | return emb 48 | 49 | 50 | class Downsample(nn.Module): 51 | __doc__ = r"""Downsamples a given tensor by a factor of 2. Uses strided convolution. Assumes even height and width. 52 | 53 | Input: 54 | x: tensor of shape (N, in_channels, H, W) 55 | time_emb: ignored 56 | y: ignored 57 | Output: 58 | tensor of shape (N, in_channels, H // 2, W // 2) 59 | Args: 60 | in_channels (int): number of input channels 61 | """ 62 | 63 | def __init__(self, in_channels): 64 | super().__init__() 65 | 66 | self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1) 67 | 68 | def forward(self, x, time_emb, y): 69 | if x.shape[2] % 2 == 1: 70 | raise ValueError("downsampling tensor height should be even") 71 | if x.shape[3] % 2 == 1: 72 | raise ValueError("downsampling tensor width should be even") 73 | 74 | return self.downsample(x) 75 | 76 | 77 | class Upsample(nn.Module): 78 | __doc__ = r"""Upsamples a given tensor by a factor of 2. Uses resize convolution to avoid checkerboard artifacts. 79 | 80 | Input: 81 | x: tensor of shape (N, in_channels, H, W) 82 | time_emb: ignored 83 | y: ignored 84 | Output: 85 | tensor of shape (N, in_channels, H * 2, W * 2) 86 | Args: 87 | in_channels (int): number of input channels 88 | """ 89 | 90 | def __init__(self, in_channels): 91 | super().__init__() 92 | 93 | self.upsample = nn.Sequential( 94 | nn.Upsample(scale_factor=2, mode="nearest"), 95 | nn.Conv2d(in_channels, in_channels, 3, padding=1), 96 | ) 97 | 98 | def forward(self, x, time_emb, y): 99 | return self.upsample(x) 100 | 101 | 102 | class AttentionBlock(nn.Module): 103 | __doc__ = r"""Applies QKV self-attention with a residual connection. 104 | 105 | Input: 106 | x: tensor of shape (N, in_channels, H, W) 107 | norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn" 108 | num_groups (int): number of groups used in group normalization. Default: 32 109 | Output: 110 | tensor of shape (N, in_channels, H, W) 111 | Args: 112 | in_channels (int): number of input channels 113 | """ 114 | def __init__(self, in_channels, norm="gn", num_groups=32): 115 | super().__init__() 116 | 117 | self.in_channels = in_channels 118 | self.norm = get_norm(norm, in_channels, num_groups) 119 | self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1) 120 | self.to_out = nn.Conv2d(in_channels, in_channels, 1) 121 | 122 | def forward(self, x): 123 | b, c, h, w = x.shape 124 | q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1) 125 | 126 | q = q.permute(0, 2, 3, 1).view(b, h * w, c) 127 | k = k.view(b, c, h * w) 128 | v = v.permute(0, 2, 3, 1).view(b, h * w, c) 129 | 130 | dot_products = torch.bmm(q, k) * (c ** (-0.5)) 131 | assert dot_products.shape == (b, h * w, h * w) 132 | 133 | attention = torch.softmax(dot_products, dim=-1) 134 | out = torch.bmm(attention, v) 135 | assert out.shape == (b, h * w, c) 136 | out = out.view(b, h, w, c).permute(0, 3, 1, 2) 137 | 138 | return self.to_out(out) + x 139 | 140 | 141 | class ResidualBlock(nn.Module): 142 | __doc__ = r"""Applies two conv blocks with resudual connection. Adds time and class conditioning by adding bias after first convolution. 143 | 144 | Input: 145 | x: tensor of shape (N, in_channels, H, W) 146 | time_emb: time embedding tensor of shape (N, time_emb_dim) or None if the block doesn't use time conditioning 147 | y: classes tensor of shape (N) or None if the block doesn't use class conditioning 148 | Output: 149 | tensor of shape (N, out_channels, H, W) 150 | Args: 151 | in_channels (int): number of input channels 152 | out_channels (int): number of output channels 153 | time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: None 154 | num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None 155 | activation (function): activation function. Default: torch.nn.functional.relu 156 | norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn" 157 | num_groups (int): number of groups used in group normalization. Default: 32 158 | use_attention (bool): if True applies AttentionBlock to the output. Default: False 159 | """ 160 | 161 | def __init__( 162 | self, 163 | in_channels, 164 | out_channels, 165 | dropout, 166 | time_emb_dim=None, 167 | num_classes=None, 168 | activation=F.relu, 169 | norm="gn", 170 | num_groups=32, 171 | use_attention=False, 172 | ): 173 | super().__init__() 174 | 175 | self.activation = activation 176 | 177 | self.norm_1 = get_norm(norm, in_channels, num_groups) 178 | self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) 179 | 180 | self.norm_2 = get_norm(norm, out_channels, num_groups) 181 | self.conv_2 = nn.Sequential( 182 | nn.Dropout(p=dropout), 183 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 184 | ) 185 | 186 | self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None 187 | self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None 188 | 189 | self.residual_connection = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() 190 | self.attention = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups) 191 | 192 | def forward(self, x, time_emb=None, y=None): 193 | out = self.activation(self.norm_1(x)) 194 | out = self.conv_1(out) 195 | 196 | if self.time_bias is not None: 197 | if time_emb is None: 198 | raise ValueError("time conditioning was specified but time_emb is not passed") 199 | out += self.time_bias(self.activation(time_emb))[:, :, None, None] 200 | 201 | if self.class_bias is not None: 202 | if y is None: 203 | raise ValueError("class conditioning was specified but y is not passed") 204 | 205 | out += self.class_bias(y)[:, :, None, None] 206 | 207 | out = self.activation(self.norm_2(out)) 208 | out = self.conv_2(out) + self.residual_connection(x) 209 | out = self.attention(out) 210 | 211 | return out 212 | 213 | 214 | class UNet(nn.Module): 215 | __doc__ = """UNet model used to estimate noise. 216 | 217 | Input: 218 | x: tensor of shape (N, in_channels, H, W) 219 | time_emb: time embedding tensor of shape (N, time_emb_dim) or None if the block doesn't use time conditioning 220 | y: classes tensor of shape (N) or None if the block doesn't use class conditioning 221 | Output: 222 | tensor of shape (N, out_channels, H, W) 223 | Args: 224 | img_channels (int): number of image channels 225 | base_channels (int): number of base channels (after first convolution) 226 | channel_mults (tuple): tuple of channel multiplers. Default: (1, 2, 4, 8) 227 | time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: None 228 | time_emb_scale (float): linear scale to be applied to timesteps. Default: 1.0 229 | num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None 230 | activation (function): activation function. Default: torch.nn.functional.relu 231 | dropout (float): dropout rate at the end of each residual block 232 | attention_resolutions (tuple): list of relative resolutions at which to apply attention. Default: () 233 | norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn" 234 | num_groups (int): number of groups used in group normalization. Default: 32 235 | initial_pad (int): initial padding applied to image. Should be used if height or width is not a power of 2. Default: 0 236 | """ 237 | 238 | def __init__( 239 | self, 240 | img_channels, 241 | base_channels, 242 | channel_mults=(1, 2, 4, 8), 243 | num_res_blocks=2, 244 | time_emb_dim=None, 245 | time_emb_scale=1.0, 246 | num_classes=None, 247 | activation=F.relu, 248 | dropout=0.1, 249 | attention_resolutions=(), 250 | norm="gn", 251 | num_groups=32, 252 | initial_pad=0, 253 | ): 254 | super().__init__() 255 | 256 | self.activation = activation 257 | self.initial_pad = initial_pad 258 | 259 | self.num_classes = num_classes 260 | self.time_mlp = nn.Sequential( 261 | PositionalEmbedding(base_channels, time_emb_scale), 262 | nn.Linear(base_channels, time_emb_dim), 263 | nn.SiLU(), 264 | nn.Linear(time_emb_dim, time_emb_dim), 265 | ) if time_emb_dim is not None else None 266 | 267 | self.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1) 268 | 269 | self.downs = nn.ModuleList() 270 | self.ups = nn.ModuleList() 271 | 272 | channels = [base_channels] 273 | now_channels = base_channels 274 | 275 | for i, mult in enumerate(channel_mults): 276 | out_channels = base_channels * mult 277 | 278 | for _ in range(num_res_blocks): 279 | self.downs.append(ResidualBlock( 280 | now_channels, 281 | out_channels, 282 | dropout, 283 | time_emb_dim=time_emb_dim, 284 | num_classes=num_classes, 285 | activation=activation, 286 | norm=norm, 287 | num_groups=num_groups, 288 | use_attention=i in attention_resolutions, 289 | )) 290 | now_channels = out_channels 291 | channels.append(now_channels) 292 | 293 | if i != len(channel_mults) - 1: 294 | self.downs.append(Downsample(now_channels)) 295 | channels.append(now_channels) 296 | 297 | 298 | self.mid = nn.ModuleList([ 299 | ResidualBlock( 300 | now_channels, 301 | now_channels, 302 | dropout, 303 | time_emb_dim=time_emb_dim, 304 | num_classes=num_classes, 305 | activation=activation, 306 | norm=norm, 307 | num_groups=num_groups, 308 | use_attention=True, 309 | ), 310 | ResidualBlock( 311 | now_channels, 312 | now_channels, 313 | dropout, 314 | time_emb_dim=time_emb_dim, 315 | num_classes=num_classes, 316 | activation=activation, 317 | norm=norm, 318 | num_groups=num_groups, 319 | use_attention=False, 320 | ), 321 | ]) 322 | 323 | for i, mult in reversed(list(enumerate(channel_mults))): 324 | out_channels = base_channels * mult 325 | 326 | for _ in range(num_res_blocks + 1): 327 | self.ups.append(ResidualBlock( 328 | channels.pop() + now_channels, 329 | out_channels, 330 | dropout, 331 | time_emb_dim=time_emb_dim, 332 | num_classes=num_classes, 333 | activation=activation, 334 | norm=norm, 335 | num_groups=num_groups, 336 | use_attention=i in attention_resolutions, 337 | )) 338 | now_channels = out_channels 339 | 340 | if i != 0: 341 | self.ups.append(Upsample(now_channels)) 342 | 343 | assert len(channels) == 0 344 | 345 | self.out_norm = get_norm(norm, base_channels, num_groups) 346 | self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1) 347 | 348 | def forward(self, x, time=None, y=None): 349 | ip = self.initial_pad 350 | if ip != 0: 351 | x = F.pad(x, (ip,) * 4) 352 | 353 | if self.time_mlp is not None: 354 | if time is None: 355 | raise ValueError("time conditioning was specified but tim is not passed") 356 | 357 | time_emb = self.time_mlp(time) 358 | else: 359 | time_emb = None 360 | 361 | if self.num_classes is not None and y is None: 362 | raise ValueError("class conditioning was specified but y is not passed") 363 | 364 | x = self.init_conv(x) 365 | 366 | skips = [x] 367 | 368 | for layer in self.downs: 369 | x = layer(x, time_emb, y) 370 | skips.append(x) 371 | 372 | for layer in self.mid: 373 | x = layer(x, time_emb, y) 374 | 375 | for layer in self.ups: 376 | if isinstance(layer, ResidualBlock): 377 | x = torch.cat([x, skips.pop()], dim=1) 378 | x = layer(x, time_emb, y) 379 | 380 | x = self.activation(self.out_norm(x)) 381 | x = self.out_conv(x) 382 | 383 | if self.initial_pad != 0: 384 | return x[:, :, ip:-ip, ip:-ip] 385 | else: 386 | return x --------------------------------------------------------------------------------