├── model ├── __init__.py ├── models.py ├── block.py ├── wideresnet_noise_song.py ├── DDPM.py ├── EDM.py ├── unet.py └── augment.py ├── config ├── DDPM_ddpmpp.yaml ├── DDPM_ddpm.yaml ├── EDM_ddpmpp.yaml ├── EDM_ddpm.yaml └── EDM_ddpmpp_aug.yaml ├── extract_cifar10_pngs.ipynb ├── DiT ├── diffusion │ ├── __init__.py │ ├── diffusion_utils.py │ ├── respace.py │ ├── timestep_sampler.py │ └── gaussian_diffusion.py ├── download.py ├── README.md ├── vae_preprocessing.py ├── linear.py └── models.py ├── utils.py ├── datasets.py ├── sample.py ├── contrastive.py ├── train.py ├── noisy_classifier_DDAE.py ├── linear.py ├── README.md └── noisy_classifier_WRN.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | from .DDPM import DDPM 2 | from .EDM import EDM 3 | from .unet import UNet 4 | 5 | CLASSES = { 6 | cls.__name__: cls 7 | for cls in [DDPM, EDM, UNet] 8 | } 9 | 10 | 11 | def get_models_class(model_type, net_type): 12 | return CLASSES[model_type], CLASSES[net_type] 13 | -------------------------------------------------------------------------------- /config/DDPM_ddpmpp.yaml: -------------------------------------------------------------------------------- 1 | # dataset params 2 | dataset: 'cifar' 3 | classes: 10 4 | 5 | # model params 6 | model_type: 'DDPM' 7 | net_type: 'UNet' 8 | diffusion: 9 | n_T: 1000 10 | betas: [1.0e-4, 0.02] 11 | network: 12 | image_shape: [3, 32, 32] 13 | n_channels: 128 14 | ch_mults: [2, 2, 2] 15 | is_attn: [False, True, False] 16 | dropout: 0.1 17 | n_blocks: 4 18 | use_res_for_updown: True 19 | 20 | # training params 21 | n_epoch: 2000 22 | batch_size: 64 23 | lrate: 1.0e-4 24 | warm_epoch: 13 25 | load_epoch: -1 26 | flip: True 27 | ema: 0.9999 28 | 29 | # testing params 30 | n_sample: 30 31 | save_dir: './output_DDPM_ddpmpp' 32 | save_model: True 33 | 34 | # linear probe 35 | linear: 36 | n_epoch: 15 37 | batch_size: 128 38 | lrate: 1.0e-3 39 | timestep: 11 40 | blockname: 'out_6' 41 | -------------------------------------------------------------------------------- /config/DDPM_ddpm.yaml: -------------------------------------------------------------------------------- 1 | # dataset params 2 | dataset: 'cifar' 3 | classes: 10 4 | 5 | # model params 6 | model_type: 'DDPM' 7 | net_type: 'UNet' 8 | diffusion: 9 | n_T: 1000 10 | betas: [1.0e-4, 0.02] 11 | network: 12 | image_shape: [3, 32, 32] 13 | n_channels: 128 14 | ch_mults: [1, 2, 2, 2] 15 | is_attn: [False, True, False, False] 16 | dropout: 0.1 17 | n_blocks: 2 18 | use_res_for_updown: False 19 | 20 | # training params 21 | n_epoch: 2000 22 | batch_size: 128 23 | lrate: 1.0e-4 24 | warm_epoch: 13 25 | load_epoch: -1 26 | flip: True 27 | ema: 0.9999 28 | 29 | # testing params 30 | n_sample: 30 31 | save_dir: './output_DDPM_ddpm' 32 | save_model: True 33 | 34 | # linear probe 35 | linear: 36 | n_epoch: 15 37 | batch_size: 128 38 | lrate: 1.0e-3 39 | timestep: 11 40 | blockname: 'out_6' 41 | -------------------------------------------------------------------------------- /config/EDM_ddpmpp.yaml: -------------------------------------------------------------------------------- 1 | # dataset params 2 | dataset: 'cifar' 3 | classes: 10 4 | 5 | # model params 6 | model_type: 'EDM' 7 | net_type: 'UNet' 8 | diffusion: 9 | sigma_data: 0.5 10 | p_mean: -1.2 11 | p_std: 1.2 12 | sigma_min: 0.002 13 | sigma_max: 80 14 | rho: 7 15 | S_min: 0.01 16 | S_max: 1 17 | S_noise: 1.007 18 | network: 19 | image_shape: [3, 32, 32] 20 | n_channels: 128 21 | ch_mults: [2, 2, 2] 22 | is_attn: [False, True, False] 23 | dropout: 0.13 24 | n_blocks: 4 25 | use_res_for_updown: True 26 | 27 | # training params 28 | n_epoch: 2000 29 | batch_size: 64 30 | lrate: 1.0e-4 31 | warm_epoch: 200 32 | load_epoch: -1 33 | flip: True 34 | ema: 0.9993 35 | 36 | # testing params 37 | n_sample: 30 38 | save_dir: './output_EDM_ddpmpp' 39 | save_model: True 40 | 41 | # linear probe 42 | linear: 43 | n_epoch: 15 44 | batch_size: 128 45 | lrate: 1.0e-3 46 | timestep: 4 47 | blockname: 'out_7' 48 | -------------------------------------------------------------------------------- /config/EDM_ddpm.yaml: -------------------------------------------------------------------------------- 1 | # dataset params 2 | dataset: 'cifar' 3 | classes: 10 4 | 5 | # model params 6 | model_type: 'EDM' 7 | net_type: 'UNet' 8 | diffusion: 9 | sigma_data: 0.5 10 | p_mean: -1.2 11 | p_std: 1.2 12 | sigma_min: 0.002 13 | sigma_max: 80 14 | rho: 7 15 | S_min: 0.01 16 | S_max: 1 17 | S_noise: 1.007 18 | network: 19 | image_shape: [3, 32, 32] 20 | n_channels: 128 21 | ch_mults: [1, 2, 2, 2] 22 | is_attn: [False, True, False, False] 23 | dropout: 0.13 24 | n_blocks: 2 25 | use_res_for_updown: False 26 | 27 | # training params 28 | n_epoch: 2000 29 | batch_size: 128 30 | lrate: 1.0e-4 31 | warm_epoch: 200 32 | load_epoch: -1 33 | flip: True 34 | ema: 0.9993 35 | 36 | # testing params 37 | n_sample: 30 38 | save_dir: './output_EDM_ddpm' 39 | save_model: True 40 | 41 | # linear probe 42 | linear: 43 | n_epoch: 15 44 | batch_size: 128 45 | lrate: 1.0e-3 46 | timestep: 4 47 | blockname: 'out_7' 48 | -------------------------------------------------------------------------------- /config/EDM_ddpmpp_aug.yaml: -------------------------------------------------------------------------------- 1 | # dataset params 2 | dataset: 'cifar' 3 | classes: 10 4 | 5 | # model params 6 | model_type: 'EDM' 7 | net_type: 'UNet' 8 | diffusion: 9 | sigma_data: 0.5 10 | p_mean: -1.2 11 | p_std: 1.2 12 | sigma_min: 0.002 13 | sigma_max: 80 14 | rho: 7 15 | S_min: 0.01 16 | S_max: 1 17 | S_noise: 1.007 18 | augment_prob: 0.12 19 | network: 20 | image_shape: [3, 32, 32] 21 | n_channels: 128 22 | ch_mults: [2, 2, 2] 23 | is_attn: [False, True, False] 24 | dropout: 0.13 25 | n_blocks: 4 26 | use_res_for_updown: True 27 | augment_dim: 9 28 | 29 | # training params 30 | n_epoch: 4000 31 | batch_size: 64 32 | lrate: 1.0e-4 33 | warm_epoch: 200 34 | load_epoch: -1 35 | flip: True 36 | ema: 0.9993 37 | 38 | # testing params 39 | n_sample: 30 40 | save_dir: './output_EDM_ddpmpp_aug' 41 | save_model: True 42 | 43 | # linear probe 44 | linear: 45 | n_epoch: 15 46 | batch_size: 128 47 | lrate: 1.0e-3 48 | timestep: 4 49 | blockname: 'out_7' 50 | -------------------------------------------------------------------------------- /extract_cifar10_pngs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from torchvision.datasets import CIFAR10\n", 11 | "import os\n", 12 | "\n", 13 | "train_set = CIFAR10(\"./data\", train=True, download=True)\n", 14 | "print(\"CIFAR10 train dataset:\", len(train_set))\n", 15 | "\n", 16 | "images = []\n", 17 | "labels = []\n", 18 | "for img, label in train_set:\n", 19 | " images.append(img)\n", 20 | " labels.append(label)\n", 21 | "\n", 22 | "labels = torch.tensor(labels)\n", 23 | "for i in range(10):\n", 24 | " assert (labels == i).sum() == 5000\n", 25 | "\n", 26 | "output_dir = \"./data/cifar10-pngs/\"\n", 27 | "for i, pil in enumerate(images):\n", 28 | " pil.save(os.path.join(output_dir, \"{:05d}.png\".format(i)))" 29 | ] 30 | } 31 | ], 32 | "metadata": { 33 | "kernelspec": { 34 | "display_name": "Python 3.8.5 ('gan')", 35 | "language": "python", 36 | "name": "python3" 37 | }, 38 | "language_info": { 39 | "codemirror_mode": { 40 | "name": "ipython", 41 | "version": 3 42 | }, 43 | "file_extension": ".py", 44 | "mimetype": "text/x-python", 45 | "name": "python", 46 | "nbconvert_exporter": "python", 47 | "pygments_lexer": "ipython3", 48 | "version": "3.8.5" 49 | }, 50 | "orig_nbformat": 4, 51 | "vscode": { 52 | "interpreter": { 53 | "hash": "da18559f301618e6e9fab00c6d05e566e4e63dfec8a595f965f0c783b8f75048" 54 | } 55 | } 56 | }, 57 | "nbformat": 4, 58 | "nbformat_minor": 2 59 | } 60 | -------------------------------------------------------------------------------- /DiT/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | return SpacedDiffusion( 30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 31 | betas=betas, 32 | model_mean_type=( 33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 34 | ), 35 | model_var_type=( 36 | ( 37 | gd.ModelVarType.FIXED_LARGE 38 | if not sigma_small 39 | else gd.ModelVarType.FIXED_SMALL 40 | ) 41 | if not learn_sigma 42 | else gd.ModelVarType.LEARNED_RANGE 43 | ), 44 | loss_type=loss_type 45 | # rescale_timesteps=rescale_timesteps, 46 | ) 47 | -------------------------------------------------------------------------------- /DiT/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Functions for downloading pre-trained DiT models 9 | """ 10 | from torchvision.datasets.utils import download_url 11 | import torch 12 | import os 13 | 14 | 15 | pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'} 16 | 17 | 18 | def find_model(model_name): 19 | """ 20 | Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path. 21 | """ 22 | if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints 23 | return download_model(model_name) 24 | else: # Load a custom DiT checkpoint: 25 | assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}' 26 | checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) 27 | if "ema" in checkpoint: # supports checkpoints from train.py 28 | checkpoint = checkpoint["ema"] 29 | return checkpoint 30 | 31 | 32 | def download_model(model_name): 33 | """ 34 | Downloads a pre-trained DiT model from the web. 35 | """ 36 | assert model_name in pretrained_models 37 | local_path = f'pretrained_models/{model_name}' 38 | if not os.path.isfile(local_path): 39 | os.makedirs('pretrained_models', exist_ok=True) 40 | web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}' 41 | download_url(web_path, 'pretrained_models') 42 | model = torch.load(local_path, map_location=lambda storage, loc: storage) 43 | return model 44 | 45 | 46 | if __name__ == "__main__": 47 | # Download all DiT checkpoints 48 | for model in pretrained_models: 49 | download_model(model) 50 | print('Done.') 51 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | import torch.distributed as dist 6 | from torch.utils.data import DataLoader 7 | 8 | # ===== Configs ===== 9 | 10 | class Config(object): 11 | def __init__(self, dic): 12 | for key in dic: 13 | setattr(self, key, dic[key]) 14 | 15 | def get_optimizer(parameters, opt, lr): 16 | if not hasattr(opt, 'optim'): 17 | return torch.optim.Adam(parameters, lr=lr) 18 | elif opt.optim == 'AdamW': 19 | return torch.optim.AdamW(parameters, **opt.optim_args, lr=lr) 20 | else: 21 | raise NotImplementedError() 22 | 23 | # ===== Multi-GPU training ===== 24 | 25 | def init_seeds(RANDOM_SEED=1337, no=0): 26 | RANDOM_SEED += no 27 | print("local_rank = {}, seed = {}".format(no, RANDOM_SEED)) 28 | random.seed(RANDOM_SEED) 29 | np.random.seed(RANDOM_SEED) 30 | torch.manual_seed(RANDOM_SEED) 31 | torch.cuda.manual_seed_all(RANDOM_SEED) 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = False 34 | 35 | 36 | def reduce_tensor(tensor): 37 | rt = tensor.clone() 38 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 39 | rt /= dist.get_world_size() 40 | return rt 41 | 42 | 43 | def gather_tensor(tensor): 44 | tensor_list = [tensor.clone() for _ in range(dist.get_world_size())] 45 | dist.all_gather(tensor_list, tensor) 46 | tensor_list = torch.cat(tensor_list, dim=0) 47 | return tensor_list 48 | 49 | 50 | def DataLoaderDDP(dataset, batch_size, shuffle=True): 51 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) 52 | dataloader = DataLoader( 53 | dataset, 54 | batch_size=batch_size, 55 | sampler=sampler, 56 | num_workers=1, 57 | ) 58 | return dataloader, sampler 59 | 60 | def print0(*args, **kwargs): 61 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: 62 | print(*args, **kwargs) 63 | -------------------------------------------------------------------------------- /DiT/README.md: -------------------------------------------------------------------------------- 1 | ## DDAE/DiT 2 | 3 | This subfolder contains transfer learning evaluation for ImageNet-256 pre-trained [DiT-XL/2](https://github.com/facebookresearch/DiT) checkpoint, by: 4 | - evaluating: 5 | - [x] Linear probing 6 | - [ ] Fine-tuning 7 | - performance on these datasets: 8 | - [x] CIFAR-10 9 | - [x] Tiny-ImageNet 10 | 11 | This implementation uses very small batch sizes, lightweight data augmentations, and a standard Adam optimizer, without advanced optimizer (e.g., LARS) and large batch sizes. However, incorporating these modern tricks may further improve performances. 12 | 13 | ## Main results 14 | The pre-trained DiT-XL/2 is expected to achieve $85.73$ % linear probing accuracy on CIFAR-10, and $66.57$ % on Tiny-ImageNet. 15 | 16 | ## Usage 17 | ### Data pre-processing 18 | Since DiT is operating in the latent-space, we need to resize the images to $256\times256$ and generate their latent codes (shape: $(4,32,32)$ ) through the VAE encoder. 19 | 20 | To reduce the computational cost at the training, we use `vae_preprocessing.py` to pre-calculate and cache the latent codes into files. Since data augmentations are essential for effective discriminative learning, we generate multiple versions (by default, 10) of latent codes to cover different variations of augmented images. Please refer to `vae_preprocessing.py` for more details. 21 | 22 | ```sh 23 | python -m torch.distributed.launch --nproc_per_node=4 24 | # pre-processing with VAE encoding 25 | vae_preprocessing.py --dataset cifar --use_amp 26 | vae_preprocessing.py --dataset tiny --use_amp 27 | ``` 28 | 29 | ### Linear probing 30 | To linear probe the features produced by pre-trained DiT, for example, run: 31 | ```sh 32 | python -m torch.distributed.launch --nproc_per_node=4 33 | # linear probing with default layer-noise combination 34 | linear.py --dataset cifar --use_amp 35 | linear.py --dataset tiny --use_amp 36 | ``` 37 | Note that this implementation loads ALL versions of the augmented dataset (by default, 10) into the memory, and hence it requires A LOT OF memory to run (e.g., 50 GB for CIFAR, 80GB for Tiny-ImageNet). 38 | You can improve this by dumping each latent code into a standalone numpy file and only load it when needed, in case you don't have enough memory to work with. 39 | 40 | ## Acknowledgments 41 | Except for `vae_preprocessing.py` and `linear.py`, all codes are retrieved or modified from the official [DiT](https://github.com/facebookresearch/DiT) repository. 42 | -------------------------------------------------------------------------------- /DiT/diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from torchvision.datasets import CIFAR10, CIFAR100 7 | 8 | 9 | class TinyImageNet(Dataset): 10 | def __init__(self, root, train=True, transform=None): 11 | if not root.endswith("tiny-imagenet-200"): 12 | root = os.path.join(root, "tiny-imagenet-200") 13 | self.train_dir = os.path.join(root, "train") 14 | self.val_dir = os.path.join(root, "val") 15 | self.transform = transform 16 | if train: 17 | self._scan_train() 18 | else: 19 | self._scan_val() 20 | 21 | def _scan_train(self): 22 | classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()] 23 | classes = sorted(classes) 24 | assert len(classes) == 200 25 | 26 | self.data = [] 27 | for idx, name in enumerate(classes): 28 | this_dir = os.path.join(self.train_dir, name) 29 | for root, _, files in sorted(os.walk(this_dir)): 30 | for fname in sorted(files): 31 | if fname.endswith(".JPEG"): 32 | path = os.path.join(root, fname) 33 | item = (path, idx) 34 | self.data.append(item) 35 | self.labels_dict = {i: classes[i] for i in range(len(classes))} 36 | 37 | def _scan_val(self): 38 | self.file_to_class = {} 39 | classes = set() 40 | with open(os.path.join(self.val_dir, "val_annotations.txt"), 'r') as f: 41 | lines = f.readlines() 42 | for line in lines: 43 | words = line.split("\t") 44 | self.file_to_class[words[0]] = words[1] 45 | classes.add(words[1]) 46 | classes = sorted(list(classes)) 47 | assert len(classes) == 200 48 | 49 | class_to_idx = {classes[i]: i for i in range(len(classes))} 50 | self.data = [] 51 | this_dir = os.path.join(self.val_dir, "images") 52 | for root, _, files in sorted(os.walk(this_dir)): 53 | for fname in sorted(files): 54 | if fname.endswith(".JPEG"): 55 | path = os.path.join(root, fname) 56 | idx = class_to_idx[self.file_to_class[fname]] 57 | item = (path, idx) 58 | self.data.append(item) 59 | self.labels_dict = {i: classes[i] for i in range(len(classes))} 60 | 61 | def __len__(self): 62 | return len(self.data) 63 | 64 | def __getitem__(self, idx): 65 | path, label = self.data[idx] 66 | image = Image.open(path) 67 | image = image.convert("RGB") 68 | 69 | if self.transform: 70 | image = self.transform(image) 71 | 72 | return image, label 73 | 74 | 75 | def get_dataset(name, root="./data", train=True, flip=False, crop=False, resize=None): 76 | if name == 'cifar': 77 | DATASET = CIFAR10 78 | RES = 32 79 | elif name == 'cifar100': 80 | DATASET = CIFAR100 81 | RES = 32 82 | elif name == 'tiny': 83 | DATASET = TinyImageNet 84 | RES = 64 85 | else: 86 | raise NotImplementedError 87 | 88 | tf = [transforms.ToTensor()] 89 | if resize is not None: 90 | tf = [transforms.Resize(resize)] + tf 91 | if train: 92 | if crop: 93 | tf = [transforms.RandomCrop(RES, 4)] + tf 94 | if flip: 95 | tf = [transforms.RandomHorizontalFlip()] + tf 96 | 97 | return DATASET(root=root, train=train, transform=transforms.Compose(tf)) 98 | -------------------------------------------------------------------------------- /DiT/vae_preprocessing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from tqdm import tqdm 4 | import numpy as np 5 | 6 | import torch 7 | import torch.distributed as dist 8 | from torch.cuda.amp import autocast as autocast 9 | from torchvision.utils import save_image 10 | 11 | from diffusers.models import AutoencoderKL 12 | 13 | import sys 14 | sys.path.append("..") 15 | from datasets import get_dataset 16 | from utils import init_seeds, gather_tensor, DataLoaderDDP, print0 17 | 18 | 19 | def show(imgs, title="debug.png"): 20 | save_image(imgs, title, normalize=True, value_range=(0, 1)) 21 | 22 | 23 | def main(opt): 24 | name = opt.dataset 25 | local_rank = opt.local_rank 26 | num_copies = opt.num_copies 27 | use_amp = opt.use_amp 28 | 29 | save_dir = os.path.join('./latent_codes', name) 30 | if local_rank == 0: 31 | os.makedirs(save_dir, exist_ok=False) 32 | 33 | device = "cuda:%d" % local_rank 34 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device) 35 | 36 | def encode(img): 37 | with torch.no_grad(): 38 | code = vae.encode(img.to(device) * 2 - 1) 39 | return 0.18215 * code.latent_dist.sample() 40 | 41 | def decode(code): 42 | with torch.no_grad(): 43 | recon = vae.decode(code / 0.18215).sample.cpu() 44 | return (recon + 1) / 2 45 | 46 | train_dataset = get_dataset(name, root="../data", train=True, resize=256, flip=True, crop=True) 47 | test_dataset = get_dataset(name, root="../data", train=False, resize=256) 48 | for dataset, epochs, string in [(train_dataset, num_copies, 'train'), (test_dataset, 1, 'test')]: 49 | loader, sampler = DataLoaderDDP( 50 | dataset, 51 | batch_size=1, 52 | shuffle=False, 53 | ) 54 | 55 | for ep in range(epochs): 56 | sampler.set_epoch(ep) 57 | data = [] 58 | label = [] 59 | for i, (x, y) in enumerate(tqdm(loader, disable=(local_rank != 0))): 60 | x = x.to(device) 61 | y = y.to(device) 62 | with autocast(enabled=use_amp): 63 | code = encode(x).float() 64 | if local_rank == 0 and i == 0: 65 | # for visualization and debugging 66 | recon = decode(code).float() 67 | show(x, f"{string}_debug_original_{ep}.png") 68 | show(recon, f"{string}_debug_reconstruct_{ep}.png") 69 | 70 | dist.barrier() 71 | code = gather_tensor(code).cpu() 72 | data.append(code) 73 | if ep == 0: 74 | y = gather_tensor(y).cpu() 75 | label.append(y) 76 | 77 | if local_rank == 0: 78 | data = torch.cat(data) 79 | with open(os.path.join(save_dir, f"{string}_code_{ep}.npy"), 'wb') as f: 80 | np.save(f, data.numpy()) 81 | if ep == 0: 82 | label = torch.cat(label) 83 | with open(os.path.join(save_dir, f"{string}_label.npy"), 'wb') as f: 84 | np.save(f, label.numpy()) 85 | 86 | 87 | if __name__ == "__main__": 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument("--dataset", default='cifar', type=str, choices=['cifar', 'tiny']) 90 | parser.add_argument('--num_copies', default=10, type=int, 91 | help='number of training data copies, higher = more augmentation variations') 92 | parser.add_argument('--local_rank', default=-1, type=int, 93 | help='node rank for distributed training') 94 | parser.add_argument("--use_amp", action='store_true', default=False) 95 | opt = parser.parse_args() 96 | print0(opt) 97 | 98 | init_seeds(no=opt.local_rank) 99 | dist.init_process_group(backend='nccl') 100 | torch.cuda.set_device(opt.local_rank) 101 | main(opt) 102 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.distributed as dist 6 | import yaml 7 | from torchvision.utils import make_grid, save_image 8 | from ema_pytorch import EMA 9 | 10 | from model.models import get_models_class 11 | from utils import Config, init_seeds, gather_tensor, print0 12 | 13 | 14 | def get_default_steps(model_type, steps): 15 | if steps is not None: 16 | return steps 17 | else: 18 | return {'DDPM': 100, 'EDM': 18}[model_type] 19 | 20 | 21 | # ===== sampling ===== 22 | 23 | def sample(opt): 24 | yaml_path = opt.config 25 | local_rank = opt.local_rank 26 | use_amp = opt.use_amp 27 | mode = opt.mode 28 | steps = opt.steps 29 | eta = opt.eta 30 | batches = opt.batches 31 | ep = opt.epoch 32 | 33 | with open(yaml_path, 'r') as f: 34 | opt = yaml.full_load(f) 35 | print0(opt) 36 | opt = Config(opt) 37 | if ep == -1: 38 | ep = opt.n_epoch - 1 39 | 40 | device = "cuda:%d" % local_rank 41 | steps = get_default_steps(opt.model_type, steps) 42 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type) 43 | diff = DIFFUSION(nn_model=NETWORK(**opt.network), 44 | **opt.diffusion, 45 | device=device, 46 | ) 47 | diff.to(device) 48 | 49 | target = os.path.join(opt.save_dir, "ckpts", f"model_{ep}.pth") 50 | print0("loading model at", target) 51 | checkpoint = torch.load(target, map_location=device) 52 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1) 53 | ema.to(device) 54 | ema.load_state_dict(checkpoint['EMA']) 55 | model = ema.ema_model 56 | model.eval() 57 | 58 | if local_rank == 0: 59 | if opt.model_type == 'EDM': 60 | gen_dir = os.path.join(opt.save_dir, f"EMAgenerated_ep{ep}_edm_steps{steps}_eta{eta}") 61 | else: 62 | if mode == 'DDPM': 63 | gen_dir = os.path.join(opt.save_dir, f"EMAgenerated_ep{ep}_ddpm") 64 | else: 65 | gen_dir = os.path.join(opt.save_dir, f"EMAgenerated_ep{ep}_ddim_steps{steps}_eta{eta}") 66 | os.makedirs(gen_dir) 67 | gen_dir_png = os.path.join(gen_dir, "pngs") 68 | os.makedirs(gen_dir_png) 69 | res = [] 70 | 71 | for batch in range(batches): 72 | with torch.no_grad(): 73 | assert 400 % dist.get_world_size() == 0 74 | samples_per_process = 400 // dist.get_world_size() 75 | args = dict(n_sample=samples_per_process, size=opt.network['image_shape'], notqdm=(local_rank != 0), use_amp=use_amp) 76 | if opt.model_type == 'EDM': 77 | x_gen = model.edm_sample(**args, steps=steps, eta=eta) 78 | else: 79 | if mode == 'DDPM': 80 | x_gen = model.sample(**args) 81 | else: 82 | x_gen = model.ddim_sample(**args, steps=steps, eta=eta) 83 | dist.barrier() 84 | x_gen = gather_tensor(x_gen).cpu() 85 | if local_rank == 0: 86 | res.append(x_gen) 87 | grid = make_grid(x_gen, nrow=20) 88 | png_path = os.path.join(gen_dir, f"grid_{batch}.png") 89 | save_image(grid, png_path) 90 | 91 | if local_rank == 0: 92 | res = torch.cat(res) 93 | for no, img in enumerate(res): 94 | png_path = os.path.join(gen_dir_png, f"{no}.png") 95 | save_image(img, png_path) 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument("--config", type=str) 101 | parser.add_argument('--local_rank', default=-1, type=int, 102 | help='node rank for distributed training') 103 | parser.add_argument("--use_amp", action='store_true', default=False) 104 | parser.add_argument("--mode", type=str, choices=['DDPM', 'DDIM'], default='DDIM') 105 | parser.add_argument("--steps", type=int, default=None) 106 | parser.add_argument("--eta", type=float, default=0.0) 107 | parser.add_argument("--batches", type=int, default=125) 108 | parser.add_argument("--epoch", type=int, default=-1) 109 | opt = parser.parse_args() 110 | print0(opt) 111 | 112 | init_seeds(no=opt.local_rank) 113 | dist.init_process_group(backend='nccl') 114 | torch.cuda.set_device(opt.local_rank) 115 | sample(opt) 116 | -------------------------------------------------------------------------------- /model/block.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def GroupNorm32(channels): 8 | return nn.GroupNorm(32, channels) 9 | 10 | 11 | class TimeEmbedding(nn.Module): 12 | def __init__(self, n_channels, augment_dim): 13 | """ 14 | * `n_channels` is the number of dimensions in the embedding 15 | """ 16 | super().__init__() 17 | self.n_channels = n_channels 18 | self.aug_emb = nn.Linear(augment_dim, self.n_channels // 4, bias=False) if augment_dim > 0 else None 19 | 20 | self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels) 21 | self.act = nn.SiLU() 22 | self.lin2 = nn.Linear(self.n_channels, self.n_channels) 23 | 24 | def forward(self, t, aug_label): 25 | # Create sinusoidal position embeddings (same as those from the transformer) 26 | half_dim = self.n_channels // 8 27 | emb = math.log(10_000) / (half_dim - 1) 28 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=t.device) * -emb) 29 | emb = t.float()[:, None] * emb[None, :] 30 | emb = torch.cat((emb.sin(), emb.cos()), dim=1) 31 | 32 | if self.aug_emb is not None and aug_label is not None: 33 | emb += self.aug_emb(aug_label) 34 | 35 | # Transform with the MLP 36 | emb = self.act(self.lin1(emb)) 37 | emb = self.lin2(emb) 38 | return emb 39 | 40 | 41 | class AttentionBlock(nn.Module): 42 | def __init__(self, n_channels, d_k): 43 | """ 44 | * `n_channels` is the number of channels in the input 45 | * `n_heads` is the number of heads in multi-head attention 46 | * `d_k` is the number of dimensions in each head 47 | """ 48 | super().__init__() 49 | 50 | # Default `d_k` 51 | if d_k is None: 52 | d_k = n_channels 53 | n_heads = n_channels // d_k 54 | 55 | self.norm = GroupNorm32(n_channels) 56 | # Projections for query, key and values 57 | self.projection = nn.Linear(n_channels, n_heads * d_k * 3) 58 | # Linear layer for final transformation 59 | self.output = nn.Linear(n_heads * d_k, n_channels) 60 | 61 | self.scale = 1 / math.sqrt(math.sqrt(d_k)) 62 | self.n_heads = n_heads 63 | self.d_k = d_k 64 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: 65 | print(f"{self.n_heads} heads, {self.d_k} channels per head") 66 | 67 | def forward(self, x): 68 | """ 69 | * `x` has shape `[batch_size, in_channels, height, width]` 70 | """ 71 | batch_size, n_channels, height, width = x.shape 72 | # Normalize and rearrange to `[batch_size, seq, n_channels]` 73 | h = self.norm(x).view(batch_size, n_channels, -1).permute(0, 2, 1) 74 | 75 | # {q, k, v} all have a shape of `[batch_size, seq, n_heads, d_k]` 76 | qkv = self.projection(h).view(batch_size, -1, self.n_heads, 3 * self.d_k) 77 | q, k, v = torch.chunk(qkv, 3, dim=-1) 78 | 79 | attn = torch.einsum('bihd,bjhd->bijh', q * self.scale, k * self.scale) # More stable with f16 than dividing afterwards 80 | attn = attn.softmax(dim=2) 81 | res = torch.einsum('bijh,bjhd->bihd', attn, v) 82 | 83 | # Reshape to `[batch_size, seq, n_heads * d_k]` and transform to `[batch_size, seq, n_channels]` 84 | res = res.reshape(batch_size, -1, self.n_heads * self.d_k) 85 | res = self.output(res) 86 | res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width) 87 | return res + x 88 | 89 | 90 | class Upsample(nn.Module): 91 | def __init__(self, n_channels, use_conv=True): 92 | super().__init__() 93 | self.use_conv = use_conv 94 | if use_conv: 95 | self.conv = nn.Conv2d(n_channels, n_channels, kernel_size=3, stride=1, padding=1) 96 | 97 | def forward(self, x): 98 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") 99 | if self.use_conv: 100 | return self.conv(x) 101 | else: 102 | return x 103 | 104 | 105 | class Downsample(nn.Module): 106 | def __init__(self, n_channels, use_conv=True): 107 | super().__init__() 108 | self.use_conv = use_conv 109 | if use_conv: 110 | self.conv = nn.Conv2d(n_channels, n_channels, kernel_size=3, stride=2, padding=1) 111 | else: 112 | self.pool = nn.AvgPool2d(2) 113 | 114 | def forward(self, x): 115 | if self.use_conv: 116 | return self.conv(x) 117 | else: 118 | return self.pool(x) 119 | 120 | -------------------------------------------------------------------------------- /contrastive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.distributed as dist 6 | import yaml 7 | from datasets import get_dataset 8 | from tqdm import tqdm 9 | from ema_pytorch import EMA 10 | 11 | from model.models import get_models_class 12 | from utils import Config, init_seeds, gather_tensor, DataLoaderDDP, print0 13 | 14 | 15 | def get_model(opt, load_epoch): 16 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type) 17 | diff = DIFFUSION(nn_model=NETWORK(**opt.network), 18 | **opt.diffusion, 19 | device=device, 20 | ) 21 | diff.to(device) 22 | target = os.path.join(opt.save_dir, "ckpts", f"model_{load_epoch}.pth") 23 | print0("loading model at", target) 24 | checkpoint = torch.load(target, map_location=device) 25 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1) 26 | ema.to(device) 27 | ema.load_state_dict(checkpoint['EMA']) 28 | model = ema.ema_model 29 | model.eval() 30 | return model 31 | 32 | 33 | def alignment(x, y, alpha=2): 34 | return (x - y).norm(p=2, dim=1).pow(alpha).mean().item() 35 | 36 | def uniformity(x, t=2): 37 | return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log().item() 38 | 39 | 40 | class NamedMeter: 41 | def __init__(self): 42 | self.sum = {} 43 | self.count = {} 44 | self.history = {} 45 | 46 | def update(self, name, val, n=1): 47 | if name not in self.sum: 48 | self.sum[name] = 0 49 | self.count[name] = 0 50 | self.history[name] = [] 51 | 52 | self.sum[name] += val * n 53 | self.count[name] += n 54 | self.history[name].append("%.4f" % val) 55 | 56 | def get_avg(self, name): 57 | return self.sum[name] / self.count[name] 58 | 59 | def get_names(self): 60 | return self.sum.keys() 61 | 62 | 63 | def metrics(opt): 64 | yaml_path = opt.config 65 | interval = opt.epoch_interval 66 | use_amp = opt.use_amp 67 | with open(yaml_path, 'r') as f: 68 | opt = yaml.full_load(f) 69 | print0(opt) 70 | opt = Config(opt) 71 | timestep = opt.linear['timestep'] 72 | 73 | train_set_raw = get_dataset(name=opt.dataset, root="./data", train=True) 74 | train_loader_raw, _ = DataLoaderDDP( 75 | train_set_raw, 76 | batch_size=128, 77 | shuffle=False, 78 | ) 79 | 80 | check_epochs = list(range(interval, opt.n_epoch, interval)) + [opt.n_epoch - 1] 81 | align_evolving = NamedMeter() 82 | uniform_evolving = NamedMeter() 83 | 84 | print0("Using timestep =", timestep) 85 | print0("Checking epochs:", check_epochs) 86 | 87 | for load_epoch in check_epochs: 88 | model = get_model(opt, load_epoch) 89 | align_cur_epoch = NamedMeter() 90 | uniform_cur_epoch = NamedMeter() 91 | 92 | for image, _ in tqdm(train_loader_raw, disable=(local_rank!=0)): 93 | with torch.no_grad(): 94 | x = model.get_feature(image.to(device), timestep, norm=True, use_amp=use_amp) 95 | y = model.get_feature(image.to(device), timestep, norm=True, use_amp=use_amp) 96 | dist.barrier() 97 | x = {name: gather_tensor(x[name]).cpu() for name in x} 98 | y = {name: gather_tensor(y[name]).cpu() for name in y} 99 | 100 | for blockname in x: 101 | align = alignment(x[blockname].detach(), y[blockname].detach()) 102 | uniform = (uniformity(x[blockname]) + uniformity(y[blockname])) / 2 103 | # calculate metrics for a small batch 104 | align_cur_epoch.update(blockname, align, n=image.shape[0]) 105 | uniform_cur_epoch.update(blockname, uniform, n=image.shape[0]) 106 | 107 | # gather metrics for the complete dataset 108 | for blockname in align_cur_epoch.get_names(): 109 | align = align_cur_epoch.get_avg(blockname) 110 | uniform = uniform_cur_epoch.get_avg(blockname) 111 | # record metrics for each checkpoint 112 | align_evolving.update(blockname, align) 113 | uniform_evolving.update(blockname, uniform) 114 | 115 | if local_rank == 0: 116 | print(align_evolving.history.keys()) 117 | print('align metric:') 118 | for blockname in align_evolving.history: 119 | align = align_evolving.history[blockname] 120 | print("'%s': [%s]" % (blockname, ', '.join(align))) 121 | 122 | print('uniform metric:') 123 | for blockname in uniform_evolving.history: 124 | uniform = uniform_evolving.history[blockname] 125 | print("'%s': [%s]" % (blockname, ', '.join(uniform))) 126 | 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument("--config", type=str) 131 | parser.add_argument('--epoch_interval', type=int, default=400) 132 | parser.add_argument('--local_rank', default=-1, type=int, 133 | help='node rank for distributed training') 134 | parser.add_argument("--use_amp", action='store_true', default=False) 135 | opt = parser.parse_args() 136 | print0(opt) 137 | 138 | local_rank = opt.local_rank 139 | init_seeds(no=local_rank) 140 | dist.init_process_group(backend='nccl') 141 | torch.cuda.set_device(local_rank) 142 | device = "cuda:%d" % local_rank 143 | 144 | metrics(opt) 145 | -------------------------------------------------------------------------------- /DiT/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.distributed as dist 6 | import yaml 7 | from datasets import get_dataset 8 | from torchvision.utils import make_grid, save_image 9 | from tqdm import tqdm 10 | from ema_pytorch import EMA 11 | 12 | from model.models import get_models_class 13 | from utils import Config, get_optimizer, init_seeds, reduce_tensor, DataLoaderDDP, print0 14 | 15 | 16 | # ===== training ===== 17 | 18 | def train(opt): 19 | yaml_path = opt.config 20 | local_rank = opt.local_rank 21 | use_amp = opt.use_amp 22 | 23 | with open(yaml_path, 'r') as f: 24 | opt = yaml.full_load(f) 25 | print0(opt) 26 | opt = Config(opt) 27 | model_dir = os.path.join(opt.save_dir, "ckpts") 28 | vis_dir = os.path.join(opt.save_dir, "visual") 29 | if local_rank == 0: 30 | os.makedirs(model_dir, exist_ok=True) 31 | os.makedirs(vis_dir, exist_ok=True) 32 | 33 | device = "cuda:%d" % local_rank 34 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type) 35 | diff = DIFFUSION(nn_model=NETWORK(**opt.network), 36 | **opt.diffusion, 37 | device=device, 38 | ) 39 | diff.to(device) 40 | if local_rank == 0: 41 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1) 42 | ema.to(device) 43 | 44 | diff = torch.nn.SyncBatchNorm.convert_sync_batchnorm(diff) 45 | diff = torch.nn.parallel.DistributedDataParallel( 46 | diff, device_ids=[local_rank], output_device=local_rank) 47 | 48 | train_set = get_dataset(name=opt.dataset, root="./data", train=True, flip=opt.flip) 49 | print0("train dataset:", len(train_set)) 50 | 51 | train_loader, sampler = DataLoaderDDP(train_set, 52 | batch_size=opt.batch_size, 53 | shuffle=True) 54 | 55 | lr = opt.lrate 56 | DDP_multiplier = dist.get_world_size() 57 | print0("Using DDP, lr = %f * %d" % (lr, DDP_multiplier)) 58 | lr *= DDP_multiplier 59 | optim = get_optimizer(diff.parameters(), opt, lr=lr) 60 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp) 61 | 62 | if opt.load_epoch != -1: 63 | target = os.path.join(model_dir, f"model_{opt.load_epoch}.pth") 64 | print0("loading model at", target) 65 | checkpoint = torch.load(target, map_location=device) 66 | diff.load_state_dict(checkpoint['MODEL']) 67 | if local_rank == 0: 68 | ema.load_state_dict(checkpoint['EMA']) 69 | optim.load_state_dict(checkpoint['opt']) 70 | 71 | for ep in range(opt.load_epoch + 1, opt.n_epoch): 72 | for g in optim.param_groups: 73 | g['lr'] = lr * min((ep + 1.0) / opt.warm_epoch, 1.0) # warmup 74 | sampler.set_epoch(ep) 75 | dist.barrier() 76 | # training 77 | diff.train() 78 | if local_rank == 0: 79 | now_lr = optim.param_groups[0]['lr'] 80 | print(f'epoch {ep}, lr {now_lr:f}') 81 | loss_ema = None 82 | pbar = tqdm(train_loader) 83 | else: 84 | pbar = train_loader 85 | for x, c in pbar: 86 | optim.zero_grad() 87 | x = x.to(device) 88 | loss = diff(x, use_amp=use_amp) 89 | scaler.scale(loss).backward() 90 | scaler.unscale_(optim) 91 | torch.nn.utils.clip_grad_norm_(parameters=diff.parameters(), max_norm=1.0) 92 | scaler.step(optim) 93 | scaler.update() 94 | 95 | # logging 96 | dist.barrier() 97 | loss = reduce_tensor(loss) 98 | if local_rank == 0: 99 | ema.update() 100 | if loss_ema is None: 101 | loss_ema = loss.item() 102 | else: 103 | loss_ema = 0.95 * loss_ema + 0.05 * loss.item() 104 | pbar.set_description(f"loss: {loss_ema:.4f}") 105 | 106 | # testing 107 | if local_rank == 0: 108 | if ep % 100 == 0 or ep == opt.n_epoch - 1: 109 | pass 110 | else: 111 | continue 112 | 113 | if opt.model_type == 'DDPM': 114 | ema_sample_method = ema.ema_model.ddim_sample 115 | elif opt.model_type == 'EDM': 116 | ema_sample_method = ema.ema_model.edm_sample 117 | 118 | ema.ema_model.eval() 119 | with torch.no_grad(): 120 | x_gen = ema_sample_method(opt.n_sample, x.shape[1:]) 121 | # save an image of currently generated samples (top rows) 122 | # followed by real images (bottom rows) 123 | x_real = x[:opt.n_sample] 124 | x_all = torch.cat([x_gen.cpu(), x_real.cpu()]) 125 | grid = make_grid(x_all, nrow=10) 126 | 127 | save_path = os.path.join(vis_dir, f"image_ep{ep}_ema.png") 128 | save_image(grid, save_path) 129 | print('saved image at', save_path) 130 | 131 | # optionally save model 132 | if opt.save_model: 133 | checkpoint = { 134 | 'MODEL': diff.state_dict(), 135 | 'EMA': ema.state_dict(), 136 | 'opt': optim.state_dict(), 137 | } 138 | save_path = os.path.join(model_dir, f"model_{ep}.pth") 139 | torch.save(checkpoint, save_path) 140 | print('saved model at', save_path) 141 | 142 | 143 | if __name__ == "__main__": 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("--config", type=str) 146 | parser.add_argument('--local_rank', default=-1, type=int, 147 | help='node rank for distributed training') 148 | parser.add_argument("--use_amp", action='store_true', default=False) 149 | opt = parser.parse_args() 150 | print0(opt) 151 | 152 | init_seeds(no=opt.local_rank) 153 | dist.init_process_group(backend='nccl') 154 | torch.cuda.set_device(opt.local_rank) 155 | train(opt) 156 | -------------------------------------------------------------------------------- /DiT/diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /model/wideresnet_noise_song.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://github.com/yang-song/score_sde/blob/main/models/wideresnet_noise_conditional.py 2 | # As a pytorch version of the noise-conditional classifier 3 | # proposed in https://arxiv.org/abs/2011.13456, Appendix I.1 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.nn.init as init 10 | 11 | 12 | def _weights_init(m): 13 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 14 | init.kaiming_normal_(m.weight) 15 | 16 | 17 | def activation(channels, apply_relu=True): 18 | gn = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-5) 19 | if apply_relu: 20 | return nn.Sequential(gn, nn.ReLU(inplace=True)) 21 | return gn 22 | 23 | 24 | def _output_add(block_x, orig_x): 25 | """Add two tensors, padding them with zeros or pooling them if necessary. 26 | 27 | Args: 28 | block_x: Output of a resnet block. 29 | orig_x: Residual branch to add to the output of the resnet block. 30 | 31 | Returns: 32 | The sum of blocks_x and orig_x. If necessary, orig_x will be average pooled 33 | or zero padded so that its shape matches orig_x. 34 | """ 35 | stride = orig_x.shape[-2] // block_x.shape[-2] 36 | strides = (stride, stride) 37 | if block_x.shape[1] != orig_x.shape[1]: 38 | orig_x = F.avg_pool2d(orig_x, strides, strides) 39 | channels_to_add = block_x.shape[1] - orig_x.shape[1] 40 | orig_x = F.pad(orig_x, (0, 0, 0, 0, 0, channels_to_add)) 41 | return block_x + orig_x 42 | 43 | 44 | class GaussianFourierProjection(nn.Module): 45 | """Gaussian Fourier embeddings for noise levels.""" 46 | 47 | def __init__(self, embedding_size=256, scale=1.0): 48 | super().__init__() 49 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 50 | 51 | def forward(self, x): 52 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 53 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 54 | 55 | 56 | class WideResnetBlock(nn.Module): 57 | """Defines a single WideResnetBlock.""" 58 | 59 | def __init__(self, in_planes, planes, time_channels, stride=1, activate_before_residual=False): 60 | super().__init__() 61 | self.activate_before_residual = activate_before_residual 62 | 63 | self.init_bn = activation(in_planes) 64 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 65 | self.bn_2 = activation(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 67 | 68 | # Linear layer for embeddings 69 | self.time_emb = nn.Sequential( 70 | nn.SiLU(), 71 | nn.Linear(time_channels, planes) 72 | ) 73 | 74 | def forward(self, x, temb): 75 | if self.activate_before_residual: 76 | x = self.init_bn(x) 77 | orig_x = x 78 | else: 79 | orig_x = x 80 | 81 | block_x = x 82 | if not self.activate_before_residual: 83 | block_x = self.init_bn(block_x) 84 | 85 | block_x = self.conv1(block_x) 86 | block_x += self.time_emb(temb)[:, :, None, None] 87 | 88 | block_x = self.bn_2(block_x) 89 | block_x = self.conv2(block_x) 90 | 91 | return _output_add(block_x, orig_x) 92 | 93 | 94 | class WideResnetGroup(nn.Module): 95 | """Defines a WideResnetGroup.""" 96 | 97 | def __init__(self, blocks_per_group, in_planes, planes, time_channels, stride=1, activate_before_residual=False): 98 | super().__init__() 99 | self.blocks_per_group = blocks_per_group 100 | 101 | self.blocks = nn.ModuleList() 102 | for i in range(self.blocks_per_group): 103 | if i == 0: 104 | blk = WideResnetBlock(in_planes, planes, time_channels, stride, activate_before_residual) 105 | else: 106 | blk = WideResnetBlock(planes, planes, time_channels, 1, False) 107 | self.blocks.append(blk) 108 | 109 | def forward(self, x, temb): 110 | for b in self.blocks: 111 | x = b(x, temb) 112 | return x 113 | 114 | 115 | class WideResnet(nn.Module): 116 | """Defines the WideResnet Model.""" 117 | 118 | def __init__(self, blocks_per_group, channel_multiplier, in_channels=3, num_classes=10): 119 | super().__init__() 120 | time_channels = 128 * 4 121 | self.time_emb = GaussianFourierProjection(embedding_size=time_channels // 4, scale=16) 122 | self.time_emb_mlp = nn.Sequential( 123 | nn.Linear(time_channels // 2, time_channels), 124 | nn.SiLU(), 125 | nn.Linear(time_channels, time_channels), 126 | ) 127 | self.init_conv = nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False) 128 | self.group1 = WideResnetGroup(blocks_per_group, 129 | 16, 16 * channel_multiplier, 130 | time_channels, 131 | activate_before_residual=True) 132 | self.group2 = WideResnetGroup(blocks_per_group, 133 | 16 * channel_multiplier, 32 * channel_multiplier, 134 | time_channels, 135 | stride=2) 136 | self.group3 = WideResnetGroup(blocks_per_group, 137 | 32 * channel_multiplier, 64 * channel_multiplier, 138 | time_channels, 139 | stride=2) 140 | self.pre_pool_bn = activation(64 * channel_multiplier) 141 | self.final_linear = nn.Linear(64 * channel_multiplier, num_classes) 142 | 143 | self.apply(_weights_init) 144 | 145 | def forward(self, x, t): 146 | # per image standardization 147 | N = np.prod(x.shape[1:]) 148 | x = (x - x.mean(dim=(1,2,3), keepdim=True)) / torch.maximum(torch.std(x, dim=(1,2,3), keepdim=True), 1. / torch.tensor(np.sqrt(N))) 149 | 150 | temb = self.time_emb(t) 151 | temb = self.time_emb_mlp(temb) 152 | 153 | x = self.init_conv(x) 154 | x = self.group1(x, temb) 155 | x = self.group2(x, temb) 156 | x = self.group3(x, temb) 157 | x = self.pre_pool_bn(x) 158 | x = F.avg_pool2d(x, x.shape[-1]) 159 | x = x.view(x.shape[0], -1) 160 | x = self.final_linear(x) 161 | return x 162 | 163 | 164 | def test(net): 165 | import numpy as np 166 | total_params = 0 167 | 168 | for x in filter(lambda p: p.requires_grad, net.parameters()): 169 | total_params += np.prod(x.data.numpy().shape) 170 | print("Total number of params", total_params) 171 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 172 | 173 | 174 | def wide_28_10_song(in_channels=3, num_classes=10): 175 | net = WideResnet(blocks_per_group=4, channel_multiplier=10, in_channels=in_channels, num_classes=num_classes) 176 | test(net) 177 | return net 178 | -------------------------------------------------------------------------------- /noisy_classifier_DDAE.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from functools import partial 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import yaml 8 | import torch.nn as nn 9 | from datasets import get_dataset 10 | from torch.optim.lr_scheduler import CosineAnnealingLR 11 | from tqdm import tqdm 12 | from ema_pytorch import EMA 13 | 14 | from model.models import get_models_class 15 | from model.block import TimeEmbedding 16 | from utils import Config, init_seeds, reduce_tensor, gather_tensor, DataLoaderDDP, print0 17 | 18 | 19 | def get_model(opt, load_epoch): 20 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type) 21 | diff = DIFFUSION(nn_model=NETWORK(**opt.network), 22 | **opt.diffusion, 23 | device=device, 24 | ) 25 | diff.to(device) 26 | target = os.path.join(opt.save_dir, "ckpts", f"model_{load_epoch}.pth") 27 | print0("loading model at", target) 28 | checkpoint = torch.load(target, map_location=device) 29 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1) 30 | ema.to(device) 31 | ema.load_state_dict(checkpoint['EMA']) 32 | model = ema.ema_model 33 | model.eval() 34 | return model 35 | 36 | ''' Train a two-layer noise-conditional MLP classifier. 37 | This training script is similar to `linear.py` which performs linear probing test. 38 | ''' 39 | 40 | class Classifier(nn.Module): 41 | def __init__(self, feat_func, blockname, dim, num_classes): 42 | super(Classifier, self).__init__() 43 | self.feat_func = feat_func 44 | self.blockname = blockname 45 | self.time_emb = TimeEmbedding(dim, augment_dim=0) 46 | self.cls = nn.Sequential( 47 | nn.Linear(dim, 2 * dim), 48 | nn.SiLU(), 49 | nn.Linear(2 * dim, num_classes) 50 | ) 51 | 52 | def forward(self, x, t): 53 | with torch.no_grad(): 54 | x = self.feat_func(x.to(device), t=t) 55 | x = x[self.blockname].detach() 56 | return self.cls(x + self.time_emb(t, aug_label=None)) 57 | 58 | 59 | class DDPM: 60 | def __init__(self, device, n_T=1000, steps=20): 61 | self.device = device 62 | self.n_T = n_T 63 | self.test_timesteps = (torch.arange(0, self.n_T, self.n_T // steps) + 1).long().tolist() 64 | 65 | def train(self, x): 66 | _t = torch.randint(1, self.n_T + 1, (x.shape[0], )) 67 | return x, _t.to(self.device) 68 | 69 | def test(self, x, t): 70 | _t = torch.full((x.shape[0], ), t) 71 | return x, _t.to(self.device) 72 | 73 | 74 | class EDM: 75 | def __init__(self, device, steps=18): 76 | self.device = device 77 | self.steps = steps 78 | self.test_timesteps = range(1, steps + 1) 79 | 80 | def train(self, x): 81 | _t = torch.randint(1, self.steps + 1, (x.shape[0], )) 82 | return x, _t.to(self.device) 83 | 84 | def test(self, x, t): 85 | _t = torch.full((x.shape[0], ), t) 86 | return x, _t.to(self.device) 87 | 88 | 89 | def train(opt): 90 | def test(t): 91 | preds = [] 92 | labels = [] 93 | for image, label in tqdm(valid_loader, disable=(local_rank!=0)): 94 | with torch.no_grad(): 95 | model.eval() 96 | logit = model(*diff.test(image, t)) 97 | pred = logit.argmax(dim=-1) 98 | preds.append(pred) 99 | labels.append(label.to(device)) 100 | 101 | pred = torch.cat(preds) 102 | label = torch.cat(labels) 103 | dist.barrier() 104 | pred = gather_tensor(pred) 105 | label = gather_tensor(label) 106 | acc = (pred == label).sum().item() / len(label) 107 | return acc 108 | 109 | yaml_path = opt.config 110 | ep = opt.epoch 111 | use_amp = opt.use_amp 112 | with open(yaml_path, 'r') as f: 113 | opt = yaml.full_load(f) 114 | print0(opt) 115 | opt = Config(opt) 116 | if ep == -1: 117 | ep = opt.n_epoch - 1 118 | model = get_model(opt, ep) 119 | 120 | epoch = opt.linear['n_epoch'] 121 | batch_size = opt.linear['batch_size'] 122 | base_lr = opt.linear['lrate'] 123 | blockname = opt.linear['blockname'] 124 | 125 | mode = opt.model_type 126 | if mode == 'DDPM': 127 | diff = DDPM(device) 128 | elif mode == 'EDM': 129 | diff = EDM(device) 130 | else: 131 | raise NotImplementedError 132 | 133 | train_set = get_dataset(name=opt.dataset, root="./data", train=True, flip=True, crop=True) 134 | valid_set = get_dataset(name=opt.dataset, root="./data", train=False) 135 | train_loader, sampler = DataLoaderDDP( 136 | train_set, 137 | batch_size=batch_size, 138 | shuffle=True, 139 | ) 140 | valid_loader, _ = DataLoaderDDP( 141 | valid_set, 142 | batch_size=batch_size, 143 | shuffle=False, 144 | ) 145 | 146 | # define a two-layer noise-conditional MLP classifier 147 | feat_func = partial(model.get_feature, norm=False, use_amp=use_amp) 148 | with torch.no_grad(): 149 | x = feat_func(next(iter(valid_loader))[0].to(device), t=0) 150 | print0("All block names:", x.keys()) 151 | print0("Using block:", blockname) 152 | 153 | dim = x[blockname].shape[-1] 154 | model = Classifier(feat_func, blockname, dim, opt.classes).to(device) 155 | model = torch.nn.parallel.DistributedDataParallel( 156 | model, device_ids=[local_rank], output_device=local_rank) 157 | 158 | # train classifier 159 | loss_fn = nn.CrossEntropyLoss() 160 | DDP_multiplier = dist.get_world_size() 161 | print0("Using DDP, lr = %f * %d" % (base_lr, DDP_multiplier)) 162 | base_lr *= DDP_multiplier 163 | optim = torch.optim.Adam(model.parameters(), lr=base_lr) 164 | scheduler = CosineAnnealingLR(optim, epoch) 165 | for e in range(epoch): 166 | sampler.set_epoch(e) 167 | pbar = tqdm(train_loader, disable=(local_rank!=0)) 168 | for i, (image, label) in enumerate(pbar): 169 | model.train() 170 | logit = model(*diff.train(image)) 171 | label = label.to(device) 172 | loss = loss_fn(logit, label) 173 | optim.zero_grad() 174 | loss.backward() 175 | optim.step() 176 | 177 | # logging 178 | dist.barrier() 179 | loss = reduce_tensor(loss) 180 | logit = gather_tensor(logit).cpu() 181 | label = gather_tensor(label).cpu() 182 | 183 | if local_rank == 0: 184 | pred = logit.argmax(dim=-1) 185 | acc = (pred == label).sum().item() / len(label) 186 | nowlr = optim.param_groups[0]['lr'] 187 | pbar.set_description("[epoch %d / iter %d]: lr %.1e loss: %.3f, acc: %.3f" % (e, i, nowlr, loss.item(), acc)) 188 | scheduler.step() 189 | 190 | accs = {} 191 | for t in diff.test_timesteps: 192 | test_acc = test(t) 193 | print0("[timestep %d]: Test acc: %.3f" % (t, test_acc)) 194 | accs[t] = test_acc 195 | 196 | 197 | if __name__ == "__main__": 198 | parser = argparse.ArgumentParser() 199 | parser.add_argument("--config", type=str) 200 | parser.add_argument("--epoch", type=int, default=-1) 201 | parser.add_argument('--local_rank', default=-1, type=int, 202 | help='node rank for distributed training') 203 | parser.add_argument("--use_amp", action='store_true', default=False) 204 | opt = parser.parse_args() 205 | print0(opt) 206 | 207 | local_rank = opt.local_rank 208 | init_seeds(no=local_rank) 209 | dist.init_process_group(backend='nccl') 210 | torch.cuda.set_device(local_rank) 211 | device = "cuda:%d" % local_rank 212 | 213 | train(opt) 214 | -------------------------------------------------------------------------------- /linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from functools import partial 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import yaml 8 | import torch.nn as nn 9 | from datasets import get_dataset 10 | from torch.optim.lr_scheduler import CosineAnnealingLR 11 | from tqdm import tqdm 12 | from ema_pytorch import EMA 13 | 14 | from model.models import get_models_class 15 | from utils import Config, init_seeds, gather_tensor, DataLoaderDDP, print0 16 | 17 | 18 | def get_model(opt, load_epoch): 19 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type) 20 | diff = DIFFUSION(nn_model=NETWORK(**opt.network), 21 | **opt.diffusion, 22 | device=device, 23 | ) 24 | diff.to(device) 25 | target = os.path.join(opt.save_dir, "ckpts", f"model_{load_epoch}.pth") 26 | print0("loading model at", target) 27 | checkpoint = torch.load(target, map_location=device) 28 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1) 29 | ema.to(device) 30 | ema.load_state_dict(checkpoint['EMA']) 31 | model = ema.ema_model 32 | model.eval() 33 | return model 34 | 35 | 36 | class ClassifierDict(nn.Module): 37 | def __init__(self, feat_func, time_list, name_list, base_lr, epoch, img_shape, local_rank, num_classes): 38 | super(ClassifierDict, self).__init__() 39 | self.feat_func = feat_func 40 | self.times = time_list 41 | self.names = name_list 42 | self.classifiers = nn.ModuleDict() 43 | self.optims = {} 44 | self.schedulers = {} 45 | self.loss_fn = nn.CrossEntropyLoss() 46 | 47 | for time in self.times: 48 | feats = self.feat_func(torch.zeros(1, *img_shape).to(device), time) 49 | if self.names is None: 50 | self.names = list(feats.keys()) # all available names 51 | 52 | for name in self.names: 53 | key = self.make_key(time, name) 54 | layers = nn.Linear(feats[name].shape[1], num_classes) 55 | layers = torch.nn.parallel.DistributedDataParallel( 56 | layers.to(device), device_ids=[local_rank], output_device=local_rank) 57 | optimizer = torch.optim.Adam(layers.parameters(), lr=base_lr) 58 | scheduler = CosineAnnealingLR(optimizer, epoch) 59 | self.classifiers[key] = layers 60 | self.optims[key] = optimizer 61 | self.schedulers[key] = scheduler 62 | 63 | def train(self, x, y): 64 | self.classifiers.train() 65 | for time in self.times: 66 | feats = self.feat_func(x, time) 67 | for name in self.names: 68 | key = self.make_key(time, name) 69 | representation = feats[name].detach() 70 | logit = self.classifiers[key](representation) 71 | loss = self.loss_fn(logit, y) 72 | 73 | self.optims[key].zero_grad() 74 | loss.backward() 75 | self.optims[key].step() 76 | 77 | def test(self, x): 78 | outputs = {} 79 | with torch.no_grad(): 80 | self.classifiers.eval() 81 | for time in self.times: 82 | feats = self.feat_func(x, time) 83 | for name in self.names: 84 | key = self.make_key(time, name) 85 | representation = feats[name].detach() 86 | logit = self.classifiers[key](representation) 87 | pred = logit.argmax(dim=-1) 88 | outputs[key] = pred 89 | return outputs 90 | 91 | def make_key(self, t, n): 92 | return str(t) + '/' + n 93 | 94 | def get_lr(self): 95 | key = self.make_key(self.times[0], self.names[0]) 96 | optim = self.optims[key] 97 | return optim.param_groups[0]['lr'] 98 | 99 | def schedule_step(self): 100 | for time in self.times: 101 | for name in self.names: 102 | key = self.make_key(time, name) 103 | self.schedulers[key].step() 104 | 105 | 106 | def train(opt): 107 | def test(): 108 | preds = {k: [] for k in classifiers.optims.keys()} 109 | accs = {} 110 | labels = [] 111 | for image, label in tqdm(valid_loader, disable=(local_rank!=0)): 112 | outputs = classifiers.test(image.to(device)) 113 | for key in outputs: 114 | preds[key].append(outputs[key]) 115 | labels.append(label.to(device)) 116 | 117 | for key in preds: 118 | preds[key] = torch.cat(preds[key]) 119 | label = torch.cat(labels) 120 | dist.barrier() 121 | label = gather_tensor(label) 122 | for key in preds: 123 | pred = gather_tensor(preds[key]) 124 | accs[key] = (pred == label).sum().item() / len(label) 125 | return accs 126 | 127 | yaml_path = opt.config 128 | ep = opt.epoch 129 | use_amp = opt.use_amp 130 | grid_search = opt.grid 131 | with open(yaml_path, 'r') as f: 132 | opt = yaml.full_load(f) 133 | print0(opt) 134 | opt = Config(opt) 135 | if ep == -1: 136 | ep = opt.n_epoch - 1 137 | model = get_model(opt, ep) 138 | 139 | epoch = opt.linear['n_epoch'] 140 | batch_size = opt.linear['batch_size'] 141 | base_lr = opt.linear['lrate'] 142 | 143 | if grid_search: 144 | time_list = [1, 11, 21] if opt.model_type == 'DDPM' else [3, 4, 5] 145 | name_list = None 146 | else: 147 | time_list = [opt.linear['timestep']] 148 | name_list = [opt.linear['blockname']] 149 | 150 | train_set = get_dataset(name=opt.dataset, root="./data", train=True, flip=True, crop=True) 151 | valid_set = get_dataset(name=opt.dataset, root="./data", train=False) 152 | train_loader, sampler = DataLoaderDDP( 153 | train_set, 154 | batch_size=batch_size, 155 | shuffle=True, 156 | ) 157 | valid_loader, _ = DataLoaderDDP( 158 | valid_set, 159 | batch_size=batch_size, 160 | shuffle=False, 161 | ) 162 | 163 | feat_func = partial(model.get_feature, norm=False, use_amp=use_amp) 164 | DDP_multiplier = dist.get_world_size() 165 | print0("Using DDP, lr = %f * %d" % (base_lr, DDP_multiplier)) 166 | base_lr *= DDP_multiplier 167 | classifiers = ClassifierDict(feat_func, time_list, name_list, 168 | base_lr, epoch, opt.network['image_shape'], local_rank, opt.classes).to(model.device) 169 | 170 | for e in range(epoch): 171 | sampler.set_epoch(e) 172 | pbar = tqdm(train_loader, disable=(local_rank!=0)) 173 | for i, (image, label) in enumerate(pbar): 174 | pbar.set_description("[epoch %d / iter %d]: lr: %.1e" % (e, i, classifiers.get_lr())) 175 | classifiers.train(image.to(device), label.to(device)) 176 | classifiers.schedule_step() 177 | 178 | accs = test() 179 | for key in accs: 180 | print0("[key %s]: Test acc: %.2f" % (key, accs[key] * 100)) 181 | 182 | 183 | if __name__ == "__main__": 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument("--config", type=str) 186 | parser.add_argument("--epoch", type=int, default=-1) 187 | parser.add_argument('--local_rank', default=-1, type=int, 188 | help='node rank for distributed training') 189 | parser.add_argument("--use_amp", action='store_true', default=False) 190 | parser.add_argument("--grid", action='store_true', default=False) 191 | opt = parser.parse_args() 192 | print0(opt) 193 | 194 | local_rank = opt.local_rank 195 | init_seeds(no=local_rank) 196 | dist.init_process_group(backend='nccl') 197 | torch.cuda.set_device(local_rank) 198 | device = "cuda:%d" % local_rank 199 | 200 | train(opt) 201 | -------------------------------------------------------------------------------- /DiT/linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import numpy as np 4 | from functools import partial 5 | 6 | import torch 7 | import torch.distributed as dist 8 | from torch.cuda.amp import autocast as autocast 9 | import torch.nn as nn 10 | from torch.utils.data import Dataset 11 | from torch.optim.lr_scheduler import CosineAnnealingLR 12 | from tqdm import tqdm 13 | from torch.cuda.amp import autocast as autocast 14 | 15 | from diffusion import create_diffusion 16 | from download import find_model 17 | from models import DiT_XL_2 18 | import sys 19 | sys.path.append("..") 20 | from utils import init_seeds, gather_tensor, DataLoaderDDP, print0 21 | 22 | 23 | class LatentCodeDataset(Dataset): 24 | # warning: needs A LOT OF memory to load these datasets ! 25 | def __init__(self, dataset, train=True, num_copies=10): 26 | if train: 27 | code_path = [f"latent_codes/{dataset}/train_code_{i}.npy" for i in range(num_copies)] 28 | label_path = f"latent_codes/{dataset}/train_label.npy" 29 | else: 30 | code_path = [f"latent_codes/{dataset}/test_code_0.npy"] 31 | label_path = f"latent_codes/{dataset}/test_label.npy" 32 | 33 | self.code = [] 34 | for p in code_path: 35 | with open(p, 'rb') as f: 36 | data = np.load(f) 37 | self.code.append(data) 38 | with open(label_path, 'rb') as f: 39 | self.label = np.load(f) 40 | 41 | print0(f"Code shape: {len(self.code)} x {self.code[0].shape}") 42 | print0("Label shape:", self.label.shape) 43 | 44 | def __getitem__(self, index): 45 | replica = random.randrange(len(self.code)) 46 | code = self.code[replica][index] 47 | label = self.label[index] 48 | return code, label 49 | 50 | def __len__(self): 51 | return len(self.code[0]) 52 | 53 | 54 | def get_model(device): 55 | model = DiT_XL_2().to(device) 56 | state_dict = find_model(f"DiT-XL-2-256x256.pt") 57 | model.load_state_dict(state_dict) 58 | model.eval() 59 | diffusion = create_diffusion(None) # 1000-len betas 60 | return model, diffusion 61 | 62 | 63 | def denoise_feature(code, model, timestep, blockname, use_amp): 64 | ''' 65 | Args: 66 | `image`: Latent codes. (-1, 4, 32, 32) tensor. 67 | `timestep`: Time step to extract features. int. 68 | `blockname`: Block to extract features. str. 69 | Returns: 70 | Collected feature map. 71 | ''' 72 | x = code.to(device) 73 | t = torch.tensor([timestep]).to(device).repeat(x.shape[0]) 74 | noise = torch.randn_like(x) 75 | x_t = diffusion.q_sample(x, t, noise=noise) 76 | y_null = torch.tensor([1000] * x.shape[0], device=device) 77 | 78 | with torch.no_grad(): 79 | with autocast(enabled=use_amp): 80 | _, acts = model(x_t, t, y_null, ret_activation=True) 81 | feat = acts[blockname].float().detach() 82 | # (-1, 256, 1152) 83 | # we average pool across the sequence dimension to extract 84 | # a 1152-dimensional vector of features per example 85 | return feat.mean(dim=1) 86 | 87 | 88 | class Classifier(nn.Module): 89 | def __init__(self, feat_func, base_lr, epoch, num_classes): 90 | super(Classifier, self).__init__() 91 | self.feat_func = feat_func 92 | self.loss_fn = nn.CrossEntropyLoss() 93 | 94 | hidden_size = feat_func(next(iter(valid_loader))[0]).shape[-1] 95 | layers = nn.Sequential( 96 | # nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6), 97 | nn.Linear(hidden_size, num_classes), 98 | ) 99 | layers = torch.nn.parallel.DistributedDataParallel( 100 | layers.to(device), device_ids=[local_rank], output_device=local_rank) 101 | self.classifier = layers 102 | self.optim = torch.optim.Adam(self.classifier.parameters(), lr=base_lr) 103 | self.scheduler = CosineAnnealingLR(self.optim, epoch) 104 | 105 | def train(self, x, y): 106 | self.classifier.train() 107 | feat = self.feat_func(x) 108 | logit = self.classifier(feat) 109 | loss = self.loss_fn(logit, y) 110 | 111 | self.optim.zero_grad() 112 | loss.backward() 113 | self.optim.step() 114 | 115 | def test(self, x): 116 | with torch.no_grad(): 117 | self.classifier.eval() 118 | feat = self.feat_func(x) 119 | logit = self.classifier(feat) 120 | pred = logit.argmax(dim=-1) 121 | return pred 122 | 123 | def get_lr(self): 124 | return self.optim.param_groups[0]['lr'] 125 | 126 | def schedule_step(self): 127 | self.scheduler.step() 128 | 129 | 130 | def train(model, timestep, blockname, epoch, base_lr, use_amp): 131 | def test(): 132 | preds = [] 133 | labels = [] 134 | for image, label in tqdm(valid_loader, disable=(local_rank!=0)): 135 | pred = classifier.test(image.to(device)) 136 | preds.append(pred) 137 | labels.append(label.to(device)) 138 | 139 | pred = torch.cat(preds) 140 | label = torch.cat(labels) 141 | dist.barrier() 142 | pred = gather_tensor(pred) 143 | label = gather_tensor(label) 144 | acc = (pred == label).sum().item() / len(label) 145 | return acc 146 | 147 | print0(f"Feature extraction: time = {timestep}, name = {blockname}") 148 | feat_func = partial(denoise_feature, model=model, timestep=timestep, blockname=blockname, use_amp=use_amp) 149 | DDP_multiplier = dist.get_world_size() 150 | print0("Using DDP, lr = %f * %d" % (base_lr, DDP_multiplier)) 151 | base_lr *= DDP_multiplier 152 | num_classes = 10 if opt.dataset == 'cifar' else 200 153 | 154 | classifier = Classifier(feat_func, base_lr, epoch, num_classes).to(device) 155 | 156 | for e in range(epoch): 157 | sampler.set_epoch(e) 158 | pbar = tqdm(train_loader, disable=(local_rank!=0)) 159 | for i, (image, label) in enumerate(pbar): 160 | pbar.set_description("[epoch %d / iter %d]: lr: %.1e" % (e, i, classifier.get_lr())) 161 | classifier.train(image.to(device), label.to(device)) 162 | classifier.schedule_step() 163 | 164 | acc = test() 165 | print0("Test acc: %.2f" % (acc * 100)) 166 | 167 | 168 | def get_default_time(dataset, t): 169 | if t > 0: 170 | return t 171 | else: 172 | return {'cifar': 121, 'tiny': 81}[dataset] 173 | 174 | 175 | def get_default_name(dataset, b): 176 | if b != 'layer-0': 177 | return b 178 | else: 179 | return {'cifar': 'layer-13', 'tiny': 'layer-13'}[dataset] 180 | 181 | 182 | if __name__ == "__main__": 183 | parser = argparse.ArgumentParser() 184 | parser.add_argument("--dataset", default='cifar', type=str, choices=['cifar', 'tiny']) 185 | parser.add_argument('--local_rank', default=-1, type=int, 186 | help='node rank for distributed training') 187 | parser.add_argument("--use_amp", action='store_true', default=False) 188 | parser.add_argument('--batch_size', default=128, type=int) 189 | parser.add_argument('--lr', default=1e-3, type=float) 190 | parser.add_argument('--epoch', default=30, type=int) 191 | parser.add_argument('--time', type=int, default=0) 192 | parser.add_argument('--name', type=str, default='layer-0') 193 | opt = parser.parse_args() 194 | 195 | local_rank = opt.local_rank 196 | init_seeds(no=local_rank) 197 | dist.init_process_group(backend='nccl') 198 | torch.cuda.set_device(local_rank) 199 | device = "cuda:%d" % local_rank 200 | model, diffusion = get_model(device) 201 | 202 | train_set = LatentCodeDataset(opt.dataset, train=True) 203 | valid_set = LatentCodeDataset(opt.dataset, train=False) 204 | train_loader, sampler = DataLoaderDDP( 205 | train_set, 206 | batch_size=opt.batch_size, 207 | shuffle=True, 208 | ) 209 | valid_loader, _ = DataLoaderDDP( 210 | valid_set, 211 | batch_size=opt.batch_size, 212 | shuffle=False, 213 | ) 214 | 215 | # default timestep & blockname values 216 | opt.time = get_default_time(opt.dataset, opt.time) 217 | opt.name = get_default_name(opt.dataset, opt.name) 218 | 219 | print0(opt) 220 | train(model, timestep=opt.time, blockname=opt.name, epoch=opt.epoch, base_lr=opt.lr, use_amp=opt.use_amp) 221 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 🆕 **[2025] Please check out the more recent study [DDAE++](https://github.com/FutureXiang/ddae_plus_plus) continuing this line of work.** 2 | 3 | 4 | # Denoising Diffusion Autoencoders (DDAE) 5 | 6 |

