├── ddpm ├── photos │ ├── mnist_1.png │ ├── mnist_2.png │ ├── cifar10_1.png │ └── cifar10_2.png ├── inference.py ├── train.py ├── diffusion.py └── net.py ├── classifier_free_ddpm ├── photos │ ├── classifier_free_mnist_1.png │ ├── classifier_free_mnist_2.png │ ├── classifier_free_cifar10_1.png │ └── classifier_free_cifar10_2.png ├── train.py ├── inference.py ├── diffusion.py └── net.py └── README.md /ddpm/photos/mnist_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BastianChen/ddpm-demo-pytorch/HEAD/ddpm/photos/mnist_1.png -------------------------------------------------------------------------------- /ddpm/photos/mnist_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BastianChen/ddpm-demo-pytorch/HEAD/ddpm/photos/mnist_2.png -------------------------------------------------------------------------------- /ddpm/photos/cifar10_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BastianChen/ddpm-demo-pytorch/HEAD/ddpm/photos/cifar10_1.png -------------------------------------------------------------------------------- /ddpm/photos/cifar10_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BastianChen/ddpm-demo-pytorch/HEAD/ddpm/photos/cifar10_2.png -------------------------------------------------------------------------------- /classifier_free_ddpm/photos/classifier_free_mnist_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BastianChen/ddpm-demo-pytorch/HEAD/classifier_free_ddpm/photos/classifier_free_mnist_1.png -------------------------------------------------------------------------------- /classifier_free_ddpm/photos/classifier_free_mnist_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BastianChen/ddpm-demo-pytorch/HEAD/classifier_free_ddpm/photos/classifier_free_mnist_2.png -------------------------------------------------------------------------------- /classifier_free_ddpm/photos/classifier_free_cifar10_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BastianChen/ddpm-demo-pytorch/HEAD/classifier_free_ddpm/photos/classifier_free_cifar10_1.png -------------------------------------------------------------------------------- /classifier_free_ddpm/photos/classifier_free_cifar10_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BastianChen/ddpm-demo-pytorch/HEAD/classifier_free_ddpm/photos/classifier_free_cifar10_2.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-DDPM-Demo 2 | 3 | PyTorch implementation of DDPM Demo and Classifier-Free DDPM Demo. 4 | Experiments were conducted on MNIST and Cifar10 datasets. 5 | 6 | ## Development Environment 7 | 8 | - Red Hat Enterprise Linux Server release 7.6 (Maipo) 9 | - pytorch 1.6.0 10 | - torchvision 0.7.0 11 | - python 3.7.16 12 | 13 | ## Requirements 14 | - Download dataset and pth file from DDPM-Demo file: 15 | 16 | [DDPM-Demo](https://www.aliyundrive.com/s/rpqX43VFfpT) 17 | ## Project structure 18 | ├── classifier_free_ddpm \ 19 | │  ├── models \ 20 | │  └── photos \ 21 | │  └── ... \ 22 | ├── datasets \ 23 | ├── ddpm \ 24 | │  ├── models \ 25 | │  └── photos \ 26 | │  └── ... \ 27 | └── README.md 28 | 29 | ## Train 30 | 31 | MNIST: 32 | ```python 33 | python train.py -b 64 -d 0 -e 20 -t 500 34 | ``` 35 | 36 | Cifar10: 37 | ```python 38 | python train.py -b 64 -d 1 -e 100 -t 1000 39 | ``` 40 | 41 | Parameter meaning: 42 | 43 | | abbreviation | full name | meaning | 44 | |----------------|---------------------|----------------------------------| 45 | | -b | --batch_size | batch size | 46 | | -d | --datasets_type | datasets type,0:MNISI,1:Cifar-10 | 47 | | -e | --epochs | epochs | 48 | | -t | --timesteps | timesteps | 49 | | -dp | --datasets_path | path of the Datasets | 50 | 51 | 52 | ## Evaluate 53 | 54 | MNIST: 55 | ```python 56 | python inference.py -b 64 -d 0 -t 500 -p models/mnist-500-20-0.0005.pth 57 | ``` 58 | 59 | Cifar10: 60 | ```python 61 | python inference.py -b 64 -d 1 -t 1000 -p models/cifar10-1000-100-0.0002.pth 62 | ``` 63 | 64 | ## Experimental Result 65 | 66 | ### DDPM 67 | 68 | #### MNIST: 69 | ![](ddpm/photos/mnist_1.png) 70 | 71 | ![](ddpm/photos/mnist_2.png) 72 | 73 | #### CIFAR10: 74 | ![](ddpm/photos/cifar10_1.png) 75 | 76 | ![](ddpm/photos/cifar10_2.png) 77 | 78 | ### Classifier-Free DDPM 79 | #### MNIST: 80 | ![](classifier_free_ddpm/photos/classifier_free_mnist_1.png) 81 | 82 | ![](classifier_free_ddpm/photos/classifier_free_mnist_2.png) 83 | 84 | #### CIFAR10: 85 | ![](classifier_free_ddpm/photos/classifier_free_cifar10_1.png) 86 | 87 | ![](classifier_free_ddpm/photos/classifier_free_cifar10_2.png) 88 | 89 | ## Reference 90 | 91 | [PyTorch-DDPM](https://github.com/LinXueyuanStdio/PyTorch-DDPM) 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /ddpm/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | from net import UNetModel 5 | from diffusion import GaussianDiffusion 6 | from torchvision.utils import save_image 7 | 8 | 9 | def inference(args): 10 | batch_size = args.batch_size 11 | timesteps = args.timesteps 12 | datasets_type = args.datasets_type 13 | 14 | if datasets_type: 15 | in_channels = 3 16 | image_size = 32 17 | save_image_name = "cifar10" 18 | else: 19 | in_channels = 1 20 | image_size = 28 21 | save_image_name = "mnist" 22 | 23 | # define model and diffusion 24 | device = "cuda:1" if torch.cuda.is_available() else "cpu" 25 | model = UNetModel( 26 | in_channels=in_channels, 27 | model_channels=96, 28 | out_channels=in_channels, 29 | channel_mult=(1, 2, 2), 30 | attention_resolutions=[] 31 | ) 32 | 33 | map_location = None if torch.cuda.is_available() else lambda storage, loc: storage 34 | model.to(device) 35 | model.load_state_dict((torch.load(args.pth_path, map_location=map_location))) 36 | model.eval() 37 | 38 | gaussian_diffusion = GaussianDiffusion(timesteps=timesteps) 39 | generated_images = gaussian_diffusion.sample(model, image_size, batch_size=batch_size, channels=in_channels) 40 | 41 | # generate new images 42 | if not os.path.exists("photos"): 43 | os.mkdir("photos") 44 | imgs = generated_images[-1].reshape(64, in_channels, image_size, image_size) 45 | img = torch.tensor(imgs) 46 | save_image(img, f'photos/{save_image_name}_1.png', 8, normalize=True, scale_each=True) 47 | 48 | imgs_time = [] 49 | for n_row in range(16): 50 | for n_col in range(16): 51 | t_idx = (timesteps // 16) * n_col if n_col < 15 else -1 52 | img = torch.tensor(generated_images[t_idx][n_row].reshape(in_channels, image_size, image_size)) 53 | imgs_time.append(img) 54 | 55 | imgs = torch.stack(imgs_time).reshape(-1, in_channels, image_size, image_size) 56 | save_image(imgs, f'photos/{save_image_name}_2.png', 16, normalize=True, scale_each=True) 57 | 58 | 59 | if __name__ == '__main__': 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('-b', '--batch_size', default=64, type=int, help="batch size") 62 | parser.add_argument('-d', '--datasets_type', default=0, type=int, help="datasets type,0:MNISI,1:cifar-10") 63 | parser.add_argument('-t', '--timesteps', default=1000, type=int, help="timesteps") 64 | parser.add_argument('-p', '--pth_path', default="models/cifar10-1000-100-0.0002.pth", type=str, help="path of pth file") 65 | 66 | args = parser.parse_args() 67 | print(args) 68 | inference(args) 69 | -------------------------------------------------------------------------------- /ddpm/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | from net import UNetModel 5 | from diffusion import GaussianDiffusion 6 | from torchvision import datasets, transforms 7 | 8 | 9 | def run(args): 10 | batch_size = args.batch_size 11 | epochs = args.epochs 12 | timesteps = args.timesteps 13 | datasets_path = args.datasets_path 14 | datasets_type = args.datasets_type 15 | 16 | if not os.path.exists("models"): 17 | os.mkdir("models") 18 | 19 | if datasets_type: 20 | dataset = datasets.CIFAR10( 21 | root=datasets_path, train=True, download=True, 22 | transform=transforms.Compose([ 23 | transforms.RandomHorizontalFlip(), 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 26 | ])) 27 | lr = 2e-4 28 | in_channels = 3 29 | save_path = f"models/cifar10-{timesteps}-{epochs}-{lr}.pth" 30 | else: 31 | dataset = datasets.MNIST(root=datasets_path, train=True, download=True, transform=transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize(mean=[0.5], std=[0.5]) 34 | ])) 35 | lr = 5e-4 36 | in_channels = 1 37 | save_path = f"models/mnist-{timesteps}-{epochs}-{lr}.pth" 38 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) 39 | 40 | # define model and diffusion 41 | device = "cuda:1" if torch.cuda.is_available() else "cpu" 42 | model = UNetModel( 43 | in_channels=in_channels, 44 | model_channels=96, 45 | out_channels=in_channels, 46 | channel_mult=(1, 2, 2), 47 | attention_resolutions=[] 48 | ) 49 | model.to(device) 50 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 51 | gaussian_diffusion = GaussianDiffusion(timesteps=timesteps) 52 | 53 | # train 54 | for epoch in range(epochs): 55 | for step, (images, labels) in enumerate(train_loader): 56 | batch_size = images.shape[0] 57 | images = images.to(device) 58 | 59 | # sample t uniformally for every example in the batch 60 | t = torch.randint(0, timesteps, (batch_size,), device=device).long() 61 | 62 | loss = gaussian_diffusion.train_losses(model, images, t) 63 | 64 | if step % 200 == 0: 65 | print("Loss:", loss.item()) 66 | 67 | optimizer.zero_grad() 68 | loss.backward() 69 | optimizer.step() 70 | 71 | torch.save(model.state_dict(), save_path) 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('-b', '--batch_size', default=64, type=int, help="batch size") 77 | parser.add_argument('-d', '--datasets_type', default=0, type=int, help="datasets type,0:MNISI,1:cifar-10") 78 | parser.add_argument('-e', '--epochs', default=20, type=int) 79 | parser.add_argument('-t', '--timesteps', default=500, type=int, help="timesteps") 80 | parser.add_argument('-dp', '--datasets_path', default="../datasets", type=str, help="path of Datasets") 81 | 82 | args = parser.parse_args() 83 | 84 | run(args) 85 | -------------------------------------------------------------------------------- /classifier_free_ddpm/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | from net import UNetModel 5 | from diffusion import GaussianDiffusion 6 | from torchvision import datasets, transforms 7 | 8 | 9 | def run(args): 10 | batch_size = args.batch_size 11 | epochs = args.epochs 12 | timesteps = args.timesteps 13 | datasets_path = args.datasets_path 14 | datasets_type = args.datasets_type 15 | 16 | if not os.path.exists("models"): 17 | os.mkdir("models") 18 | 19 | if datasets_type: 20 | dataset = datasets.CIFAR10( 21 | root=datasets_path, train=True, download=True, 22 | transform=transforms.Compose([ 23 | transforms.RandomHorizontalFlip(), 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 26 | ])) 27 | lr = 2e-4 28 | in_channels = 3 29 | save_path = f"models/cifar10-{timesteps}-{epochs}-{lr}.pth" 30 | else: 31 | dataset = datasets.MNIST(root=datasets_path, train=True, download=True, transform=transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize(mean=[0.5], std=[0.5]) 34 | ])) 35 | lr = 5e-4 36 | in_channels = 1 37 | save_path = f"models/mnist-{timesteps}-{epochs}-{lr}.pth" 38 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) 39 | 40 | # define model and diffusion 41 | device = "cuda:1" if torch.cuda.is_available() else "cpu" 42 | model = UNetModel( 43 | in_channels=in_channels, 44 | model_channels=96, 45 | out_channels=in_channels, 46 | channel_mult=(1, 2, 2), 47 | attention_resolutions=[] 48 | ) 49 | model.to(device) 50 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 51 | gaussian_diffusion = GaussianDiffusion(timesteps=timesteps) 52 | 53 | # train 54 | for epoch in range(epochs): 55 | for step, (images, labels) in enumerate(train_loader): 56 | batch_size = images.shape[0] 57 | images = images.to(device) 58 | labels = labels.to(device) 59 | 60 | # sample t uniformally for every example in the batch 61 | t = torch.randint(0, timesteps, (batch_size,), device=device).long() 62 | 63 | loss = gaussian_diffusion.train_losses(model, images, t, labels) 64 | 65 | if step % 200 == 0: 66 | print("Loss:", loss.item()) 67 | 68 | optimizer.zero_grad() 69 | loss.backward() 70 | optimizer.step() 71 | 72 | torch.save(model.state_dict(), save_path) 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument('-b', '--batch_size', default=64, type=int, help="batch size") 78 | parser.add_argument('-d', '--datasets_type', default=0, type=int, help="datasets type,0:MNISI,1:cifar-10") 79 | parser.add_argument('-e', '--epochs', default=20, type=int) 80 | parser.add_argument('-t', '--timesteps', default=500, type=int, help="timesteps") 81 | parser.add_argument('-dp', '--datasets_path', default="../datasets", type=str, help="path of the Datasets") 82 | 83 | args = parser.parse_args() 84 | 85 | run(args) 86 | -------------------------------------------------------------------------------- /classifier_free_ddpm/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | import matplotlib.pyplot as plt 5 | from net import UNetModel 6 | from diffusion import GaussianDiffusion 7 | 8 | 9 | def inference(args): 10 | batch_size = args.batch_size 11 | timesteps = args.timesteps 12 | datasets_type = args.datasets_type 13 | 14 | if datasets_type: 15 | in_channels = 3 16 | image_size = 32 17 | save_image_name = "cifar10" 18 | else: 19 | in_channels = 1 20 | image_size = 28 21 | save_image_name = "mnist" 22 | 23 | # define model and diffusion 24 | device = "cuda:1" if torch.cuda.is_available() else "cpu" 25 | model = UNetModel( 26 | in_channels=in_channels, 27 | model_channels=96, 28 | out_channels=in_channels, 29 | channel_mult=(1, 2, 2), 30 | attention_resolutions=[] 31 | ) 32 | 33 | map_location = None if torch.cuda.is_available() else lambda storage, loc: storage 34 | model.to(device) 35 | model.load_state_dict((torch.load(args.pth_path, map_location=map_location))) 36 | model.eval() 37 | 38 | gaussian_diffusion = GaussianDiffusion(timesteps=timesteps) 39 | 40 | label = torch.randint(0, 10, (batch_size,)).to(device) 41 | label_cifar = ['plane ', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 42 | generated_images = gaussian_diffusion.sample(model, label, image_size, batch_size=batch_size, channels=in_channels) 43 | # generated_images: [timesteps, batch_size=64, channels=1, height=28, width=28] 44 | 45 | # generate new images 46 | if not os.path.exists("photos"): 47 | os.mkdir("photos") 48 | fig = plt.figure(figsize=(12, 12), constrained_layout=True) 49 | gs = fig.add_gridspec(8, 8) 50 | 51 | if datasets_type: 52 | imgs = generated_images[-1].reshape(8, 8, 3, 32, 32) 53 | else: 54 | imgs = generated_images[-1].reshape(8, 8, 28, 28) 55 | for n_row in range(8): 56 | for n_col in range(8): 57 | f_ax = fig.add_subplot(gs[n_row, n_col]) 58 | img = imgs[n_row, n_col] 59 | if datasets_type: 60 | img = img.swapaxes(0, 1) 61 | img = img.swapaxes(1, 2) 62 | f_ax.imshow(((img + 1.0) * 255 / 2) / 255) 63 | else: 64 | f_ax.imshow((img + 1.0) * 255 / 2, cmap="gray") 65 | f_ax.axis("off") 66 | plt.title( 67 | f"{label_cifar[label[n_row * 8 + n_col]] if datasets_type else label[n_row * 8 + n_col]}") 68 | f = plt.gcf() # 获取当前图像 69 | f.savefig(f'photos/classifier_free_{save_image_name}_1.png') 70 | f.clear() # 释放内存 71 | 72 | # show the denoise steps 73 | fig = plt.figure(figsize=(12, 12), constrained_layout=True) 74 | rows = 12 # len(y) 75 | gs = fig.add_gridspec(rows, 16) 76 | for n_row in range(rows): 77 | for n_col in range(16): 78 | f_ax = fig.add_subplot(gs[n_row, n_col]) 79 | t_idx = (timesteps // 16) * n_col if n_col < 15 else -1 80 | img = generated_images[t_idx][n_row] 81 | if datasets_type: 82 | img = img.swapaxes(0, 1) 83 | img = img.swapaxes(1, 2) 84 | f_ax.imshow(((img + 1.0) * 255 / 2) / 255) 85 | else: 86 | img = img[0] 87 | f_ax.imshow((img + 1.0) * 255 / 2, cmap="gray") 88 | f_ax.axis("off") 89 | plt.title(f"{label_cifar[label[n_row]] if datasets_type else label[n_row]}") 90 | 91 | f = plt.gcf() # 获取当前图像 92 | f.savefig(f'photos/classifier_free_{save_image_name}_2.png') 93 | f.clear() # 释放内存 94 | 95 | 96 | if __name__ == '__main__': 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument('-b', '--batch_size', default=64, type=int, help="batch size") 99 | parser.add_argument('-d', '--datasets_type', default=0, type=int, help="datasets type,0:MNISI,1:cifar-10") 100 | parser.add_argument('-t', '--timesteps', default=1000, type=int, help="timesteps") 101 | parser.add_argument('-p', '--pth_path', default="models/cifar10-1000-100-0.0002.pth", type=str, 102 | help="path of the pth file") 103 | 104 | args = parser.parse_args() 105 | print(args) 106 | inference(args) 107 | -------------------------------------------------------------------------------- /ddpm/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from tqdm import tqdm 5 | 6 | 7 | def linear_beta_schedule(timesteps): 8 | """ 9 | beta schedule 10 | """ 11 | scale = 1000 / timesteps 12 | beta_start = scale * 0.0001 13 | beta_end = scale * 0.02 14 | return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) 15 | 16 | 17 | def cosine_beta_schedule(timesteps, s=0.008): 18 | """ 19 | cosine schedule 20 | as proposed in https://arxiv.org/abs/2102.09672 21 | """ 22 | steps = timesteps + 1 23 | x = torch.linspace(0, timesteps, steps, dtype=torch.float64) 24 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 25 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 26 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 27 | return torch.clip(betas, 0, 0.999) 28 | 29 | 30 | class GaussianDiffusion: 31 | def __init__( 32 | self, 33 | timesteps=1000, 34 | beta_schedule='linear' 35 | ): 36 | self.timesteps = timesteps 37 | 38 | if beta_schedule == 'linear': 39 | betas = linear_beta_schedule(timesteps) 40 | elif beta_schedule == 'cosine': 41 | betas = cosine_beta_schedule(timesteps) 42 | else: 43 | raise ValueError(f'unknown beta schedule {beta_schedule}') 44 | self.betas = betas 45 | 46 | self.alphas = 1. - self.betas 47 | self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) 48 | self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.) 49 | 50 | # calculations for diffusion q(x_t | x_{t-1}) and others 51 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 52 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) 53 | self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) 54 | self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) 55 | self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) 56 | 57 | # calculations for posterior q(x_{t-1} | x_t, x_0) 58 | self.posterior_variance = ( 59 | self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 60 | ) 61 | # below: log calculation clipped because the posterior variance is 0 at the beginning 62 | # of the diffusion chain 63 | self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20)) 64 | 65 | self.posterior_mean_coef1 = ( 66 | self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 67 | ) 68 | self.posterior_mean_coef2 = ( 69 | (1.0 - self.alphas_cumprod_prev) 70 | * torch.sqrt(self.alphas) 71 | / (1.0 - self.alphas_cumprod) 72 | ) 73 | 74 | def _extract(self, a, t, x_shape): 75 | # get the param of given timestep t 76 | batch_size = t.shape[0] 77 | out = a.to(t.device).gather(0, t).float() 78 | out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1))) 79 | return out 80 | 81 | def q_sample(self, x_start, t, noise=None): 82 | # forward diffusion (using the nice property): q(x_t | x_0) 83 | if noise is None: 84 | noise = torch.randn_like(x_start) 85 | 86 | sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) 87 | sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 88 | 89 | return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise 90 | 91 | def q_mean_variance(self, x_start, t): 92 | # Get the mean and variance of q(x_t | x_0). 93 | mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 94 | variance = self._extract(1.0 - self.alphas_cumprod, t, x_start.shape) 95 | log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) 96 | return mean, variance, log_variance 97 | 98 | def q_posterior_mean_variance(self, x_start, x_t, t): 99 | # Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) 100 | posterior_mean = ( 101 | self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start 102 | + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 103 | ) 104 | posterior_variance = self._extract(self.posterior_variance, t, x_t.shape) 105 | posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape) 106 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 107 | 108 | def predict_start_from_noise(self, x_t, t, noise): 109 | # compute x_0 from x_t and pred noise: the reverse of `q_sample` 110 | return ( 111 | self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 112 | self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 113 | ) 114 | 115 | def p_mean_variance(self, model, x_t, t, clip_denoised=True): 116 | # compute predicted mean and variance of p(x_{t-1} | x_t) 117 | # predict noise using model 118 | pred_noise = model(x_t, t) 119 | # get the predicted x_0: different from the algorithm2 in the paper 120 | x_recon = self.predict_start_from_noise(x_t, t, pred_noise) 121 | if clip_denoised: 122 | x_recon = torch.clamp(x_recon, min=-1., max=1.) 123 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(x_recon, x_t, t) 124 | return model_mean, posterior_variance, posterior_log_variance 125 | 126 | @torch.no_grad() 127 | def p_sample(self, model, x_t, t, clip_denoised=True): 128 | # denoise_step: sample x_{t-1} from x_t and pred_noise 129 | # predict mean and variance 130 | model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised) 131 | noise = torch.randn_like(x_t) 132 | # no noise when t == 0 133 | nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))) 134 | # compute x_{t-1} 135 | pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 136 | return pred_img 137 | 138 | @torch.no_grad() 139 | def p_sample_loop(self, model, shape): 140 | # denoise: reverse diffusion 141 | batch_size = shape[0] 142 | device = next(model.parameters()).device 143 | # start from pure noise (for each example in the batch) 144 | img = torch.randn(shape, device=device) 145 | imgs = [] 146 | for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps): 147 | img = self.p_sample(model, img, torch.full((batch_size,), i, device=device, dtype=torch.long)) 148 | imgs.append(img.cpu().numpy()) 149 | return imgs 150 | 151 | @torch.no_grad() 152 | def sample(self, model, image_size, batch_size=8, channels=3): 153 | # sample new images 154 | return self.p_sample_loop(model, shape=(batch_size, channels, image_size, image_size)) 155 | 156 | def train_losses(self, model, x_start, t): 157 | # compute train losses 158 | # generate random noise 159 | noise = torch.randn_like(x_start) 160 | # get x_t 161 | x_noisy = self.q_sample(x_start, t, noise=noise) 162 | predicted_noise = model(x_noisy, t) 163 | loss = F.mse_loss(noise, predicted_noise) 164 | return loss 165 | -------------------------------------------------------------------------------- /classifier_free_ddpm/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from tqdm import tqdm 5 | 6 | 7 | def linear_beta_schedule(timesteps): 8 | """ 9 | beta schedule 10 | """ 11 | scale = 1000 / timesteps 12 | beta_start = scale * 0.0001 13 | beta_end = scale * 0.02 14 | return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) 15 | 16 | 17 | def cosine_beta_schedule(timesteps, s=0.008): 18 | """ 19 | cosine schedule 20 | as proposed in https://arxiv.org/abs/2102.09672 21 | """ 22 | steps = timesteps + 1 23 | x = torch.linspace(0, timesteps, steps, dtype=torch.float64) 24 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 25 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 26 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 27 | return torch.clip(betas, 0, 0.999) 28 | 29 | 30 | class GaussianDiffusion: 31 | def __init__( 32 | self, 33 | timesteps=1000, 34 | beta_schedule='linear' 35 | ): 36 | self.timesteps = timesteps 37 | 38 | if beta_schedule == 'linear': 39 | betas = linear_beta_schedule(timesteps) 40 | elif beta_schedule == 'cosine': 41 | betas = cosine_beta_schedule(timesteps) 42 | else: 43 | raise ValueError(f'unknown beta schedule {beta_schedule}') 44 | self.betas = betas 45 | 46 | self.alphas = 1. - self.betas 47 | self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) 48 | self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.) 49 | 50 | # calculations for diffusion q(x_t | x_{t-1}) and others 51 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 52 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) 53 | self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) 54 | self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) 55 | self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) 56 | 57 | # calculations for posterior q(x_{t-1} | x_t, x_0) 58 | self.posterior_variance = ( 59 | self.betas * (1.0 - self.alphas_cumprod_prev) 60 | / (1.0 - self.alphas_cumprod) 61 | ) 62 | # below: log calculation clipped because the posterior variance is 0 at the beginning 63 | # of the diffusion chain 64 | self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20)) 65 | 66 | self.posterior_mean_coef1 = ( 67 | self.betas * torch.sqrt(self.alphas_cumprod_prev) 68 | / (1.0 - self.alphas_cumprod) 69 | ) 70 | self.posterior_mean_coef2 = ( 71 | (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) 72 | / (1.0 - self.alphas_cumprod) 73 | ) 74 | 75 | def _extract(self, a, t, x_shape): 76 | # get the param of given timestep t 77 | batch_size = t.shape[0] 78 | out = a.to(t.device).gather(0, t).float() 79 | out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1))) 80 | return out 81 | 82 | def q_sample(self, x_start, t, noise=None): 83 | # forward diffusion (using the nice property): q(x_t | x_0) 84 | if noise is None: 85 | noise = torch.randn_like(x_start) 86 | 87 | sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) 88 | sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 89 | 90 | return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise 91 | 92 | def q_mean_variance(self, x_start, t): 93 | # Get the mean and variance of q(x_t | x_0). 94 | mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 95 | variance = self._extract(1.0 - self.alphas_cumprod, t, x_start.shape) 96 | log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) 97 | return mean, variance, log_variance 98 | 99 | def q_posterior_mean_variance(self, x_start, x_t, t): 100 | # Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) 101 | posterior_mean = ( 102 | self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start 103 | + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 104 | ) 105 | posterior_variance = self._extract(self.posterior_variance, t, x_t.shape) 106 | posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape) 107 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 108 | 109 | def predict_start_from_noise(self, x_t, t, noise): 110 | # compute x_0 from x_t and pred noise: the reverse of `q_sample` 111 | return ( 112 | self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 113 | self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 114 | ) 115 | 116 | def p_mean_variance(self, model, x_t, t, y, clip_denoised=True): 117 | # compute predicted mean and variance of p(x_{t-1} | x_t) 118 | # predict noise using model 119 | pred_noise = model(x_t, t, y) 120 | # get the predicted x_0: different from the algorithm2 in the paper 121 | x_recon = self.predict_start_from_noise(x_t, t, pred_noise) 122 | if clip_denoised: 123 | x_recon = torch.clamp(x_recon, min=-1., max=1.) 124 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(x_recon, x_t, t) 125 | return model_mean, posterior_variance, posterior_log_variance 126 | 127 | @torch.no_grad() 128 | def p_sample(self, model, x_t, t, y, clip_denoised=True): 129 | # denoise_step: sample x_{t-1} from x_t and pred_noise 130 | # predict mean and variance 131 | model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t, y, clip_denoised=clip_denoised) 132 | noise = torch.randn_like(x_t) 133 | # no noise when t == 0 134 | nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))) 135 | # compute x_{t-1} 136 | pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 137 | return pred_img 138 | 139 | @torch.no_grad() 140 | def p_sample_loop(self, model, y, shape): 141 | # denoise: reverse diffusion 142 | batch_size = shape[0] 143 | device = next(model.parameters()).device 144 | # start from pure noise (for each example in the batch) 145 | img = torch.randn(shape, device=device) 146 | imgs = [] 147 | for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps): 148 | t = torch.full((batch_size,), i, device=device, dtype=torch.long) 149 | img = self.p_sample(model, img, t, y) 150 | imgs.append(img.cpu().numpy()) 151 | return imgs 152 | 153 | @torch.no_grad() 154 | def sample(self, model, y, image_size, batch_size=8, channels=3): 155 | # sample new images 156 | return self.p_sample_loop(model, y, shape=(batch_size, channels, image_size, image_size)) 157 | 158 | def train_losses(self, model, x_start, t, y): 159 | # compute train losses 160 | # generate random noise 161 | noise = torch.randn_like(x_start) 162 | # get x_t 163 | x_noisy = self.q_sample(x_start, t, noise=noise) 164 | predicted_noise = model(x_noisy, t, y) 165 | loss = F.mse_loss(noise, predicted_noise) 166 | return loss 167 | -------------------------------------------------------------------------------- /ddpm/net.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import abstractmethod 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def timestep_embedding(timesteps, dim, max_period=10000): 10 | """Create sinusoidal timestep embeddings. 11 | 12 | Args: 13 | timesteps (Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. 14 | dim (int): the dimension of the output. 15 | max_period (int, optional): controls the minimum frequency of the embeddings. Defaults to 10000. 16 | 17 | Returns: 18 | Tensor: an [N x dim] Tensor of positional embeddings. 19 | """ 20 | half = dim // 2 21 | freqs = torch.exp( 22 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 23 | ).to(device=timesteps.device) 24 | args = timesteps[:, None].float() * freqs[None] 25 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 26 | if dim % 2: 27 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 28 | return embedding 29 | 30 | 31 | class TimestepBlock(nn.Module): 32 | """ 33 | Any module where forward() takes timestep embeddings as a second argument. 34 | """ 35 | 36 | @abstractmethod 37 | def forward(self, x, t): 38 | """ 39 | Apply the module to `x` given `t` timestep embeddings. 40 | """ 41 | 42 | 43 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 44 | """ 45 | A sequential module that passes timestep embeddings to the children that support it as an extra input. 46 | """ 47 | 48 | def forward(self, x, t): 49 | for layer in self: 50 | if isinstance(layer, TimestepBlock): 51 | x = layer(x, t) 52 | else: 53 | x = layer(x) 54 | return x 55 | 56 | 57 | def norm_layer(channels): 58 | return nn.GroupNorm(32, channels) 59 | 60 | 61 | class ResidualBlock(TimestepBlock): 62 | def __init__(self, in_channels, out_channels, time_channels, dropout): 63 | super().__init__() 64 | self.conv1 = nn.Sequential( 65 | norm_layer(in_channels), 66 | # nn.SiLU(), 67 | nn.ReLU(), 68 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 69 | ) 70 | 71 | # pojection for time step embedding 72 | self.time_emb = nn.Sequential( 73 | # nn.SiLU(), 74 | nn.ReLU(), 75 | nn.Linear(time_channels, out_channels) 76 | ) 77 | 78 | self.conv2 = nn.Sequential( 79 | norm_layer(out_channels), 80 | # nn.SiLU(), 81 | nn.ReLU(), 82 | nn.Dropout(p=dropout), 83 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 84 | ) 85 | 86 | if in_channels != out_channels: 87 | self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) 88 | else: 89 | self.shortcut = nn.Identity() 90 | 91 | def forward(self, x, t): 92 | """ 93 | `x` has shape `[batch_size, in_dim, height, width]` 94 | `t` has shape `[batch_size, time_dim]` 95 | """ 96 | h = self.conv1(x) 97 | # Add time step embeddings 98 | h += self.time_emb(t)[:, :, None, None] 99 | h = self.conv2(h) 100 | return h + self.shortcut(x) 101 | 102 | 103 | class AttentionBlock(nn.Module): 104 | def __init__(self, channels, num_heads=1): 105 | """ 106 | Attention block with shortcut 107 | 108 | Args: 109 | channels (int): channels 110 | num_heads (int, optional): attention heads. Defaults to 1. 111 | """ 112 | super().__init__() 113 | self.num_heads = num_heads 114 | assert channels % num_heads == 0 115 | 116 | self.norm = norm_layer(channels) 117 | self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False) 118 | self.proj = nn.Conv2d(channels, channels, kernel_size=1) 119 | 120 | def forward(self, x): 121 | B, C, H, W = x.shape 122 | qkv = self.qkv(self.norm(x)) 123 | q, k, v = qkv.reshape(B * self.num_heads, -1, H * W).chunk(3, dim=1) 124 | scale = 1. / math.sqrt(math.sqrt(C // self.num_heads)) 125 | attn = torch.einsum("bct,bcs->bts", q * scale, k * scale) 126 | attn = attn.softmax(dim=-1) 127 | h = torch.einsum("bts,bcs->bct", attn, v) 128 | h = h.reshape(B, -1, H, W) 129 | h = self.proj(h) 130 | return h + x 131 | 132 | 133 | class Upsample(nn.Module): 134 | def __init__(self, channels, use_conv): 135 | super().__init__() 136 | self.use_conv = use_conv 137 | if use_conv: 138 | self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 139 | 140 | def forward(self, x): 141 | x = F.interpolate(x, scale_factor=2, mode="nearest") 142 | if self.use_conv: 143 | x = self.conv(x) 144 | return x 145 | 146 | 147 | class Downsample(nn.Module): 148 | def __init__(self, channels, use_conv): 149 | super().__init__() 150 | self.use_conv = use_conv 151 | if use_conv: 152 | self.op = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1) 153 | else: 154 | self.op = nn.AvgPool2d(stride=2) 155 | 156 | def forward(self, x): 157 | return self.op(x) 158 | 159 | 160 | class UNetModel(nn.Module): 161 | """ 162 | The full UNet model with attention and timestep embedding 163 | """ 164 | 165 | def __init__( 166 | self, 167 | in_channels=3, 168 | model_channels=128, 169 | out_channels=3, 170 | num_res_blocks=2, 171 | attention_resolutions=(8, 16), 172 | dropout=0, 173 | channel_mult=(1, 2, 2, 2), 174 | conv_resample=True, 175 | num_heads=4 176 | ): 177 | super().__init__() 178 | 179 | self.in_channels = in_channels 180 | self.model_channels = model_channels 181 | self.out_channels = out_channels 182 | self.num_res_blocks = num_res_blocks 183 | self.attention_resolutions = attention_resolutions 184 | self.dropout = dropout 185 | self.channel_mult = channel_mult 186 | self.conv_resample = conv_resample 187 | self.num_heads = num_heads 188 | 189 | # time embedding 190 | time_embed_dim = model_channels * 4 191 | self.time_embed = nn.Sequential( 192 | nn.Linear(model_channels, time_embed_dim), 193 | # nn.SiLU(), 194 | nn.ReLU(), 195 | nn.Linear(time_embed_dim, time_embed_dim), 196 | ) 197 | 198 | # down blocks 199 | self.down_blocks = nn.ModuleList([ 200 | TimestepEmbedSequential(nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=1)) 201 | ]) 202 | down_block_chans = [model_channels] 203 | ch = model_channels 204 | ds = 1 205 | for level, mult in enumerate(channel_mult): 206 | for _ in range(num_res_blocks): 207 | layers = [ 208 | ResidualBlock(ch, mult * model_channels, time_embed_dim, dropout) 209 | ] 210 | ch = mult * model_channels 211 | if ds in attention_resolutions: 212 | layers.append(AttentionBlock(ch, num_heads=num_heads)) 213 | self.down_blocks.append(TimestepEmbedSequential(*layers)) 214 | down_block_chans.append(ch) 215 | if level != len(channel_mult) - 1: # don't use downsample for the last stage 216 | self.down_blocks.append(TimestepEmbedSequential(Downsample(ch, conv_resample))) 217 | down_block_chans.append(ch) 218 | ds *= 2 219 | 220 | # middle block 221 | self.middle_block = TimestepEmbedSequential( 222 | ResidualBlock(ch, ch, time_embed_dim, dropout), 223 | AttentionBlock(ch, num_heads=num_heads), 224 | ResidualBlock(ch, ch, time_embed_dim, dropout) 225 | ) 226 | 227 | # up blocks 228 | self.up_blocks = nn.ModuleList([]) 229 | for level, mult in list(enumerate(channel_mult))[::-1]: 230 | for i in range(num_res_blocks + 1): 231 | layers = [ 232 | ResidualBlock( 233 | ch + down_block_chans.pop(), 234 | model_channels * mult, 235 | time_embed_dim, 236 | dropout 237 | ) 238 | ] 239 | ch = model_channels * mult 240 | if ds in attention_resolutions: 241 | layers.append(AttentionBlock(ch, num_heads=num_heads)) 242 | if level and i == num_res_blocks: 243 | layers.append(Upsample(ch, conv_resample)) 244 | ds //= 2 245 | self.up_blocks.append(TimestepEmbedSequential(*layers)) 246 | 247 | self.out = nn.Sequential( 248 | norm_layer(ch), 249 | # nn.SiLU(), 250 | nn.ReLU(), 251 | nn.Conv2d(model_channels, out_channels, kernel_size=3, padding=1), 252 | ) 253 | 254 | def forward(self, x, timesteps): 255 | """Apply the model to an input batch. 256 | 257 | Args: 258 | x (Tensor): [N x C x H x W] 259 | timesteps (Tensor): a 1-D batch of timesteps. 260 | 261 | Returns: 262 | Tensor: [N x C x ...] 263 | """ 264 | hs = [] 265 | # time step embedding 266 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 267 | 268 | # down stage 269 | h = x 270 | for module in self.down_blocks: 271 | h = module(h, emb) 272 | hs.append(h) 273 | # middle stage 274 | h = self.middle_block(h, emb) 275 | # up stage 276 | for module in self.up_blocks: 277 | cat_in = torch.cat([h, hs.pop()], dim=1) 278 | h = module(cat_in, emb) 279 | return self.out(h) 280 | -------------------------------------------------------------------------------- /classifier_free_ddpm/net.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from abc import abstractmethod 7 | from einops import rearrange 8 | from einops.layers.torch import Rearrange 9 | from einops_exts import rearrange_many, repeat_many 10 | 11 | 12 | class Identity(nn.Module): 13 | def __init__(self, *args, **kwargs): 14 | super().__init__() 15 | 16 | def forward(self, x, *args, **kwargs): 17 | return x 18 | 19 | 20 | class LayerNorm(nn.Module): 21 | def __init__(self, feats, stable=True, dim=-1): 22 | super().__init__() 23 | self.stable = stable 24 | self.dim = dim 25 | 26 | self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1)))) 27 | 28 | def forward(self, x): 29 | dtype, dim = x.dtype, self.dim 30 | if self.stable: 31 | # x = x / x.amax(dim=dim, keepdim=True).detach() 32 | x = x / x.max(dim=dim, keepdim=True)[0].detach() 33 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 34 | var = torch.var(x, dim=dim, unbiased=False, keepdim=True) 35 | mean = torch.mean(x, dim=dim, keepdim=True) 36 | return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype) 37 | 38 | 39 | class Always(): 40 | def __init__(self, val): 41 | self.val = val 42 | 43 | def __call__(self, *args, **kwargs): 44 | return self.val 45 | 46 | 47 | def timestep_embedding(timesteps, dim, max_period=10000): 48 | """Create sinusoidal timestep embeddings. 49 | 50 | Args: 51 | timesteps (Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. 52 | dim (int): the dimension of the output. 53 | max_period (int, optional): controls the minimum frequency of the embeddings. Defaults to 10000. 54 | 55 | Returns: 56 | Tensor: an [N x dim] Tensor of positional embeddings. 57 | """ 58 | half = dim // 2 59 | freqs = torch.exp( 60 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 61 | ).to(device=timesteps.device) 62 | args = timesteps[:, None].float() * freqs[None] 63 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 64 | if dim % 2: 65 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 66 | return embedding 67 | 68 | 69 | class TimestepBlock(nn.Module): 70 | """ 71 | Any module where forward() takes timestep embeddings as a second argument. 72 | """ 73 | 74 | @abstractmethod 75 | def forward(self, x, t, y): 76 | """ 77 | Apply the module to `x` given `t` timestep embeddings, `y` conditional embedding same shape as t. 78 | """ 79 | pass 80 | 81 | 82 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 83 | """ 84 | A sequential module that passes timestep embeddings to the children that support it as an extra input. 85 | """ 86 | 87 | def forward(self, x, t, y): 88 | for layer in self: 89 | if isinstance(layer, TimestepBlock): 90 | x = layer(x, t, y) 91 | else: 92 | x = layer(x) 93 | return x 94 | 95 | 96 | def norm_layer(channels): 97 | return nn.GroupNorm(32, channels) 98 | 99 | 100 | class Block(nn.Module): 101 | def __init__( 102 | self, 103 | dim, 104 | dim_out, 105 | groups=8, 106 | norm=True 107 | ): 108 | super().__init__() 109 | self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() 110 | # self.activation = nn.SiLU() 111 | self.activation = nn.ReLU() 112 | self.project = nn.Conv2d(dim, dim_out, kernel_size=3, padding=1) 113 | 114 | def forward(self, x, scale_shift=None): 115 | x = self.groupnorm(x) 116 | 117 | if scale_shift is not None: 118 | scale, shift = scale_shift 119 | x = x * (scale + 1) + shift 120 | 121 | x = self.activation(x) 122 | return self.project(x) 123 | 124 | 125 | class CrossAttention(nn.Module): 126 | def __init__( 127 | self, 128 | dim, 129 | *, 130 | context_dim=None, 131 | dim_head=64, 132 | heads=8, 133 | norm_context=False, 134 | cosine_sim_attn=False 135 | ): 136 | super().__init__() 137 | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1. 138 | self.cosine_sim_attn = cosine_sim_attn 139 | self.cosine_sim_scale = 16 if cosine_sim_attn else 1 140 | 141 | self.heads = heads 142 | inner_dim = dim_head * heads 143 | 144 | context_dim = dim if context_dim is None else context_dim 145 | 146 | self.norm = LayerNorm(dim) 147 | self.norm_context = LayerNorm(context_dim) if norm_context else Identity() 148 | 149 | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) 150 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 151 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) 152 | 153 | self.to_out = nn.Sequential( 154 | nn.Linear(inner_dim, dim, bias=False), 155 | LayerNorm(dim) 156 | ) 157 | 158 | def forward(self, x, context): 159 | b, n, device = *x.shape[:2], x.device 160 | x = self.norm(x) 161 | context = self.norm_context(context) 162 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) 163 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h=self.heads) 164 | # add null key / value for classifier free guidance in prior net 165 | nk, nv = repeat_many(self.null_kv.unbind(dim=-2), 'd -> b h 1 d', h=self.heads, b=b) 166 | k = torch.cat((nk, k), dim=-2) 167 | v = torch.cat((nv, v), dim=-2) 168 | q = q * self.scale 169 | # similarities 170 | sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale 171 | # masking 172 | max_neg_value = -torch.finfo(sim.dtype).max 173 | attn = sim.softmax(dim=-1, dtype=torch.float32) 174 | attn = attn.to(sim.dtype) 175 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) 176 | out = rearrange(out, 'b h n d -> b n (h d)') 177 | return self.to_out(out) 178 | 179 | 180 | class GlobalContext(nn.Module): 181 | """ basically a superior form of squeeze-excitation that is attention-esque """ 182 | 183 | def __init__( 184 | self, 185 | *, 186 | dim_in, 187 | dim_out 188 | ): 189 | super().__init__() 190 | self.to_k = nn.Conv2d(dim_in, 1, 1) 191 | hidden_dim = max(3, dim_out // 2) 192 | 193 | self.net = nn.Sequential( 194 | nn.Conv2d(dim_in, hidden_dim, 1), 195 | # nn.SiLU(), 196 | nn.ReLU(), 197 | nn.Conv2d(hidden_dim, dim_out, 1), 198 | nn.Sigmoid() 199 | ) 200 | 201 | def forward(self, x): 202 | context = self.to_k(x) 203 | x, context = rearrange_many((x, context), 'b n ... -> b n (...)') 204 | out = torch.einsum('b i n, b c n -> b c i', context.softmax(dim=-1), x) 205 | out = rearrange(out, '... -> ... 1') 206 | return self.net(out) 207 | 208 | 209 | class ResidualBlock(TimestepBlock): 210 | def __init__(self, dim_in, dim_out, time_dim, dropout, use_global_context=False, groups=8): 211 | super().__init__() 212 | self.conv1 = nn.Sequential( 213 | norm_layer(dim_in), 214 | # nn.SiLU(), 215 | nn.ReLU(), 216 | nn.Conv2d(dim_in, dim_out, kernel_size=3, padding=1) 217 | ) 218 | 219 | # pojection for time step embedding 220 | self.time_emb = nn.Sequential( 221 | # nn.SiLU(), 222 | nn.ReLU(), 223 | nn.Linear(time_dim, dim_out * 2) 224 | ) 225 | 226 | self.conv2 = nn.Sequential( 227 | norm_layer(dim_out), 228 | # nn.SiLU(), 229 | nn.ReLU(), 230 | nn.Dropout(p=dropout), 231 | nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=1) 232 | ) 233 | 234 | self.block1 = Block(dim_in, dim_out, groups=groups) 235 | self.block2 = Block(dim_out, dim_out, groups=groups) 236 | 237 | if dim_in != dim_out: 238 | self.shortcut = nn.Conv2d(dim_in, dim_out, kernel_size=1) 239 | else: 240 | self.shortcut = nn.Identity() 241 | cond_dim = time_dim 242 | self.gca = GlobalContext(dim_in=dim_out, dim_out=dim_out) if use_global_context else Always(1) 243 | self.cross_attn = CrossAttention(dim=dim_out, context_dim=cond_dim, ) 244 | 245 | def forward(self, x, t, y): 246 | """ 247 | `x` has shape `[batch_size, in_dim, height, width]` 248 | `t` has shape `[batch_size, time_dim]` 249 | `y` has shape `[batch_size, num_time_tokens, cond_dim]` 250 | """ 251 | h = self.block1(x) 252 | 253 | # Add time step embeddings 254 | context = y 255 | # print("h.shape", h.shape, "x.shape", x.shape, "context.shape", context.shape, "t.shape", t.shape, "y.shape", y.shape) 256 | size = h.size(-2) 257 | hidden = rearrange(h, 'b c h w -> b (h w) c') 258 | attn = self.cross_attn(hidden, context) 259 | # print("attn.shape", attn.shape) 260 | attn = rearrange(attn, 'b (h w) c -> b c h w', h=size) 261 | h += attn 262 | 263 | t = self.time_emb(t) 264 | t = rearrange(t, 'b c -> b c 1 1') 265 | scale_shift = t.chunk(2, dim=1) 266 | h = self.block2(h, scale_shift=scale_shift) 267 | 268 | h *= self.gca(h) 269 | return h + self.shortcut(x) 270 | 271 | 272 | class AttentionBlock(nn.Module): 273 | def __init__(self, channels, num_heads=1): 274 | """ 275 | Attention block with shortcut 276 | 277 | Args: 278 | channels (int): channels 279 | num_heads (int, optional): attention heads. Defaults to 1. 280 | """ 281 | super().__init__() 282 | self.num_heads = num_heads 283 | assert channels % num_heads == 0 284 | 285 | self.norm = norm_layer(channels) 286 | self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False) 287 | self.proj = nn.Conv2d(channels, channels, kernel_size=1) 288 | 289 | def forward(self, x): 290 | B, C, H, W = x.shape 291 | qkv = self.qkv(self.norm(x)) 292 | q, k, v = qkv.reshape(B * self.num_heads, -1, H * W).chunk(3, dim=1) 293 | scale = 1. / math.sqrt(math.sqrt(C // self.num_heads)) 294 | attn = torch.einsum("bct,bcs->bts", q * scale, k * scale) 295 | attn = attn.softmax(dim=-1) 296 | h = torch.einsum("bts,bcs->bct", attn, v) 297 | h = h.reshape(B, -1, H, W) 298 | h = self.proj(h) 299 | return h + x 300 | 301 | 302 | class Upsample(nn.Module): 303 | def __init__(self, channels, use_conv): 304 | super().__init__() 305 | self.use_conv = use_conv 306 | if use_conv: 307 | self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 308 | 309 | def forward(self, x): 310 | x = F.interpolate(x, scale_factor=2, mode="nearest") 311 | if self.use_conv: 312 | x = self.conv(x) 313 | return x 314 | 315 | 316 | class Downsample(nn.Module): 317 | def __init__(self, channels, use_conv): 318 | super().__init__() 319 | self.use_conv = use_conv 320 | if use_conv: 321 | self.op = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1) 322 | else: 323 | self.op = nn.AvgPool2d(stride=2) 324 | 325 | def forward(self, x): 326 | return self.op(x) 327 | 328 | 329 | class UNetModel(nn.Module): 330 | """ 331 | The full UNet model with attention and timestep embedding 332 | """ 333 | 334 | def __init__( 335 | self, 336 | in_channels=3, 337 | model_channels=128, 338 | out_channels=3, 339 | num_res_blocks=2, 340 | attention_resolutions=(8, 16), 341 | dropout=0, 342 | channel_mult=(1, 2, 2, 2), 343 | conv_resample=True, 344 | num_heads=4, 345 | label_num=10, 346 | num_time_tokens=2, 347 | ): 348 | super().__init__() 349 | 350 | self.in_channels = in_channels 351 | self.model_channels = model_channels 352 | self.out_channels = out_channels 353 | self.num_res_blocks = num_res_blocks 354 | self.attention_resolutions = attention_resolutions 355 | self.dropout = dropout 356 | self.channel_mult = channel_mult 357 | self.conv_resample = conv_resample 358 | self.num_heads = num_heads 359 | 360 | # time embedding 361 | time_embed_dim = model_channels * 4 362 | cond_dim = time_embed_dim 363 | self.time_embed = nn.Sequential( 364 | nn.Linear(model_channels, time_embed_dim), 365 | # nn.SiLU(), 366 | nn.ReLU(), 367 | nn.Linear(time_embed_dim, time_embed_dim), 368 | ) 369 | self.label_embedding = nn.Embedding(label_num, time_embed_dim) 370 | self.to_time_tokens = nn.Sequential( 371 | nn.Linear(time_embed_dim, num_time_tokens * cond_dim), 372 | Rearrange('b (r d) -> b r d', r=num_time_tokens) 373 | ) 374 | 375 | # down blocks 376 | self.down_blocks = nn.ModuleList([ 377 | TimestepEmbedSequential(nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=1)) 378 | ]) 379 | down_block_chans = [model_channels] 380 | ch = model_channels 381 | ds = 1 382 | for level, mult in enumerate(channel_mult): 383 | for _ in range(num_res_blocks): 384 | layers = [ 385 | ResidualBlock(ch, mult * model_channels, time_embed_dim, dropout) 386 | ] 387 | ch = mult * model_channels 388 | if ds in attention_resolutions: 389 | layers.append(AttentionBlock(ch, num_heads=num_heads)) 390 | self.down_blocks.append(TimestepEmbedSequential(*layers)) 391 | down_block_chans.append(ch) 392 | if level != len(channel_mult) - 1: # don't use downsample for the last stage 393 | self.down_blocks.append(TimestepEmbedSequential(Downsample(ch, conv_resample))) 394 | down_block_chans.append(ch) 395 | ds *= 2 396 | 397 | # middle block 398 | self.middle_block = TimestepEmbedSequential( 399 | ResidualBlock(ch, ch, time_embed_dim, dropout), 400 | AttentionBlock(ch, num_heads=num_heads), 401 | ResidualBlock(ch, ch, time_embed_dim, dropout) 402 | ) 403 | 404 | # up blocks 405 | self.up_blocks = nn.ModuleList([]) 406 | for level, mult in list(enumerate(channel_mult))[::-1]: 407 | for i in range(num_res_blocks + 1): 408 | layers = [ 409 | ResidualBlock( 410 | ch + down_block_chans.pop(), 411 | model_channels * mult, 412 | time_embed_dim, 413 | dropout 414 | ) 415 | ] 416 | ch = model_channels * mult 417 | if ds in attention_resolutions: 418 | layers.append(AttentionBlock(ch, num_heads=num_heads)) 419 | if level and i == num_res_blocks: 420 | layers.append(Upsample(ch, conv_resample)) 421 | ds //= 2 422 | self.up_blocks.append(TimestepEmbedSequential(*layers)) 423 | 424 | self.out = nn.Sequential( 425 | norm_layer(ch), 426 | # nn.SiLU(), 427 | nn.ReLU(), 428 | nn.Conv2d(model_channels, out_channels, kernel_size=3, padding=1), 429 | ) 430 | 431 | def forward(self, x, t, y): 432 | """Apply the model to an input batch. 433 | 434 | Args: 435 | x (Tensor): [N x C x H x W] 436 | t (Tensor): [N,] a 1-D batch of timesteps. 437 | y (Tensor): [N,] LongTensor conditional labels. 438 | 439 | Returns: 440 | Tensor: [N x C x ...] 441 | """ 442 | # time step embedding 443 | t = self.time_embed(timestep_embedding(t, self.model_channels)) 444 | y = self.label_embedding(y) 445 | y = self.to_time_tokens(y) 446 | 447 | hs = [] 448 | # down stage 449 | h = x 450 | for module in self.down_blocks: 451 | h = module(h, t, y) 452 | hs.append(h) 453 | # middle stage 454 | h = self.middle_block(h, t, y) 455 | # up stage 456 | for module in self.up_blocks: 457 | cat_in = torch.cat([h, hs.pop()], dim=1) 458 | h = module(cat_in, t, y) 459 | return self.out(h) 460 | --------------------------------------------------------------------------------