├── ddpm
├── __init__.py
├── utils.py
├── ema.py
├── script_utils.py
├── diffusion.py
└── unet.py
├── .gitignore
├── resources
├── samples_linear_200.png
├── diffusion_models_report.pdf
├── diffusion_sequence_mnist.gif
└── diffusion_models_talk_slides.pdf
├── setup.py
├── scripts
├── sample_images.py
└── train_cifar.py
└── README.md
/ddpm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ddpm_env/
2 | .idea/
3 | .vscode/
4 | *.egg-info/
--------------------------------------------------------------------------------
/resources/samples_linear_200.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abarankab/DDPM/HEAD/resources/samples_linear_200.png
--------------------------------------------------------------------------------
/resources/diffusion_models_report.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abarankab/DDPM/HEAD/resources/diffusion_models_report.pdf
--------------------------------------------------------------------------------
/resources/diffusion_sequence_mnist.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abarankab/DDPM/HEAD/resources/diffusion_sequence_mnist.gif
--------------------------------------------------------------------------------
/resources/diffusion_models_talk_slides.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abarankab/DDPM/HEAD/resources/diffusion_models_talk_slides.pdf
--------------------------------------------------------------------------------
/ddpm/utils.py:
--------------------------------------------------------------------------------
1 | def extract(a, t, x_shape):
2 | b, *_ = t.shape
3 | out = a.gather(-1, t)
4 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="ddpm",
5 | py_modules=["ddpm"],
6 | install_requires=["torch", "torchvision", "einops", "wandb", "joblib"],
7 | )
--------------------------------------------------------------------------------
/ddpm/ema.py:
--------------------------------------------------------------------------------
1 | class EMA():
2 | def __init__(self, decay):
3 | self.decay = decay
4 |
5 | def update_average(self, old, new):
6 | if old is None:
7 | return new
8 | return old * self.decay + (1 - self.decay) * new
9 |
10 | def update_model_average(self, ema_model, current_model):
11 | for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()):
12 | old, new = ema_params.data, current_params.data
13 | ema_params.data = self.update_average(old, new)
--------------------------------------------------------------------------------
/scripts/sample_images.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torchvision
4 |
5 | from ddpm import script_utils
6 |
7 |
8 | def main():
9 | args = create_argparser().parse_args()
10 | device = args.device
11 |
12 | try:
13 | diffusion = script_utils.get_diffusion_from_args(args).to(device)
14 | diffusion.load_state_dict(torch.load(args.model_path))
15 |
16 | if args.use_labels:
17 | for label in range(10):
18 | y = torch.ones(args.num_images // 10, dtype=torch.long, device=device) * label
19 | samples = diffusion.sample(args.num_images // 10, device, y=y)
20 |
21 | for image_id in range(len(samples)):
22 | image = ((samples[image_id] + 1) / 2).clip(0, 1)
23 | torchvision.utils.save_image(image, f"{args.save_dir}/{label}-{image_id}.png")
24 | else:
25 | samples = diffusion.sample(args.num_images, device)
26 |
27 | for image_id in range(len(samples)):
28 | image = ((samples[image_id] + 1) / 2).clip(0, 1)
29 | torchvision.utils.save_image(image, f"{args.save_dir}/{image_id}.png")
30 | except KeyboardInterrupt:
31 | print("Keyboard interrupt, generation finished early")
32 |
33 |
34 | def create_argparser():
35 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
36 | defaults = dict(num_images=10000, device=device)
37 | defaults.update(script_utils.diffusion_defaults())
38 |
39 | parser = argparse.ArgumentParser()
40 | parser.add_argument("--model_path", type=str)
41 | parser.add_argument("--save_dir", type=str)
42 | script_utils.add_dict_to_argparser(parser, defaults)
43 | return parser
44 |
45 |
46 | if __name__ == "__main__":
47 | main()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Denoising Diffusion Probabilistic Models
2 |
3 | An implementation of Denoising Diffusion Probabilistic Models for image generation written in PyTorch. This roughly follows the original code by Ho et al. Unlike their implementation, however, my model allows for class conditioning through bias in residual blocks.
4 |
5 | ## Experiments
6 |
7 | I have trained the model on MNIST and CIFAR-10 datasets. The model seemed to converge well on the MNIST dataset, producing realistic samples. However, I am yet to report the same CIFAR-10 quality that Ho. et al. provide in their paper. Here are the samples generated with a linear schedule after 2000 epochs:
8 |
9 | 
10 |
11 | Here is a sample of a diffusion sequence on MNIST:
12 |
13 |
14 |
15 |
16 |
17 |
18 | ## Resources
19 |
20 | I gave a talk about diffusion models, NCSNs, and their applications in audio generation. The [slides are available here](resources/diffusion_models_talk_slides.pdf).
21 |
22 | I also compiled a report with what are, in my opinion, the most crucial findings on the topic of denoising diffusion models. It is also [available in this repository](resources/diffusion_models_report.pdf).
23 |
24 |
25 | ## Acknowledgements
26 |
27 | I used [Phil Wang's implementation](https://github.com/lucidrains/denoising-diffusion-pytorch) and [the official Tensorflow repo](https://github.com/hojonathanho/diffusion) as a reference for my work.
28 |
29 | ## Citations
30 |
31 | ```bibtex
32 | @misc{ho2020denoising,
33 | title = {Denoising Diffusion Probabilistic Models},
34 | author = {Jonathan Ho and Ajay Jain and Pieter Abbeel},
35 | year = {2020},
36 | eprint = {2006.11239},
37 | archivePrefix = {arXiv},
38 | primaryClass = {cs.LG}
39 | }
40 | ```
41 |
42 | ```bibtex
43 | @inproceedings{anonymous2021improved,
44 | title = {Improved Denoising Diffusion Probabilistic Models},
45 | author = {Anonymous},
46 | booktitle = {Submitted to International Conference on Learning Representations},
47 | year = {2021},
48 | url = {https://openreview.net/forum?id=-NEXDKk8gZ},
49 | note = {under review}
50 | }
51 |
--------------------------------------------------------------------------------
/ddpm/script_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torchvision
3 | import torch.nn.functional as F
4 |
5 | from .unet import UNet
6 | from .diffusion import (
7 | GaussianDiffusion,
8 | generate_linear_schedule,
9 | generate_cosine_schedule,
10 | )
11 |
12 |
13 | def cycle(dl):
14 | """
15 | https://github.com/lucidrains/denoising-diffusion-pytorch/
16 | """
17 | while True:
18 | for data in dl:
19 | yield data
20 |
21 | def get_transform():
22 | class RescaleChannels(object):
23 | def __call__(self, sample):
24 | return 2 * sample - 1
25 |
26 | return torchvision.transforms.Compose([
27 | torchvision.transforms.ToTensor(),
28 | RescaleChannels(),
29 | ])
30 |
31 |
32 | def str2bool(v):
33 | """
34 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
35 | """
36 | if isinstance(v, bool):
37 | return v
38 | if v.lower() in ("yes", "true", "t", "y", "1"):
39 | return True
40 | elif v.lower() in ("no", "false", "f", "n", "0"):
41 | return False
42 | else:
43 | raise argparse.ArgumentTypeError("boolean value expected")
44 |
45 |
46 | def add_dict_to_argparser(parser, default_dict):
47 | """
48 | https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py
49 | """
50 | for k, v in default_dict.items():
51 | v_type = type(v)
52 | if v is None:
53 | v_type = str
54 | elif isinstance(v, bool):
55 | v_type = str2bool
56 | parser.add_argument(f"--{k}", default=v, type=v_type)
57 |
58 |
59 | def diffusion_defaults():
60 | defaults = dict(
61 | num_timesteps=1000,
62 | schedule="linear",
63 | loss_type="l2",
64 | use_labels=False,
65 |
66 | base_channels=128,
67 | channel_mults=(1, 2, 2, 2),
68 | num_res_blocks=2,
69 | time_emb_dim=128 * 4,
70 | norm="gn",
71 | dropout=0.1,
72 | activation="silu",
73 | attention_resolutions=(1,),
74 |
75 | ema_decay=0.9999,
76 | ema_update_rate=1,
77 | )
78 |
79 | return defaults
80 |
81 |
82 | def get_diffusion_from_args(args):
83 | activations = {
84 | "relu": F.relu,
85 | "mish": F.mish,
86 | "silu": F.silu,
87 | }
88 |
89 | model = UNet(
90 | img_channels=3,
91 |
92 | base_channels=args.base_channels,
93 | channel_mults=args.channel_mults,
94 | time_emb_dim=args.time_emb_dim,
95 | norm=args.norm,
96 | dropout=args.dropout,
97 | activation=activations[args.activation],
98 | attention_resolutions=args.attention_resolutions,
99 |
100 | num_classes=None if not args.use_labels else 10,
101 | initial_pad=0,
102 | )
103 |
104 | if args.schedule == "cosine":
105 | betas = generate_cosine_schedule(args.num_timesteps)
106 | else:
107 | betas = generate_linear_schedule(
108 | args.num_timesteps,
109 | args.schedule_low * 1000 / args.num_timesteps,
110 | args.schedule_high * 1000 / args.num_timesteps,
111 | )
112 |
113 | diffusion = GaussianDiffusion(
114 | model, (32, 32), 3, 10,
115 | betas,
116 | ema_decay=args.ema_decay,
117 | ema_update_rate=args.ema_update_rate,
118 | ema_start=2000,
119 | loss_type=args.loss_type,
120 | )
121 |
122 | return diffusion
--------------------------------------------------------------------------------
/scripts/train_cifar.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import torch
4 | import wandb
5 |
6 | from torch.utils.data import DataLoader
7 | from torchvision import datasets
8 | from ddpm import script_utils
9 |
10 |
11 | def main():
12 | args = create_argparser().parse_args()
13 | device = args.device
14 |
15 | try:
16 | diffusion = script_utils.get_diffusion_from_args(args).to(device)
17 | optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate)
18 |
19 | if args.model_checkpoint is not None:
20 | diffusion.load_state_dict(torch.load(args.model_checkpoint))
21 | if args.optim_checkpoint is not None:
22 | optimizer.load_state_dict(torch.load(args.optim_checkpoint))
23 |
24 | if args.log_to_wandb:
25 | if args.project_name is None:
26 | raise ValueError("args.log_to_wandb set to True but args.project_name is None")
27 |
28 | run = wandb.init(
29 | project=args.project_name,
30 | entity='treaptofun',
31 | config=vars(args),
32 | name=args.run_name,
33 | )
34 | wandb.watch(diffusion)
35 |
36 | batch_size = args.batch_size
37 |
38 | train_dataset = datasets.CIFAR10(
39 | root='./cifar_train',
40 | train=True,
41 | download=True,
42 | transform=script_utils.get_transform(),
43 | )
44 |
45 | test_dataset = datasets.CIFAR10(
46 | root='./cifar_test',
47 | train=False,
48 | download=True,
49 | transform=script_utils.get_transform(),
50 | )
51 |
52 | train_loader = script_utils.cycle(DataLoader(
53 | train_dataset,
54 | batch_size=batch_size,
55 | shuffle=True,
56 | drop_last=True,
57 | num_workers=2,
58 | ))
59 | test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2)
60 |
61 | acc_train_loss = 0
62 |
63 | for iteration in range(1, args.iterations + 1):
64 | diffusion.train()
65 |
66 | x, y = next(train_loader)
67 | x = x.to(device)
68 | y = y.to(device)
69 |
70 | if args.use_labels:
71 | loss = diffusion(x, y)
72 | else:
73 | loss = diffusion(x)
74 |
75 | acc_train_loss += loss.item()
76 |
77 | optimizer.zero_grad()
78 | loss.backward()
79 | optimizer.step()
80 |
81 | diffusion.update_ema()
82 |
83 | if iteration % args.log_rate == 0:
84 | test_loss = 0
85 | with torch.no_grad():
86 | diffusion.eval()
87 | for x, y in test_loader:
88 | x = x.to(device)
89 | y = y.to(device)
90 |
91 | if args.use_labels:
92 | loss = diffusion(x, y)
93 | else:
94 | loss = diffusion(x)
95 |
96 | test_loss += loss.item()
97 |
98 | if args.use_labels:
99 | samples = diffusion.sample(10, device, y=torch.arange(10, device=device))
100 | else:
101 | samples = diffusion.sample(10, device)
102 |
103 | samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy()
104 |
105 | test_loss /= len(test_loader)
106 | acc_train_loss /= args.log_rate
107 |
108 | wandb.log({
109 | "test_loss": test_loss,
110 | "train_loss": acc_train_loss,
111 | "samples": [wandb.Image(sample) for sample in samples],
112 | })
113 |
114 | acc_train_loss = 0
115 |
116 | if iteration % args.checkpoint_rate == 0:
117 | model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth"
118 | optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth"
119 |
120 | torch.save(diffusion.state_dict(), model_filename)
121 | torch.save(optimizer.state_dict(), optim_filename)
122 |
123 | if args.log_to_wandb:
124 | run.finish()
125 | except KeyboardInterrupt:
126 | if args.log_to_wandb:
127 | run.finish()
128 | print("Keyboard interrupt, run finished early")
129 |
130 |
131 | def create_argparser():
132 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
133 | run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M")
134 | defaults = dict(
135 | learning_rate=2e-4,
136 | batch_size=128,
137 | iterations=800000,
138 |
139 | log_to_wandb=True,
140 | log_rate=1000,
141 | checkpoint_rate=1000,
142 | log_dir="~/ddpm_logs",
143 | project_name=None,
144 | run_name=run_name,
145 |
146 | model_checkpoint=None,
147 | optim_checkpoint=None,
148 |
149 | schedule_low=1e-4,
150 | schedule_high=0.02,
151 |
152 | device=device,
153 | )
154 | defaults.update(script_utils.diffusion_defaults())
155 |
156 | parser = argparse.ArgumentParser()
157 | script_utils.add_dict_to_argparser(parser, defaults)
158 | return parser
159 |
160 |
161 | if __name__ == "__main__":
162 | main()
--------------------------------------------------------------------------------
/ddpm/diffusion.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from functools import partial
7 | from copy import deepcopy
8 |
9 | from .ema import EMA
10 | from .utils import extract
11 |
12 | class GaussianDiffusion(nn.Module):
13 | __doc__ = r"""Gaussian Diffusion model. Forwarding through the module returns diffusion reversal scalar loss tensor.
14 |
15 | Input:
16 | x: tensor of shape (N, img_channels, *img_size)
17 | y: tensor of shape (N)
18 | Output:
19 | scalar loss tensor
20 | Args:
21 | model (nn.Module): model which estimates diffusion noise
22 | img_size (tuple): image size tuple (H, W)
23 | img_channels (int): number of image channels
24 | betas (np.ndarray): numpy array of diffusion betas
25 | loss_type (string): loss type, "l1" or "l2"
26 | ema_decay (float): model weights exponential moving average decay
27 | ema_start (int): number of steps before EMA
28 | ema_update_rate (int): number of steps before each EMA update
29 | """
30 | def __init__(
31 | self,
32 | model,
33 | img_size,
34 | img_channels,
35 | num_classes,
36 | betas,
37 | loss_type="l2",
38 | ema_decay=0.9999,
39 | ema_start=5000,
40 | ema_update_rate=1,
41 | ):
42 | super().__init__()
43 |
44 | self.model = model
45 | self.ema_model = deepcopy(model)
46 |
47 | self.ema = EMA(ema_decay)
48 | self.ema_decay = ema_decay
49 | self.ema_start = ema_start
50 | self.ema_update_rate = ema_update_rate
51 | self.step = 0
52 |
53 | self.img_size = img_size
54 | self.img_channels = img_channels
55 | self.num_classes = num_classes
56 |
57 | if loss_type not in ["l1", "l2"]:
58 | raise ValueError("__init__() got unknown loss type")
59 |
60 | self.loss_type = loss_type
61 | self.num_timesteps = len(betas)
62 |
63 | alphas = 1.0 - betas
64 | alphas_cumprod = np.cumprod(alphas)
65 |
66 | to_torch = partial(torch.tensor, dtype=torch.float32)
67 |
68 | self.register_buffer("betas", to_torch(betas))
69 | self.register_buffer("alphas", to_torch(alphas))
70 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
71 |
72 | self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
73 | self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1 - alphas_cumprod)))
74 | self.register_buffer("reciprocal_sqrt_alphas", to_torch(np.sqrt(1 / alphas)))
75 |
76 | self.register_buffer("remove_noise_coeff", to_torch(betas / np.sqrt(1 - alphas_cumprod)))
77 | self.register_buffer("sigma", to_torch(np.sqrt(betas)))
78 |
79 | def update_ema(self):
80 | self.step += 1
81 | if self.step % self.ema_update_rate == 0:
82 | if self.step < self.ema_start:
83 | self.ema_model.load_state_dict(self.model.state_dict())
84 | else:
85 | self.ema.update_model_average(self.ema_model, self.model)
86 |
87 | @torch.no_grad()
88 | def remove_noise(self, x, t, y, use_ema=True):
89 | if use_ema:
90 | return (
91 | (x - extract(self.remove_noise_coeff, t, x.shape) * self.ema_model(x, t, y)) *
92 | extract(self.reciprocal_sqrt_alphas, t, x.shape)
93 | )
94 | else:
95 | return (
96 | (x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) *
97 | extract(self.reciprocal_sqrt_alphas, t, x.shape)
98 | )
99 |
100 | @torch.no_grad()
101 | def sample(self, batch_size, device, y=None, use_ema=True):
102 | if y is not None and batch_size != len(y):
103 | raise ValueError("sample batch size different from length of given y")
104 |
105 | x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
106 |
107 | for t in range(self.num_timesteps - 1, -1, -1):
108 | t_batch = torch.tensor([t], device=device).repeat(batch_size)
109 | x = self.remove_noise(x, t_batch, y, use_ema)
110 |
111 | if t > 0:
112 | x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
113 |
114 | return x.cpu().detach()
115 |
116 | @torch.no_grad()
117 | def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True):
118 | if y is not None and batch_size != len(y):
119 | raise ValueError("sample batch size different from length of given y")
120 |
121 | x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
122 | diffusion_sequence = [x.cpu().detach()]
123 |
124 | for t in range(self.num_timesteps - 1, -1, -1):
125 | t_batch = torch.tensor([t], device=device).repeat(batch_size)
126 | x = self.remove_noise(x, t_batch, y, use_ema)
127 |
128 | if t > 0:
129 | x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
130 |
131 | diffusion_sequence.append(x.cpu().detach())
132 |
133 | return diffusion_sequence
134 |
135 | def perturb_x(self, x, t, noise):
136 | return (
137 | extract(self.sqrt_alphas_cumprod, t, x.shape) * x +
138 | extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise
139 | )
140 |
141 | def get_losses(self, x, t, y):
142 | noise = torch.randn_like(x)
143 |
144 | perturbed_x = self.perturb_x(x, t, noise)
145 | estimated_noise = self.model(perturbed_x, t, y)
146 |
147 | if self.loss_type == "l1":
148 | loss = F.l1_loss(estimated_noise, noise)
149 | elif self.loss_type == "l2":
150 | loss = F.mse_loss(estimated_noise, noise)
151 |
152 | return loss
153 |
154 | def forward(self, x, y=None):
155 | b, c, h, w = x.shape
156 | device = x.device
157 |
158 | if h != self.img_size[0]:
159 | raise ValueError("image height does not match diffusion parameters")
160 | if w != self.img_size[0]:
161 | raise ValueError("image width does not match diffusion parameters")
162 |
163 | t = torch.randint(0, self.num_timesteps, (b,), device=device)
164 | return self.get_losses(x, t, y)
165 |
166 |
167 | def generate_cosine_schedule(T, s=0.008):
168 | def f(t, T):
169 | return (np.cos((t / T + s) / (1 + s) * np.pi / 2)) ** 2
170 |
171 | alphas = []
172 | f0 = f(0, T)
173 |
174 | for t in range(T + 1):
175 | alphas.append(f(t, T) / f0)
176 |
177 | betas = []
178 |
179 | for t in range(1, T + 1):
180 | betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999))
181 |
182 | return np.array(betas)
183 |
184 |
185 | def generate_linear_schedule(T, low, high):
186 | return np.linspace(low, high, T)
--------------------------------------------------------------------------------
/ddpm/unet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torch.nn.modules.normalization import GroupNorm
7 |
8 |
9 | def get_norm(norm, num_channels, num_groups):
10 | if norm == "in":
11 | return nn.InstanceNorm2d(num_channels, affine=True)
12 | elif norm == "bn":
13 | return nn.BatchNorm2d(num_channels)
14 | elif norm == "gn":
15 | return nn.GroupNorm(num_groups, num_channels)
16 | elif norm is None:
17 | return nn.Identity()
18 | else:
19 | raise ValueError("unknown normalization type")
20 |
21 |
22 | class PositionalEmbedding(nn.Module):
23 | __doc__ = r"""Computes a positional embedding of timesteps.
24 |
25 | Input:
26 | x: tensor of shape (N)
27 | Output:
28 | tensor of shape (N, dim)
29 | Args:
30 | dim (int): embedding dimension
31 | scale (float): linear scale to be applied to timesteps. Default: 1.0
32 | """
33 |
34 | def __init__(self, dim, scale=1.0):
35 | super().__init__()
36 | assert dim % 2 == 0
37 | self.dim = dim
38 | self.scale = scale
39 |
40 | def forward(self, x):
41 | device = x.device
42 | half_dim = self.dim // 2
43 | emb = math.log(10000) / half_dim
44 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
45 | emb = torch.outer(x * self.scale, emb)
46 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
47 | return emb
48 |
49 |
50 | class Downsample(nn.Module):
51 | __doc__ = r"""Downsamples a given tensor by a factor of 2. Uses strided convolution. Assumes even height and width.
52 |
53 | Input:
54 | x: tensor of shape (N, in_channels, H, W)
55 | time_emb: ignored
56 | y: ignored
57 | Output:
58 | tensor of shape (N, in_channels, H // 2, W // 2)
59 | Args:
60 | in_channels (int): number of input channels
61 | """
62 |
63 | def __init__(self, in_channels):
64 | super().__init__()
65 |
66 | self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)
67 |
68 | def forward(self, x, time_emb, y):
69 | if x.shape[2] % 2 == 1:
70 | raise ValueError("downsampling tensor height should be even")
71 | if x.shape[3] % 2 == 1:
72 | raise ValueError("downsampling tensor width should be even")
73 |
74 | return self.downsample(x)
75 |
76 |
77 | class Upsample(nn.Module):
78 | __doc__ = r"""Upsamples a given tensor by a factor of 2. Uses resize convolution to avoid checkerboard artifacts.
79 |
80 | Input:
81 | x: tensor of shape (N, in_channels, H, W)
82 | time_emb: ignored
83 | y: ignored
84 | Output:
85 | tensor of shape (N, in_channels, H * 2, W * 2)
86 | Args:
87 | in_channels (int): number of input channels
88 | """
89 |
90 | def __init__(self, in_channels):
91 | super().__init__()
92 |
93 | self.upsample = nn.Sequential(
94 | nn.Upsample(scale_factor=2, mode="nearest"),
95 | nn.Conv2d(in_channels, in_channels, 3, padding=1),
96 | )
97 |
98 | def forward(self, x, time_emb, y):
99 | return self.upsample(x)
100 |
101 |
102 | class AttentionBlock(nn.Module):
103 | __doc__ = r"""Applies QKV self-attention with a residual connection.
104 |
105 | Input:
106 | x: tensor of shape (N, in_channels, H, W)
107 | norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
108 | num_groups (int): number of groups used in group normalization. Default: 32
109 | Output:
110 | tensor of shape (N, in_channels, H, W)
111 | Args:
112 | in_channels (int): number of input channels
113 | """
114 | def __init__(self, in_channels, norm="gn", num_groups=32):
115 | super().__init__()
116 |
117 | self.in_channels = in_channels
118 | self.norm = get_norm(norm, in_channels, num_groups)
119 | self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
120 | self.to_out = nn.Conv2d(in_channels, in_channels, 1)
121 |
122 | def forward(self, x):
123 | b, c, h, w = x.shape
124 | q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
125 |
126 | q = q.permute(0, 2, 3, 1).view(b, h * w, c)
127 | k = k.view(b, c, h * w)
128 | v = v.permute(0, 2, 3, 1).view(b, h * w, c)
129 |
130 | dot_products = torch.bmm(q, k) * (c ** (-0.5))
131 | assert dot_products.shape == (b, h * w, h * w)
132 |
133 | attention = torch.softmax(dot_products, dim=-1)
134 | out = torch.bmm(attention, v)
135 | assert out.shape == (b, h * w, c)
136 | out = out.view(b, h, w, c).permute(0, 3, 1, 2)
137 |
138 | return self.to_out(out) + x
139 |
140 |
141 | class ResidualBlock(nn.Module):
142 | __doc__ = r"""Applies two conv blocks with resudual connection. Adds time and class conditioning by adding bias after first convolution.
143 |
144 | Input:
145 | x: tensor of shape (N, in_channels, H, W)
146 | time_emb: time embedding tensor of shape (N, time_emb_dim) or None if the block doesn't use time conditioning
147 | y: classes tensor of shape (N) or None if the block doesn't use class conditioning
148 | Output:
149 | tensor of shape (N, out_channels, H, W)
150 | Args:
151 | in_channels (int): number of input channels
152 | out_channels (int): number of output channels
153 | time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: None
154 | num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None
155 | activation (function): activation function. Default: torch.nn.functional.relu
156 | norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
157 | num_groups (int): number of groups used in group normalization. Default: 32
158 | use_attention (bool): if True applies AttentionBlock to the output. Default: False
159 | """
160 |
161 | def __init__(
162 | self,
163 | in_channels,
164 | out_channels,
165 | dropout,
166 | time_emb_dim=None,
167 | num_classes=None,
168 | activation=F.relu,
169 | norm="gn",
170 | num_groups=32,
171 | use_attention=False,
172 | ):
173 | super().__init__()
174 |
175 | self.activation = activation
176 |
177 | self.norm_1 = get_norm(norm, in_channels, num_groups)
178 | self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
179 |
180 | self.norm_2 = get_norm(norm, out_channels, num_groups)
181 | self.conv_2 = nn.Sequential(
182 | nn.Dropout(p=dropout),
183 | nn.Conv2d(out_channels, out_channels, 3, padding=1),
184 | )
185 |
186 | self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
187 | self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None
188 |
189 | self.residual_connection = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
190 | self.attention = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)
191 |
192 | def forward(self, x, time_emb=None, y=None):
193 | out = self.activation(self.norm_1(x))
194 | out = self.conv_1(out)
195 |
196 | if self.time_bias is not None:
197 | if time_emb is None:
198 | raise ValueError("time conditioning was specified but time_emb is not passed")
199 | out += self.time_bias(self.activation(time_emb))[:, :, None, None]
200 |
201 | if self.class_bias is not None:
202 | if y is None:
203 | raise ValueError("class conditioning was specified but y is not passed")
204 |
205 | out += self.class_bias(y)[:, :, None, None]
206 |
207 | out = self.activation(self.norm_2(out))
208 | out = self.conv_2(out) + self.residual_connection(x)
209 | out = self.attention(out)
210 |
211 | return out
212 |
213 |
214 | class UNet(nn.Module):
215 | __doc__ = """UNet model used to estimate noise.
216 |
217 | Input:
218 | x: tensor of shape (N, in_channels, H, W)
219 | time_emb: time embedding tensor of shape (N, time_emb_dim) or None if the block doesn't use time conditioning
220 | y: classes tensor of shape (N) or None if the block doesn't use class conditioning
221 | Output:
222 | tensor of shape (N, out_channels, H, W)
223 | Args:
224 | img_channels (int): number of image channels
225 | base_channels (int): number of base channels (after first convolution)
226 | channel_mults (tuple): tuple of channel multiplers. Default: (1, 2, 4, 8)
227 | time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: None
228 | time_emb_scale (float): linear scale to be applied to timesteps. Default: 1.0
229 | num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None
230 | activation (function): activation function. Default: torch.nn.functional.relu
231 | dropout (float): dropout rate at the end of each residual block
232 | attention_resolutions (tuple): list of relative resolutions at which to apply attention. Default: ()
233 | norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
234 | num_groups (int): number of groups used in group normalization. Default: 32
235 | initial_pad (int): initial padding applied to image. Should be used if height or width is not a power of 2. Default: 0
236 | """
237 |
238 | def __init__(
239 | self,
240 | img_channels,
241 | base_channels,
242 | channel_mults=(1, 2, 4, 8),
243 | num_res_blocks=2,
244 | time_emb_dim=None,
245 | time_emb_scale=1.0,
246 | num_classes=None,
247 | activation=F.relu,
248 | dropout=0.1,
249 | attention_resolutions=(),
250 | norm="gn",
251 | num_groups=32,
252 | initial_pad=0,
253 | ):
254 | super().__init__()
255 |
256 | self.activation = activation
257 | self.initial_pad = initial_pad
258 |
259 | self.num_classes = num_classes
260 | self.time_mlp = nn.Sequential(
261 | PositionalEmbedding(base_channels, time_emb_scale),
262 | nn.Linear(base_channels, time_emb_dim),
263 | nn.SiLU(),
264 | nn.Linear(time_emb_dim, time_emb_dim),
265 | ) if time_emb_dim is not None else None
266 |
267 | self.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1)
268 |
269 | self.downs = nn.ModuleList()
270 | self.ups = nn.ModuleList()
271 |
272 | channels = [base_channels]
273 | now_channels = base_channels
274 |
275 | for i, mult in enumerate(channel_mults):
276 | out_channels = base_channels * mult
277 |
278 | for _ in range(num_res_blocks):
279 | self.downs.append(ResidualBlock(
280 | now_channels,
281 | out_channels,
282 | dropout,
283 | time_emb_dim=time_emb_dim,
284 | num_classes=num_classes,
285 | activation=activation,
286 | norm=norm,
287 | num_groups=num_groups,
288 | use_attention=i in attention_resolutions,
289 | ))
290 | now_channels = out_channels
291 | channels.append(now_channels)
292 |
293 | if i != len(channel_mults) - 1:
294 | self.downs.append(Downsample(now_channels))
295 | channels.append(now_channels)
296 |
297 |
298 | self.mid = nn.ModuleList([
299 | ResidualBlock(
300 | now_channels,
301 | now_channels,
302 | dropout,
303 | time_emb_dim=time_emb_dim,
304 | num_classes=num_classes,
305 | activation=activation,
306 | norm=norm,
307 | num_groups=num_groups,
308 | use_attention=True,
309 | ),
310 | ResidualBlock(
311 | now_channels,
312 | now_channels,
313 | dropout,
314 | time_emb_dim=time_emb_dim,
315 | num_classes=num_classes,
316 | activation=activation,
317 | norm=norm,
318 | num_groups=num_groups,
319 | use_attention=False,
320 | ),
321 | ])
322 |
323 | for i, mult in reversed(list(enumerate(channel_mults))):
324 | out_channels = base_channels * mult
325 |
326 | for _ in range(num_res_blocks + 1):
327 | self.ups.append(ResidualBlock(
328 | channels.pop() + now_channels,
329 | out_channels,
330 | dropout,
331 | time_emb_dim=time_emb_dim,
332 | num_classes=num_classes,
333 | activation=activation,
334 | norm=norm,
335 | num_groups=num_groups,
336 | use_attention=i in attention_resolutions,
337 | ))
338 | now_channels = out_channels
339 |
340 | if i != 0:
341 | self.ups.append(Upsample(now_channels))
342 |
343 | assert len(channels) == 0
344 |
345 | self.out_norm = get_norm(norm, base_channels, num_groups)
346 | self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)
347 |
348 | def forward(self, x, time=None, y=None):
349 | ip = self.initial_pad
350 | if ip != 0:
351 | x = F.pad(x, (ip,) * 4)
352 |
353 | if self.time_mlp is not None:
354 | if time is None:
355 | raise ValueError("time conditioning was specified but tim is not passed")
356 |
357 | time_emb = self.time_mlp(time)
358 | else:
359 | time_emb = None
360 |
361 | if self.num_classes is not None and y is None:
362 | raise ValueError("class conditioning was specified but y is not passed")
363 |
364 | x = self.init_conv(x)
365 |
366 | skips = [x]
367 |
368 | for layer in self.downs:
369 | x = layer(x, time_emb, y)
370 | skips.append(x)
371 |
372 | for layer in self.mid:
373 | x = layer(x, time_emb, y)
374 |
375 | for layer in self.ups:
376 | if isinstance(layer, ResidualBlock):
377 | x = torch.cat([x, skips.pop()], dim=1)
378 | x = layer(x, time_emb, y)
379 |
380 | x = self.activation(self.out_norm(x))
381 | x = self.out_conv(x)
382 |
383 | if self.initial_pad != 0:
384 | return x[:, :, ip:-ip, ip:-ip]
385 | else:
386 | return x
--------------------------------------------------------------------------------