7 | 8 |

9 | 10 | This is a multi-gpu PyTorch implementation of the paper [Denoising Diffusion Autoencoders are Unified Self-supervised Learners](https://arxiv.org/abs/2303.09769): 11 | ```bibtex 12 | @inproceedings{ddae2023, 13 | title={Denoising Diffusion Autoencoders are Unified Self-supervised Learners}, 14 | author={Xiang, Weilai and Yang, Hongyu and Huang, Di and Wang, Yunhong}, 15 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 16 | year={2023} 17 | } 18 | ``` 19 | :star: (News) Our paper is cited by Kaiming He's new paper [Deconstructing Denoising Diffusion Models for Self-Supervised Learning](https://arxiv.org/abs/2401.14404), check it out! :fire: 20 | 21 | ## Overview 22 | 23 | This repo contains: 24 | - [x] Pre-training, sampling and FID evaluation code for diffusion models, including 25 | - Frameworks: 26 | - [x] DDPM & DDIM 27 | - [x] EDM (w/ or w/o data augmentation) 28 | - Networks: 29 | - [x] The basic 35.7M DDPM UNet 30 | - [x] A larger 56M DDPM++ UNet 31 | - Datasets: 32 | - [x] CIFAR-10 33 | - [ ] Tiny-ImageNet 34 | - [x] Feature quality evaluation code, including 35 | - [x] Linear probing and grid searching 36 | - [x] Contrastive metrics, i.e., alignment and uniformity 37 | - [ ] Fine-tuning 38 | - [x] Noise-conditional classifier training and evaluation, including 39 | - [x] MLP classifier based on DDPM/EDM features 40 | - [x] WideResNet with VP/VE perturbation 41 | - [x] Evaluation code for ImageNet-256 pre-trained [DiT-XL/2](https://github.com/facebookresearch/DiT) checkpoint 42 | 43 | ## Requirements 44 | - In addition to PyTorch environments, please install: 45 | ```sh 46 | conda install pyyaml 47 | pip install pytorch-fid ema-pytorch 48 | ``` 49 | - We use 4 or 8 3080ti GPUs to conduct all the experiments presented in the paper. With automatic mixed precision enabled and 4 GPUs, training a basic 35.7M UNet on CIFAR-10 takes ~14 hours. 50 | - The `pytorch-fid` requires image files to calculate the FID metric. Please refer to `extract_cifar10_pngs.ipynb` to unpack the CIFAR-10 training dataset into 50000 `.png` image files. 51 | 52 | ## Main results 53 | We present the generative and discriminative evaluation results that can be obtained by this codebase. The `EDM_ddpmpp_aug.yaml` training is performed on 8 GPUs, while other models are trained on 4 GPUs. 54 | 55 | Please note that this is a *over-simplified* DDPM / EDM implementation, and some network details, initialization, and hyper-parameters may *differ from* official ones. Please refer to their respective official codebases to reproduce the *exact results* reported in the paper. 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 |
ConfigModelNetworkBest linear probe checkpointBest FID checkpoint
epochFIDaccepochFIDacc
DDPM_ddpm.yamlDDPM35.7M UNet8004.0990.0519993.6288.23
EDM_ddpm.yamlEDM35.7M UNet12003.9790.4419993.5689.71
DDPM_ddpmpp.yamlDDPM56.5M DDPM++12003.0893.9719992.9893.03
EDM_ddpmpp.yamlEDM56.5M DDPM++12002.2394.50(same)
EDM_ddpmpp_aug.yamlEDM + data aug56.5M DDPM++20002.3495.4932002.1295.19
132 | 133 | FIDs are calculated using 50000 images generated by the deterministic fast sampler (DDIM 100 steps or EDM 18 steps). 134 | 135 | ## Latent-space DiT 136 | We evaluate pre-trained Transformer-based diffusion networks, [DiT](https://github.com/facebookresearch/DiT), from the perspective of *transfer learning*. Please refer to the [ddae/DiT](DiT/) subfolder. 137 | 138 | ## Usage 139 | ### Diffusion pre-training 140 | To train a DDAE model and generate 50000 image samples with 4 GPUs, for example, run: 141 | ```sh 142 | python -m torch.distributed.launch --nproc_per_node=4 143 | # diffusion pre-training with AMP enabled 144 | train.py --config config/DDPM_ddpm.yaml --use_amp 145 | 146 | # deterministic fast sampling (i.e. DDIM 100 steps / EDM 18 steps) 147 | sample.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 148 | 149 | # stochastic sampling (i.e. DDPM 1000 steps) 150 | sample.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --mode DDPM 151 | ``` 152 | To calculate the FID metric on the training set, for example, run: 153 | ```sh 154 | python -m pytorch_fid data/cifar10-pngs/ output_DDPM_ddpm/EMAgenerated_ep400_ddim_steps100_eta0.0/pngs/ 155 | ``` 156 | 157 | ### Features produced by DDAE 158 | To evaluate the features produced by pre-trained DDAE, for example, run: 159 | ```sh 160 | python -m torch.distributed.launch --nproc_per_node=4 161 | # grid searching for proper layer-noise combination 162 | linear.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --grid 163 | 164 | # linear probing, using the layer-noise combination specified by config.yaml 165 | linear.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 166 | 167 | # showing the alignment-uniformity metrics with respect to different checkpoints 168 | contrastive.py --config config/DDPM_ddpm.yaml --use_amp 169 | ``` 170 | 171 | ### Noise-conditional classifier 172 | To train WideResNet-based classifiers from scratch: 173 | ```sh 174 | python -m torch.distributed.launch --nproc_per_node=4 175 | # VP (DDPM) perturbation 176 | noisy_classifier_WRN.py --mode DDPM 177 | # VE (EDM) perturbation 178 | noisy_classifier_WRN.py --mode EDM 179 | ``` 180 | and compare their noise-conditional recognition rates with DDAE-based MLP classifier heads: 181 | ```sh 182 | python -m torch.distributed.launch --nproc_per_node=4 183 | # using DDPM DDAE encoder 184 | noisy_classifier_DDAE.py --config config/DDPM_ddpm.yaml --use_amp --epoch 1999 185 | # using EDM DDAE encoder 186 | noisy_classifier_DDAE.py --config config/EDM_ddpmpp.yaml --use_amp --epoch 1200 187 | ``` 188 | 189 | ## Acknowledgments 190 | This repository is built on numerous open-source codebases such as [DDPM](https://github.com/hojonathanho/diffusion), [DDPM-pytorch](https://github.com/pesser/pytorch_diffusion), [DDIM](https://github.com/ermongroup/ddim), [EDM](https://github.com/NVlabs/edm), [Score-based SDE](https://github.com/yang-song/score_sde), [DiT](https://github.com/facebookresearch/DiT), and [align_uniform](https://github.com/SsnL/align_uniform). 191 | -------------------------------------------------------------------------------- /noisy_classifier_WRN.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | 4 | import torch 5 | import torch.distributed as dist 6 | import torch.nn as nn 7 | from datasets import get_dataset 8 | from torch.optim.lr_scheduler import CosineAnnealingLR 9 | from tqdm import tqdm 10 | 11 | from model.wideresnet_noise_song import wide_28_10_song 12 | from utils import init_seeds, reduce_tensor, gather_tensor, DataLoaderDDP, print0 13 | 14 | 15 | def normalize_to_neg_one_to_one(img): 16 | # [0.0, 1.0] -> [-1.0, 1.0] 17 | return img * 2 - 1 18 | 19 | 20 | class DDPM: 21 | def __init__(self, device, betas=[1.0e-4, 0.02], n_T=1000, steps=20): 22 | self.device = device 23 | self.n_T = n_T 24 | self.ddpm_sche = self.schedules(betas, n_T, device, 'DDPM') 25 | self.test_timesteps = (torch.arange(0, self.n_T, self.n_T // steps) + 1).long().tolist() 26 | 27 | def train(self, x): 28 | x = normalize_to_neg_one_to_one(x).to(self.device) 29 | # Perturbation 30 | _ts = torch.randint(1, self.n_T + 1, (x.shape[0], )).to(self.device) 31 | noise = torch.randn_like(x) 32 | sche = self.ddpm_sche 33 | x_noised = (sche["sqrtab"][_ts, None, None, None] * x + 34 | sche["sqrtmab"][_ts, None, None, None] * noise) 35 | return x_noised, _ts / self.n_T 36 | 37 | def test(self, x, t): 38 | x = normalize_to_neg_one_to_one(x).to(self.device) 39 | # Perturbation 40 | _ts = torch.tensor([t]).to(self.device).repeat(x.shape[0]) 41 | noise = torch.randn_like(x) 42 | sche = self.ddpm_sche 43 | x_noised = (sche["sqrtab"][_ts, None, None, None] * x + 44 | sche["sqrtmab"][_ts, None, None, None] * noise) 45 | return x_noised, _ts / self.n_T 46 | 47 | def schedules(self, betas, T, device, type='DDPM'): 48 | def linear_beta_schedule(timesteps, beta1, beta2): 49 | assert 0.0 < beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)" 50 | return torch.linspace(beta1, beta2, timesteps) 51 | 52 | beta1, beta2 = betas 53 | schedule_fn = partial(linear_beta_schedule, beta1=beta1, beta2=beta2) 54 | 55 | if type == 'DDPM': 56 | beta_t = torch.cat([torch.tensor([0.0]), schedule_fn(T)]) 57 | elif type == 'DDIM': 58 | beta_t = schedule_fn(T + 1) 59 | else: 60 | raise NotImplementedError() 61 | sqrt_beta_t = torch.sqrt(beta_t) 62 | alpha_t = 1 - beta_t 63 | log_alpha_t = torch.log(alpha_t) 64 | alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() 65 | 66 | sqrtab = torch.sqrt(alphabar_t) 67 | oneover_sqrta = 1 / torch.sqrt(alpha_t) 68 | 69 | sqrtmab = torch.sqrt(1 - alphabar_t) 70 | ma_over_sqrtmab = (1 - alpha_t) / sqrtmab 71 | 72 | dic = { 73 | "alpha_t": alpha_t, 74 | "oneover_sqrta": oneover_sqrta, 75 | "sqrt_beta_t": sqrt_beta_t, 76 | "alphabar_t": alphabar_t, 77 | "sqrtab": sqrtab, 78 | "sqrtmab": sqrtmab, 79 | "ma_over_sqrtmab": ma_over_sqrtmab, 80 | } 81 | return {key: dic[key].to(device) for key in dic} 82 | 83 | 84 | class EDM: 85 | def __init__(self, device, p_std=1.2, p_mean=-1.2, sigma_min=0.002, sigma_max=80, rho=7, steps=18): 86 | self.device = device 87 | self.p_std = p_std 88 | self.p_mean = p_mean 89 | self.times = self.schedules(sigma_min, sigma_max, rho, steps) 90 | self.test_timesteps = range(1, steps + 1) 91 | 92 | def train(self, x): 93 | x = normalize_to_neg_one_to_one(x).to(self.device) 94 | # Perturbation 95 | rnd_normal = torch.randn((x.shape[0], 1, 1, 1)).to(self.device) 96 | sigma = (rnd_normal * self.p_std + self.p_mean).exp() 97 | noise = torch.randn_like(x) 98 | x_noised = x + noise * sigma 99 | 100 | sigma = sigma.reshape(x.shape[0],) 101 | return x_noised, sigma.log() 102 | 103 | def test(self, x, t): 104 | x = normalize_to_neg_one_to_one(x).to(self.device) 105 | # Perturbation 106 | noise = torch.randn_like(x) 107 | sigma = self.times[t] 108 | x_noised = x + noise * sigma 109 | 110 | sigma = torch.full((x.shape[0], ), sigma) 111 | return x_noised, sigma.log() 112 | 113 | def schedules(self, sigma_min, sigma_max, rho, steps): 114 | times = torch.arange(steps, dtype=torch.float64, device=self.device) 115 | times = (sigma_max ** (1 / rho) + times / (steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 116 | times = torch.cat([times, torch.zeros_like(times[:1])]) # t_N = 0 117 | times = reversed(times) 118 | return times 119 | 120 | 121 | def train(opt): 122 | def test(t): 123 | preds = [] 124 | labels = [] 125 | for image, label in tqdm(valid_loader, disable=(local_rank!=0)): 126 | with torch.no_grad(): 127 | model.eval() 128 | logit = model(*diff.test(image, t)) 129 | pred = logit.argmax(dim=-1) 130 | preds.append(pred) 131 | labels.append(label.to(device)) 132 | 133 | pred = torch.cat(preds) 134 | label = torch.cat(labels) 135 | dist.barrier() 136 | pred = gather_tensor(pred) 137 | label = gather_tensor(label) 138 | acc = (pred == label).sum().item() / len(label) 139 | return acc 140 | 141 | warm_epoch = opt.warm_epoch 142 | epoch = opt.epoch 143 | batch_size = opt.batch_size 144 | base_lr = opt.lr 145 | mode = opt.mode 146 | 147 | if mode == 'DDPM': 148 | diff = DDPM(device) 149 | elif mode == 'EDM': 150 | diff = EDM(device) 151 | else: 152 | raise NotImplementedError 153 | 154 | train_set = get_dataset(name='cifar', root="./data", train=True, flip=True, crop=True) 155 | valid_set = get_dataset(name='cifar', root="./data", train=False) 156 | train_loader, sampler = DataLoaderDDP( 157 | train_set, 158 | batch_size=batch_size, 159 | shuffle=True, 160 | ) 161 | valid_loader, _ = DataLoaderDDP( 162 | valid_set, 163 | batch_size=batch_size, 164 | shuffle=False, 165 | ) 166 | 167 | model = wide_28_10_song(num_classes=10).to(device) 168 | model = torch.nn.parallel.DistributedDataParallel( 169 | model, device_ids=[local_rank], output_device=local_rank) 170 | loss_fn = nn.CrossEntropyLoss() 171 | optim = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=5e-4) 172 | scheduler = CosineAnnealingLR(optim, epoch) 173 | for e in range(epoch): 174 | sampler.set_epoch(e) 175 | if (e + 1) <= warm_epoch: 176 | for g in optim.param_groups: 177 | g['lr'] = base_lr * (e + 1.0) / warm_epoch # warmup 178 | 179 | pbar = tqdm(train_loader, disable=(local_rank!=0)) 180 | for i, (image, label) in enumerate(pbar): 181 | model.train() 182 | logit = model(*diff.train(image)) 183 | label = label.to(device) 184 | loss = loss_fn(logit, label) 185 | optim.zero_grad() 186 | loss.backward() 187 | optim.step() 188 | 189 | # logging 190 | dist.barrier() 191 | loss = reduce_tensor(loss) 192 | logit = gather_tensor(logit).cpu() 193 | label = gather_tensor(label).cpu() 194 | 195 | if local_rank == 0: 196 | pred = logit.argmax(dim=-1) 197 | acc = (pred == label).sum().item() / len(label) 198 | nowlr = optim.param_groups[0]['lr'] 199 | pbar.set_description("[epoch %d / iter %d]: lr %.1e loss: %.3f, acc: %.3f" % (e, i, nowlr, loss.item(), acc)) 200 | scheduler.step() 201 | 202 | accs = {} 203 | for t in diff.test_timesteps: 204 | test_acc = test(t) 205 | print0("[timestep %d]: Test acc: %.3f" % (t, test_acc)) 206 | accs[t] = test_acc 207 | 208 | 209 | if __name__ == "__main__": 210 | parser = argparse.ArgumentParser() 211 | parser.add_argument('--local_rank', default=-1, type=int, 212 | help='node rank for distributed training') 213 | parser.add_argument('--batch_size', default=128, type=int) 214 | parser.add_argument('--lr', default=0.1, type=float) 215 | parser.add_argument('--epoch', default=200, type=int) 216 | parser.add_argument('--warm_epoch', default=5, type=int) 217 | parser.add_argument("--mode", type=str, choices=['DDPM', 'EDM'], default='DDPM') 218 | opt = parser.parse_args() 219 | print0(opt) 220 | 221 | local_rank = opt.local_rank 222 | init_seeds(no=local_rank) 223 | dist.init_process_group(backend='nccl') 224 | torch.cuda.set_device(local_rank) 225 | device = "cuda:%d" % local_rank 226 | 227 | train(opt) 228 | -------------------------------------------------------------------------------- /model/DDPM.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | import torch.nn as nn 4 | from tqdm import tqdm 5 | from torch.cuda.amp import autocast as autocast 6 | 7 | 8 | def normalize_to_neg_one_to_one(img): 9 | # [0.0, 1.0] -> [-1.0, 1.0] 10 | return img * 2 - 1 11 | 12 | 13 | def unnormalize_to_zero_to_one(t): 14 | # [-1.0, 1.0] -> [0.0, 1.0] 15 | return (t + 1) * 0.5 16 | 17 | 18 | def linear_beta_schedule(timesteps, beta1, beta2): 19 | assert 0.0 < beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)" 20 | return torch.linspace(beta1, beta2, timesteps) 21 | 22 | 23 | def schedules(betas, T, device, type='DDPM'): 24 | beta1, beta2 = betas 25 | schedule_fn = partial(linear_beta_schedule, beta1=beta1, beta2=beta2) 26 | 27 | if type == 'DDPM': 28 | beta_t = torch.cat([torch.tensor([0.0]), schedule_fn(T)]) 29 | elif type == 'DDIM': 30 | beta_t = schedule_fn(T + 1) 31 | else: 32 | raise NotImplementedError() 33 | sqrt_beta_t = torch.sqrt(beta_t) 34 | alpha_t = 1 - beta_t 35 | log_alpha_t = torch.log(alpha_t) 36 | alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() 37 | 38 | sqrtab = torch.sqrt(alphabar_t) 39 | oneover_sqrta = 1 / torch.sqrt(alpha_t) 40 | 41 | sqrtmab = torch.sqrt(1 - alphabar_t) 42 | ma_over_sqrtmab = (1 - alpha_t) / sqrtmab 43 | 44 | dic = { 45 | "alpha_t": alpha_t, 46 | "oneover_sqrta": oneover_sqrta, 47 | "sqrt_beta_t": sqrt_beta_t, 48 | "alphabar_t": alphabar_t, 49 | "sqrtab": sqrtab, 50 | "sqrtmab": sqrtmab, 51 | "ma_over_sqrtmab": ma_over_sqrtmab, 52 | } 53 | return {key: dic[key].to(device) for key in dic} 54 | 55 | 56 | class DDPM(nn.Module): 57 | def __init__(self, nn_model, betas, n_T, device): 58 | ''' DDPM proposed by "Denoising Diffusion Probabilistic Models", and \ 59 | DDIM sampler proposed by "Denoising Diffusion Implicit Models". 60 | 61 | Args: 62 | nn_model: A network (e.g. UNet) which performs same-shape mapping. 63 | device: The CUDA device that tensors run on. 64 | Parameters: 65 | betas, n_T 66 | ''' 67 | super(DDPM, self).__init__() 68 | self.nn_model = nn_model.to(device) 69 | params = sum(p.numel() for p in nn_model.parameters() if p.requires_grad) / 1e6 70 | print(f"nn model # params: {params:.1f}") 71 | 72 | self.device = device 73 | self.ddpm_sche = schedules(betas, n_T, device, 'DDPM') 74 | self.ddim_sche = schedules(betas, n_T, device, 'DDIM') 75 | self.n_T = n_T 76 | self.loss = nn.MSELoss() 77 | 78 | def perturb(self, x, t=None): 79 | ''' Add noise to a clean image (diffusion process). 80 | 81 | Args: 82 | x: The normalized image tensor. 83 | t: The specified timestep ranged in `[1, n_T]`. Type: int / torch.LongTensor / None. \ 84 | Random `t ~ U[1, n_T]` is taken if t is None. 85 | Returns: 86 | The perturbed image, the corresponding timestep, and the noise. 87 | ''' 88 | if t is None: 89 | t = torch.randint(1, self.n_T + 1, (x.shape[0], )).to(self.device) 90 | elif not isinstance(t, torch.Tensor): 91 | t = torch.tensor([t]).to(self.device).repeat(x.shape[0]) 92 | 93 | noise = torch.randn_like(x) 94 | sche = self.ddpm_sche 95 | x_noised = (sche["sqrtab"][t, None, None, None] * x + 96 | sche["sqrtmab"][t, None, None, None] * noise) 97 | return x_noised, t, noise 98 | 99 | def forward(self, x, use_amp=False): 100 | ''' Training with simple noise prediction loss. 101 | 102 | Args: 103 | x: The clean image tensor ranged in `[0, 1]`. 104 | Returns: 105 | The simple MSE loss. 106 | ''' 107 | x = normalize_to_neg_one_to_one(x) 108 | x_noised, t, noise = self.perturb(x, t=None) 109 | 110 | with autocast(enabled=use_amp): 111 | return self.loss(noise, self.nn_model(x_noised, t / self.n_T)) 112 | 113 | def get_feature(self, x, t, name=None, norm=False, use_amp=False): 114 | ''' Get network's intermediate activation in a forward pass. 115 | 116 | Args: 117 | x: The clean image tensor ranged in `[0, 1]`. 118 | t: The specified timestep ranged in `[1, n_T]`. Type: int / torch.LongTensor. 119 | norm: to normalize features to the the unit hypersphere. 120 | Returns: 121 | A {name: tensor} dict which contains global average pooled features. 122 | ''' 123 | x = normalize_to_neg_one_to_one(x) 124 | x_noised, t, noise = self.perturb(x, t) 125 | 126 | def gap_and_norm(act, norm=False): 127 | if len(act.shape) == 4: 128 | # unet (B, C, H, W) 129 | act = act.view(act.shape[0], act.shape[1], -1).float() 130 | act = torch.mean(act, dim=2) 131 | else: 132 | raise NotImplementedError 133 | if norm: 134 | act = torch.nn.functional.normalize(act) 135 | return act 136 | 137 | with autocast(enabled=use_amp): 138 | _, acts = self.nn_model(x_noised, t / self.n_T, ret_activation=True) 139 | all_feats = {blockname: gap_and_norm(acts[blockname], norm) for blockname in acts} 140 | if name is not None: 141 | return all_feats[name] 142 | else: 143 | return all_feats 144 | 145 | def sample(self, n_sample, size, notqdm=False, use_amp=False): 146 | ''' Sampling with DDPM sampler. Actual NFE is `n_T`. 147 | 148 | Args: 149 | n_sample: The batch size. 150 | size: The image shape (e.g. `(3, 32, 32)`). 151 | Returns: 152 | The sampled image tensor ranged in `[0, 1]`. 153 | ''' 154 | sche = self.ddpm_sche 155 | x_i = torch.randn(n_sample, *size).to(self.device) 156 | 157 | for i in tqdm(range(self.n_T, 0, -1), disable=notqdm): 158 | t_is = torch.tensor([i / self.n_T]).to(self.device).repeat(n_sample) 159 | 160 | z = torch.randn(n_sample, *size).to(self.device) if i > 1 else 0 161 | 162 | alpha = sche["alphabar_t"][i] 163 | eps, _ = self.pred_eps_(x_i, t_is, alpha, use_amp) 164 | 165 | mean = sche["oneover_sqrta"][i] * (x_i - sche["ma_over_sqrtmab"][i] * eps) 166 | variance = sche["sqrt_beta_t"][i] # LET variance sigma_t = sqrt_beta_t 167 | x_i = mean + variance * z 168 | 169 | return unnormalize_to_zero_to_one(x_i) 170 | 171 | def ddim_sample(self, n_sample, size, steps=100, eta=0.0, notqdm=False, use_amp=False): 172 | ''' Sampling with DDIM sampler. Actual NFE is `steps`. 173 | 174 | Args: 175 | n_sample: The batch size. 176 | size: The image shape (e.g. `(3, 32, 32)`). 177 | steps: The number of total timesteps. 178 | eta: controls stochasticity. Set `eta=0` for deterministic sampling. 179 | Returns: 180 | The sampled image tensor ranged in `[0, 1]`. 181 | ''' 182 | sche = self.ddim_sche 183 | x_i = torch.randn(n_sample, *size).to(self.device) 184 | 185 | times = torch.arange(0, self.n_T, self.n_T // steps) + 1 186 | times = list(reversed(times.int().tolist())) + [0] 187 | time_pairs = list(zip(times[:-1], times[1:])) 188 | # e.g. [(801, 601), (601, 401), (401, 201), (201, 1), (1, 0)] 189 | 190 | for time, time_next in tqdm(time_pairs, disable=notqdm): 191 | t_is = torch.tensor([time / self.n_T]).to(self.device).repeat(n_sample) 192 | 193 | z = torch.randn(n_sample, *size).to(self.device) if time_next > 0 else 0 194 | 195 | alpha = sche["alphabar_t"][time] 196 | eps, x0_t = self.pred_eps_(x_i, t_is, alpha, use_amp) 197 | alpha_next = sche["alphabar_t"][time_next] 198 | c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() 199 | c2 = (1 - alpha_next - c1 ** 2).sqrt() 200 | x_i = alpha_next.sqrt() * x0_t + c2 * eps + c1 * z 201 | 202 | return unnormalize_to_zero_to_one(x_i) 203 | 204 | def pred_eps_(self, x, t, alpha, use_amp, clip_x=True): 205 | def pred_eps_from_x0(x0): 206 | return (x - x0 * alpha.sqrt()) / (1 - alpha).sqrt() 207 | 208 | def pred_x0_from_eps(eps): 209 | return (x - (1 - alpha).sqrt() * eps) / alpha.sqrt() 210 | 211 | # get prediction of x0 212 | with autocast(enabled=use_amp): 213 | eps = self.nn_model(x, t).float() 214 | denoised = pred_x0_from_eps(eps) 215 | 216 | # pixel-space clipping (optional) 217 | if clip_x: 218 | denoised = torch.clip(denoised, -1., 1.) 219 | eps = pred_eps_from_x0(denoised) 220 | return eps, denoised 221 | -------------------------------------------------------------------------------- /model/EDM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from tqdm import tqdm 5 | from torch.cuda.amp import autocast as autocast 6 | from .augment import AugmentPipe 7 | 8 | 9 | def normalize_to_neg_one_to_one(img): 10 | # [0.0, 1.0] -> [-1.0, 1.0] 11 | return img * 2 - 1 12 | 13 | 14 | def unnormalize_to_zero_to_one(t): 15 | # [-1.0, 1.0] -> [0.0, 1.0] 16 | return (t + 1) * 0.5 17 | 18 | 19 | class EDM(nn.Module): 20 | def __init__(self, nn_model, 21 | sigma_data, p_mean, p_std, 22 | sigma_min, sigma_max, rho, 23 | S_min, S_max, S_noise, 24 | device, 25 | augment_prob=0): 26 | ''' EDM proposed by "Elucidating the Design Space of Diffusion-Based Generative Models". 27 | 28 | Args: 29 | nn_model: A network (e.g. UNet) which performs same-shape mapping. 30 | device: The CUDA device that tensors run on. 31 | Training parameters: 32 | sigma_data, p_mean, p_std 33 | augment_prob 34 | Sampling parameters: 35 | sigma_min, sigma_max, rho 36 | S_min, S_max, S_noise 37 | ''' 38 | super(EDM, self).__init__() 39 | self.nn_model = nn_model.to(device) 40 | params = sum(p.numel() for p in nn_model.parameters() if p.requires_grad) / 1e6 41 | print(f"nn model # params: {params:.1f}") 42 | 43 | self.device = device 44 | 45 | def number_to_torch_device(value): 46 | return torch.tensor(value).to(device) 47 | 48 | self.sigma_data = number_to_torch_device(sigma_data) 49 | self.p_mean = number_to_torch_device(p_mean) 50 | self.p_std = number_to_torch_device(p_std) 51 | self.sigma_min = number_to_torch_device(sigma_min) 52 | self.sigma_max = number_to_torch_device(sigma_max) 53 | self.rho = number_to_torch_device(rho) 54 | self.S_min = number_to_torch_device(S_min) 55 | self.S_max = number_to_torch_device(S_max) 56 | self.S_noise = number_to_torch_device(S_noise) 57 | if augment_prob > 0: 58 | self.augpipe = AugmentPipe(p=augment_prob, xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1) 59 | else: 60 | self.augpipe = None 61 | 62 | def perturb(self, x, t=None, steps=None): 63 | ''' Add noise to a clean image (diffusion process). 64 | 65 | Args: 66 | x: The normalized image tensor. 67 | t: The specified timestep ranged in `[1, steps]`. Type: int / torch.LongTensor / None. \ 68 | Random `ln(sigma) ~ N(P_mean, P_std)` is taken if t is None. 69 | Returns: 70 | The perturbed image, and the corresponding sigma. 71 | ''' 72 | if t is None: 73 | rnd_normal = torch.randn((x.shape[0], 1, 1, 1)).to(self.device) 74 | sigma = (rnd_normal * self.p_std + self.p_mean).exp() 75 | else: 76 | times = reversed(self.sample_schedule(steps)) 77 | sigma = times[t] 78 | if len(sigma.shape) == 1: 79 | sigma = sigma[:, None, None, None] 80 | 81 | noise = torch.randn_like(x) 82 | x_noised = x + noise * sigma 83 | return x_noised, sigma 84 | 85 | def forward(self, x, use_amp=False): 86 | ''' Training with weighted denoising loss. 87 | 88 | Args: 89 | x: The clean image tensor ranged in `[0, 1]`. 90 | Returns: 91 | The weighted MSE loss. 92 | ''' 93 | x = normalize_to_neg_one_to_one(x) 94 | x, aug_label = self.augpipe(x) if self.augpipe is not None else (x, None) 95 | x_noised, sigma = self.perturb(x, t=None) 96 | 97 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 98 | loss_4shape = weight * ((x - self.D_x(x_noised, sigma, use_amp, aug_label)) ** 2) 99 | return loss_4shape.mean() 100 | 101 | def get_feature(self, x, t, steps=18, name=None, norm=False, use_amp=False): 102 | ''' Get network's intermediate activation in a forward pass. 103 | 104 | Args: 105 | x: The clean image tensor ranged in `[0, 1]`. 106 | t: The specified timestep ranged in `[1, steps]`. Type: int / torch.LongTensor. 107 | norm: to normalize features to the the unit hypersphere. 108 | Returns: 109 | A {name: tensor} dict which contains global average pooled features. 110 | ''' 111 | x = normalize_to_neg_one_to_one(x) 112 | x_noised, sigma = self.perturb(x, t, steps) 113 | 114 | def gap_and_norm(act, norm=False): 115 | if len(act.shape) == 4: 116 | # unet (B, C, H, W) 117 | act = act.view(act.shape[0], act.shape[1], -1).float() 118 | act = torch.mean(act, dim=2) 119 | else: 120 | raise NotImplementedError 121 | if norm: 122 | act = torch.nn.functional.normalize(act) 123 | return act 124 | 125 | _, acts = self.D_x(x_noised, sigma, use_amp, ret_activation=True) 126 | all_feats = {blockname: gap_and_norm(acts[blockname], norm) for blockname in acts} 127 | if name is not None: 128 | return all_feats[name] 129 | else: 130 | return all_feats 131 | 132 | def edm_sample(self, n_sample, size, steps=18, eta=0.0, notqdm=False, use_amp=False): 133 | ''' Sampling with EDM sampler. Actual NFE is `2 * steps - 1`. 134 | 135 | Args: 136 | n_sample: The batch size. 137 | size: The image shape (e.g. `(3, 32, 32)`). 138 | steps: The number of total timesteps. 139 | eta: controls stochasticity. Set `eta=0` for deterministic sampling. 140 | Returns: 141 | The sampled image tensor ranged in `[0, 1]`. 142 | ''' 143 | S_min, S_max, S_noise = self.S_min, self.S_max, self.S_noise 144 | gamma_stochasticity = torch.tensor(np.sqrt(2) - 1) * eta # S_churn = (sqrt(2) - 1) * eta * steps 145 | 146 | times = self.sample_schedule(steps) 147 | time_pairs = list(zip(times[:-1], times[1:])) 148 | 149 | x_next = torch.randn(n_sample, *size).to(self.device).to(torch.float64) * times[0] 150 | for i, (t_cur, t_next) in enumerate(tqdm(time_pairs, disable=notqdm)): # 0, ..., N-1 151 | x_cur = x_next 152 | 153 | # Increase noise temporarily. 154 | gamma = gamma_stochasticity if S_min <= t_cur <= S_max else 0 155 | t_hat = t_cur + gamma * t_cur 156 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur) 157 | 158 | # Euler step. 159 | d_cur = self.pred_eps_(x_hat, t_hat, use_amp) 160 | x_next = x_hat + (t_next - t_hat) * d_cur 161 | 162 | # Apply 2nd order correction. 163 | if i < steps - 1: 164 | d_prime = self.pred_eps_(x_next, t_next, use_amp) 165 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 166 | 167 | return unnormalize_to_zero_to_one(x_next) 168 | 169 | def pred_eps_(self, x, t, use_amp, clip_x=True): 170 | denoised = self.D_x(x, t, use_amp).to(torch.float64) 171 | # pixel-space clipping (optional) 172 | if clip_x: 173 | denoised = torch.clip(denoised, -1., 1.) 174 | eps = (x - denoised) / t 175 | return eps 176 | 177 | def D_x(self, x_noised, sigma, use_amp, aug_label=None, ret_activation=False): 178 | ''' Denoising with network preconditioning. 179 | 180 | Args: 181 | x_noised: The perturbed image tensor. 182 | sigma: The variance (noise level) tensor. 183 | aug_label: The augmentation labels produced by AugmentPipe. 184 | Returns: 185 | The estimated denoised image tensor. 186 | The {name: (B, C, H, W) tensor} activation dict (if ret_activation is True). 187 | ''' 188 | x_noised = x_noised.to(torch.float32) 189 | sigma = sigma.to(torch.float32) 190 | 191 | # Preconditioning 192 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 193 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() 194 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2).sqrt() 195 | c_noise = sigma.log() / 4 196 | 197 | # Denoising 198 | with autocast(enabled=use_amp): 199 | F_x = self.nn_model(c_in * x_noised, c_noise.flatten(), aug_label, ret_activation) 200 | 201 | if ret_activation: 202 | return c_skip * x_noised + c_out * F_x[0], F_x[1] 203 | else: 204 | return c_skip * x_noised + c_out * F_x 205 | 206 | def sample_schedule(self, steps): 207 | ''' Make the variance schedule for EDM sampling. 208 | 209 | Args: 210 | steps: The number of total timesteps. Typically 18, 50 or 100. 211 | Returns: 212 | times: A decreasing tensor list such that 213 | `times[0] == sigma_max`, 214 | `times[steps-1] == sigma_min`, and 215 | `times[steps] == 0`. 216 | ''' 217 | sigma_min, sigma_max, rho = self.sigma_min, self.sigma_max, self.rho 218 | times = torch.arange(steps, dtype=torch.float64, device=self.device) 219 | times = (sigma_max ** (1 / rho) + times / (steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 220 | times = torch.cat([times, torch.zeros_like(times[:1])]) # t_N = 0 221 | return times 222 | -------------------------------------------------------------------------------- /model/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .block import GroupNorm32, TimeEmbedding, AttentionBlock, Upsample, Downsample 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_channels, out_channels, time_channels, dropout=0.1, up=False, down=False): 8 | """ 9 | * `in_channels` is the number of input channels 10 | * `out_channels` is the number of output channels 11 | * `time_channels` is the number channels in the time step ($t$) embeddings 12 | * `dropout` is the dropout rate 13 | """ 14 | super().__init__() 15 | self.norm1 = GroupNorm32(in_channels) 16 | self.act1 = nn.SiLU() 17 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 18 | 19 | self.norm2 = GroupNorm32(out_channels) 20 | self.act2 = nn.SiLU() 21 | self.conv2 = nn.Sequential( 22 | nn.Dropout(dropout), 23 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 24 | ) 25 | 26 | if in_channels != out_channels: 27 | self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) 28 | else: 29 | self.shortcut = nn.Identity() 30 | 31 | # Linear layer for embeddings 32 | self.time_emb = nn.Sequential( 33 | nn.SiLU(), 34 | nn.Linear(time_channels, out_channels) 35 | ) 36 | 37 | # BigGAN style: use resblock for up/downsampling 38 | self.updown = up or down 39 | if up: 40 | self.h_upd = Upsample(in_channels, use_conv=False) 41 | self.x_upd = Upsample(in_channels, use_conv=False) 42 | elif down: 43 | self.h_upd = Downsample(in_channels, use_conv=False) 44 | self.x_upd = Downsample(in_channels, use_conv=False) 45 | else: 46 | self.h_upd = self.x_upd = nn.Identity() 47 | 48 | def forward(self, x, t): 49 | """ 50 | * `x` has shape `[batch_size, in_channels, height, width]` 51 | * `t` has shape `[batch_size, time_channels]` 52 | """ 53 | if self.updown: 54 | h = self.conv1(self.h_upd(self.act1(self.norm1(x)))) 55 | x = self.x_upd(x) 56 | else: 57 | h = self.conv1(self.act1(self.norm1(x))) 58 | 59 | # Adaptive Group Normalization 60 | t_ = self.time_emb(t)[:, :, None, None] 61 | h = h + t_ 62 | 63 | h = self.conv2(self.act2(self.norm2(h))) 64 | return h + self.shortcut(x) 65 | 66 | 67 | class ResAttBlock(nn.Module): 68 | def __init__(self, in_channels, out_channels, time_channels, has_attn, attn_channels_per_head, dropout): 69 | super().__init__() 70 | self.res = ResidualBlock(in_channels, out_channels, time_channels, dropout=dropout) 71 | if has_attn: 72 | self.attn = AttentionBlock(out_channels, attn_channels_per_head) 73 | else: 74 | self.attn = nn.Identity() 75 | 76 | def forward(self, x, t): 77 | x = self.res(x, t) 78 | x = self.attn(x) 79 | return x 80 | 81 | 82 | class MiddleBlock(nn.Module): 83 | def __init__(self, n_channels, time_channels, attn_channels_per_head, dropout): 84 | super().__init__() 85 | self.res1 = ResidualBlock(n_channels, n_channels, time_channels, dropout=dropout) 86 | self.attn = AttentionBlock(n_channels, attn_channels_per_head) 87 | self.res2 = ResidualBlock(n_channels, n_channels, time_channels, dropout=dropout) 88 | 89 | def forward(self, x, t): 90 | x = self.res1(x, t) 91 | x = self.attn(x) 92 | x = self.res2(x, t) 93 | return x 94 | 95 | 96 | class UpsampleRes(nn.Module): 97 | def __init__(self, n_channels, time_channels, dropout): 98 | super().__init__() 99 | self.op = ResidualBlock(n_channels, n_channels, time_channels, dropout=dropout, up=True) 100 | 101 | def forward(self, x, t): 102 | return self.op(x, t) 103 | 104 | 105 | class DownsampleRes(nn.Module): 106 | def __init__(self, n_channels, time_channels, dropout): 107 | super().__init__() 108 | self.op = ResidualBlock(n_channels, n_channels, time_channels, dropout=dropout, down=True) 109 | 110 | def forward(self, x, t): 111 | return self.op(x, t) 112 | 113 | 114 | class UNet(nn.Module): 115 | def __init__(self, image_shape = [3, 32, 32], n_channels = 128, 116 | ch_mults = (1, 2, 2, 2), 117 | is_attn = (False, True, False, False), 118 | attn_channels_per_head = None, 119 | dropout = 0.1, 120 | n_blocks = 2, 121 | use_res_for_updown = False, 122 | augment_dim = 0): 123 | """ 124 | * `image_shape` is the (channel, height, width) size of images. 125 | * `n_channels` is number of channels in the initial feature map that we transform the image into 126 | * `ch_mults` is the list of channel numbers at each resolution. The number of channels is `n_channels * ch_mults[i]` 127 | * `is_attn` is a list of booleans that indicate whether to use attention at each resolution 128 | * `dropout` is the dropout rate 129 | * `n_blocks` is the number of `UpDownBlocks` at each resolution 130 | * `use_res_for_updown` indicates whether to use ResBlocks for up/down sampling (BigGAN-style) 131 | * `augment_dim` indicates augmentation label dimensionality, 0 = no augmentation 132 | """ 133 | super().__init__() 134 | 135 | n_resolutions = len(ch_mults) 136 | 137 | self.image_proj = nn.Conv2d(image_shape[0], n_channels, kernel_size=3, padding=1) 138 | 139 | # Embedding layers (time & augment) 140 | time_channels = n_channels * 4 141 | self.time_emb = TimeEmbedding(time_channels, augment_dim) 142 | 143 | # Down stages 144 | down = [] 145 | in_channels = n_channels 146 | h_channels = [n_channels] 147 | for i in range(n_resolutions): 148 | # Number of output channels at this resolution 149 | out_channels = n_channels * ch_mults[i] 150 | # `n_blocks` at the same resolution 151 | down.append(ResAttBlock(in_channels, out_channels, time_channels, is_attn[i], attn_channels_per_head, dropout)) 152 | h_channels.append(out_channels) 153 | for _ in range(n_blocks - 1): 154 | down.append(ResAttBlock(out_channels, out_channels, time_channels, is_attn[i], attn_channels_per_head, dropout)) 155 | h_channels.append(out_channels) 156 | # Down sample at all resolutions except the last 157 | if i < n_resolutions - 1: 158 | if use_res_for_updown: 159 | down.append(DownsampleRes(out_channels, time_channels, dropout)) 160 | else: 161 | down.append(Downsample(out_channels)) 162 | h_channels.append(out_channels) 163 | in_channels = out_channels 164 | self.down = nn.ModuleList(down) 165 | 166 | # Middle block 167 | self.middle = MiddleBlock(out_channels, time_channels, attn_channels_per_head, dropout) 168 | 169 | # Up stages 170 | up = [] 171 | in_channels = out_channels 172 | for i in reversed(range(n_resolutions)): 173 | # Number of output channels at this resolution 174 | out_channels = n_channels * ch_mults[i] 175 | # `n_blocks + 1` at the same resolution 176 | for _ in range(n_blocks + 1): 177 | up.append(ResAttBlock(in_channels + h_channels.pop(), out_channels, time_channels, is_attn[i], attn_channels_per_head, dropout)) 178 | in_channels = out_channels 179 | # Up sample at all resolutions except last 180 | if i > 0: 181 | if use_res_for_updown: 182 | up.append(UpsampleRes(out_channels, time_channels, dropout)) 183 | else: 184 | up.append(Upsample(out_channels)) 185 | assert not h_channels 186 | self.up = nn.ModuleList(up) 187 | 188 | # Final normalization and convolution layer 189 | self.norm = nn.GroupNorm(8, out_channels) 190 | self.act = nn.SiLU() 191 | self.final = nn.Conv2d(out_channels, image_shape[0], kernel_size=3, padding=1) 192 | 193 | def forward(self, x, t, aug=None, ret_activation=False): 194 | if not ret_activation: 195 | return self.forward_core(x, t, aug) 196 | 197 | activation = {} 198 | def namedHook(name): 199 | def hook(module, input, output): 200 | activation[name] = output 201 | return hook 202 | hooks = {} 203 | no = 0 204 | for blk in self.up: 205 | if isinstance(blk, ResAttBlock): 206 | no += 1 207 | name = f'out_{no}' 208 | hooks[name] = blk.register_forward_hook(namedHook(name)) 209 | 210 | result = self.forward_core(x, t, aug) 211 | for name in hooks: 212 | hooks[name].remove() 213 | return result, activation 214 | 215 | def forward_core(self, x, t, aug): 216 | """ 217 | * `x` has shape `[batch_size, in_channels, height, width]` 218 | * `t` has shape `[batch_size]` 219 | """ 220 | 221 | t = self.time_emb(t, aug) 222 | x = self.image_proj(x) 223 | 224 | # `h` will store outputs at each resolution for skip connection 225 | h = [x] 226 | 227 | for m in self.down: 228 | if isinstance(m, Downsample): 229 | x = m(x) 230 | elif isinstance(m, DownsampleRes): 231 | x = m(x, t) 232 | else: 233 | x = m(x, t).contiguous() 234 | h.append(x) 235 | 236 | x = self.middle(x, t).contiguous() 237 | 238 | for m in self.up: 239 | if isinstance(m, Upsample): 240 | x = m(x) 241 | elif isinstance(m, UpsampleRes): 242 | x = m(x, t) 243 | else: 244 | # Get the skip connection from first half of U-Net and concatenate 245 | s = h.pop() 246 | x = torch.cat((x, s), dim=1) 247 | x = m(x, t).contiguous() 248 | 249 | return self.final(self.act(self.norm(x))) 250 | 251 | 252 | ''' 253 | from model.unet import UNet 254 | net = UNet() 255 | import torch 256 | x = torch.zeros(1, 3, 32, 32) 257 | t = torch.zeros(1,) 258 | 259 | net(x, t).shape 260 | sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6 261 | 262 | >>> 35.746307 M parameters for CIFAR-10 model (original DDPM) 263 | ''' 264 | -------------------------------------------------------------------------------- /DiT/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | import math 16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 17 | 18 | 19 | def modulate(x, shift, scale): 20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 21 | 22 | 23 | ################################################################################# 24 | # Embedding Layers for Timesteps and Class Labels # 25 | ################################################################################# 26 | 27 | class TimestepEmbedder(nn.Module): 28 | """ 29 | Embeds scalar timesteps into vector representations. 30 | """ 31 | def __init__(self, hidden_size, frequency_embedding_size=256): 32 | super().__init__() 33 | self.mlp = nn.Sequential( 34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 35 | nn.SiLU(), 36 | nn.Linear(hidden_size, hidden_size, bias=True), 37 | ) 38 | self.frequency_embedding_size = frequency_embedding_size 39 | 40 | @staticmethod 41 | def timestep_embedding(t, dim, max_period=10000): 42 | """ 43 | Create sinusoidal timestep embeddings. 44 | :param t: a 1-D Tensor of N indices, one per batch element. 45 | These may be fractional. 46 | :param dim: the dimension of the output. 47 | :param max_period: controls the minimum frequency of the embeddings. 48 | :return: an (N, D) Tensor of positional embeddings. 49 | """ 50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 51 | half = dim // 2 52 | freqs = torch.exp( 53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 54 | ).to(device=t.device) 55 | args = t[:, None].float() * freqs[None] 56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 57 | if dim % 2: 58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 59 | return embedding 60 | 61 | def forward(self, t): 62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 63 | t_emb = self.mlp(t_freq) 64 | return t_emb 65 | 66 | 67 | class LabelEmbedder(nn.Module): 68 | """ 69 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 70 | """ 71 | def __init__(self, num_classes, hidden_size, dropout_prob): 72 | super().__init__() 73 | use_cfg_embedding = dropout_prob > 0 74 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 75 | self.num_classes = num_classes 76 | self.dropout_prob = dropout_prob 77 | 78 | def token_drop(self, labels, force_drop_ids=None): 79 | """ 80 | Drops labels to enable classifier-free guidance. 81 | """ 82 | if force_drop_ids is None: 83 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 84 | else: 85 | drop_ids = force_drop_ids == 1 86 | labels = torch.where(drop_ids, self.num_classes, labels) 87 | return labels 88 | 89 | def forward(self, labels, train, force_drop_ids=None): 90 | use_dropout = self.dropout_prob > 0 91 | if (train and use_dropout) or (force_drop_ids is not None): 92 | labels = self.token_drop(labels, force_drop_ids) 93 | embeddings = self.embedding_table(labels) 94 | return embeddings 95 | 96 | 97 | ################################################################################# 98 | # Core DiT Model # 99 | ################################################################################# 100 | 101 | class DiTBlock(nn.Module): 102 | """ 103 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 104 | """ 105 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 106 | super().__init__() 107 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 108 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 109 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 110 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 111 | approx_gelu = lambda: nn.GELU(approximate="tanh") 112 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 113 | self.adaLN_modulation = nn.Sequential( 114 | nn.SiLU(), 115 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 116 | ) 117 | 118 | def forward(self, x, c): 119 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 120 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 121 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 122 | return x 123 | 124 | 125 | class FinalLayer(nn.Module): 126 | """ 127 | The final layer of DiT. 128 | """ 129 | def __init__(self, hidden_size, patch_size, out_channels): 130 | super().__init__() 131 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 132 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 133 | self.adaLN_modulation = nn.Sequential( 134 | nn.SiLU(), 135 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 136 | ) 137 | 138 | def forward(self, x, c): 139 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 140 | x = modulate(self.norm_final(x), shift, scale) 141 | x = self.linear(x) 142 | return x 143 | 144 | 145 | class DiT(nn.Module): 146 | """ 147 | Diffusion model with a Transformer backbone. 148 | """ 149 | def __init__( 150 | self, 151 | input_size=32, 152 | patch_size=2, 153 | in_channels=4, 154 | hidden_size=1152, 155 | depth=28, 156 | num_heads=16, 157 | mlp_ratio=4.0, 158 | class_dropout_prob=0.1, 159 | num_classes=1000, 160 | learn_sigma=True, 161 | ): 162 | super().__init__() 163 | self.learn_sigma = learn_sigma 164 | self.in_channels = in_channels 165 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 166 | self.patch_size = patch_size 167 | self.num_heads = num_heads 168 | 169 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 170 | self.t_embedder = TimestepEmbedder(hidden_size) 171 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 172 | num_patches = self.x_embedder.num_patches 173 | # Will use fixed sin-cos embedding: 174 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 175 | 176 | self.blocks = nn.ModuleList([ 177 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) 178 | ]) 179 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 180 | self.initialize_weights() 181 | 182 | def initialize_weights(self): 183 | # Initialize transformer layers: 184 | def _basic_init(module): 185 | if isinstance(module, nn.Linear): 186 | torch.nn.init.xavier_uniform_(module.weight) 187 | if module.bias is not None: 188 | nn.init.constant_(module.bias, 0) 189 | self.apply(_basic_init) 190 | 191 | # Initialize (and freeze) pos_embed by sin-cos embedding: 192 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 193 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 194 | 195 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 196 | w = self.x_embedder.proj.weight.data 197 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 198 | nn.init.constant_(self.x_embedder.proj.bias, 0) 199 | 200 | # Initialize label embedding table: 201 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 202 | 203 | # Initialize timestep embedding MLP: 204 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 205 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 206 | 207 | # Zero-out adaLN modulation layers in DiT blocks: 208 | for block in self.blocks: 209 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 210 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 211 | 212 | # Zero-out output layers: 213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 214 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 215 | nn.init.constant_(self.final_layer.linear.weight, 0) 216 | nn.init.constant_(self.final_layer.linear.bias, 0) 217 | 218 | def unpatchify(self, x): 219 | """ 220 | x: (N, T, patch_size**2 * C) 221 | imgs: (N, H, W, C) 222 | """ 223 | c = self.out_channels 224 | p = self.x_embedder.patch_size[0] 225 | h = w = int(x.shape[1] ** 0.5) 226 | assert h * w == x.shape[1] 227 | 228 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 229 | x = torch.einsum('nhwpqc->nchpwq', x) 230 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 231 | return imgs 232 | 233 | def forward(self, x, t, y, ret_activation=False): 234 | if not ret_activation: 235 | return self.forward_core(x, t, y) 236 | 237 | activation = {} 238 | def namedHook(name): 239 | def hook(module, input, output): 240 | activation[name] = output 241 | return hook 242 | hooks = {} 243 | for idx, block in enumerate(self.blocks): 244 | name = f"layer-{idx}" 245 | hooks[name] = block.register_forward_hook(namedHook(name)) 246 | 247 | result = self.forward_core(x, t, y) 248 | for name in hooks: 249 | hooks[name].remove() 250 | return result, activation 251 | 252 | def forward_core(self, x, t, y): 253 | """ 254 | Forward pass of DiT. 255 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 256 | t: (N,) tensor of diffusion timesteps 257 | y: (N,) tensor of class labels 258 | """ 259 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 260 | t = self.t_embedder(t) # (N, D) 261 | y = self.y_embedder(y, self.training) # (N, D) 262 | c = t + y # (N, D) 263 | for block in self.blocks: 264 | x = block(x, c) # (N, T, D) 265 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 266 | x = self.unpatchify(x) # (N, out_channels, H, W) 267 | return x 268 | 269 | def forward_with_cfg(self, x, t, y, cfg_scale): 270 | """ 271 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. 272 | """ 273 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 274 | half = x[: len(x) // 2] 275 | combined = torch.cat([half, half], dim=0) 276 | model_out = self.forward(combined, t, y) 277 | # For exact reproducibility reasons, we apply classifier-free guidance on only 278 | # three channels by default. The standard approach to cfg applies it to all channels. 279 | # This can be done by uncommenting the following line and commenting-out the line following that. 280 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 281 | eps, rest = model_out[:, :3], model_out[:, 3:] 282 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 283 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 284 | eps = torch.cat([half_eps, half_eps], dim=0) 285 | return torch.cat([eps, rest], dim=1) 286 | 287 | 288 | ################################################################################# 289 | # Sine/Cosine Positional Embedding Functions # 290 | ################################################################################# 291 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 292 | 293 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 294 | """ 295 | grid_size: int of the grid height and width 296 | return: 297 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 298 | """ 299 | grid_h = np.arange(grid_size, dtype=np.float32) 300 | grid_w = np.arange(grid_size, dtype=np.float32) 301 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 302 | grid = np.stack(grid, axis=0) 303 | 304 | grid = grid.reshape([2, 1, grid_size, grid_size]) 305 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 306 | if cls_token and extra_tokens > 0: 307 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 308 | return pos_embed 309 | 310 | 311 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 312 | assert embed_dim % 2 == 0 313 | 314 | # use half of dimensions to encode grid_h 315 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 316 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 317 | 318 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 319 | return emb 320 | 321 | 322 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 323 | """ 324 | embed_dim: output dimension for each position 325 | pos: a list of positions to be encoded: size (M,) 326 | out: (M, D) 327 | """ 328 | assert embed_dim % 2 == 0 329 | omega = np.arange(embed_dim // 2, dtype=np.float64) 330 | omega /= embed_dim / 2. 331 | omega = 1. / 10000**omega # (D/2,) 332 | 333 | pos = pos.reshape(-1) # (M,) 334 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 335 | 336 | emb_sin = np.sin(out) # (M, D/2) 337 | emb_cos = np.cos(out) # (M, D/2) 338 | 339 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 340 | return emb 341 | 342 | 343 | ################################################################################# 344 | # DiT Configs # 345 | ################################################################################# 346 | 347 | def DiT_XL_2(**kwargs): 348 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 349 | 350 | def DiT_XL_4(**kwargs): 351 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 352 | 353 | def DiT_XL_8(**kwargs): 354 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 355 | 356 | def DiT_L_2(**kwargs): 357 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 358 | 359 | def DiT_L_4(**kwargs): 360 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 361 | 362 | def DiT_L_8(**kwargs): 363 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 364 | 365 | def DiT_B_2(**kwargs): 366 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 367 | 368 | def DiT_B_4(**kwargs): 369 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) 370 | 371 | def DiT_B_8(**kwargs): 372 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) 373 | 374 | def DiT_S_2(**kwargs): 375 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 376 | 377 | def DiT_S_4(**kwargs): 378 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 379 | 380 | def DiT_S_8(**kwargs): 381 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 382 | 383 | 384 | DiT_models = { 385 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, 386 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 387 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 388 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, 389 | } 390 | -------------------------------------------------------------------------------- /model/augment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Augmentation pipeline used in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models". 10 | Built around the same concepts that were originally proposed in the paper 11 | "Training Generative Adversarial Networks with Limited Data".""" 12 | 13 | import numpy as np 14 | import torch 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Coefficients of various wavelet decomposition low-pass filters. 45 | 46 | wavelets = { 47 | 'haar': [0.7071067811865476, 0.7071067811865476], 48 | 'db1': [0.7071067811865476, 0.7071067811865476], 49 | 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], 50 | 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], 51 | 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], 52 | 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], 53 | 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], 54 | 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], 55 | 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], 56 | 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], 57 | 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], 58 | 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], 59 | 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], 60 | 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], 61 | 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], 62 | 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], 63 | } 64 | 65 | #---------------------------------------------------------------------------- 66 | # Helpers for constructing transformation matrices. 67 | 68 | def matrix(*rows, device=None): 69 | assert all(len(row) == len(rows[0]) for row in rows) 70 | elems = [x for row in rows for x in row] 71 | ref = [x for x in elems if isinstance(x, torch.Tensor)] 72 | if len(ref) == 0: 73 | return constant(np.asarray(rows), device=device) 74 | assert device is None or device == ref[0].device 75 | elems = [x if isinstance(x, torch.Tensor) else constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] 76 | return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) 77 | 78 | def translate2d(tx, ty, **kwargs): 79 | return matrix( 80 | [1, 0, tx], 81 | [0, 1, ty], 82 | [0, 0, 1], 83 | **kwargs) 84 | 85 | def translate3d(tx, ty, tz, **kwargs): 86 | return matrix( 87 | [1, 0, 0, tx], 88 | [0, 1, 0, ty], 89 | [0, 0, 1, tz], 90 | [0, 0, 0, 1], 91 | **kwargs) 92 | 93 | def scale2d(sx, sy, **kwargs): 94 | return matrix( 95 | [sx, 0, 0], 96 | [0, sy, 0], 97 | [0, 0, 1], 98 | **kwargs) 99 | 100 | def scale3d(sx, sy, sz, **kwargs): 101 | return matrix( 102 | [sx, 0, 0, 0], 103 | [0, sy, 0, 0], 104 | [0, 0, sz, 0], 105 | [0, 0, 0, 1], 106 | **kwargs) 107 | 108 | def rotate2d(theta, **kwargs): 109 | return matrix( 110 | [torch.cos(theta), torch.sin(-theta), 0], 111 | [torch.sin(theta), torch.cos(theta), 0], 112 | [0, 0, 1], 113 | **kwargs) 114 | 115 | def rotate3d(v, theta, **kwargs): 116 | vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] 117 | s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c 118 | return matrix( 119 | [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], 120 | [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], 121 | [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], 122 | [0, 0, 0, 1], 123 | **kwargs) 124 | 125 | def translate2d_inv(tx, ty, **kwargs): 126 | return translate2d(-tx, -ty, **kwargs) 127 | 128 | def scale2d_inv(sx, sy, **kwargs): 129 | return scale2d(1 / sx, 1 / sy, **kwargs) 130 | 131 | def rotate2d_inv(theta, **kwargs): 132 | return rotate2d(-theta, **kwargs) 133 | 134 | #---------------------------------------------------------------------------- 135 | # Augmentation pipeline main class. 136 | # All augmentations are disabled by default; individual augmentations can 137 | # be enabled by setting their probability multipliers to 1. 138 | 139 | class AugmentPipe: 140 | def __init__(self, p=1, 141 | xflip=0, yflip=0, rotate_int=0, translate_int=0, translate_int_max=0.125, 142 | scale=0, rotate_frac=0, aniso=0, translate_frac=0, scale_std=0.2, rotate_frac_max=1, aniso_std=0.2, aniso_rotate_prob=0.5, translate_frac_std=0.125, 143 | brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1, 144 | ): 145 | super().__init__() 146 | self.p = float(p) # Overall multiplier for augmentation probability. 147 | 148 | # Pixel blitting. 149 | self.xflip = float(xflip) # Probability multiplier for x-flip. 150 | self.yflip = float(yflip) # Probability multiplier for y-flip. 151 | self.rotate_int = float(rotate_int) # Probability multiplier for integer rotation. 152 | self.translate_int = float(translate_int) # Probability multiplier for integer translation. 153 | self.translate_int_max = float(translate_int_max) # Range of integer translation, relative to image dimensions. 154 | 155 | # Geometric transformations. 156 | self.scale = float(scale) # Probability multiplier for isotropic scaling. 157 | self.rotate_frac = float(rotate_frac) # Probability multiplier for fractional rotation. 158 | self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. 159 | self.translate_frac = float(translate_frac) # Probability multiplier for fractional translation. 160 | self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. 161 | self.rotate_frac_max = float(rotate_frac_max) # Range of fractional rotation, 1 = full circle. 162 | self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. 163 | self.aniso_rotate_prob = float(aniso_rotate_prob) # Probability of doing anisotropic scaling w.r.t. rotated coordinate frame. 164 | self.translate_frac_std = float(translate_frac_std) # Standard deviation of frational translation, relative to image dimensions. 165 | 166 | # Color transformations. 167 | self.brightness = float(brightness) # Probability multiplier for brightness. 168 | self.contrast = float(contrast) # Probability multiplier for contrast. 169 | self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. 170 | self.hue = float(hue) # Probability multiplier for hue rotation. 171 | self.saturation = float(saturation) # Probability multiplier for saturation. 172 | self.brightness_std = float(brightness_std) # Standard deviation of brightness. 173 | self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. 174 | self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. 175 | self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. 176 | 177 | def __call__(self, images): 178 | N, C, H, W = images.shape 179 | device = images.device 180 | labels = [torch.zeros([images.shape[0], 0], device=device)] 181 | 182 | # --------------- 183 | # Pixel blitting. 184 | # --------------- 185 | 186 | if self.xflip > 0: 187 | w = torch.randint(2, [N, 1, 1, 1], device=device) 188 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.xflip * self.p, w, torch.zeros_like(w)) 189 | images = torch.where(w == 1, images.flip(3), images) 190 | labels += [w] 191 | 192 | if self.yflip > 0: 193 | w = torch.randint(2, [N, 1, 1, 1], device=device) 194 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.yflip * self.p, w, torch.zeros_like(w)) 195 | images = torch.where(w == 1, images.flip(2), images) 196 | labels += [w] 197 | 198 | if self.rotate_int > 0: 199 | w = torch.randint(4, [N, 1, 1, 1], device=device) 200 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.rotate_int * self.p, w, torch.zeros_like(w)) 201 | images = torch.where((w == 1) | (w == 2), images.flip(3), images) 202 | images = torch.where((w == 2) | (w == 3), images.flip(2), images) 203 | images = torch.where((w == 1) | (w == 3), images.transpose(2, 3), images) 204 | labels += [(w == 1) | (w == 2), (w == 2) | (w == 3)] 205 | 206 | if self.translate_int > 0: 207 | w = torch.rand([2, N, 1, 1, 1], device=device) * 2 - 1 208 | w = torch.where(torch.rand([1, N, 1, 1, 1], device=device) < self.translate_int * self.p, w, torch.zeros_like(w)) 209 | tx = w[0].mul(W * self.translate_int_max).round().to(torch.int64) 210 | ty = w[1].mul(H * self.translate_int_max).round().to(torch.int64) 211 | b, c, y, x = torch.meshgrid(*(torch.arange(x, device=device) for x in images.shape), indexing='ij') 212 | x = W - 1 - (W - 1 - (x - tx) % (W * 2 - 2)).abs() 213 | y = H - 1 - (H - 1 - (y + ty) % (H * 2 - 2)).abs() 214 | images = images.flatten()[(((b * C) + c) * H + y) * W + x] 215 | labels += [tx.div(W * self.translate_int_max), ty.div(H * self.translate_int_max)] 216 | 217 | # ------------------------------------------------ 218 | # Select parameters for geometric transformations. 219 | # ------------------------------------------------ 220 | 221 | I_3 = torch.eye(3, device=device) 222 | G_inv = I_3 223 | 224 | if self.scale > 0: 225 | w = torch.randn([N], device=device) 226 | w = torch.where(torch.rand([N], device=device) < self.scale * self.p, w, torch.zeros_like(w)) 227 | s = w.mul(self.scale_std).exp2() 228 | G_inv = G_inv @ scale2d_inv(s, s) 229 | labels += [w] 230 | 231 | if self.rotate_frac > 0: 232 | w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.rotate_frac_max) 233 | w = torch.where(torch.rand([N], device=device) < self.rotate_frac * self.p, w, torch.zeros_like(w)) 234 | G_inv = G_inv @ rotate2d_inv(-w) 235 | labels += [w.cos() - 1, w.sin()] 236 | 237 | if self.aniso > 0: 238 | w = torch.randn([N], device=device) 239 | r = (torch.rand([N], device=device) * 2 - 1) * np.pi 240 | w = torch.where(torch.rand([N], device=device) < self.aniso * self.p, w, torch.zeros_like(w)) 241 | r = torch.where(torch.rand([N], device=device) < self.aniso_rotate_prob, r, torch.zeros_like(r)) 242 | s = w.mul(self.aniso_std).exp2() 243 | G_inv = G_inv @ rotate2d_inv(r) @ scale2d_inv(s, 1 / s) @ rotate2d_inv(-r) 244 | labels += [w * r.cos(), w * r.sin()] 245 | 246 | if self.translate_frac > 0: 247 | w = torch.randn([2, N], device=device) 248 | w = torch.where(torch.rand([1, N], device=device) < self.translate_frac * self.p, w, torch.zeros_like(w)) 249 | G_inv = G_inv @ translate2d_inv(w[0].mul(W * self.translate_frac_std), w[1].mul(H * self.translate_frac_std)) 250 | labels += [w[0], w[1]] 251 | 252 | # ---------------------------------- 253 | # Execute geometric transformations. 254 | # ---------------------------------- 255 | 256 | if G_inv is not I_3: 257 | cx = (W - 1) / 2 258 | cy = (H - 1) / 2 259 | cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] 260 | cp = G_inv @ cp.t() # [batch, xyz, idx] 261 | Hz = np.asarray(wavelets['sym6'], dtype=np.float32) 262 | Hz_pad = len(Hz) // 4 263 | margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] 264 | margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] 265 | margin = margin + constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) 266 | margin = margin.max(constant([0, 0] * 2, device=device)) 267 | margin = margin.min(constant([W - 1, H - 1] * 2, device=device)) 268 | mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) 269 | 270 | # Pad image and adjust origin. 271 | images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') 272 | G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv 273 | 274 | # Upsample. 275 | conv_weight = constant(Hz[None, None, ::-1], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1]) 276 | conv_pad = (len(Hz) + 1) // 2 277 | images = torch.stack([images, torch.zeros_like(images)], dim=4).reshape(N, C, images.shape[2], -1)[:, :, :, :-1] 278 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], padding=[0,conv_pad]) 279 | images = torch.stack([images, torch.zeros_like(images)], dim=3).reshape(N, C, -1, images.shape[3])[:, :, :-1, :] 280 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], padding=[conv_pad,0]) 281 | G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) 282 | G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) 283 | 284 | # Execute transformation. 285 | shape = [N, C, (H + Hz_pad * 2) * 2, (W + Hz_pad * 2) * 2] 286 | G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) 287 | grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) 288 | images = torch.nn.functional.grid_sample(images, grid, mode='bilinear', padding_mode='zeros', align_corners=False) 289 | 290 | # Downsample and crop. 291 | conv_weight = constant(Hz[None, None, :], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1]) 292 | conv_pad = (len(Hz) - 1) // 2 293 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], stride=[1,2], padding=[0,conv_pad])[:, :, :, Hz_pad : -Hz_pad] 294 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], stride=[2,1], padding=[conv_pad,0])[:, :, Hz_pad : -Hz_pad, :] 295 | 296 | # -------------------------------------------- 297 | # Select parameters for color transformations. 298 | # -------------------------------------------- 299 | 300 | I_4 = torch.eye(4, device=device) 301 | M = I_4 302 | luma_axis = constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) 303 | 304 | if self.brightness > 0: 305 | w = torch.randn([N], device=device) 306 | w = torch.where(torch.rand([N], device=device) < self.brightness * self.p, w, torch.zeros_like(w)) 307 | b = w * self.brightness_std 308 | M = translate3d(b, b, b) @ M 309 | labels += [w] 310 | 311 | if self.contrast > 0: 312 | w = torch.randn([N], device=device) 313 | w = torch.where(torch.rand([N], device=device) < self.contrast * self.p, w, torch.zeros_like(w)) 314 | c = w.mul(self.contrast_std).exp2() 315 | M = scale3d(c, c, c) @ M 316 | labels += [w] 317 | 318 | if self.lumaflip > 0: 319 | w = torch.randint(2, [N, 1, 1], device=device) 320 | w = torch.where(torch.rand([N, 1, 1], device=device) < self.lumaflip * self.p, w, torch.zeros_like(w)) 321 | M = (I_4 - 2 * luma_axis.ger(luma_axis) * w) @ M 322 | labels += [w] 323 | 324 | if self.hue > 0: 325 | w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.hue_max) 326 | w = torch.where(torch.rand([N], device=device) < self.hue * self.p, w, torch.zeros_like(w)) 327 | M = rotate3d(luma_axis, w) @ M 328 | labels += [w.cos() - 1, w.sin()] 329 | 330 | if self.saturation > 0: 331 | w = torch.randn([N, 1, 1], device=device) 332 | w = torch.where(torch.rand([N, 1, 1], device=device) < self.saturation * self.p, w, torch.zeros_like(w)) 333 | M = (luma_axis.ger(luma_axis) + (I_4 - luma_axis.ger(luma_axis)) * w.mul(self.saturation_std).exp2()) @ M 334 | labels += [w] 335 | 336 | # ------------------------------ 337 | # Execute color transformations. 338 | # ------------------------------ 339 | 340 | if M is not I_4: 341 | images = images.reshape([N, C, H * W]) 342 | if C == 3: 343 | images = M[:, :3, :3] @ images + M[:, :3, 3:] 344 | elif C == 1: 345 | M = M[:, :3, :].mean(dim=1, keepdims=True) 346 | images = images * M[:, :, :3].sum(dim=2, keepdims=True) + M[:, :, 3:] 347 | else: 348 | raise ValueError('Image must be RGB (3 channels) or L (1 channel)') 349 | images = images.reshape([N, C, H, W]) 350 | 351 | labels = torch.cat([x.to(torch.float32).reshape(N, -1) for x in labels], dim=1) 352 | return images, labels 353 | 354 | #---------------------------------------------------------------------------- 355 | -------------------------------------------------------------------------------- /DiT/diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | 7 | import math 8 | 9 | import numpy as np 10 | import torch as th 11 | import enum 12 | 13 | from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl 14 | 15 | 16 | def mean_flat(tensor): 17 | """ 18 | Take the mean over all non-batch dimensions. 19 | """ 20 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 21 | 22 | 23 | class ModelMeanType(enum.Enum): 24 | """ 25 | Which type of output the model predicts. 26 | """ 27 | 28 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 29 | START_X = enum.auto() # the model predicts x_0 30 | EPSILON = enum.auto() # the model predicts epsilon 31 | 32 | 33 | class ModelVarType(enum.Enum): 34 | """ 35 | What is used as the model's output variance. 36 | The LEARNED_RANGE option has been added to allow the model to predict 37 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 38 | """ 39 | 40 | LEARNED = enum.auto() 41 | FIXED_SMALL = enum.auto() 42 | FIXED_LARGE = enum.auto() 43 | LEARNED_RANGE = enum.auto() 44 | 45 | 46 | class LossType(enum.Enum): 47 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 48 | RESCALED_MSE = ( 49 | enum.auto() 50 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 51 | KL = enum.auto() # use the variational lower-bound 52 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 53 | 54 | def is_vb(self): 55 | return self == LossType.KL or self == LossType.RESCALED_KL 56 | 57 | 58 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 59 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 60 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 61 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) 62 | return betas 63 | 64 | 65 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 66 | """ 67 | This is the deprecated API for creating beta schedules. 68 | See get_named_beta_schedule() for the new library of schedules. 69 | """ 70 | if beta_schedule == "quad": 71 | betas = ( 72 | np.linspace( 73 | beta_start ** 0.5, 74 | beta_end ** 0.5, 75 | num_diffusion_timesteps, 76 | dtype=np.float64, 77 | ) 78 | ** 2 79 | ) 80 | elif beta_schedule == "linear": 81 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 82 | elif beta_schedule == "warmup10": 83 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) 84 | elif beta_schedule == "warmup50": 85 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) 86 | elif beta_schedule == "const": 87 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 88 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 89 | betas = 1.0 / np.linspace( 90 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 91 | ) 92 | else: 93 | raise NotImplementedError(beta_schedule) 94 | assert betas.shape == (num_diffusion_timesteps,) 95 | return betas 96 | 97 | 98 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 99 | """ 100 | Get a pre-defined beta schedule for the given name. 101 | The beta schedule library consists of beta schedules which remain similar 102 | in the limit of num_diffusion_timesteps. 103 | Beta schedules may be added, but should not be removed or changed once 104 | they are committed to maintain backwards compatibility. 105 | """ 106 | if schedule_name == "linear": 107 | # Linear schedule from Ho et al, extended to work for any number of 108 | # diffusion steps. 109 | scale = 1000 / num_diffusion_timesteps 110 | return get_beta_schedule( 111 | "linear", 112 | beta_start=scale * 0.0001, 113 | beta_end=scale * 0.02, 114 | num_diffusion_timesteps=num_diffusion_timesteps, 115 | ) 116 | elif schedule_name == "squaredcos_cap_v2": 117 | return betas_for_alpha_bar( 118 | num_diffusion_timesteps, 119 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 120 | ) 121 | else: 122 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 123 | 124 | 125 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 126 | """ 127 | Create a beta schedule that discretizes the given alpha_t_bar function, 128 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 129 | :param num_diffusion_timesteps: the number of betas to produce. 130 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 131 | produces the cumulative product of (1-beta) up to that 132 | part of the diffusion process. 133 | :param max_beta: the maximum beta to use; use values lower than 1 to 134 | prevent singularities. 135 | """ 136 | betas = [] 137 | for i in range(num_diffusion_timesteps): 138 | t1 = i / num_diffusion_timesteps 139 | t2 = (i + 1) / num_diffusion_timesteps 140 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 141 | return np.array(betas) 142 | 143 | 144 | class GaussianDiffusion: 145 | """ 146 | Utilities for training and sampling diffusion models. 147 | Original ported from this codebase: 148 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 149 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 150 | starting at T and going to 1. 151 | """ 152 | 153 | def __init__( 154 | self, 155 | *, 156 | betas, 157 | model_mean_type, 158 | model_var_type, 159 | loss_type 160 | ): 161 | 162 | self.model_mean_type = model_mean_type 163 | self.model_var_type = model_var_type 164 | self.loss_type = loss_type 165 | 166 | # Use float64 for accuracy. 167 | betas = np.array(betas, dtype=np.float64) 168 | self.betas = betas 169 | assert len(betas.shape) == 1, "betas must be 1-D" 170 | assert (betas > 0).all() and (betas <= 1).all() 171 | 172 | self.num_timesteps = int(betas.shape[0]) 173 | 174 | alphas = 1.0 - betas 175 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 176 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 177 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 178 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 179 | 180 | # calculations for diffusion q(x_t | x_{t-1}) and others 181 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 182 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 183 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 184 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 185 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 186 | 187 | # calculations for posterior q(x_{t-1} | x_t, x_0) 188 | self.posterior_variance = ( 189 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 190 | ) 191 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 192 | self.posterior_log_variance_clipped = np.log( 193 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 194 | ) if len(self.posterior_variance) > 1 else np.array([]) 195 | 196 | self.posterior_mean_coef1 = ( 197 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 198 | ) 199 | self.posterior_mean_coef2 = ( 200 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) 201 | ) 202 | 203 | def q_mean_variance(self, x_start, t): 204 | """ 205 | Get the distribution q(x_t | x_0). 206 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 207 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 208 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 209 | """ 210 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 211 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 212 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) 213 | return mean, variance, log_variance 214 | 215 | def q_sample(self, x_start, t, noise=None): 216 | """ 217 | Diffuse the data for a given number of diffusion steps. 218 | In other words, sample from q(x_t | x_0). 219 | :param x_start: the initial data batch. 220 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 221 | :param noise: if specified, the split-out normal noise. 222 | :return: A noisy version of x_start. 223 | """ 224 | if noise is None: 225 | noise = th.randn_like(x_start) 226 | assert noise.shape == x_start.shape 227 | return ( 228 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 229 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 230 | ) 231 | 232 | def q_posterior_mean_variance(self, x_start, x_t, t): 233 | """ 234 | Compute the mean and variance of the diffusion posterior: 235 | q(x_{t-1} | x_t, x_0) 236 | """ 237 | assert x_start.shape == x_t.shape 238 | posterior_mean = ( 239 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 240 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 241 | ) 242 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 243 | posterior_log_variance_clipped = _extract_into_tensor( 244 | self.posterior_log_variance_clipped, t, x_t.shape 245 | ) 246 | assert ( 247 | posterior_mean.shape[0] 248 | == posterior_variance.shape[0] 249 | == posterior_log_variance_clipped.shape[0] 250 | == x_start.shape[0] 251 | ) 252 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 253 | 254 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): 255 | """ 256 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 257 | the initial x, x_0. 258 | :param model: the model, which takes a signal and a batch of timesteps 259 | as input. 260 | :param x: the [N x C x ...] tensor at time t. 261 | :param t: a 1-D Tensor of timesteps. 262 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 263 | :param denoised_fn: if not None, a function which applies to the 264 | x_start prediction before it is used to sample. Applies before 265 | clip_denoised. 266 | :param model_kwargs: if not None, a dict of extra keyword arguments to 267 | pass to the model. This can be used for conditioning. 268 | :return: a dict with the following keys: 269 | - 'mean': the model mean output. 270 | - 'variance': the model variance output. 271 | - 'log_variance': the log of 'variance'. 272 | - 'pred_xstart': the prediction for x_0. 273 | """ 274 | if model_kwargs is None: 275 | model_kwargs = {} 276 | 277 | B, C = x.shape[:2] 278 | assert t.shape == (B,) 279 | model_output = model(x, t, **model_kwargs) 280 | if isinstance(model_output, tuple): 281 | model_output, extra = model_output 282 | else: 283 | extra = None 284 | 285 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 286 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 287 | model_output, model_var_values = th.split(model_output, C, dim=1) 288 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) 289 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 290 | # The model_var_values is [-1, 1] for [min_var, max_var]. 291 | frac = (model_var_values + 1) / 2 292 | model_log_variance = frac * max_log + (1 - frac) * min_log 293 | model_variance = th.exp(model_log_variance) 294 | else: 295 | model_variance, model_log_variance = { 296 | # for fixedlarge, we set the initial (log-)variance like so 297 | # to get a better decoder log likelihood. 298 | ModelVarType.FIXED_LARGE: ( 299 | np.append(self.posterior_variance[1], self.betas[1:]), 300 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 301 | ), 302 | ModelVarType.FIXED_SMALL: ( 303 | self.posterior_variance, 304 | self.posterior_log_variance_clipped, 305 | ), 306 | }[self.model_var_type] 307 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 308 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 309 | 310 | def process_xstart(x): 311 | if denoised_fn is not None: 312 | x = denoised_fn(x) 313 | if clip_denoised: 314 | return x.clamp(-1, 1) 315 | return x 316 | 317 | if self.model_mean_type == ModelMeanType.START_X: 318 | pred_xstart = process_xstart(model_output) 319 | else: 320 | pred_xstart = process_xstart( 321 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 322 | ) 323 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 324 | 325 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 326 | return { 327 | "mean": model_mean, 328 | "variance": model_variance, 329 | "log_variance": model_log_variance, 330 | "pred_xstart": pred_xstart, 331 | "extra": extra, 332 | } 333 | 334 | def _predict_xstart_from_eps(self, x_t, t, eps): 335 | assert x_t.shape == eps.shape 336 | return ( 337 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 338 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 339 | ) 340 | 341 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 342 | return ( 343 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart 344 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 345 | 346 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 347 | """ 348 | Compute the mean for the previous step, given a function cond_fn that 349 | computes the gradient of a conditional log probability with respect to 350 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 351 | condition on y. 352 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 353 | """ 354 | gradient = cond_fn(x, t, **model_kwargs) 355 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 356 | return new_mean 357 | 358 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 359 | """ 360 | Compute what the p_mean_variance output would have been, should the 361 | model's score function be conditioned by cond_fn. 362 | See condition_mean() for details on cond_fn. 363 | Unlike condition_mean(), this instead uses the conditioning strategy 364 | from Song et al (2020). 365 | """ 366 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 367 | 368 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 369 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) 370 | 371 | out = p_mean_var.copy() 372 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 373 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) 374 | return out 375 | 376 | def p_sample( 377 | self, 378 | model, 379 | x, 380 | t, 381 | clip_denoised=True, 382 | denoised_fn=None, 383 | cond_fn=None, 384 | model_kwargs=None, 385 | ): 386 | """ 387 | Sample x_{t-1} from the model at the given timestep. 388 | :param model: the model to sample from. 389 | :param x: the current tensor at x_{t-1}. 390 | :param t: the value of t, starting at 0 for the first diffusion step. 391 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 392 | :param denoised_fn: if not None, a function which applies to the 393 | x_start prediction before it is used to sample. 394 | :param cond_fn: if not None, this is a gradient function that acts 395 | similarly to the model. 396 | :param model_kwargs: if not None, a dict of extra keyword arguments to 397 | pass to the model. This can be used for conditioning. 398 | :return: a dict containing the following keys: 399 | - 'sample': a random sample from the model. 400 | - 'pred_xstart': a prediction of x_0. 401 | """ 402 | out = self.p_mean_variance( 403 | model, 404 | x, 405 | t, 406 | clip_denoised=clip_denoised, 407 | denoised_fn=denoised_fn, 408 | model_kwargs=model_kwargs, 409 | ) 410 | noise = th.randn_like(x) 411 | nonzero_mask = ( 412 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 413 | ) # no noise when t == 0 414 | if cond_fn is not None: 415 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) 416 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 417 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 418 | 419 | def p_sample_loop( 420 | self, 421 | model, 422 | shape, 423 | noise=None, 424 | clip_denoised=True, 425 | denoised_fn=None, 426 | cond_fn=None, 427 | model_kwargs=None, 428 | device=None, 429 | progress=False, 430 | ): 431 | """ 432 | Generate samples from the model. 433 | :param model: the model module. 434 | :param shape: the shape of the samples, (N, C, H, W). 435 | :param noise: if specified, the noise from the encoder to sample. 436 | Should be of the same shape as `shape`. 437 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 438 | :param denoised_fn: if not None, a function which applies to the 439 | x_start prediction before it is used to sample. 440 | :param cond_fn: if not None, this is a gradient function that acts 441 | similarly to the model. 442 | :param model_kwargs: if not None, a dict of extra keyword arguments to 443 | pass to the model. This can be used for conditioning. 444 | :param device: if specified, the device to create the samples on. 445 | If not specified, use a model parameter's device. 446 | :param progress: if True, show a tqdm progress bar. 447 | :return: a non-differentiable batch of samples. 448 | """ 449 | final = None 450 | for sample in self.p_sample_loop_progressive( 451 | model, 452 | shape, 453 | noise=noise, 454 | clip_denoised=clip_denoised, 455 | denoised_fn=denoised_fn, 456 | cond_fn=cond_fn, 457 | model_kwargs=model_kwargs, 458 | device=device, 459 | progress=progress, 460 | ): 461 | final = sample 462 | return final["sample"] 463 | 464 | def p_sample_loop_progressive( 465 | self, 466 | model, 467 | shape, 468 | noise=None, 469 | clip_denoised=True, 470 | denoised_fn=None, 471 | cond_fn=None, 472 | model_kwargs=None, 473 | device=None, 474 | progress=False, 475 | ): 476 | """ 477 | Generate samples from the model and yield intermediate samples from 478 | each timestep of diffusion. 479 | Arguments are the same as p_sample_loop(). 480 | Returns a generator over dicts, where each dict is the return value of 481 | p_sample(). 482 | """ 483 | if device is None: 484 | device = next(model.parameters()).device 485 | assert isinstance(shape, (tuple, list)) 486 | if noise is not None: 487 | img = noise 488 | else: 489 | img = th.randn(*shape, device=device) 490 | indices = list(range(self.num_timesteps))[::-1] 491 | 492 | if progress: 493 | # Lazy import so that we don't depend on tqdm. 494 | from tqdm.auto import tqdm 495 | 496 | indices = tqdm(indices) 497 | 498 | for i in indices: 499 | t = th.tensor([i] * shape[0], device=device) 500 | with th.no_grad(): 501 | out = self.p_sample( 502 | model, 503 | img, 504 | t, 505 | clip_denoised=clip_denoised, 506 | denoised_fn=denoised_fn, 507 | cond_fn=cond_fn, 508 | model_kwargs=model_kwargs, 509 | ) 510 | yield out 511 | img = out["sample"] 512 | 513 | def ddim_sample( 514 | self, 515 | model, 516 | x, 517 | t, 518 | clip_denoised=True, 519 | denoised_fn=None, 520 | cond_fn=None, 521 | model_kwargs=None, 522 | eta=0.0, 523 | ): 524 | """ 525 | Sample x_{t-1} from the model using DDIM. 526 | Same usage as p_sample(). 527 | """ 528 | out = self.p_mean_variance( 529 | model, 530 | x, 531 | t, 532 | clip_denoised=clip_denoised, 533 | denoised_fn=denoised_fn, 534 | model_kwargs=model_kwargs, 535 | ) 536 | if cond_fn is not None: 537 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 538 | 539 | # Usually our model outputs epsilon, but we re-derive it 540 | # in case we used x_start or x_prev prediction. 541 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 542 | 543 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 544 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 545 | sigma = ( 546 | eta 547 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 548 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 549 | ) 550 | # Equation 12. 551 | noise = th.randn_like(x) 552 | mean_pred = ( 553 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 554 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 555 | ) 556 | nonzero_mask = ( 557 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 558 | ) # no noise when t == 0 559 | sample = mean_pred + nonzero_mask * sigma * noise 560 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 561 | 562 | def ddim_reverse_sample( 563 | self, 564 | model, 565 | x, 566 | t, 567 | clip_denoised=True, 568 | denoised_fn=None, 569 | cond_fn=None, 570 | model_kwargs=None, 571 | eta=0.0, 572 | ): 573 | """ 574 | Sample x_{t+1} from the model using DDIM reverse ODE. 575 | """ 576 | assert eta == 0.0, "Reverse ODE only for deterministic path" 577 | out = self.p_mean_variance( 578 | model, 579 | x, 580 | t, 581 | clip_denoised=clip_denoised, 582 | denoised_fn=denoised_fn, 583 | model_kwargs=model_kwargs, 584 | ) 585 | if cond_fn is not None: 586 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 587 | # Usually our model outputs epsilon, but we re-derive it 588 | # in case we used x_start or x_prev prediction. 589 | eps = ( 590 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 591 | - out["pred_xstart"] 592 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 593 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 594 | 595 | # Equation 12. reversed 596 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps 597 | 598 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 599 | 600 | def ddim_sample_loop( 601 | self, 602 | model, 603 | shape, 604 | noise=None, 605 | clip_denoised=True, 606 | denoised_fn=None, 607 | cond_fn=None, 608 | model_kwargs=None, 609 | device=None, 610 | progress=False, 611 | eta=0.0, 612 | ): 613 | """ 614 | Generate samples from the model using DDIM. 615 | Same usage as p_sample_loop(). 616 | """ 617 | final = None 618 | for sample in self.ddim_sample_loop_progressive( 619 | model, 620 | shape, 621 | noise=noise, 622 | clip_denoised=clip_denoised, 623 | denoised_fn=denoised_fn, 624 | cond_fn=cond_fn, 625 | model_kwargs=model_kwargs, 626 | device=device, 627 | progress=progress, 628 | eta=eta, 629 | ): 630 | final = sample 631 | return final["sample"] 632 | 633 | def ddim_sample_loop_progressive( 634 | self, 635 | model, 636 | shape, 637 | noise=None, 638 | clip_denoised=True, 639 | denoised_fn=None, 640 | cond_fn=None, 641 | model_kwargs=None, 642 | device=None, 643 | progress=False, 644 | eta=0.0, 645 | ): 646 | """ 647 | Use DDIM to sample from the model and yield intermediate samples from 648 | each timestep of DDIM. 649 | Same usage as p_sample_loop_progressive(). 650 | """ 651 | if device is None: 652 | device = next(model.parameters()).device 653 | assert isinstance(shape, (tuple, list)) 654 | if noise is not None: 655 | img = noise 656 | else: 657 | img = th.randn(*shape, device=device) 658 | indices = list(range(self.num_timesteps))[::-1] 659 | 660 | if progress: 661 | # Lazy import so that we don't depend on tqdm. 662 | from tqdm.auto import tqdm 663 | 664 | indices = tqdm(indices) 665 | 666 | for i in indices: 667 | t = th.tensor([i] * shape[0], device=device) 668 | with th.no_grad(): 669 | out = self.ddim_sample( 670 | model, 671 | img, 672 | t, 673 | clip_denoised=clip_denoised, 674 | denoised_fn=denoised_fn, 675 | cond_fn=cond_fn, 676 | model_kwargs=model_kwargs, 677 | eta=eta, 678 | ) 679 | yield out 680 | img = out["sample"] 681 | 682 | def _vb_terms_bpd( 683 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 684 | ): 685 | """ 686 | Get a term for the variational lower-bound. 687 | The resulting units are bits (rather than nats, as one might expect). 688 | This allows for comparison to other papers. 689 | :return: a dict with the following keys: 690 | - 'output': a shape [N] tensor of NLLs or KLs. 691 | - 'pred_xstart': the x_0 predictions. 692 | """ 693 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 694 | x_start=x_start, x_t=x_t, t=t 695 | ) 696 | out = self.p_mean_variance( 697 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 698 | ) 699 | kl = normal_kl( 700 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 701 | ) 702 | kl = mean_flat(kl) / np.log(2.0) 703 | 704 | decoder_nll = -discretized_gaussian_log_likelihood( 705 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 706 | ) 707 | assert decoder_nll.shape == x_start.shape 708 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 709 | 710 | # At the first timestep return the decoder NLL, 711 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 712 | output = th.where((t == 0), decoder_nll, kl) 713 | return {"output": output, "pred_xstart": out["pred_xstart"]} 714 | 715 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): 716 | """ 717 | Compute training losses for a single timestep. 718 | :param model: the model to evaluate loss on. 719 | :param x_start: the [N x C x ...] tensor of inputs. 720 | :param t: a batch of timestep indices. 721 | :param model_kwargs: if not None, a dict of extra keyword arguments to 722 | pass to the model. This can be used for conditioning. 723 | :param noise: if specified, the specific Gaussian noise to try to remove. 724 | :return: a dict with the key "loss" containing a tensor of shape [N]. 725 | Some mean or variance settings may also have other keys. 726 | """ 727 | if model_kwargs is None: 728 | model_kwargs = {} 729 | if noise is None: 730 | noise = th.randn_like(x_start) 731 | x_t = self.q_sample(x_start, t, noise=noise) 732 | 733 | terms = {} 734 | 735 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 736 | terms["loss"] = self._vb_terms_bpd( 737 | model=model, 738 | x_start=x_start, 739 | x_t=x_t, 740 | t=t, 741 | clip_denoised=False, 742 | model_kwargs=model_kwargs, 743 | )["output"] 744 | if self.loss_type == LossType.RESCALED_KL: 745 | terms["loss"] *= self.num_timesteps 746 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 747 | model_output = model(x_t, t, **model_kwargs) 748 | 749 | if self.model_var_type in [ 750 | ModelVarType.LEARNED, 751 | ModelVarType.LEARNED_RANGE, 752 | ]: 753 | B, C = x_t.shape[:2] 754 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 755 | model_output, model_var_values = th.split(model_output, C, dim=1) 756 | # Learn the variance using the variational bound, but don't let 757 | # it affect our mean prediction. 758 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 759 | terms["vb"] = self._vb_terms_bpd( 760 | model=lambda *args, r=frozen_out: r, 761 | x_start=x_start, 762 | x_t=x_t, 763 | t=t, 764 | clip_denoised=False, 765 | )["output"] 766 | if self.loss_type == LossType.RESCALED_MSE: 767 | # Divide by 1000 for equivalence with initial implementation. 768 | # Without a factor of 1/1000, the VB term hurts the MSE term. 769 | terms["vb"] *= self.num_timesteps / 1000.0 770 | 771 | target = { 772 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 773 | x_start=x_start, x_t=x_t, t=t 774 | )[0], 775 | ModelMeanType.START_X: x_start, 776 | ModelMeanType.EPSILON: noise, 777 | }[self.model_mean_type] 778 | assert model_output.shape == target.shape == x_start.shape 779 | terms["mse"] = mean_flat((target - model_output) ** 2) 780 | if "vb" in terms: 781 | terms["loss"] = terms["mse"] + terms["vb"] 782 | else: 783 | terms["loss"] = terms["mse"] 784 | else: 785 | raise NotImplementedError(self.loss_type) 786 | 787 | return terms 788 | 789 | def _prior_bpd(self, x_start): 790 | """ 791 | Get the prior KL term for the variational lower-bound, measured in 792 | bits-per-dim. 793 | This term can't be optimized, as it only depends on the encoder. 794 | :param x_start: the [N x C x ...] tensor of inputs. 795 | :return: a batch of [N] KL values (in bits), one per batch element. 796 | """ 797 | batch_size = x_start.shape[0] 798 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 799 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 800 | kl_prior = normal_kl( 801 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 802 | ) 803 | return mean_flat(kl_prior) / np.log(2.0) 804 | 805 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 806 | """ 807 | Compute the entire variational lower-bound, measured in bits-per-dim, 808 | as well as other related quantities. 809 | :param model: the model to evaluate loss on. 810 | :param x_start: the [N x C x ...] tensor of inputs. 811 | :param clip_denoised: if True, clip denoised samples. 812 | :param model_kwargs: if not None, a dict of extra keyword arguments to 813 | pass to the model. This can be used for conditioning. 814 | :return: a dict containing the following keys: 815 | - total_bpd: the total variational lower-bound, per batch element. 816 | - prior_bpd: the prior term in the lower-bound. 817 | - vb: an [N x T] tensor of terms in the lower-bound. 818 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 819 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 820 | """ 821 | device = x_start.device 822 | batch_size = x_start.shape[0] 823 | 824 | vb = [] 825 | xstart_mse = [] 826 | mse = [] 827 | for t in list(range(self.num_timesteps))[::-1]: 828 | t_batch = th.tensor([t] * batch_size, device=device) 829 | noise = th.randn_like(x_start) 830 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 831 | # Calculate VLB term at the current timestep 832 | with th.no_grad(): 833 | out = self._vb_terms_bpd( 834 | model, 835 | x_start=x_start, 836 | x_t=x_t, 837 | t=t_batch, 838 | clip_denoised=clip_denoised, 839 | model_kwargs=model_kwargs, 840 | ) 841 | vb.append(out["output"]) 842 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 843 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 844 | mse.append(mean_flat((eps - noise) ** 2)) 845 | 846 | vb = th.stack(vb, dim=1) 847 | xstart_mse = th.stack(xstart_mse, dim=1) 848 | mse = th.stack(mse, dim=1) 849 | 850 | prior_bpd = self._prior_bpd(x_start) 851 | total_bpd = vb.sum(dim=1) + prior_bpd 852 | return { 853 | "total_bpd": total_bpd, 854 | "prior_bpd": prior_bpd, 855 | "vb": vb, 856 | "xstart_mse": xstart_mse, 857 | "mse": mse, 858 | } 859 | 860 | 861 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 862 | """ 863 | Extract values from a 1-D numpy array for a batch of indices. 864 | :param arr: the 1-D numpy array. 865 | :param timesteps: a tensor of indices into the array to extract. 866 | :param broadcast_shape: a larger shape of K dimensions with the batch 867 | dimension equal to the length of timesteps. 868 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 869 | """ 870 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 871 | while len(res.shape) < len(broadcast_shape): 872 | res = res[..., None] 873 | return res + th.zeros(broadcast_shape, device=timesteps.device) 874 | --------------------------------------------------------------------------------