├── .gitignore ├── LICENSE ├── README.md ├── accelerate_config.yaml ├── assets ├── fourier_demo_1d.gif ├── mandrill_fourier.gif ├── sample-ema-2M-1000.png └── training-curve.png ├── eval.py ├── requirements.txt ├── train.py ├── utils.py ├── vdm.py └── vdm_unet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | # Other 133 | .DS_Store 134 | .idea 135 | data/ 136 | results/ 137 | plots/ 138 | Makefile 139 | 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrea Dittadi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Diffusion Models (VDM) 2 | 3 | This is a PyTorch implementation of [Variational Diffusion Models](https://arxiv.org/abs/2107.00630), 4 | where the focus is on optimizing *likelihood rather than sample quality*, 5 | in the spirit of *probabilistic* generative modeling. 6 | 7 | This implementation should match the 8 | [official one](https://github.com/google-research/vdm) in JAX. 9 | However, the purpose is mainly educational and the focus is on simplicity. 10 | So far, the repo only includes CIFAR10, and variance minimization 11 | with the $\gamma_{\eta}$ network (see Appendix `I.2` in the paper) is not 12 | implemented (it's only used for CIFAR10 *with augmentations* and, according 13 | to the paper, it does not have a significant impact). 14 | 15 | 16 | ## Results 17 | 18 | The samples below are from a model trained on CIFAR10 for 2M steps with gradient clipping and with a fixed noise 19 | schedule such that $\log \mathrm{SNR}(t)$ is linear, with $\log \mathrm{SNR}(0) = 13.3$ and $\log \mathrm{SNR}(1) = -5$. 20 | These samples are generated from the EMA model in 1000 denoising steps. 21 | 22 |

23 | Random samples from a model trained on CIFAR10 for 2M steps 24 |

25 | 26 | Without gradient clipping (as in the paper), the test set variational lower bound (VLB) is 2.715 bpd after 2M steps 27 | (the paper reports 2.65 after 10M steps). 28 | However, training is a bit unstable and requires some care (tendency to overfit) 29 | and the train-test gap is rather large. 30 | With gradient clipping, the test set VLB is slightly worse, but training seems more well-behaved. 31 | 32 |

33 | Training curves 34 |

35 | 36 | 37 | ## Overview of the model 38 | 39 | ### Diffusion process 40 | 41 | Let $\mathbf{x}$ be a data point, $\mathbf{z}_t$ the latent variable at time $t \in [0,1]$, and 42 | 43 | $$\sigma^2_t = \mathrm{sigmoid}(\gamma_t)$$ 44 | 45 | $$\alpha^2_t = 1 - \sigma^2_t = \mathrm{sigmoid}(-\gamma_t)$$ 46 | 47 | with $\gamma_t$ the negative log SNR at time $t$. 48 | Then the forward diffusion process is: 49 | 50 | $$q\left(\mathbf{z}_t \mid \mathbf{x}\right)=\mathcal{N}\left(\alpha_t \mathbf{x}, \sigma_t^2 \mathbf{I}\right)$$ 51 | 52 | 53 | ### Reverse generative process 54 | 55 | In discrete time, the generative (denoising) process in $T$ steps is 56 | 57 | $$p(\mathbf{x})=\int_{\mathbf{z}} p\left(\mathbf{z}\_1\right) p\left(\mathbf{x} \mid \mathbf{z}\_0\right) \prod_{i=1}^T p\left(\mathbf{z}\_{s(i)} \mid \mathbf{z}_{t(i)}\right)$$ 58 | 59 | $$p(\mathbf{z}_1) = \mathcal{N}(\mathbf{0}, \mathbf{I})$$ 60 | 61 | $$p(\mathbf{x} \mid \mathbf{z}\_0) = \prod_{i=1}^N p(x_i \mid z_{0,i})$$ 62 | 63 | $$p(x_i \mid z_{0,i}) \propto q(z_{0,i} \mid x_i)$$ 64 | 65 | where $s(i) = \frac{i-1}{T}$ and $t(i) = \frac{i}{T}$. 66 | We then choose the one-step denoising distribution to be equal to the 67 | true denoising distribution given the data (which is available in 68 | closed form) except that we substitute the unavailable data 69 | with a prediction of the clean data at the previous time step: 70 | 71 | $$p\left(\mathbf{z}\_s \mid \mathbf{z}\_t\right)=q\left(\mathbf{z}\_s \mid \mathbf{z}\_t, \mathbf{x}=\hat{\mathbf{x}}\_\theta\left(\mathbf{z}\_t ; t\right)\right)$$ 72 | 73 | where $\hat{\mathbf{x}}_\theta$ is a denoising model with parameters $\theta$. 74 | 75 | 76 | ### Optimization in continuous time 77 | 78 | The loss function is given by the usual variational lower bound: 79 | 80 | $$-\log p(\mathbf{x}) \leq-\text{VLB}(\mathbf{x})=D\_{KL}\left(q\left(\mathbf{z}\_1 \mid \mathbf{x}\right)\ ||\ p\left(\mathbf{z}\_1\right)\right)+\mathbb{E}\_{q\left(\mathbf{z}\_0 \mid \mathbf{x}\right)}\left[-\log p\left(\mathbf{x} \mid \mathbf{z}\_0\right)\right]+\mathcal{L}\_T(\mathbf{x})$$ 81 | 82 | where the diffusion loss $\mathcal{L}_T(\mathbf{x})$ is 83 | 84 | $$\mathcal{L}\_T (\mathbf{x}) = \sum_{i=1}^T \mathbb{E}\_{q \left(\mathbf{z}\_t \mid \mathbf{x}\right)} D\_{KL}\left[q\left(\mathbf{z}\_s \mid \mathbf{z}\_t, \mathbf{x}\right)\ ||\ p\left(\mathbf{z}\_s \mid \mathbf{z}\_t \right)\right]$$ 85 | 86 | Long story short, using the classic noise-prediction parameterization of the denoising model: 87 | 88 | $$\hat{\mathbf{x}}\_\theta\left(\mathbf{z}\_t ; t\right) = \frac{\mathbf{z}\_t-\sigma\_t \hat{\boldsymbol{\epsilon}}\_\theta\left(\mathbf{z}\_t ; t\right)}{\alpha\_t}$$ 89 | 90 | and considering the continuous-time limit ($T \to \infty$), 91 | the diffusion loss simplifies to: 92 | 93 | $$\mathcal{L}\_{\infty}(\mathbf{x})=\frac{1}{2} \mathbb{E}\_{\boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I}), t \sim \mathcal{U}(0,1)}\left[ \frac{d\gamma\_t}{dt} \ \|\| \boldsymbol{\epsilon}-\hat{\boldsymbol{\epsilon}}\_{\boldsymbol{\theta}}\left(\mathbf{z}\_t ; t\right) \|\|\_2^2\right]$$ 94 | 95 | 96 | 97 | 98 | ### Fourier features 99 | 100 | One of the key components to reach SOTA likelihood is the 101 | concatenation of Fourier features to $\mathbf{z}_t$ before feeding it into the 102 | UNet. For each element $z_t^i$ of $\mathbf{z}_t$ (e.g., one channel of 103 | a specific pixel), we concatenate: 104 | 105 | $$f_n^{i} = \sin \left(2^n z_t^{i} 2\pi\right)$$ 106 | 107 | $$g_n^{i} = \cos \left(2^n z_t^{i} 2\pi\right)$$ 108 | 109 | with $n$ taking a set of integer values. 110 | 111 | Assume that each scalar variable takes values: 112 | 113 | $$\frac{2k + 1}{2^{m+1}} \ \text{ with }\ k = 0, ..., 2^m - 1 \ \text{ and }\ m \in \mathbb{N}.$$ 114 | 115 | E.g., in our case the $2^m = 256$ pixel values are $\left\\{\frac{1}{512}, \frac{3}{512}, ..., \frac{511}{512} \right\\}$. 116 | The argument of $\sin$ and $\cos$ is then 117 | 118 | $$\frac{2k + 1}{2^m} 2^n \pi = 2^{n-m} \pi + 2\pi 2^{n-m}k$$ 119 | 120 | which means the features have period $2^{m-n}$ in $k$. 121 | Therefore, at very high SNR (i.e., almost discrete values with negligible noise), where 122 | Fourier features are expected to be most useful to deal with fine details, we should choose 123 | $n < m$, such that the period is greater than 1. 124 | For the cosine, the condition is even stricter, because if $n = m-1$ then 125 | $g_n^i = \cos\left(\frac{\pi}{2} + k\pi\right) = 0$. 126 | Since in our case $m=8$, we take $n \leq 7$. 127 | In the code we use $n \leq 6$ because images have twice the range 128 | (between $\pm \frac{255}{256}$). 129 | 130 | Below we visualize the feature values for pixel values 0 to 25, varying the 131 | frequency $2^n$ with $n$ from 0 to 7. At $n=m-1=7$, the cosine features are constant, 132 | and the sine features measure the least significant bit of the pixel value. 133 | On clean data, any frequency $2^n$ with $n$ integer and $n > 7$ would 134 | be useless (1 would be a multiple of the period). 135 | 136 |

