├── model
├── __init__.py
├── models.py
├── block.py
├── wideresnet_noise_song.py
├── DDPM.py
├── EDM.py
├── unet.py
└── augment.py
├── config
├── DDPM_ddpmpp.yaml
├── DDPM_ddpm.yaml
├── EDM_ddpmpp.yaml
├── EDM_ddpm.yaml
└── EDM_ddpmpp_aug.yaml
├── extract_cifar10_pngs.ipynb
├── DiT
├── diffusion
│ ├── __init__.py
│ ├── diffusion_utils.py
│ ├── respace.py
│ ├── timestep_sampler.py
│ └── gaussian_diffusion.py
├── download.py
├── README.md
├── vae_preprocessing.py
├── linear.py
└── models.py
├── utils.py
├── datasets.py
├── sample.py
├── contrastive.py
├── train.py
├── noisy_classifier_DDAE.py
├── linear.py
├── README.md
└── noisy_classifier_WRN.py
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/models.py:
--------------------------------------------------------------------------------
1 | from .DDPM import DDPM
2 | from .EDM import EDM
3 | from .unet import UNet
4 |
5 | CLASSES = {
6 | cls.__name__: cls
7 | for cls in [DDPM, EDM, UNet]
8 | }
9 |
10 |
11 | def get_models_class(model_type, net_type):
12 | return CLASSES[model_type], CLASSES[net_type]
13 |
--------------------------------------------------------------------------------
/config/DDPM_ddpmpp.yaml:
--------------------------------------------------------------------------------
1 | # dataset params
2 | dataset: 'cifar'
3 | classes: 10
4 |
5 | # model params
6 | model_type: 'DDPM'
7 | net_type: 'UNet'
8 | diffusion:
9 | n_T: 1000
10 | betas: [1.0e-4, 0.02]
11 | network:
12 | image_shape: [3, 32, 32]
13 | n_channels: 128
14 | ch_mults: [2, 2, 2]
15 | is_attn: [False, True, False]
16 | dropout: 0.1
17 | n_blocks: 4
18 | use_res_for_updown: True
19 |
20 | # training params
21 | n_epoch: 2000
22 | batch_size: 64
23 | lrate: 1.0e-4
24 | warm_epoch: 13
25 | load_epoch: -1
26 | flip: True
27 | ema: 0.9999
28 |
29 | # testing params
30 | n_sample: 30
31 | save_dir: './output_DDPM_ddpmpp'
32 | save_model: True
33 |
34 | # linear probe
35 | linear:
36 | n_epoch: 15
37 | batch_size: 128
38 | lrate: 1.0e-3
39 | timestep: 11
40 | blockname: 'out_6'
41 |
--------------------------------------------------------------------------------
/config/DDPM_ddpm.yaml:
--------------------------------------------------------------------------------
1 | # dataset params
2 | dataset: 'cifar'
3 | classes: 10
4 |
5 | # model params
6 | model_type: 'DDPM'
7 | net_type: 'UNet'
8 | diffusion:
9 | n_T: 1000
10 | betas: [1.0e-4, 0.02]
11 | network:
12 | image_shape: [3, 32, 32]
13 | n_channels: 128
14 | ch_mults: [1, 2, 2, 2]
15 | is_attn: [False, True, False, False]
16 | dropout: 0.1
17 | n_blocks: 2
18 | use_res_for_updown: False
19 |
20 | # training params
21 | n_epoch: 2000
22 | batch_size: 128
23 | lrate: 1.0e-4
24 | warm_epoch: 13
25 | load_epoch: -1
26 | flip: True
27 | ema: 0.9999
28 |
29 | # testing params
30 | n_sample: 30
31 | save_dir: './output_DDPM_ddpm'
32 | save_model: True
33 |
34 | # linear probe
35 | linear:
36 | n_epoch: 15
37 | batch_size: 128
38 | lrate: 1.0e-3
39 | timestep: 11
40 | blockname: 'out_6'
41 |
--------------------------------------------------------------------------------
/config/EDM_ddpmpp.yaml:
--------------------------------------------------------------------------------
1 | # dataset params
2 | dataset: 'cifar'
3 | classes: 10
4 |
5 | # model params
6 | model_type: 'EDM'
7 | net_type: 'UNet'
8 | diffusion:
9 | sigma_data: 0.5
10 | p_mean: -1.2
11 | p_std: 1.2
12 | sigma_min: 0.002
13 | sigma_max: 80
14 | rho: 7
15 | S_min: 0.01
16 | S_max: 1
17 | S_noise: 1.007
18 | network:
19 | image_shape: [3, 32, 32]
20 | n_channels: 128
21 | ch_mults: [2, 2, 2]
22 | is_attn: [False, True, False]
23 | dropout: 0.13
24 | n_blocks: 4
25 | use_res_for_updown: True
26 |
27 | # training params
28 | n_epoch: 2000
29 | batch_size: 64
30 | lrate: 1.0e-4
31 | warm_epoch: 200
32 | load_epoch: -1
33 | flip: True
34 | ema: 0.9993
35 |
36 | # testing params
37 | n_sample: 30
38 | save_dir: './output_EDM_ddpmpp'
39 | save_model: True
40 |
41 | # linear probe
42 | linear:
43 | n_epoch: 15
44 | batch_size: 128
45 | lrate: 1.0e-3
46 | timestep: 4
47 | blockname: 'out_7'
48 |
--------------------------------------------------------------------------------
/config/EDM_ddpm.yaml:
--------------------------------------------------------------------------------
1 | # dataset params
2 | dataset: 'cifar'
3 | classes: 10
4 |
5 | # model params
6 | model_type: 'EDM'
7 | net_type: 'UNet'
8 | diffusion:
9 | sigma_data: 0.5
10 | p_mean: -1.2
11 | p_std: 1.2
12 | sigma_min: 0.002
13 | sigma_max: 80
14 | rho: 7
15 | S_min: 0.01
16 | S_max: 1
17 | S_noise: 1.007
18 | network:
19 | image_shape: [3, 32, 32]
20 | n_channels: 128
21 | ch_mults: [1, 2, 2, 2]
22 | is_attn: [False, True, False, False]
23 | dropout: 0.13
24 | n_blocks: 2
25 | use_res_for_updown: False
26 |
27 | # training params
28 | n_epoch: 2000
29 | batch_size: 128
30 | lrate: 1.0e-4
31 | warm_epoch: 200
32 | load_epoch: -1
33 | flip: True
34 | ema: 0.9993
35 |
36 | # testing params
37 | n_sample: 30
38 | save_dir: './output_EDM_ddpm'
39 | save_model: True
40 |
41 | # linear probe
42 | linear:
43 | n_epoch: 15
44 | batch_size: 128
45 | lrate: 1.0e-3
46 | timestep: 4
47 | blockname: 'out_7'
48 |
--------------------------------------------------------------------------------
/config/EDM_ddpmpp_aug.yaml:
--------------------------------------------------------------------------------
1 | # dataset params
2 | dataset: 'cifar'
3 | classes: 10
4 |
5 | # model params
6 | model_type: 'EDM'
7 | net_type: 'UNet'
8 | diffusion:
9 | sigma_data: 0.5
10 | p_mean: -1.2
11 | p_std: 1.2
12 | sigma_min: 0.002
13 | sigma_max: 80
14 | rho: 7
15 | S_min: 0.01
16 | S_max: 1
17 | S_noise: 1.007
18 | augment_prob: 0.12
19 | network:
20 | image_shape: [3, 32, 32]
21 | n_channels: 128
22 | ch_mults: [2, 2, 2]
23 | is_attn: [False, True, False]
24 | dropout: 0.13
25 | n_blocks: 4
26 | use_res_for_updown: True
27 | augment_dim: 9
28 |
29 | # training params
30 | n_epoch: 4000
31 | batch_size: 64
32 | lrate: 1.0e-4
33 | warm_epoch: 200
34 | load_epoch: -1
35 | flip: True
36 | ema: 0.9993
37 |
38 | # testing params
39 | n_sample: 30
40 | save_dir: './output_EDM_ddpmpp_aug'
41 | save_model: True
42 |
43 | # linear probe
44 | linear:
45 | n_epoch: 15
46 | batch_size: 128
47 | lrate: 1.0e-3
48 | timestep: 4
49 | blockname: 'out_7'
50 |
--------------------------------------------------------------------------------
/extract_cifar10_pngs.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "from torchvision.datasets import CIFAR10\n",
11 | "import os\n",
12 | "\n",
13 | "train_set = CIFAR10(\"./data\", train=True, download=True)\n",
14 | "print(\"CIFAR10 train dataset:\", len(train_set))\n",
15 | "\n",
16 | "images = []\n",
17 | "labels = []\n",
18 | "for img, label in train_set:\n",
19 | " images.append(img)\n",
20 | " labels.append(label)\n",
21 | "\n",
22 | "labels = torch.tensor(labels)\n",
23 | "for i in range(10):\n",
24 | " assert (labels == i).sum() == 5000\n",
25 | "\n",
26 | "output_dir = \"./data/cifar10-pngs/\"\n",
27 | "for i, pil in enumerate(images):\n",
28 | " pil.save(os.path.join(output_dir, \"{:05d}.png\".format(i)))"
29 | ]
30 | }
31 | ],
32 | "metadata": {
33 | "kernelspec": {
34 | "display_name": "Python 3.8.5 ('gan')",
35 | "language": "python",
36 | "name": "python3"
37 | },
38 | "language_info": {
39 | "codemirror_mode": {
40 | "name": "ipython",
41 | "version": 3
42 | },
43 | "file_extension": ".py",
44 | "mimetype": "text/x-python",
45 | "name": "python",
46 | "nbconvert_exporter": "python",
47 | "pygments_lexer": "ipython3",
48 | "version": "3.8.5"
49 | },
50 | "orig_nbformat": 4,
51 | "vscode": {
52 | "interpreter": {
53 | "hash": "da18559f301618e6e9fab00c6d05e566e4e63dfec8a595f965f0c783b8f75048"
54 | }
55 | }
56 | },
57 | "nbformat": 4,
58 | "nbformat_minor": 2
59 | }
60 |
--------------------------------------------------------------------------------
/DiT/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_kl=False,
14 | sigma_small=False,
15 | predict_xstart=False,
16 | learn_sigma=True,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000
19 | ):
20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21 | if use_kl:
22 | loss_type = gd.LossType.RESCALED_KL
23 | elif rescale_learned_sigmas:
24 | loss_type = gd.LossType.RESCALED_MSE
25 | else:
26 | loss_type = gd.LossType.MSE
27 | if timestep_respacing is None or timestep_respacing == "":
28 | timestep_respacing = [diffusion_steps]
29 | return SpacedDiffusion(
30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31 | betas=betas,
32 | model_mean_type=(
33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34 | ),
35 | model_var_type=(
36 | (
37 | gd.ModelVarType.FIXED_LARGE
38 | if not sigma_small
39 | else gd.ModelVarType.FIXED_SMALL
40 | )
41 | if not learn_sigma
42 | else gd.ModelVarType.LEARNED_RANGE
43 | ),
44 | loss_type=loss_type
45 | # rescale_timesteps=rescale_timesteps,
46 | )
47 |
--------------------------------------------------------------------------------
/DiT/download.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Functions for downloading pre-trained DiT models
9 | """
10 | from torchvision.datasets.utils import download_url
11 | import torch
12 | import os
13 |
14 |
15 | pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'}
16 |
17 |
18 | def find_model(model_name):
19 | """
20 | Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
21 | """
22 | if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
23 | return download_model(model_name)
24 | else: # Load a custom DiT checkpoint:
25 | assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}'
26 | checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
27 | if "ema" in checkpoint: # supports checkpoints from train.py
28 | checkpoint = checkpoint["ema"]
29 | return checkpoint
30 |
31 |
32 | def download_model(model_name):
33 | """
34 | Downloads a pre-trained DiT model from the web.
35 | """
36 | assert model_name in pretrained_models
37 | local_path = f'pretrained_models/{model_name}'
38 | if not os.path.isfile(local_path):
39 | os.makedirs('pretrained_models', exist_ok=True)
40 | web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}'
41 | download_url(web_path, 'pretrained_models')
42 | model = torch.load(local_path, map_location=lambda storage, loc: storage)
43 | return model
44 |
45 |
46 | if __name__ == "__main__":
47 | # Download all DiT checkpoints
48 | for model in pretrained_models:
49 | download_model(model)
50 | print('Done.')
51 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import torch
5 | import torch.distributed as dist
6 | from torch.utils.data import DataLoader
7 |
8 | # ===== Configs =====
9 |
10 | class Config(object):
11 | def __init__(self, dic):
12 | for key in dic:
13 | setattr(self, key, dic[key])
14 |
15 | def get_optimizer(parameters, opt, lr):
16 | if not hasattr(opt, 'optim'):
17 | return torch.optim.Adam(parameters, lr=lr)
18 | elif opt.optim == 'AdamW':
19 | return torch.optim.AdamW(parameters, **opt.optim_args, lr=lr)
20 | else:
21 | raise NotImplementedError()
22 |
23 | # ===== Multi-GPU training =====
24 |
25 | def init_seeds(RANDOM_SEED=1337, no=0):
26 | RANDOM_SEED += no
27 | print("local_rank = {}, seed = {}".format(no, RANDOM_SEED))
28 | random.seed(RANDOM_SEED)
29 | np.random.seed(RANDOM_SEED)
30 | torch.manual_seed(RANDOM_SEED)
31 | torch.cuda.manual_seed_all(RANDOM_SEED)
32 | torch.backends.cudnn.deterministic = True
33 | torch.backends.cudnn.benchmark = False
34 |
35 |
36 | def reduce_tensor(tensor):
37 | rt = tensor.clone()
38 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
39 | rt /= dist.get_world_size()
40 | return rt
41 |
42 |
43 | def gather_tensor(tensor):
44 | tensor_list = [tensor.clone() for _ in range(dist.get_world_size())]
45 | dist.all_gather(tensor_list, tensor)
46 | tensor_list = torch.cat(tensor_list, dim=0)
47 | return tensor_list
48 |
49 |
50 | def DataLoaderDDP(dataset, batch_size, shuffle=True):
51 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
52 | dataloader = DataLoader(
53 | dataset,
54 | batch_size=batch_size,
55 | sampler=sampler,
56 | num_workers=1,
57 | )
58 | return dataloader, sampler
59 |
60 | def print0(*args, **kwargs):
61 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
62 | print(*args, **kwargs)
63 |
--------------------------------------------------------------------------------
/DiT/README.md:
--------------------------------------------------------------------------------
1 | ## DDAE/DiT
2 |
3 | This subfolder contains transfer learning evaluation for ImageNet-256 pre-trained [DiT-XL/2](https://github.com/facebookresearch/DiT) checkpoint, by:
4 | - evaluating:
5 | - [x] Linear probing
6 | - [ ] Fine-tuning
7 | - performance on these datasets:
8 | - [x] CIFAR-10
9 | - [x] Tiny-ImageNet
10 |
11 | This implementation uses very small batch sizes, lightweight data augmentations, and a standard Adam optimizer, without advanced optimizer (e.g., LARS) and large batch sizes. However, incorporating these modern tricks may further improve performances.
12 |
13 | ## Main results
14 | The pre-trained DiT-XL/2 is expected to achieve $85.73$ % linear probing accuracy on CIFAR-10, and $66.57$ % on Tiny-ImageNet.
15 |
16 | ## Usage
17 | ### Data pre-processing
18 | Since DiT is operating in the latent-space, we need to resize the images to $256\times256$ and generate their latent codes (shape: $(4,32,32)$ ) through the VAE encoder.
19 |
20 | To reduce the computational cost at the training, we use `vae_preprocessing.py` to pre-calculate and cache the latent codes into files. Since data augmentations are essential for effective discriminative learning, we generate multiple versions (by default, 10) of latent codes to cover different variations of augmented images. Please refer to `vae_preprocessing.py` for more details.
21 |
22 | ```sh
23 | python -m torch.distributed.launch --nproc_per_node=4
24 | # pre-processing with VAE encoding
25 | vae_preprocessing.py --dataset cifar --use_amp
26 | vae_preprocessing.py --dataset tiny --use_amp
27 | ```
28 |
29 | ### Linear probing
30 | To linear probe the features produced by pre-trained DiT, for example, run:
31 | ```sh
32 | python -m torch.distributed.launch --nproc_per_node=4
33 | # linear probing with default layer-noise combination
34 | linear.py --dataset cifar --use_amp
35 | linear.py --dataset tiny --use_amp
36 | ```
37 | Note that this implementation loads ALL versions of the augmented dataset (by default, 10) into the memory, and hence it requires A LOT OF memory to run (e.g., 50 GB for CIFAR, 80GB for Tiny-ImageNet).
38 | You can improve this by dumping each latent code into a standalone numpy file and only load it when needed, in case you don't have enough memory to work with.
39 |
40 | ## Acknowledgments
41 | Except for `vae_preprocessing.py` and `linear.py`, all codes are retrieved or modified from the official [DiT](https://github.com/facebookresearch/DiT) repository.
42 |
--------------------------------------------------------------------------------
/DiT/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 |
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 | from torchvision.datasets import CIFAR10, CIFAR100
7 |
8 |
9 | class TinyImageNet(Dataset):
10 | def __init__(self, root, train=True, transform=None):
11 | if not root.endswith("tiny-imagenet-200"):
12 | root = os.path.join(root, "tiny-imagenet-200")
13 | self.train_dir = os.path.join(root, "train")
14 | self.val_dir = os.path.join(root, "val")
15 | self.transform = transform
16 | if train:
17 | self._scan_train()
18 | else:
19 | self._scan_val()
20 |
21 | def _scan_train(self):
22 | classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()]
23 | classes = sorted(classes)
24 | assert len(classes) == 200
25 |
26 | self.data = []
27 | for idx, name in enumerate(classes):
28 | this_dir = os.path.join(self.train_dir, name)
29 | for root, _, files in sorted(os.walk(this_dir)):
30 | for fname in sorted(files):
31 | if fname.endswith(".JPEG"):
32 | path = os.path.join(root, fname)
33 | item = (path, idx)
34 | self.data.append(item)
35 | self.labels_dict = {i: classes[i] for i in range(len(classes))}
36 |
37 | def _scan_val(self):
38 | self.file_to_class = {}
39 | classes = set()
40 | with open(os.path.join(self.val_dir, "val_annotations.txt"), 'r') as f:
41 | lines = f.readlines()
42 | for line in lines:
43 | words = line.split("\t")
44 | self.file_to_class[words[0]] = words[1]
45 | classes.add(words[1])
46 | classes = sorted(list(classes))
47 | assert len(classes) == 200
48 |
49 | class_to_idx = {classes[i]: i for i in range(len(classes))}
50 | self.data = []
51 | this_dir = os.path.join(self.val_dir, "images")
52 | for root, _, files in sorted(os.walk(this_dir)):
53 | for fname in sorted(files):
54 | if fname.endswith(".JPEG"):
55 | path = os.path.join(root, fname)
56 | idx = class_to_idx[self.file_to_class[fname]]
57 | item = (path, idx)
58 | self.data.append(item)
59 | self.labels_dict = {i: classes[i] for i in range(len(classes))}
60 |
61 | def __len__(self):
62 | return len(self.data)
63 |
64 | def __getitem__(self, idx):
65 | path, label = self.data[idx]
66 | image = Image.open(path)
67 | image = image.convert("RGB")
68 |
69 | if self.transform:
70 | image = self.transform(image)
71 |
72 | return image, label
73 |
74 |
75 | def get_dataset(name, root="./data", train=True, flip=False, crop=False, resize=None):
76 | if name == 'cifar':
77 | DATASET = CIFAR10
78 | RES = 32
79 | elif name == 'cifar100':
80 | DATASET = CIFAR100
81 | RES = 32
82 | elif name == 'tiny':
83 | DATASET = TinyImageNet
84 | RES = 64
85 | else:
86 | raise NotImplementedError
87 |
88 | tf = [transforms.ToTensor()]
89 | if resize is not None:
90 | tf = [transforms.Resize(resize)] + tf
91 | if train:
92 | if crop:
93 | tf = [transforms.RandomCrop(RES, 4)] + tf
94 | if flip:
95 | tf = [transforms.RandomHorizontalFlip()] + tf
96 |
97 | return DATASET(root=root, train=train, transform=transforms.Compose(tf))
98 |
--------------------------------------------------------------------------------
/DiT/vae_preprocessing.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from tqdm import tqdm
4 | import numpy as np
5 |
6 | import torch
7 | import torch.distributed as dist
8 | from torch.cuda.amp import autocast as autocast
9 | from torchvision.utils import save_image
10 |
11 | from diffusers.models import AutoencoderKL
12 |
13 | import sys
14 | sys.path.append("..")
15 | from datasets import get_dataset
16 | from utils import init_seeds, gather_tensor, DataLoaderDDP, print0
17 |
18 |
19 | def show(imgs, title="debug.png"):
20 | save_image(imgs, title, normalize=True, value_range=(0, 1))
21 |
22 |
23 | def main(opt):
24 | name = opt.dataset
25 | local_rank = opt.local_rank
26 | num_copies = opt.num_copies
27 | use_amp = opt.use_amp
28 |
29 | save_dir = os.path.join('./latent_codes', name)
30 | if local_rank == 0:
31 | os.makedirs(save_dir, exist_ok=False)
32 |
33 | device = "cuda:%d" % local_rank
34 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
35 |
36 | def encode(img):
37 | with torch.no_grad():
38 | code = vae.encode(img.to(device) * 2 - 1)
39 | return 0.18215 * code.latent_dist.sample()
40 |
41 | def decode(code):
42 | with torch.no_grad():
43 | recon = vae.decode(code / 0.18215).sample.cpu()
44 | return (recon + 1) / 2
45 |
46 | train_dataset = get_dataset(name, root="../data", train=True, resize=256, flip=True, crop=True)
47 | test_dataset = get_dataset(name, root="../data", train=False, resize=256)
48 | for dataset, epochs, string in [(train_dataset, num_copies, 'train'), (test_dataset, 1, 'test')]:
49 | loader, sampler = DataLoaderDDP(
50 | dataset,
51 | batch_size=1,
52 | shuffle=False,
53 | )
54 |
55 | for ep in range(epochs):
56 | sampler.set_epoch(ep)
57 | data = []
58 | label = []
59 | for i, (x, y) in enumerate(tqdm(loader, disable=(local_rank != 0))):
60 | x = x.to(device)
61 | y = y.to(device)
62 | with autocast(enabled=use_amp):
63 | code = encode(x).float()
64 | if local_rank == 0 and i == 0:
65 | # for visualization and debugging
66 | recon = decode(code).float()
67 | show(x, f"{string}_debug_original_{ep}.png")
68 | show(recon, f"{string}_debug_reconstruct_{ep}.png")
69 |
70 | dist.barrier()
71 | code = gather_tensor(code).cpu()
72 | data.append(code)
73 | if ep == 0:
74 | y = gather_tensor(y).cpu()
75 | label.append(y)
76 |
77 | if local_rank == 0:
78 | data = torch.cat(data)
79 | with open(os.path.join(save_dir, f"{string}_code_{ep}.npy"), 'wb') as f:
80 | np.save(f, data.numpy())
81 | if ep == 0:
82 | label = torch.cat(label)
83 | with open(os.path.join(save_dir, f"{string}_label.npy"), 'wb') as f:
84 | np.save(f, label.numpy())
85 |
86 |
87 | if __name__ == "__main__":
88 | parser = argparse.ArgumentParser()
89 | parser.add_argument("--dataset", default='cifar', type=str, choices=['cifar', 'tiny'])
90 | parser.add_argument('--num_copies', default=10, type=int,
91 | help='number of training data copies, higher = more augmentation variations')
92 | parser.add_argument('--local_rank', default=-1, type=int,
93 | help='node rank for distributed training')
94 | parser.add_argument("--use_amp", action='store_true', default=False)
95 | opt = parser.parse_args()
96 | print0(opt)
97 |
98 | init_seeds(no=opt.local_rank)
99 | dist.init_process_group(backend='nccl')
100 | torch.cuda.set_device(opt.local_rank)
101 | main(opt)
102 |
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.distributed as dist
6 | import yaml
7 | from torchvision.utils import make_grid, save_image
8 | from ema_pytorch import EMA
9 |
10 | from model.models import get_models_class
11 | from utils import Config, init_seeds, gather_tensor, print0
12 |
13 |
14 | def get_default_steps(model_type, steps):
15 | if steps is not None:
16 | return steps
17 | else:
18 | return {'DDPM': 100, 'EDM': 18}[model_type]
19 |
20 |
21 | # ===== sampling =====
22 |
23 | def sample(opt):
24 | yaml_path = opt.config
25 | local_rank = opt.local_rank
26 | use_amp = opt.use_amp
27 | mode = opt.mode
28 | steps = opt.steps
29 | eta = opt.eta
30 | batches = opt.batches
31 | ep = opt.epoch
32 |
33 | with open(yaml_path, 'r') as f:
34 | opt = yaml.full_load(f)
35 | print0(opt)
36 | opt = Config(opt)
37 | if ep == -1:
38 | ep = opt.n_epoch - 1
39 |
40 | device = "cuda:%d" % local_rank
41 | steps = get_default_steps(opt.model_type, steps)
42 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type)
43 | diff = DIFFUSION(nn_model=NETWORK(**opt.network),
44 | **opt.diffusion,
45 | device=device,
46 | )
47 | diff.to(device)
48 |
49 | target = os.path.join(opt.save_dir, "ckpts", f"model_{ep}.pth")
50 | print0("loading model at", target)
51 | checkpoint = torch.load(target, map_location=device)
52 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1)
53 | ema.to(device)
54 | ema.load_state_dict(checkpoint['EMA'])
55 | model = ema.ema_model
56 | model.eval()
57 |
58 | if local_rank == 0:
59 | if opt.model_type == 'EDM':
60 | gen_dir = os.path.join(opt.save_dir, f"EMAgenerated_ep{ep}_edm_steps{steps}_eta{eta}")
61 | else:
62 | if mode == 'DDPM':
63 | gen_dir = os.path.join(opt.save_dir, f"EMAgenerated_ep{ep}_ddpm")
64 | else:
65 | gen_dir = os.path.join(opt.save_dir, f"EMAgenerated_ep{ep}_ddim_steps{steps}_eta{eta}")
66 | os.makedirs(gen_dir)
67 | gen_dir_png = os.path.join(gen_dir, "pngs")
68 | os.makedirs(gen_dir_png)
69 | res = []
70 |
71 | for batch in range(batches):
72 | with torch.no_grad():
73 | assert 400 % dist.get_world_size() == 0
74 | samples_per_process = 400 // dist.get_world_size()
75 | args = dict(n_sample=samples_per_process, size=opt.network['image_shape'], notqdm=(local_rank != 0), use_amp=use_amp)
76 | if opt.model_type == 'EDM':
77 | x_gen = model.edm_sample(**args, steps=steps, eta=eta)
78 | else:
79 | if mode == 'DDPM':
80 | x_gen = model.sample(**args)
81 | else:
82 | x_gen = model.ddim_sample(**args, steps=steps, eta=eta)
83 | dist.barrier()
84 | x_gen = gather_tensor(x_gen).cpu()
85 | if local_rank == 0:
86 | res.append(x_gen)
87 | grid = make_grid(x_gen, nrow=20)
88 | png_path = os.path.join(gen_dir, f"grid_{batch}.png")
89 | save_image(grid, png_path)
90 |
91 | if local_rank == 0:
92 | res = torch.cat(res)
93 | for no, img in enumerate(res):
94 | png_path = os.path.join(gen_dir_png, f"{no}.png")
95 | save_image(img, png_path)
96 |
97 |
98 | if __name__ == "__main__":
99 | parser = argparse.ArgumentParser()
100 | parser.add_argument("--config", type=str)
101 | parser.add_argument('--local_rank', default=-1, type=int,
102 | help='node rank for distributed training')
103 | parser.add_argument("--use_amp", action='store_true', default=False)
104 | parser.add_argument("--mode", type=str, choices=['DDPM', 'DDIM'], default='DDIM')
105 | parser.add_argument("--steps", type=int, default=None)
106 | parser.add_argument("--eta", type=float, default=0.0)
107 | parser.add_argument("--batches", type=int, default=125)
108 | parser.add_argument("--epoch", type=int, default=-1)
109 | opt = parser.parse_args()
110 | print0(opt)
111 |
112 | init_seeds(no=opt.local_rank)
113 | dist.init_process_group(backend='nccl')
114 | torch.cuda.set_device(opt.local_rank)
115 | sample(opt)
116 |
--------------------------------------------------------------------------------
/model/block.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | def GroupNorm32(channels):
8 | return nn.GroupNorm(32, channels)
9 |
10 |
11 | class TimeEmbedding(nn.Module):
12 | def __init__(self, n_channels, augment_dim):
13 | """
14 | * `n_channels` is the number of dimensions in the embedding
15 | """
16 | super().__init__()
17 | self.n_channels = n_channels
18 | self.aug_emb = nn.Linear(augment_dim, self.n_channels // 4, bias=False) if augment_dim > 0 else None
19 |
20 | self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
21 | self.act = nn.SiLU()
22 | self.lin2 = nn.Linear(self.n_channels, self.n_channels)
23 |
24 | def forward(self, t, aug_label):
25 | # Create sinusoidal position embeddings (same as those from the transformer)
26 | half_dim = self.n_channels // 8
27 | emb = math.log(10_000) / (half_dim - 1)
28 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=t.device) * -emb)
29 | emb = t.float()[:, None] * emb[None, :]
30 | emb = torch.cat((emb.sin(), emb.cos()), dim=1)
31 |
32 | if self.aug_emb is not None and aug_label is not None:
33 | emb += self.aug_emb(aug_label)
34 |
35 | # Transform with the MLP
36 | emb = self.act(self.lin1(emb))
37 | emb = self.lin2(emb)
38 | return emb
39 |
40 |
41 | class AttentionBlock(nn.Module):
42 | def __init__(self, n_channels, d_k):
43 | """
44 | * `n_channels` is the number of channels in the input
45 | * `n_heads` is the number of heads in multi-head attention
46 | * `d_k` is the number of dimensions in each head
47 | """
48 | super().__init__()
49 |
50 | # Default `d_k`
51 | if d_k is None:
52 | d_k = n_channels
53 | n_heads = n_channels // d_k
54 |
55 | self.norm = GroupNorm32(n_channels)
56 | # Projections for query, key and values
57 | self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
58 | # Linear layer for final transformation
59 | self.output = nn.Linear(n_heads * d_k, n_channels)
60 |
61 | self.scale = 1 / math.sqrt(math.sqrt(d_k))
62 | self.n_heads = n_heads
63 | self.d_k = d_k
64 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
65 | print(f"{self.n_heads} heads, {self.d_k} channels per head")
66 |
67 | def forward(self, x):
68 | """
69 | * `x` has shape `[batch_size, in_channels, height, width]`
70 | """
71 | batch_size, n_channels, height, width = x.shape
72 | # Normalize and rearrange to `[batch_size, seq, n_channels]`
73 | h = self.norm(x).view(batch_size, n_channels, -1).permute(0, 2, 1)
74 |
75 | # {q, k, v} all have a shape of `[batch_size, seq, n_heads, d_k]`
76 | qkv = self.projection(h).view(batch_size, -1, self.n_heads, 3 * self.d_k)
77 | q, k, v = torch.chunk(qkv, 3, dim=-1)
78 |
79 | attn = torch.einsum('bihd,bjhd->bijh', q * self.scale, k * self.scale) # More stable with f16 than dividing afterwards
80 | attn = attn.softmax(dim=2)
81 | res = torch.einsum('bijh,bjhd->bihd', attn, v)
82 |
83 | # Reshape to `[batch_size, seq, n_heads * d_k]` and transform to `[batch_size, seq, n_channels]`
84 | res = res.reshape(batch_size, -1, self.n_heads * self.d_k)
85 | res = self.output(res)
86 | res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
87 | return res + x
88 |
89 |
90 | class Upsample(nn.Module):
91 | def __init__(self, n_channels, use_conv=True):
92 | super().__init__()
93 | self.use_conv = use_conv
94 | if use_conv:
95 | self.conv = nn.Conv2d(n_channels, n_channels, kernel_size=3, stride=1, padding=1)
96 |
97 | def forward(self, x):
98 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
99 | if self.use_conv:
100 | return self.conv(x)
101 | else:
102 | return x
103 |
104 |
105 | class Downsample(nn.Module):
106 | def __init__(self, n_channels, use_conv=True):
107 | super().__init__()
108 | self.use_conv = use_conv
109 | if use_conv:
110 | self.conv = nn.Conv2d(n_channels, n_channels, kernel_size=3, stride=2, padding=1)
111 | else:
112 | self.pool = nn.AvgPool2d(2)
113 |
114 | def forward(self, x):
115 | if self.use_conv:
116 | return self.conv(x)
117 | else:
118 | return self.pool(x)
119 |
120 |
--------------------------------------------------------------------------------
/contrastive.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.distributed as dist
6 | import yaml
7 | from datasets import get_dataset
8 | from tqdm import tqdm
9 | from ema_pytorch import EMA
10 |
11 | from model.models import get_models_class
12 | from utils import Config, init_seeds, gather_tensor, DataLoaderDDP, print0
13 |
14 |
15 | def get_model(opt, load_epoch):
16 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type)
17 | diff = DIFFUSION(nn_model=NETWORK(**opt.network),
18 | **opt.diffusion,
19 | device=device,
20 | )
21 | diff.to(device)
22 | target = os.path.join(opt.save_dir, "ckpts", f"model_{load_epoch}.pth")
23 | print0("loading model at", target)
24 | checkpoint = torch.load(target, map_location=device)
25 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1)
26 | ema.to(device)
27 | ema.load_state_dict(checkpoint['EMA'])
28 | model = ema.ema_model
29 | model.eval()
30 | return model
31 |
32 |
33 | def alignment(x, y, alpha=2):
34 | return (x - y).norm(p=2, dim=1).pow(alpha).mean().item()
35 |
36 | def uniformity(x, t=2):
37 | return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log().item()
38 |
39 |
40 | class NamedMeter:
41 | def __init__(self):
42 | self.sum = {}
43 | self.count = {}
44 | self.history = {}
45 |
46 | def update(self, name, val, n=1):
47 | if name not in self.sum:
48 | self.sum[name] = 0
49 | self.count[name] = 0
50 | self.history[name] = []
51 |
52 | self.sum[name] += val * n
53 | self.count[name] += n
54 | self.history[name].append("%.4f" % val)
55 |
56 | def get_avg(self, name):
57 | return self.sum[name] / self.count[name]
58 |
59 | def get_names(self):
60 | return self.sum.keys()
61 |
62 |
63 | def metrics(opt):
64 | yaml_path = opt.config
65 | interval = opt.epoch_interval
66 | use_amp = opt.use_amp
67 | with open(yaml_path, 'r') as f:
68 | opt = yaml.full_load(f)
69 | print0(opt)
70 | opt = Config(opt)
71 | timestep = opt.linear['timestep']
72 |
73 | train_set_raw = get_dataset(name=opt.dataset, root="./data", train=True)
74 | train_loader_raw, _ = DataLoaderDDP(
75 | train_set_raw,
76 | batch_size=128,
77 | shuffle=False,
78 | )
79 |
80 | check_epochs = list(range(interval, opt.n_epoch, interval)) + [opt.n_epoch - 1]
81 | align_evolving = NamedMeter()
82 | uniform_evolving = NamedMeter()
83 |
84 | print0("Using timestep =", timestep)
85 | print0("Checking epochs:", check_epochs)
86 |
87 | for load_epoch in check_epochs:
88 | model = get_model(opt, load_epoch)
89 | align_cur_epoch = NamedMeter()
90 | uniform_cur_epoch = NamedMeter()
91 |
92 | for image, _ in tqdm(train_loader_raw, disable=(local_rank!=0)):
93 | with torch.no_grad():
94 | x = model.get_feature(image.to(device), timestep, norm=True, use_amp=use_amp)
95 | y = model.get_feature(image.to(device), timestep, norm=True, use_amp=use_amp)
96 | dist.barrier()
97 | x = {name: gather_tensor(x[name]).cpu() for name in x}
98 | y = {name: gather_tensor(y[name]).cpu() for name in y}
99 |
100 | for blockname in x:
101 | align = alignment(x[blockname].detach(), y[blockname].detach())
102 | uniform = (uniformity(x[blockname]) + uniformity(y[blockname])) / 2
103 | # calculate metrics for a small batch
104 | align_cur_epoch.update(blockname, align, n=image.shape[0])
105 | uniform_cur_epoch.update(blockname, uniform, n=image.shape[0])
106 |
107 | # gather metrics for the complete dataset
108 | for blockname in align_cur_epoch.get_names():
109 | align = align_cur_epoch.get_avg(blockname)
110 | uniform = uniform_cur_epoch.get_avg(blockname)
111 | # record metrics for each checkpoint
112 | align_evolving.update(blockname, align)
113 | uniform_evolving.update(blockname, uniform)
114 |
115 | if local_rank == 0:
116 | print(align_evolving.history.keys())
117 | print('align metric:')
118 | for blockname in align_evolving.history:
119 | align = align_evolving.history[blockname]
120 | print("'%s': [%s]" % (blockname, ', '.join(align)))
121 |
122 | print('uniform metric:')
123 | for blockname in uniform_evolving.history:
124 | uniform = uniform_evolving.history[blockname]
125 | print("'%s': [%s]" % (blockname, ', '.join(uniform)))
126 |
127 |
128 | if __name__ == "__main__":
129 | parser = argparse.ArgumentParser()
130 | parser.add_argument("--config", type=str)
131 | parser.add_argument('--epoch_interval', type=int, default=400)
132 | parser.add_argument('--local_rank', default=-1, type=int,
133 | help='node rank for distributed training')
134 | parser.add_argument("--use_amp", action='store_true', default=False)
135 | opt = parser.parse_args()
136 | print0(opt)
137 |
138 | local_rank = opt.local_rank
139 | init_seeds(no=local_rank)
140 | dist.init_process_group(backend='nccl')
141 | torch.cuda.set_device(local_rank)
142 | device = "cuda:%d" % local_rank
143 |
144 | metrics(opt)
145 |
--------------------------------------------------------------------------------
/DiT/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 |
117 | class _WrappedModel:
118 | def __init__(self, model, timestep_map, original_num_steps):
119 | self.model = model
120 | self.timestep_map = timestep_map
121 | # self.rescale_timesteps = rescale_timesteps
122 | self.original_num_steps = original_num_steps
123 |
124 | def __call__(self, x, ts, **kwargs):
125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126 | new_ts = map_tensor[ts]
127 | # if self.rescale_timesteps:
128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129 | return self.model(x, new_ts, **kwargs)
130 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.distributed as dist
6 | import yaml
7 | from datasets import get_dataset
8 | from torchvision.utils import make_grid, save_image
9 | from tqdm import tqdm
10 | from ema_pytorch import EMA
11 |
12 | from model.models import get_models_class
13 | from utils import Config, get_optimizer, init_seeds, reduce_tensor, DataLoaderDDP, print0
14 |
15 |
16 | # ===== training =====
17 |
18 | def train(opt):
19 | yaml_path = opt.config
20 | local_rank = opt.local_rank
21 | use_amp = opt.use_amp
22 |
23 | with open(yaml_path, 'r') as f:
24 | opt = yaml.full_load(f)
25 | print0(opt)
26 | opt = Config(opt)
27 | model_dir = os.path.join(opt.save_dir, "ckpts")
28 | vis_dir = os.path.join(opt.save_dir, "visual")
29 | if local_rank == 0:
30 | os.makedirs(model_dir, exist_ok=True)
31 | os.makedirs(vis_dir, exist_ok=True)
32 |
33 | device = "cuda:%d" % local_rank
34 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type)
35 | diff = DIFFUSION(nn_model=NETWORK(**opt.network),
36 | **opt.diffusion,
37 | device=device,
38 | )
39 | diff.to(device)
40 | if local_rank == 0:
41 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1)
42 | ema.to(device)
43 |
44 | diff = torch.nn.SyncBatchNorm.convert_sync_batchnorm(diff)
45 | diff = torch.nn.parallel.DistributedDataParallel(
46 | diff, device_ids=[local_rank], output_device=local_rank)
47 |
48 | train_set = get_dataset(name=opt.dataset, root="./data", train=True, flip=opt.flip)
49 | print0("train dataset:", len(train_set))
50 |
51 | train_loader, sampler = DataLoaderDDP(train_set,
52 | batch_size=opt.batch_size,
53 | shuffle=True)
54 |
55 | lr = opt.lrate
56 | DDP_multiplier = dist.get_world_size()
57 | print0("Using DDP, lr = %f * %d" % (lr, DDP_multiplier))
58 | lr *= DDP_multiplier
59 | optim = get_optimizer(diff.parameters(), opt, lr=lr)
60 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
61 |
62 | if opt.load_epoch != -1:
63 | target = os.path.join(model_dir, f"model_{opt.load_epoch}.pth")
64 | print0("loading model at", target)
65 | checkpoint = torch.load(target, map_location=device)
66 | diff.load_state_dict(checkpoint['MODEL'])
67 | if local_rank == 0:
68 | ema.load_state_dict(checkpoint['EMA'])
69 | optim.load_state_dict(checkpoint['opt'])
70 |
71 | for ep in range(opt.load_epoch + 1, opt.n_epoch):
72 | for g in optim.param_groups:
73 | g['lr'] = lr * min((ep + 1.0) / opt.warm_epoch, 1.0) # warmup
74 | sampler.set_epoch(ep)
75 | dist.barrier()
76 | # training
77 | diff.train()
78 | if local_rank == 0:
79 | now_lr = optim.param_groups[0]['lr']
80 | print(f'epoch {ep}, lr {now_lr:f}')
81 | loss_ema = None
82 | pbar = tqdm(train_loader)
83 | else:
84 | pbar = train_loader
85 | for x, c in pbar:
86 | optim.zero_grad()
87 | x = x.to(device)
88 | loss = diff(x, use_amp=use_amp)
89 | scaler.scale(loss).backward()
90 | scaler.unscale_(optim)
91 | torch.nn.utils.clip_grad_norm_(parameters=diff.parameters(), max_norm=1.0)
92 | scaler.step(optim)
93 | scaler.update()
94 |
95 | # logging
96 | dist.barrier()
97 | loss = reduce_tensor(loss)
98 | if local_rank == 0:
99 | ema.update()
100 | if loss_ema is None:
101 | loss_ema = loss.item()
102 | else:
103 | loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
104 | pbar.set_description(f"loss: {loss_ema:.4f}")
105 |
106 | # testing
107 | if local_rank == 0:
108 | if ep % 100 == 0 or ep == opt.n_epoch - 1:
109 | pass
110 | else:
111 | continue
112 |
113 | if opt.model_type == 'DDPM':
114 | ema_sample_method = ema.ema_model.ddim_sample
115 | elif opt.model_type == 'EDM':
116 | ema_sample_method = ema.ema_model.edm_sample
117 |
118 | ema.ema_model.eval()
119 | with torch.no_grad():
120 | x_gen = ema_sample_method(opt.n_sample, x.shape[1:])
121 | # save an image of currently generated samples (top rows)
122 | # followed by real images (bottom rows)
123 | x_real = x[:opt.n_sample]
124 | x_all = torch.cat([x_gen.cpu(), x_real.cpu()])
125 | grid = make_grid(x_all, nrow=10)
126 |
127 | save_path = os.path.join(vis_dir, f"image_ep{ep}_ema.png")
128 | save_image(grid, save_path)
129 | print('saved image at', save_path)
130 |
131 | # optionally save model
132 | if opt.save_model:
133 | checkpoint = {
134 | 'MODEL': diff.state_dict(),
135 | 'EMA': ema.state_dict(),
136 | 'opt': optim.state_dict(),
137 | }
138 | save_path = os.path.join(model_dir, f"model_{ep}.pth")
139 | torch.save(checkpoint, save_path)
140 | print('saved model at', save_path)
141 |
142 |
143 | if __name__ == "__main__":
144 | parser = argparse.ArgumentParser()
145 | parser.add_argument("--config", type=str)
146 | parser.add_argument('--local_rank', default=-1, type=int,
147 | help='node rank for distributed training')
148 | parser.add_argument("--use_amp", action='store_true', default=False)
149 | opt = parser.parse_args()
150 | print0(opt)
151 |
152 | init_seeds(no=opt.local_rank)
153 | dist.init_process_group(backend='nccl')
154 | torch.cuda.set_device(opt.local_rank)
155 | train(opt)
156 |
--------------------------------------------------------------------------------
/DiT/diffusion/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/model/wideresnet_noise_song.py:
--------------------------------------------------------------------------------
1 | # Code adapted from https://github.com/yang-song/score_sde/blob/main/models/wideresnet_noise_conditional.py
2 | # As a pytorch version of the noise-conditional classifier
3 | # proposed in https://arxiv.org/abs/2011.13456, Appendix I.1
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.nn.init as init
10 |
11 |
12 | def _weights_init(m):
13 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
14 | init.kaiming_normal_(m.weight)
15 |
16 |
17 | def activation(channels, apply_relu=True):
18 | gn = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-5)
19 | if apply_relu:
20 | return nn.Sequential(gn, nn.ReLU(inplace=True))
21 | return gn
22 |
23 |
24 | def _output_add(block_x, orig_x):
25 | """Add two tensors, padding them with zeros or pooling them if necessary.
26 |
27 | Args:
28 | block_x: Output of a resnet block.
29 | orig_x: Residual branch to add to the output of the resnet block.
30 |
31 | Returns:
32 | The sum of blocks_x and orig_x. If necessary, orig_x will be average pooled
33 | or zero padded so that its shape matches orig_x.
34 | """
35 | stride = orig_x.shape[-2] // block_x.shape[-2]
36 | strides = (stride, stride)
37 | if block_x.shape[1] != orig_x.shape[1]:
38 | orig_x = F.avg_pool2d(orig_x, strides, strides)
39 | channels_to_add = block_x.shape[1] - orig_x.shape[1]
40 | orig_x = F.pad(orig_x, (0, 0, 0, 0, 0, channels_to_add))
41 | return block_x + orig_x
42 |
43 |
44 | class GaussianFourierProjection(nn.Module):
45 | """Gaussian Fourier embeddings for noise levels."""
46 |
47 | def __init__(self, embedding_size=256, scale=1.0):
48 | super().__init__()
49 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
50 |
51 | def forward(self, x):
52 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
53 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
54 |
55 |
56 | class WideResnetBlock(nn.Module):
57 | """Defines a single WideResnetBlock."""
58 |
59 | def __init__(self, in_planes, planes, time_channels, stride=1, activate_before_residual=False):
60 | super().__init__()
61 | self.activate_before_residual = activate_before_residual
62 |
63 | self.init_bn = activation(in_planes)
64 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
65 | self.bn_2 = activation(planes)
66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
67 |
68 | # Linear layer for embeddings
69 | self.time_emb = nn.Sequential(
70 | nn.SiLU(),
71 | nn.Linear(time_channels, planes)
72 | )
73 |
74 | def forward(self, x, temb):
75 | if self.activate_before_residual:
76 | x = self.init_bn(x)
77 | orig_x = x
78 | else:
79 | orig_x = x
80 |
81 | block_x = x
82 | if not self.activate_before_residual:
83 | block_x = self.init_bn(block_x)
84 |
85 | block_x = self.conv1(block_x)
86 | block_x += self.time_emb(temb)[:, :, None, None]
87 |
88 | block_x = self.bn_2(block_x)
89 | block_x = self.conv2(block_x)
90 |
91 | return _output_add(block_x, orig_x)
92 |
93 |
94 | class WideResnetGroup(nn.Module):
95 | """Defines a WideResnetGroup."""
96 |
97 | def __init__(self, blocks_per_group, in_planes, planes, time_channels, stride=1, activate_before_residual=False):
98 | super().__init__()
99 | self.blocks_per_group = blocks_per_group
100 |
101 | self.blocks = nn.ModuleList()
102 | for i in range(self.blocks_per_group):
103 | if i == 0:
104 | blk = WideResnetBlock(in_planes, planes, time_channels, stride, activate_before_residual)
105 | else:
106 | blk = WideResnetBlock(planes, planes, time_channels, 1, False)
107 | self.blocks.append(blk)
108 |
109 | def forward(self, x, temb):
110 | for b in self.blocks:
111 | x = b(x, temb)
112 | return x
113 |
114 |
115 | class WideResnet(nn.Module):
116 | """Defines the WideResnet Model."""
117 |
118 | def __init__(self, blocks_per_group, channel_multiplier, in_channels=3, num_classes=10):
119 | super().__init__()
120 | time_channels = 128 * 4
121 | self.time_emb = GaussianFourierProjection(embedding_size=time_channels // 4, scale=16)
122 | self.time_emb_mlp = nn.Sequential(
123 | nn.Linear(time_channels // 2, time_channels),
124 | nn.SiLU(),
125 | nn.Linear(time_channels, time_channels),
126 | )
127 | self.init_conv = nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
128 | self.group1 = WideResnetGroup(blocks_per_group,
129 | 16, 16 * channel_multiplier,
130 | time_channels,
131 | activate_before_residual=True)
132 | self.group2 = WideResnetGroup(blocks_per_group,
133 | 16 * channel_multiplier, 32 * channel_multiplier,
134 | time_channels,
135 | stride=2)
136 | self.group3 = WideResnetGroup(blocks_per_group,
137 | 32 * channel_multiplier, 64 * channel_multiplier,
138 | time_channels,
139 | stride=2)
140 | self.pre_pool_bn = activation(64 * channel_multiplier)
141 | self.final_linear = nn.Linear(64 * channel_multiplier, num_classes)
142 |
143 | self.apply(_weights_init)
144 |
145 | def forward(self, x, t):
146 | # per image standardization
147 | N = np.prod(x.shape[1:])
148 | x = (x - x.mean(dim=(1,2,3), keepdim=True)) / torch.maximum(torch.std(x, dim=(1,2,3), keepdim=True), 1. / torch.tensor(np.sqrt(N)))
149 |
150 | temb = self.time_emb(t)
151 | temb = self.time_emb_mlp(temb)
152 |
153 | x = self.init_conv(x)
154 | x = self.group1(x, temb)
155 | x = self.group2(x, temb)
156 | x = self.group3(x, temb)
157 | x = self.pre_pool_bn(x)
158 | x = F.avg_pool2d(x, x.shape[-1])
159 | x = x.view(x.shape[0], -1)
160 | x = self.final_linear(x)
161 | return x
162 |
163 |
164 | def test(net):
165 | import numpy as np
166 | total_params = 0
167 |
168 | for x in filter(lambda p: p.requires_grad, net.parameters()):
169 | total_params += np.prod(x.data.numpy().shape)
170 | print("Total number of params", total_params)
171 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))
172 |
173 |
174 | def wide_28_10_song(in_channels=3, num_classes=10):
175 | net = WideResnet(blocks_per_group=4, channel_multiplier=10, in_channels=in_channels, num_classes=num_classes)
176 | test(net)
177 | return net
178 |
--------------------------------------------------------------------------------
/noisy_classifier_DDAE.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from functools import partial
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import yaml
8 | import torch.nn as nn
9 | from datasets import get_dataset
10 | from torch.optim.lr_scheduler import CosineAnnealingLR
11 | from tqdm import tqdm
12 | from ema_pytorch import EMA
13 |
14 | from model.models import get_models_class
15 | from model.block import TimeEmbedding
16 | from utils import Config, init_seeds, reduce_tensor, gather_tensor, DataLoaderDDP, print0
17 |
18 |
19 | def get_model(opt, load_epoch):
20 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type)
21 | diff = DIFFUSION(nn_model=NETWORK(**opt.network),
22 | **opt.diffusion,
23 | device=device,
24 | )
25 | diff.to(device)
26 | target = os.path.join(opt.save_dir, "ckpts", f"model_{load_epoch}.pth")
27 | print0("loading model at", target)
28 | checkpoint = torch.load(target, map_location=device)
29 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1)
30 | ema.to(device)
31 | ema.load_state_dict(checkpoint['EMA'])
32 | model = ema.ema_model
33 | model.eval()
34 | return model
35 |
36 | ''' Train a two-layer noise-conditional MLP classifier.
37 | This training script is similar to `linear.py` which performs linear probing test.
38 | '''
39 |
40 | class Classifier(nn.Module):
41 | def __init__(self, feat_func, blockname, dim, num_classes):
42 | super(Classifier, self).__init__()
43 | self.feat_func = feat_func
44 | self.blockname = blockname
45 | self.time_emb = TimeEmbedding(dim, augment_dim=0)
46 | self.cls = nn.Sequential(
47 | nn.Linear(dim, 2 * dim),
48 | nn.SiLU(),
49 | nn.Linear(2 * dim, num_classes)
50 | )
51 |
52 | def forward(self, x, t):
53 | with torch.no_grad():
54 | x = self.feat_func(x.to(device), t=t)
55 | x = x[self.blockname].detach()
56 | return self.cls(x + self.time_emb(t, aug_label=None))
57 |
58 |
59 | class DDPM:
60 | def __init__(self, device, n_T=1000, steps=20):
61 | self.device = device
62 | self.n_T = n_T
63 | self.test_timesteps = (torch.arange(0, self.n_T, self.n_T // steps) + 1).long().tolist()
64 |
65 | def train(self, x):
66 | _t = torch.randint(1, self.n_T + 1, (x.shape[0], ))
67 | return x, _t.to(self.device)
68 |
69 | def test(self, x, t):
70 | _t = torch.full((x.shape[0], ), t)
71 | return x, _t.to(self.device)
72 |
73 |
74 | class EDM:
75 | def __init__(self, device, steps=18):
76 | self.device = device
77 | self.steps = steps
78 | self.test_timesteps = range(1, steps + 1)
79 |
80 | def train(self, x):
81 | _t = torch.randint(1, self.steps + 1, (x.shape[0], ))
82 | return x, _t.to(self.device)
83 |
84 | def test(self, x, t):
85 | _t = torch.full((x.shape[0], ), t)
86 | return x, _t.to(self.device)
87 |
88 |
89 | def train(opt):
90 | def test(t):
91 | preds = []
92 | labels = []
93 | for image, label in tqdm(valid_loader, disable=(local_rank!=0)):
94 | with torch.no_grad():
95 | model.eval()
96 | logit = model(*diff.test(image, t))
97 | pred = logit.argmax(dim=-1)
98 | preds.append(pred)
99 | labels.append(label.to(device))
100 |
101 | pred = torch.cat(preds)
102 | label = torch.cat(labels)
103 | dist.barrier()
104 | pred = gather_tensor(pred)
105 | label = gather_tensor(label)
106 | acc = (pred == label).sum().item() / len(label)
107 | return acc
108 |
109 | yaml_path = opt.config
110 | ep = opt.epoch
111 | use_amp = opt.use_amp
112 | with open(yaml_path, 'r') as f:
113 | opt = yaml.full_load(f)
114 | print0(opt)
115 | opt = Config(opt)
116 | if ep == -1:
117 | ep = opt.n_epoch - 1
118 | model = get_model(opt, ep)
119 |
120 | epoch = opt.linear['n_epoch']
121 | batch_size = opt.linear['batch_size']
122 | base_lr = opt.linear['lrate']
123 | blockname = opt.linear['blockname']
124 |
125 | mode = opt.model_type
126 | if mode == 'DDPM':
127 | diff = DDPM(device)
128 | elif mode == 'EDM':
129 | diff = EDM(device)
130 | else:
131 | raise NotImplementedError
132 |
133 | train_set = get_dataset(name=opt.dataset, root="./data", train=True, flip=True, crop=True)
134 | valid_set = get_dataset(name=opt.dataset, root="./data", train=False)
135 | train_loader, sampler = DataLoaderDDP(
136 | train_set,
137 | batch_size=batch_size,
138 | shuffle=True,
139 | )
140 | valid_loader, _ = DataLoaderDDP(
141 | valid_set,
142 | batch_size=batch_size,
143 | shuffle=False,
144 | )
145 |
146 | # define a two-layer noise-conditional MLP classifier
147 | feat_func = partial(model.get_feature, norm=False, use_amp=use_amp)
148 | with torch.no_grad():
149 | x = feat_func(next(iter(valid_loader))[0].to(device), t=0)
150 | print0("All block names:", x.keys())
151 | print0("Using block:", blockname)
152 |
153 | dim = x[blockname].shape[-1]
154 | model = Classifier(feat_func, blockname, dim, opt.classes).to(device)
155 | model = torch.nn.parallel.DistributedDataParallel(
156 | model, device_ids=[local_rank], output_device=local_rank)
157 |
158 | # train classifier
159 | loss_fn = nn.CrossEntropyLoss()
160 | DDP_multiplier = dist.get_world_size()
161 | print0("Using DDP, lr = %f * %d" % (base_lr, DDP_multiplier))
162 | base_lr *= DDP_multiplier
163 | optim = torch.optim.Adam(model.parameters(), lr=base_lr)
164 | scheduler = CosineAnnealingLR(optim, epoch)
165 | for e in range(epoch):
166 | sampler.set_epoch(e)
167 | pbar = tqdm(train_loader, disable=(local_rank!=0))
168 | for i, (image, label) in enumerate(pbar):
169 | model.train()
170 | logit = model(*diff.train(image))
171 | label = label.to(device)
172 | loss = loss_fn(logit, label)
173 | optim.zero_grad()
174 | loss.backward()
175 | optim.step()
176 |
177 | # logging
178 | dist.barrier()
179 | loss = reduce_tensor(loss)
180 | logit = gather_tensor(logit).cpu()
181 | label = gather_tensor(label).cpu()
182 |
183 | if local_rank == 0:
184 | pred = logit.argmax(dim=-1)
185 | acc = (pred == label).sum().item() / len(label)
186 | nowlr = optim.param_groups[0]['lr']
187 | pbar.set_description("[epoch %d / iter %d]: lr %.1e loss: %.3f, acc: %.3f" % (e, i, nowlr, loss.item(), acc))
188 | scheduler.step()
189 |
190 | accs = {}
191 | for t in diff.test_timesteps:
192 | test_acc = test(t)
193 | print0("[timestep %d]: Test acc: %.3f" % (t, test_acc))
194 | accs[t] = test_acc
195 |
196 |
197 | if __name__ == "__main__":
198 | parser = argparse.ArgumentParser()
199 | parser.add_argument("--config", type=str)
200 | parser.add_argument("--epoch", type=int, default=-1)
201 | parser.add_argument('--local_rank', default=-1, type=int,
202 | help='node rank for distributed training')
203 | parser.add_argument("--use_amp", action='store_true', default=False)
204 | opt = parser.parse_args()
205 | print0(opt)
206 |
207 | local_rank = opt.local_rank
208 | init_seeds(no=local_rank)
209 | dist.init_process_group(backend='nccl')
210 | torch.cuda.set_device(local_rank)
211 | device = "cuda:%d" % local_rank
212 |
213 | train(opt)
214 |
--------------------------------------------------------------------------------
/linear.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from functools import partial
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import yaml
8 | import torch.nn as nn
9 | from datasets import get_dataset
10 | from torch.optim.lr_scheduler import CosineAnnealingLR
11 | from tqdm import tqdm
12 | from ema_pytorch import EMA
13 |
14 | from model.models import get_models_class
15 | from utils import Config, init_seeds, gather_tensor, DataLoaderDDP, print0
16 |
17 |
18 | def get_model(opt, load_epoch):
19 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type)
20 | diff = DIFFUSION(nn_model=NETWORK(**opt.network),
21 | **opt.diffusion,
22 | device=device,
23 | )
24 | diff.to(device)
25 | target = os.path.join(opt.save_dir, "ckpts", f"model_{load_epoch}.pth")
26 | print0("loading model at", target)
27 | checkpoint = torch.load(target, map_location=device)
28 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1)
29 | ema.to(device)
30 | ema.load_state_dict(checkpoint['EMA'])
31 | model = ema.ema_model
32 | model.eval()
33 | return model
34 |
35 |
36 | class ClassifierDict(nn.Module):
37 | def __init__(self, feat_func, time_list, name_list, base_lr, epoch, img_shape, local_rank, num_classes):
38 | super(ClassifierDict, self).__init__()
39 | self.feat_func = feat_func
40 | self.times = time_list
41 | self.names = name_list
42 | self.classifiers = nn.ModuleDict()
43 | self.optims = {}
44 | self.schedulers = {}
45 | self.loss_fn = nn.CrossEntropyLoss()
46 |
47 | for time in self.times:
48 | feats = self.feat_func(torch.zeros(1, *img_shape).to(device), time)
49 | if self.names is None:
50 | self.names = list(feats.keys()) # all available names
51 |
52 | for name in self.names:
53 | key = self.make_key(time, name)
54 | layers = nn.Linear(feats[name].shape[1], num_classes)
55 | layers = torch.nn.parallel.DistributedDataParallel(
56 | layers.to(device), device_ids=[local_rank], output_device=local_rank)
57 | optimizer = torch.optim.Adam(layers.parameters(), lr=base_lr)
58 | scheduler = CosineAnnealingLR(optimizer, epoch)
59 | self.classifiers[key] = layers
60 | self.optims[key] = optimizer
61 | self.schedulers[key] = scheduler
62 |
63 | def train(self, x, y):
64 | self.classifiers.train()
65 | for time in self.times:
66 | feats = self.feat_func(x, time)
67 | for name in self.names:
68 | key = self.make_key(time, name)
69 | representation = feats[name].detach()
70 | logit = self.classifiers[key](representation)
71 | loss = self.loss_fn(logit, y)
72 |
73 | self.optims[key].zero_grad()
74 | loss.backward()
75 | self.optims[key].step()
76 |
77 | def test(self, x):
78 | outputs = {}
79 | with torch.no_grad():
80 | self.classifiers.eval()
81 | for time in self.times:
82 | feats = self.feat_func(x, time)
83 | for name in self.names:
84 | key = self.make_key(time, name)
85 | representation = feats[name].detach()
86 | logit = self.classifiers[key](representation)
87 | pred = logit.argmax(dim=-1)
88 | outputs[key] = pred
89 | return outputs
90 |
91 | def make_key(self, t, n):
92 | return str(t) + '/' + n
93 |
94 | def get_lr(self):
95 | key = self.make_key(self.times[0], self.names[0])
96 | optim = self.optims[key]
97 | return optim.param_groups[0]['lr']
98 |
99 | def schedule_step(self):
100 | for time in self.times:
101 | for name in self.names:
102 | key = self.make_key(time, name)
103 | self.schedulers[key].step()
104 |
105 |
106 | def train(opt):
107 | def test():
108 | preds = {k: [] for k in classifiers.optims.keys()}
109 | accs = {}
110 | labels = []
111 | for image, label in tqdm(valid_loader, disable=(local_rank!=0)):
112 | outputs = classifiers.test(image.to(device))
113 | for key in outputs:
114 | preds[key].append(outputs[key])
115 | labels.append(label.to(device))
116 |
117 | for key in preds:
118 | preds[key] = torch.cat(preds[key])
119 | label = torch.cat(labels)
120 | dist.barrier()
121 | label = gather_tensor(label)
122 | for key in preds:
123 | pred = gather_tensor(preds[key])
124 | accs[key] = (pred == label).sum().item() / len(label)
125 | return accs
126 |
127 | yaml_path = opt.config
128 | ep = opt.epoch
129 | use_amp = opt.use_amp
130 | grid_search = opt.grid
131 | with open(yaml_path, 'r') as f:
132 | opt = yaml.full_load(f)
133 | print0(opt)
134 | opt = Config(opt)
135 | if ep == -1:
136 | ep = opt.n_epoch - 1
137 | model = get_model(opt, ep)
138 |
139 | epoch = opt.linear['n_epoch']
140 | batch_size = opt.linear['batch_size']
141 | base_lr = opt.linear['lrate']
142 |
143 | if grid_search:
144 | time_list = [1, 11, 21] if opt.model_type == 'DDPM' else [3, 4, 5]
145 | name_list = None
146 | else:
147 | time_list = [opt.linear['timestep']]
148 | name_list = [opt.linear['blockname']]
149 |
150 | train_set = get_dataset(name=opt.dataset, root="./data", train=True, flip=True, crop=True)
151 | valid_set = get_dataset(name=opt.dataset, root="./data", train=False)
152 | train_loader, sampler = DataLoaderDDP(
153 | train_set,
154 | batch_size=batch_size,
155 | shuffle=True,
156 | )
157 | valid_loader, _ = DataLoaderDDP(
158 | valid_set,
159 | batch_size=batch_size,
160 | shuffle=False,
161 | )
162 |
163 | feat_func = partial(model.get_feature, norm=False, use_amp=use_amp)
164 | DDP_multiplier = dist.get_world_size()
165 | print0("Using DDP, lr = %f * %d" % (base_lr, DDP_multiplier))
166 | base_lr *= DDP_multiplier
167 | classifiers = ClassifierDict(feat_func, time_list, name_list,
168 | base_lr, epoch, opt.network['image_shape'], local_rank, opt.classes).to(model.device)
169 |
170 | for e in range(epoch):
171 | sampler.set_epoch(e)
172 | pbar = tqdm(train_loader, disable=(local_rank!=0))
173 | for i, (image, label) in enumerate(pbar):
174 | pbar.set_description("[epoch %d / iter %d]: lr: %.1e" % (e, i, classifiers.get_lr()))
175 | classifiers.train(image.to(device), label.to(device))
176 | classifiers.schedule_step()
177 |
178 | accs = test()
179 | for key in accs:
180 | print0("[key %s]: Test acc: %.2f" % (key, accs[key] * 100))
181 |
182 |
183 | if __name__ == "__main__":
184 | parser = argparse.ArgumentParser()
185 | parser.add_argument("--config", type=str)
186 | parser.add_argument("--epoch", type=int, default=-1)
187 | parser.add_argument('--local_rank', default=-1, type=int,
188 | help='node rank for distributed training')
189 | parser.add_argument("--use_amp", action='store_true', default=False)
190 | parser.add_argument("--grid", action='store_true', default=False)
191 | opt = parser.parse_args()
192 | print0(opt)
193 |
194 | local_rank = opt.local_rank
195 | init_seeds(no=local_rank)
196 | dist.init_process_group(backend='nccl')
197 | torch.cuda.set_device(local_rank)
198 | device = "cuda:%d" % local_rank
199 |
200 | train(opt)
201 |
--------------------------------------------------------------------------------
/DiT/linear.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import numpy as np
4 | from functools import partial
5 |
6 | import torch
7 | import torch.distributed as dist
8 | from torch.cuda.amp import autocast as autocast
9 | import torch.nn as nn
10 | from torch.utils.data import Dataset
11 | from torch.optim.lr_scheduler import CosineAnnealingLR
12 | from tqdm import tqdm
13 | from torch.cuda.amp import autocast as autocast
14 |
15 | from diffusion import create_diffusion
16 | from download import find_model
17 | from models import DiT_XL_2
18 | import sys
19 | sys.path.append("..")
20 | from utils import init_seeds, gather_tensor, DataLoaderDDP, print0
21 |
22 |
23 | class LatentCodeDataset(Dataset):
24 | # warning: needs A LOT OF memory to load these datasets !
25 | def __init__(self, dataset, train=True, num_copies=10):
26 | if train:
27 | code_path = [f"latent_codes/{dataset}/train_code_{i}.npy" for i in range(num_copies)]
28 | label_path = f"latent_codes/{dataset}/train_label.npy"
29 | else:
30 | code_path = [f"latent_codes/{dataset}/test_code_0.npy"]
31 | label_path = f"latent_codes/{dataset}/test_label.npy"
32 |
33 | self.code = []
34 | for p in code_path:
35 | with open(p, 'rb') as f:
36 | data = np.load(f)
37 | self.code.append(data)
38 | with open(label_path, 'rb') as f:
39 | self.label = np.load(f)
40 |
41 | print0(f"Code shape: {len(self.code)} x {self.code[0].shape}")
42 | print0("Label shape:", self.label.shape)
43 |
44 | def __getitem__(self, index):
45 | replica = random.randrange(len(self.code))
46 | code = self.code[replica][index]
47 | label = self.label[index]
48 | return code, label
49 |
50 | def __len__(self):
51 | return len(self.code[0])
52 |
53 |
54 | def get_model(device):
55 | model = DiT_XL_2().to(device)
56 | state_dict = find_model(f"DiT-XL-2-256x256.pt")
57 | model.load_state_dict(state_dict)
58 | model.eval()
59 | diffusion = create_diffusion(None) # 1000-len betas
60 | return model, diffusion
61 |
62 |
63 | def denoise_feature(code, model, timestep, blockname, use_amp):
64 | '''
65 | Args:
66 | `image`: Latent codes. (-1, 4, 32, 32) tensor.
67 | `timestep`: Time step to extract features. int.
68 | `blockname`: Block to extract features. str.
69 | Returns:
70 | Collected feature map.
71 | '''
72 | x = code.to(device)
73 | t = torch.tensor([timestep]).to(device).repeat(x.shape[0])
74 | noise = torch.randn_like(x)
75 | x_t = diffusion.q_sample(x, t, noise=noise)
76 | y_null = torch.tensor([1000] * x.shape[0], device=device)
77 |
78 | with torch.no_grad():
79 | with autocast(enabled=use_amp):
80 | _, acts = model(x_t, t, y_null, ret_activation=True)
81 | feat = acts[blockname].float().detach()
82 | # (-1, 256, 1152)
83 | # we average pool across the sequence dimension to extract
84 | # a 1152-dimensional vector of features per example
85 | return feat.mean(dim=1)
86 |
87 |
88 | class Classifier(nn.Module):
89 | def __init__(self, feat_func, base_lr, epoch, num_classes):
90 | super(Classifier, self).__init__()
91 | self.feat_func = feat_func
92 | self.loss_fn = nn.CrossEntropyLoss()
93 |
94 | hidden_size = feat_func(next(iter(valid_loader))[0]).shape[-1]
95 | layers = nn.Sequential(
96 | # nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6),
97 | nn.Linear(hidden_size, num_classes),
98 | )
99 | layers = torch.nn.parallel.DistributedDataParallel(
100 | layers.to(device), device_ids=[local_rank], output_device=local_rank)
101 | self.classifier = layers
102 | self.optim = torch.optim.Adam(self.classifier.parameters(), lr=base_lr)
103 | self.scheduler = CosineAnnealingLR(self.optim, epoch)
104 |
105 | def train(self, x, y):
106 | self.classifier.train()
107 | feat = self.feat_func(x)
108 | logit = self.classifier(feat)
109 | loss = self.loss_fn(logit, y)
110 |
111 | self.optim.zero_grad()
112 | loss.backward()
113 | self.optim.step()
114 |
115 | def test(self, x):
116 | with torch.no_grad():
117 | self.classifier.eval()
118 | feat = self.feat_func(x)
119 | logit = self.classifier(feat)
120 | pred = logit.argmax(dim=-1)
121 | return pred
122 |
123 | def get_lr(self):
124 | return self.optim.param_groups[0]['lr']
125 |
126 | def schedule_step(self):
127 | self.scheduler.step()
128 |
129 |
130 | def train(model, timestep, blockname, epoch, base_lr, use_amp):
131 | def test():
132 | preds = []
133 | labels = []
134 | for image, label in tqdm(valid_loader, disable=(local_rank!=0)):
135 | pred = classifier.test(image.to(device))
136 | preds.append(pred)
137 | labels.append(label.to(device))
138 |
139 | pred = torch.cat(preds)
140 | label = torch.cat(labels)
141 | dist.barrier()
142 | pred = gather_tensor(pred)
143 | label = gather_tensor(label)
144 | acc = (pred == label).sum().item() / len(label)
145 | return acc
146 |
147 | print0(f"Feature extraction: time = {timestep}, name = {blockname}")
148 | feat_func = partial(denoise_feature, model=model, timestep=timestep, blockname=blockname, use_amp=use_amp)
149 | DDP_multiplier = dist.get_world_size()
150 | print0("Using DDP, lr = %f * %d" % (base_lr, DDP_multiplier))
151 | base_lr *= DDP_multiplier
152 | num_classes = 10 if opt.dataset == 'cifar' else 200
153 |
154 | classifier = Classifier(feat_func, base_lr, epoch, num_classes).to(device)
155 |
156 | for e in range(epoch):
157 | sampler.set_epoch(e)
158 | pbar = tqdm(train_loader, disable=(local_rank!=0))
159 | for i, (image, label) in enumerate(pbar):
160 | pbar.set_description("[epoch %d / iter %d]: lr: %.1e" % (e, i, classifier.get_lr()))
161 | classifier.train(image.to(device), label.to(device))
162 | classifier.schedule_step()
163 |
164 | acc = test()
165 | print0("Test acc: %.2f" % (acc * 100))
166 |
167 |
168 | def get_default_time(dataset, t):
169 | if t > 0:
170 | return t
171 | else:
172 | return {'cifar': 121, 'tiny': 81}[dataset]
173 |
174 |
175 | def get_default_name(dataset, b):
176 | if b != 'layer-0':
177 | return b
178 | else:
179 | return {'cifar': 'layer-13', 'tiny': 'layer-13'}[dataset]
180 |
181 |
182 | if __name__ == "__main__":
183 | parser = argparse.ArgumentParser()
184 | parser.add_argument("--dataset", default='cifar', type=str, choices=['cifar', 'tiny'])
185 | parser.add_argument('--local_rank', default=-1, type=int,
186 | help='node rank for distributed training')
187 | parser.add_argument("--use_amp", action='store_true', default=False)
188 | parser.add_argument('--batch_size', default=128, type=int)
189 | parser.add_argument('--lr', default=1e-3, type=float)
190 | parser.add_argument('--epoch', default=30, type=int)
191 | parser.add_argument('--time', type=int, default=0)
192 | parser.add_argument('--name', type=str, default='layer-0')
193 | opt = parser.parse_args()
194 |
195 | local_rank = opt.local_rank
196 | init_seeds(no=local_rank)
197 | dist.init_process_group(backend='nccl')
198 | torch.cuda.set_device(local_rank)
199 | device = "cuda:%d" % local_rank
200 | model, diffusion = get_model(device)
201 |
202 | train_set = LatentCodeDataset(opt.dataset, train=True)
203 | valid_set = LatentCodeDataset(opt.dataset, train=False)
204 | train_loader, sampler = DataLoaderDDP(
205 | train_set,
206 | batch_size=opt.batch_size,
207 | shuffle=True,
208 | )
209 | valid_loader, _ = DataLoaderDDP(
210 | valid_set,
211 | batch_size=opt.batch_size,
212 | shuffle=False,
213 | )
214 |
215 | # default timestep & blockname values
216 | opt.time = get_default_time(opt.dataset, opt.time)
217 | opt.name = get_default_name(opt.dataset, opt.name)
218 |
219 | print0(opt)
220 | train(model, timestep=opt.time, blockname=opt.name, epoch=opt.epoch, base_lr=opt.lr, use_amp=opt.use_amp)
221 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 🆕 **[2025] Please check out the more recent study [DDAE++](https://github.com/FutureXiang/ddae_plus_plus) continuing this line of work.**
2 |
3 |
4 | # Denoising Diffusion Autoencoders (DDAE)
5 |
6 |
7 |
8 |
9 |
10 | This is a multi-gpu PyTorch implementation of the paper [Denoising Diffusion Autoencoders are Unified Self-supervised Learners](https://arxiv.org/abs/2303.09769):
11 | ```bibtex
12 | @inproceedings{ddae2023,
13 | title={Denoising Diffusion Autoencoders are Unified Self-supervised Learners},
14 | author={Xiang, Weilai and Yang, Hongyu and Huang, Di and Wang, Yunhong},
15 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
16 | year={2023}
17 | }
18 | ```
19 | :star: (News) Our paper is cited by Kaiming He's new paper [Deconstructing Denoising Diffusion Models for Self-Supervised Learning](https://arxiv.org/abs/2401.14404), check it out! :fire:
20 |
21 | ## Overview
22 |
23 | This repo contains:
24 | - [x] Pre-training, sampling and FID evaluation code for diffusion models, including
25 | - Frameworks:
26 | - [x] DDPM & DDIM
27 | - [x] EDM (w/ or w/o data augmentation)
28 | - Networks:
29 | - [x] The basic 35.7M DDPM UNet
30 | - [x] A larger 56M DDPM++ UNet
31 | - Datasets:
32 | - [x] CIFAR-10
33 | - [ ] Tiny-ImageNet
34 | - [x] Feature quality evaluation code, including
35 | - [x] Linear probing and grid searching
36 | - [x] Contrastive metrics, i.e., alignment and uniformity
37 | - [ ] Fine-tuning
38 | - [x] Noise-conditional classifier training and evaluation, including
39 | - [x] MLP classifier based on DDPM/EDM features
40 | - [x] WideResNet with VP/VE perturbation
41 | - [x] Evaluation code for ImageNet-256 pre-trained [DiT-XL/2](https://github.com/facebookresearch/DiT) checkpoint
42 |
43 | ## Requirements
44 | - In addition to PyTorch environments, please install:
45 | ```sh
46 | conda install pyyaml
47 | pip install pytorch-fid ema-pytorch
48 | ```
49 | - We use 4 or 8 3080ti GPUs to conduct all the experiments presented in the paper. With automatic mixed precision enabled and 4 GPUs, training a basic 35.7M UNet on CIFAR-10 takes ~14 hours.
50 | - The `pytorch-fid` requires image files to calculate the FID metric. Please refer to `extract_cifar10_pngs.ipynb` to unpack the CIFAR-10 training dataset into 50000 `.png` image files.
51 |
52 | ## Main results
53 | We present the generative and discriminative evaluation results that can be obtained by this codebase. The `EDM_ddpmpp_aug.yaml` training is performed on 8 GPUs, while other models are trained on 4 GPUs.
54 |
55 | Please note that this is a *over-simplified* DDPM / EDM implementation, and some network details, initialization, and hyper-parameters may *differ from* official ones. Please refer to their respective official codebases to reproduce the *exact results* reported in the paper.
56 |
57 |
58 |
59 |
60 |
61 | | Config |
62 | Model |
63 | Network |
64 | Best linear probe checkpoint |
65 | Best FID checkpoint |
66 |
67 |
68 | | epoch |
69 | FID |
70 | acc |
71 | epoch |
72 | FID |
73 | acc |
74 |
75 |
76 |
77 |
78 | | DDPM_ddpm.yaml |
79 | DDPM |
80 | 35.7M UNet |
81 | 800 |
82 | 4.09 |
83 | 90.05 |
84 | 1999 |
85 | 3.62 |
86 | 88.23 |
87 |
88 |
89 | | EDM_ddpm.yaml |
90 | EDM |
91 | 35.7M UNet |
92 | 1200 |
93 | 3.97 |
94 | 90.44 |
95 | 1999 |
96 | 3.56 |
97 | 89.71 |
98 |
99 |
100 | | DDPM_ddpmpp.yaml |
101 | DDPM |
102 | 56.5M DDPM++ |
103 | 1200 |
104 | 3.08 |
105 | 93.97 |
106 | 1999 |
107 | 2.98 |
108 | 93.03 |
109 |
110 |
111 | | EDM_ddpmpp.yaml |
112 | EDM |
113 | 56.5M DDPM++ |
114 | 1200 |
115 | 2.23 |
116 | 94.50 |
117 | (same) |
118 |
119 |
120 | | EDM_ddpmpp_aug.yaml |
121 | EDM + data aug |
122 | 56.5M DDPM++ |
123 | 2000 |
124 | 2.34 |
125 | 95.49 |
126 | 3200 |
127 | 2.12 |
128 | 95.19 |
129 |
130 |
131 |
132 |
133 | FIDs are calculated using 50000 images generated by the deterministic fast sampler (DDIM 100 steps or EDM 18 steps).
134 |
135 | ## Latent-space DiT
136 | We evaluate pre-trained Transformer-based diffusion networks, [DiT](https://github.com/facebookresearch/DiT), from the perspective of *transfer learning*. Please refer to the [ddae/DiT](DiT/) subfolder.
137 |
138 | ## Usage
139 | ### Diffusion pre-training
140 | To train a DDAE model and generate 50000 image samples with 4 GPUs, for example, run:
141 | ```sh
142 | python -m torch.distributed.launch --nproc_per_node=4
143 | # diffusion pre-training with AMP enabled
144 | train.py --config config/DDPM_ddpm.yaml --use_amp
145 |
146 | # deterministic fast sampling (i.e. DDIM 100 steps / EDM 18 steps)
147 | sample.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400
148 |
149 | # stochastic sampling (i.e. DDPM 1000 steps)
150 | sample.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --mode DDPM
151 | ```
152 | To calculate the FID metric on the training set, for example, run:
153 | ```sh
154 | python -m pytorch_fid data/cifar10-pngs/ output_DDPM_ddpm/EMAgenerated_ep400_ddim_steps100_eta0.0/pngs/
155 | ```
156 |
157 | ### Features produced by DDAE
158 | To evaluate the features produced by pre-trained DDAE, for example, run:
159 | ```sh
160 | python -m torch.distributed.launch --nproc_per_node=4
161 | # grid searching for proper layer-noise combination
162 | linear.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --grid
163 |
164 | # linear probing, using the layer-noise combination specified by config.yaml
165 | linear.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400
166 |
167 | # showing the alignment-uniformity metrics with respect to different checkpoints
168 | contrastive.py --config config/DDPM_ddpm.yaml --use_amp
169 | ```
170 |
171 | ### Noise-conditional classifier
172 | To train WideResNet-based classifiers from scratch:
173 | ```sh
174 | python -m torch.distributed.launch --nproc_per_node=4
175 | # VP (DDPM) perturbation
176 | noisy_classifier_WRN.py --mode DDPM
177 | # VE (EDM) perturbation
178 | noisy_classifier_WRN.py --mode EDM
179 | ```
180 | and compare their noise-conditional recognition rates with DDAE-based MLP classifier heads:
181 | ```sh
182 | python -m torch.distributed.launch --nproc_per_node=4
183 | # using DDPM DDAE encoder
184 | noisy_classifier_DDAE.py --config config/DDPM_ddpm.yaml --use_amp --epoch 1999
185 | # using EDM DDAE encoder
186 | noisy_classifier_DDAE.py --config config/EDM_ddpmpp.yaml --use_amp --epoch 1200
187 | ```
188 |
189 | ## Acknowledgments
190 | This repository is built on numerous open-source codebases such as [DDPM](https://github.com/hojonathanho/diffusion), [DDPM-pytorch](https://github.com/pesser/pytorch_diffusion), [DDIM](https://github.com/ermongroup/ddim), [EDM](https://github.com/NVlabs/edm), [Score-based SDE](https://github.com/yang-song/score_sde), [DiT](https://github.com/facebookresearch/DiT), and [align_uniform](https://github.com/SsnL/align_uniform).
191 |
--------------------------------------------------------------------------------
/noisy_classifier_WRN.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from functools import partial
3 |
4 | import torch
5 | import torch.distributed as dist
6 | import torch.nn as nn
7 | from datasets import get_dataset
8 | from torch.optim.lr_scheduler import CosineAnnealingLR
9 | from tqdm import tqdm
10 |
11 | from model.wideresnet_noise_song import wide_28_10_song
12 | from utils import init_seeds, reduce_tensor, gather_tensor, DataLoaderDDP, print0
13 |
14 |
15 | def normalize_to_neg_one_to_one(img):
16 | # [0.0, 1.0] -> [-1.0, 1.0]
17 | return img * 2 - 1
18 |
19 |
20 | class DDPM:
21 | def __init__(self, device, betas=[1.0e-4, 0.02], n_T=1000, steps=20):
22 | self.device = device
23 | self.n_T = n_T
24 | self.ddpm_sche = self.schedules(betas, n_T, device, 'DDPM')
25 | self.test_timesteps = (torch.arange(0, self.n_T, self.n_T // steps) + 1).long().tolist()
26 |
27 | def train(self, x):
28 | x = normalize_to_neg_one_to_one(x).to(self.device)
29 | # Perturbation
30 | _ts = torch.randint(1, self.n_T + 1, (x.shape[0], )).to(self.device)
31 | noise = torch.randn_like(x)
32 | sche = self.ddpm_sche
33 | x_noised = (sche["sqrtab"][_ts, None, None, None] * x +
34 | sche["sqrtmab"][_ts, None, None, None] * noise)
35 | return x_noised, _ts / self.n_T
36 |
37 | def test(self, x, t):
38 | x = normalize_to_neg_one_to_one(x).to(self.device)
39 | # Perturbation
40 | _ts = torch.tensor([t]).to(self.device).repeat(x.shape[0])
41 | noise = torch.randn_like(x)
42 | sche = self.ddpm_sche
43 | x_noised = (sche["sqrtab"][_ts, None, None, None] * x +
44 | sche["sqrtmab"][_ts, None, None, None] * noise)
45 | return x_noised, _ts / self.n_T
46 |
47 | def schedules(self, betas, T, device, type='DDPM'):
48 | def linear_beta_schedule(timesteps, beta1, beta2):
49 | assert 0.0 < beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"
50 | return torch.linspace(beta1, beta2, timesteps)
51 |
52 | beta1, beta2 = betas
53 | schedule_fn = partial(linear_beta_schedule, beta1=beta1, beta2=beta2)
54 |
55 | if type == 'DDPM':
56 | beta_t = torch.cat([torch.tensor([0.0]), schedule_fn(T)])
57 | elif type == 'DDIM':
58 | beta_t = schedule_fn(T + 1)
59 | else:
60 | raise NotImplementedError()
61 | sqrt_beta_t = torch.sqrt(beta_t)
62 | alpha_t = 1 - beta_t
63 | log_alpha_t = torch.log(alpha_t)
64 | alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()
65 |
66 | sqrtab = torch.sqrt(alphabar_t)
67 | oneover_sqrta = 1 / torch.sqrt(alpha_t)
68 |
69 | sqrtmab = torch.sqrt(1 - alphabar_t)
70 | ma_over_sqrtmab = (1 - alpha_t) / sqrtmab
71 |
72 | dic = {
73 | "alpha_t": alpha_t,
74 | "oneover_sqrta": oneover_sqrta,
75 | "sqrt_beta_t": sqrt_beta_t,
76 | "alphabar_t": alphabar_t,
77 | "sqrtab": sqrtab,
78 | "sqrtmab": sqrtmab,
79 | "ma_over_sqrtmab": ma_over_sqrtmab,
80 | }
81 | return {key: dic[key].to(device) for key in dic}
82 |
83 |
84 | class EDM:
85 | def __init__(self, device, p_std=1.2, p_mean=-1.2, sigma_min=0.002, sigma_max=80, rho=7, steps=18):
86 | self.device = device
87 | self.p_std = p_std
88 | self.p_mean = p_mean
89 | self.times = self.schedules(sigma_min, sigma_max, rho, steps)
90 | self.test_timesteps = range(1, steps + 1)
91 |
92 | def train(self, x):
93 | x = normalize_to_neg_one_to_one(x).to(self.device)
94 | # Perturbation
95 | rnd_normal = torch.randn((x.shape[0], 1, 1, 1)).to(self.device)
96 | sigma = (rnd_normal * self.p_std + self.p_mean).exp()
97 | noise = torch.randn_like(x)
98 | x_noised = x + noise * sigma
99 |
100 | sigma = sigma.reshape(x.shape[0],)
101 | return x_noised, sigma.log()
102 |
103 | def test(self, x, t):
104 | x = normalize_to_neg_one_to_one(x).to(self.device)
105 | # Perturbation
106 | noise = torch.randn_like(x)
107 | sigma = self.times[t]
108 | x_noised = x + noise * sigma
109 |
110 | sigma = torch.full((x.shape[0], ), sigma)
111 | return x_noised, sigma.log()
112 |
113 | def schedules(self, sigma_min, sigma_max, rho, steps):
114 | times = torch.arange(steps, dtype=torch.float64, device=self.device)
115 | times = (sigma_max ** (1 / rho) + times / (steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
116 | times = torch.cat([times, torch.zeros_like(times[:1])]) # t_N = 0
117 | times = reversed(times)
118 | return times
119 |
120 |
121 | def train(opt):
122 | def test(t):
123 | preds = []
124 | labels = []
125 | for image, label in tqdm(valid_loader, disable=(local_rank!=0)):
126 | with torch.no_grad():
127 | model.eval()
128 | logit = model(*diff.test(image, t))
129 | pred = logit.argmax(dim=-1)
130 | preds.append(pred)
131 | labels.append(label.to(device))
132 |
133 | pred = torch.cat(preds)
134 | label = torch.cat(labels)
135 | dist.barrier()
136 | pred = gather_tensor(pred)
137 | label = gather_tensor(label)
138 | acc = (pred == label).sum().item() / len(label)
139 | return acc
140 |
141 | warm_epoch = opt.warm_epoch
142 | epoch = opt.epoch
143 | batch_size = opt.batch_size
144 | base_lr = opt.lr
145 | mode = opt.mode
146 |
147 | if mode == 'DDPM':
148 | diff = DDPM(device)
149 | elif mode == 'EDM':
150 | diff = EDM(device)
151 | else:
152 | raise NotImplementedError
153 |
154 | train_set = get_dataset(name='cifar', root="./data", train=True, flip=True, crop=True)
155 | valid_set = get_dataset(name='cifar', root="./data", train=False)
156 | train_loader, sampler = DataLoaderDDP(
157 | train_set,
158 | batch_size=batch_size,
159 | shuffle=True,
160 | )
161 | valid_loader, _ = DataLoaderDDP(
162 | valid_set,
163 | batch_size=batch_size,
164 | shuffle=False,
165 | )
166 |
167 | model = wide_28_10_song(num_classes=10).to(device)
168 | model = torch.nn.parallel.DistributedDataParallel(
169 | model, device_ids=[local_rank], output_device=local_rank)
170 | loss_fn = nn.CrossEntropyLoss()
171 | optim = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=5e-4)
172 | scheduler = CosineAnnealingLR(optim, epoch)
173 | for e in range(epoch):
174 | sampler.set_epoch(e)
175 | if (e + 1) <= warm_epoch:
176 | for g in optim.param_groups:
177 | g['lr'] = base_lr * (e + 1.0) / warm_epoch # warmup
178 |
179 | pbar = tqdm(train_loader, disable=(local_rank!=0))
180 | for i, (image, label) in enumerate(pbar):
181 | model.train()
182 | logit = model(*diff.train(image))
183 | label = label.to(device)
184 | loss = loss_fn(logit, label)
185 | optim.zero_grad()
186 | loss.backward()
187 | optim.step()
188 |
189 | # logging
190 | dist.barrier()
191 | loss = reduce_tensor(loss)
192 | logit = gather_tensor(logit).cpu()
193 | label = gather_tensor(label).cpu()
194 |
195 | if local_rank == 0:
196 | pred = logit.argmax(dim=-1)
197 | acc = (pred == label).sum().item() / len(label)
198 | nowlr = optim.param_groups[0]['lr']
199 | pbar.set_description("[epoch %d / iter %d]: lr %.1e loss: %.3f, acc: %.3f" % (e, i, nowlr, loss.item(), acc))
200 | scheduler.step()
201 |
202 | accs = {}
203 | for t in diff.test_timesteps:
204 | test_acc = test(t)
205 | print0("[timestep %d]: Test acc: %.3f" % (t, test_acc))
206 | accs[t] = test_acc
207 |
208 |
209 | if __name__ == "__main__":
210 | parser = argparse.ArgumentParser()
211 | parser.add_argument('--local_rank', default=-1, type=int,
212 | help='node rank for distributed training')
213 | parser.add_argument('--batch_size', default=128, type=int)
214 | parser.add_argument('--lr', default=0.1, type=float)
215 | parser.add_argument('--epoch', default=200, type=int)
216 | parser.add_argument('--warm_epoch', default=5, type=int)
217 | parser.add_argument("--mode", type=str, choices=['DDPM', 'EDM'], default='DDPM')
218 | opt = parser.parse_args()
219 | print0(opt)
220 |
221 | local_rank = opt.local_rank
222 | init_seeds(no=local_rank)
223 | dist.init_process_group(backend='nccl')
224 | torch.cuda.set_device(local_rank)
225 | device = "cuda:%d" % local_rank
226 |
227 | train(opt)
228 |
--------------------------------------------------------------------------------
/model/DDPM.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import torch
3 | import torch.nn as nn
4 | from tqdm import tqdm
5 | from torch.cuda.amp import autocast as autocast
6 |
7 |
8 | def normalize_to_neg_one_to_one(img):
9 | # [0.0, 1.0] -> [-1.0, 1.0]
10 | return img * 2 - 1
11 |
12 |
13 | def unnormalize_to_zero_to_one(t):
14 | # [-1.0, 1.0] -> [0.0, 1.0]
15 | return (t + 1) * 0.5
16 |
17 |
18 | def linear_beta_schedule(timesteps, beta1, beta2):
19 | assert 0.0 < beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"
20 | return torch.linspace(beta1, beta2, timesteps)
21 |
22 |
23 | def schedules(betas, T, device, type='DDPM'):
24 | beta1, beta2 = betas
25 | schedule_fn = partial(linear_beta_schedule, beta1=beta1, beta2=beta2)
26 |
27 | if type == 'DDPM':
28 | beta_t = torch.cat([torch.tensor([0.0]), schedule_fn(T)])
29 | elif type == 'DDIM':
30 | beta_t = schedule_fn(T + 1)
31 | else:
32 | raise NotImplementedError()
33 | sqrt_beta_t = torch.sqrt(beta_t)
34 | alpha_t = 1 - beta_t
35 | log_alpha_t = torch.log(alpha_t)
36 | alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()
37 |
38 | sqrtab = torch.sqrt(alphabar_t)
39 | oneover_sqrta = 1 / torch.sqrt(alpha_t)
40 |
41 | sqrtmab = torch.sqrt(1 - alphabar_t)
42 | ma_over_sqrtmab = (1 - alpha_t) / sqrtmab
43 |
44 | dic = {
45 | "alpha_t": alpha_t,
46 | "oneover_sqrta": oneover_sqrta,
47 | "sqrt_beta_t": sqrt_beta_t,
48 | "alphabar_t": alphabar_t,
49 | "sqrtab": sqrtab,
50 | "sqrtmab": sqrtmab,
51 | "ma_over_sqrtmab": ma_over_sqrtmab,
52 | }
53 | return {key: dic[key].to(device) for key in dic}
54 |
55 |
56 | class DDPM(nn.Module):
57 | def __init__(self, nn_model, betas, n_T, device):
58 | ''' DDPM proposed by "Denoising Diffusion Probabilistic Models", and \
59 | DDIM sampler proposed by "Denoising Diffusion Implicit Models".
60 |
61 | Args:
62 | nn_model: A network (e.g. UNet) which performs same-shape mapping.
63 | device: The CUDA device that tensors run on.
64 | Parameters:
65 | betas, n_T
66 | '''
67 | super(DDPM, self).__init__()
68 | self.nn_model = nn_model.to(device)
69 | params = sum(p.numel() for p in nn_model.parameters() if p.requires_grad) / 1e6
70 | print(f"nn model # params: {params:.1f}")
71 |
72 | self.device = device
73 | self.ddpm_sche = schedules(betas, n_T, device, 'DDPM')
74 | self.ddim_sche = schedules(betas, n_T, device, 'DDIM')
75 | self.n_T = n_T
76 | self.loss = nn.MSELoss()
77 |
78 | def perturb(self, x, t=None):
79 | ''' Add noise to a clean image (diffusion process).
80 |
81 | Args:
82 | x: The normalized image tensor.
83 | t: The specified timestep ranged in `[1, n_T]`. Type: int / torch.LongTensor / None. \
84 | Random `t ~ U[1, n_T]` is taken if t is None.
85 | Returns:
86 | The perturbed image, the corresponding timestep, and the noise.
87 | '''
88 | if t is None:
89 | t = torch.randint(1, self.n_T + 1, (x.shape[0], )).to(self.device)
90 | elif not isinstance(t, torch.Tensor):
91 | t = torch.tensor([t]).to(self.device).repeat(x.shape[0])
92 |
93 | noise = torch.randn_like(x)
94 | sche = self.ddpm_sche
95 | x_noised = (sche["sqrtab"][t, None, None, None] * x +
96 | sche["sqrtmab"][t, None, None, None] * noise)
97 | return x_noised, t, noise
98 |
99 | def forward(self, x, use_amp=False):
100 | ''' Training with simple noise prediction loss.
101 |
102 | Args:
103 | x: The clean image tensor ranged in `[0, 1]`.
104 | Returns:
105 | The simple MSE loss.
106 | '''
107 | x = normalize_to_neg_one_to_one(x)
108 | x_noised, t, noise = self.perturb(x, t=None)
109 |
110 | with autocast(enabled=use_amp):
111 | return self.loss(noise, self.nn_model(x_noised, t / self.n_T))
112 |
113 | def get_feature(self, x, t, name=None, norm=False, use_amp=False):
114 | ''' Get network's intermediate activation in a forward pass.
115 |
116 | Args:
117 | x: The clean image tensor ranged in `[0, 1]`.
118 | t: The specified timestep ranged in `[1, n_T]`. Type: int / torch.LongTensor.
119 | norm: to normalize features to the the unit hypersphere.
120 | Returns:
121 | A {name: tensor} dict which contains global average pooled features.
122 | '''
123 | x = normalize_to_neg_one_to_one(x)
124 | x_noised, t, noise = self.perturb(x, t)
125 |
126 | def gap_and_norm(act, norm=False):
127 | if len(act.shape) == 4:
128 | # unet (B, C, H, W)
129 | act = act.view(act.shape[0], act.shape[1], -1).float()
130 | act = torch.mean(act, dim=2)
131 | else:
132 | raise NotImplementedError
133 | if norm:
134 | act = torch.nn.functional.normalize(act)
135 | return act
136 |
137 | with autocast(enabled=use_amp):
138 | _, acts = self.nn_model(x_noised, t / self.n_T, ret_activation=True)
139 | all_feats = {blockname: gap_and_norm(acts[blockname], norm) for blockname in acts}
140 | if name is not None:
141 | return all_feats[name]
142 | else:
143 | return all_feats
144 |
145 | def sample(self, n_sample, size, notqdm=False, use_amp=False):
146 | ''' Sampling with DDPM sampler. Actual NFE is `n_T`.
147 |
148 | Args:
149 | n_sample: The batch size.
150 | size: The image shape (e.g. `(3, 32, 32)`).
151 | Returns:
152 | The sampled image tensor ranged in `[0, 1]`.
153 | '''
154 | sche = self.ddpm_sche
155 | x_i = torch.randn(n_sample, *size).to(self.device)
156 |
157 | for i in tqdm(range(self.n_T, 0, -1), disable=notqdm):
158 | t_is = torch.tensor([i / self.n_T]).to(self.device).repeat(n_sample)
159 |
160 | z = torch.randn(n_sample, *size).to(self.device) if i > 1 else 0
161 |
162 | alpha = sche["alphabar_t"][i]
163 | eps, _ = self.pred_eps_(x_i, t_is, alpha, use_amp)
164 |
165 | mean = sche["oneover_sqrta"][i] * (x_i - sche["ma_over_sqrtmab"][i] * eps)
166 | variance = sche["sqrt_beta_t"][i] # LET variance sigma_t = sqrt_beta_t
167 | x_i = mean + variance * z
168 |
169 | return unnormalize_to_zero_to_one(x_i)
170 |
171 | def ddim_sample(self, n_sample, size, steps=100, eta=0.0, notqdm=False, use_amp=False):
172 | ''' Sampling with DDIM sampler. Actual NFE is `steps`.
173 |
174 | Args:
175 | n_sample: The batch size.
176 | size: The image shape (e.g. `(3, 32, 32)`).
177 | steps: The number of total timesteps.
178 | eta: controls stochasticity. Set `eta=0` for deterministic sampling.
179 | Returns:
180 | The sampled image tensor ranged in `[0, 1]`.
181 | '''
182 | sche = self.ddim_sche
183 | x_i = torch.randn(n_sample, *size).to(self.device)
184 |
185 | times = torch.arange(0, self.n_T, self.n_T // steps) + 1
186 | times = list(reversed(times.int().tolist())) + [0]
187 | time_pairs = list(zip(times[:-1], times[1:]))
188 | # e.g. [(801, 601), (601, 401), (401, 201), (201, 1), (1, 0)]
189 |
190 | for time, time_next in tqdm(time_pairs, disable=notqdm):
191 | t_is = torch.tensor([time / self.n_T]).to(self.device).repeat(n_sample)
192 |
193 | z = torch.randn(n_sample, *size).to(self.device) if time_next > 0 else 0
194 |
195 | alpha = sche["alphabar_t"][time]
196 | eps, x0_t = self.pred_eps_(x_i, t_is, alpha, use_amp)
197 | alpha_next = sche["alphabar_t"][time_next]
198 | c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
199 | c2 = (1 - alpha_next - c1 ** 2).sqrt()
200 | x_i = alpha_next.sqrt() * x0_t + c2 * eps + c1 * z
201 |
202 | return unnormalize_to_zero_to_one(x_i)
203 |
204 | def pred_eps_(self, x, t, alpha, use_amp, clip_x=True):
205 | def pred_eps_from_x0(x0):
206 | return (x - x0 * alpha.sqrt()) / (1 - alpha).sqrt()
207 |
208 | def pred_x0_from_eps(eps):
209 | return (x - (1 - alpha).sqrt() * eps) / alpha.sqrt()
210 |
211 | # get prediction of x0
212 | with autocast(enabled=use_amp):
213 | eps = self.nn_model(x, t).float()
214 | denoised = pred_x0_from_eps(eps)
215 |
216 | # pixel-space clipping (optional)
217 | if clip_x:
218 | denoised = torch.clip(denoised, -1., 1.)
219 | eps = pred_eps_from_x0(denoised)
220 | return eps, denoised
221 |
--------------------------------------------------------------------------------
/model/EDM.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from tqdm import tqdm
5 | from torch.cuda.amp import autocast as autocast
6 | from .augment import AugmentPipe
7 |
8 |
9 | def normalize_to_neg_one_to_one(img):
10 | # [0.0, 1.0] -> [-1.0, 1.0]
11 | return img * 2 - 1
12 |
13 |
14 | def unnormalize_to_zero_to_one(t):
15 | # [-1.0, 1.0] -> [0.0, 1.0]
16 | return (t + 1) * 0.5
17 |
18 |
19 | class EDM(nn.Module):
20 | def __init__(self, nn_model,
21 | sigma_data, p_mean, p_std,
22 | sigma_min, sigma_max, rho,
23 | S_min, S_max, S_noise,
24 | device,
25 | augment_prob=0):
26 | ''' EDM proposed by "Elucidating the Design Space of Diffusion-Based Generative Models".
27 |
28 | Args:
29 | nn_model: A network (e.g. UNet) which performs same-shape mapping.
30 | device: The CUDA device that tensors run on.
31 | Training parameters:
32 | sigma_data, p_mean, p_std
33 | augment_prob
34 | Sampling parameters:
35 | sigma_min, sigma_max, rho
36 | S_min, S_max, S_noise
37 | '''
38 | super(EDM, self).__init__()
39 | self.nn_model = nn_model.to(device)
40 | params = sum(p.numel() for p in nn_model.parameters() if p.requires_grad) / 1e6
41 | print(f"nn model # params: {params:.1f}")
42 |
43 | self.device = device
44 |
45 | def number_to_torch_device(value):
46 | return torch.tensor(value).to(device)
47 |
48 | self.sigma_data = number_to_torch_device(sigma_data)
49 | self.p_mean = number_to_torch_device(p_mean)
50 | self.p_std = number_to_torch_device(p_std)
51 | self.sigma_min = number_to_torch_device(sigma_min)
52 | self.sigma_max = number_to_torch_device(sigma_max)
53 | self.rho = number_to_torch_device(rho)
54 | self.S_min = number_to_torch_device(S_min)
55 | self.S_max = number_to_torch_device(S_max)
56 | self.S_noise = number_to_torch_device(S_noise)
57 | if augment_prob > 0:
58 | self.augpipe = AugmentPipe(p=augment_prob, xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1)
59 | else:
60 | self.augpipe = None
61 |
62 | def perturb(self, x, t=None, steps=None):
63 | ''' Add noise to a clean image (diffusion process).
64 |
65 | Args:
66 | x: The normalized image tensor.
67 | t: The specified timestep ranged in `[1, steps]`. Type: int / torch.LongTensor / None. \
68 | Random `ln(sigma) ~ N(P_mean, P_std)` is taken if t is None.
69 | Returns:
70 | The perturbed image, and the corresponding sigma.
71 | '''
72 | if t is None:
73 | rnd_normal = torch.randn((x.shape[0], 1, 1, 1)).to(self.device)
74 | sigma = (rnd_normal * self.p_std + self.p_mean).exp()
75 | else:
76 | times = reversed(self.sample_schedule(steps))
77 | sigma = times[t]
78 | if len(sigma.shape) == 1:
79 | sigma = sigma[:, None, None, None]
80 |
81 | noise = torch.randn_like(x)
82 | x_noised = x + noise * sigma
83 | return x_noised, sigma
84 |
85 | def forward(self, x, use_amp=False):
86 | ''' Training with weighted denoising loss.
87 |
88 | Args:
89 | x: The clean image tensor ranged in `[0, 1]`.
90 | Returns:
91 | The weighted MSE loss.
92 | '''
93 | x = normalize_to_neg_one_to_one(x)
94 | x, aug_label = self.augpipe(x) if self.augpipe is not None else (x, None)
95 | x_noised, sigma = self.perturb(x, t=None)
96 |
97 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
98 | loss_4shape = weight * ((x - self.D_x(x_noised, sigma, use_amp, aug_label)) ** 2)
99 | return loss_4shape.mean()
100 |
101 | def get_feature(self, x, t, steps=18, name=None, norm=False, use_amp=False):
102 | ''' Get network's intermediate activation in a forward pass.
103 |
104 | Args:
105 | x: The clean image tensor ranged in `[0, 1]`.
106 | t: The specified timestep ranged in `[1, steps]`. Type: int / torch.LongTensor.
107 | norm: to normalize features to the the unit hypersphere.
108 | Returns:
109 | A {name: tensor} dict which contains global average pooled features.
110 | '''
111 | x = normalize_to_neg_one_to_one(x)
112 | x_noised, sigma = self.perturb(x, t, steps)
113 |
114 | def gap_and_norm(act, norm=False):
115 | if len(act.shape) == 4:
116 | # unet (B, C, H, W)
117 | act = act.view(act.shape[0], act.shape[1], -1).float()
118 | act = torch.mean(act, dim=2)
119 | else:
120 | raise NotImplementedError
121 | if norm:
122 | act = torch.nn.functional.normalize(act)
123 | return act
124 |
125 | _, acts = self.D_x(x_noised, sigma, use_amp, ret_activation=True)
126 | all_feats = {blockname: gap_and_norm(acts[blockname], norm) for blockname in acts}
127 | if name is not None:
128 | return all_feats[name]
129 | else:
130 | return all_feats
131 |
132 | def edm_sample(self, n_sample, size, steps=18, eta=0.0, notqdm=False, use_amp=False):
133 | ''' Sampling with EDM sampler. Actual NFE is `2 * steps - 1`.
134 |
135 | Args:
136 | n_sample: The batch size.
137 | size: The image shape (e.g. `(3, 32, 32)`).
138 | steps: The number of total timesteps.
139 | eta: controls stochasticity. Set `eta=0` for deterministic sampling.
140 | Returns:
141 | The sampled image tensor ranged in `[0, 1]`.
142 | '''
143 | S_min, S_max, S_noise = self.S_min, self.S_max, self.S_noise
144 | gamma_stochasticity = torch.tensor(np.sqrt(2) - 1) * eta # S_churn = (sqrt(2) - 1) * eta * steps
145 |
146 | times = self.sample_schedule(steps)
147 | time_pairs = list(zip(times[:-1], times[1:]))
148 |
149 | x_next = torch.randn(n_sample, *size).to(self.device).to(torch.float64) * times[0]
150 | for i, (t_cur, t_next) in enumerate(tqdm(time_pairs, disable=notqdm)): # 0, ..., N-1
151 | x_cur = x_next
152 |
153 | # Increase noise temporarily.
154 | gamma = gamma_stochasticity if S_min <= t_cur <= S_max else 0
155 | t_hat = t_cur + gamma * t_cur
156 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)
157 |
158 | # Euler step.
159 | d_cur = self.pred_eps_(x_hat, t_hat, use_amp)
160 | x_next = x_hat + (t_next - t_hat) * d_cur
161 |
162 | # Apply 2nd order correction.
163 | if i < steps - 1:
164 | d_prime = self.pred_eps_(x_next, t_next, use_amp)
165 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
166 |
167 | return unnormalize_to_zero_to_one(x_next)
168 |
169 | def pred_eps_(self, x, t, use_amp, clip_x=True):
170 | denoised = self.D_x(x, t, use_amp).to(torch.float64)
171 | # pixel-space clipping (optional)
172 | if clip_x:
173 | denoised = torch.clip(denoised, -1., 1.)
174 | eps = (x - denoised) / t
175 | return eps
176 |
177 | def D_x(self, x_noised, sigma, use_amp, aug_label=None, ret_activation=False):
178 | ''' Denoising with network preconditioning.
179 |
180 | Args:
181 | x_noised: The perturbed image tensor.
182 | sigma: The variance (noise level) tensor.
183 | aug_label: The augmentation labels produced by AugmentPipe.
184 | Returns:
185 | The estimated denoised image tensor.
186 | The {name: (B, C, H, W) tensor} activation dict (if ret_activation is True).
187 | '''
188 | x_noised = x_noised.to(torch.float32)
189 | sigma = sigma.to(torch.float32)
190 |
191 | # Preconditioning
192 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
193 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
194 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2).sqrt()
195 | c_noise = sigma.log() / 4
196 |
197 | # Denoising
198 | with autocast(enabled=use_amp):
199 | F_x = self.nn_model(c_in * x_noised, c_noise.flatten(), aug_label, ret_activation)
200 |
201 | if ret_activation:
202 | return c_skip * x_noised + c_out * F_x[0], F_x[1]
203 | else:
204 | return c_skip * x_noised + c_out * F_x
205 |
206 | def sample_schedule(self, steps):
207 | ''' Make the variance schedule for EDM sampling.
208 |
209 | Args:
210 | steps: The number of total timesteps. Typically 18, 50 or 100.
211 | Returns:
212 | times: A decreasing tensor list such that
213 | `times[0] == sigma_max`,
214 | `times[steps-1] == sigma_min`, and
215 | `times[steps] == 0`.
216 | '''
217 | sigma_min, sigma_max, rho = self.sigma_min, self.sigma_max, self.rho
218 | times = torch.arange(steps, dtype=torch.float64, device=self.device)
219 | times = (sigma_max ** (1 / rho) + times / (steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
220 | times = torch.cat([times, torch.zeros_like(times[:1])]) # t_N = 0
221 | return times
222 |
--------------------------------------------------------------------------------
/model/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .block import GroupNorm32, TimeEmbedding, AttentionBlock, Upsample, Downsample
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_channels, out_channels, time_channels, dropout=0.1, up=False, down=False):
8 | """
9 | * `in_channels` is the number of input channels
10 | * `out_channels` is the number of output channels
11 | * `time_channels` is the number channels in the time step ($t$) embeddings
12 | * `dropout` is the dropout rate
13 | """
14 | super().__init__()
15 | self.norm1 = GroupNorm32(in_channels)
16 | self.act1 = nn.SiLU()
17 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
18 |
19 | self.norm2 = GroupNorm32(out_channels)
20 | self.act2 = nn.SiLU()
21 | self.conv2 = nn.Sequential(
22 | nn.Dropout(dropout),
23 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
24 | )
25 |
26 | if in_channels != out_channels:
27 | self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
28 | else:
29 | self.shortcut = nn.Identity()
30 |
31 | # Linear layer for embeddings
32 | self.time_emb = nn.Sequential(
33 | nn.SiLU(),
34 | nn.Linear(time_channels, out_channels)
35 | )
36 |
37 | # BigGAN style: use resblock for up/downsampling
38 | self.updown = up or down
39 | if up:
40 | self.h_upd = Upsample(in_channels, use_conv=False)
41 | self.x_upd = Upsample(in_channels, use_conv=False)
42 | elif down:
43 | self.h_upd = Downsample(in_channels, use_conv=False)
44 | self.x_upd = Downsample(in_channels, use_conv=False)
45 | else:
46 | self.h_upd = self.x_upd = nn.Identity()
47 |
48 | def forward(self, x, t):
49 | """
50 | * `x` has shape `[batch_size, in_channels, height, width]`
51 | * `t` has shape `[batch_size, time_channels]`
52 | """
53 | if self.updown:
54 | h = self.conv1(self.h_upd(self.act1(self.norm1(x))))
55 | x = self.x_upd(x)
56 | else:
57 | h = self.conv1(self.act1(self.norm1(x)))
58 |
59 | # Adaptive Group Normalization
60 | t_ = self.time_emb(t)[:, :, None, None]
61 | h = h + t_
62 |
63 | h = self.conv2(self.act2(self.norm2(h)))
64 | return h + self.shortcut(x)
65 |
66 |
67 | class ResAttBlock(nn.Module):
68 | def __init__(self, in_channels, out_channels, time_channels, has_attn, attn_channels_per_head, dropout):
69 | super().__init__()
70 | self.res = ResidualBlock(in_channels, out_channels, time_channels, dropout=dropout)
71 | if has_attn:
72 | self.attn = AttentionBlock(out_channels, attn_channels_per_head)
73 | else:
74 | self.attn = nn.Identity()
75 |
76 | def forward(self, x, t):
77 | x = self.res(x, t)
78 | x = self.attn(x)
79 | return x
80 |
81 |
82 | class MiddleBlock(nn.Module):
83 | def __init__(self, n_channels, time_channels, attn_channels_per_head, dropout):
84 | super().__init__()
85 | self.res1 = ResidualBlock(n_channels, n_channels, time_channels, dropout=dropout)
86 | self.attn = AttentionBlock(n_channels, attn_channels_per_head)
87 | self.res2 = ResidualBlock(n_channels, n_channels, time_channels, dropout=dropout)
88 |
89 | def forward(self, x, t):
90 | x = self.res1(x, t)
91 | x = self.attn(x)
92 | x = self.res2(x, t)
93 | return x
94 |
95 |
96 | class UpsampleRes(nn.Module):
97 | def __init__(self, n_channels, time_channels, dropout):
98 | super().__init__()
99 | self.op = ResidualBlock(n_channels, n_channels, time_channels, dropout=dropout, up=True)
100 |
101 | def forward(self, x, t):
102 | return self.op(x, t)
103 |
104 |
105 | class DownsampleRes(nn.Module):
106 | def __init__(self, n_channels, time_channels, dropout):
107 | super().__init__()
108 | self.op = ResidualBlock(n_channels, n_channels, time_channels, dropout=dropout, down=True)
109 |
110 | def forward(self, x, t):
111 | return self.op(x, t)
112 |
113 |
114 | class UNet(nn.Module):
115 | def __init__(self, image_shape = [3, 32, 32], n_channels = 128,
116 | ch_mults = (1, 2, 2, 2),
117 | is_attn = (False, True, False, False),
118 | attn_channels_per_head = None,
119 | dropout = 0.1,
120 | n_blocks = 2,
121 | use_res_for_updown = False,
122 | augment_dim = 0):
123 | """
124 | * `image_shape` is the (channel, height, width) size of images.
125 | * `n_channels` is number of channels in the initial feature map that we transform the image into
126 | * `ch_mults` is the list of channel numbers at each resolution. The number of channels is `n_channels * ch_mults[i]`
127 | * `is_attn` is a list of booleans that indicate whether to use attention at each resolution
128 | * `dropout` is the dropout rate
129 | * `n_blocks` is the number of `UpDownBlocks` at each resolution
130 | * `use_res_for_updown` indicates whether to use ResBlocks for up/down sampling (BigGAN-style)
131 | * `augment_dim` indicates augmentation label dimensionality, 0 = no augmentation
132 | """
133 | super().__init__()
134 |
135 | n_resolutions = len(ch_mults)
136 |
137 | self.image_proj = nn.Conv2d(image_shape[0], n_channels, kernel_size=3, padding=1)
138 |
139 | # Embedding layers (time & augment)
140 | time_channels = n_channels * 4
141 | self.time_emb = TimeEmbedding(time_channels, augment_dim)
142 |
143 | # Down stages
144 | down = []
145 | in_channels = n_channels
146 | h_channels = [n_channels]
147 | for i in range(n_resolutions):
148 | # Number of output channels at this resolution
149 | out_channels = n_channels * ch_mults[i]
150 | # `n_blocks` at the same resolution
151 | down.append(ResAttBlock(in_channels, out_channels, time_channels, is_attn[i], attn_channels_per_head, dropout))
152 | h_channels.append(out_channels)
153 | for _ in range(n_blocks - 1):
154 | down.append(ResAttBlock(out_channels, out_channels, time_channels, is_attn[i], attn_channels_per_head, dropout))
155 | h_channels.append(out_channels)
156 | # Down sample at all resolutions except the last
157 | if i < n_resolutions - 1:
158 | if use_res_for_updown:
159 | down.append(DownsampleRes(out_channels, time_channels, dropout))
160 | else:
161 | down.append(Downsample(out_channels))
162 | h_channels.append(out_channels)
163 | in_channels = out_channels
164 | self.down = nn.ModuleList(down)
165 |
166 | # Middle block
167 | self.middle = MiddleBlock(out_channels, time_channels, attn_channels_per_head, dropout)
168 |
169 | # Up stages
170 | up = []
171 | in_channels = out_channels
172 | for i in reversed(range(n_resolutions)):
173 | # Number of output channels at this resolution
174 | out_channels = n_channels * ch_mults[i]
175 | # `n_blocks + 1` at the same resolution
176 | for _ in range(n_blocks + 1):
177 | up.append(ResAttBlock(in_channels + h_channels.pop(), out_channels, time_channels, is_attn[i], attn_channels_per_head, dropout))
178 | in_channels = out_channels
179 | # Up sample at all resolutions except last
180 | if i > 0:
181 | if use_res_for_updown:
182 | up.append(UpsampleRes(out_channels, time_channels, dropout))
183 | else:
184 | up.append(Upsample(out_channels))
185 | assert not h_channels
186 | self.up = nn.ModuleList(up)
187 |
188 | # Final normalization and convolution layer
189 | self.norm = nn.GroupNorm(8, out_channels)
190 | self.act = nn.SiLU()
191 | self.final = nn.Conv2d(out_channels, image_shape[0], kernel_size=3, padding=1)
192 |
193 | def forward(self, x, t, aug=None, ret_activation=False):
194 | if not ret_activation:
195 | return self.forward_core(x, t, aug)
196 |
197 | activation = {}
198 | def namedHook(name):
199 | def hook(module, input, output):
200 | activation[name] = output
201 | return hook
202 | hooks = {}
203 | no = 0
204 | for blk in self.up:
205 | if isinstance(blk, ResAttBlock):
206 | no += 1
207 | name = f'out_{no}'
208 | hooks[name] = blk.register_forward_hook(namedHook(name))
209 |
210 | result = self.forward_core(x, t, aug)
211 | for name in hooks:
212 | hooks[name].remove()
213 | return result, activation
214 |
215 | def forward_core(self, x, t, aug):
216 | """
217 | * `x` has shape `[batch_size, in_channels, height, width]`
218 | * `t` has shape `[batch_size]`
219 | """
220 |
221 | t = self.time_emb(t, aug)
222 | x = self.image_proj(x)
223 |
224 | # `h` will store outputs at each resolution for skip connection
225 | h = [x]
226 |
227 | for m in self.down:
228 | if isinstance(m, Downsample):
229 | x = m(x)
230 | elif isinstance(m, DownsampleRes):
231 | x = m(x, t)
232 | else:
233 | x = m(x, t).contiguous()
234 | h.append(x)
235 |
236 | x = self.middle(x, t).contiguous()
237 |
238 | for m in self.up:
239 | if isinstance(m, Upsample):
240 | x = m(x)
241 | elif isinstance(m, UpsampleRes):
242 | x = m(x, t)
243 | else:
244 | # Get the skip connection from first half of U-Net and concatenate
245 | s = h.pop()
246 | x = torch.cat((x, s), dim=1)
247 | x = m(x, t).contiguous()
248 |
249 | return self.final(self.act(self.norm(x)))
250 |
251 |
252 | '''
253 | from model.unet import UNet
254 | net = UNet()
255 | import torch
256 | x = torch.zeros(1, 3, 32, 32)
257 | t = torch.zeros(1,)
258 |
259 | net(x, t).shape
260 | sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
261 |
262 | >>> 35.746307 M parameters for CIFAR-10 model (original DDPM)
263 | '''
264 |
--------------------------------------------------------------------------------
/DiT/models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 | import math
16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
17 |
18 |
19 | def modulate(x, shift, scale):
20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
21 |
22 |
23 | #################################################################################
24 | # Embedding Layers for Timesteps and Class Labels #
25 | #################################################################################
26 |
27 | class TimestepEmbedder(nn.Module):
28 | """
29 | Embeds scalar timesteps into vector representations.
30 | """
31 | def __init__(self, hidden_size, frequency_embedding_size=256):
32 | super().__init__()
33 | self.mlp = nn.Sequential(
34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
35 | nn.SiLU(),
36 | nn.Linear(hidden_size, hidden_size, bias=True),
37 | )
38 | self.frequency_embedding_size = frequency_embedding_size
39 |
40 | @staticmethod
41 | def timestep_embedding(t, dim, max_period=10000):
42 | """
43 | Create sinusoidal timestep embeddings.
44 | :param t: a 1-D Tensor of N indices, one per batch element.
45 | These may be fractional.
46 | :param dim: the dimension of the output.
47 | :param max_period: controls the minimum frequency of the embeddings.
48 | :return: an (N, D) Tensor of positional embeddings.
49 | """
50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
51 | half = dim // 2
52 | freqs = torch.exp(
53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
54 | ).to(device=t.device)
55 | args = t[:, None].float() * freqs[None]
56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
57 | if dim % 2:
58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
59 | return embedding
60 |
61 | def forward(self, t):
62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
63 | t_emb = self.mlp(t_freq)
64 | return t_emb
65 |
66 |
67 | class LabelEmbedder(nn.Module):
68 | """
69 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
70 | """
71 | def __init__(self, num_classes, hidden_size, dropout_prob):
72 | super().__init__()
73 | use_cfg_embedding = dropout_prob > 0
74 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
75 | self.num_classes = num_classes
76 | self.dropout_prob = dropout_prob
77 |
78 | def token_drop(self, labels, force_drop_ids=None):
79 | """
80 | Drops labels to enable classifier-free guidance.
81 | """
82 | if force_drop_ids is None:
83 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
84 | else:
85 | drop_ids = force_drop_ids == 1
86 | labels = torch.where(drop_ids, self.num_classes, labels)
87 | return labels
88 |
89 | def forward(self, labels, train, force_drop_ids=None):
90 | use_dropout = self.dropout_prob > 0
91 | if (train and use_dropout) or (force_drop_ids is not None):
92 | labels = self.token_drop(labels, force_drop_ids)
93 | embeddings = self.embedding_table(labels)
94 | return embeddings
95 |
96 |
97 | #################################################################################
98 | # Core DiT Model #
99 | #################################################################################
100 |
101 | class DiTBlock(nn.Module):
102 | """
103 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
104 | """
105 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
106 | super().__init__()
107 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
108 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
109 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
110 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
111 | approx_gelu = lambda: nn.GELU(approximate="tanh")
112 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
113 | self.adaLN_modulation = nn.Sequential(
114 | nn.SiLU(),
115 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
116 | )
117 |
118 | def forward(self, x, c):
119 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
120 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
121 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
122 | return x
123 |
124 |
125 | class FinalLayer(nn.Module):
126 | """
127 | The final layer of DiT.
128 | """
129 | def __init__(self, hidden_size, patch_size, out_channels):
130 | super().__init__()
131 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
132 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
133 | self.adaLN_modulation = nn.Sequential(
134 | nn.SiLU(),
135 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
136 | )
137 |
138 | def forward(self, x, c):
139 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
140 | x = modulate(self.norm_final(x), shift, scale)
141 | x = self.linear(x)
142 | return x
143 |
144 |
145 | class DiT(nn.Module):
146 | """
147 | Diffusion model with a Transformer backbone.
148 | """
149 | def __init__(
150 | self,
151 | input_size=32,
152 | patch_size=2,
153 | in_channels=4,
154 | hidden_size=1152,
155 | depth=28,
156 | num_heads=16,
157 | mlp_ratio=4.0,
158 | class_dropout_prob=0.1,
159 | num_classes=1000,
160 | learn_sigma=True,
161 | ):
162 | super().__init__()
163 | self.learn_sigma = learn_sigma
164 | self.in_channels = in_channels
165 | self.out_channels = in_channels * 2 if learn_sigma else in_channels
166 | self.patch_size = patch_size
167 | self.num_heads = num_heads
168 |
169 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
170 | self.t_embedder = TimestepEmbedder(hidden_size)
171 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
172 | num_patches = self.x_embedder.num_patches
173 | # Will use fixed sin-cos embedding:
174 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
175 |
176 | self.blocks = nn.ModuleList([
177 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
178 | ])
179 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
180 | self.initialize_weights()
181 |
182 | def initialize_weights(self):
183 | # Initialize transformer layers:
184 | def _basic_init(module):
185 | if isinstance(module, nn.Linear):
186 | torch.nn.init.xavier_uniform_(module.weight)
187 | if module.bias is not None:
188 | nn.init.constant_(module.bias, 0)
189 | self.apply(_basic_init)
190 |
191 | # Initialize (and freeze) pos_embed by sin-cos embedding:
192 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
193 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
194 |
195 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
196 | w = self.x_embedder.proj.weight.data
197 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
198 | nn.init.constant_(self.x_embedder.proj.bias, 0)
199 |
200 | # Initialize label embedding table:
201 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
202 |
203 | # Initialize timestep embedding MLP:
204 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
205 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
206 |
207 | # Zero-out adaLN modulation layers in DiT blocks:
208 | for block in self.blocks:
209 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
210 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
211 |
212 | # Zero-out output layers:
213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
214 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
215 | nn.init.constant_(self.final_layer.linear.weight, 0)
216 | nn.init.constant_(self.final_layer.linear.bias, 0)
217 |
218 | def unpatchify(self, x):
219 | """
220 | x: (N, T, patch_size**2 * C)
221 | imgs: (N, H, W, C)
222 | """
223 | c = self.out_channels
224 | p = self.x_embedder.patch_size[0]
225 | h = w = int(x.shape[1] ** 0.5)
226 | assert h * w == x.shape[1]
227 |
228 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
229 | x = torch.einsum('nhwpqc->nchpwq', x)
230 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
231 | return imgs
232 |
233 | def forward(self, x, t, y, ret_activation=False):
234 | if not ret_activation:
235 | return self.forward_core(x, t, y)
236 |
237 | activation = {}
238 | def namedHook(name):
239 | def hook(module, input, output):
240 | activation[name] = output
241 | return hook
242 | hooks = {}
243 | for idx, block in enumerate(self.blocks):
244 | name = f"layer-{idx}"
245 | hooks[name] = block.register_forward_hook(namedHook(name))
246 |
247 | result = self.forward_core(x, t, y)
248 | for name in hooks:
249 | hooks[name].remove()
250 | return result, activation
251 |
252 | def forward_core(self, x, t, y):
253 | """
254 | Forward pass of DiT.
255 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
256 | t: (N,) tensor of diffusion timesteps
257 | y: (N,) tensor of class labels
258 | """
259 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
260 | t = self.t_embedder(t) # (N, D)
261 | y = self.y_embedder(y, self.training) # (N, D)
262 | c = t + y # (N, D)
263 | for block in self.blocks:
264 | x = block(x, c) # (N, T, D)
265 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
266 | x = self.unpatchify(x) # (N, out_channels, H, W)
267 | return x
268 |
269 | def forward_with_cfg(self, x, t, y, cfg_scale):
270 | """
271 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
272 | """
273 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
274 | half = x[: len(x) // 2]
275 | combined = torch.cat([half, half], dim=0)
276 | model_out = self.forward(combined, t, y)
277 | # For exact reproducibility reasons, we apply classifier-free guidance on only
278 | # three channels by default. The standard approach to cfg applies it to all channels.
279 | # This can be done by uncommenting the following line and commenting-out the line following that.
280 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
281 | eps, rest = model_out[:, :3], model_out[:, 3:]
282 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
283 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
284 | eps = torch.cat([half_eps, half_eps], dim=0)
285 | return torch.cat([eps, rest], dim=1)
286 |
287 |
288 | #################################################################################
289 | # Sine/Cosine Positional Embedding Functions #
290 | #################################################################################
291 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
292 |
293 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
294 | """
295 | grid_size: int of the grid height and width
296 | return:
297 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
298 | """
299 | grid_h = np.arange(grid_size, dtype=np.float32)
300 | grid_w = np.arange(grid_size, dtype=np.float32)
301 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
302 | grid = np.stack(grid, axis=0)
303 |
304 | grid = grid.reshape([2, 1, grid_size, grid_size])
305 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
306 | if cls_token and extra_tokens > 0:
307 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
308 | return pos_embed
309 |
310 |
311 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
312 | assert embed_dim % 2 == 0
313 |
314 | # use half of dimensions to encode grid_h
315 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
316 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
317 |
318 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
319 | return emb
320 |
321 |
322 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
323 | """
324 | embed_dim: output dimension for each position
325 | pos: a list of positions to be encoded: size (M,)
326 | out: (M, D)
327 | """
328 | assert embed_dim % 2 == 0
329 | omega = np.arange(embed_dim // 2, dtype=np.float64)
330 | omega /= embed_dim / 2.
331 | omega = 1. / 10000**omega # (D/2,)
332 |
333 | pos = pos.reshape(-1) # (M,)
334 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
335 |
336 | emb_sin = np.sin(out) # (M, D/2)
337 | emb_cos = np.cos(out) # (M, D/2)
338 |
339 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
340 | return emb
341 |
342 |
343 | #################################################################################
344 | # DiT Configs #
345 | #################################################################################
346 |
347 | def DiT_XL_2(**kwargs):
348 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
349 |
350 | def DiT_XL_4(**kwargs):
351 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
352 |
353 | def DiT_XL_8(**kwargs):
354 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
355 |
356 | def DiT_L_2(**kwargs):
357 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
358 |
359 | def DiT_L_4(**kwargs):
360 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
361 |
362 | def DiT_L_8(**kwargs):
363 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
364 |
365 | def DiT_B_2(**kwargs):
366 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
367 |
368 | def DiT_B_4(**kwargs):
369 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
370 |
371 | def DiT_B_8(**kwargs):
372 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
373 |
374 | def DiT_S_2(**kwargs):
375 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
376 |
377 | def DiT_S_4(**kwargs):
378 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
379 |
380 | def DiT_S_8(**kwargs):
381 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
382 |
383 |
384 | DiT_models = {
385 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
386 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
387 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
388 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
389 | }
390 |
--------------------------------------------------------------------------------
/model/augment.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Augmentation pipeline used in the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models".
10 | Built around the same concepts that were originally proposed in the paper
11 | "Training Generative Adversarial Networks with Limited Data"."""
12 |
13 | import numpy as np
14 | import torch
15 |
16 | #----------------------------------------------------------------------------
17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18 | # same constant is used multiple times.
19 |
20 | _constant_cache = dict()
21 |
22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23 | value = np.asarray(value)
24 | if shape is not None:
25 | shape = tuple(shape)
26 | if dtype is None:
27 | dtype = torch.get_default_dtype()
28 | if device is None:
29 | device = torch.device('cpu')
30 | if memory_format is None:
31 | memory_format = torch.contiguous_format
32 |
33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34 | tensor = _constant_cache.get(key, None)
35 | if tensor is None:
36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37 | if shape is not None:
38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39 | tensor = tensor.contiguous(memory_format=memory_format)
40 | _constant_cache[key] = tensor
41 | return tensor
42 |
43 | #----------------------------------------------------------------------------
44 | # Coefficients of various wavelet decomposition low-pass filters.
45 |
46 | wavelets = {
47 | 'haar': [0.7071067811865476, 0.7071067811865476],
48 | 'db1': [0.7071067811865476, 0.7071067811865476],
49 | 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
50 | 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
51 | 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
52 | 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
53 | 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
54 | 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
55 | 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
56 | 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
57 | 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
58 | 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
59 | 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
60 | 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
61 | 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
62 | 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
63 | }
64 |
65 | #----------------------------------------------------------------------------
66 | # Helpers for constructing transformation matrices.
67 |
68 | def matrix(*rows, device=None):
69 | assert all(len(row) == len(rows[0]) for row in rows)
70 | elems = [x for row in rows for x in row]
71 | ref = [x for x in elems if isinstance(x, torch.Tensor)]
72 | if len(ref) == 0:
73 | return constant(np.asarray(rows), device=device)
74 | assert device is None or device == ref[0].device
75 | elems = [x if isinstance(x, torch.Tensor) else constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
76 | return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
77 |
78 | def translate2d(tx, ty, **kwargs):
79 | return matrix(
80 | [1, 0, tx],
81 | [0, 1, ty],
82 | [0, 0, 1],
83 | **kwargs)
84 |
85 | def translate3d(tx, ty, tz, **kwargs):
86 | return matrix(
87 | [1, 0, 0, tx],
88 | [0, 1, 0, ty],
89 | [0, 0, 1, tz],
90 | [0, 0, 0, 1],
91 | **kwargs)
92 |
93 | def scale2d(sx, sy, **kwargs):
94 | return matrix(
95 | [sx, 0, 0],
96 | [0, sy, 0],
97 | [0, 0, 1],
98 | **kwargs)
99 |
100 | def scale3d(sx, sy, sz, **kwargs):
101 | return matrix(
102 | [sx, 0, 0, 0],
103 | [0, sy, 0, 0],
104 | [0, 0, sz, 0],
105 | [0, 0, 0, 1],
106 | **kwargs)
107 |
108 | def rotate2d(theta, **kwargs):
109 | return matrix(
110 | [torch.cos(theta), torch.sin(-theta), 0],
111 | [torch.sin(theta), torch.cos(theta), 0],
112 | [0, 0, 1],
113 | **kwargs)
114 |
115 | def rotate3d(v, theta, **kwargs):
116 | vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
117 | s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
118 | return matrix(
119 | [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
120 | [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
121 | [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
122 | [0, 0, 0, 1],
123 | **kwargs)
124 |
125 | def translate2d_inv(tx, ty, **kwargs):
126 | return translate2d(-tx, -ty, **kwargs)
127 |
128 | def scale2d_inv(sx, sy, **kwargs):
129 | return scale2d(1 / sx, 1 / sy, **kwargs)
130 |
131 | def rotate2d_inv(theta, **kwargs):
132 | return rotate2d(-theta, **kwargs)
133 |
134 | #----------------------------------------------------------------------------
135 | # Augmentation pipeline main class.
136 | # All augmentations are disabled by default; individual augmentations can
137 | # be enabled by setting their probability multipliers to 1.
138 |
139 | class AugmentPipe:
140 | def __init__(self, p=1,
141 | xflip=0, yflip=0, rotate_int=0, translate_int=0, translate_int_max=0.125,
142 | scale=0, rotate_frac=0, aniso=0, translate_frac=0, scale_std=0.2, rotate_frac_max=1, aniso_std=0.2, aniso_rotate_prob=0.5, translate_frac_std=0.125,
143 | brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
144 | ):
145 | super().__init__()
146 | self.p = float(p) # Overall multiplier for augmentation probability.
147 |
148 | # Pixel blitting.
149 | self.xflip = float(xflip) # Probability multiplier for x-flip.
150 | self.yflip = float(yflip) # Probability multiplier for y-flip.
151 | self.rotate_int = float(rotate_int) # Probability multiplier for integer rotation.
152 | self.translate_int = float(translate_int) # Probability multiplier for integer translation.
153 | self.translate_int_max = float(translate_int_max) # Range of integer translation, relative to image dimensions.
154 |
155 | # Geometric transformations.
156 | self.scale = float(scale) # Probability multiplier for isotropic scaling.
157 | self.rotate_frac = float(rotate_frac) # Probability multiplier for fractional rotation.
158 | self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
159 | self.translate_frac = float(translate_frac) # Probability multiplier for fractional translation.
160 | self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
161 | self.rotate_frac_max = float(rotate_frac_max) # Range of fractional rotation, 1 = full circle.
162 | self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
163 | self.aniso_rotate_prob = float(aniso_rotate_prob) # Probability of doing anisotropic scaling w.r.t. rotated coordinate frame.
164 | self.translate_frac_std = float(translate_frac_std) # Standard deviation of frational translation, relative to image dimensions.
165 |
166 | # Color transformations.
167 | self.brightness = float(brightness) # Probability multiplier for brightness.
168 | self.contrast = float(contrast) # Probability multiplier for contrast.
169 | self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
170 | self.hue = float(hue) # Probability multiplier for hue rotation.
171 | self.saturation = float(saturation) # Probability multiplier for saturation.
172 | self.brightness_std = float(brightness_std) # Standard deviation of brightness.
173 | self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
174 | self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
175 | self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
176 |
177 | def __call__(self, images):
178 | N, C, H, W = images.shape
179 | device = images.device
180 | labels = [torch.zeros([images.shape[0], 0], device=device)]
181 |
182 | # ---------------
183 | # Pixel blitting.
184 | # ---------------
185 |
186 | if self.xflip > 0:
187 | w = torch.randint(2, [N, 1, 1, 1], device=device)
188 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.xflip * self.p, w, torch.zeros_like(w))
189 | images = torch.where(w == 1, images.flip(3), images)
190 | labels += [w]
191 |
192 | if self.yflip > 0:
193 | w = torch.randint(2, [N, 1, 1, 1], device=device)
194 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.yflip * self.p, w, torch.zeros_like(w))
195 | images = torch.where(w == 1, images.flip(2), images)
196 | labels += [w]
197 |
198 | if self.rotate_int > 0:
199 | w = torch.randint(4, [N, 1, 1, 1], device=device)
200 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.rotate_int * self.p, w, torch.zeros_like(w))
201 | images = torch.where((w == 1) | (w == 2), images.flip(3), images)
202 | images = torch.where((w == 2) | (w == 3), images.flip(2), images)
203 | images = torch.where((w == 1) | (w == 3), images.transpose(2, 3), images)
204 | labels += [(w == 1) | (w == 2), (w == 2) | (w == 3)]
205 |
206 | if self.translate_int > 0:
207 | w = torch.rand([2, N, 1, 1, 1], device=device) * 2 - 1
208 | w = torch.where(torch.rand([1, N, 1, 1, 1], device=device) < self.translate_int * self.p, w, torch.zeros_like(w))
209 | tx = w[0].mul(W * self.translate_int_max).round().to(torch.int64)
210 | ty = w[1].mul(H * self.translate_int_max).round().to(torch.int64)
211 | b, c, y, x = torch.meshgrid(*(torch.arange(x, device=device) for x in images.shape), indexing='ij')
212 | x = W - 1 - (W - 1 - (x - tx) % (W * 2 - 2)).abs()
213 | y = H - 1 - (H - 1 - (y + ty) % (H * 2 - 2)).abs()
214 | images = images.flatten()[(((b * C) + c) * H + y) * W + x]
215 | labels += [tx.div(W * self.translate_int_max), ty.div(H * self.translate_int_max)]
216 |
217 | # ------------------------------------------------
218 | # Select parameters for geometric transformations.
219 | # ------------------------------------------------
220 |
221 | I_3 = torch.eye(3, device=device)
222 | G_inv = I_3
223 |
224 | if self.scale > 0:
225 | w = torch.randn([N], device=device)
226 | w = torch.where(torch.rand([N], device=device) < self.scale * self.p, w, torch.zeros_like(w))
227 | s = w.mul(self.scale_std).exp2()
228 | G_inv = G_inv @ scale2d_inv(s, s)
229 | labels += [w]
230 |
231 | if self.rotate_frac > 0:
232 | w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.rotate_frac_max)
233 | w = torch.where(torch.rand([N], device=device) < self.rotate_frac * self.p, w, torch.zeros_like(w))
234 | G_inv = G_inv @ rotate2d_inv(-w)
235 | labels += [w.cos() - 1, w.sin()]
236 |
237 | if self.aniso > 0:
238 | w = torch.randn([N], device=device)
239 | r = (torch.rand([N], device=device) * 2 - 1) * np.pi
240 | w = torch.where(torch.rand([N], device=device) < self.aniso * self.p, w, torch.zeros_like(w))
241 | r = torch.where(torch.rand([N], device=device) < self.aniso_rotate_prob, r, torch.zeros_like(r))
242 | s = w.mul(self.aniso_std).exp2()
243 | G_inv = G_inv @ rotate2d_inv(r) @ scale2d_inv(s, 1 / s) @ rotate2d_inv(-r)
244 | labels += [w * r.cos(), w * r.sin()]
245 |
246 | if self.translate_frac > 0:
247 | w = torch.randn([2, N], device=device)
248 | w = torch.where(torch.rand([1, N], device=device) < self.translate_frac * self.p, w, torch.zeros_like(w))
249 | G_inv = G_inv @ translate2d_inv(w[0].mul(W * self.translate_frac_std), w[1].mul(H * self.translate_frac_std))
250 | labels += [w[0], w[1]]
251 |
252 | # ----------------------------------
253 | # Execute geometric transformations.
254 | # ----------------------------------
255 |
256 | if G_inv is not I_3:
257 | cx = (W - 1) / 2
258 | cy = (H - 1) / 2
259 | cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
260 | cp = G_inv @ cp.t() # [batch, xyz, idx]
261 | Hz = np.asarray(wavelets['sym6'], dtype=np.float32)
262 | Hz_pad = len(Hz) // 4
263 | margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
264 | margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
265 | margin = margin + constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
266 | margin = margin.max(constant([0, 0] * 2, device=device))
267 | margin = margin.min(constant([W - 1, H - 1] * 2, device=device))
268 | mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
269 |
270 | # Pad image and adjust origin.
271 | images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
272 | G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
273 |
274 | # Upsample.
275 | conv_weight = constant(Hz[None, None, ::-1], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1])
276 | conv_pad = (len(Hz) + 1) // 2
277 | images = torch.stack([images, torch.zeros_like(images)], dim=4).reshape(N, C, images.shape[2], -1)[:, :, :, :-1]
278 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], padding=[0,conv_pad])
279 | images = torch.stack([images, torch.zeros_like(images)], dim=3).reshape(N, C, -1, images.shape[3])[:, :, :-1, :]
280 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], padding=[conv_pad,0])
281 | G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
282 | G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
283 |
284 | # Execute transformation.
285 | shape = [N, C, (H + Hz_pad * 2) * 2, (W + Hz_pad * 2) * 2]
286 | G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
287 | grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
288 | images = torch.nn.functional.grid_sample(images, grid, mode='bilinear', padding_mode='zeros', align_corners=False)
289 |
290 | # Downsample and crop.
291 | conv_weight = constant(Hz[None, None, :], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1])
292 | conv_pad = (len(Hz) - 1) // 2
293 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], stride=[1,2], padding=[0,conv_pad])[:, :, :, Hz_pad : -Hz_pad]
294 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], stride=[2,1], padding=[conv_pad,0])[:, :, Hz_pad : -Hz_pad, :]
295 |
296 | # --------------------------------------------
297 | # Select parameters for color transformations.
298 | # --------------------------------------------
299 |
300 | I_4 = torch.eye(4, device=device)
301 | M = I_4
302 | luma_axis = constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device)
303 |
304 | if self.brightness > 0:
305 | w = torch.randn([N], device=device)
306 | w = torch.where(torch.rand([N], device=device) < self.brightness * self.p, w, torch.zeros_like(w))
307 | b = w * self.brightness_std
308 | M = translate3d(b, b, b) @ M
309 | labels += [w]
310 |
311 | if self.contrast > 0:
312 | w = torch.randn([N], device=device)
313 | w = torch.where(torch.rand([N], device=device) < self.contrast * self.p, w, torch.zeros_like(w))
314 | c = w.mul(self.contrast_std).exp2()
315 | M = scale3d(c, c, c) @ M
316 | labels += [w]
317 |
318 | if self.lumaflip > 0:
319 | w = torch.randint(2, [N, 1, 1], device=device)
320 | w = torch.where(torch.rand([N, 1, 1], device=device) < self.lumaflip * self.p, w, torch.zeros_like(w))
321 | M = (I_4 - 2 * luma_axis.ger(luma_axis) * w) @ M
322 | labels += [w]
323 |
324 | if self.hue > 0:
325 | w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.hue_max)
326 | w = torch.where(torch.rand([N], device=device) < self.hue * self.p, w, torch.zeros_like(w))
327 | M = rotate3d(luma_axis, w) @ M
328 | labels += [w.cos() - 1, w.sin()]
329 |
330 | if self.saturation > 0:
331 | w = torch.randn([N, 1, 1], device=device)
332 | w = torch.where(torch.rand([N, 1, 1], device=device) < self.saturation * self.p, w, torch.zeros_like(w))
333 | M = (luma_axis.ger(luma_axis) + (I_4 - luma_axis.ger(luma_axis)) * w.mul(self.saturation_std).exp2()) @ M
334 | labels += [w]
335 |
336 | # ------------------------------
337 | # Execute color transformations.
338 | # ------------------------------
339 |
340 | if M is not I_4:
341 | images = images.reshape([N, C, H * W])
342 | if C == 3:
343 | images = M[:, :3, :3] @ images + M[:, :3, 3:]
344 | elif C == 1:
345 | M = M[:, :3, :].mean(dim=1, keepdims=True)
346 | images = images * M[:, :, :3].sum(dim=2, keepdims=True) + M[:, :, 3:]
347 | else:
348 | raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
349 | images = images.reshape([N, C, H, W])
350 |
351 | labels = torch.cat([x.to(torch.float32).reshape(N, -1) for x in labels], dim=1)
352 | return images, labels
353 |
354 | #----------------------------------------------------------------------------
355 |
--------------------------------------------------------------------------------
/DiT/diffusion/gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 |
7 | import math
8 |
9 | import numpy as np
10 | import torch as th
11 | import enum
12 |
13 | from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14 |
15 |
16 | def mean_flat(tensor):
17 | """
18 | Take the mean over all non-batch dimensions.
19 | """
20 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
21 |
22 |
23 | class ModelMeanType(enum.Enum):
24 | """
25 | Which type of output the model predicts.
26 | """
27 |
28 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29 | START_X = enum.auto() # the model predicts x_0
30 | EPSILON = enum.auto() # the model predicts epsilon
31 |
32 |
33 | class ModelVarType(enum.Enum):
34 | """
35 | What is used as the model's output variance.
36 | The LEARNED_RANGE option has been added to allow the model to predict
37 | values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38 | """
39 |
40 | LEARNED = enum.auto()
41 | FIXED_SMALL = enum.auto()
42 | FIXED_LARGE = enum.auto()
43 | LEARNED_RANGE = enum.auto()
44 |
45 |
46 | class LossType(enum.Enum):
47 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48 | RESCALED_MSE = (
49 | enum.auto()
50 | ) # use raw MSE loss (with RESCALED_KL when learning variances)
51 | KL = enum.auto() # use the variational lower-bound
52 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53 |
54 | def is_vb(self):
55 | return self == LossType.KL or self == LossType.RESCALED_KL
56 |
57 |
58 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60 | warmup_time = int(num_diffusion_timesteps * warmup_frac)
61 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62 | return betas
63 |
64 |
65 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66 | """
67 | This is the deprecated API for creating beta schedules.
68 | See get_named_beta_schedule() for the new library of schedules.
69 | """
70 | if beta_schedule == "quad":
71 | betas = (
72 | np.linspace(
73 | beta_start ** 0.5,
74 | beta_end ** 0.5,
75 | num_diffusion_timesteps,
76 | dtype=np.float64,
77 | )
78 | ** 2
79 | )
80 | elif beta_schedule == "linear":
81 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82 | elif beta_schedule == "warmup10":
83 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84 | elif beta_schedule == "warmup50":
85 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86 | elif beta_schedule == "const":
87 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89 | betas = 1.0 / np.linspace(
90 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91 | )
92 | else:
93 | raise NotImplementedError(beta_schedule)
94 | assert betas.shape == (num_diffusion_timesteps,)
95 | return betas
96 |
97 |
98 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99 | """
100 | Get a pre-defined beta schedule for the given name.
101 | The beta schedule library consists of beta schedules which remain similar
102 | in the limit of num_diffusion_timesteps.
103 | Beta schedules may be added, but should not be removed or changed once
104 | they are committed to maintain backwards compatibility.
105 | """
106 | if schedule_name == "linear":
107 | # Linear schedule from Ho et al, extended to work for any number of
108 | # diffusion steps.
109 | scale = 1000 / num_diffusion_timesteps
110 | return get_beta_schedule(
111 | "linear",
112 | beta_start=scale * 0.0001,
113 | beta_end=scale * 0.02,
114 | num_diffusion_timesteps=num_diffusion_timesteps,
115 | )
116 | elif schedule_name == "squaredcos_cap_v2":
117 | return betas_for_alpha_bar(
118 | num_diffusion_timesteps,
119 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
120 | )
121 | else:
122 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
123 |
124 |
125 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
126 | """
127 | Create a beta schedule that discretizes the given alpha_t_bar function,
128 | which defines the cumulative product of (1-beta) over time from t = [0,1].
129 | :param num_diffusion_timesteps: the number of betas to produce.
130 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
131 | produces the cumulative product of (1-beta) up to that
132 | part of the diffusion process.
133 | :param max_beta: the maximum beta to use; use values lower than 1 to
134 | prevent singularities.
135 | """
136 | betas = []
137 | for i in range(num_diffusion_timesteps):
138 | t1 = i / num_diffusion_timesteps
139 | t2 = (i + 1) / num_diffusion_timesteps
140 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
141 | return np.array(betas)
142 |
143 |
144 | class GaussianDiffusion:
145 | """
146 | Utilities for training and sampling diffusion models.
147 | Original ported from this codebase:
148 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
149 | :param betas: a 1-D numpy array of betas for each diffusion timestep,
150 | starting at T and going to 1.
151 | """
152 |
153 | def __init__(
154 | self,
155 | *,
156 | betas,
157 | model_mean_type,
158 | model_var_type,
159 | loss_type
160 | ):
161 |
162 | self.model_mean_type = model_mean_type
163 | self.model_var_type = model_var_type
164 | self.loss_type = loss_type
165 |
166 | # Use float64 for accuracy.
167 | betas = np.array(betas, dtype=np.float64)
168 | self.betas = betas
169 | assert len(betas.shape) == 1, "betas must be 1-D"
170 | assert (betas > 0).all() and (betas <= 1).all()
171 |
172 | self.num_timesteps = int(betas.shape[0])
173 |
174 | alphas = 1.0 - betas
175 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
176 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
177 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
178 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
179 |
180 | # calculations for diffusion q(x_t | x_{t-1}) and others
181 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
182 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
183 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
184 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
185 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
186 |
187 | # calculations for posterior q(x_{t-1} | x_t, x_0)
188 | self.posterior_variance = (
189 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
190 | )
191 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
192 | self.posterior_log_variance_clipped = np.log(
193 | np.append(self.posterior_variance[1], self.posterior_variance[1:])
194 | ) if len(self.posterior_variance) > 1 else np.array([])
195 |
196 | self.posterior_mean_coef1 = (
197 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
198 | )
199 | self.posterior_mean_coef2 = (
200 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
201 | )
202 |
203 | def q_mean_variance(self, x_start, t):
204 | """
205 | Get the distribution q(x_t | x_0).
206 | :param x_start: the [N x C x ...] tensor of noiseless inputs.
207 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208 | :return: A tuple (mean, variance, log_variance), all of x_start's shape.
209 | """
210 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
211 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
212 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
213 | return mean, variance, log_variance
214 |
215 | def q_sample(self, x_start, t, noise=None):
216 | """
217 | Diffuse the data for a given number of diffusion steps.
218 | In other words, sample from q(x_t | x_0).
219 | :param x_start: the initial data batch.
220 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
221 | :param noise: if specified, the split-out normal noise.
222 | :return: A noisy version of x_start.
223 | """
224 | if noise is None:
225 | noise = th.randn_like(x_start)
226 | assert noise.shape == x_start.shape
227 | return (
228 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
229 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
230 | )
231 |
232 | def q_posterior_mean_variance(self, x_start, x_t, t):
233 | """
234 | Compute the mean and variance of the diffusion posterior:
235 | q(x_{t-1} | x_t, x_0)
236 | """
237 | assert x_start.shape == x_t.shape
238 | posterior_mean = (
239 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
240 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
241 | )
242 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
243 | posterior_log_variance_clipped = _extract_into_tensor(
244 | self.posterior_log_variance_clipped, t, x_t.shape
245 | )
246 | assert (
247 | posterior_mean.shape[0]
248 | == posterior_variance.shape[0]
249 | == posterior_log_variance_clipped.shape[0]
250 | == x_start.shape[0]
251 | )
252 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
253 |
254 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
255 | """
256 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
257 | the initial x, x_0.
258 | :param model: the model, which takes a signal and a batch of timesteps
259 | as input.
260 | :param x: the [N x C x ...] tensor at time t.
261 | :param t: a 1-D Tensor of timesteps.
262 | :param clip_denoised: if True, clip the denoised signal into [-1, 1].
263 | :param denoised_fn: if not None, a function which applies to the
264 | x_start prediction before it is used to sample. Applies before
265 | clip_denoised.
266 | :param model_kwargs: if not None, a dict of extra keyword arguments to
267 | pass to the model. This can be used for conditioning.
268 | :return: a dict with the following keys:
269 | - 'mean': the model mean output.
270 | - 'variance': the model variance output.
271 | - 'log_variance': the log of 'variance'.
272 | - 'pred_xstart': the prediction for x_0.
273 | """
274 | if model_kwargs is None:
275 | model_kwargs = {}
276 |
277 | B, C = x.shape[:2]
278 | assert t.shape == (B,)
279 | model_output = model(x, t, **model_kwargs)
280 | if isinstance(model_output, tuple):
281 | model_output, extra = model_output
282 | else:
283 | extra = None
284 |
285 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
286 | assert model_output.shape == (B, C * 2, *x.shape[2:])
287 | model_output, model_var_values = th.split(model_output, C, dim=1)
288 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
289 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
290 | # The model_var_values is [-1, 1] for [min_var, max_var].
291 | frac = (model_var_values + 1) / 2
292 | model_log_variance = frac * max_log + (1 - frac) * min_log
293 | model_variance = th.exp(model_log_variance)
294 | else:
295 | model_variance, model_log_variance = {
296 | # for fixedlarge, we set the initial (log-)variance like so
297 | # to get a better decoder log likelihood.
298 | ModelVarType.FIXED_LARGE: (
299 | np.append(self.posterior_variance[1], self.betas[1:]),
300 | np.log(np.append(self.posterior_variance[1], self.betas[1:])),
301 | ),
302 | ModelVarType.FIXED_SMALL: (
303 | self.posterior_variance,
304 | self.posterior_log_variance_clipped,
305 | ),
306 | }[self.model_var_type]
307 | model_variance = _extract_into_tensor(model_variance, t, x.shape)
308 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
309 |
310 | def process_xstart(x):
311 | if denoised_fn is not None:
312 | x = denoised_fn(x)
313 | if clip_denoised:
314 | return x.clamp(-1, 1)
315 | return x
316 |
317 | if self.model_mean_type == ModelMeanType.START_X:
318 | pred_xstart = process_xstart(model_output)
319 | else:
320 | pred_xstart = process_xstart(
321 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
322 | )
323 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
324 |
325 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
326 | return {
327 | "mean": model_mean,
328 | "variance": model_variance,
329 | "log_variance": model_log_variance,
330 | "pred_xstart": pred_xstart,
331 | "extra": extra,
332 | }
333 |
334 | def _predict_xstart_from_eps(self, x_t, t, eps):
335 | assert x_t.shape == eps.shape
336 | return (
337 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
338 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
339 | )
340 |
341 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
342 | return (
343 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
344 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
345 |
346 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
347 | """
348 | Compute the mean for the previous step, given a function cond_fn that
349 | computes the gradient of a conditional log probability with respect to
350 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
351 | condition on y.
352 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
353 | """
354 | gradient = cond_fn(x, t, **model_kwargs)
355 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
356 | return new_mean
357 |
358 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
359 | """
360 | Compute what the p_mean_variance output would have been, should the
361 | model's score function be conditioned by cond_fn.
362 | See condition_mean() for details on cond_fn.
363 | Unlike condition_mean(), this instead uses the conditioning strategy
364 | from Song et al (2020).
365 | """
366 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
367 |
368 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
369 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
370 |
371 | out = p_mean_var.copy()
372 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
373 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
374 | return out
375 |
376 | def p_sample(
377 | self,
378 | model,
379 | x,
380 | t,
381 | clip_denoised=True,
382 | denoised_fn=None,
383 | cond_fn=None,
384 | model_kwargs=None,
385 | ):
386 | """
387 | Sample x_{t-1} from the model at the given timestep.
388 | :param model: the model to sample from.
389 | :param x: the current tensor at x_{t-1}.
390 | :param t: the value of t, starting at 0 for the first diffusion step.
391 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
392 | :param denoised_fn: if not None, a function which applies to the
393 | x_start prediction before it is used to sample.
394 | :param cond_fn: if not None, this is a gradient function that acts
395 | similarly to the model.
396 | :param model_kwargs: if not None, a dict of extra keyword arguments to
397 | pass to the model. This can be used for conditioning.
398 | :return: a dict containing the following keys:
399 | - 'sample': a random sample from the model.
400 | - 'pred_xstart': a prediction of x_0.
401 | """
402 | out = self.p_mean_variance(
403 | model,
404 | x,
405 | t,
406 | clip_denoised=clip_denoised,
407 | denoised_fn=denoised_fn,
408 | model_kwargs=model_kwargs,
409 | )
410 | noise = th.randn_like(x)
411 | nonzero_mask = (
412 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
413 | ) # no noise when t == 0
414 | if cond_fn is not None:
415 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
416 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
417 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
418 |
419 | def p_sample_loop(
420 | self,
421 | model,
422 | shape,
423 | noise=None,
424 | clip_denoised=True,
425 | denoised_fn=None,
426 | cond_fn=None,
427 | model_kwargs=None,
428 | device=None,
429 | progress=False,
430 | ):
431 | """
432 | Generate samples from the model.
433 | :param model: the model module.
434 | :param shape: the shape of the samples, (N, C, H, W).
435 | :param noise: if specified, the noise from the encoder to sample.
436 | Should be of the same shape as `shape`.
437 | :param clip_denoised: if True, clip x_start predictions to [-1, 1].
438 | :param denoised_fn: if not None, a function which applies to the
439 | x_start prediction before it is used to sample.
440 | :param cond_fn: if not None, this is a gradient function that acts
441 | similarly to the model.
442 | :param model_kwargs: if not None, a dict of extra keyword arguments to
443 | pass to the model. This can be used for conditioning.
444 | :param device: if specified, the device to create the samples on.
445 | If not specified, use a model parameter's device.
446 | :param progress: if True, show a tqdm progress bar.
447 | :return: a non-differentiable batch of samples.
448 | """
449 | final = None
450 | for sample in self.p_sample_loop_progressive(
451 | model,
452 | shape,
453 | noise=noise,
454 | clip_denoised=clip_denoised,
455 | denoised_fn=denoised_fn,
456 | cond_fn=cond_fn,
457 | model_kwargs=model_kwargs,
458 | device=device,
459 | progress=progress,
460 | ):
461 | final = sample
462 | return final["sample"]
463 |
464 | def p_sample_loop_progressive(
465 | self,
466 | model,
467 | shape,
468 | noise=None,
469 | clip_denoised=True,
470 | denoised_fn=None,
471 | cond_fn=None,
472 | model_kwargs=None,
473 | device=None,
474 | progress=False,
475 | ):
476 | """
477 | Generate samples from the model and yield intermediate samples from
478 | each timestep of diffusion.
479 | Arguments are the same as p_sample_loop().
480 | Returns a generator over dicts, where each dict is the return value of
481 | p_sample().
482 | """
483 | if device is None:
484 | device = next(model.parameters()).device
485 | assert isinstance(shape, (tuple, list))
486 | if noise is not None:
487 | img = noise
488 | else:
489 | img = th.randn(*shape, device=device)
490 | indices = list(range(self.num_timesteps))[::-1]
491 |
492 | if progress:
493 | # Lazy import so that we don't depend on tqdm.
494 | from tqdm.auto import tqdm
495 |
496 | indices = tqdm(indices)
497 |
498 | for i in indices:
499 | t = th.tensor([i] * shape[0], device=device)
500 | with th.no_grad():
501 | out = self.p_sample(
502 | model,
503 | img,
504 | t,
505 | clip_denoised=clip_denoised,
506 | denoised_fn=denoised_fn,
507 | cond_fn=cond_fn,
508 | model_kwargs=model_kwargs,
509 | )
510 | yield out
511 | img = out["sample"]
512 |
513 | def ddim_sample(
514 | self,
515 | model,
516 | x,
517 | t,
518 | clip_denoised=True,
519 | denoised_fn=None,
520 | cond_fn=None,
521 | model_kwargs=None,
522 | eta=0.0,
523 | ):
524 | """
525 | Sample x_{t-1} from the model using DDIM.
526 | Same usage as p_sample().
527 | """
528 | out = self.p_mean_variance(
529 | model,
530 | x,
531 | t,
532 | clip_denoised=clip_denoised,
533 | denoised_fn=denoised_fn,
534 | model_kwargs=model_kwargs,
535 | )
536 | if cond_fn is not None:
537 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
538 |
539 | # Usually our model outputs epsilon, but we re-derive it
540 | # in case we used x_start or x_prev prediction.
541 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
542 |
543 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
544 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
545 | sigma = (
546 | eta
547 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
548 | * th.sqrt(1 - alpha_bar / alpha_bar_prev)
549 | )
550 | # Equation 12.
551 | noise = th.randn_like(x)
552 | mean_pred = (
553 | out["pred_xstart"] * th.sqrt(alpha_bar_prev)
554 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
555 | )
556 | nonzero_mask = (
557 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
558 | ) # no noise when t == 0
559 | sample = mean_pred + nonzero_mask * sigma * noise
560 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
561 |
562 | def ddim_reverse_sample(
563 | self,
564 | model,
565 | x,
566 | t,
567 | clip_denoised=True,
568 | denoised_fn=None,
569 | cond_fn=None,
570 | model_kwargs=None,
571 | eta=0.0,
572 | ):
573 | """
574 | Sample x_{t+1} from the model using DDIM reverse ODE.
575 | """
576 | assert eta == 0.0, "Reverse ODE only for deterministic path"
577 | out = self.p_mean_variance(
578 | model,
579 | x,
580 | t,
581 | clip_denoised=clip_denoised,
582 | denoised_fn=denoised_fn,
583 | model_kwargs=model_kwargs,
584 | )
585 | if cond_fn is not None:
586 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
587 | # Usually our model outputs epsilon, but we re-derive it
588 | # in case we used x_start or x_prev prediction.
589 | eps = (
590 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
591 | - out["pred_xstart"]
592 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
593 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
594 |
595 | # Equation 12. reversed
596 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
597 |
598 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
599 |
600 | def ddim_sample_loop(
601 | self,
602 | model,
603 | shape,
604 | noise=None,
605 | clip_denoised=True,
606 | denoised_fn=None,
607 | cond_fn=None,
608 | model_kwargs=None,
609 | device=None,
610 | progress=False,
611 | eta=0.0,
612 | ):
613 | """
614 | Generate samples from the model using DDIM.
615 | Same usage as p_sample_loop().
616 | """
617 | final = None
618 | for sample in self.ddim_sample_loop_progressive(
619 | model,
620 | shape,
621 | noise=noise,
622 | clip_denoised=clip_denoised,
623 | denoised_fn=denoised_fn,
624 | cond_fn=cond_fn,
625 | model_kwargs=model_kwargs,
626 | device=device,
627 | progress=progress,
628 | eta=eta,
629 | ):
630 | final = sample
631 | return final["sample"]
632 |
633 | def ddim_sample_loop_progressive(
634 | self,
635 | model,
636 | shape,
637 | noise=None,
638 | clip_denoised=True,
639 | denoised_fn=None,
640 | cond_fn=None,
641 | model_kwargs=None,
642 | device=None,
643 | progress=False,
644 | eta=0.0,
645 | ):
646 | """
647 | Use DDIM to sample from the model and yield intermediate samples from
648 | each timestep of DDIM.
649 | Same usage as p_sample_loop_progressive().
650 | """
651 | if device is None:
652 | device = next(model.parameters()).device
653 | assert isinstance(shape, (tuple, list))
654 | if noise is not None:
655 | img = noise
656 | else:
657 | img = th.randn(*shape, device=device)
658 | indices = list(range(self.num_timesteps))[::-1]
659 |
660 | if progress:
661 | # Lazy import so that we don't depend on tqdm.
662 | from tqdm.auto import tqdm
663 |
664 | indices = tqdm(indices)
665 |
666 | for i in indices:
667 | t = th.tensor([i] * shape[0], device=device)
668 | with th.no_grad():
669 | out = self.ddim_sample(
670 | model,
671 | img,
672 | t,
673 | clip_denoised=clip_denoised,
674 | denoised_fn=denoised_fn,
675 | cond_fn=cond_fn,
676 | model_kwargs=model_kwargs,
677 | eta=eta,
678 | )
679 | yield out
680 | img = out["sample"]
681 |
682 | def _vb_terms_bpd(
683 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
684 | ):
685 | """
686 | Get a term for the variational lower-bound.
687 | The resulting units are bits (rather than nats, as one might expect).
688 | This allows for comparison to other papers.
689 | :return: a dict with the following keys:
690 | - 'output': a shape [N] tensor of NLLs or KLs.
691 | - 'pred_xstart': the x_0 predictions.
692 | """
693 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
694 | x_start=x_start, x_t=x_t, t=t
695 | )
696 | out = self.p_mean_variance(
697 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
698 | )
699 | kl = normal_kl(
700 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
701 | )
702 | kl = mean_flat(kl) / np.log(2.0)
703 |
704 | decoder_nll = -discretized_gaussian_log_likelihood(
705 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
706 | )
707 | assert decoder_nll.shape == x_start.shape
708 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
709 |
710 | # At the first timestep return the decoder NLL,
711 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
712 | output = th.where((t == 0), decoder_nll, kl)
713 | return {"output": output, "pred_xstart": out["pred_xstart"]}
714 |
715 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
716 | """
717 | Compute training losses for a single timestep.
718 | :param model: the model to evaluate loss on.
719 | :param x_start: the [N x C x ...] tensor of inputs.
720 | :param t: a batch of timestep indices.
721 | :param model_kwargs: if not None, a dict of extra keyword arguments to
722 | pass to the model. This can be used for conditioning.
723 | :param noise: if specified, the specific Gaussian noise to try to remove.
724 | :return: a dict with the key "loss" containing a tensor of shape [N].
725 | Some mean or variance settings may also have other keys.
726 | """
727 | if model_kwargs is None:
728 | model_kwargs = {}
729 | if noise is None:
730 | noise = th.randn_like(x_start)
731 | x_t = self.q_sample(x_start, t, noise=noise)
732 |
733 | terms = {}
734 |
735 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
736 | terms["loss"] = self._vb_terms_bpd(
737 | model=model,
738 | x_start=x_start,
739 | x_t=x_t,
740 | t=t,
741 | clip_denoised=False,
742 | model_kwargs=model_kwargs,
743 | )["output"]
744 | if self.loss_type == LossType.RESCALED_KL:
745 | terms["loss"] *= self.num_timesteps
746 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
747 | model_output = model(x_t, t, **model_kwargs)
748 |
749 | if self.model_var_type in [
750 | ModelVarType.LEARNED,
751 | ModelVarType.LEARNED_RANGE,
752 | ]:
753 | B, C = x_t.shape[:2]
754 | assert model_output.shape == (B, C * 2, *x_t.shape[2:])
755 | model_output, model_var_values = th.split(model_output, C, dim=1)
756 | # Learn the variance using the variational bound, but don't let
757 | # it affect our mean prediction.
758 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
759 | terms["vb"] = self._vb_terms_bpd(
760 | model=lambda *args, r=frozen_out: r,
761 | x_start=x_start,
762 | x_t=x_t,
763 | t=t,
764 | clip_denoised=False,
765 | )["output"]
766 | if self.loss_type == LossType.RESCALED_MSE:
767 | # Divide by 1000 for equivalence with initial implementation.
768 | # Without a factor of 1/1000, the VB term hurts the MSE term.
769 | terms["vb"] *= self.num_timesteps / 1000.0
770 |
771 | target = {
772 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
773 | x_start=x_start, x_t=x_t, t=t
774 | )[0],
775 | ModelMeanType.START_X: x_start,
776 | ModelMeanType.EPSILON: noise,
777 | }[self.model_mean_type]
778 | assert model_output.shape == target.shape == x_start.shape
779 | terms["mse"] = mean_flat((target - model_output) ** 2)
780 | if "vb" in terms:
781 | terms["loss"] = terms["mse"] + terms["vb"]
782 | else:
783 | terms["loss"] = terms["mse"]
784 | else:
785 | raise NotImplementedError(self.loss_type)
786 |
787 | return terms
788 |
789 | def _prior_bpd(self, x_start):
790 | """
791 | Get the prior KL term for the variational lower-bound, measured in
792 | bits-per-dim.
793 | This term can't be optimized, as it only depends on the encoder.
794 | :param x_start: the [N x C x ...] tensor of inputs.
795 | :return: a batch of [N] KL values (in bits), one per batch element.
796 | """
797 | batch_size = x_start.shape[0]
798 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
799 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
800 | kl_prior = normal_kl(
801 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
802 | )
803 | return mean_flat(kl_prior) / np.log(2.0)
804 |
805 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
806 | """
807 | Compute the entire variational lower-bound, measured in bits-per-dim,
808 | as well as other related quantities.
809 | :param model: the model to evaluate loss on.
810 | :param x_start: the [N x C x ...] tensor of inputs.
811 | :param clip_denoised: if True, clip denoised samples.
812 | :param model_kwargs: if not None, a dict of extra keyword arguments to
813 | pass to the model. This can be used for conditioning.
814 | :return: a dict containing the following keys:
815 | - total_bpd: the total variational lower-bound, per batch element.
816 | - prior_bpd: the prior term in the lower-bound.
817 | - vb: an [N x T] tensor of terms in the lower-bound.
818 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
819 | - mse: an [N x T] tensor of epsilon MSEs for each timestep.
820 | """
821 | device = x_start.device
822 | batch_size = x_start.shape[0]
823 |
824 | vb = []
825 | xstart_mse = []
826 | mse = []
827 | for t in list(range(self.num_timesteps))[::-1]:
828 | t_batch = th.tensor([t] * batch_size, device=device)
829 | noise = th.randn_like(x_start)
830 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
831 | # Calculate VLB term at the current timestep
832 | with th.no_grad():
833 | out = self._vb_terms_bpd(
834 | model,
835 | x_start=x_start,
836 | x_t=x_t,
837 | t=t_batch,
838 | clip_denoised=clip_denoised,
839 | model_kwargs=model_kwargs,
840 | )
841 | vb.append(out["output"])
842 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
843 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
844 | mse.append(mean_flat((eps - noise) ** 2))
845 |
846 | vb = th.stack(vb, dim=1)
847 | xstart_mse = th.stack(xstart_mse, dim=1)
848 | mse = th.stack(mse, dim=1)
849 |
850 | prior_bpd = self._prior_bpd(x_start)
851 | total_bpd = vb.sum(dim=1) + prior_bpd
852 | return {
853 | "total_bpd": total_bpd,
854 | "prior_bpd": prior_bpd,
855 | "vb": vb,
856 | "xstart_mse": xstart_mse,
857 | "mse": mse,
858 | }
859 |
860 |
861 | def _extract_into_tensor(arr, timesteps, broadcast_shape):
862 | """
863 | Extract values from a 1-D numpy array for a batch of indices.
864 | :param arr: the 1-D numpy array.
865 | :param timesteps: a tensor of indices into the array to extract.
866 | :param broadcast_shape: a larger shape of K dimensions with the batch
867 | dimension equal to the length of timesteps.
868 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
869 | """
870 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
871 | while len(res.shape) < len(broadcast_shape):
872 | res = res[..., None]
873 | return res + th.zeros(broadcast_shape, device=timesteps.device)
874 |
--------------------------------------------------------------------------------