├── .gitignore ├── LICENSE ├── README.md ├── assets ├── burgers.png ├── cifar10.png ├── main.png └── pandas.png ├── configs ├── data │ ├── cifar10.yaml │ ├── imagenet32.yaml │ └── imagenet64c.yaml ├── model │ ├── adm.yaml │ ├── ddpmpp.yaml │ └── vdm.yaml ├── train.yaml └── vis.yaml ├── cube.py ├── datasets.py ├── datasets ├── compile_imagenet64.py ├── download_cifar10.py └── download_imagenet32.sh ├── download_pretrained.sh ├── losses.py ├── models ├── __init__.py ├── adm.py ├── ema.py ├── layers.py ├── layerspp.py ├── layersv2.py ├── ncsnpp.py ├── normalization.py ├── up_or_down_sampling.py ├── utils.py └── vdm.py ├── requirements.txt ├── run_train.py ├── run_vis.py ├── sampling.py ├── sde_lib.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | weights/ 2 | run/ 3 | vis/ 4 | datasets/ 5 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Aaron Lou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reflected Diffusion Models 2 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE) 3 | [![arXiv](https://img.shields.io/badge/arXiv-2304.04740-b31b1b.svg)](https://arxiv.org/abs/2304.04740) 4 | [![blog](https://img.shields.io/badge/blogpost-%20-blue?style=social&logo=disqus)](https://aaronlou.com/blog/2023/reflected-diffusion/) 5 | [![twitter](https://img.shields.io/twitter/url?style=social&url=https%3A%2F%2Ftwitter.com%2Faaron_lou%2Fstatus%2F1646528998594482176%3Fs%3D20)](https://twitter.com/aaron_lou/status/1646528998594482176?s=20) 6 | [![hackernews](https://img.shields.io/badge/hacker%20News-%20-orange?style=social&logo=ycombinator)](https://news.ycombinator.com/item?id=35863309) 7 | [![youtube](https://img.shields.io/badge/youtube-%20-red?style=social&logo=youtube)](https://www.youtube.com/watch?v=YfneSNXJSLE&ab_channel=Valence) 8 | 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/reflected-diffusion-models/image-generation-on-cifar-10)](https://paperswithcode.com/sota/image-generation-on-cifar-10?p=reflected-diffusion-models) 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/reflected-diffusion-models/image-generation-on-imagenet-32x32)](https://paperswithcode.com/sota/image-generation-on-imagenet-32x32?p=reflected-diffusion-models) 11 | 12 | This repo contains a PyTorch implementation for the paper [Reflected Diffusion Models](https://arxiv.org/abs/2304.04740) by [Aaron Lou](https://aaronlou.com) and [Stefano Ermon](https://cs.stanford.edu/~ermon/), appearing at ICML 2023. 13 | 14 | ![cover](assets/main.png) 15 | 16 | ## Setup 17 | 18 | Requisite packages can be installed directly from the `requirements.txt` (note that this installs pytorch 2.0.1 with CUDA 11.7 enabled, but this can be downgraded based off of your requirements): 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | One can also manually install the packages `torch`, `torchvision`, `scipy`, `hydra-core`, and `hydra-submitit-launcher` through `pip`. 23 | 24 | ## Working with Pretrained Models 25 | 26 | ### Download Models 27 | 28 | To download the pretrained models, simply run 29 | ``` 30 | sh download_pretrained.sh 31 | ``` 32 | This creates a `weights` folder with various model weights for CIFAR10 and ImageNet64. Each (main) model directory is of the form 33 | ``` 34 | ├── model_directory 35 | │ ├── .hydra 36 | │ │ ├── config.yaml 37 | │ ├── checkpoints 38 | │ │ ├── checkpoint_*.pth 39 | ``` 40 | which is the minimum file structure that is needed for visualization and is generated by our training script. 41 | 42 | ### Run Visualization 43 | 44 | We can visualize samples using 45 | ``` 46 | python run_vis.py load_dir=model_directory 47 | ``` 48 | where `model_directory` is a model directory as given above such as `weights/cifar10` or `weights/imagenet64`. This creates a new directory `direc=vis/DATE/TIME` which outputs the images in the following directory structure 49 | ``` 50 | ├── direc 51 | │ ├── images 52 | │ │ ├── *.png 53 | │ │ ├── *.npz 54 | ``` 55 | Arguments can be added with `ARG_NAME=ARG_VALUE`. Interesting ones include 56 | ``` 57 | label imagenet class label (used with imagenet64 classifier-free guided model) 58 | w guidance weight (used with imagenet64 classifier-free guided model) 59 | sampling 60 | method "pc" or "ode" 61 | eval 62 | rounds number of images to generate 63 | batch_size batch size of generated images 64 | ``` 65 | For more information on imagenet class lables, reference this [document](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/). As an example, to generate 4 rounds of 16 images of pandas with guidance weight 2.5 using the ode, simply run 66 | ``` 67 | python run_vis.py load_dir=weights/imagenet64 label=388 w=2.5 sampling.method=ode eval.batch_size=16 68 | ``` 69 | Changing these hyperparameters results in outputs like the following: 70 | 71 | ![pandas](assets/pandas.png) ![burgers](assets/burgers.png) ![cifar10](assets/cifar10.png) 72 | 73 | ## Training New Models 74 | 75 | ### Downloading Data 76 | 77 | To download the datasets used in the paper, simply run `cd dataset` and run the corresponding command: 78 | ``` 79 | python download_cifar10.py 80 | sh download_imagenet32.sh 81 | python compile_imagenet64.py /path/to/imagenet/ 82 | ``` 83 | One can also add in new datasets by modifying the `datasets.py` file directly. 84 | 85 | ### Run Training 86 | 87 | We can run training using the command 88 | ``` 89 | python run_train.py 90 | ``` 91 | This creates a new directory `direc=runs/DATE/TIME` with the following structure (compatible with running visualizations) 92 | ``` 93 | ├── direc 94 | │ ├── .hydra 95 | │ │ ├── config.yaml 96 | │ │ ├── ... 97 | │ ├── checkpoints 98 | │ │ ├── checkpoint_*.pth 99 | │ ├── checkpoints-meta 100 | │ │ ├── checkpoint.pth 101 | │ ├── samples 102 | │ │ ├── iter_* 103 | │ │ │ ├── sample_*.png 104 | │ │ │ ├── sample_*.npy 105 | │ ├── logs 106 | ``` 107 | Here, `checkpoints-meta` is used for reloading the run following interruptions, `samples` contains generated images as the run progresses, and `logs` contains the run output. Arguments can be added with `ARG_NAME=ARG_VALUE`, with important ones being: 108 | ``` 109 | ngpus the number of gpus to use in training (using pytorch DDP) 110 | model one of ddpmpp, vdm, adm, where adm is used with imagenet64c and ddpmpp and vdm are used otherwise. 111 | * various args for the model 112 | data one of cifar10, imagenet32, imagenet64c. 113 | * various args for the data 114 | training 115 | batch_size training batch size 116 | n_iters number of gradient updates 117 | drop_label percentage of labels dropped during training (when using adm and imagenet64c) 118 | sde 119 | sigma_min minimum sde noise level. 0.01 is used for image generation while 0.0001 is used for likelihood results 120 | optim 121 | lr learning rate of the optimizer 122 | ``` 123 | The commands used for CIFAR10 image generation, CIFAR10 likelihood, ImageNet32 likelihood, and ImageNet64 image generation are respectively given below (on a 80GB gpu): 124 | ``` 125 | python train.py data=cifar10 model=ddpmpp 126 | python train.py ngpus=4 data=cifar10 data.random_flip=False model=vdm training.n_iters=10000001 sde.sigma_min=0.0001 127 | python train.py ngpus=8 data=imagenet32 model=vdm training.n_iters=2000001 training.batch_size=512 sde.sigma_min=0.0001 128 | python train.py ngpus=8 data=imagenet64c model=adm training.n_iters=400001 training.batch_size=2048 optim.lr=1e-4 129 | ``` 130 | 131 | ## Other Features 132 | 133 | ### SLURM compatibility 134 | 135 | To run on slurm, simply uncomment the last few lines of `configs/train.yaml` or `configs/vis.yaml` and fill in your cluster details in place of `null`. Then, run the corresponding train/visualization commands with -m, e.g. 136 | ``` 137 | python run_train.py -m 138 | ``` 139 | 140 | ## Citation 141 | ``` 142 | @inproceedings{lou2023reflected, 143 | title={Reflected Diffusion Models}, 144 | author={Aaron Lou and Stefano Ermon}, 145 | booktitle={International Conference on Machine Learning}, 146 | year={2023}, 147 | organization={PMLR} 148 | } 149 | ``` 150 | ## Acknowledgements 151 | 152 | This repository builds heavily off of [score sde](https://github.com/yang-song/score_sde_pytorch) and [edm](https://github.com/NVlabs/edm). -------------------------------------------------------------------------------- /assets/burgers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louaaron/Reflected-Diffusion/dc4402607b7a4f5302f4b325ecfe2095b693b26f/assets/burgers.png -------------------------------------------------------------------------------- /assets/cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louaaron/Reflected-Diffusion/dc4402607b7a4f5302f4b325ecfe2095b693b26f/assets/cifar10.png -------------------------------------------------------------------------------- /assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louaaron/Reflected-Diffusion/dc4402607b7a4f5302f4b325ecfe2095b693b26f/assets/main.png -------------------------------------------------------------------------------- /assets/pandas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louaaron/Reflected-Diffusion/dc4402607b7a4f5302f4b325ecfe2095b693b26f/assets/pandas.png -------------------------------------------------------------------------------- /configs/data/cifar10.yaml: -------------------------------------------------------------------------------- 1 | dataset: CIFAR10 2 | image_size: 32 3 | random_flip: True 4 | num_channels: 3 5 | classes: False -------------------------------------------------------------------------------- /configs/data/imagenet32.yaml: -------------------------------------------------------------------------------- 1 | dataset: ImageNet32 2 | image_size: 32 3 | num_channels: 3 4 | classes: False -------------------------------------------------------------------------------- /configs/data/imagenet64c.yaml: -------------------------------------------------------------------------------- 1 | dataset: ImageNet64C 2 | image_size: 64 3 | num_channels: 3 4 | classes: True 5 | num_classes: 1000 -------------------------------------------------------------------------------- /configs/model/adm.yaml: -------------------------------------------------------------------------------- 1 | name: adm 2 | model_channels: 192 3 | channel_mult: [1, 2, 3, 4] 4 | channel_mult_emb: 4 5 | num_blocks: 3 6 | attn_resolutions: [32, 16, 8] 7 | 8 | dropout: 0.1 9 | ema_rate: 0.9999 10 | scale_by_sigma: True 11 | -------------------------------------------------------------------------------- /configs/model/ddpmpp.yaml: -------------------------------------------------------------------------------- 1 | dropout: 0.1 2 | 3 | name: ncsnpp 4 | scale_by_sigma: True 5 | ema_rate: 0.9999 6 | normalization: GroupNorm 7 | nonlinearity: swish 8 | nf: 128 9 | ch_mult: [1, 2, 2, 2] 10 | num_res_blocks: 8 11 | attn_resolutions: [16,] 12 | resamp_with_conv: True 13 | conditional: True 14 | fir: False 15 | fir_kernel: [1, 3, 3, 1] 16 | skip_rescale: True 17 | resblock_type: biggan 18 | progressive: none 19 | progressive_input: residual 20 | progressive_combine: sum 21 | attention_type: ddpm 22 | init_scale: 0. 23 | embedding_type: fourier 24 | fourier_scale: 16 25 | conv_size: 3 26 | -------------------------------------------------------------------------------- /configs/model/vdm.yaml: -------------------------------------------------------------------------------- 1 | name: vdm 2 | channels: 128 3 | num_blocks: 32 4 | 5 | dropout: 0.1 6 | ema_rate: 0.9999 7 | scale_by_sigma: True 8 | 9 | image_fourier: True 10 | image_fourier_start: 6 11 | image_fourier_end: 8 12 | 13 | attention: False 14 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: ddpmpp 4 | - data: cifar10 5 | - override hydra/launcher: submitit_slurm 6 | 7 | ngpus: 1 8 | dataroot: datasets 9 | 10 | training: 11 | batch_size: 128 12 | n_iters: 1300001 13 | snapshot_freq: 50000 14 | log_freq: 50 15 | eval_freq: 100 16 | snapshot_freq_for_preemption: 10000 17 | snapshot_sampling: True 18 | likelihood_weighting: False 19 | reduce_mean: False 20 | drop_label: 0.2 21 | 22 | eval: 23 | batch_size: 256 24 | 25 | sde: 26 | name: vesde 27 | sigma_min: 0.01 28 | sigma_max: 5 29 | num_scales: 1000 30 | 31 | sampling: 32 | n_steps_each: 1 33 | noise_removal: True 34 | probability_flow: False 35 | snr: 0.01 36 | method: pc 37 | predictor: euler_maruyama 38 | corrector: none 39 | denoiser: none 40 | 41 | optim: 42 | weight_decay: 0 43 | optimizer: Adam 44 | lr: 2e-4 45 | beta1: 0.9 46 | beta2: 0.999 47 | eps: 1e-8 48 | warmup: 5000 49 | grad_clip: 1. 50 | 51 | 52 | hydra: 53 | run: 54 | dir: runs/${data.dataset}/${now:%Y.%m.%d}/${now:%H%M%S} 55 | sweep: 56 | dir: runs/${data.dataset}/${now:%Y.%m.%d}/${now:%H%M%S} 57 | subdir: ${hydra.job.num} 58 | # launcher: 59 | # max_num_timeout: null 60 | # timeout_min: null 61 | # partition: null 62 | # account: null 63 | # mem_gb: null 64 | # cpus_per_task: null 65 | # gpus_per_node: ${ngpus} 66 | # constraint: null 67 | -------------------------------------------------------------------------------- /configs/vis.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/launcher: submitit_slurm 4 | 5 | load_dir: ??? 6 | w: 0 7 | label: 388 8 | 9 | sampling: 10 | method: pc 11 | n_steps_each: 1 12 | noise_removal: True 13 | snr: 0.01 14 | predictor: euler_maruyama 15 | corrector: none 16 | denoiser: none 17 | moll: 200 18 | side_eps: 0.01 19 | 20 | denoiser_path: null 21 | 22 | eval: 23 | ckpt: -1 24 | batch_size: 100 25 | rounds: 1 26 | 27 | hydra: 28 | run: 29 | dir: vis/${now:%Y.%m.%d}/${now:%H%M%S} 30 | sweep: 31 | dir: vis/${now:%Y.%m.%d}/${now:%H%M%S} 32 | subdir: ${hydra.job.num} 33 | # launcher: 34 | # max_num_timeout: null 35 | # timeout_min: null 36 | # partition: null 37 | # account: null 38 | # mem_gb: null 39 | # cpus_per_task: null 40 | # gpus_per_node: ${ngpus} 41 | # constraint: null 42 | -------------------------------------------------------------------------------- /cube.py: -------------------------------------------------------------------------------- 1 | """Helper functions for the unit hypercube [0, 1]^D""" 2 | import torch 3 | from math import pi 4 | 5 | def unsqueeze_as(x, y, back=True): 6 | """ 7 | Unsqueeze x to have as many dimensions as y. For example, tensor shapes: 8 | 9 | x: (a, b, c), y: (a, b, c, d, e) -> output: (a, b, c, 1, 1) 10 | """ 11 | if back: 12 | return x.view(*x.shape, *((1,) * (len(y.shape) - len(x.shape)))) 13 | else: 14 | return x.view(*((1,) * (len(y.shape) - len(x.shape))), *x.shape) 15 | 16 | 17 | def inside(x): 18 | """ 19 | Checks if x is inside the unit hypercube, batchwise 20 | 21 | Args 22 | ---- 23 | x (Tensor): 24 | input of shape [B, ...] 25 | 26 | Returns 27 | ------- 28 | an output Tensor of shape [B] correpsonding to if each x[i] is in the cube 29 | """ 30 | x = x.flatten(1) 31 | return torch.logical_and(x >= 0, x <= 1).all(dim=-1) 32 | 33 | 34 | def reflect(x): 35 | """ 36 | Performs reflections until x is inside the domain. 37 | 38 | Args 39 | ---- 40 | x (Tensor): 41 | input of shape [B, ...] 42 | 43 | Returns 44 | ------- 45 | an output Tensor with the same shape as x which is the "reflected"-inside version. 46 | """ 47 | xm2 = x % 2 48 | xm2[xm2 > 1] = 2 - xm2[xm2 > 1] 49 | return xm2 50 | 51 | 52 | def sample_hk(x, sigma): 53 | """ 54 | Sample from heat kernel starting at point x with coefficient sigma. 55 | 56 | Args 57 | ---- 58 | x (Tensor): 59 | input of shape [B, ...]. Corresponds to the pseudo-"mean" or "starting point". 60 | sigma (Tensor): 61 | input of shape [B]. Corresponds to the std dev of the underlying Gaussian 62 | or t^2/2 where t is the time of the heat equation PDE. 63 | Returns 64 | ------- 65 | an output Tensor with the same shape as x corresponding to a random sample. 66 | """ 67 | if not torch.is_tensor(sigma): 68 | sigma = sigma * torch.ones(x.shape[0]).to(x) 69 | samples_gauss = torch.randn_like(x) * unsqueeze_as(sigma, x) + x 70 | return reflect(samples_gauss) 71 | 72 | 73 | def _score_hk_ef(x, x_orig, t, efs=20): 74 | """ 75 | Computes the score of the heat kernel using eigenfunctions. 76 | 77 | Args 78 | ---- 79 | x (Tensor): 80 | shape [B, ...]. Corresponds to the sampled point. 81 | x_orig (Tensor): 82 | shape [B, ...] same as x. Corresponds to the origin/pseudo-mean. 83 | t (Tensor): 84 | shape [B]. Time of the heat equation PDE. 85 | efs (int): 86 | number of eigenfunctions to compute with 87 | 88 | Returns 89 | ------- 90 | an output tensor of the same shape as x corresponding to the score of the heat kernel. 91 | """ 92 | eval_range = torch.arange(1, efs + 1).to(x) 93 | 94 | x_rescaled = pi * x.unsqueeze(0) * unsqueeze_as(eval_range, x.unsqueeze(0)) 95 | x_orig_rescaled = pi * x_orig.unsqueeze(0) * unsqueeze_as(eval_range, x_orig.unsqueeze(0)) 96 | 97 | x_sin = x_rescaled.sin() 98 | x_cos = x_rescaled.cos() 99 | x_orig_cos = x_orig_rescaled.cos() 100 | 101 | e_powers_denom = (-t.unsqueeze(0) * eval_range.unsqueeze(-1).pow(2) * (pi ** 2)).exp() 102 | e_powers_num = e_powers_denom * eval_range.unsqueeze(-1) 103 | 104 | num = - 2 * pi * (unsqueeze_as(e_powers_num, x_sin) * (x_sin * x_orig_cos)).sum(0) 105 | denom = 1 + 2 * (unsqueeze_as(e_powers_denom, x_sin) * (x_cos * x_orig_cos)).sum(0) 106 | 107 | return (num / (denom + 1e-12)) 108 | 109 | 110 | def _score_hk_refl(x, x_orig, t, refls=2): 111 | """ 112 | Computes the score of the heat kernel using reflection. 113 | 114 | Args 115 | ---- 116 | x (Tensor): 117 | shape [B, ...]. Corresponds to the sampled point. 118 | x_orig (Tensor): 119 | shape [B, ...] same as x. Corresponds to the origin/pseudo-mean. 120 | t (Tensor): 121 | shape [B]. Time of the heat flow PDE. 122 | refls (int): 123 | number of reflections to sum up. 124 | 125 | Returns 126 | ------- 127 | an output tensor of the same shape as x corresponding to the score of the heat kernel. 128 | """ 129 | refls = torch.arange(-2 * refls, 2 * refls + 1, 2).to(x) 130 | 131 | x_refl = torch.cat(( 132 | unsqueeze_as(refls, x.unsqueeze(0)) + x.unsqueeze(0), 133 | unsqueeze_as(refls, x.unsqueeze(0)) - x.unsqueeze(0) 134 | ), dim=0) 135 | refl_sign = torch.cat((torch.ones_like(refls), -torch.ones_like(refls)), dim=0) 136 | 137 | x_minus = x_refl - x_orig.unsqueeze(0) 138 | fourt = (4 * unsqueeze_as(t.unsqueeze(0), x_minus)) 139 | 140 | denom_coeff = - 2 * x_minus / fourt 141 | e_powers = (- x_minus.pow(2) / fourt).exp() 142 | 143 | num = (denom_coeff * e_powers * unsqueeze_as(refl_sign, e_powers)).sum(0) 144 | denom = e_powers.sum(0) 145 | 146 | return (num/ (denom + 1e-12)) 147 | 148 | 149 | def score_hk(x, x_orig, sigma, efs=20, refls=10, min_cutoff=1e-2): 150 | """ 151 | Computes the score of the heat kernel using eigenfunctions. 152 | 153 | Args 154 | ---- 155 | x (Tensor): 156 | shape [B, ...]. Corresponds to the sampled point. 157 | x_orig (Tensor): 158 | shape [B, ...] same as x. Corresponds to the origin/pseudo-mean. 159 | sigma (Tensor): 160 | shape [B]. Std dev of the underlying Guassian 161 | efs (int): 162 | see _score_hk_ef 163 | refls (int): 164 | see _score_hk_refl 165 | min_cutoff (float): 166 | value such that below computes with refls and above with efs 167 | 168 | Returns 169 | ------- 170 | an output tensor of the same shape as x corresponding to the score of the heat kernel. 171 | """ 172 | t = sigma ** 2 / 2 173 | if not torch.is_tensor(t): 174 | t = t * torch.ones(x.shape[0]).to(x) 175 | 176 | ef_cond = t > min_cutoff 177 | x_ef = x[ef_cond] 178 | x_orig_ef = x_orig[ef_cond] 179 | t_ef = t[ef_cond] 180 | 181 | refl_cond = torch.logical_not(ef_cond) 182 | x_refl = x[refl_cond] 183 | x_orig_refl = x_orig[refl_cond] 184 | t_refl = t[refl_cond] 185 | 186 | scores_ef = _score_hk_ef(x_ef, x_orig_ef, t_ef, efs=efs) 187 | scores_refl = _score_hk_refl(x_refl, x_orig_refl, t_refl, refls=refls) 188 | 189 | scores = torch.zeros_like(x) 190 | scores[ef_cond] = scores_ef 191 | scores[refl_cond] = scores_refl 192 | 193 | return scores 194 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | """Return training and evaluation/test datasets from config files.""" 2 | import json 3 | import os 4 | import os.path 5 | import pickle 6 | import sys 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torchvision.datasets as vdsets 12 | import torchvision.transforms as transforms 13 | from torchvision.transforms.functional import InterpolationMode 14 | from PIL import Image 15 | from torch.utils.data import Dataset, DataLoader, DistributedSampler, TensorDataset 16 | 17 | 18 | def identity(x): 19 | return x 20 | 21 | def cycle_loader(dataloader, sampler=None): 22 | while 1: 23 | if sampler is not None: 24 | sampler.set_epoch(np.random.randint(0, 100000)) 25 | for data in dataloader: 26 | yield data 27 | 28 | # fast class to load all images 29 | class ImageFolderFast(vdsets.VisionDataset): 30 | def __init__(self, root, transform=None): 31 | super().__init__(root, transform=transform) 32 | self.image_paths = os.listdir(root) 33 | self.transform = transform 34 | 35 | def __getitem__(self, index): 36 | image_path = os.path.join(self.root, self.image_paths[index]) 37 | with open(image_path, "rb") as f: 38 | img = Image.open(f) 39 | x = img.convert("RGB") 40 | if self.transform is not None: 41 | x = self.transform(x) 42 | return x, #needed to make it consistent: index dataset[0][0] for image 43 | 44 | def __len__(self): 45 | return len(self.image_paths) 46 | 47 | # fast class to load all images 48 | class ImageFolderClassFast(vdsets.VisionDataset): 49 | def __init__(self, root, transform=None): 50 | super().__init__(root, transform=transform) 51 | with open(os.path.join(root, "dataset.json"), "r") as f: 52 | self.image_paths = json.load(f)["labels"] 53 | self.transform = transform 54 | 55 | def __getitem__(self, index): 56 | pair = self.image_paths[index] 57 | image_path = os.path.join(self.root, pair[0]) 58 | with open(image_path, "rb") as f: 59 | img = Image.open(f) 60 | x = img.convert("RGB") 61 | if self.transform is not None: 62 | x = self.transform(x) 63 | return x, pair[1] 64 | 65 | def __len__(self): 66 | return len(self.image_paths) 67 | 68 | 69 | def get_dataset(config, evaluation=False, distributed=True): 70 | 71 | dataroot = config.dataroot 72 | if config.data.dataset == "CIFAR10": 73 | 74 | train_transforms = transforms.Compose( 75 | [ 76 | transforms.Resize(config.data.image_size), 77 | transforms.RandomHorizontalFlip() if config.data.random_flip else identity, 78 | transforms.ToTensor(), 79 | ] 80 | ) 81 | test_transforms = transforms.Compose( 82 | [ 83 | transforms.Resize(config.data.image_size), 84 | transforms.ToTensor(), 85 | ] 86 | ) 87 | 88 | train_set = vdsets.CIFAR10(dataroot, train=True, transform=train_transforms) 89 | test_set = vdsets.CIFAR10(dataroot, train=False, transform=test_transforms) 90 | workers = 2 91 | elif config.data.dataset == "ImageNet32": 92 | data_transforms = transforms.Compose( 93 | [ 94 | transforms.ToTensor(), 95 | ] 96 | ) 97 | train_set = ImageFolderFast(os.path.join(dataroot, "ds_imagenet", "train_32x32"), transform=data_transforms) 98 | test_set = ImageFolderFast(os.path.join(dataroot, "ds_imagenet", "valid_32x32"), transform=data_transforms) 99 | workers = 4 100 | elif config.data.dataset == "ImageNet64C": 101 | data_transforms = transforms.Compose( 102 | [ 103 | transforms.ToTensor(), 104 | ] 105 | ) 106 | train_set = ImageFolderClassFast(os.path.join(dataroot, "imagenet-64x64", "train"), transform=data_transforms) 107 | test_set = ImageFolderClassFast(os.path.join(dataroot, "imagenet-64x64", "valid"), transform=data_transforms) 108 | workers = 4 109 | else: 110 | raise ValueError(f"{config.data.dataset} is not valid") 111 | 112 | if evaluation: 113 | if distributed: 114 | sampler = DistributedSampler(test_set, shuffle=False) 115 | else: 116 | sampler = None 117 | 118 | test_loader = DataLoader( 119 | test_set, 120 | batch_size=config.eval.batch_size, 121 | sampler=sampler, 122 | num_workers=workers, 123 | pin_memory=True, 124 | shuffle=(sampler is None) 125 | ) 126 | 127 | return test_loader 128 | else: 129 | if config.training.batch_size % config.ngpus != 0: 130 | raise ValueError(f"Train Batch Size {config.training.batch_size} is not divisible by {config.ngpus} gpus.") 131 | if config.eval.batch_size % config.ngpus != 0: 132 | raise ValueError(f"Eval Batch Size {config.eval.batch_size} is not divisible by {config.ngpus} gpus.") 133 | 134 | if distributed: 135 | train_sampler = DistributedSampler(train_set) 136 | test_sampler = DistributedSampler(test_set) 137 | else: 138 | train_sampler = None 139 | test_sampler = None 140 | 141 | train_loader = DataLoader( 142 | train_set, 143 | batch_size=config.training.batch_size // config.ngpus, 144 | sampler=train_sampler, 145 | num_workers=workers, 146 | pin_memory=True, 147 | shuffle=(train_sampler is None), 148 | persistent_workers=True if workers > 0 else False, 149 | ) 150 | test_loader = DataLoader( 151 | test_set, 152 | batch_size=config.eval.batch_size // config.ngpus, 153 | sampler=test_sampler, 154 | num_workers=workers, 155 | pin_memory=True, 156 | shuffle=(test_sampler is None), 157 | ) 158 | 159 | train_loader, test_loader = cycle_loader(train_loader, train_sampler), cycle_loader(test_loader, test_sampler) 160 | return train_loader, test_loader 161 | -------------------------------------------------------------------------------- /datasets/compile_imagenet64.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import PIL.Image 4 | import sys 5 | from pathlib import Path 6 | 7 | PIL.Image.init() 8 | source_dir = sys.argv[1] 9 | output_dir = "imagenet-64x64" 10 | 11 | def file_ext(name): 12 | return str(name).split('.')[-1] 13 | 14 | def is_image_ext(fname): 15 | ext = file_ext(fname).lower() 16 | return f'.{ext}' in PIL.Image.EXTENSION 17 | 18 | if not os.path.exists(output_dir): 19 | os.makedirs(output_dir) 20 | 21 | input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)] 22 | 23 | print(input_images[0:10]) 24 | toplevel_names = [os.path.relpath(fname, source_dir).split('/')[0] for fname in input_images] 25 | print(toplevel_names[0:10]) 26 | toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names)))} 27 | 28 | def get_name(idx): 29 | idx_str = f'{idx:08d}' 30 | return f'{idx_str[:5]}/img{idx_str}.png' 31 | labels = [[get_name(i), toplevel_indices[toplevel_name]] for i, toplevel_name in enumerate(toplevel_names)] 32 | 33 | print(len(labels)) 34 | print(labels[0:10]) 35 | 36 | metadata = {'labels': labels} 37 | data = json.dumps(metadata) 38 | 39 | with open(os.path.join(output_dir, "dataset.json"), 'wb') as fout: 40 | if isinstance(data, str): 41 | data = data.encode('utf8') 42 | fout.write(data) 43 | -------------------------------------------------------------------------------- /datasets/download_cifar10.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | 3 | torchvision.datasets.CIFAR10('.', download=True) -------------------------------------------------------------------------------- /datasets/download_imagenet32.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir ds_imagenet 3 | cd ds_imagenet 4 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-08kPTbCYHhFcwerMbZpiYFWCbbtWR17' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-08kPTbCYHhFcwerMbZpiYFWCbbtWR17" -O train_32x32.tar && rm -rf /tmp/cookies.txt 5 | tar -xf train_32x32.tar 6 | rm train_32x32.tar 7 | 8 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=14CNPjwnkYAFXI77YYHSb0qa8HwIohxnh' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=14CNPjwnkYAFXI77YYHSb0qa8HwIohxnh" -O valid_32x32.tar && rm -rf /tmp/cookies.txt 9 | tar -xf valid_32x32.tar 10 | rm valid_32x32.tar 11 | cd .. 12 | -------------------------------------------------------------------------------- /download_pretrained.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir weights 3 | cd weights 4 | 5 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1AYPr0R8-3CssADBfYYSi1JuYaVrpLkTm' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1AYPr0R8-3CssADBfYYSi1JuYaVrpLkTm" -O cifar10.tar.gz && rm -rf /tmp/cookies.txt 6 | tar -xvzf cifar10.tar.gz 7 | rm cifar10.tar.gz 8 | 9 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1CL5tM-SO4vn6tyXzrFh7VBzQv3jXDI6X' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1CL5tM-SO4vn6tyXzrFh7VBzQv3jXDI6X" -O denoiser.tar.gz && rm -rf /tmp/cookies.txt 10 | tar -xvzf denoiser.tar.gz 11 | rm denoiser.tar.gz 12 | 13 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1e177im3rwI1rsHcQ5wAsaCKBKcDYRllf' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1e177im3rwI1rsHcQ5wAsaCKBKcDYRllf" -O imagenet64.tar.gz && rm -rf /tmp/cookies.txt 14 | tar -xvzf imagenet64.tar.gz 15 | rm imagenet64.tar.gz 16 | 17 | cd .. -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | """All functions related to loss computation and optimization. 2 | """ 3 | 4 | import torch 5 | import torch.optim as optim 6 | import numpy as np 7 | from models import utils as mutils 8 | 9 | import cube 10 | 11 | 12 | def get_optimizer(config, params): 13 | if config.optim.optimizer == 'Adam': 14 | optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, config.optim.beta2), eps=config.optim.eps, 15 | weight_decay=config.optim.weight_decay) 16 | elif config.optim.optimizer == 'AdamW': 17 | optimizer = optim.AdamW(params, lr=config.optim.lr, betas=(config.optim.beta1, config.optim.beta2), eps=config.optim.eps, 18 | weight_decay=config.optim.weight_decay) 19 | else: 20 | raise NotImplementedError( 21 | f'Optimizer {config.optim.optimizer} not supported yet!') 22 | 23 | return optimizer 24 | 25 | 26 | def optimization_manager(config): 27 | """Returns an optimize_fn based on `config`.""" 28 | 29 | def optimize_fn(optimizer, params, step, lr=config.optim.lr, 30 | warmup=config.optim.warmup, 31 | grad_clip=config.optim.grad_clip, 32 | scaler=None): 33 | """Optimizes with warmup and gradient clipping (disabled if negative).""" 34 | if scaler is not None: 35 | scaler.unscale_(optimizer) 36 | 37 | if warmup > 0: 38 | for g in optimizer.param_groups: 39 | g['lr'] = lr * np.minimum(step / warmup, 1.0) 40 | if grad_clip >= 0: 41 | torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip) 42 | 43 | if scaler is None: 44 | optimizer.step() 45 | else: 46 | scaler.step(optimizer) 47 | scaler.update() 48 | 49 | return optimize_fn 50 | 51 | 52 | def get_sde_loss_fn(sde, train, reduce_mean=True, likelihood_weighting=True, eps=1e-5): 53 | """Create a loss function for training with arbitrary SDEs. 54 | 55 | Args: 56 | sde: An `sde_lib.SDE` object that represents the forward SDE. 57 | train: `True` for training loss and `False` for evaluation loss. 58 | reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions. 59 | likelihood_weighting: If `True`, outputs the diffusion variational bound term. 60 | eps: A `float` number. The smallest time step to sample from. 61 | 62 | Returns: 63 | A loss function. 64 | """ 65 | reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * \ 66 | torch.sum(*args, **kwargs) 67 | 68 | def loss_fn(model, batch, class_labels=None): 69 | """Compute the loss function. 70 | 71 | Args: 72 | model: A score model. 73 | batch: A mini-batch of training data. 74 | 75 | Returns: 76 | loss: A scalar that represents the average loss value across the mini-batch. 77 | """ 78 | score_fn = mutils.get_score_fn(sde, model, train=train) 79 | t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps 80 | z = torch.randn_like(batch) 81 | 82 | mean, std = sde.marginal_prob(batch, t) 83 | perturbed_data = cube.reflect(mean + std[:, None, None, None] * z) 84 | score = score_fn(perturbed_data, t, class_labels=class_labels) 85 | score_hk = cube.score_hk(perturbed_data, mean, std) 86 | 87 | if not likelihood_weighting: 88 | losses = (std ** 2)[:, None, None, None] * (score - score_hk).pow(2) 89 | else: 90 | g2 = sde.sde(torch.zeros_like(batch), t)[1] ** 2 91 | losses = g2[:, None, None, None] * (score - score_hk).pow(2) 92 | 93 | losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) 94 | loss = torch.mean(losses) 95 | return loss 96 | 97 | return loss_fn 98 | 99 | 100 | def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, likelihood_weighting=False): 101 | """Create a one-step training/evaluation function. 102 | 103 | Args: 104 | sde: An `sde_lib.SDE` object that represents the forward SDE. 105 | optimize_fn: An optimization function. 106 | reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions. 107 | likelihood_weighting: If `True`, outputs the diffusion variational bound term. 108 | 109 | Returns: 110 | A one-step function for training or evaluation. 111 | """ 112 | loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean, likelihood_weighting=likelihood_weighting) 113 | 114 | def step_fn(state, batch, class_labels=None): 115 | """Running one step of training or evaluation. 116 | 117 | This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together 118 | for faster execution. 119 | 120 | Args: 121 | state: A dictionary of training information, containing the score model, optimizer, 122 | EMA status, and number of optimization steps. 123 | batch: A mini-batch of training/evaluation data. 124 | 125 | Returns: 126 | loss: The average loss value of this state. 127 | """ 128 | model = state['model'] 129 | if train: 130 | optimizer = state['optimizer'] 131 | optimizer.zero_grad() 132 | loss = loss_fn(model, batch, class_labels=class_labels) 133 | if state['scaler'] is None: 134 | loss.backward() 135 | else: 136 | state['scaler'].scale(loss).backward() 137 | optimize_fn(optimizer, model.parameters(), step=state['step'], scaler=state['scaler']) 138 | state['step'] += 1 139 | state['ema'].update(model.parameters()) 140 | else: 141 | with torch.no_grad(): 142 | ema = state['ema'] 143 | ema.store(model.parameters()) 144 | ema.copy_to(model.parameters()) 145 | loss = loss_fn(model, batch, class_labels=class_labels) 146 | ema.restore(model.parameters()) 147 | 148 | return loss 149 | 150 | return step_fn 151 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louaaron/Reflected-Diffusion/dc4402607b7a4f5302f4b325ecfe2095b693b26f/models/__init__.py -------------------------------------------------------------------------------- /models/adm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch.nn.functional import silu 6 | 7 | from .layersv2 import * 8 | 9 | from . import utils 10 | 11 | 12 | class UNetBlock(nn.Module): 13 | def __init__(self, 14 | in_channels, out_channels, emb_channels, up=False, down=False, attention=False, 15 | num_heads=None, channels_per_head=64, dropout=0, skip_scale=1, eps=1e-5, 16 | resample_filter=[1,1], resample_proj=False, adaptive_scale=True, 17 | init=dict(), init_zero=dict(init_weight=0), init_attn=None, 18 | ): 19 | super().__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.emb_channels = emb_channels 23 | self.num_heads = 0 if not attention else num_heads if num_heads is not None else out_channels // channels_per_head 24 | self.dropout = dropout 25 | self.skip_scale = skip_scale 26 | self.adaptive_scale = adaptive_scale 27 | 28 | self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) 29 | self.conv0 = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=3, up=up, down=down, resample_filter=resample_filter, **init) 30 | self.affine = Linear(in_features=emb_channels, out_features=out_channels*(2 if adaptive_scale else 1), **init) 31 | self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) 32 | self.conv1 = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero) 33 | 34 | self.skip = None 35 | if out_channels != in_channels or up or down: 36 | kernel = 1 if resample_proj or out_channels!= in_channels else 0 37 | self.skip = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=kernel, up=up, down=down, resample_filter=resample_filter, **init) 38 | 39 | if self.num_heads: 40 | self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) 41 | self.qkv = Conv2d(in_channels=out_channels, out_channels=out_channels*3, kernel=1, **(init_attn if init_attn is not None else init)) 42 | self.proj = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=1, **init_zero) 43 | 44 | def forward(self, x, emb): 45 | orig = x 46 | x = self.conv0(silu(self.norm0(x))) 47 | 48 | params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) 49 | if self.adaptive_scale: 50 | scale, shift = params.chunk(chunks=2, dim=1) 51 | x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) 52 | else: 53 | x = silu(self.norm1(x.add_(params))) 54 | 55 | x = self.conv1(F.dropout(x, p=self.dropout, training=self.training)) 56 | x = x.add_(self.skip(orig) if self.skip is not None else orig) 57 | x = x * self.skip_scale 58 | 59 | if self.num_heads: 60 | q, k, v = self.qkv(self.norm2(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2) 61 | w = AttentionOp.apply(q, k) 62 | a = torch.einsum('nqk,nck->ncq', w, v) 63 | x = self.proj(a.reshape(*x.shape)).add_(x) 64 | x = x * self.skip_scale 65 | return x 66 | 67 | 68 | class ADM(nn.Module): 69 | def __init__(self, 70 | img_resolution=64, 71 | in_channels=3, 72 | out_channels=3, 73 | label_dim=0, 74 | augment_dim=0, 75 | model_channels=192, 76 | channel_mult=[1,2,3,4], 77 | channel_mult_emb=4, 78 | num_blocks=3, 79 | attn_resolutions=[32,16,8], 80 | dropout=0.10, 81 | label_dropout=0, 82 | ): 83 | super().__init__() 84 | self.label_dropout = label_dropout 85 | emb_channels = model_channels * channel_mult_emb 86 | init = dict(init_mode='kaiming_uniform', init_weight=np.sqrt(1/3), init_bias=np.sqrt(1/3)) 87 | init_zero = dict(init_mode='kaiming_uniform', init_weight=0, init_bias=0) 88 | block_kwargs = dict(emb_channels=emb_channels, channels_per_head=64, dropout=dropout, init=init, init_zero=init_zero) 89 | 90 | # Mapping. 91 | self.map_noise = PositionalEmbedding(num_channels=model_channels) 92 | self.map_augment = Linear(in_features=augment_dim, out_features=model_channels, bias=False, **init_zero) if augment_dim else None 93 | self.map_layer0 = Linear(in_features=model_channels, out_features=emb_channels, **init) 94 | self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init) 95 | self.map_label = Linear(in_features=label_dim, out_features=emb_channels, bias=False, init_mode='kaiming_normal', init_weight=np.sqrt(label_dim)) if label_dim else None 96 | 97 | # Encoder. 98 | self.enc = nn.ModuleDict() 99 | cout = in_channels 100 | for level, mult in enumerate(channel_mult): 101 | res = img_resolution >> level 102 | if level == 0: 103 | cin = cout 104 | cout = model_channels * mult 105 | self.enc[f'{res}x{res}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init) 106 | else: 107 | self.enc[f'{res}x{res}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs) 108 | for idx in range(num_blocks): 109 | cin = cout 110 | cout = model_channels * mult 111 | self.enc[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=(res in attn_resolutions), **block_kwargs) 112 | skips = [block.out_channels for block in self.enc.values()] 113 | 114 | # Decoder. 115 | self.dec = nn.ModuleDict() 116 | for level, mult in reversed(list(enumerate(channel_mult))): 117 | res = img_resolution >> level 118 | if level == len(channel_mult) - 1: 119 | self.dec[f'{res}x{res}_in0'] = UNetBlock(in_channels=cout, out_channels=cout, attention=True, **block_kwargs) 120 | self.dec[f'{res}x{res}_in1'] = UNetBlock(in_channels=cout, out_channels=cout, **block_kwargs) 121 | else: 122 | self.dec[f'{res}x{res}_up'] = UNetBlock(in_channels=cout, out_channels=cout, up=True, **block_kwargs) 123 | for idx in range(num_blocks + 1): 124 | cin = cout + skips.pop() 125 | cout = model_channels * mult 126 | self.dec[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=(res in attn_resolutions), **block_kwargs) 127 | self.out_norm = GroupNorm(num_channels=cout) 128 | self.out_conv = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, **init_zero) 129 | 130 | def forward(self, x, noise_labels, class_labels, augment_labels=None): 131 | # Mapping. 132 | emb = self.map_noise(noise_labels) 133 | if self.map_augment is not None and augment_labels is not None: 134 | emb = emb + self.map_augment(augment_labels) 135 | emb = silu(self.map_layer0(emb)) 136 | emb = self.map_layer1(emb) 137 | if self.map_label is not None: 138 | tmp = class_labels 139 | if self.training and self.label_dropout: 140 | tmp = tmp * (torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout).to(tmp.dtype) 141 | emb = emb + self.map_label(tmp) 142 | emb = silu(emb) 143 | 144 | # Encoder. 145 | skips = [] 146 | for block in self.enc.values(): 147 | x = block(x, emb) if isinstance(block, UNetBlock) else block(x) 148 | skips.append(x) 149 | 150 | # Decoder. 151 | for block in self.dec.values(): 152 | if x.shape[1] != block.in_channels: 153 | x = torch.cat([x, skips.pop()], dim=1) 154 | x = block(x, emb) 155 | x = self.out_conv(silu(self.out_norm(x))) 156 | return x 157 | 158 | 159 | @utils.register_model(name='adm') 160 | class WrappedADM(nn.Module): 161 | def __init__(self, cfg): 162 | super().__init__() 163 | self.sigma_min = cfg.sde.sigma_min 164 | self.sigma_max = cfg.sde.sigma_max 165 | self.model = ADM( 166 | img_resolution=cfg.data.image_size, 167 | in_channels=cfg.data.num_channels, 168 | out_channels=cfg.data.num_channels, 169 | label_dim=cfg.data.num_classes, 170 | augment_dim=0, 171 | model_channels=cfg.model.model_channels, 172 | channel_mult=cfg.model.channel_mult, 173 | channel_mult_emb=cfg.model.channel_mult_emb, 174 | num_blocks=cfg.model.num_blocks, 175 | attn_resolutions=cfg.model.attn_resolutions, 176 | dropout=cfg.model.dropout, 177 | label_dropout=cfg.training.drop_label, 178 | ) 179 | self.num_classes = cfg.data.num_classes 180 | self.scale_by_sigma = cfg.model.scale_by_sigma 181 | 182 | def forward(self, x, sigma, class_labels=None): 183 | if class_labels is None: 184 | class_labels = torch.zeros(x.shape[0], self.num_classes).to(x.device) 185 | else: 186 | class_labels = F.one_hot(class_labels, num_classes=self.num_classes).float() 187 | 188 | sigma_inp = (sigma / 2).log() 189 | Fx = self.model(x.half(), sigma_inp, class_labels=class_labels) 190 | 191 | if self.scale_by_sigma: 192 | Fx /= sigma[:, None, None, None] 193 | 194 | return Fx 195 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py 2 | 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | 8 | 9 | # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py 10 | class ExponentialMovingAverage: 11 | """ 12 | Maintains (exponential) moving average of a set of parameters. 13 | """ 14 | 15 | def __init__(self, parameters, decay, use_num_updates=True): 16 | """ 17 | Args: 18 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 19 | `model.parameters()`. 20 | decay: The exponential decay. 21 | use_num_updates: Whether to use number of updates when computing 22 | averages. 23 | """ 24 | if decay < 0.0 or decay > 1.0: 25 | raise ValueError('Decay must be between 0 and 1') 26 | self.decay = decay 27 | self.num_updates = 0 if use_num_updates else None 28 | self.shadow_params = [p.clone().detach() 29 | for p in parameters if p.requires_grad] 30 | self.collected_params = [] 31 | 32 | def update(self, parameters): 33 | """ 34 | Update currently maintained parameters. 35 | 36 | Call this every time the parameters are updated, such as the result of 37 | the `optimizer.step()` call. 38 | 39 | Args: 40 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 41 | parameters used to initialize this object. 42 | """ 43 | decay = self.decay 44 | if self.num_updates is not None: 45 | self.num_updates += 1 46 | decay = min(decay, (1 + self.num_updates) / 47 | (10 + self.num_updates)) 48 | one_minus_decay = 1.0 - decay 49 | with torch.no_grad(): 50 | parameters = [p for p in parameters if p.requires_grad] 51 | for s_param, param in zip(self.shadow_params, parameters): 52 | s_param.sub_(one_minus_decay * (s_param - param)) 53 | 54 | def copy_to(self, parameters): 55 | """ 56 | Copy current parameters into given collection of parameters. 57 | 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | updated with the stored moving averages. 61 | """ 62 | parameters = [p for p in parameters if p.requires_grad] 63 | for s_param, param in zip(self.shadow_params, parameters): 64 | if param.requires_grad: 65 | param.data.copy_(s_param.data) 66 | 67 | def store(self, parameters): 68 | """ 69 | Save the current parameters for restoring later. 70 | 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | temporarily stored. 74 | """ 75 | self.collected_params = [param.clone() for param in parameters] 76 | 77 | def restore(self, parameters): 78 | """ 79 | Restore the parameters stored with the `store` method. 80 | Useful to validate the model with EMA parameters without affecting the 81 | original optimization process. Store the parameters before the 82 | `copy_to` method. After validation (or model saving), use this to 83 | restore the former parameters. 84 | 85 | Args: 86 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 87 | updated with the stored parameters. 88 | """ 89 | for c_param, param in zip(self.collected_params, parameters): 90 | param.data.copy_(c_param.data) 91 | 92 | def state_dict(self): 93 | return dict(decay=self.decay, num_updates=self.num_updates, 94 | shadow_params=self.shadow_params) 95 | 96 | def load_state_dict(self, state_dict): 97 | self.decay = state_dict['decay'] 98 | self.num_updates = state_dict['num_updates'] 99 | self.shadow_params = state_dict['shadow_params'] 100 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | """Common layers for defining score networks. 3 | """ 4 | import math 5 | import string 6 | from functools import partial 7 | import torch.nn as nn 8 | import torch 9 | import torch.nn.functional as F 10 | import numpy as np 11 | from .normalization import ConditionalInstanceNorm2dPlus 12 | 13 | 14 | def get_act(config): 15 | """Get activation functions from the config file.""" 16 | 17 | if config.model.nonlinearity.lower() == 'elu': 18 | return nn.ELU() 19 | elif config.model.nonlinearity.lower() == 'relu': 20 | return nn.ReLU() 21 | elif config.model.nonlinearity.lower() == 'lrelu': 22 | return nn.LeakyReLU(negative_slope=0.2) 23 | elif config.model.nonlinearity.lower() == 'swish': 24 | return nn.SiLU() 25 | else: 26 | raise NotImplementedError('activation function does not exist!') 27 | 28 | 29 | def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0): 30 | """1x1 convolution. Same as NCSNv1/v2.""" 31 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation, 32 | padding=padding) 33 | init_scale = 1e-10 if init_scale == 0 else init_scale 34 | conv.weight.data *= init_scale 35 | conv.bias.data *= init_scale 36 | return conv 37 | 38 | 39 | def variance_scaling(scale, mode, distribution, 40 | in_axis=1, out_axis=0, 41 | dtype=torch.float32, 42 | device='cpu'): 43 | """Ported from JAX. """ 44 | 45 | def _compute_fans(shape, in_axis=1, out_axis=0): 46 | receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] 47 | fan_in = shape[in_axis] * receptive_field_size 48 | fan_out = shape[out_axis] * receptive_field_size 49 | return fan_in, fan_out 50 | 51 | def init(shape, dtype=dtype, device=device): 52 | fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) 53 | if mode == "fan_in": 54 | denominator = fan_in 55 | elif mode == "fan_out": 56 | denominator = fan_out 57 | elif mode == "fan_avg": 58 | denominator = (fan_in + fan_out) / 2 59 | else: 60 | raise ValueError( 61 | "invalid mode for variance scaling initializer: {}".format(mode)) 62 | variance = scale / denominator 63 | if distribution == "normal": 64 | return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) 65 | elif distribution == "uniform": 66 | return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) 67 | else: 68 | raise ValueError("invalid distribution for variance scaling initializer") 69 | 70 | return init 71 | 72 | 73 | def default_init(scale=1.): 74 | """The same initialization used in DDPM.""" 75 | scale = 1e-10 if scale == 0 else scale 76 | return variance_scaling(scale, 'fan_avg', 'uniform') 77 | 78 | 79 | class Dense(nn.Module): 80 | """Linear layer with `default_init`.""" 81 | def __init__(self): 82 | super().__init__() 83 | 84 | 85 | def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0): 86 | """1x1 convolution with DDPM initialization.""" 87 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) 88 | conv.weight.data = default_init(init_scale)(conv.weight.data.shape) 89 | nn.init.zeros_(conv.bias) 90 | return conv 91 | 92 | 93 | def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1): 94 | """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2.""" 95 | init_scale = 1e-10 if init_scale == 0 else init_scale 96 | conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias, 97 | dilation=dilation, padding=padding, kernel_size=3) 98 | conv.weight.data *= init_scale 99 | conv.bias.data *= init_scale 100 | return conv 101 | 102 | 103 | def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1): 104 | """3x3 convolution with DDPM initialization.""" 105 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, 106 | dilation=dilation, bias=bias) 107 | conv.weight.data = default_init(init_scale)(conv.weight.data.shape) 108 | nn.init.zeros_(conv.bias) 109 | return conv 110 | 111 | ########################################################################### 112 | # Functions below are ported over from the NCSNv1/NCSNv2 codebase: 113 | # https://github.com/ermongroup/ncsn 114 | # https://github.com/ermongroup/ncsnv2 115 | ########################################################################### 116 | 117 | 118 | class CRPBlock(nn.Module): 119 | def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True): 120 | super().__init__() 121 | self.convs = nn.ModuleList() 122 | for i in range(n_stages): 123 | self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False)) 124 | self.n_stages = n_stages 125 | if maxpool: 126 | self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) 127 | else: 128 | self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) 129 | 130 | self.act = act 131 | 132 | def forward(self, x): 133 | x = self.act(x) 134 | path = x 135 | for i in range(self.n_stages): 136 | path = self.pool(path) 137 | path = self.convs[i](path) 138 | x = path + x 139 | return x 140 | 141 | 142 | class CondCRPBlock(nn.Module): 143 | def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()): 144 | super().__init__() 145 | self.convs = nn.ModuleList() 146 | self.norms = nn.ModuleList() 147 | self.normalizer = normalizer 148 | for i in range(n_stages): 149 | self.norms.append(normalizer(features, num_classes, bias=True)) 150 | self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False)) 151 | 152 | self.n_stages = n_stages 153 | self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) 154 | self.act = act 155 | 156 | def forward(self, x, y): 157 | x = self.act(x) 158 | path = x 159 | for i in range(self.n_stages): 160 | path = self.norms[i](path, y) 161 | path = self.pool(path) 162 | path = self.convs[i](path) 163 | 164 | x = path + x 165 | return x 166 | 167 | 168 | class RCUBlock(nn.Module): 169 | def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()): 170 | super().__init__() 171 | 172 | for i in range(n_blocks): 173 | for j in range(n_stages): 174 | setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False)) 175 | 176 | self.stride = 1 177 | self.n_blocks = n_blocks 178 | self.n_stages = n_stages 179 | self.act = act 180 | 181 | def forward(self, x): 182 | for i in range(self.n_blocks): 183 | residual = x 184 | for j in range(self.n_stages): 185 | x = self.act(x) 186 | x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x) 187 | 188 | x += residual 189 | return x 190 | 191 | 192 | class CondRCUBlock(nn.Module): 193 | def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()): 194 | super().__init__() 195 | 196 | for i in range(n_blocks): 197 | for j in range(n_stages): 198 | setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True)) 199 | setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False)) 200 | 201 | self.stride = 1 202 | self.n_blocks = n_blocks 203 | self.n_stages = n_stages 204 | self.act = act 205 | self.normalizer = normalizer 206 | 207 | def forward(self, x, y): 208 | for i in range(self.n_blocks): 209 | residual = x 210 | for j in range(self.n_stages): 211 | x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y) 212 | x = self.act(x) 213 | x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x) 214 | 215 | x += residual 216 | return x 217 | 218 | 219 | class MSFBlock(nn.Module): 220 | def __init__(self, in_planes, features): 221 | super().__init__() 222 | assert isinstance(in_planes, list) or isinstance(in_planes, tuple) 223 | self.convs = nn.ModuleList() 224 | self.features = features 225 | 226 | for i in range(len(in_planes)): 227 | self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True)) 228 | 229 | def forward(self, xs, shape): 230 | sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) 231 | for i in range(len(self.convs)): 232 | h = self.convs[i](xs[i]) 233 | h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True) 234 | sums += h 235 | return sums 236 | 237 | 238 | class CondMSFBlock(nn.Module): 239 | def __init__(self, in_planes, features, num_classes, normalizer): 240 | super().__init__() 241 | assert isinstance(in_planes, list) or isinstance(in_planes, tuple) 242 | 243 | self.convs = nn.ModuleList() 244 | self.norms = nn.ModuleList() 245 | self.features = features 246 | self.normalizer = normalizer 247 | 248 | for i in range(len(in_planes)): 249 | self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True)) 250 | self.norms.append(normalizer(in_planes[i], num_classes, bias=True)) 251 | 252 | def forward(self, xs, y, shape): 253 | sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) 254 | for i in range(len(self.convs)): 255 | h = self.norms[i](xs[i], y) 256 | h = self.convs[i](h) 257 | h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True) 258 | sums += h 259 | return sums 260 | 261 | 262 | class RefineBlock(nn.Module): 263 | def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True): 264 | super().__init__() 265 | 266 | assert isinstance(in_planes, tuple) or isinstance(in_planes, list) 267 | self.n_blocks = n_blocks = len(in_planes) 268 | 269 | self.adapt_convs = nn.ModuleList() 270 | for i in range(n_blocks): 271 | self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act)) 272 | 273 | self.output_convs = RCUBlock(features, 3 if end else 1, 2, act) 274 | 275 | if not start: 276 | self.msf = MSFBlock(in_planes, features) 277 | 278 | self.crp = CRPBlock(features, 2, act, maxpool=maxpool) 279 | 280 | def forward(self, xs, output_shape): 281 | assert isinstance(xs, tuple) or isinstance(xs, list) 282 | hs = [] 283 | for i in range(len(xs)): 284 | h = self.adapt_convs[i](xs[i]) 285 | hs.append(h) 286 | 287 | if self.n_blocks > 1: 288 | h = self.msf(hs, output_shape) 289 | else: 290 | h = hs[0] 291 | 292 | h = self.crp(h) 293 | h = self.output_convs(h) 294 | 295 | return h 296 | 297 | 298 | class CondRefineBlock(nn.Module): 299 | def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False): 300 | super().__init__() 301 | 302 | assert isinstance(in_planes, tuple) or isinstance(in_planes, list) 303 | self.n_blocks = n_blocks = len(in_planes) 304 | 305 | self.adapt_convs = nn.ModuleList() 306 | for i in range(n_blocks): 307 | self.adapt_convs.append( 308 | CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act) 309 | ) 310 | 311 | self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act) 312 | 313 | if not start: 314 | self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer) 315 | 316 | self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act) 317 | 318 | def forward(self, xs, y, output_shape): 319 | assert isinstance(xs, tuple) or isinstance(xs, list) 320 | hs = [] 321 | for i in range(len(xs)): 322 | h = self.adapt_convs[i](xs[i], y) 323 | hs.append(h) 324 | 325 | if self.n_blocks > 1: 326 | h = self.msf(hs, y, output_shape) 327 | else: 328 | h = hs[0] 329 | 330 | h = self.crp(h, y) 331 | h = self.output_convs(h, y) 332 | 333 | return h 334 | 335 | 336 | class ConvMeanPool(nn.Module): 337 | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False): 338 | super().__init__() 339 | if not adjust_padding: 340 | conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 341 | self.conv = conv 342 | else: 343 | conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 344 | 345 | self.conv = nn.Sequential( 346 | nn.ZeroPad2d((1, 0, 1, 0)), 347 | conv 348 | ) 349 | 350 | def forward(self, inputs): 351 | output = self.conv(inputs) 352 | output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], 353 | output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4. 354 | return output 355 | 356 | 357 | class MeanPoolConv(nn.Module): 358 | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True): 359 | super().__init__() 360 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 361 | 362 | def forward(self, inputs): 363 | output = inputs 364 | output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], 365 | output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4. 366 | return self.conv(output) 367 | 368 | 369 | class UpsampleConv(nn.Module): 370 | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True): 371 | super().__init__() 372 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 373 | self.pixelshuffle = nn.PixelShuffle(upscale_factor=2) 374 | 375 | def forward(self, inputs): 376 | output = inputs 377 | output = torch.cat([output, output, output, output], dim=1) 378 | output = self.pixelshuffle(output) 379 | return self.conv(output) 380 | 381 | 382 | class ConditionalResidualBlock(nn.Module): 383 | def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(), 384 | normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None): 385 | super().__init__() 386 | self.non_linearity = act 387 | self.input_dim = input_dim 388 | self.output_dim = output_dim 389 | self.resample = resample 390 | self.normalization = normalization 391 | if resample == 'down': 392 | if dilation > 1: 393 | self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation) 394 | self.normalize2 = normalization(input_dim, num_classes) 395 | self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) 396 | conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) 397 | else: 398 | self.conv1 = ncsn_conv3x3(input_dim, input_dim) 399 | self.normalize2 = normalization(input_dim, num_classes) 400 | self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding) 401 | conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding) 402 | 403 | elif resample is None: 404 | if dilation > 1: 405 | conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) 406 | self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) 407 | self.normalize2 = normalization(output_dim, num_classes) 408 | self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation) 409 | else: 410 | conv_shortcut = nn.Conv2d 411 | self.conv1 = ncsn_conv3x3(input_dim, output_dim) 412 | self.normalize2 = normalization(output_dim, num_classes) 413 | self.conv2 = ncsn_conv3x3(output_dim, output_dim) 414 | else: 415 | raise Exception('invalid resample value') 416 | 417 | if output_dim != input_dim or resample is not None: 418 | self.shortcut = conv_shortcut(input_dim, output_dim) 419 | 420 | self.normalize1 = normalization(input_dim, num_classes) 421 | 422 | def forward(self, x, y): 423 | output = self.normalize1(x, y) 424 | output = self.non_linearity(output) 425 | output = self.conv1(output) 426 | output = self.normalize2(output, y) 427 | output = self.non_linearity(output) 428 | output = self.conv2(output) 429 | 430 | if self.output_dim == self.input_dim and self.resample is None: 431 | shortcut = x 432 | else: 433 | shortcut = self.shortcut(x) 434 | 435 | return shortcut + output 436 | 437 | 438 | class ResidualBlock(nn.Module): 439 | def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(), 440 | normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1): 441 | super().__init__() 442 | self.non_linearity = act 443 | self.input_dim = input_dim 444 | self.output_dim = output_dim 445 | self.resample = resample 446 | self.normalization = normalization 447 | if resample == 'down': 448 | if dilation > 1: 449 | self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation) 450 | self.normalize2 = normalization(input_dim) 451 | self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) 452 | conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) 453 | else: 454 | self.conv1 = ncsn_conv3x3(input_dim, input_dim) 455 | self.normalize2 = normalization(input_dim) 456 | self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding) 457 | conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding) 458 | 459 | elif resample is None: 460 | if dilation > 1: 461 | conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) 462 | self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) 463 | self.normalize2 = normalization(output_dim) 464 | self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation) 465 | else: 466 | # conv_shortcut = nn.Conv2d ### Something wierd here. 467 | conv_shortcut = partial(ncsn_conv1x1) 468 | self.conv1 = ncsn_conv3x3(input_dim, output_dim) 469 | self.normalize2 = normalization(output_dim) 470 | self.conv2 = ncsn_conv3x3(output_dim, output_dim) 471 | else: 472 | raise Exception('invalid resample value') 473 | 474 | if output_dim != input_dim or resample is not None: 475 | self.shortcut = conv_shortcut(input_dim, output_dim) 476 | 477 | self.normalize1 = normalization(input_dim) 478 | 479 | def forward(self, x): 480 | output = self.normalize1(x) 481 | output = self.non_linearity(output) 482 | output = self.conv1(output) 483 | output = self.normalize2(output) 484 | output = self.non_linearity(output) 485 | output = self.conv2(output) 486 | 487 | if self.output_dim == self.input_dim and self.resample is None: 488 | shortcut = x 489 | else: 490 | shortcut = self.shortcut(x) 491 | 492 | return shortcut + output 493 | 494 | 495 | ########################################################################### 496 | # Functions below are ported over from the DDPM codebase: 497 | # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py 498 | ########################################################################### 499 | 500 | def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): 501 | assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 502 | half_dim = embedding_dim // 2 503 | # magic number 10000 is from transformers 504 | emb = math.log(max_positions) / (half_dim - 1) 505 | # emb = math.log(2.) / (half_dim - 1) 506 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) 507 | # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] 508 | # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] 509 | emb = timesteps.float()[:, None] * emb[None, :] 510 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 511 | if embedding_dim % 2 == 1: # zero pad 512 | emb = F.pad(emb, (0, 1), mode='constant') 513 | assert emb.shape == (timesteps.shape[0], embedding_dim) 514 | return emb 515 | 516 | 517 | def _einsum(a, b, c, x, y): 518 | einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) 519 | return torch.einsum(einsum_str, x, y) 520 | 521 | 522 | def contract_inner(x, y): 523 | """tensordot(x, y, 1).""" 524 | x_chars = list(string.ascii_lowercase[:len(x.shape)]) 525 | y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)]) 526 | y_chars[0] = x_chars[-1] # first axis of y and last of x get summed 527 | out_chars = x_chars[:-1] + y_chars[1:] 528 | return _einsum(x_chars, y_chars, out_chars, x, y) 529 | 530 | 531 | class NIN(nn.Module): 532 | def __init__(self, in_dim, num_units, init_scale=0.1): 533 | super().__init__() 534 | self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) 535 | self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) 536 | 537 | def forward(self, x): 538 | x = x.permute(0, 2, 3, 1) 539 | y = contract_inner(x, self.W) + self.b 540 | return y.permute(0, 3, 1, 2) 541 | 542 | 543 | class AttnBlock(nn.Module): 544 | """Channel-wise self-attention block.""" 545 | def __init__(self, channels): 546 | super().__init__() 547 | self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6) 548 | self.NIN_0 = NIN(channels, channels) 549 | self.NIN_1 = NIN(channels, channels) 550 | self.NIN_2 = NIN(channels, channels) 551 | self.NIN_3 = NIN(channels, channels, init_scale=0.) 552 | 553 | def forward(self, x): 554 | B, C, H, W = x.shape 555 | h = self.GroupNorm_0(x) 556 | q = self.NIN_0(h) 557 | k = self.NIN_1(h) 558 | v = self.NIN_2(h) 559 | 560 | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) 561 | w = torch.reshape(w, (B, H, W, H * W)) 562 | w = F.softmax(w, dim=-1) 563 | w = torch.reshape(w, (B, H, W, H, W)) 564 | h = torch.einsum('bhwij,bcij->bchw', w, v) 565 | h = self.NIN_3(h) 566 | return x + h 567 | 568 | 569 | class Upsample(nn.Module): 570 | def __init__(self, channels, with_conv=False): 571 | super().__init__() 572 | if with_conv: 573 | self.Conv_0 = ddpm_conv3x3(channels, channels) 574 | self.with_conv = with_conv 575 | 576 | def forward(self, x): 577 | B, C, H, W = x.shape 578 | h = F.interpolate(x, (H * 2, W * 2), mode='nearest') 579 | if self.with_conv: 580 | h = self.Conv_0(h) 581 | return h 582 | 583 | 584 | class Downsample(nn.Module): 585 | def __init__(self, channels, with_conv=False): 586 | super().__init__() 587 | if with_conv: 588 | self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0) 589 | self.with_conv = with_conv 590 | 591 | def forward(self, x): 592 | B, C, H, W = x.shape 593 | # Emulate 'SAME' padding 594 | if self.with_conv: 595 | x = F.pad(x, (0, 1, 0, 1)) 596 | x = self.Conv_0(x) 597 | else: 598 | x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0) 599 | 600 | assert x.shape == (B, C, H // 2, W // 2) 601 | return x 602 | 603 | 604 | class ResnetBlockDDPM(nn.Module): 605 | """The ResNet Blocks used in DDPM.""" 606 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1): 607 | super().__init__() 608 | if out_ch is None: 609 | out_ch = in_ch 610 | self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6) 611 | self.act = act 612 | self.Conv_0 = ddpm_conv3x3(in_ch, out_ch) 613 | if temb_dim is not None: 614 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 615 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) 616 | nn.init.zeros_(self.Dense_0.bias) 617 | 618 | self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6) 619 | self.Dropout_0 = nn.Dropout(dropout) 620 | self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.) 621 | if in_ch != out_ch: 622 | if conv_shortcut: 623 | self.Conv_2 = ddpm_conv3x3(in_ch, out_ch) 624 | else: 625 | self.NIN_0 = NIN(in_ch, out_ch) 626 | self.out_ch = out_ch 627 | self.in_ch = in_ch 628 | self.conv_shortcut = conv_shortcut 629 | 630 | def forward(self, x, temb=None): 631 | B, C, H, W = x.shape 632 | assert C == self.in_ch 633 | out_ch = self.out_ch if self.out_ch else self.in_ch 634 | h = self.act(self.GroupNorm_0(x)) 635 | h = self.Conv_0(h) 636 | # Add bias to each feature map conditioned on the time embedding 637 | if temb is not None: 638 | h += self.Dense_0(self.act(temb))[:, :, None, None] 639 | h = self.act(self.GroupNorm_1(h)) 640 | h = self.Dropout_0(h) 641 | h = self.Conv_1(h) 642 | if C != out_ch: 643 | if self.conv_shortcut: 644 | x = self.Conv_2(x) 645 | else: 646 | x = self.NIN_0(x) 647 | return x + h 648 | -------------------------------------------------------------------------------- /models/layerspp.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | """Layers for defining NCSN++. 3 | """ 4 | from math import pi 5 | 6 | from . import layers 7 | from . import up_or_down_sampling 8 | import torch.nn as nn 9 | import torch 10 | import torch.nn.functional as F 11 | import numpy as np 12 | 13 | conv1x1 = layers.ddpm_conv1x1 14 | conv3x3 = layers.ddpm_conv3x3 15 | NIN = layers.NIN 16 | default_init = layers.default_init 17 | 18 | 19 | class GaussianFourierProjection(nn.Module): 20 | """Gaussian Fourier embeddings for noise levels. Meant for time condition usage.""" 21 | 22 | def __init__(self, embedding_size=256, scale=1.0): 23 | super().__init__() 24 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 25 | 26 | def forward(self, x): 27 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 28 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 29 | 30 | 31 | class ImageFourierFeatures(nn.Module): 32 | """Fourier features used in VDMs. Meant for usage in image space""" 33 | def __init__(self, start=6, end=8): 34 | super().__init__() 35 | self.register_buffer("freqs", 2 ** torch.arange(start, end)) 36 | 37 | def forward(self, x): 38 | freqs = (self.freqs * 2 * pi).repeat(x.shape[1]) 39 | x_inp = x 40 | x = x.repeat_interleave(len(self.freqs), dim=1) 41 | 42 | x = freqs[None, :, None, None] * x 43 | return torch.cat([x_inp, x.sin(), x.cos()], dim=1) 44 | 45 | def extra_repr(self): 46 | return f"ImageFourierFeatures({self.freqs.detach().cpu().numpy()})" 47 | 48 | 49 | class Combine(nn.Module): 50 | """Combine information from skip connections.""" 51 | 52 | def __init__(self, dim1, dim2, method='cat'): 53 | super().__init__() 54 | self.Conv_0 = conv1x1(dim1, dim2) 55 | self.method = method 56 | 57 | def forward(self, x, y): 58 | h = self.Conv_0(x) 59 | if self.method == 'cat': 60 | return torch.cat([h, y], dim=1) 61 | elif self.method == 'sum': 62 | return h + y 63 | else: 64 | raise ValueError(f'Method {self.method} not recognized.') 65 | 66 | 67 | class AttnBlockpp(nn.Module): 68 | """Channel-wise self-attention block. Modified from DDPM.""" 69 | 70 | def __init__(self, channels, skip_rescale=False, init_scale=0.): 71 | super().__init__() 72 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, 73 | eps=1e-6) 74 | self.NIN_0 = NIN(channels, channels) 75 | self.NIN_1 = NIN(channels, channels) 76 | self.NIN_2 = NIN(channels, channels) 77 | self.NIN_3 = NIN(channels, channels, init_scale=init_scale) 78 | self.skip_rescale = skip_rescale 79 | 80 | def forward(self, x): 81 | B, C, H, W = x.shape 82 | h = self.GroupNorm_0(x) 83 | q = self.NIN_0(h) 84 | k = self.NIN_1(h) 85 | v = self.NIN_2(h) 86 | 87 | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) 88 | w = torch.reshape(w, (B, H, W, H * W)) 89 | w = F.softmax(w, dim=-1) 90 | w = torch.reshape(w, (B, H, W, H, W)) 91 | h = torch.einsum('bhwij,bcij->bchw', w, v) 92 | h = self.NIN_3(h) 93 | if not self.skip_rescale: 94 | return x + h 95 | else: 96 | return (x + h) / np.sqrt(2.) 97 | 98 | 99 | class Upsample(nn.Module): 100 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 101 | fir_kernel=(1, 3, 3, 1)): 102 | super().__init__() 103 | out_ch = out_ch if out_ch else in_ch 104 | if not fir: 105 | if with_conv: 106 | self.Conv_0 = conv3x3(in_ch, out_ch) 107 | else: 108 | if with_conv: 109 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 110 | kernel=3, up=True, 111 | resample_kernel=fir_kernel, 112 | use_bias=True, 113 | kernel_init=default_init()) 114 | self.fir = fir 115 | self.with_conv = with_conv 116 | self.fir_kernel = fir_kernel 117 | self.out_ch = out_ch 118 | 119 | def forward(self, x): 120 | B, C, H, W = x.shape 121 | if not self.fir: 122 | h = F.interpolate(x, size=(H * 2, W * 2), mode='nearest') 123 | if self.with_conv: 124 | h = self.Conv_0(h) 125 | else: 126 | if not self.with_conv: 127 | h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 128 | else: 129 | h = self.Conv2d_0(x) 130 | 131 | return h 132 | 133 | 134 | class Downsample(nn.Module): 135 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 136 | fir_kernel=(1, 3, 3, 1)): 137 | super().__init__() 138 | out_ch = out_ch if out_ch else in_ch 139 | if not fir: 140 | if with_conv: 141 | self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) 142 | else: 143 | if with_conv: 144 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 145 | kernel=3, down=True, 146 | resample_kernel=fir_kernel, 147 | use_bias=True, 148 | kernel_init=default_init()) 149 | self.fir = fir 150 | self.fir_kernel = fir_kernel 151 | self.with_conv = with_conv 152 | self.out_ch = out_ch 153 | 154 | def forward(self, x): 155 | B, C, H, W = x.shape 156 | if not self.fir: 157 | if self.with_conv: 158 | x = F.pad(x, (0, 1, 0, 1)) 159 | x = self.Conv_0(x) 160 | else: 161 | x = F.avg_pool2d(x, 2, stride=2) 162 | else: 163 | if not self.with_conv: 164 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 165 | else: 166 | x = self.Conv2d_0(x) 167 | 168 | return x 169 | 170 | 171 | class ResnetBlockDDPMpp(nn.Module): 172 | """ResBlock adapted from DDPM.""" 173 | 174 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, 175 | dropout=0.1, skip_rescale=False, init_scale=0.): 176 | super().__init__() 177 | out_ch = out_ch if out_ch else in_ch 178 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 179 | self.Conv_0 = conv3x3(in_ch, out_ch) 180 | if temb_dim is not None: 181 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 182 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) 183 | nn.init.zeros_(self.Dense_0.bias) 184 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 185 | self.Dropout_0 = nn.Dropout(dropout) 186 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 187 | if in_ch != out_ch: 188 | if conv_shortcut: 189 | self.Conv_2 = conv3x3(in_ch, out_ch) 190 | else: 191 | self.NIN_0 = NIN(in_ch, out_ch) 192 | 193 | self.skip_rescale = skip_rescale 194 | self.act = act 195 | self.out_ch = out_ch 196 | self.conv_shortcut = conv_shortcut 197 | 198 | def forward(self, x, temb=None): 199 | h = self.act(self.GroupNorm_0(x)) 200 | h = self.Conv_0(h) 201 | if temb is not None: 202 | h += self.Dense_0(self.act(temb))[:, :, None, None] 203 | h = self.act(self.GroupNorm_1(h)) 204 | h = self.Dropout_0(h) 205 | h = self.Conv_1(h) 206 | if x.shape[1] != self.out_ch: 207 | if self.conv_shortcut: 208 | x = self.Conv_2(x) 209 | else: 210 | x = self.NIN_0(x) 211 | if not self.skip_rescale: 212 | return x + h 213 | else: 214 | return (x + h) / np.sqrt(2.) 215 | 216 | 217 | class ResnetBlockBigGANpp(nn.Module): 218 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False, 219 | dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1), 220 | skip_rescale=True, init_scale=0.): 221 | super().__init__() 222 | 223 | out_ch = out_ch if out_ch else in_ch 224 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 225 | self.up = up 226 | self.down = down 227 | self.fir = fir 228 | self.fir_kernel = fir_kernel 229 | 230 | self.Conv_0 = conv3x3(in_ch, out_ch) 231 | if temb_dim is not None: 232 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 233 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) 234 | nn.init.zeros_(self.Dense_0.bias) 235 | 236 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 237 | self.Dropout_0 = nn.Dropout(dropout) 238 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 239 | if in_ch != out_ch or up or down: 240 | self.Conv_2 = conv1x1(in_ch, out_ch) 241 | 242 | self.skip_rescale = skip_rescale 243 | self.act = act 244 | self.in_ch = in_ch 245 | self.out_ch = out_ch 246 | 247 | def forward(self, x, temb=None): 248 | h = self.act(self.GroupNorm_0(x)) 249 | 250 | if self.up: 251 | if self.fir: 252 | h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) 253 | x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 254 | else: 255 | h = up_or_down_sampling.naive_upsample_2d(h, factor=2) 256 | x = up_or_down_sampling.naive_upsample_2d(x, factor=2) 257 | elif self.down: 258 | if self.fir: 259 | h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) 260 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 261 | else: 262 | h = up_or_down_sampling.naive_downsample_2d(h, factor=2) 263 | x = up_or_down_sampling.naive_downsample_2d(x, factor=2) 264 | 265 | h = self.Conv_0(h) 266 | # Add bias to each feature map conditioned on the time embedding 267 | if temb is not None: 268 | h += self.Dense_0(self.act(temb))[:, :, None, None] 269 | h = self.act(self.GroupNorm_1(h)) 270 | h = self.Dropout_0(h) 271 | h = self.Conv_1(h) 272 | 273 | if self.in_ch != self.out_ch or self.up or self.down: 274 | x = self.Conv_2(x) 275 | 276 | if not self.skip_rescale: 277 | return x + h 278 | else: 279 | return (x + h) / np.sqrt(2.) 280 | -------------------------------------------------------------------------------- /models/layersv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class PositionalEmbedding(nn.Module): 8 | def __init__(self, num_channels, max_positions=10000, endpoint=False): 9 | super().__init__() 10 | self.num_channels = num_channels 11 | self.max_positions = max_positions 12 | self.endpoint = endpoint 13 | 14 | def forward(self, x): 15 | freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) 16 | freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) 17 | freqs = (1 / self.max_positions) ** freqs 18 | x = x.ger(freqs.to(x.dtype)) 19 | x = torch.cat([x.cos(), x.sin()], dim=1) 20 | return x 21 | 22 | def weight_init(shape, mode, fan_in, fan_out): 23 | if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) 24 | if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) 25 | if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) 26 | if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape) 27 | raise ValueError(f'Invalid init mode "{mode}"') 28 | 29 | 30 | class Conv2d(nn.Module): 31 | def __init__(self, 32 | in_channels, out_channels, kernel, bias=True, up=False, down=False, 33 | resample_filter=[1,1], fused_resample=False, init_mode='kaiming_normal', init_weight=1, init_bias=0, 34 | ): 35 | assert not (up and down) 36 | super().__init__() 37 | self.in_channels = in_channels 38 | self.out_channels = out_channels 39 | self.up = up 40 | self.down = down 41 | self.fused_resample = fused_resample 42 | init_kwargs = dict(mode=init_mode, fan_in=in_channels*kernel*kernel, fan_out=out_channels*kernel*kernel) 43 | self.weight = nn.Parameter(weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) * init_weight) if kernel else None 44 | self.bias = nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) if kernel and bias else None 45 | f = torch.as_tensor(resample_filter, dtype=torch.float32) 46 | f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square() 47 | self.register_buffer('resample_filter', f if up or down else None) 48 | 49 | def forward(self, x): 50 | w = self.weight.to(x.dtype) if self.weight is not None else None 51 | b = self.bias.to(x.dtype) if self.bias is not None else None 52 | f = self.resample_filter.to(x.dtype) if self.resample_filter is not None else None 53 | w_pad = w.shape[-1] // 2 if w is not None else 0 54 | f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 55 | 56 | if self.fused_resample and self.up and w is not None: 57 | x = F.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=max(f_pad - w_pad, 0)) 58 | x = F.conv2d(x, w, padding=max(w_pad - f_pad, 0)) 59 | elif self.fused_resample and self.down and w is not None: 60 | x = F.conv2d(x, w, padding=w_pad+f_pad) 61 | x = F.conv2d(x, f.tile([self.out_channels, 1, 1, 1]), groups=self.out_channels, stride=2) 62 | else: 63 | if self.up: 64 | x = F.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad) 65 | if self.down: 66 | x = F.conv2d(x, f.tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad) 67 | if w is not None: 68 | x = F.conv2d(x, w, padding=w_pad) 69 | if b is not None: 70 | x = x.add_(b.reshape(1, -1, 1, 1)) 71 | return x 72 | 73 | class Linear(nn.Module): 74 | def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0): 75 | super().__init__() 76 | self.in_features = in_features 77 | self.out_features = out_features 78 | init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) 79 | self.weight = nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight) 80 | self.bias = nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None 81 | 82 | def forward(self, x): 83 | x = x @ self.weight.to(x.dtype).t() 84 | if self.bias is not None: 85 | x = x.add_(self.bias.to(x.dtype)) 86 | return x 87 | 88 | class GroupNorm(nn.Module): 89 | def __init__(self, num_channels, num_groups=32, min_channels_per_group=4, eps=1e-5): 90 | super().__init__() 91 | self.num_groups = min(num_groups, num_channels // min_channels_per_group) 92 | self.eps = eps 93 | self.weight = nn.Parameter(torch.ones(num_channels)) 94 | self.bias = nn.Parameter(torch.zeros(num_channels)) 95 | 96 | def forward(self, x): 97 | x = F.group_norm(x, num_groups=self.num_groups, weight=self.weight.to(x.dtype), bias=self.bias.to(x.dtype), eps=self.eps) 98 | return x 99 | 100 | class AttentionOp(torch.autograd.Function): 101 | @staticmethod 102 | def forward(ctx, q, k): 103 | w = torch.einsum('ncq,nck->nqk', q.to(torch.float32), (k / np.sqrt(k.shape[1])).to(torch.float32)).softmax(dim=2).to(q.dtype) 104 | ctx.save_for_backward(q, k, w) 105 | return w 106 | 107 | @staticmethod 108 | def backward(ctx, dw): 109 | q, k, w = ctx.saved_tensors 110 | db = torch._softmax_backward_data(grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32) 111 | dq = torch.einsum('nck,nqk->ncq', k.to(torch.float32), db).to(q.dtype) / np.sqrt(k.shape[1]) 112 | dk = torch.einsum('ncq,nqk->nck', q.to(torch.float32), db).to(k.dtype) / np.sqrt(k.shape[1]) 113 | return dq, dk -------------------------------------------------------------------------------- /models/ncsnpp.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from . import utils, layers, layerspp, normalization 4 | import torch.nn as nn 5 | import functools 6 | import torch 7 | import numpy as np 8 | 9 | ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp 10 | ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp 11 | Combine = layerspp.Combine 12 | conv3x3 = layerspp.conv3x3 13 | conv1x1 = layerspp.conv1x1 14 | get_act = layers.get_act 15 | get_normalization = normalization.get_normalization 16 | default_initializer = layers.default_init 17 | 18 | 19 | @utils.register_model(name='ncsnpp') 20 | class NCSNpp(nn.Module): 21 | """NCSN++ model""" 22 | 23 | def __init__(self, config): 24 | super().__init__() 25 | self.config = config 26 | self.act = act = get_act(config) 27 | self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config), dtype=torch.float32)) 28 | 29 | self.nf = nf = config.model.nf 30 | ch_mult = config.model.ch_mult 31 | self.num_res_blocks = num_res_blocks = config.model.num_res_blocks 32 | self.attn_resolutions = attn_resolutions = config.model.attn_resolutions 33 | dropout = config.model.dropout 34 | resamp_with_conv = config.model.resamp_with_conv 35 | self.num_resolutions = num_resolutions = len(ch_mult) 36 | self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] 37 | 38 | self.conditional = conditional = config.model.conditional # noise-conditional 39 | fir = config.model.fir 40 | fir_kernel = config.model.fir_kernel 41 | self.skip_rescale = skip_rescale = config.model.skip_rescale 42 | self.resblock_type = resblock_type = config.model.resblock_type.lower() 43 | self.progressive = progressive = config.model.progressive.lower() 44 | self.progressive_input = progressive_input = config.model.progressive_input.lower() 45 | self.embedding_type = embedding_type = config.model.embedding_type.lower() 46 | init_scale = config.model.init_scale 47 | assert progressive in ['none', 'output_skip', 'residual'] 48 | assert progressive_input in ['none', 'input_skip', 'residual'] 49 | assert embedding_type in ['fourier', 'positional'] 50 | combine_method = config.model.progressive_combine.lower() 51 | combiner = functools.partial(Combine, method=combine_method) 52 | 53 | modules = [] 54 | # timestep/noise_level embedding; only for continuous training 55 | if embedding_type == 'fourier': 56 | # Gaussian Fourier features embeddings. 57 | 58 | modules.append(layerspp.GaussianFourierProjection( 59 | embedding_size=nf, scale=config.model.fourier_scale 60 | )) 61 | embed_dim = 2 * nf 62 | 63 | elif embedding_type == 'positional': 64 | embed_dim = nf 65 | 66 | else: 67 | raise ValueError(f'embedding type {embedding_type} unknown.') 68 | 69 | if conditional: 70 | modules.append(nn.Linear(embed_dim, nf * 4)) 71 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 72 | nn.init.zeros_(modules[-1].bias) 73 | modules.append(nn.Linear(nf * 4, nf * 4)) 74 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 75 | nn.init.zeros_(modules[-1].bias) 76 | 77 | AttnBlock = functools.partial(layerspp.AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale) 78 | 79 | Upsample = functools.partial(layerspp.Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) 80 | 81 | if progressive == 'output_skip': 82 | self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) 83 | elif progressive == 'residual': 84 | pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir, fir_kernel=fir_kernel, with_conv=True) 85 | 86 | Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) 87 | 88 | if progressive_input == 'input_skip': 89 | self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) 90 | elif progressive_input == 'residual': 91 | pyramid_downsample = functools.partial(layerspp.Downsample, fir=fir, fir_kernel=fir_kernel, with_conv=True) 92 | 93 | if resblock_type == 'ddpm': 94 | ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, dropout=dropout, init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4) 95 | 96 | elif resblock_type == 'biggan': 97 | ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act, dropout=dropout, fir=fir, fir_kernel=fir_kernel, init_scale=init_scale, 98 | skip_rescale=skip_rescale, temb_dim=nf * 4) 99 | 100 | else: 101 | raise ValueError(f'resblock type {resblock_type} unrecognized.') 102 | 103 | channels = config.data.num_channels 104 | 105 | # Downsampling block 106 | 107 | if progressive_input != 'none': 108 | input_pyramid_ch = channels 109 | 110 | modules.append(conv3x3(channels, nf)) 111 | hs_c = [nf] 112 | 113 | in_ch = nf 114 | for i_level in range(num_resolutions): 115 | # Residual blocks for this resolution 116 | for i_block in range(num_res_blocks): 117 | out_ch = nf * ch_mult[i_level] 118 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) 119 | in_ch = out_ch 120 | 121 | if all_resolutions[i_level] in attn_resolutions: 122 | modules.append(AttnBlock(channels=in_ch)) 123 | hs_c.append(in_ch) 124 | 125 | if i_level != num_resolutions - 1: 126 | if resblock_type == 'ddpm': 127 | modules.append(Downsample(in_ch=in_ch)) 128 | else: 129 | modules.append(ResnetBlock(down=True, in_ch=in_ch)) 130 | 131 | if progressive_input == 'input_skip': 132 | modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) 133 | if combine_method == 'cat': 134 | in_ch *= 2 135 | 136 | elif progressive_input == 'residual': 137 | modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) 138 | input_pyramid_ch = in_ch 139 | 140 | hs_c.append(in_ch) 141 | 142 | in_ch = hs_c[-1] 143 | modules.append(ResnetBlock(in_ch=in_ch)) 144 | modules.append(AttnBlock(channels=in_ch)) 145 | modules.append(ResnetBlock(in_ch=in_ch)) 146 | 147 | pyramid_ch = 0 148 | # Upsampling block 149 | for i_level in reversed(range(num_resolutions)): 150 | for i_block in range(num_res_blocks + 1): 151 | out_ch = nf * ch_mult[i_level] 152 | modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) 153 | in_ch = out_ch 154 | 155 | if all_resolutions[i_level] in attn_resolutions: 156 | modules.append(AttnBlock(channels=in_ch)) 157 | 158 | if progressive != 'none': 159 | if i_level == num_resolutions - 1: 160 | if progressive == 'output_skip': 161 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) 162 | modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) 163 | pyramid_ch = channels 164 | elif progressive == 'residual': 165 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) 166 | modules.append(conv3x3(in_ch, in_ch, bias=True)) 167 | pyramid_ch = in_ch 168 | else: 169 | raise ValueError(f'{progressive} is not a valid name.') 170 | else: 171 | if progressive == 'output_skip': 172 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) 173 | modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) 174 | pyramid_ch = channels 175 | elif progressive == 'residual': 176 | modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) 177 | pyramid_ch = in_ch 178 | else: 179 | raise ValueError(f'{progressive} is not a valid name') 180 | 181 | if i_level != 0: 182 | if resblock_type == 'ddpm': 183 | modules.append(Upsample(in_ch=in_ch)) 184 | else: 185 | modules.append(ResnetBlock(in_ch=in_ch, up=True)) 186 | 187 | assert not hs_c 188 | 189 | if progressive != 'output_skip': 190 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) 191 | modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) 192 | 193 | self.all_modules = nn.ModuleList(modules) 194 | 195 | def forward(self, x, time_cond, class_labels=None): 196 | # timestep/noise_level embedding; only for continuous training 197 | modules = self.all_modules 198 | m_idx = 0 199 | if self.embedding_type == 'fourier': 200 | # Gaussian Fourier features embeddings. 201 | used_sigmas = time_cond 202 | temb = modules[m_idx](torch.log(used_sigmas)) 203 | m_idx += 1 204 | 205 | elif self.embedding_type == 'positional': 206 | # Sinusoidal positional embeddings. 207 | timesteps = time_cond 208 | used_sigmas = self.sigmas[time_cond.long()] 209 | temb = layers.get_timestep_embedding(timesteps, self.nf) 210 | 211 | else: 212 | raise ValueError(f'embedding type {self.embedding_type} unknown.') 213 | 214 | if self.conditional: 215 | temb = modules[m_idx](temb) 216 | m_idx += 1 217 | temb = modules[m_idx](self.act(temb)) 218 | m_idx += 1 219 | else: 220 | temb = None 221 | 222 | # Transform data to [-1, 1] 223 | x = 2 * x - 1. 224 | 225 | # Downsampling block 226 | input_pyramid = None 227 | if self.progressive_input != 'none': 228 | input_pyramid = x 229 | 230 | hs = [modules[m_idx](x)] 231 | m_idx += 1 232 | for i_level in range(self.num_resolutions): 233 | # Residual blocks for this resolution 234 | for i_block in range(self.num_res_blocks): 235 | h = modules[m_idx](hs[-1], temb) 236 | m_idx += 1 237 | if h.shape[-1] in self.attn_resolutions: 238 | h = modules[m_idx](h) 239 | m_idx += 1 240 | 241 | hs.append(h) 242 | 243 | if i_level != self.num_resolutions - 1: 244 | if self.resblock_type == 'ddpm': 245 | h = modules[m_idx](hs[-1]) 246 | m_idx += 1 247 | else: 248 | h = modules[m_idx](hs[-1], temb) 249 | m_idx += 1 250 | 251 | if self.progressive_input == 'input_skip': 252 | input_pyramid = self.pyramid_downsample(input_pyramid) 253 | h = modules[m_idx](input_pyramid, h) 254 | m_idx += 1 255 | 256 | elif self.progressive_input == 'residual': 257 | input_pyramid = modules[m_idx](input_pyramid) 258 | m_idx += 1 259 | if self.skip_rescale: 260 | input_pyramid = (input_pyramid + h) / np.sqrt(2.) 261 | else: 262 | input_pyramid = input_pyramid + h 263 | h = input_pyramid 264 | 265 | hs.append(h) 266 | 267 | h = hs[-1] 268 | h = modules[m_idx](h, temb) 269 | m_idx += 1 270 | h = modules[m_idx](h) 271 | m_idx += 1 272 | h = modules[m_idx](h, temb) 273 | m_idx += 1 274 | 275 | pyramid = None 276 | 277 | # Upsampling block 278 | for i_level in reversed(range(self.num_resolutions)): 279 | for i_block in range(self.num_res_blocks + 1): 280 | h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) 281 | m_idx += 1 282 | 283 | if h.shape[-1] in self.attn_resolutions: 284 | h = modules[m_idx](h) 285 | m_idx += 1 286 | 287 | if self.progressive != 'none': 288 | if i_level == self.num_resolutions - 1: 289 | if self.progressive == 'output_skip': 290 | pyramid = self.act(modules[m_idx](h)) 291 | m_idx += 1 292 | pyramid = modules[m_idx](pyramid) 293 | m_idx += 1 294 | elif self.progressive == 'residual': 295 | pyramid = self.act(modules[m_idx](h)) 296 | m_idx += 1 297 | pyramid = modules[m_idx](pyramid) 298 | m_idx += 1 299 | else: 300 | raise ValueError(f'{self.progressive} is not a valid name.') 301 | else: 302 | if self.progressive == 'output_skip': 303 | pyramid = self.pyramid_upsample(pyramid) 304 | pyramid_h = self.act(modules[m_idx](h)) 305 | m_idx += 1 306 | pyramid_h = modules[m_idx](pyramid_h) 307 | m_idx += 1 308 | pyramid = pyramid + pyramid_h 309 | elif self.progressive == 'residual': 310 | pyramid = modules[m_idx](pyramid) 311 | m_idx += 1 312 | if self.skip_rescale: 313 | pyramid = (pyramid + h) / np.sqrt(2.) 314 | else: 315 | pyramid = pyramid + h 316 | h = pyramid 317 | else: 318 | raise ValueError(f'{self.progressive} is not a valid name') 319 | 320 | if i_level != 0: 321 | if self.resblock_type == 'ddpm': 322 | h = modules[m_idx](h) 323 | m_idx += 1 324 | else: 325 | h = modules[m_idx](h, temb) 326 | m_idx += 1 327 | 328 | assert not hs 329 | 330 | if self.progressive == 'output_skip': 331 | h = pyramid 332 | else: 333 | h = self.act(modules[m_idx](h)) 334 | m_idx += 1 335 | h = modules[m_idx](h) 336 | m_idx += 1 337 | 338 | assert m_idx == len(modules) 339 | if self.config.model.scale_by_sigma: 340 | used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) 341 | h = h / used_sigmas 342 | 343 | return h 344 | -------------------------------------------------------------------------------- /models/normalization.py: -------------------------------------------------------------------------------- 1 | """Normalization layers.""" 2 | import torch.nn as nn 3 | import torch 4 | import functools 5 | 6 | 7 | def get_normalization(config, conditional=False): 8 | """Obtain normalization modules from the config file.""" 9 | norm = config.model.normalization 10 | if conditional: 11 | if norm == 'InstanceNorm++': 12 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes) 13 | else: 14 | raise NotImplementedError(f'{norm} not implemented yet.') 15 | else: 16 | if norm == 'InstanceNorm': 17 | return nn.InstanceNorm2d 18 | elif norm == 'InstanceNorm++': 19 | return InstanceNorm2dPlus 20 | elif norm == 'VarianceNorm': 21 | return VarianceNorm2d 22 | elif norm == 'GroupNorm': 23 | return nn.GroupNorm 24 | else: 25 | raise ValueError('Unknown normalization: %s' % norm) 26 | 27 | 28 | class ConditionalBatchNorm2d(nn.Module): 29 | def __init__(self, num_features, num_classes, bias=True): 30 | super().__init__() 31 | self.num_features = num_features 32 | self.bias = bias 33 | self.bn = nn.BatchNorm2d(num_features, affine=False) 34 | if self.bias: 35 | self.embed = nn.Embedding(num_classes, num_features * 2) 36 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 37 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 38 | else: 39 | self.embed = nn.Embedding(num_classes, num_features) 40 | self.embed.weight.data.uniform_() 41 | 42 | def forward(self, x, y): 43 | out = self.bn(x) 44 | if self.bias: 45 | gamma, beta = self.embed(y).chunk(2, dim=1) 46 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 47 | else: 48 | gamma = self.embed(y) 49 | out = gamma.view(-1, self.num_features, 1, 1) * out 50 | return out 51 | 52 | 53 | class ConditionalInstanceNorm2d(nn.Module): 54 | def __init__(self, num_features, num_classes, bias=True): 55 | super().__init__() 56 | self.num_features = num_features 57 | self.bias = bias 58 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 59 | if bias: 60 | self.embed = nn.Embedding(num_classes, num_features * 2) 61 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 62 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 63 | else: 64 | self.embed = nn.Embedding(num_classes, num_features) 65 | self.embed.weight.data.uniform_() 66 | 67 | def forward(self, x, y): 68 | h = self.instance_norm(x) 69 | if self.bias: 70 | gamma, beta = self.embed(y).chunk(2, dim=-1) 71 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 72 | else: 73 | gamma = self.embed(y) 74 | out = gamma.view(-1, self.num_features, 1, 1) * h 75 | return out 76 | 77 | 78 | class ConditionalVarianceNorm2d(nn.Module): 79 | def __init__(self, num_features, num_classes, bias=False): 80 | super().__init__() 81 | self.num_features = num_features 82 | self.bias = bias 83 | self.embed = nn.Embedding(num_classes, num_features) 84 | self.embed.weight.data.normal_(1, 0.02) 85 | 86 | def forward(self, x, y): 87 | vars = torch.var(x, dim=(2, 3), keepdim=True) 88 | h = x / torch.sqrt(vars + 1e-5) 89 | 90 | gamma = self.embed(y) 91 | out = gamma.view(-1, self.num_features, 1, 1) * h 92 | return out 93 | 94 | 95 | class VarianceNorm2d(nn.Module): 96 | def __init__(self, num_features, bias=False): 97 | super().__init__() 98 | self.num_features = num_features 99 | self.bias = bias 100 | self.alpha = nn.Parameter(torch.zeros(num_features)) 101 | self.alpha.data.normal_(1, 0.02) 102 | 103 | def forward(self, x): 104 | vars = torch.var(x, dim=(2, 3), keepdim=True) 105 | h = x / torch.sqrt(vars + 1e-5) 106 | 107 | out = self.alpha.view(-1, self.num_features, 1, 1) * h 108 | return out 109 | 110 | 111 | class ConditionalNoneNorm2d(nn.Module): 112 | def __init__(self, num_features, num_classes, bias=True): 113 | super().__init__() 114 | self.num_features = num_features 115 | self.bias = bias 116 | if bias: 117 | self.embed = nn.Embedding(num_classes, num_features * 2) 118 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 119 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 120 | else: 121 | self.embed = nn.Embedding(num_classes, num_features) 122 | self.embed.weight.data.uniform_() 123 | 124 | def forward(self, x, y): 125 | if self.bias: 126 | gamma, beta = self.embed(y).chunk(2, dim=-1) 127 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) 128 | else: 129 | gamma = self.embed(y) 130 | out = gamma.view(-1, self.num_features, 1, 1) * x 131 | return out 132 | 133 | 134 | class NoneNorm2d(nn.Module): 135 | def __init__(self, num_features, bias=True): 136 | super().__init__() 137 | 138 | def forward(self, x): 139 | return x 140 | 141 | 142 | class InstanceNorm2dPlus(nn.Module): 143 | def __init__(self, num_features, bias=True): 144 | super().__init__() 145 | self.num_features = num_features 146 | self.bias = bias 147 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 148 | self.alpha = nn.Parameter(torch.zeros(num_features)) 149 | self.gamma = nn.Parameter(torch.zeros(num_features)) 150 | self.alpha.data.normal_(1, 0.02) 151 | self.gamma.data.normal_(1, 0.02) 152 | if bias: 153 | self.beta = nn.Parameter(torch.zeros(num_features)) 154 | 155 | def forward(self, x): 156 | means = torch.mean(x, dim=(2, 3)) 157 | m = torch.mean(means, dim=-1, keepdim=True) 158 | v = torch.var(means, dim=-1, keepdim=True) 159 | means = (means - m) / (torch.sqrt(v + 1e-5)) 160 | h = self.instance_norm(x) 161 | 162 | if self.bias: 163 | h = h + means[..., None, None] * self.alpha[..., None, None] 164 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) 165 | else: 166 | h = h + means[..., None, None] * self.alpha[..., None, None] 167 | out = self.gamma.view(-1, self.num_features, 1, 1) * h 168 | return out 169 | 170 | 171 | class ConditionalInstanceNorm2dPlus(nn.Module): 172 | def __init__(self, num_features, num_classes, bias=True): 173 | super().__init__() 174 | self.num_features = num_features 175 | self.bias = bias 176 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 177 | if bias: 178 | self.embed = nn.Embedding(num_classes, num_features * 3) 179 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 180 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 181 | else: 182 | self.embed = nn.Embedding(num_classes, 2 * num_features) 183 | self.embed.weight.data.normal_(1, 0.02) 184 | 185 | def forward(self, x, y): 186 | means = torch.mean(x, dim=(2, 3)) 187 | m = torch.mean(means, dim=-1, keepdim=True) 188 | v = torch.var(means, dim=-1, keepdim=True) 189 | means = (means - m) / (torch.sqrt(v + 1e-5)) 190 | h = self.instance_norm(x) 191 | 192 | if self.bias: 193 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) 194 | h = h + means[..., None, None] * alpha[..., None, None] 195 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 196 | else: 197 | gamma, alpha = self.embed(y).chunk(2, dim=-1) 198 | h = h + means[..., None, None] * alpha[..., None, None] 199 | out = gamma.view(-1, self.num_features, 1, 1) * h 200 | return out 201 | -------------------------------------------------------------------------------- /models/up_or_down_sampling.py: -------------------------------------------------------------------------------- 1 | """Layers used for up-sampling or down-sampling images. 2 | 3 | Many functions are ported from https://github.com/NVlabs/stylegan2. 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | #from op import upfirdn2d 11 | 12 | 13 | # Function ported from StyleGAN2 14 | def get_weight(module, shape, weight_var='weight', kernel_init=None): 15 | """Get/create weight tensor for a convolution or fully-connected layer.""" 16 | 17 | return module.param(weight_var, kernel_init, shape) 18 | 19 | 20 | class Conv2d(nn.Module): 21 | """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" 22 | 23 | def __init__(self, in_ch, out_ch, kernel, up=False, down=False, 24 | resample_kernel=(1, 3, 3, 1), 25 | use_bias=True, 26 | kernel_init=None): 27 | super().__init__() 28 | assert not (up and down) 29 | assert kernel >= 1 and kernel % 2 == 1 30 | self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) 31 | if kernel_init is not None: 32 | self.weight.data = kernel_init(self.weight.data.shape) 33 | if use_bias: 34 | self.bias = nn.Parameter(torch.zeros(out_ch)) 35 | 36 | self.up = up 37 | self.down = down 38 | self.resample_kernel = resample_kernel 39 | self.kernel = kernel 40 | self.use_bias = use_bias 41 | 42 | def forward(self, x): 43 | if self.up: 44 | x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) 45 | elif self.down: 46 | x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) 47 | else: 48 | x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) 49 | 50 | if self.use_bias: 51 | x = x + self.bias.reshape(1, -1, 1, 1) 52 | 53 | return x 54 | 55 | 56 | def naive_upsample_2d(x, factor=2): 57 | _N, C, H, W = x.shape 58 | x = torch.reshape(x, (-1, C, H, 1, W, 1)) 59 | x = x.repeat(1, 1, 1, factor, 1, factor) 60 | return torch.reshape(x, (-1, C, H * factor, W * factor)) 61 | 62 | 63 | def naive_downsample_2d(x, factor=2): 64 | _N, C, H, W = x.shape 65 | x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) 66 | return torch.mean(x, dim=(3, 5)) 67 | 68 | 69 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1): 70 | """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. 71 | 72 | Padding is performed only once at the beginning, not between the 73 | operations. 74 | The fused op is considerably more efficient than performing the same 75 | calculation 76 | using standard TensorFlow ops. It supports gradients of arbitrary order. 77 | Args: 78 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. 79 | w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. 80 | Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. 81 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 82 | (separable). The default is `[1] * factor`, which corresponds to 83 | nearest-neighbor upsampling. 84 | factor: Integer upsampling factor (default: 2). 85 | gain: Scaling factor for signal magnitude (default: 1.0). 86 | 87 | Returns: 88 | Tensor of the shape `[N, C, H * factor, W * factor]` or 89 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 90 | """ 91 | 92 | assert isinstance(factor, int) and factor >= 1 93 | 94 | # Check weight shape. 95 | assert len(w.shape) == 4 96 | convH = w.shape[2] 97 | convW = w.shape[3] 98 | inC = w.shape[1] 99 | outC = w.shape[0] 100 | 101 | assert convW == convH 102 | 103 | # Setup filter kernel. 104 | if k is None: 105 | k = [1] * factor 106 | k = _setup_kernel(k) * (gain * (factor ** 2)) 107 | p = (k.shape[0] - factor) - (convW - 1) 108 | 109 | # Determine data dimensions. 110 | stride = [1, factor] 111 | output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) 112 | output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, 113 | output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) 114 | assert output_padding[0] >= 0 and output_padding[1] >= 0 115 | num_groups = _shape(x, 1) // inC 116 | 117 | # Transpose weights. 118 | w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) 119 | w = w.flip([-1, -2]).permute(0, 2, 1, 3, 4) 120 | w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) 121 | 122 | x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) 123 | return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) 124 | 125 | 126 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1): 127 | """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. 128 | 129 | Padding is performed only once at the beginning, not between the operations. 130 | The fused op is considerably more efficient than performing the same 131 | calculation 132 | using standard TensorFlow ops. It supports gradients of arbitrary order. 133 | Args: 134 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 135 | C]`. 136 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 137 | outChannels]`. Grouped convolution can be performed by `inChannels = 138 | x.shape[0] // numGroups`. 139 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 140 | (separable). The default is `[1] * factor`, which corresponds to 141 | average pooling. 142 | factor: Integer downsampling factor (default: 2). 143 | gain: Scaling factor for signal magnitude (default: 1.0). 144 | 145 | Returns: 146 | Tensor of the shape `[N, C, H // factor, W // factor]` or 147 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 148 | """ 149 | 150 | assert isinstance(factor, int) and factor >= 1 151 | _outC, _inC, convH, convW = w.shape 152 | assert convW == convH 153 | if k is None: 154 | k = [1] * factor 155 | k = _setup_kernel(k) * gain 156 | p = (k.shape[0] - factor) + (convW - 1) 157 | s = [factor, factor] 158 | x = upfirdn2d(x, torch.tensor(k, device=x.device), 159 | pad=((p + 1) // 2, p // 2)) 160 | return F.conv2d(x, w, stride=s, padding=0) 161 | 162 | 163 | def _setup_kernel(k): 164 | k = np.asarray(k, dtype=np.float32) 165 | if k.ndim == 1: 166 | k = np.outer(k, k) 167 | k /= np.sum(k) 168 | assert k.ndim == 2 169 | assert k.shape[0] == k.shape[1] 170 | return k 171 | 172 | 173 | def _shape(x, dim): 174 | return x.shape[dim] 175 | 176 | 177 | def upsample_2d(x, k=None, factor=2, gain=1): 178 | r"""Upsample a batch of 2D images with the given filter. 179 | 180 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 181 | and upsamples each image with the given filter. The filter is normalized so 182 | that 183 | if the input pixels are constant, they will be scaled by the specified 184 | `gain`. 185 | Pixels outside the image are assumed to be zero, and the filter is padded 186 | with 187 | zeros so that its shape is a multiple of the upsampling factor. 188 | Args: 189 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 190 | C]`. 191 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 192 | (separable). The default is `[1] * factor`, which corresponds to 193 | nearest-neighbor upsampling. 194 | factor: Integer upsampling factor (default: 2). 195 | gain: Scaling factor for signal magnitude (default: 1.0). 196 | 197 | Returns: 198 | Tensor of the shape `[N, C, H * factor, W * factor]` 199 | """ 200 | assert isinstance(factor, int) and factor >= 1 201 | if k is None: 202 | k = [1] * factor 203 | k = _setup_kernel(k) * (gain * (factor ** 2)) 204 | p = k.shape[0] - factor 205 | return upfirdn2d(x, torch.tensor(k, device=x.device), 206 | up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) 207 | 208 | 209 | def downsample_2d(x, k=None, factor=2, gain=1): 210 | r"""Downsample a batch of 2D images with the given filter. 211 | 212 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 213 | and downsamples each image with the given filter. The filter is normalized 214 | so that 215 | if the input pixels are constant, they will be scaled by the specified 216 | `gain`. 217 | Pixels outside the image are assumed to be zero, and the filter is padded 218 | with 219 | zeros so that its shape is a multiple of the downsampling factor. 220 | Args: 221 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 222 | C]`. 223 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 224 | (separable). The default is `[1] * factor`, which corresponds to 225 | average pooling. 226 | factor: Integer downsampling factor (default: 2). 227 | gain: Scaling factor for signal magnitude (default: 1.0). 228 | 229 | Returns: 230 | Tensor of the shape `[N, C, H // factor, W // factor]` 231 | """ 232 | 233 | assert isinstance(factor, int) and factor >= 1 234 | if k is None: 235 | k = [1] * factor 236 | k = _setup_kernel(k) * gain 237 | p = k.shape[0] - factor 238 | return upfirdn2d(x, torch.tensor(k).to(x), 239 | down=factor, pad=((p + 1) // 2, p // 2)) 240 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | """All functions and modules related to model definition. 2 | """ 3 | 4 | import torch 5 | import sde_lib 6 | import numpy as np 7 | 8 | _MODELS = {} 9 | 10 | 11 | def register_model(cls=None, *, name=None): 12 | """A decorator for registering model classes.""" 13 | 14 | def _register(cls): 15 | if name is None: 16 | local_name = cls.__name__ 17 | else: 18 | local_name = name 19 | if local_name in _MODELS: 20 | raise ValueError( 21 | f'Already registered model with name: {local_name}') 22 | _MODELS[local_name] = cls 23 | return cls 24 | 25 | if cls is None: 26 | return _register 27 | else: 28 | return _register(cls) 29 | 30 | 31 | def get_model(name): 32 | return _MODELS[name] 33 | 34 | 35 | def get_sigmas(config): 36 | """Get sigmas --- the set of noise levels for SMLD from config files. 37 | Args: 38 | config: A ConfigDict object parsed from the config file 39 | Returns: 40 | sigmas: a jax numpy arrary of noise levels 41 | """ 42 | sigmas = np.exp( 43 | np.linspace(np.log(config.sde.sigma_max), np.log(config.sde.sigma_min), config.sde.num_scales)) 44 | 45 | return sigmas 46 | 47 | 48 | def create_model(config): 49 | """Create the score model.""" 50 | model_name = config.model.name 51 | score_model = get_model(model_name)(config) 52 | return score_model 53 | 54 | 55 | def get_model_fn(model, train=False, ): 56 | """Create a function to give the output of the score-based model. 57 | 58 | Args: 59 | model: The score model. 60 | train: `True` for training and `False` for evaluation. 61 | 62 | Returns: 63 | A model function. 64 | """ 65 | 66 | def model_fn(x, time_cond, class_labels=None): 67 | """Compute the output of the score-based model. 68 | 69 | Args: 70 | x: A mini-batch of input data. 71 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 72 | for different models. 73 | 74 | Returns: 75 | A tuple of (model output, new mutable states) 76 | """ 77 | if train: 78 | model.train() 79 | else: 80 | model.eval() 81 | 82 | return model(x, time_cond, class_labels=class_labels) 83 | 84 | return model_fn 85 | 86 | 87 | def get_score_fn(sde, model, train=False): 88 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 89 | 90 | Args: 91 | sde: An `sde_lib.SDE` object that represents the forward SDE. 92 | model: A score model. 93 | train: `True` for training and `False` for evaluation. 94 | 95 | Returns: 96 | A score function. 97 | """ 98 | model_fn = get_model_fn(model, train=train) 99 | 100 | def score_fn(x, t, class_labels=None): 101 | time_cond = sde.marginal_prob(torch.zeros_like(x), t)[1] 102 | score = model_fn(x, time_cond, class_labels=class_labels) 103 | return score 104 | 105 | return score_fn 106 | 107 | 108 | def get_cf_score_fn(sde, model, class_labels, weight): 109 | """Wraps `score_fn` with weighting. 110 | 111 | Args: 112 | sde: A `sde_lib.SDE` object 113 | model: the score model 114 | 115 | Returns: 116 | A weighted score function. Input of x, t, class_labels, and weight. 117 | """ 118 | score_fn = get_score_fn(sde, model, train=False) 119 | 120 | def weighted_score_fn(x, t): 121 | concat_x = x.repeat(2, 1, 1, 1) 122 | concat_t = t.repeat(2) 123 | concat_cl = torch.cat([class_labels, torch.zeros_like(class_labels)], dim=0) 124 | 125 | concat_score = score_fn(concat_x, concat_t, concat_cl) 126 | score_conditioned = concat_score[:x.shape[0]] 127 | score_clean = concat_score[x.shape[0]:] 128 | 129 | return (1 + weight)[:, None, None, None] * score_conditioned - weight[:, None, None, None] * score_clean 130 | 131 | return weighted_score_fn 132 | 133 | 134 | def to_flattened_numpy(x): 135 | """Flatten a torch tensor `x` and convert it to numpy.""" 136 | return x.detach().cpu().numpy().reshape((-1,)) 137 | 138 | 139 | def from_flattened_numpy(x, shape): 140 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 141 | return torch.from_numpy(x.reshape(shape)) 142 | -------------------------------------------------------------------------------- /models/vdm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from typing import Iterable 7 | 8 | from . import utils 9 | from .layersv2 import * 10 | from math import pi 11 | 12 | 13 | class ImageFourierFeatures(nn.Module): 14 | """Fourier features used in VDMs. Meant for usage in image space""" 15 | def __init__(self, start=6, end=8): 16 | super().__init__() 17 | self.register_buffer("freqs", 2 ** torch.arange(start, end)) 18 | 19 | def forward(self, x): 20 | freqs = (self.freqs * 2 * pi).repeat(x.shape[1]) 21 | x_inp = x 22 | x = x.repeat_interleave(len(self.freqs), dim=1) 23 | 24 | x = freqs[None, :, None, None] * x 25 | return torch.cat([x_inp, x.sin(), x.cos()], dim=1) 26 | 27 | def extra_repr(self): 28 | return f"ImageFourierFeatures({self.freqs.detach().cpu().numpy()})" 29 | 30 | 31 | def get_timestep_embedding(timesteps, embedding_dim, dtype=torch.float32): 32 | assert len(timesteps.shape) == 1 33 | timesteps *= 1000. 34 | 35 | half_dim = embedding_dim // 2 36 | emb = np.log(10000) / (half_dim - 1) 37 | emb = (torch.arange(half_dim, dtype=dtype, device=timesteps.device) * -emb).exp() 38 | emb = timesteps.to(dtype)[:, None] * emb[None, :] 39 | emb = torch.cat([emb.sin(), emb.cos()], dim=-1) 40 | if embedding_dim % 2 == 1: # zero pad 41 | emb = F.pad(emb, (0, 1)) 42 | assert emb.shape == (timesteps.shape[0], embedding_dim) 43 | return emb 44 | 45 | 46 | class ResNetBlock(nn.Module): 47 | def __init__(self, in_ch, out_ch, cond_dim, dropout=0.1): 48 | super().__init__() 49 | self.conv1 = Conv2d(in_ch, out_ch, 3) 50 | self.conv2 = Conv2d(out_ch, out_ch, 3, init_weight=0) 51 | 52 | self.norm1 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 53 | self.norm2 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 54 | if in_ch != out_ch: 55 | self.skip = Conv2d(in_ch, out_ch, 1) 56 | else: 57 | self.skip = nn.Identity() 58 | 59 | self.cond_map = Linear(cond_dim, out_ch, bias=False, init_weight=0) 60 | 61 | self.dropout = dropout 62 | 63 | def forward(self, x, cond): 64 | h = x 65 | # activation for the last block 66 | h = F.silu(self.norm1(x)) 67 | h = self.conv1(h) 68 | 69 | # add in conditioning 70 | h += self.cond_map(cond)[:, :, None, None] 71 | 72 | h = F.silu(self.norm2(h)) 73 | h = F.dropout(h, p=self.dropout, training=self.training) 74 | 75 | h = self.conv2(h) 76 | x = h + self.skip(x) 77 | 78 | return x 79 | 80 | 81 | class AttnBlock(nn.Module): 82 | """Self-attention residual block.""" 83 | def __init__(self, channels, num_heads=1): 84 | super().__init__() 85 | self.num_heads = num_heads 86 | self.norm = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) 87 | self.qkv = Conv2d(channels, 3 * channels, 1) 88 | self.proj_out = Conv2d(channels, channels, 1, init_weight=0) 89 | 90 | def forward(self, x): 91 | q, k, v = self.qkv(self.norm(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2) 92 | w = AttentionOp.apply(q, k) 93 | a = torch.einsum('nqk,nck->ncq', w, v) 94 | x = self.proj_out(a.reshape(*x.shape)).add_(x) 95 | 96 | return x 97 | 98 | 99 | @utils.register_model(name='vdm') 100 | class VDM(nn.Module): 101 | def __init__(self, config): 102 | super().__init__() 103 | self.num_blocks = config.model.num_blocks 104 | self.channels = channels = config.model.channels 105 | self.attention = config.model.attention 106 | dropout = config.model.dropout 107 | input_ch = config.data.num_channels 108 | 109 | self.sigma_min = config.sde.sigma_min 110 | self.sigma_max = config.sde.sigma_max 111 | self.scale_by_sigma = config.model.scale_by_sigma 112 | 113 | self.cond_map = nn.Sequential( 114 | Linear(channels, 4 * channels), 115 | nn.SiLU(), 116 | Linear(4 * channels, 4 * channels), 117 | ) 118 | 119 | if config.model.image_fourier: 120 | self.image_fourier = ImageFourierFeatures(start=config.model.image_fourier_start, end=config.model.image_fourier_end) 121 | freqs = config.model.image_fourier_end - config.model.image_fourier_start 122 | fourier_channels = (2 * freqs + 1) * input_ch 123 | else: 124 | self.image_fourier = nn.Identity() 125 | fourier_channels = input_ch 126 | 127 | self.conv_in = Conv2d(fourier_channels, channels, 3) 128 | 129 | 130 | # "downsampling" 131 | enc = [] 132 | for _ in range(self.num_blocks): 133 | enc.append(ResNetBlock(channels, channels, 4 * channels, dropout=dropout)) 134 | if self.attention: 135 | enc.append(AttnBlock(channels)) 136 | self.enc = nn.ModuleList(enc) 137 | 138 | # middle 139 | self.mid1 = ResNetBlock(channels, channels, 4 * channels, dropout=dropout) 140 | self.midattn = AttnBlock(channels) 141 | self.mid2 = ResNetBlock(channels, channels, 4 * channels, dropout=dropout) 142 | 143 | # "upsampling" 144 | dec = [] 145 | for _ in range(self.num_blocks + 1): 146 | dec.append(ResNetBlock(2 * channels, channels, 4 * channels, dropout=dropout)) 147 | if self.attention: 148 | dec.append(AttnBlock(channels)) 149 | self.dec = nn.ModuleList(dec) 150 | 151 | # output 152 | self.out = nn.Sequential( 153 | nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6), 154 | nn.SiLU(), 155 | Conv2d(channels, input_ch, 3, init_weight=0) 156 | ) 157 | 158 | def forward(self, x, cond, class_labels=None): 159 | sigma_inp = cond 160 | t = (cond - self.sigma_min) / (self.sigma_max - self.sigma_min) 161 | temb = get_timestep_embedding(t, self.channels) 162 | cond = self.cond_map(temb) 163 | 164 | x = self.image_fourier(x) 165 | 166 | outputs = [] 167 | 168 | x = self.conv_in(x) 169 | outputs.append(x) 170 | 171 | for i in range(self.num_blocks): 172 | if self.attention: 173 | x = self.enc[2 * i](x, cond) 174 | x = self.enc[2 * i + 1](x) 175 | else: 176 | x = self.enc[i](x, cond) 177 | outputs.append(x) 178 | 179 | x = self.mid1(x, cond) 180 | x = self.midattn(x) 181 | x = self.mid2(x, cond) 182 | 183 | for i in range(self.num_blocks + 1): 184 | 185 | if self.attention: 186 | x = self.dec[2 * i](torch.cat((x, outputs.pop()), dim=1), cond) 187 | x = self.dec[2 * i + 1](x) 188 | else: 189 | x = self.dec[i](torch.cat((x, outputs.pop()), dim=1), cond) 190 | 191 | if len(outputs) > 0: 192 | raise ValueError("Something went wrong with the blocks") 193 | 194 | out = self.out(x) 195 | 196 | if self.scale_by_sigma: 197 | out = out / sigma_inp[:, None, None, None] 198 | 199 | return out 200 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | antlr4-python3-runtime==4.9.3 2 | certifi==2023.5.7 3 | charset-normalizer==3.1.0 4 | cloudpickle==2.2.1 5 | cmake==3.26.3 6 | filelock==3.12.0 7 | hydra-core==1.3.2 8 | hydra-submitit-launcher==1.2.0 9 | idna==3.4 10 | Jinja2==3.1.2 11 | lit==16.0.5.post0 12 | MarkupSafe==2.1.3 13 | mpmath==1.3.0 14 | networkx==3.1 15 | numpy==1.24.3 16 | nvidia-cublas-cu11==11.10.3.66 17 | nvidia-cuda-cupti-cu11==11.7.101 18 | nvidia-cuda-nvrtc-cu11==11.7.99 19 | nvidia-cuda-runtime-cu11==11.7.99 20 | nvidia-cudnn-cu11==8.5.0.96 21 | nvidia-cufft-cu11==10.9.0.58 22 | nvidia-curand-cu11==10.2.10.91 23 | nvidia-cusolver-cu11==11.4.0.1 24 | nvidia-cusparse-cu11==11.7.4.91 25 | nvidia-nccl-cu11==2.14.3 26 | nvidia-nvtx-cu11==11.7.91 27 | omegaconf==2.3.0 28 | packaging==23.1 29 | Pillow==9.5.0 30 | PyYAML==6.0 31 | requests==2.31.0 32 | scipy==1.10.1 33 | submitit==1.4.5 34 | sympy==1.12 35 | torch==2.0.1 36 | torchaudio==2.0.2 37 | torchvision==0.15.2 38 | triton==2.0.0 39 | typing_extensions==4.6.3 40 | urllib3==2.0.2 41 | -------------------------------------------------------------------------------- /run_train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import gc 3 | import os 4 | import os.path 5 | 6 | import hydra 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from hydra.core.hydra_config import HydraConfig 12 | from hydra.types import RunMode 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torchvision.utils import make_grid, save_image 15 | 16 | import datasets 17 | import losses 18 | import sampling 19 | import sde_lib 20 | import utils 21 | from models import adm, ncsnpp, vdm 22 | from models import utils as mutils 23 | from models import vdm 24 | from models.ema import ExponentialMovingAverage 25 | 26 | torch.backends.cudnn.benchmark = True 27 | 28 | 29 | def setup(rank, world_size, port): 30 | os.environ["MASTER_ADDR"] = "localhost" 31 | os.environ["MASTER_PORT"] = str(port) 32 | 33 | # initialize the process group 34 | dist.init_process_group( 35 | "nccl", rank=rank, world_size=world_size, timeout=datetime.timedelta(minutes=30) 36 | ) 37 | 38 | 39 | def cleanup(): 40 | dist.destroy_process_group() 41 | 42 | 43 | def run_multiprocess(rank, world_size, cfg, work_dir, port): 44 | try: 45 | setup(rank, world_size, port) 46 | _run(rank, world_size, work_dir, cfg) 47 | finally: 48 | cleanup() 49 | 50 | 51 | def _run(rank, world_size, work_dir, cfg): 52 | 53 | # Create directories for experimental logs 54 | sample_dir = os.path.join(work_dir, "samples") 55 | checkpoint_dir = os.path.join(work_dir, "checkpoints") 56 | checkpoint_meta_dir = os.path.join(work_dir, "checkpoints-meta", "checkpoint.pth") 57 | if rank == 0: 58 | utils.makedirs(sample_dir) 59 | utils.makedirs(checkpoint_dir) 60 | utils.makedirs(os.path.dirname(checkpoint_meta_dir)) 61 | 62 | # logging 63 | if rank == 0: 64 | logger = utils.get_logger(os.path.join(work_dir, "logs")) 65 | 66 | def mprint(msg): 67 | if rank == 0: 68 | logger.info(msg) 69 | 70 | # construct models etc... 71 | device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") 72 | 73 | score_model = mutils.create_model(cfg).to(device) 74 | score_model = DDP(score_model, device_ids=[rank], static_graph=True, find_unused_parameters=True) 75 | if torch.__version__.startswith('1.14'): 76 | score_model = torch.compile(score_model) 77 | ema = ExponentialMovingAverage( 78 | score_model.parameters(), decay=cfg.model.ema_rate) 79 | scaler = torch.cuda.amp.GradScaler() if cfg.model.name == "adm" else None 80 | optimizer = losses.get_optimizer(cfg, score_model.parameters()) 81 | 82 | mprint(score_model) 83 | mprint(f"EMA: {ema}") 84 | mprint(f"Optimizer: {optimizer}") 85 | mprint(f"Scaler: {scaler}.") 86 | 87 | state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0, scaler=scaler) 88 | 89 | state = utils.restore_checkpoint(checkpoint_meta_dir, state, device) 90 | initial_step = int(state['step']) 91 | 92 | # Build data iterators 93 | train_ds, eval_ds = datasets.get_dataset(cfg) 94 | 95 | train_iter = iter(train_ds) 96 | eval_iter = iter(eval_ds) 97 | 98 | sde = sde_lib.RVESDE(sigma_min=cfg.sde.sigma_min, sigma_max=cfg.sde.sigma_max, N=cfg.sde.num_scales) 99 | sampling_eps = 1e-5 100 | 101 | # Build one-step training and evaluation functions 102 | optimize_fn = losses.optimization_manager(cfg) 103 | reduce_mean = cfg.training.reduce_mean 104 | likelihood_weighting = cfg.training.likelihood_weighting 105 | train_step_fn = losses.get_step_fn(sde, 106 | train=True, 107 | optimize_fn=optimize_fn, 108 | reduce_mean=reduce_mean, 109 | likelihood_weighting=likelihood_weighting) 110 | eval_step_fn = losses.get_step_fn(sde, 111 | train=False, 112 | optimize_fn=optimize_fn, 113 | reduce_mean=reduce_mean, 114 | likelihood_weighting=likelihood_weighting) 115 | 116 | # Build samping functions 117 | if cfg.training.snapshot_sampling: 118 | sampling_shape = (cfg.training.batch_size // cfg.ngpus, 119 | cfg.data.num_channels, 120 | cfg.data.image_size, 121 | cfg.data.image_size) 122 | sampling_fn = sampling.get_sampling_fn( 123 | cfg, sde, sampling_shape, sampling_eps, device) 124 | 125 | num_train_steps = cfg.training.n_iters 126 | mprint(f"Starting training loop at step {initial_step}.") 127 | 128 | for step in range(initial_step, num_train_steps + 1): 129 | # clear out memory 130 | torch.cuda.empty_cache() 131 | gc.collect() 132 | 133 | batch = next(train_iter) 134 | batch_imgs = batch[0].to(device) 135 | batch_class = batch[1].to(device) if cfg.data.classes else None 136 | loss = train_step_fn(state, batch_imgs, class_labels=batch_class) 137 | 138 | if step % cfg.training.log_freq == 0: 139 | mprint("step: %d, training_loss: %.5e" % (step, loss.item())) 140 | 141 | # save checkpoint periodically 142 | if step != 0 and step % cfg.training.snapshot_freq_for_preemption == 0 and rank == 0: 143 | utils.save_checkpoint(checkpoint_meta_dir, state) 144 | 145 | # print out eval loss 146 | if step % cfg.training.eval_freq == 0: 147 | eval_batch = next(eval_iter) 148 | batch_imgs = eval_batch[0].to(device) 149 | batch_class = eval_batch[1].to(device) if cfg.data.classes else None 150 | eval_loss = eval_step_fn(state, batch_imgs) 151 | mprint("step: %d, evaluation_loss: %.5e" % (step, eval_loss.item())) 152 | 153 | if step != 0 and step % cfg.training.snapshot_freq == 0 or step == num_train_steps: 154 | # Save the checkpoint. 155 | save_step = step // cfg.training.snapshot_freq 156 | if rank == 0: 157 | utils.save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{save_step}.pth'), state) 158 | 159 | # Generate and save samples 160 | if cfg.training.snapshot_sampling: 161 | mprint(f"Generating images at step: {step}") 162 | 163 | if cfg.data.classes: 164 | weight = 4 * torch.rand(sampling_shape[0]).to(device) 165 | class_labels = torch.randint(0, cfg.data.num_classes, (sampling_shape[0],)).to(device) 166 | else: 167 | weight = None 168 | class_labels = None 169 | 170 | ema.store(score_model.parameters()) 171 | ema.copy_to(score_model.parameters()) 172 | sample, n = sampling_fn(score_model, weight=weight, class_labels=class_labels) 173 | ema.restore(score_model.parameters()) 174 | 175 | this_sample_dir = os.path.join(sample_dir, "iter_{}".format(step)) 176 | utils.makedirs(this_sample_dir) 177 | nrow = int(np.sqrt(sample.shape[0])) 178 | image_grid = make_grid(sample, nrow, padding=2) 179 | sample = np.clip(np.round(sample.permute(0, 2, 3, 1).cpu().numpy() * 255), 0, 255).astype(np.uint8) 180 | np.save(os.path.join(this_sample_dir, f"sample_{rank}"), sample) 181 | save_image(image_grid, os.path.join(this_sample_dir, f"sample_{rank}.png")) 182 | dist.barrier() 183 | 184 | 185 | from run_train import run_multiprocess 186 | @hydra.main(version_base=None, config_path="configs", config_name="train") 187 | def main(cfg): 188 | hydra_cfg = HydraConfig.get() 189 | work_dir = hydra_cfg.run.dir if hydra_cfg.mode == RunMode.RUN else os.path.join(hydra_cfg.sweep.dir, hydra_cfg.sweep.subdir) 190 | utils.makedirs(work_dir) 191 | 192 | # Run the training pipeline 193 | port = int(np.random.randint(10000, 20000)) 194 | logger = utils.get_logger(os.path.join(work_dir, "logs")) 195 | 196 | hydra_cfg = HydraConfig.get() 197 | if hydra_cfg.mode != RunMode.RUN: 198 | logger.info(f"Run id: {hydra_cfg.job.id}") 199 | 200 | try: 201 | mp.set_start_method("forkserver") 202 | mp.spawn(run_multiprocess, args=(cfg.ngpus, cfg, work_dir, port), nprocs=cfg.ngpus, join=True) 203 | except Exception as e: 204 | logger.critical(e, exc_info=True) 205 | 206 | 207 | if __name__ == "__main__": 208 | main() -------------------------------------------------------------------------------- /run_vis.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import os.path 4 | 5 | import numpy as np 6 | import torch 7 | import hydra 8 | from hydra.core.hydra_config import HydraConfig 9 | from hydra.types import RunMode 10 | from torchvision.utils import make_grid, save_image 11 | from omegaconf import open_dict 12 | 13 | import losses 14 | import sampling 15 | import sde_lib 16 | import utils 17 | from models import utils as mutils 18 | from models import adm, ncsnpp, vdm # needed for creating the model 19 | from models.ema import ExponentialMovingAverage 20 | 21 | 22 | torch.backends.cudnn.benchmark = True 23 | 24 | 25 | def visualize(cfg, load_cfg, noise_removal_cfg, log_dir): 26 | # set up 27 | logger = utils.get_logger(os.path.join(log_dir, "logs")) 28 | work_dir = cfg.load_dir 29 | 30 | device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu") 31 | 32 | sde = sde_lib.RVESDE(sigma_min=load_cfg.sde.sigma_min, sigma_max=load_cfg.sde.sigma_max, N=load_cfg.sde.num_scales) 33 | sampling_eps = 1e-5 34 | 35 | sampling_shape = (cfg.eval.batch_size, load_cfg.data.num_channels, load_cfg.data.image_size, load_cfg.data.image_size) 36 | sampling_fn = sampling.get_sampling_fn(load_cfg, sde, sampling_shape, sampling_eps, device) 37 | 38 | # load in models 39 | score_model = mutils.create_model(load_cfg).to(device) 40 | ema = ExponentialMovingAverage(score_model.parameters(), decay=load_cfg.model.ema_rate) 41 | optimizer = losses.get_optimizer(load_cfg, score_model.parameters()) 42 | scaler = torch.cuda.amp.GradScaler() if load_cfg.model.name == "adm" else None 43 | state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0, scaler=scaler) 44 | 45 | if noise_removal_cfg is not None: 46 | noise_removal_model = mutils.create_model(noise_removal_cfg).to(device) 47 | utils.load_denoising_model(os.path.join(cfg.denoiser_path, "checkpoints/checkpoint.pth"), noise_removal_model) 48 | else: 49 | noise_removal_model = None 50 | 51 | ckpt = cfg.eval.ckpt 52 | if ckpt == -1: 53 | ckpts = os.listdir(os.path.join(work_dir, "checkpoints")) 54 | ckpts = [int(x.split(".")[0].split("_")[1]) for x in ckpts] 55 | ckpt = max(ckpts) 56 | 57 | checkpoint_dir = os.path.join(work_dir, "checkpoints", f"checkpoint_{ckpt}.pth") 58 | state = utils.restore_checkpoint(checkpoint_dir, state, device, ddp=False) 59 | ema.copy_to(score_model.parameters()) 60 | 61 | # generate images 62 | this_sample_dir = os.path.join(log_dir, "images") 63 | utils.makedirs(this_sample_dir) 64 | 65 | if load_cfg.model.name == "adm": 66 | w = cfg.w * torch.ones(sampling_shape[0], device=device) 67 | labels = cfg.label * torch.ones(sampling_shape[0], device=device).long() 68 | else: 69 | w = None 70 | labels = None 71 | 72 | logger.info(f"Generating samples for checkpoint {ckpt}") 73 | for r in range(cfg.eval.rounds): 74 | logger.info(f"Round {r}") 75 | samples = sampling_fn(score_model, noise_removal_model=noise_removal_model, weight=w, class_labels=labels)[0] 76 | samples_np = np.round(samples.clip(min=0, max=1).permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8) 77 | 78 | nrow = int(np.sqrt(samples.shape[0])) 79 | image_grid = make_grid(samples, nrow, padding=0) 80 | save_image(image_grid, os.path.join(this_sample_dir, f"samples_{r}.png")) 81 | 82 | with open(os.path.join(this_sample_dir, f"samples_{r}.npz"), "wb") as fout: 83 | io_buffer = io.BytesIO() 84 | np.savez_compressed(io_buffer, samples=samples_np) 85 | fout.write(io_buffer.getvalue()) 86 | 87 | logger.info("Finished generating samples.") 88 | 89 | 90 | from run_vis import * 91 | @hydra.main(version_base=None, config_path="configs", config_name="vis") 92 | def main(cfg): 93 | hydra_cfg = HydraConfig.get() 94 | load_cfg = utils.load_hydra_config_from_run(cfg.load_dir) 95 | 96 | log_dir = hydra_cfg.run.dir if hydra_cfg.mode == RunMode.RUN else os.path.join(hydra_cfg.sweep.dir, hydra_cfg.sweep.subdir) 97 | utils.makedirs(log_dir) 98 | 99 | # overwrite the sampling instructions 100 | with open_dict(load_cfg): 101 | load_cfg.sampling = cfg.sampling 102 | 103 | if cfg.sampling.denoiser == "network": 104 | noise_removal_cfg = utils.load_hydra_config_from_run(cfg.denoiser_path) 105 | else: 106 | noise_removal_cfg = None 107 | 108 | logger = utils.get_logger(os.path.join(log_dir, "logs")) 109 | logger.info(cfg) 110 | logger.info(f"loaded in config from {cfg.load_dir}") 111 | logger.info(load_cfg) 112 | logger.info(f"Denoising with config?") 113 | logger.info(noise_removal_cfg) 114 | 115 | try: 116 | visualize(cfg, load_cfg, noise_removal_cfg, log_dir) 117 | except Exception as e: 118 | logger.critical(e, exc_info=True) 119 | 120 | if __name__ == "__main__": 121 | main() -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | """Various sampling methods.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import abc 7 | 8 | from models.utils import from_flattened_numpy, to_flattened_numpy, get_score_fn 9 | from scipy import integrate 10 | from models import utils as mutils 11 | import cube 12 | 13 | _CORRECTORS = {} 14 | _PREDICTORS = {} 15 | _DENOISERS = {} 16 | 17 | 18 | def register_predictor(cls=None, *, name=None): 19 | """A decorator for registering predictor classes.""" 20 | 21 | def _register(cls): 22 | if name is None: 23 | local_name = cls.__name__ 24 | else: 25 | local_name = name 26 | if local_name in _PREDICTORS: 27 | raise ValueError( 28 | f'Already registered model with name: {local_name}') 29 | _PREDICTORS[local_name] = cls 30 | return cls 31 | 32 | if cls is None: 33 | return _register 34 | else: 35 | return _register(cls) 36 | 37 | 38 | def register_corrector(cls=None, *, name=None): 39 | """A decorator for registering corrector classes.""" 40 | 41 | def _register(cls): 42 | if name is None: 43 | local_name = cls.__name__ 44 | else: 45 | local_name = name 46 | if local_name in _CORRECTORS: 47 | raise ValueError( 48 | f'Already registered model with name: {local_name}') 49 | _CORRECTORS[local_name] = cls 50 | return cls 51 | 52 | if cls is None: 53 | return _register 54 | else: 55 | return _register(cls) 56 | 57 | def register_denoiser(cls=None, *, name=None): 58 | """A decorator for registering corrector classes.""" 59 | def _register(cls): 60 | if name is None: 61 | local_name = cls.__name__ 62 | else: 63 | local_name = name 64 | if local_name in _DENOISERS: 65 | raise ValueError( 66 | f'Already registered model with name: {local_name}') 67 | _DENOISERS[local_name] = cls 68 | return cls 69 | 70 | if cls is None: 71 | return _register 72 | else: 73 | return _register(cls) 74 | 75 | 76 | def get_predictor(name): 77 | return _PREDICTORS[name] 78 | 79 | 80 | def get_corrector(name): 81 | return _CORRECTORS[name] 82 | 83 | def get_denoiser(name): 84 | return _DENOISERS[name] 85 | 86 | 87 | def get_sampling_fn(config, sde, shape, eps, device): 88 | """Create a sampling function. 89 | 90 | Args: 91 | config: A `ml_collections.ConfigDict` object that contains all configuration information. 92 | sde: A `sde_lib.SDE` object that represents the forward SDE. 93 | shape: A sequence of integers representing the expected shape of a single sample. 94 | inverse_scaler: The inverse data normalizer function. 95 | eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability. 96 | 97 | Returns: 98 | A function that takes random states and a replicated training state and outputs samples with the 99 | trailing dimensions matching `shape`. 100 | """ 101 | 102 | sampler_name = config.sampling.method 103 | # Probability flow ODE sampling with black-box ODE solvers 104 | if sampler_name.lower() == 'ode': 105 | denoiser = get_denoiser(config.sampling.denoiser.lower()) 106 | sampling_fn = get_ode_sampler(sde=sde, 107 | shape=shape, 108 | eps=eps, 109 | moll=config.sampling.moll, 110 | side_eps=config.sampling.side_eps, 111 | device=device) 112 | 113 | # Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases. 114 | elif sampler_name.lower() == 'pc': 115 | predictor = get_predictor(config.sampling.predictor.lower()) 116 | corrector = get_corrector(config.sampling.corrector.lower()) 117 | denoiser = get_denoiser(config.sampling.denoiser.lower()) 118 | sampling_fn = get_pc_sampler(sde=sde, 119 | shape=shape, 120 | predictor=predictor, 121 | corrector=corrector, 122 | denoiser=denoiser, 123 | snr=config.sampling.snr, 124 | n_steps=config.sampling.n_steps_each, 125 | eps=eps, 126 | device=device) 127 | else: 128 | raise ValueError(f"Sampler name {sampler_name} unknown.") 129 | 130 | return sampling_fn 131 | 132 | 133 | class Predictor(abc.ABC): 134 | """The abstract class for a predictor algorithm.""" 135 | 136 | def __init__(self, sde, score_fn, probability_flow=False): 137 | super().__init__() 138 | self.sde = sde 139 | # Compute the reverse SDE/ODE 140 | self.rsde = sde.reverse(score_fn, probability_flow) 141 | self.score_fn = score_fn 142 | 143 | @abc.abstractmethod 144 | def update_fn(self, x, t): 145 | """One update of the predictor. 146 | 147 | Args: 148 | x: A PyTorch tensor representing the current state 149 | t: A Pytorch tensor representing the current time step. 150 | 151 | Returns: 152 | x: A PyTorch tensor of the next state. 153 | x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. 154 | """ 155 | pass 156 | 157 | 158 | class Corrector(abc.ABC): 159 | """The abstract class for a corrector algorithm.""" 160 | 161 | def __init__(self, sde, score_fn, snr, n_steps): 162 | super().__init__() 163 | self.sde = sde 164 | self.score_fn = score_fn 165 | self.snr = snr 166 | self.n_steps = n_steps 167 | 168 | @abc.abstractmethod 169 | def update_fn(self, x, t): 170 | """One update of the corrector. 171 | 172 | Args: 173 | x: A PyTorch tensor representing the current state 174 | t: A PyTorch tensor representing the current time step. 175 | 176 | Returns: 177 | x: A PyTorch tensor of the next state. 178 | x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. 179 | """ 180 | pass 181 | 182 | class Denoiser(abc.ABC): 183 | """The abstract class for a denoiser""" 184 | def __init__(self, denoiser): 185 | super().__init__() 186 | self.denoiser = denoiser 187 | 188 | @abc.abstractmethod 189 | def update_fn(self, x, x_mean, t): 190 | pass 191 | 192 | 193 | @register_predictor(name='euler_maruyama') 194 | class ReflectedEulerMaruyamaPredictor(Predictor): 195 | def __init__(self, sde, score_fn, probability_flow=False): 196 | super().__init__(sde, score_fn, probability_flow) 197 | 198 | def update_fn(self, x, t): 199 | dt = -1. / self.rsde.N 200 | z = torch.randn_like(x) 201 | drift, diffusion = self.rsde.sde(x, t) 202 | x_mean = x + drift * dt 203 | x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z 204 | 205 | x, x_mean = cube.reflect(x), cube.reflect(x_mean) 206 | 207 | return x, x_mean 208 | 209 | 210 | @register_corrector(name='langevin') 211 | class ReflectedLangevinCorrector(Corrector): 212 | def __init__(self, sde, score_fn, snr, n_steps): 213 | super().__init__(sde, score_fn, snr, n_steps) 214 | 215 | def update_fn(self, x, t): 216 | sde = self.sde 217 | score_fn = self.score_fn 218 | n_steps = self.n_steps 219 | target_snr = self.snr 220 | alpha = torch.ones_like(t) 221 | 222 | for i in range(n_steps): 223 | grad = score_fn(x, t) 224 | noise = torch.randn_like(x) 225 | grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() 226 | noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() 227 | step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha 228 | x_mean = x + step_size[:, None, None, None] * grad 229 | x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise 230 | 231 | x, x_mean = cube.reflect(x), cube.reflect(x_mean) 232 | 233 | return x, x_mean 234 | 235 | 236 | @register_corrector(name='none') 237 | class NoneCorrector(Corrector): 238 | """An empty corrector that does nothing.""" 239 | 240 | def update_fn(self, x, t): 241 | return x, x 242 | 243 | 244 | @register_denoiser(name='network') 245 | class TrainedDenoiser(Denoiser): 246 | """Apply network to denoise input""" 247 | def update_fn(self, x, x_mean, t): 248 | return (x - self.denoiser(x, t)).clamp(min=0, max=1) 249 | 250 | 251 | @register_denoiser(name="mean") 252 | class MeanDenoiser(Denoiser): 253 | def update_fn(self, x, x_mean, t): 254 | return x_mean 255 | 256 | 257 | @register_denoiser(name="none") 258 | class NoneDenoiser(Denoiser): 259 | def update_fn(self, x, x_mean, t): 260 | return x 261 | 262 | 263 | def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow): 264 | """A wrapper that configures and returns the update function of predictors.""" 265 | score_fn = mutils.get_score_fn(sde, model, train=False) 266 | if predictor is None: 267 | predictor_obj = NonePredictor(sde, score_fn, probability_flow) 268 | else: 269 | predictor_obj = predictor(sde, score_fn, probability_flow) 270 | return predictor_obj.update_fn(x, t) 271 | 272 | 273 | def shared_corrector_update_fn(x, t, sde, model, corrector, snr, n_steps): 274 | """A wrapper tha configures and returns the update function of correctors.""" 275 | score_fn = mutils.get_score_fn(sde, model, train=False) 276 | if corrector is None: 277 | # Predictor-only sampler 278 | corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps) 279 | else: 280 | corrector_obj = corrector(sde, score_fn, snr, n_steps) 281 | return corrector_obj.update_fn(x, t) 282 | 283 | 284 | def shared_denoiser_update_fn(x, x_mean, t, denoiser, denoise_model): 285 | if denoiser is None: 286 | denoiser_obj = NoneDenoiser(denoise_model) 287 | else: 288 | denoiser_obj = denoiser(denoise_model) 289 | return denoise_obj.denoise(x, x_mean, t) 290 | 291 | 292 | def get_pc_sampler(sde, shape, predictor, corrector, denoiser, snr, 293 | n_steps=1, eps=1e-3, device='cuda'): 294 | """Create a Predictor-Corrector (PC) sampler.""" 295 | def pc_sampler(model, z=None, noise_removal_model=None, weight=0, class_labels=None): 296 | """ The PC sampler funciton. 297 | 298 | Args: 299 | model: A score model. 300 | noise_removal_model: A noise removal model (if used). 301 | weight: Weight used for CF guidance. 302 | class_labels: Class labels used for CF guidance. 303 | Returns: 304 | Samples, number of function evaluations. 305 | """ 306 | # Initial sample 307 | if z is None: 308 | x = torch.rand(shape).to(device) 309 | else: 310 | x = z 311 | 312 | if class_labels is None: 313 | score_fn = mutils.get_score_fn(sde, model, train=False) 314 | else: 315 | score_fn = mutils.get_cf_score_fn(sde, model, class_labels, weight) 316 | 317 | # Create update functions 318 | pred = predictor(sde, score_fn) 319 | corr = corrector(sde, score_fn, snr, n_steps) 320 | deno = denoiser(noise_removal_model) 321 | 322 | with torch.no_grad(): 323 | # Initial sample 324 | x = torch.rand(shape).to(device) 325 | timesteps = torch.linspace(sde.T, eps, sde.N, device=device) 326 | 327 | for i in range(sde.N): 328 | t = timesteps[i] 329 | vec_t = torch.ones(shape[0], device=t.device) * t 330 | if i < sde.N - 1: 331 | x, _ = corr.update_fn(x, vec_t) 332 | x, x_mean = pred.update_fn(x, vec_t) 333 | 334 | vec_t = torch.ones(shape[0], device=t.device) * eps 335 | deno.update_fn(x, x_mean, vec_t) 336 | 337 | return x, sde.N * (n_steps + 1) 338 | 339 | return pc_sampler 340 | 341 | 342 | def get_ode_sampler(sde, shape, rtol=1e-5, atol=1e-5, method='RK45', eps=1e-3, moll=200, side_eps=1e-2, device='cuda'): 343 | """Probability flow ODE sampler with the black-box ODE solver.""" 344 | 345 | def drift_fn(score_fn, x, t): 346 | """Get the drift function of the reverse-time SDE.""" 347 | rsde = sde.reverse(score_fn, probability_flow=True) 348 | return rsde.sde(x, t)[0] 349 | 350 | def ode_sampler(model, z=None, noise_removal_model=None, weight=0, class_labels=None): 351 | """The probability flow ODE sampler with black-box ODE solver. 352 | 353 | Args: 354 | model: A score model. 355 | z: If present, generate samples from latent code `z`. 356 | Returns: 357 | samples, number of function evaluations. 358 | """ 359 | with torch.no_grad(): 360 | # Initial sample 361 | if z is None: 362 | x = (1 - 2 * side_eps) * torch.rand(shape).to(device) + side_eps 363 | else: 364 | x = z 365 | 366 | if class_labels is None: 367 | score_fn = mutils.get_score_fn(sde, model, train=False) 368 | else: 369 | score_fn = mutils.get_cf_score_fn(sde, model, class_labels, weight) 370 | 371 | def bump(x): 372 | if moll > 0: 373 | return ((- 1/ (0.5 ** 2 - (0.5 - x).pow(2)) + 4) / moll).exp() 374 | else: 375 | return x 376 | 377 | def ode_func(t, x): 378 | x = from_flattened_numpy(x, shape).to(device).type(torch.float32) 379 | vec_t = torch.ones(shape[0], device=x.device) * t 380 | drift = drift_fn(score_fn, x, vec_t) * bump(x) 381 | return to_flattened_numpy(drift) 382 | 383 | solution = integrate.solve_ivp(ode_func, (sde.T, eps), to_flattened_numpy(x), 384 | rtol=rtol, atol=atol, method=method) 385 | nfe = solution.nfev 386 | x = torch.tensor(solution.y[:, -1]).reshape(shape).to(device).type(torch.float32) 387 | 388 | vec_t = torch.ones(shape[0], device=x.device) * eps 389 | 390 | return x, nfe 391 | 392 | return ode_sampler 393 | -------------------------------------------------------------------------------- /sde_lib.py: -------------------------------------------------------------------------------- 1 | """Abstract SDE classes, Reverse SDE, and VE/VP SDEs.""" 2 | import abc 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class SDE(abc.ABC): 8 | """SDE abstract class. Functions are designed for a mini-batch of inputs.""" 9 | 10 | def __init__(self, N): 11 | """Construct an SDE. 12 | 13 | Args: 14 | N: number of discretization time steps. 15 | """ 16 | super().__init__() 17 | self.N = N 18 | 19 | @property 20 | @abc.abstractmethod 21 | def T(self): 22 | """End time of the SDE.""" 23 | pass 24 | 25 | @abc.abstractmethod 26 | def sde(self, x, t): 27 | pass 28 | 29 | @abc.abstractmethod 30 | def marginal_prob(self, x, t): 31 | """Parameters to determine the marginal distribution of the SDE, $p_t(x)$.""" 32 | pass 33 | 34 | @abc.abstractmethod 35 | def prior_sampling(self, shape): 36 | """Generate one sample from the prior distribution, $p_T(x)$.""" 37 | pass 38 | 39 | @abc.abstractmethod 40 | def prior_logp(self, z): 41 | """Compute log-density of the prior distribution. 42 | 43 | Useful for computing the log-likelihood via probability flow ODE. 44 | 45 | Args: 46 | z: latent code 47 | Returns: 48 | log probability density 49 | """ 50 | pass 51 | 52 | def discretize(self, x, t): 53 | """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. 54 | 55 | Useful for reverse diffusion sampling and probabiliy flow sampling. 56 | Defaults to Euler-Maruyama discretization. 57 | 58 | Args: 59 | x: a torch tensor 60 | t: a torch float representing the time step (from 0 to `self.T`) 61 | 62 | Returns: 63 | f, G 64 | """ 65 | dt = 1 / self.N 66 | drift, diffusion = self.sde(x, t) 67 | f = drift * dt 68 | G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) 69 | return f, G 70 | 71 | def reverse(self, score_fn, probability_flow=False): 72 | """Create the reverse-time SDE/ODE. 73 | 74 | Args: 75 | score_fn: A time-dependent score-based model that takes x and t and returns the score. 76 | probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. 77 | """ 78 | N = self.N 79 | T = self.T 80 | sde_fn = self.sde 81 | discretize_fn = self.discretize 82 | 83 | # Build the class for reverse-time SDE. 84 | class RSDE(self.__class__): 85 | def __init__(self): 86 | self.N = N 87 | self.probability_flow = probability_flow 88 | 89 | @property 90 | def T(self): 91 | return T 92 | 93 | def sde(self, x, t): 94 | """Create the drift and diffusion functions for the reverse SDE/ODE.""" 95 | drift, diffusion = sde_fn(x, t) 96 | score = score_fn(x, t) 97 | drift = drift - diffusion[:, None, None, None] ** 2 * \ 98 | score * (0.5 if self.probability_flow else 1.) 99 | # Set the diffusion function to zero for ODEs. 100 | diffusion = torch.zeros_like(diffusion) if self.probability_flow else diffusion 101 | return drift, diffusion 102 | 103 | def discretize(self, x, t): 104 | """Create discretized iteration rules for the reverse diffusion sampler.""" 105 | f, G = discretize_fn(x, t) 106 | rev_f = f - G[:, None, None, None] ** 2 * \ 107 | score_fn(x, t) * (0.5 if self.probability_flow else 1.) 108 | rev_G = torch.zeros_like(G) if self.probability_flow else G 109 | return rev_f, rev_G 110 | 111 | return RSDE() 112 | 113 | 114 | class RVESDE(SDE): 115 | def __init__(self, sigma_min=0.01, sigma_max=50, N=1000, T=1): 116 | """Construct a Variance Exploding SDE. 117 | 118 | Args: 119 | sigma_min: smallest sigma. 120 | sigma_max: largest sigma. 121 | N: number of discretization steps 122 | """ 123 | super().__init__(N) 124 | self.sigma_min = sigma_min 125 | self.sigma_max = sigma_max 126 | self.discrete_sigmas = torch.exp(torch.linspace( 127 | np.log(self.sigma_min), np.log(self.sigma_max), N)) 128 | self.N = N 129 | self.T_val = T 130 | 131 | @property 132 | def T(self): 133 | return self.T_val 134 | 135 | def sde(self, x, t): 136 | sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 137 | drift = torch.zeros_like(x) 138 | diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), 139 | device=t.device)) 140 | return drift, diffusion 141 | 142 | def marginal_prob(self, x, t): 143 | std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 144 | mean = x 145 | return mean, std 146 | 147 | def prior_sampling(self, shape): 148 | return torch.rand(*shape) 149 | 150 | def prior_logp(self, z): 151 | return torch.zeros_like(z) 152 | 153 | def discretize(self, x, t): 154 | """SMLD(NCSN) discretization.""" 155 | timestep = (t * (self.N - 1) / self.T).long() 156 | sigma = self.discrete_sigmas.to(t.device)[timestep] 157 | adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), 158 | self.discrete_sigmas.to(t.device)[timestep - 1]) 159 | f = torch.zeros_like(x) 160 | G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) 161 | return f, G -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import logging 4 | from omegaconf import OmegaConf 5 | 6 | 7 | def load_hydra_config_from_run(load_dir): 8 | cfg_path = os.path.join(load_dir, ".hydra/config.yaml") 9 | cfg = OmegaConf.load(cfg_path) 10 | return cfg 11 | 12 | 13 | def makedirs(dirname): 14 | os.makedirs(dirname, exist_ok=True) 15 | 16 | 17 | def get_logger(logpath, package_files=[], displaying=True, saving=True, debug=False): 18 | logger = logging.getLogger() 19 | if debug: 20 | level = logging.DEBUG 21 | else: 22 | level = logging.INFO 23 | 24 | if (logger.hasHandlers()): 25 | logger.handlers.clear() 26 | 27 | logger.setLevel(level) 28 | formatter = logging.Formatter('%(asctime)s - %(message)s') 29 | if saving: 30 | info_file_handler = logging.FileHandler(logpath, mode="a") 31 | info_file_handler.setLevel(level) 32 | info_file_handler.setFormatter(formatter) 33 | logger.addHandler(info_file_handler) 34 | if displaying: 35 | console_handler = logging.StreamHandler() 36 | console_handler.setLevel(level) 37 | console_handler.setFormatter(formatter) 38 | logger.addHandler(console_handler) 39 | 40 | for f in package_files: 41 | logger.info(f) 42 | with open(f, "r") as package_f: 43 | logger.info(package_f.read()) 44 | 45 | return logger 46 | 47 | 48 | def restore_checkpoint(ckpt_dir, state, device, ddp=True): 49 | if not os.path.exists(ckpt_dir): 50 | makedirs(os.path.dirname(ckpt_dir)) 51 | logging.warning(f"No checkpoint found at {ckpt_dir}. " 52 | f"Returned the same state as input") 53 | return state 54 | else: 55 | loaded_state = torch.load(ckpt_dir, map_location=device) 56 | state['optimizer'].load_state_dict(loaded_state['optimizer']) 57 | if ddp: 58 | state['model'].module.load_state_dict(loaded_state['model'], strict=False) 59 | else: 60 | state['model'].load_state_dict(loaded_state['model'], strict=False) 61 | state['ema'].load_state_dict(loaded_state['ema']) 62 | state['step'] = loaded_state['step'] 63 | if state['scaler'] is not None: 64 | state['scaler'].load_state_dict(loaded_state['scaler']) 65 | return state 66 | 67 | 68 | def load_denoising_model(ckpt_dir, model, device=torch.device('cpu')): 69 | if not os.path.exists(ckpt_dir): 70 | raise ValueError(f"No checkpoint found at {ckpt_dir}.") 71 | loaded_state = torch.load(ckpt_dir, map_location=device) 72 | model.load_state_dict(loaded_state['model'], strict=False) 73 | return model 74 | 75 | 76 | def save_checkpoint(ckpt_dir, state): 77 | saved_state = { 78 | 'optimizer': state['optimizer'].state_dict(), 79 | 'model': state['model'].module.state_dict(), 80 | 'ema': state['ema'].state_dict(), 81 | 'step': state['step'], 82 | 'scaler': state['scaler'].state_dict() if state['scaler'] else None 83 | } 84 | torch.save(saved_state, ckpt_dir) 85 | 86 | --------------------------------------------------------------------------------