137 | Animation showing Fourier features at different frequencies on discrete pixel values 138 |

139 | 140 | Below are the sine features on the Mandrill image (and detail on the right) with smoothly increasing frequency 141 | from $2^0$ to $2^{4.5}$. 142 | 143 |

144 | Animation showing Fourier features at different frequencies on the Mandrill test image 145 |

146 | 147 | 148 | 149 | 150 | ## Setup 151 | 152 | The environment can be set up with `requirements.txt`. For example with conda: 153 | 154 | ``` 155 | conda create --name vdm python=3.9 156 | conda activate vdm 157 | pip install -r requirements.txt 158 | ``` 159 | 160 | 161 | ## Training with 🤗 Accelerate 162 | 163 | To train with default parameters and options: 164 | 165 | ```bash 166 | accelerate launch --config_file accelerate_config.yaml train.py --results-path results/my_experiment/ 167 | ``` 168 | 169 | Append `--resume` to the command above to resume training from the latest checkpoint. 170 | See [`train.py`](train.py) for more training options. 171 | 172 | Here we provide a sensible configuration for training on 2 GPUs in the file 173 | [`accelerate_config.yaml`](accelerate_config.yaml). This can be modified directly, or overridden 174 | on the command line by adding flags before "`train.py`" (e.g., `--num_processes N` 175 | to train on N GPUs). 176 | See the [Accelerate docs](https://huggingface.co/docs/accelerate/index) for more configuration options. 177 | After initialization, we print an estimate of the required GPU memory for the given 178 | batch size, so that the number of GPUs can be adjusted accordingly. 179 | The training loop periodically logs train and validation metrics to a JSONL file, 180 | and generates samples. 181 | 182 | 183 | ## Evaluating from checkpoint 184 | 185 | ```bash 186 | python eval.py --results-path results/my_experiment/ --n-sample-steps 1000 187 | ``` 188 | 189 | 190 | ## Credits 191 | 192 | This implementation is based on the VDM [paper](https://arxiv.org/abs/2107.00630) and [official code](https://github.com/google-research/vdm). The code structure for training diffusion models with Accelerate is inspired by [this repo](https://github.com/lucidrains/denoising-diffusion-pytorch). 193 | -------------------------------------------------------------------------------- /accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: {} 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | fsdp_config: {} 6 | gpu_ids: all 7 | machine_rank: 0 8 | main_process_ip: null 9 | main_process_port: null 10 | main_training_function: main 11 | mixed_precision: 'no' 12 | num_machines: 1 13 | num_processes: 2 14 | rdzv_backend: static 15 | same_network: true 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /assets/fourier_demo_1d.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/variational-diffusion-models/90a8b59d16845f8fb52e4587d149073caf3465fb/assets/fourier_demo_1d.gif -------------------------------------------------------------------------------- /assets/mandrill_fourier.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/variational-diffusion-models/90a8b59d16845f8fb52e4587d149073caf3465fb/assets/mandrill_fourier.gif -------------------------------------------------------------------------------- /assets/sample-ema-2M-1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/variational-diffusion-models/90a8b59d16845f8fb52e4587d149073caf3465fb/assets/sample-ema-2M-1000.png -------------------------------------------------------------------------------- /assets/training-curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/variational-diffusion-models/90a8b59d16845f8fb52e4587d149073caf3465fb/assets/training-curve.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | from pathlib import Path 4 | 5 | import torch 6 | import yaml 7 | from accelerate.utils import set_seed 8 | from ema_pytorch import EMA 9 | from torch.utils.data import Subset 10 | from torchvision.utils import save_image 11 | 12 | from utils import ( 13 | DeviceAwareDataLoader, 14 | TrainConfig, 15 | evaluate_model_and_log, 16 | get_date_str, 17 | has_int_squareroot, 18 | log, 19 | make_cifar, 20 | print_model_summary, 21 | sample_batched, 22 | ) 23 | from vdm import VDM 24 | from vdm_unet import UNetVDM 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--batch-size", type=int, default=128) 30 | parser.add_argument("--seed", type=int, default=12345) 31 | parser.add_argument("--results-path", type=str, required=True) 32 | parser.add_argument("--num-workers", type=int, default=1) 33 | parser.add_argument("--device", type=str, default="cuda") 34 | parser.add_argument("--n-sample-steps", type=int, default=250) 35 | parser.add_argument("--clip-samples", type=bool, default=True) 36 | parser.add_argument("--n-samples-for-eval", type=int, default=1) 37 | args = parser.parse_args() 38 | set_seed(args.seed) 39 | 40 | # Load config from YAML. 41 | with open(Path(args.results_path) / "config.yaml", "r") as f: 42 | cfg = TrainConfig(**yaml.safe_load(f)) 43 | 44 | model = UNetVDM(cfg) 45 | print_model_summary(model, batch_size=None, shape=(3, 32, 32)) 46 | train_set = make_cifar(train=True, download=True) 47 | validation_set = make_cifar(train=False, download=False) 48 | diffusion = VDM(model, cfg, image_shape=train_set[0][0].shape) 49 | Evaluator( 50 | diffusion, 51 | train_set, 52 | validation_set, 53 | config=cfg, 54 | eval_batch_size=args.batch_size, 55 | results_path=Path(args.results_path), 56 | num_dataloader_workers=args.num_workers, 57 | device=args.device, 58 | n_sample_steps=args.n_sample_steps, 59 | clip_samples=args.clip_samples, 60 | n_samples_for_eval=args.n_samples_for_eval, 61 | ).eval() 62 | 63 | 64 | class Evaluator: 65 | def __init__( 66 | self, 67 | diffusion_model, 68 | train_set, 69 | validation_set, 70 | config, 71 | *, 72 | eval_batch_size, 73 | device, 74 | results_path, 75 | num_samples=64, 76 | num_dataloader_workers=1, 77 | n_sample_steps=250, 78 | clip_samples=True, 79 | n_samples_for_eval=4, 80 | ): 81 | assert has_int_squareroot(num_samples), "num_samples must have an integer sqrt" 82 | self.num_samples = num_samples 83 | self.cfg = config 84 | self.n_sample_steps = n_sample_steps 85 | self.clip_samples = clip_samples 86 | self.device = device 87 | self.eval_batch_size = eval_batch_size 88 | self.n_samples_for_eval = n_samples_for_eval 89 | 90 | def make_dataloader(dataset, limit_size=None): 91 | # If limit_size is not None, only use a subset of the dataset 92 | if limit_size is not None: 93 | dataset = Subset(dataset, range(limit_size)) 94 | return DeviceAwareDataLoader( 95 | dataset, 96 | eval_batch_size, 97 | device=device, 98 | shuffle=False, 99 | pin_memory=True, 100 | num_workers=num_dataloader_workers, 101 | drop_last=True, 102 | ) 103 | 104 | self.validation_dataloader = make_dataloader(validation_set) 105 | self.train_eval_dataloader = make_dataloader(train_set, len(validation_set)) 106 | self.diffusion_model = diffusion_model.eval().to(self.device) 107 | # No need to set EMA parameters since we only use it for eval from checkpoint. 108 | self.ema = EMA(self.diffusion_model).to(self.device) 109 | self.ema.ema_model.eval() 110 | self.path = results_path 111 | self.eval_path = self.path / f"eval_{get_date_str()}" 112 | self.eval_path.mkdir() 113 | self.checkpoint_file = self.path / f"model.pt" 114 | with open(self.eval_path / "eval_config.yaml", "w") as f: 115 | eval_conf = { 116 | "n_sample_steps": n_sample_steps, 117 | "clip_samples": clip_samples, 118 | "n_samples_for_eval": n_samples_for_eval, 119 | } 120 | yaml.dump(eval_conf, f) 121 | self.load_checkpoint() 122 | 123 | def load_checkpoint(self): 124 | data = torch.load(self.checkpoint_file, map_location=self.device) 125 | log(f"Loading checkpoint '{self.checkpoint_file}'") 126 | self.diffusion_model.load_state_dict(data["model"]) 127 | self.ema.load_state_dict(data["ema"]) 128 | 129 | @torch.no_grad() 130 | def eval(self): 131 | self.eval_model(self.diffusion_model, is_ema=False) 132 | self.eval_model(self.ema.ema_model, is_ema=True) 133 | 134 | def eval_model(self, model, *, is_ema): 135 | log(f"\n *** Evaluating {'EMA' if is_ema else 'online'} model\n") 136 | self.sample_images(model, is_ema=is_ema) 137 | for validation in [True, False]: 138 | evaluate_model_and_log( 139 | model, 140 | self.validation_dataloader 141 | if validation 142 | else self.train_eval_dataloader, 143 | self.eval_path / ("ema-metrics.jsonl" if is_ema else "metrics.jsonl"), 144 | "validation" if validation else "train", 145 | n=self.n_samples_for_eval, 146 | ) 147 | 148 | def sample_images(self, model, *, is_ema): 149 | samples = sample_batched( 150 | model, 151 | self.num_samples, 152 | self.eval_batch_size, 153 | self.n_sample_steps, 154 | self.clip_samples, 155 | ) 156 | path = self.eval_path / f"sample{'-ema' if is_ema else ''}.png" 157 | save_image(samples, str(path), nrow=int(math.sqrt(self.num_samples))) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.16.0 2 | black==23.1.0 3 | ema-pytorch==0.1.4 4 | isort==5.12.0 5 | numpy==1.24.2 6 | Pillow==9.4.0 7 | PyYAML==6.0 8 | torch==1.13.1 9 | torchinfo==1.7.2 10 | torchvision==0.14.1 11 | tqdm==4.64.1 12 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | import math 4 | from argparse import BooleanOptionalAction 5 | 6 | import torch 7 | import yaml 8 | from accelerate import Accelerator 9 | from accelerate.utils import set_seed 10 | from ema_pytorch import EMA 11 | from torch.utils.data import Subset 12 | from torchvision.utils import save_image 13 | from tqdm.auto import tqdm 14 | 15 | from utils import ( 16 | DeviceAwareDataLoader, 17 | TrainConfig, 18 | check_config_matches_checkpoint, 19 | cycle, 20 | evaluate_model_and_log, 21 | get_date_str, 22 | handle_results_path, 23 | has_int_squareroot, 24 | init_config_from_args, 25 | init_logger, 26 | log, 27 | make_cifar, 28 | print_model_summary, 29 | sample_batched, 30 | ) 31 | from vdm import VDM 32 | from vdm_unet import UNetVDM 33 | 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser() 37 | 38 | # Architecture 39 | parser.add_argument("--embedding-dim", type=int, default=128) 40 | parser.add_argument("--n-blocks", type=int, default=32) 41 | parser.add_argument("--n-attention-heads", type=int, default=1) 42 | parser.add_argument("--dropout-prob", type=float, default=0.1) 43 | parser.add_argument("--norm-groups", type=int, default=32) 44 | parser.add_argument("--input-channels", type=int, default=3) 45 | parser.add_argument("--use-fourier-features", action=BooleanOptionalAction, default=True) 46 | parser.add_argument("--attention-everywhere", action=BooleanOptionalAction, default=False) 47 | 48 | # Training 49 | parser.add_argument("--batch-size", type=int, default=128) 50 | parser.add_argument("--noise-schedule", type=str, default="fixed_linear") 51 | parser.add_argument("--gamma-min", type=float, default=-13.3) 52 | parser.add_argument("--gamma-max", type=float, default=5.0) 53 | parser.add_argument("--antithetic-time-sampling", action=BooleanOptionalAction, default=True) 54 | parser.add_argument("--lr", type=float, default=2e-4) 55 | parser.add_argument("--weight-decay", type=float, default=0.01) 56 | parser.add_argument("--clip-grad-norm", action=BooleanOptionalAction, default=True) 57 | 58 | parser.add_argument("--eval-every", type=int, default=10_000) 59 | parser.add_argument("--seed", type=int, default=12345) 60 | parser.add_argument("--results-path", type=str, default=None) 61 | parser.add_argument("--resume", action="store_true") 62 | parser.add_argument("--num-workers", type=int, default=2) 63 | args = parser.parse_args() 64 | 65 | set_seed(args.seed) 66 | accelerator = Accelerator(split_batches=True) 67 | init_logger(accelerator) 68 | cfg = init_config_from_args(TrainConfig, args) 69 | 70 | model = UNetVDM(cfg) 71 | print_model_summary(model, batch_size=cfg.batch_size, shape=(3, 32, 32)) 72 | with accelerator.local_main_process_first(): 73 | train_set = make_cifar(train=True, download=accelerator.is_local_main_process) 74 | validation_set = make_cifar(train=False, download=False) 75 | diffusion = VDM(model, cfg, image_shape=train_set[0][0].shape) 76 | Trainer( 77 | diffusion, 78 | train_set, 79 | validation_set, 80 | accelerator, 81 | make_opt=lambda params: torch.optim.AdamW( 82 | params, cfg.lr, betas=(0.9, 0.99), weight_decay=cfg.weight_decay, eps=1e-8 83 | ), 84 | config=cfg, 85 | save_and_eval_every=args.eval_every, 86 | results_path=handle_results_path(args.results_path), 87 | resume=args.resume, 88 | num_dataloader_workers=args.num_workers, 89 | ).train() 90 | 91 | 92 | class Trainer: 93 | def __init__( 94 | self, 95 | diffusion_model, 96 | train_set, 97 | validation_set, 98 | accelerator, 99 | make_opt, 100 | config, 101 | *, 102 | train_num_steps=10_000_000, 103 | ema_decay=0.9999, 104 | ema_update_every=1, 105 | ema_power=3 / 4, # 0.999 at 10k, 0.9997 at 50k, 0.9999 at 200k 106 | save_and_eval_every=1000, 107 | num_samples=64, 108 | results_path=None, 109 | resume=False, 110 | num_dataloader_workers=1, 111 | n_sample_steps=250, 112 | clip_samples=True, 113 | ): 114 | super().__init__() 115 | assert has_int_squareroot(num_samples), "num_samples must have an integer sqrt" 116 | self.num_samples = num_samples 117 | self.save_and_eval_every = save_and_eval_every 118 | self.cfg = config 119 | self.train_num_steps = train_num_steps 120 | self.n_sample_steps = n_sample_steps 121 | self.clip_samples = clip_samples 122 | self.accelerator = accelerator 123 | self.step = 0 124 | 125 | def make_dataloader(dataset, limit_size=None, *, train=False): 126 | if limit_size is not None: 127 | dataset = Subset(dataset, range(limit_size)) 128 | dataloader = DeviceAwareDataLoader( 129 | dataset, 130 | config.batch_size, 131 | shuffle=train, 132 | pin_memory=True, 133 | num_workers=num_dataloader_workers, 134 | drop_last=True, 135 | device=accelerator.device if not train else None, # None -> standard DL 136 | ) 137 | if train: 138 | dataloader = accelerator.prepare(dataloader) 139 | return dataloader 140 | 141 | self.train_dataloader = cycle(make_dataloader(train_set, train=True)) 142 | self.validation_dataloader = make_dataloader(validation_set) 143 | self.train_eval_dataloader = make_dataloader(train_set, len(validation_set)) 144 | 145 | self.path = results_path 146 | self.checkpoint_file = self.path / f"model.pt" 147 | if accelerator.is_main_process: 148 | self.ema = EMA( 149 | diffusion_model.to(accelerator.device), 150 | beta=ema_decay, 151 | update_every=ema_update_every, 152 | power=ema_power, 153 | ) 154 | self.ema.ema_model.eval() 155 | self.path.mkdir(exist_ok=True, parents=True) 156 | self.diffusion_model = accelerator.prepare(diffusion_model) 157 | self.opt = accelerator.prepare(make_opt(self.diffusion_model.parameters())) 158 | if resume: 159 | self.load_checkpoint() 160 | else: 161 | if len(list(self.path.glob("*.pt"))) > 0: 162 | raise ValueError(f"'{self.path}' contains checkpoints but resume=False") 163 | if accelerator.is_main_process: 164 | with open(self.path / "config.yaml", "w") as f: 165 | yaml.dump(dataclasses.asdict(config), f) 166 | 167 | def save_checkpoint(self): 168 | tmp_file = self.checkpoint_file.with_suffix(f".tmp.{get_date_str()}.pt") 169 | if self.checkpoint_file.exists(): 170 | self.checkpoint_file.rename(tmp_file) # Rename old checkpoint to temp file 171 | checkpoint = { 172 | "step": self.step, 173 | "model": self.accelerator.get_state_dict(self.diffusion_model), 174 | "opt": self.opt.state_dict(), 175 | "ema": self.ema.state_dict(), 176 | } 177 | torch.save(checkpoint, self.checkpoint_file) 178 | tmp_file.unlink(missing_ok=True) # Delete temp file 179 | 180 | def load_checkpoint(self): 181 | check_config_matches_checkpoint(self.cfg, self.path) 182 | data = torch.load(self.checkpoint_file, map_location=self.accelerator.device) 183 | self.step = data["step"] 184 | log(f"Resuming from checkpoint '{self.checkpoint_file}' (step {self.step})") 185 | model = self.accelerator.unwrap_model(self.diffusion_model) 186 | model.load_state_dict(data["model"]) 187 | self.opt.load_state_dict(data["opt"]) 188 | if self.accelerator.is_main_process: 189 | self.ema.load_state_dict(data["ema"]) 190 | 191 | def train(self): 192 | with tqdm( 193 | initial=self.step, 194 | total=self.train_num_steps, 195 | disable=not self.accelerator.is_main_process, 196 | ) as pbar: 197 | while self.step < self.train_num_steps: 198 | data = next(self.train_dataloader) 199 | self.opt.zero_grad() 200 | loss, _ = self.diffusion_model(data) 201 | self.accelerator.backward(loss) 202 | if self.cfg.clip_grad_norm: 203 | self.accelerator.clip_grad_norm_( 204 | self.diffusion_model.parameters(), 1.0 205 | ) 206 | self.opt.step() 207 | pbar.set_description(f"loss: {loss.item():.4f}") 208 | self.step += 1 209 | self.accelerator.wait_for_everyone() 210 | if self.accelerator.is_main_process: 211 | self.ema.update() 212 | if self.step % self.save_and_eval_every == 0: 213 | self.eval() 214 | pbar.update() 215 | 216 | @torch.no_grad() 217 | def eval(self): 218 | self.save_checkpoint() 219 | self.sample_images(self.ema.ema_model, is_ema=True) 220 | self.sample_images(self.diffusion_model, is_ema=False) 221 | self.evaluate_ema_model_and_log(validation=True) 222 | self.evaluate_ema_model_and_log(validation=False) 223 | 224 | def evaluate_ema_model_and_log(self, *, validation): 225 | evaluate_model_and_log( 226 | self.ema.ema_model, 227 | self.validation_dataloader if validation else self.train_eval_dataloader, 228 | self.path / "metrics_log.jsonl", 229 | "validation" if validation else "train", 230 | self.step, 231 | ) 232 | 233 | def sample_images(self, model, *, is_ema): 234 | train_state = model.training 235 | model.eval() 236 | samples = sample_batched( 237 | self.accelerator.unwrap_model(model), 238 | self.num_samples, 239 | self.cfg.batch_size, 240 | self.n_sample_steps, 241 | self.clip_samples, 242 | ) 243 | path = self.path / f"sample-{'ema-' if is_ema else ''}{self.step}.png" 244 | save_image(samples, str(path), nrow=int(math.sqrt(self.num_samples))) 245 | model.train(train_state) 246 | 247 | 248 | if __name__ == "__main__": 249 | main() 250 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | import math 4 | import warnings 5 | from collections import defaultdict 6 | from dataclasses import dataclass 7 | from datetime import datetime 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import numpy as np 12 | import torch 13 | import torchinfo 14 | import yaml 15 | from accelerate import Accelerator 16 | from torch import nn 17 | from torch.utils.data import DataLoader 18 | from torchvision import transforms 19 | from torchvision.datasets import CIFAR10 20 | from tqdm.auto import tqdm 21 | 22 | 23 | @dataclass 24 | class TrainConfig: 25 | embedding_dim: float 26 | n_blocks: int 27 | n_attention_heads: int 28 | dropout_prob: float 29 | norm_groups: int 30 | input_channels: int 31 | use_fourier_features: bool 32 | attention_everywhere: bool 33 | batch_size: int 34 | noise_schedule: str 35 | gamma_min: float 36 | gamma_max: float 37 | antithetic_time_sampling: bool 38 | lr: float 39 | weight_decay: float 40 | clip_grad_norm: bool 41 | 42 | 43 | def print_model_summary(model, *, batch_size, shape, depth=4, batch_size_torchinfo=1): 44 | summary = torchinfo.summary( 45 | model, 46 | [(batch_size_torchinfo, *shape), (batch_size_torchinfo,)], 47 | depth=depth, 48 | col_names=["input_size", "output_size", "num_params"], 49 | verbose=0, # quiet 50 | ) 51 | log(summary) 52 | if batch_size is None or batch_size == batch_size_torchinfo: 53 | return 54 | output_bytes_large = summary.total_output_bytes / batch_size_torchinfo * batch_size 55 | total_bytes = summary.total_input + output_bytes_large + summary.total_param_bytes 56 | log( 57 | f"\n--- With batch size {batch_size} ---\n" 58 | f"Forward/backward pass size: {output_bytes_large / 1e9:0.2f} GB\n" 59 | f"Estimated Total Size: {total_bytes / 1e9:0.2f} GB\n" 60 | + "=" * len(str(summary).splitlines()[-1]) 61 | + "\n" 62 | ) 63 | 64 | 65 | def cycle(dl): 66 | # We don't use itertools.cycle because it caches the entire iterator. 67 | while True: 68 | for data in dl: 69 | yield data 70 | 71 | 72 | def has_int_squareroot(num): 73 | return (math.sqrt(num) ** 2) == num 74 | 75 | 76 | def sample_batched(model, num_samples, batch_size, n_sample_steps, clip_samples): 77 | samples = [] 78 | for i in range(0, num_samples, batch_size): 79 | corrected_batch_size = min(batch_size, num_samples - i) 80 | samples.append(model.sample(corrected_batch_size, n_sample_steps, clip_samples)) 81 | return torch.cat(samples, dim=0) 82 | 83 | 84 | def get_date_str(): 85 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 86 | 87 | 88 | class DeviceAwareDataLoader(DataLoader): 89 | """A DataLoader that moves batches to a device. If device is None, it is equivalent to a standard DataLoader.""" 90 | 91 | def __init__(self, *args, device=None, **kwargs): 92 | super().__init__(*args, **kwargs) 93 | self.device = device 94 | 95 | def __iter__(self): 96 | for batch in super().__iter__(): 97 | yield self.move_to_device(batch) 98 | 99 | def move_to_device(self, batch): 100 | if self.device is None: 101 | return batch 102 | if isinstance(batch, (tuple, list)): 103 | return [self.move_to_device(x) for x in batch] 104 | elif isinstance(batch, dict): 105 | return {k: self.move_to_device(v) for k, v in batch.items()} 106 | elif isinstance(batch, torch.Tensor): 107 | return batch.to(self.device) 108 | else: 109 | return batch 110 | 111 | 112 | def evaluate_model(model, dataloader): 113 | all_metrics = defaultdict(list) 114 | for batch in tqdm(dataloader, desc="evaluation"): 115 | loss, metrics = model(batch) 116 | for k, v in metrics.items(): 117 | try: 118 | v = v.item() 119 | except AttributeError: 120 | pass 121 | all_metrics[k].append(v) 122 | return {k: sum(v) / len(v) for k, v in all_metrics.items()} # average over dataset 123 | 124 | 125 | def log_and_save_metrics(avg_metrics, dataset_split, step, filename): 126 | log(f"\n{dataset_split} metrics:") 127 | for k, v in avg_metrics.items(): 128 | log(f" {k}: {v}") 129 | 130 | avg_metrics = {"step": step, "set": dataset_split, **avg_metrics} 131 | with open(filename, "a") as f: 132 | json.dump(avg_metrics, f) 133 | f.write("\n") 134 | 135 | 136 | def dict_stats(dictionaries: list[dict]) -> dict: 137 | """Computes the average and standard deviation of metrics in a list of dictionaries. 138 | 139 | Args: 140 | dictionaries: A list of dictionaries, where each dictionary contains the same keys, 141 | and the values are numbers. 142 | 143 | Returns: 144 | A dictionary of the same keys as the input dictionaries, with the average and 145 | standard deviation of the values. If the list has length 1, the original dictionary 146 | is returned instead. 147 | """ 148 | if len(dictionaries) == 1: 149 | return dictionaries[0] 150 | 151 | # Convert the list of dictionaries to a dictionary of lists. 152 | lists = defaultdict(list) 153 | for d in dictionaries: 154 | for k, v in d.items(): 155 | lists[k].append(v) 156 | 157 | # Compute the average and standard deviation of each list. 158 | stats = {} 159 | for k, v in lists.items(): 160 | stats[f"{k}_avg"] = np.mean(v) 161 | stats[f"{k}_std"] = np.std(v) 162 | return stats 163 | 164 | 165 | def evaluate_model_and_log(model, dataloader, filename, split, step=None, n=1): 166 | # Call evaluate_model multiple times. Each call returns a dictionary of metrics, and 167 | # we then compute their average and standard deviation. 168 | if n > 1: 169 | log(f"\nRunning {n} evaluations to compute average metrics") 170 | metrics = dict_stats([evaluate_model(model, dataloader) for _ in range(n)]) 171 | log_and_save_metrics(metrics, split, step, filename) 172 | 173 | 174 | @torch.no_grad() 175 | def zero_init(module: nn.Module) -> nn.Module: 176 | """Sets to zero all the parameters of a module, and returns the module.""" 177 | for p in module.parameters(): 178 | nn.init.zeros_(p.data) 179 | return module 180 | 181 | 182 | def maybe_unpack_batch(batch): 183 | if isinstance(batch, (tuple, list)) and len(batch) == 2: 184 | return batch 185 | else: 186 | return batch, None 187 | 188 | 189 | def make_cifar(*, train, download): 190 | return CIFAR10( 191 | root="data", 192 | download=download, 193 | train=train, 194 | transform=transforms.Compose([transforms.ToTensor()]), 195 | ) 196 | 197 | 198 | def handle_results_path(res_path: str, default_root: str = "./results") -> Path: 199 | """Sets results path if it doesn't exist yet.""" 200 | if res_path is None: 201 | results_path = Path(default_root) / get_date_str() 202 | else: 203 | results_path = Path(res_path) 204 | log(f"Results will be saved to '{results_path}'") 205 | return results_path 206 | 207 | 208 | def unsqueeze_right(x, num_dims=1): 209 | """Unsqueezes the last `num_dims` dimensions of `x`.""" 210 | return x.view(x.shape + (1,) * num_dims) 211 | 212 | 213 | def init_config_from_args(cls, args): 214 | """Initializes a dataclass from a Namespace, ignoring unknown fields.""" 215 | return cls(**{f.name: getattr(args, f.name) for f in dataclasses.fields(cls)}) 216 | 217 | 218 | def check_config_matches_checkpoint(config, checkpoint_path): 219 | with open(checkpoint_path / "config.yaml", "r") as f: 220 | ckpt_config = yaml.safe_load(f) 221 | config = dataclasses.asdict(config) 222 | if config != ckpt_config: 223 | config_str = "\n ".join(f"{k}: {config[k]}" for k in sorted(config)) 224 | ckpt_str = "\n ".join(f"{k}: {ckpt_config[k]}" for k in sorted(ckpt_config)) 225 | raise ValueError( 226 | f"Config mismatch:\n\n" 227 | f"> Config:\n {config_str}\n\n" 228 | f"> Checkpoint:\n {ckpt_str}\n\n" 229 | ) 230 | 231 | 232 | _accelerator: Optional[Accelerator] = None 233 | 234 | 235 | def init_logger(accelerator: Accelerator): 236 | global _accelerator 237 | if _accelerator is not None: 238 | raise ValueError("Accelerator already set") 239 | _accelerator = accelerator 240 | 241 | 242 | def log(message): 243 | global _accelerator 244 | if _accelerator is None: 245 | warnings.warn("Accelerator not set, using print instead.") 246 | print_fn = print 247 | else: 248 | print_fn = _accelerator.print 249 | print_fn(message) 250 | -------------------------------------------------------------------------------- /vdm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import allclose, argmax, autograd, exp, linspace, nn, sigmoid, sqrt 4 | from torch.special import expm1 5 | from tqdm import trange 6 | 7 | from utils import maybe_unpack_batch, unsqueeze_right 8 | 9 | 10 | class VDM(nn.Module): 11 | def __init__(self, model, cfg, image_shape): 12 | super().__init__() 13 | self.model = model 14 | self.cfg = cfg 15 | self.image_shape = image_shape 16 | self.vocab_size = 256 17 | if cfg.noise_schedule == "fixed_linear": 18 | self.gamma = FixedLinearSchedule(cfg.gamma_min, cfg.gamma_max) 19 | elif cfg.noise_schedule == "learned_linear": 20 | self.gamma = LearnedLinearSchedule(cfg.gamma_min, cfg.gamma_max) 21 | else: 22 | raise ValueError(f"Unknown noise schedule {cfg.noise_schedule}") 23 | 24 | @property 25 | def device(self): 26 | return next(self.model.parameters()).device 27 | 28 | @torch.no_grad() 29 | def sample_p_s_t(self, z, t, s, clip_samples): 30 | """Samples from p(z_s | z_t, x). Used for standard ancestral sampling.""" 31 | gamma_t = self.gamma(t) 32 | gamma_s = self.gamma(s) 33 | c = -expm1(gamma_s - gamma_t) 34 | alpha_t = sqrt(sigmoid(-gamma_t)) 35 | alpha_s = sqrt(sigmoid(-gamma_s)) 36 | sigma_t = sqrt(sigmoid(gamma_t)) 37 | sigma_s = sqrt(sigmoid(gamma_s)) 38 | 39 | pred_noise = self.model(z, gamma_t) 40 | if clip_samples: 41 | x_start = (z - sigma_t * pred_noise) / alpha_t 42 | x_start.clamp_(-1.0, 1.0) 43 | mean = alpha_s * (z * (1 - c) / alpha_t + c * x_start) 44 | else: 45 | mean = alpha_s / alpha_t * (z - c * sigma_t * pred_noise) 46 | scale = sigma_s * sqrt(c) 47 | return mean + scale * torch.randn_like(z) 48 | 49 | @torch.no_grad() 50 | def sample(self, batch_size, n_sample_steps, clip_samples): 51 | z = torch.randn((batch_size, *self.image_shape), device=self.device) 52 | steps = linspace(1.0, 0.0, n_sample_steps + 1, device=self.device) 53 | for i in trange(n_sample_steps, desc="sampling"): 54 | z = self.sample_p_s_t(z, steps[i], steps[i + 1], clip_samples) 55 | logprobs = self.log_probs_x_z0(z_0=z) # (B, C, H, W, vocab_size) 56 | x = argmax(logprobs, dim=-1) # (B, C, H, W) 57 | return x.float() / (self.vocab_size - 1) # normalize to [0, 1] 58 | 59 | def sample_q_t_0(self, x, times, noise=None): 60 | """Samples from the distributions q(x_t | x_0) at the given time steps.""" 61 | with torch.enable_grad(): # Need gradient to compute loss even when evaluating 62 | gamma_t = self.gamma(times) 63 | gamma_t_padded = unsqueeze_right(gamma_t, x.ndim - gamma_t.ndim) 64 | mean = x * sqrt(sigmoid(-gamma_t_padded)) # x * alpha 65 | scale = sqrt(sigmoid(gamma_t_padded)) 66 | if noise is None: 67 | noise = torch.randn_like(x) 68 | return mean + noise * scale, gamma_t 69 | 70 | def sample_times(self, batch_size): 71 | if self.cfg.antithetic_time_sampling: 72 | t0 = np.random.uniform(0, 1 / batch_size) 73 | times = torch.arange(t0, 1.0, 1.0 / batch_size, device=self.device) 74 | else: 75 | times = torch.rand(batch_size, device=self.device) 76 | return times 77 | 78 | def forward(self, batch, *, noise=None): 79 | x, label = maybe_unpack_batch(batch) 80 | assert x.shape[1:] == self.image_shape 81 | assert 0.0 <= x.min() and x.max() <= 1.0 82 | bpd_factor = 1 / (np.prod(x.shape[1:]) * np.log(2)) 83 | 84 | # Convert image to integers in range [0, vocab_size - 1]. 85 | img_int = torch.round(x * (self.vocab_size - 1)).long() 86 | assert (img_int >= 0).all() and (img_int <= self.vocab_size - 1).all() 87 | # Check that the image was discrete with vocab_size values. 88 | assert allclose(img_int / (self.vocab_size - 1), x) 89 | 90 | # Rescale integer image to [-1 + 1/vocab_size, 1 - 1/vocab_size] 91 | x = 2 * ((img_int + 0.5) / self.vocab_size) - 1 92 | 93 | # Sample from q(x_t | x_0) with random t. 94 | times = self.sample_times(x.shape[0]).requires_grad_(True) 95 | if noise is None: 96 | noise = torch.randn_like(x) 97 | x_t, gamma_t = self.sample_q_t_0(x=x, times=times, noise=noise) 98 | 99 | # Forward through model 100 | model_out = self.model(x_t, gamma_t) 101 | 102 | # *** Diffusion loss (bpd) 103 | gamma_grad = autograd.grad( # gamma_grad shape: (B, ) 104 | gamma_t, # (B, ) 105 | times, # (B, ) 106 | grad_outputs=torch.ones_like(gamma_t), 107 | create_graph=True, 108 | retain_graph=True, 109 | )[0] 110 | pred_loss = ((model_out - noise) ** 2).sum((1, 2, 3)) # (B, ) 111 | diffusion_loss = 0.5 * pred_loss * gamma_grad * bpd_factor 112 | 113 | # *** Latent loss (bpd): KL divergence from N(0, 1) to q(z_1 | x) 114 | gamma_1 = self.gamma(torch.tensor([1.0], device=self.device)) 115 | sigma_1_sq = sigmoid(gamma_1) 116 | mean_sq = (1 - sigma_1_sq) * x**2 # (alpha_1 * x)**2 117 | latent_loss = kl_std_normal(mean_sq, sigma_1_sq).sum((1, 2, 3)) * bpd_factor 118 | 119 | # *** Reconstruction loss (bpd): - E_{q(z_0 | x)} [log p(x | z_0)]. 120 | # Compute log p(x | z_0) for all possible values of each pixel in x. 121 | log_probs = self.log_probs_x_z0(x) # (B, C, H, W, vocab_size) 122 | # One-hot representation of original image. Shape: (B, C, H, W, vocab_size). 123 | x_one_hot = torch.zeros((*x.shape, self.vocab_size), device=self.device) 124 | x_one_hot.scatter_(4, img_int.unsqueeze(-1), 1) # one-hot over last dim 125 | # Select the correct log probabilities. 126 | log_probs = (x_one_hot * log_probs).sum(-1) # (B, C, H, W) 127 | # Overall logprob for each image in batch. 128 | recons_loss = -log_probs.sum((1, 2, 3)) * bpd_factor 129 | 130 | # *** Overall loss in bpd. Shape (B, ). 131 | loss = diffusion_loss + latent_loss + recons_loss 132 | 133 | with torch.no_grad(): 134 | gamma_0 = self.gamma(torch.tensor([0.0], device=self.device)) 135 | metrics = { 136 | "bpd": loss.mean(), 137 | "diff_loss": diffusion_loss.mean(), 138 | "latent_loss": latent_loss.mean(), 139 | "loss_recon": recons_loss.mean(), 140 | "gamma_0": gamma_0.item(), 141 | "gamma_1": gamma_1.item(), 142 | } 143 | return loss.mean(), metrics 144 | 145 | def log_probs_x_z0(self, x=None, z_0=None): 146 | """Computes log p(x | z_0) for all possible values of x. 147 | 148 | Compute p(x_i | z_0i), with i = pixel index, for all possible values of x_i in 149 | the vocabulary. We approximate this with q(z_0i | x_i). Unnormalized logits are: 150 | -1/2 SNR_0 (z_0 / alpha_0 - k)^2 151 | where k takes all possible x_i values. Logits are then normalized to logprobs. 152 | 153 | The method returns a tensor of shape (B, C, H, W, vocab_size) containing, for 154 | each pixel, the log probabilities for all `vocab_size` possible values of that 155 | pixel. The output sums to 1 over the last dimension. 156 | 157 | The method accepts either `x` or `z_0` as input. If `z_0` is given, it is used 158 | directly. If `x` is given, a sample z_0 is drawn from q(z_0 | x). It's more 159 | efficient to pass `x` directly, if available. 160 | 161 | Args: 162 | x: Input image, shape (B, C, H, W). 163 | z_0: z_0 to be decoded, shape (B, C, H, W). 164 | 165 | Returns: 166 | log_probs: Log probabilities of shape (B, C, H, W, vocab_size). 167 | """ 168 | gamma_0 = self.gamma(torch.tensor([0.0], device=self.device)) 169 | if x is None and z_0 is not None: 170 | z_0_rescaled = z_0 / sqrt(sigmoid(-gamma_0)) # z_0 / alpha_0 171 | elif z_0 is None and x is not None: 172 | # Equal to z_0/alpha_0 with z_0 sampled from q(z_0 | x) 173 | z_0_rescaled = x + exp(0.5 * gamma_0) * torch.randn_like(x) # (B, C, H, W) 174 | else: 175 | raise ValueError("Must provide either x or z_0, not both.") 176 | z_0_rescaled = z_0_rescaled.unsqueeze(-1) # (B, C, H, W, 1) 177 | x_lim = 1 - 1 / self.vocab_size 178 | x_values = linspace(-x_lim, x_lim, self.vocab_size, device=self.device) 179 | logits = -0.5 * exp(-gamma_0) * (z_0_rescaled - x_values) ** 2 # broadcast x 180 | log_probs = torch.log_softmax(logits, dim=-1) # (B, C, H, W, vocab_size) 181 | return log_probs 182 | 183 | 184 | def kl_std_normal(mean_squared, var): 185 | return 0.5 * (var + mean_squared - torch.log(var.clamp(min=1e-15)) - 1.0) 186 | 187 | 188 | class FixedLinearSchedule(nn.Module): 189 | def __init__(self, gamma_min, gamma_max): 190 | super().__init__() 191 | self.gamma_min = gamma_min 192 | self.gamma_max = gamma_max 193 | 194 | def forward(self, t): 195 | return self.gamma_min + (self.gamma_max - self.gamma_min) * t 196 | 197 | 198 | class LearnedLinearSchedule(nn.Module): 199 | def __init__(self, gamma_min, gamma_max): 200 | super().__init__() 201 | self.b = nn.Parameter(torch.tensor(gamma_min)) 202 | self.w = nn.Parameter(torch.tensor(gamma_max - gamma_min)) 203 | 204 | def forward(self, t): 205 | return self.b + self.w.abs() * t 206 | -------------------------------------------------------------------------------- /vdm_unet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import einsum, nn, pi, softmax 4 | 5 | from utils import zero_init 6 | 7 | 8 | class UNetVDM(nn.Module): 9 | def __init__(self, cfg): 10 | super().__init__() 11 | self.cfg = cfg 12 | 13 | attention_params = dict( 14 | n_heads=cfg.n_attention_heads, 15 | n_channels=cfg.embedding_dim, 16 | norm_groups=cfg.norm_groups, 17 | ) 18 | resnet_params = dict( 19 | ch_in=cfg.embedding_dim, 20 | ch_out=cfg.embedding_dim, 21 | condition_dim=4 * cfg.embedding_dim, 22 | dropout_prob=cfg.dropout_prob, 23 | norm_groups=cfg.norm_groups, 24 | ) 25 | if cfg.use_fourier_features: 26 | self.fourier_features = FourierFeatures() 27 | self.embed_conditioning = nn.Sequential( 28 | nn.Linear(cfg.embedding_dim, cfg.embedding_dim * 4), 29 | nn.SiLU(), 30 | nn.Linear(cfg.embedding_dim * 4, cfg.embedding_dim * 4), 31 | nn.SiLU(), 32 | ) 33 | total_input_ch = cfg.input_channels 34 | if cfg.use_fourier_features: 35 | total_input_ch *= 1 + self.fourier_features.num_features 36 | self.conv_in = nn.Conv2d(total_input_ch, cfg.embedding_dim, 3, padding=1) 37 | 38 | # Down path: n_blocks blocks with a resnet block and maybe attention. 39 | self.down_blocks = nn.ModuleList( 40 | UpDownBlock( 41 | resnet_block=ResnetBlock(**resnet_params), 42 | attention_block=AttentionBlock(**attention_params) 43 | if cfg.attention_everywhere 44 | else None, 45 | ) 46 | for _ in range(cfg.n_blocks) 47 | ) 48 | 49 | self.mid_resnet_block_1 = ResnetBlock(**resnet_params) 50 | self.mid_attn_block = AttentionBlock(**attention_params) 51 | self.mid_resnet_block_2 = ResnetBlock(**resnet_params) 52 | 53 | # Up path: n_blocks+1 blocks with a resnet block and maybe attention. 54 | resnet_params["ch_in"] *= 2 # double input channels due to skip connections 55 | self.up_blocks = nn.ModuleList( 56 | UpDownBlock( 57 | resnet_block=ResnetBlock(**resnet_params), 58 | attention_block=AttentionBlock(**attention_params) 59 | if cfg.attention_everywhere 60 | else None, 61 | ) 62 | for _ in range(cfg.n_blocks + 1) 63 | ) 64 | 65 | self.conv_out = nn.Sequential( 66 | nn.GroupNorm(num_groups=cfg.norm_groups, num_channels=cfg.embedding_dim), 67 | nn.SiLU(), 68 | zero_init(nn.Conv2d(cfg.embedding_dim, cfg.input_channels, 3, padding=1)), 69 | ) 70 | 71 | def forward(self, z, g_t): 72 | # Get gamma to shape (B, ). 73 | g_t = g_t.expand(z.shape[0]) # assume shape () or (1,) or (B,) 74 | assert g_t.shape == (z.shape[0],) 75 | # Rescale to [0, 1], but only approximately since gamma0 & gamma1 are not fixed. 76 | t = (g_t - self.cfg.gamma_min) / (self.cfg.gamma_max - self.cfg.gamma_min) 77 | t_embedding = get_timestep_embedding(t, self.cfg.embedding_dim) 78 | # We will condition on time embedding. 79 | cond = self.embed_conditioning(t_embedding) 80 | 81 | h = self.maybe_concat_fourier(z) 82 | h = self.conv_in(h) # (B, embedding_dim, H, W) 83 | hs = [] 84 | for down_block in self.down_blocks: # n_blocks times 85 | hs.append(h) 86 | h = down_block(h, cond) 87 | hs.append(h) 88 | h = self.mid_resnet_block_1(h, cond) 89 | h = self.mid_attn_block(h) 90 | h = self.mid_resnet_block_2(h, cond) 91 | for up_block in self.up_blocks: # n_blocks+1 times 92 | h = torch.cat([h, hs.pop()], dim=1) 93 | h = up_block(h, cond) 94 | prediction = self.conv_out(h) 95 | assert prediction.shape == z.shape, (prediction.shape, z.shape) 96 | return prediction + z 97 | 98 | def maybe_concat_fourier(self, z): 99 | if self.cfg.use_fourier_features: 100 | return torch.cat([z, self.fourier_features(z)], dim=1) 101 | return z 102 | 103 | 104 | class ResnetBlock(nn.Module): 105 | def __init__( 106 | self, 107 | ch_in, 108 | ch_out=None, 109 | condition_dim=None, 110 | dropout_prob=0.0, 111 | norm_groups=32, 112 | ): 113 | super().__init__() 114 | ch_out = ch_in if ch_out is None else ch_out 115 | self.ch_out = ch_out 116 | self.condition_dim = condition_dim 117 | self.net1 = nn.Sequential( 118 | nn.GroupNorm(num_groups=norm_groups, num_channels=ch_in), 119 | nn.SiLU(), 120 | nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1), 121 | ) 122 | if condition_dim is not None: 123 | self.cond_proj = zero_init(nn.Linear(condition_dim, ch_out, bias=False)) 124 | self.net2 = nn.Sequential( 125 | nn.GroupNorm(num_groups=norm_groups, num_channels=ch_out), 126 | nn.SiLU(), 127 | *([nn.Dropout(dropout_prob)] * (dropout_prob > 0.0)), 128 | zero_init(nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1)), 129 | ) 130 | if ch_in != ch_out: 131 | self.skip_conv = nn.Conv2d(ch_in, ch_out, kernel_size=1) 132 | 133 | def forward(self, x, condition): 134 | h = self.net1(x) 135 | if condition is not None: 136 | assert condition.shape == (x.shape[0], self.condition_dim) 137 | condition = self.cond_proj(condition) 138 | condition = condition[:, :, None, None] 139 | h = h + condition 140 | h = self.net2(h) 141 | if x.shape[1] != self.ch_out: 142 | x = self.skip_conv(x) 143 | assert x.shape == h.shape 144 | return x + h 145 | 146 | 147 | def get_timestep_embedding( 148 | timesteps, 149 | embedding_dim: int, 150 | dtype=torch.float32, 151 | max_timescale=10_000, 152 | min_timescale=1, 153 | ): 154 | # Adapted from tensor2tensor and VDM codebase. 155 | assert timesteps.ndim == 1 156 | assert embedding_dim % 2 == 0 157 | timesteps *= 1000.0 # In DDPM the time step is in [0, 1000], here [0, 1] 158 | num_timescales = embedding_dim // 2 159 | inv_timescales = torch.logspace( # or exp(-linspace(log(min), log(max), n)) 160 | -np.log10(min_timescale), 161 | -np.log10(max_timescale), 162 | num_timescales, 163 | device=timesteps.device, 164 | ) 165 | emb = timesteps.to(dtype)[:, None] * inv_timescales[None, :] # (T, D/2) 166 | return torch.cat([emb.sin(), emb.cos()], dim=1) # (T, D) 167 | 168 | 169 | class FourierFeatures(nn.Module): 170 | def __init__(self, first=5.0, last=6.0, step=1.0): 171 | super().__init__() 172 | self.freqs_exponent = torch.arange(first, last + 1e-8, step) 173 | 174 | @property 175 | def num_features(self): 176 | return len(self.freqs_exponent) * 2 177 | 178 | def forward(self, x): 179 | assert len(x.shape) >= 2 180 | 181 | # Compute (2pi * 2^n) for n in freqs. 182 | freqs_exponent = self.freqs_exponent.to(dtype=x.dtype, device=x.device) # (F, ) 183 | freqs = 2.0**freqs_exponent * 2 * pi # (F, ) 184 | freqs = freqs.view(-1, *([1] * (x.dim() - 1))) # (F, 1, 1, ...) 185 | 186 | # Compute (2pi * 2^n * x) for n in freqs. 187 | features = freqs * x.unsqueeze(1) # (B, F, X1, X2, ...) 188 | features = features.flatten(1, 2) # (B, F * C, X1, X2, ...) 189 | 190 | # Output features are cos and sin of above. Shape (B, 2 * F * C, H, W). 191 | return torch.cat([features.sin(), features.cos()], dim=1) 192 | 193 | 194 | def attention_inner_heads(qkv, num_heads): 195 | """Computes attention with heads inside of qkv in the channel dimension. 196 | 197 | Args: 198 | qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where: 199 | H = number of heads, 200 | C = number of channels per head. 201 | num_heads: number of heads. 202 | 203 | Returns: 204 | Attention output of shape (B, H*C, T). 205 | """ 206 | 207 | bs, width, length = qkv.shape 208 | ch = width // (3 * num_heads) 209 | 210 | # Split into (q, k, v) of shape (B, H*C, T). 211 | q, k, v = qkv.chunk(3, dim=1) 212 | 213 | # Rescale q and k. This makes them contiguous in memory. 214 | scale = ch ** (-1 / 4) # scale with 4th root = scaling output by sqrt 215 | q = q * scale 216 | k = k * scale 217 | 218 | # Reshape qkv to (B*H, C, T). 219 | new_shape = (bs * num_heads, ch, length) 220 | q = q.view(*new_shape) 221 | k = k.view(*new_shape) 222 | v = v.reshape(*new_shape) 223 | 224 | # Compute attention. 225 | weight = einsum("bct,bcs->bts", q, k) # (B*H, T, T) 226 | weight = softmax(weight.float(), dim=-1).to(weight.dtype) # (B*H, T, T) 227 | out = einsum("bts,bcs->bct", weight, v) # (B*H, C, T) 228 | return out.reshape(bs, num_heads * ch, length) # (B, H*C, T) 229 | 230 | 231 | class Attention(nn.Module): 232 | """Based on https://github.com/openai/guided-diffusion.""" 233 | 234 | def __init__(self, n_heads): 235 | super().__init__() 236 | self.n_heads = n_heads 237 | 238 | def forward(self, qkv): 239 | assert qkv.dim() >= 3, qkv.dim() 240 | assert qkv.shape[1] % (3 * self.n_heads) == 0 241 | spatial_dims = qkv.shape[2:] 242 | qkv = qkv.view(*qkv.shape[:2], -1) # (B, 3*H*C, T) 243 | out = attention_inner_heads(qkv, self.n_heads) # (B, H*C, T) 244 | return out.view(*out.shape[:2], *spatial_dims) 245 | 246 | 247 | class AttentionBlock(nn.Module): 248 | """Self-attention residual block.""" 249 | 250 | def __init__(self, n_heads, n_channels, norm_groups): 251 | super().__init__() 252 | assert n_channels % n_heads == 0 253 | self.layers = nn.Sequential( 254 | nn.GroupNorm(num_groups=norm_groups, num_channels=n_channels), 255 | nn.Conv2d(n_channels, 3 * n_channels, kernel_size=1), # (B, 3 * C, H, W) 256 | Attention(n_heads), 257 | zero_init(nn.Conv2d(n_channels, n_channels, kernel_size=1)), 258 | ) 259 | 260 | def forward(self, x): 261 | return self.layers(x) + x 262 | 263 | 264 | class UpDownBlock(nn.Module): 265 | def __init__(self, resnet_block, attention_block=None): 266 | super().__init__() 267 | self.resnet_block = resnet_block 268 | self.attention_block = attention_block 269 | 270 | def forward(self, x, cond): 271 | x = self.resnet_block(x, cond) 272 | if self.attention_block is not None: 273 | x = self.attention_block(x) 274 | return x 275 | --------------------------------------------------------------------------------