├── .gitignore
├── LICENSE
├── README.md
├── accelerate_config.yaml
├── assets
├── fourier_demo_1d.gif
├── mandrill_fourier.gif
├── sample-ema-2M-1000.png
└── training-curve.png
├── eval.py
├── requirements.txt
├── train.py
├── utils.py
├── vdm.py
└── vdm_unet.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 |
132 | # Other
133 | .DS_Store
134 | .idea
135 | data/
136 | results/
137 | plots/
138 | Makefile
139 |
140 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Andrea Dittadi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Variational Diffusion Models (VDM)
2 |
3 | This is a PyTorch implementation of [Variational Diffusion Models](https://arxiv.org/abs/2107.00630),
4 | where the focus is on optimizing *likelihood rather than sample quality*,
5 | in the spirit of *probabilistic* generative modeling.
6 |
7 | This implementation should match the
8 | [official one](https://github.com/google-research/vdm) in JAX.
9 | However, the purpose is mainly educational and the focus is on simplicity.
10 | So far, the repo only includes CIFAR10, and variance minimization
11 | with the $\gamma_{\eta}$ network (see Appendix `I.2` in the paper) is not
12 | implemented (it's only used for CIFAR10 *with augmentations* and, according
13 | to the paper, it does not have a significant impact).
14 |
15 |
16 | ## Results
17 |
18 | The samples below are from a model trained on CIFAR10 for 2M steps with gradient clipping and with a fixed noise
19 | schedule such that $\log \mathrm{SNR}(t)$ is linear, with $\log \mathrm{SNR}(0) = 13.3$ and $\log \mathrm{SNR}(1) = -5$.
20 | These samples are generated from the EMA model in 1000 denoising steps.
21 |
22 |
23 |
24 |
25 |
26 | Without gradient clipping (as in the paper), the test set variational lower bound (VLB) is 2.715 bpd after 2M steps
27 | (the paper reports 2.65 after 10M steps).
28 | However, training is a bit unstable and requires some care (tendency to overfit)
29 | and the train-test gap is rather large.
30 | With gradient clipping, the test set VLB is slightly worse, but training seems more well-behaved.
31 |
32 |
33 |
34 |
35 |
36 |
37 | ## Overview of the model
38 |
39 | ### Diffusion process
40 |
41 | Let $\mathbf{x}$ be a data point, $\mathbf{z}_t$ the latent variable at time $t \in [0,1]$, and
42 |
43 | $$\sigma^2_t = \mathrm{sigmoid}(\gamma_t)$$
44 |
45 | $$\alpha^2_t = 1 - \sigma^2_t = \mathrm{sigmoid}(-\gamma_t)$$
46 |
47 | with $\gamma_t$ the negative log SNR at time $t$.
48 | Then the forward diffusion process is:
49 |
50 | $$q\left(\mathbf{z}_t \mid \mathbf{x}\right)=\mathcal{N}\left(\alpha_t \mathbf{x}, \sigma_t^2 \mathbf{I}\right)$$
51 |
52 |
53 | ### Reverse generative process
54 |
55 | In discrete time, the generative (denoising) process in $T$ steps is
56 |
57 | $$p(\mathbf{x})=\int_{\mathbf{z}} p\left(\mathbf{z}\_1\right) p\left(\mathbf{x} \mid \mathbf{z}\_0\right) \prod_{i=1}^T p\left(\mathbf{z}\_{s(i)} \mid \mathbf{z}_{t(i)}\right)$$
58 |
59 | $$p(\mathbf{z}_1) = \mathcal{N}(\mathbf{0}, \mathbf{I})$$
60 |
61 | $$p(\mathbf{x} \mid \mathbf{z}\_0) = \prod_{i=1}^N p(x_i \mid z_{0,i})$$
62 |
63 | $$p(x_i \mid z_{0,i}) \propto q(z_{0,i} \mid x_i)$$
64 |
65 | where $s(i) = \frac{i-1}{T}$ and $t(i) = \frac{i}{T}$.
66 | We then choose the one-step denoising distribution to be equal to the
67 | true denoising distribution given the data (which is available in
68 | closed form) except that we substitute the unavailable data
69 | with a prediction of the clean data at the previous time step:
70 |
71 | $$p\left(\mathbf{z}\_s \mid \mathbf{z}\_t\right)=q\left(\mathbf{z}\_s \mid \mathbf{z}\_t, \mathbf{x}=\hat{\mathbf{x}}\_\theta\left(\mathbf{z}\_t ; t\right)\right)$$
72 |
73 | where $\hat{\mathbf{x}}_\theta$ is a denoising model with parameters $\theta$.
74 |
75 |
76 | ### Optimization in continuous time
77 |
78 | The loss function is given by the usual variational lower bound:
79 |
80 | $$-\log p(\mathbf{x}) \leq-\text{VLB}(\mathbf{x})=D\_{KL}\left(q\left(\mathbf{z}\_1 \mid \mathbf{x}\right)\ ||\ p\left(\mathbf{z}\_1\right)\right)+\mathbb{E}\_{q\left(\mathbf{z}\_0 \mid \mathbf{x}\right)}\left[-\log p\left(\mathbf{x} \mid \mathbf{z}\_0\right)\right]+\mathcal{L}\_T(\mathbf{x})$$
81 |
82 | where the diffusion loss $\mathcal{L}_T(\mathbf{x})$ is
83 |
84 | $$\mathcal{L}\_T (\mathbf{x}) = \sum_{i=1}^T \mathbb{E}\_{q \left(\mathbf{z}\_t \mid \mathbf{x}\right)} D\_{KL}\left[q\left(\mathbf{z}\_s \mid \mathbf{z}\_t, \mathbf{x}\right)\ ||\ p\left(\mathbf{z}\_s \mid \mathbf{z}\_t \right)\right]$$
85 |
86 | Long story short, using the classic noise-prediction parameterization of the denoising model:
87 |
88 | $$\hat{\mathbf{x}}\_\theta\left(\mathbf{z}\_t ; t\right) = \frac{\mathbf{z}\_t-\sigma\_t \hat{\boldsymbol{\epsilon}}\_\theta\left(\mathbf{z}\_t ; t\right)}{\alpha\_t}$$
89 |
90 | and considering the continuous-time limit ($T \to \infty$),
91 | the diffusion loss simplifies to:
92 |
93 | $$\mathcal{L}\_{\infty}(\mathbf{x})=\frac{1}{2} \mathbb{E}\_{\boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I}), t \sim \mathcal{U}(0,1)}\left[ \frac{d\gamma\_t}{dt} \ \|\| \boldsymbol{\epsilon}-\hat{\boldsymbol{\epsilon}}\_{\boldsymbol{\theta}}\left(\mathbf{z}\_t ; t\right) \|\|\_2^2\right]$$
94 |
95 |
96 |
97 |
98 | ### Fourier features
99 |
100 | One of the key components to reach SOTA likelihood is the
101 | concatenation of Fourier features to $\mathbf{z}_t$ before feeding it into the
102 | UNet. For each element $z_t^i$ of $\mathbf{z}_t$ (e.g., one channel of
103 | a specific pixel), we concatenate:
104 |
105 | $$f_n^{i} = \sin \left(2^n z_t^{i} 2\pi\right)$$
106 |
107 | $$g_n^{i} = \cos \left(2^n z_t^{i} 2\pi\right)$$
108 |
109 | with $n$ taking a set of integer values.
110 |
111 | Assume that each scalar variable takes values:
112 |
113 | $$\frac{2k + 1}{2^{m+1}} \ \text{ with }\ k = 0, ..., 2^m - 1 \ \text{ and }\ m \in \mathbb{N}.$$
114 |
115 | E.g., in our case the $2^m = 256$ pixel values are $\left\\{\frac{1}{512}, \frac{3}{512}, ..., \frac{511}{512} \right\\}$.
116 | The argument of $\sin$ and $\cos$ is then
117 |
118 | $$\frac{2k + 1}{2^m} 2^n \pi = 2^{n-m} \pi + 2\pi 2^{n-m}k$$
119 |
120 | which means the features have period $2^{m-n}$ in $k$.
121 | Therefore, at very high SNR (i.e., almost discrete values with negligible noise), where
122 | Fourier features are expected to be most useful to deal with fine details, we should choose
123 | $n < m$, such that the period is greater than 1.
124 | For the cosine, the condition is even stricter, because if $n = m-1$ then
125 | $g_n^i = \cos\left(\frac{\pi}{2} + k\pi\right) = 0$.
126 | Since in our case $m=8$, we take $n \leq 7$.
127 | In the code we use $n \leq 6$ because images have twice the range
128 | (between $\pm \frac{255}{256}$).
129 |
130 | Below we visualize the feature values for pixel values 0 to 25, varying the
131 | frequency $2^n$ with $n$ from 0 to 7. At $n=m-1=7$, the cosine features are constant,
132 | and the sine features measure the least significant bit of the pixel value.
133 | On clean data, any frequency $2^n$ with $n$ integer and $n > 7$ would
134 | be useless (1 would be a multiple of the period).
135 |
136 |
137 |
138 |
139 |
140 | Below are the sine features on the Mandrill image (and detail on the right) with smoothly increasing frequency
141 | from $2^0$ to $2^{4.5}$.
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 | ## Setup
151 |
152 | The environment can be set up with `requirements.txt`. For example with conda:
153 |
154 | ```
155 | conda create --name vdm python=3.9
156 | conda activate vdm
157 | pip install -r requirements.txt
158 | ```
159 |
160 |
161 | ## Training with 🤗 Accelerate
162 |
163 | To train with default parameters and options:
164 |
165 | ```bash
166 | accelerate launch --config_file accelerate_config.yaml train.py --results-path results/my_experiment/
167 | ```
168 |
169 | Append `--resume` to the command above to resume training from the latest checkpoint.
170 | See [`train.py`](train.py) for more training options.
171 |
172 | Here we provide a sensible configuration for training on 2 GPUs in the file
173 | [`accelerate_config.yaml`](accelerate_config.yaml). This can be modified directly, or overridden
174 | on the command line by adding flags before "`train.py`" (e.g., `--num_processes N`
175 | to train on N GPUs).
176 | See the [Accelerate docs](https://huggingface.co/docs/accelerate/index) for more configuration options.
177 | After initialization, we print an estimate of the required GPU memory for the given
178 | batch size, so that the number of GPUs can be adjusted accordingly.
179 | The training loop periodically logs train and validation metrics to a JSONL file,
180 | and generates samples.
181 |
182 |
183 | ## Evaluating from checkpoint
184 |
185 | ```bash
186 | python eval.py --results-path results/my_experiment/ --n-sample-steps 1000
187 | ```
188 |
189 |
190 | ## Credits
191 |
192 | This implementation is based on the VDM [paper](https://arxiv.org/abs/2107.00630) and [official code](https://github.com/google-research/vdm). The code structure for training diffusion models with Accelerate is inspired by [this repo](https://github.com/lucidrains/denoising-diffusion-pytorch).
193 |
--------------------------------------------------------------------------------
/accelerate_config.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config: {}
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | fsdp_config: {}
6 | gpu_ids: all
7 | machine_rank: 0
8 | main_process_ip: null
9 | main_process_port: null
10 | main_training_function: main
11 | mixed_precision: 'no'
12 | num_machines: 1
13 | num_processes: 2
14 | rdzv_backend: static
15 | same_network: true
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/assets/fourier_demo_1d.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/variational-diffusion-models/90a8b59d16845f8fb52e4587d149073caf3465fb/assets/fourier_demo_1d.gif
--------------------------------------------------------------------------------
/assets/mandrill_fourier.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/variational-diffusion-models/90a8b59d16845f8fb52e4587d149073caf3465fb/assets/mandrill_fourier.gif
--------------------------------------------------------------------------------
/assets/sample-ema-2M-1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/variational-diffusion-models/90a8b59d16845f8fb52e4587d149073caf3465fb/assets/sample-ema-2M-1000.png
--------------------------------------------------------------------------------
/assets/training-curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/variational-diffusion-models/90a8b59d16845f8fb52e4587d149073caf3465fb/assets/training-curve.png
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | from pathlib import Path
4 |
5 | import torch
6 | import yaml
7 | from accelerate.utils import set_seed
8 | from ema_pytorch import EMA
9 | from torch.utils.data import Subset
10 | from torchvision.utils import save_image
11 |
12 | from utils import (
13 | DeviceAwareDataLoader,
14 | TrainConfig,
15 | evaluate_model_and_log,
16 | get_date_str,
17 | has_int_squareroot,
18 | log,
19 | make_cifar,
20 | print_model_summary,
21 | sample_batched,
22 | )
23 | from vdm import VDM
24 | from vdm_unet import UNetVDM
25 |
26 |
27 | def main():
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("--batch-size", type=int, default=128)
30 | parser.add_argument("--seed", type=int, default=12345)
31 | parser.add_argument("--results-path", type=str, required=True)
32 | parser.add_argument("--num-workers", type=int, default=1)
33 | parser.add_argument("--device", type=str, default="cuda")
34 | parser.add_argument("--n-sample-steps", type=int, default=250)
35 | parser.add_argument("--clip-samples", type=bool, default=True)
36 | parser.add_argument("--n-samples-for-eval", type=int, default=1)
37 | args = parser.parse_args()
38 | set_seed(args.seed)
39 |
40 | # Load config from YAML.
41 | with open(Path(args.results_path) / "config.yaml", "r") as f:
42 | cfg = TrainConfig(**yaml.safe_load(f))
43 |
44 | model = UNetVDM(cfg)
45 | print_model_summary(model, batch_size=None, shape=(3, 32, 32))
46 | train_set = make_cifar(train=True, download=True)
47 | validation_set = make_cifar(train=False, download=False)
48 | diffusion = VDM(model, cfg, image_shape=train_set[0][0].shape)
49 | Evaluator(
50 | diffusion,
51 | train_set,
52 | validation_set,
53 | config=cfg,
54 | eval_batch_size=args.batch_size,
55 | results_path=Path(args.results_path),
56 | num_dataloader_workers=args.num_workers,
57 | device=args.device,
58 | n_sample_steps=args.n_sample_steps,
59 | clip_samples=args.clip_samples,
60 | n_samples_for_eval=args.n_samples_for_eval,
61 | ).eval()
62 |
63 |
64 | class Evaluator:
65 | def __init__(
66 | self,
67 | diffusion_model,
68 | train_set,
69 | validation_set,
70 | config,
71 | *,
72 | eval_batch_size,
73 | device,
74 | results_path,
75 | num_samples=64,
76 | num_dataloader_workers=1,
77 | n_sample_steps=250,
78 | clip_samples=True,
79 | n_samples_for_eval=4,
80 | ):
81 | assert has_int_squareroot(num_samples), "num_samples must have an integer sqrt"
82 | self.num_samples = num_samples
83 | self.cfg = config
84 | self.n_sample_steps = n_sample_steps
85 | self.clip_samples = clip_samples
86 | self.device = device
87 | self.eval_batch_size = eval_batch_size
88 | self.n_samples_for_eval = n_samples_for_eval
89 |
90 | def make_dataloader(dataset, limit_size=None):
91 | # If limit_size is not None, only use a subset of the dataset
92 | if limit_size is not None:
93 | dataset = Subset(dataset, range(limit_size))
94 | return DeviceAwareDataLoader(
95 | dataset,
96 | eval_batch_size,
97 | device=device,
98 | shuffle=False,
99 | pin_memory=True,
100 | num_workers=num_dataloader_workers,
101 | drop_last=True,
102 | )
103 |
104 | self.validation_dataloader = make_dataloader(validation_set)
105 | self.train_eval_dataloader = make_dataloader(train_set, len(validation_set))
106 | self.diffusion_model = diffusion_model.eval().to(self.device)
107 | # No need to set EMA parameters since we only use it for eval from checkpoint.
108 | self.ema = EMA(self.diffusion_model).to(self.device)
109 | self.ema.ema_model.eval()
110 | self.path = results_path
111 | self.eval_path = self.path / f"eval_{get_date_str()}"
112 | self.eval_path.mkdir()
113 | self.checkpoint_file = self.path / f"model.pt"
114 | with open(self.eval_path / "eval_config.yaml", "w") as f:
115 | eval_conf = {
116 | "n_sample_steps": n_sample_steps,
117 | "clip_samples": clip_samples,
118 | "n_samples_for_eval": n_samples_for_eval,
119 | }
120 | yaml.dump(eval_conf, f)
121 | self.load_checkpoint()
122 |
123 | def load_checkpoint(self):
124 | data = torch.load(self.checkpoint_file, map_location=self.device)
125 | log(f"Loading checkpoint '{self.checkpoint_file}'")
126 | self.diffusion_model.load_state_dict(data["model"])
127 | self.ema.load_state_dict(data["ema"])
128 |
129 | @torch.no_grad()
130 | def eval(self):
131 | self.eval_model(self.diffusion_model, is_ema=False)
132 | self.eval_model(self.ema.ema_model, is_ema=True)
133 |
134 | def eval_model(self, model, *, is_ema):
135 | log(f"\n *** Evaluating {'EMA' if is_ema else 'online'} model\n")
136 | self.sample_images(model, is_ema=is_ema)
137 | for validation in [True, False]:
138 | evaluate_model_and_log(
139 | model,
140 | self.validation_dataloader
141 | if validation
142 | else self.train_eval_dataloader,
143 | self.eval_path / ("ema-metrics.jsonl" if is_ema else "metrics.jsonl"),
144 | "validation" if validation else "train",
145 | n=self.n_samples_for_eval,
146 | )
147 |
148 | def sample_images(self, model, *, is_ema):
149 | samples = sample_batched(
150 | model,
151 | self.num_samples,
152 | self.eval_batch_size,
153 | self.n_sample_steps,
154 | self.clip_samples,
155 | )
156 | path = self.eval_path / f"sample{'-ema' if is_ema else ''}.png"
157 | save_image(samples, str(path), nrow=int(math.sqrt(self.num_samples)))
158 |
159 |
160 | if __name__ == "__main__":
161 | main()
162 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.16.0
2 | black==23.1.0
3 | ema-pytorch==0.1.4
4 | isort==5.12.0
5 | numpy==1.24.2
6 | Pillow==9.4.0
7 | PyYAML==6.0
8 | torch==1.13.1
9 | torchinfo==1.7.2
10 | torchvision==0.14.1
11 | tqdm==4.64.1
12 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import dataclasses
3 | import math
4 | from argparse import BooleanOptionalAction
5 |
6 | import torch
7 | import yaml
8 | from accelerate import Accelerator
9 | from accelerate.utils import set_seed
10 | from ema_pytorch import EMA
11 | from torch.utils.data import Subset
12 | from torchvision.utils import save_image
13 | from tqdm.auto import tqdm
14 |
15 | from utils import (
16 | DeviceAwareDataLoader,
17 | TrainConfig,
18 | check_config_matches_checkpoint,
19 | cycle,
20 | evaluate_model_and_log,
21 | get_date_str,
22 | handle_results_path,
23 | has_int_squareroot,
24 | init_config_from_args,
25 | init_logger,
26 | log,
27 | make_cifar,
28 | print_model_summary,
29 | sample_batched,
30 | )
31 | from vdm import VDM
32 | from vdm_unet import UNetVDM
33 |
34 |
35 | def main():
36 | parser = argparse.ArgumentParser()
37 |
38 | # Architecture
39 | parser.add_argument("--embedding-dim", type=int, default=128)
40 | parser.add_argument("--n-blocks", type=int, default=32)
41 | parser.add_argument("--n-attention-heads", type=int, default=1)
42 | parser.add_argument("--dropout-prob", type=float, default=0.1)
43 | parser.add_argument("--norm-groups", type=int, default=32)
44 | parser.add_argument("--input-channels", type=int, default=3)
45 | parser.add_argument("--use-fourier-features", action=BooleanOptionalAction, default=True)
46 | parser.add_argument("--attention-everywhere", action=BooleanOptionalAction, default=False)
47 |
48 | # Training
49 | parser.add_argument("--batch-size", type=int, default=128)
50 | parser.add_argument("--noise-schedule", type=str, default="fixed_linear")
51 | parser.add_argument("--gamma-min", type=float, default=-13.3)
52 | parser.add_argument("--gamma-max", type=float, default=5.0)
53 | parser.add_argument("--antithetic-time-sampling", action=BooleanOptionalAction, default=True)
54 | parser.add_argument("--lr", type=float, default=2e-4)
55 | parser.add_argument("--weight-decay", type=float, default=0.01)
56 | parser.add_argument("--clip-grad-norm", action=BooleanOptionalAction, default=True)
57 |
58 | parser.add_argument("--eval-every", type=int, default=10_000)
59 | parser.add_argument("--seed", type=int, default=12345)
60 | parser.add_argument("--results-path", type=str, default=None)
61 | parser.add_argument("--resume", action="store_true")
62 | parser.add_argument("--num-workers", type=int, default=2)
63 | args = parser.parse_args()
64 |
65 | set_seed(args.seed)
66 | accelerator = Accelerator(split_batches=True)
67 | init_logger(accelerator)
68 | cfg = init_config_from_args(TrainConfig, args)
69 |
70 | model = UNetVDM(cfg)
71 | print_model_summary(model, batch_size=cfg.batch_size, shape=(3, 32, 32))
72 | with accelerator.local_main_process_first():
73 | train_set = make_cifar(train=True, download=accelerator.is_local_main_process)
74 | validation_set = make_cifar(train=False, download=False)
75 | diffusion = VDM(model, cfg, image_shape=train_set[0][0].shape)
76 | Trainer(
77 | diffusion,
78 | train_set,
79 | validation_set,
80 | accelerator,
81 | make_opt=lambda params: torch.optim.AdamW(
82 | params, cfg.lr, betas=(0.9, 0.99), weight_decay=cfg.weight_decay, eps=1e-8
83 | ),
84 | config=cfg,
85 | save_and_eval_every=args.eval_every,
86 | results_path=handle_results_path(args.results_path),
87 | resume=args.resume,
88 | num_dataloader_workers=args.num_workers,
89 | ).train()
90 |
91 |
92 | class Trainer:
93 | def __init__(
94 | self,
95 | diffusion_model,
96 | train_set,
97 | validation_set,
98 | accelerator,
99 | make_opt,
100 | config,
101 | *,
102 | train_num_steps=10_000_000,
103 | ema_decay=0.9999,
104 | ema_update_every=1,
105 | ema_power=3 / 4, # 0.999 at 10k, 0.9997 at 50k, 0.9999 at 200k
106 | save_and_eval_every=1000,
107 | num_samples=64,
108 | results_path=None,
109 | resume=False,
110 | num_dataloader_workers=1,
111 | n_sample_steps=250,
112 | clip_samples=True,
113 | ):
114 | super().__init__()
115 | assert has_int_squareroot(num_samples), "num_samples must have an integer sqrt"
116 | self.num_samples = num_samples
117 | self.save_and_eval_every = save_and_eval_every
118 | self.cfg = config
119 | self.train_num_steps = train_num_steps
120 | self.n_sample_steps = n_sample_steps
121 | self.clip_samples = clip_samples
122 | self.accelerator = accelerator
123 | self.step = 0
124 |
125 | def make_dataloader(dataset, limit_size=None, *, train=False):
126 | if limit_size is not None:
127 | dataset = Subset(dataset, range(limit_size))
128 | dataloader = DeviceAwareDataLoader(
129 | dataset,
130 | config.batch_size,
131 | shuffle=train,
132 | pin_memory=True,
133 | num_workers=num_dataloader_workers,
134 | drop_last=True,
135 | device=accelerator.device if not train else None, # None -> standard DL
136 | )
137 | if train:
138 | dataloader = accelerator.prepare(dataloader)
139 | return dataloader
140 |
141 | self.train_dataloader = cycle(make_dataloader(train_set, train=True))
142 | self.validation_dataloader = make_dataloader(validation_set)
143 | self.train_eval_dataloader = make_dataloader(train_set, len(validation_set))
144 |
145 | self.path = results_path
146 | self.checkpoint_file = self.path / f"model.pt"
147 | if accelerator.is_main_process:
148 | self.ema = EMA(
149 | diffusion_model.to(accelerator.device),
150 | beta=ema_decay,
151 | update_every=ema_update_every,
152 | power=ema_power,
153 | )
154 | self.ema.ema_model.eval()
155 | self.path.mkdir(exist_ok=True, parents=True)
156 | self.diffusion_model = accelerator.prepare(diffusion_model)
157 | self.opt = accelerator.prepare(make_opt(self.diffusion_model.parameters()))
158 | if resume:
159 | self.load_checkpoint()
160 | else:
161 | if len(list(self.path.glob("*.pt"))) > 0:
162 | raise ValueError(f"'{self.path}' contains checkpoints but resume=False")
163 | if accelerator.is_main_process:
164 | with open(self.path / "config.yaml", "w") as f:
165 | yaml.dump(dataclasses.asdict(config), f)
166 |
167 | def save_checkpoint(self):
168 | tmp_file = self.checkpoint_file.with_suffix(f".tmp.{get_date_str()}.pt")
169 | if self.checkpoint_file.exists():
170 | self.checkpoint_file.rename(tmp_file) # Rename old checkpoint to temp file
171 | checkpoint = {
172 | "step": self.step,
173 | "model": self.accelerator.get_state_dict(self.diffusion_model),
174 | "opt": self.opt.state_dict(),
175 | "ema": self.ema.state_dict(),
176 | }
177 | torch.save(checkpoint, self.checkpoint_file)
178 | tmp_file.unlink(missing_ok=True) # Delete temp file
179 |
180 | def load_checkpoint(self):
181 | check_config_matches_checkpoint(self.cfg, self.path)
182 | data = torch.load(self.checkpoint_file, map_location=self.accelerator.device)
183 | self.step = data["step"]
184 | log(f"Resuming from checkpoint '{self.checkpoint_file}' (step {self.step})")
185 | model = self.accelerator.unwrap_model(self.diffusion_model)
186 | model.load_state_dict(data["model"])
187 | self.opt.load_state_dict(data["opt"])
188 | if self.accelerator.is_main_process:
189 | self.ema.load_state_dict(data["ema"])
190 |
191 | def train(self):
192 | with tqdm(
193 | initial=self.step,
194 | total=self.train_num_steps,
195 | disable=not self.accelerator.is_main_process,
196 | ) as pbar:
197 | while self.step < self.train_num_steps:
198 | data = next(self.train_dataloader)
199 | self.opt.zero_grad()
200 | loss, _ = self.diffusion_model(data)
201 | self.accelerator.backward(loss)
202 | if self.cfg.clip_grad_norm:
203 | self.accelerator.clip_grad_norm_(
204 | self.diffusion_model.parameters(), 1.0
205 | )
206 | self.opt.step()
207 | pbar.set_description(f"loss: {loss.item():.4f}")
208 | self.step += 1
209 | self.accelerator.wait_for_everyone()
210 | if self.accelerator.is_main_process:
211 | self.ema.update()
212 | if self.step % self.save_and_eval_every == 0:
213 | self.eval()
214 | pbar.update()
215 |
216 | @torch.no_grad()
217 | def eval(self):
218 | self.save_checkpoint()
219 | self.sample_images(self.ema.ema_model, is_ema=True)
220 | self.sample_images(self.diffusion_model, is_ema=False)
221 | self.evaluate_ema_model_and_log(validation=True)
222 | self.evaluate_ema_model_and_log(validation=False)
223 |
224 | def evaluate_ema_model_and_log(self, *, validation):
225 | evaluate_model_and_log(
226 | self.ema.ema_model,
227 | self.validation_dataloader if validation else self.train_eval_dataloader,
228 | self.path / "metrics_log.jsonl",
229 | "validation" if validation else "train",
230 | self.step,
231 | )
232 |
233 | def sample_images(self, model, *, is_ema):
234 | train_state = model.training
235 | model.eval()
236 | samples = sample_batched(
237 | self.accelerator.unwrap_model(model),
238 | self.num_samples,
239 | self.cfg.batch_size,
240 | self.n_sample_steps,
241 | self.clip_samples,
242 | )
243 | path = self.path / f"sample-{'ema-' if is_ema else ''}{self.step}.png"
244 | save_image(samples, str(path), nrow=int(math.sqrt(self.num_samples)))
245 | model.train(train_state)
246 |
247 |
248 | if __name__ == "__main__":
249 | main()
250 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import json
3 | import math
4 | import warnings
5 | from collections import defaultdict
6 | from dataclasses import dataclass
7 | from datetime import datetime
8 | from pathlib import Path
9 | from typing import Optional
10 |
11 | import numpy as np
12 | import torch
13 | import torchinfo
14 | import yaml
15 | from accelerate import Accelerator
16 | from torch import nn
17 | from torch.utils.data import DataLoader
18 | from torchvision import transforms
19 | from torchvision.datasets import CIFAR10
20 | from tqdm.auto import tqdm
21 |
22 |
23 | @dataclass
24 | class TrainConfig:
25 | embedding_dim: float
26 | n_blocks: int
27 | n_attention_heads: int
28 | dropout_prob: float
29 | norm_groups: int
30 | input_channels: int
31 | use_fourier_features: bool
32 | attention_everywhere: bool
33 | batch_size: int
34 | noise_schedule: str
35 | gamma_min: float
36 | gamma_max: float
37 | antithetic_time_sampling: bool
38 | lr: float
39 | weight_decay: float
40 | clip_grad_norm: bool
41 |
42 |
43 | def print_model_summary(model, *, batch_size, shape, depth=4, batch_size_torchinfo=1):
44 | summary = torchinfo.summary(
45 | model,
46 | [(batch_size_torchinfo, *shape), (batch_size_torchinfo,)],
47 | depth=depth,
48 | col_names=["input_size", "output_size", "num_params"],
49 | verbose=0, # quiet
50 | )
51 | log(summary)
52 | if batch_size is None or batch_size == batch_size_torchinfo:
53 | return
54 | output_bytes_large = summary.total_output_bytes / batch_size_torchinfo * batch_size
55 | total_bytes = summary.total_input + output_bytes_large + summary.total_param_bytes
56 | log(
57 | f"\n--- With batch size {batch_size} ---\n"
58 | f"Forward/backward pass size: {output_bytes_large / 1e9:0.2f} GB\n"
59 | f"Estimated Total Size: {total_bytes / 1e9:0.2f} GB\n"
60 | + "=" * len(str(summary).splitlines()[-1])
61 | + "\n"
62 | )
63 |
64 |
65 | def cycle(dl):
66 | # We don't use itertools.cycle because it caches the entire iterator.
67 | while True:
68 | for data in dl:
69 | yield data
70 |
71 |
72 | def has_int_squareroot(num):
73 | return (math.sqrt(num) ** 2) == num
74 |
75 |
76 | def sample_batched(model, num_samples, batch_size, n_sample_steps, clip_samples):
77 | samples = []
78 | for i in range(0, num_samples, batch_size):
79 | corrected_batch_size = min(batch_size, num_samples - i)
80 | samples.append(model.sample(corrected_batch_size, n_sample_steps, clip_samples))
81 | return torch.cat(samples, dim=0)
82 |
83 |
84 | def get_date_str():
85 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
86 |
87 |
88 | class DeviceAwareDataLoader(DataLoader):
89 | """A DataLoader that moves batches to a device. If device is None, it is equivalent to a standard DataLoader."""
90 |
91 | def __init__(self, *args, device=None, **kwargs):
92 | super().__init__(*args, **kwargs)
93 | self.device = device
94 |
95 | def __iter__(self):
96 | for batch in super().__iter__():
97 | yield self.move_to_device(batch)
98 |
99 | def move_to_device(self, batch):
100 | if self.device is None:
101 | return batch
102 | if isinstance(batch, (tuple, list)):
103 | return [self.move_to_device(x) for x in batch]
104 | elif isinstance(batch, dict):
105 | return {k: self.move_to_device(v) for k, v in batch.items()}
106 | elif isinstance(batch, torch.Tensor):
107 | return batch.to(self.device)
108 | else:
109 | return batch
110 |
111 |
112 | def evaluate_model(model, dataloader):
113 | all_metrics = defaultdict(list)
114 | for batch in tqdm(dataloader, desc="evaluation"):
115 | loss, metrics = model(batch)
116 | for k, v in metrics.items():
117 | try:
118 | v = v.item()
119 | except AttributeError:
120 | pass
121 | all_metrics[k].append(v)
122 | return {k: sum(v) / len(v) for k, v in all_metrics.items()} # average over dataset
123 |
124 |
125 | def log_and_save_metrics(avg_metrics, dataset_split, step, filename):
126 | log(f"\n{dataset_split} metrics:")
127 | for k, v in avg_metrics.items():
128 | log(f" {k}: {v}")
129 |
130 | avg_metrics = {"step": step, "set": dataset_split, **avg_metrics}
131 | with open(filename, "a") as f:
132 | json.dump(avg_metrics, f)
133 | f.write("\n")
134 |
135 |
136 | def dict_stats(dictionaries: list[dict]) -> dict:
137 | """Computes the average and standard deviation of metrics in a list of dictionaries.
138 |
139 | Args:
140 | dictionaries: A list of dictionaries, where each dictionary contains the same keys,
141 | and the values are numbers.
142 |
143 | Returns:
144 | A dictionary of the same keys as the input dictionaries, with the average and
145 | standard deviation of the values. If the list has length 1, the original dictionary
146 | is returned instead.
147 | """
148 | if len(dictionaries) == 1:
149 | return dictionaries[0]
150 |
151 | # Convert the list of dictionaries to a dictionary of lists.
152 | lists = defaultdict(list)
153 | for d in dictionaries:
154 | for k, v in d.items():
155 | lists[k].append(v)
156 |
157 | # Compute the average and standard deviation of each list.
158 | stats = {}
159 | for k, v in lists.items():
160 | stats[f"{k}_avg"] = np.mean(v)
161 | stats[f"{k}_std"] = np.std(v)
162 | return stats
163 |
164 |
165 | def evaluate_model_and_log(model, dataloader, filename, split, step=None, n=1):
166 | # Call evaluate_model multiple times. Each call returns a dictionary of metrics, and
167 | # we then compute their average and standard deviation.
168 | if n > 1:
169 | log(f"\nRunning {n} evaluations to compute average metrics")
170 | metrics = dict_stats([evaluate_model(model, dataloader) for _ in range(n)])
171 | log_and_save_metrics(metrics, split, step, filename)
172 |
173 |
174 | @torch.no_grad()
175 | def zero_init(module: nn.Module) -> nn.Module:
176 | """Sets to zero all the parameters of a module, and returns the module."""
177 | for p in module.parameters():
178 | nn.init.zeros_(p.data)
179 | return module
180 |
181 |
182 | def maybe_unpack_batch(batch):
183 | if isinstance(batch, (tuple, list)) and len(batch) == 2:
184 | return batch
185 | else:
186 | return batch, None
187 |
188 |
189 | def make_cifar(*, train, download):
190 | return CIFAR10(
191 | root="data",
192 | download=download,
193 | train=train,
194 | transform=transforms.Compose([transforms.ToTensor()]),
195 | )
196 |
197 |
198 | def handle_results_path(res_path: str, default_root: str = "./results") -> Path:
199 | """Sets results path if it doesn't exist yet."""
200 | if res_path is None:
201 | results_path = Path(default_root) / get_date_str()
202 | else:
203 | results_path = Path(res_path)
204 | log(f"Results will be saved to '{results_path}'")
205 | return results_path
206 |
207 |
208 | def unsqueeze_right(x, num_dims=1):
209 | """Unsqueezes the last `num_dims` dimensions of `x`."""
210 | return x.view(x.shape + (1,) * num_dims)
211 |
212 |
213 | def init_config_from_args(cls, args):
214 | """Initializes a dataclass from a Namespace, ignoring unknown fields."""
215 | return cls(**{f.name: getattr(args, f.name) for f in dataclasses.fields(cls)})
216 |
217 |
218 | def check_config_matches_checkpoint(config, checkpoint_path):
219 | with open(checkpoint_path / "config.yaml", "r") as f:
220 | ckpt_config = yaml.safe_load(f)
221 | config = dataclasses.asdict(config)
222 | if config != ckpt_config:
223 | config_str = "\n ".join(f"{k}: {config[k]}" for k in sorted(config))
224 | ckpt_str = "\n ".join(f"{k}: {ckpt_config[k]}" for k in sorted(ckpt_config))
225 | raise ValueError(
226 | f"Config mismatch:\n\n"
227 | f"> Config:\n {config_str}\n\n"
228 | f"> Checkpoint:\n {ckpt_str}\n\n"
229 | )
230 |
231 |
232 | _accelerator: Optional[Accelerator] = None
233 |
234 |
235 | def init_logger(accelerator: Accelerator):
236 | global _accelerator
237 | if _accelerator is not None:
238 | raise ValueError("Accelerator already set")
239 | _accelerator = accelerator
240 |
241 |
242 | def log(message):
243 | global _accelerator
244 | if _accelerator is None:
245 | warnings.warn("Accelerator not set, using print instead.")
246 | print_fn = print
247 | else:
248 | print_fn = _accelerator.print
249 | print_fn(message)
250 |
--------------------------------------------------------------------------------
/vdm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import allclose, argmax, autograd, exp, linspace, nn, sigmoid, sqrt
4 | from torch.special import expm1
5 | from tqdm import trange
6 |
7 | from utils import maybe_unpack_batch, unsqueeze_right
8 |
9 |
10 | class VDM(nn.Module):
11 | def __init__(self, model, cfg, image_shape):
12 | super().__init__()
13 | self.model = model
14 | self.cfg = cfg
15 | self.image_shape = image_shape
16 | self.vocab_size = 256
17 | if cfg.noise_schedule == "fixed_linear":
18 | self.gamma = FixedLinearSchedule(cfg.gamma_min, cfg.gamma_max)
19 | elif cfg.noise_schedule == "learned_linear":
20 | self.gamma = LearnedLinearSchedule(cfg.gamma_min, cfg.gamma_max)
21 | else:
22 | raise ValueError(f"Unknown noise schedule {cfg.noise_schedule}")
23 |
24 | @property
25 | def device(self):
26 | return next(self.model.parameters()).device
27 |
28 | @torch.no_grad()
29 | def sample_p_s_t(self, z, t, s, clip_samples):
30 | """Samples from p(z_s | z_t, x). Used for standard ancestral sampling."""
31 | gamma_t = self.gamma(t)
32 | gamma_s = self.gamma(s)
33 | c = -expm1(gamma_s - gamma_t)
34 | alpha_t = sqrt(sigmoid(-gamma_t))
35 | alpha_s = sqrt(sigmoid(-gamma_s))
36 | sigma_t = sqrt(sigmoid(gamma_t))
37 | sigma_s = sqrt(sigmoid(gamma_s))
38 |
39 | pred_noise = self.model(z, gamma_t)
40 | if clip_samples:
41 | x_start = (z - sigma_t * pred_noise) / alpha_t
42 | x_start.clamp_(-1.0, 1.0)
43 | mean = alpha_s * (z * (1 - c) / alpha_t + c * x_start)
44 | else:
45 | mean = alpha_s / alpha_t * (z - c * sigma_t * pred_noise)
46 | scale = sigma_s * sqrt(c)
47 | return mean + scale * torch.randn_like(z)
48 |
49 | @torch.no_grad()
50 | def sample(self, batch_size, n_sample_steps, clip_samples):
51 | z = torch.randn((batch_size, *self.image_shape), device=self.device)
52 | steps = linspace(1.0, 0.0, n_sample_steps + 1, device=self.device)
53 | for i in trange(n_sample_steps, desc="sampling"):
54 | z = self.sample_p_s_t(z, steps[i], steps[i + 1], clip_samples)
55 | logprobs = self.log_probs_x_z0(z_0=z) # (B, C, H, W, vocab_size)
56 | x = argmax(logprobs, dim=-1) # (B, C, H, W)
57 | return x.float() / (self.vocab_size - 1) # normalize to [0, 1]
58 |
59 | def sample_q_t_0(self, x, times, noise=None):
60 | """Samples from the distributions q(x_t | x_0) at the given time steps."""
61 | with torch.enable_grad(): # Need gradient to compute loss even when evaluating
62 | gamma_t = self.gamma(times)
63 | gamma_t_padded = unsqueeze_right(gamma_t, x.ndim - gamma_t.ndim)
64 | mean = x * sqrt(sigmoid(-gamma_t_padded)) # x * alpha
65 | scale = sqrt(sigmoid(gamma_t_padded))
66 | if noise is None:
67 | noise = torch.randn_like(x)
68 | return mean + noise * scale, gamma_t
69 |
70 | def sample_times(self, batch_size):
71 | if self.cfg.antithetic_time_sampling:
72 | t0 = np.random.uniform(0, 1 / batch_size)
73 | times = torch.arange(t0, 1.0, 1.0 / batch_size, device=self.device)
74 | else:
75 | times = torch.rand(batch_size, device=self.device)
76 | return times
77 |
78 | def forward(self, batch, *, noise=None):
79 | x, label = maybe_unpack_batch(batch)
80 | assert x.shape[1:] == self.image_shape
81 | assert 0.0 <= x.min() and x.max() <= 1.0
82 | bpd_factor = 1 / (np.prod(x.shape[1:]) * np.log(2))
83 |
84 | # Convert image to integers in range [0, vocab_size - 1].
85 | img_int = torch.round(x * (self.vocab_size - 1)).long()
86 | assert (img_int >= 0).all() and (img_int <= self.vocab_size - 1).all()
87 | # Check that the image was discrete with vocab_size values.
88 | assert allclose(img_int / (self.vocab_size - 1), x)
89 |
90 | # Rescale integer image to [-1 + 1/vocab_size, 1 - 1/vocab_size]
91 | x = 2 * ((img_int + 0.5) / self.vocab_size) - 1
92 |
93 | # Sample from q(x_t | x_0) with random t.
94 | times = self.sample_times(x.shape[0]).requires_grad_(True)
95 | if noise is None:
96 | noise = torch.randn_like(x)
97 | x_t, gamma_t = self.sample_q_t_0(x=x, times=times, noise=noise)
98 |
99 | # Forward through model
100 | model_out = self.model(x_t, gamma_t)
101 |
102 | # *** Diffusion loss (bpd)
103 | gamma_grad = autograd.grad( # gamma_grad shape: (B, )
104 | gamma_t, # (B, )
105 | times, # (B, )
106 | grad_outputs=torch.ones_like(gamma_t),
107 | create_graph=True,
108 | retain_graph=True,
109 | )[0]
110 | pred_loss = ((model_out - noise) ** 2).sum((1, 2, 3)) # (B, )
111 | diffusion_loss = 0.5 * pred_loss * gamma_grad * bpd_factor
112 |
113 | # *** Latent loss (bpd): KL divergence from N(0, 1) to q(z_1 | x)
114 | gamma_1 = self.gamma(torch.tensor([1.0], device=self.device))
115 | sigma_1_sq = sigmoid(gamma_1)
116 | mean_sq = (1 - sigma_1_sq) * x**2 # (alpha_1 * x)**2
117 | latent_loss = kl_std_normal(mean_sq, sigma_1_sq).sum((1, 2, 3)) * bpd_factor
118 |
119 | # *** Reconstruction loss (bpd): - E_{q(z_0 | x)} [log p(x | z_0)].
120 | # Compute log p(x | z_0) for all possible values of each pixel in x.
121 | log_probs = self.log_probs_x_z0(x) # (B, C, H, W, vocab_size)
122 | # One-hot representation of original image. Shape: (B, C, H, W, vocab_size).
123 | x_one_hot = torch.zeros((*x.shape, self.vocab_size), device=self.device)
124 | x_one_hot.scatter_(4, img_int.unsqueeze(-1), 1) # one-hot over last dim
125 | # Select the correct log probabilities.
126 | log_probs = (x_one_hot * log_probs).sum(-1) # (B, C, H, W)
127 | # Overall logprob for each image in batch.
128 | recons_loss = -log_probs.sum((1, 2, 3)) * bpd_factor
129 |
130 | # *** Overall loss in bpd. Shape (B, ).
131 | loss = diffusion_loss + latent_loss + recons_loss
132 |
133 | with torch.no_grad():
134 | gamma_0 = self.gamma(torch.tensor([0.0], device=self.device))
135 | metrics = {
136 | "bpd": loss.mean(),
137 | "diff_loss": diffusion_loss.mean(),
138 | "latent_loss": latent_loss.mean(),
139 | "loss_recon": recons_loss.mean(),
140 | "gamma_0": gamma_0.item(),
141 | "gamma_1": gamma_1.item(),
142 | }
143 | return loss.mean(), metrics
144 |
145 | def log_probs_x_z0(self, x=None, z_0=None):
146 | """Computes log p(x | z_0) for all possible values of x.
147 |
148 | Compute p(x_i | z_0i), with i = pixel index, for all possible values of x_i in
149 | the vocabulary. We approximate this with q(z_0i | x_i). Unnormalized logits are:
150 | -1/2 SNR_0 (z_0 / alpha_0 - k)^2
151 | where k takes all possible x_i values. Logits are then normalized to logprobs.
152 |
153 | The method returns a tensor of shape (B, C, H, W, vocab_size) containing, for
154 | each pixel, the log probabilities for all `vocab_size` possible values of that
155 | pixel. The output sums to 1 over the last dimension.
156 |
157 | The method accepts either `x` or `z_0` as input. If `z_0` is given, it is used
158 | directly. If `x` is given, a sample z_0 is drawn from q(z_0 | x). It's more
159 | efficient to pass `x` directly, if available.
160 |
161 | Args:
162 | x: Input image, shape (B, C, H, W).
163 | z_0: z_0 to be decoded, shape (B, C, H, W).
164 |
165 | Returns:
166 | log_probs: Log probabilities of shape (B, C, H, W, vocab_size).
167 | """
168 | gamma_0 = self.gamma(torch.tensor([0.0], device=self.device))
169 | if x is None and z_0 is not None:
170 | z_0_rescaled = z_0 / sqrt(sigmoid(-gamma_0)) # z_0 / alpha_0
171 | elif z_0 is None and x is not None:
172 | # Equal to z_0/alpha_0 with z_0 sampled from q(z_0 | x)
173 | z_0_rescaled = x + exp(0.5 * gamma_0) * torch.randn_like(x) # (B, C, H, W)
174 | else:
175 | raise ValueError("Must provide either x or z_0, not both.")
176 | z_0_rescaled = z_0_rescaled.unsqueeze(-1) # (B, C, H, W, 1)
177 | x_lim = 1 - 1 / self.vocab_size
178 | x_values = linspace(-x_lim, x_lim, self.vocab_size, device=self.device)
179 | logits = -0.5 * exp(-gamma_0) * (z_0_rescaled - x_values) ** 2 # broadcast x
180 | log_probs = torch.log_softmax(logits, dim=-1) # (B, C, H, W, vocab_size)
181 | return log_probs
182 |
183 |
184 | def kl_std_normal(mean_squared, var):
185 | return 0.5 * (var + mean_squared - torch.log(var.clamp(min=1e-15)) - 1.0)
186 |
187 |
188 | class FixedLinearSchedule(nn.Module):
189 | def __init__(self, gamma_min, gamma_max):
190 | super().__init__()
191 | self.gamma_min = gamma_min
192 | self.gamma_max = gamma_max
193 |
194 | def forward(self, t):
195 | return self.gamma_min + (self.gamma_max - self.gamma_min) * t
196 |
197 |
198 | class LearnedLinearSchedule(nn.Module):
199 | def __init__(self, gamma_min, gamma_max):
200 | super().__init__()
201 | self.b = nn.Parameter(torch.tensor(gamma_min))
202 | self.w = nn.Parameter(torch.tensor(gamma_max - gamma_min))
203 |
204 | def forward(self, t):
205 | return self.b + self.w.abs() * t
206 |
--------------------------------------------------------------------------------
/vdm_unet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import einsum, nn, pi, softmax
4 |
5 | from utils import zero_init
6 |
7 |
8 | class UNetVDM(nn.Module):
9 | def __init__(self, cfg):
10 | super().__init__()
11 | self.cfg = cfg
12 |
13 | attention_params = dict(
14 | n_heads=cfg.n_attention_heads,
15 | n_channels=cfg.embedding_dim,
16 | norm_groups=cfg.norm_groups,
17 | )
18 | resnet_params = dict(
19 | ch_in=cfg.embedding_dim,
20 | ch_out=cfg.embedding_dim,
21 | condition_dim=4 * cfg.embedding_dim,
22 | dropout_prob=cfg.dropout_prob,
23 | norm_groups=cfg.norm_groups,
24 | )
25 | if cfg.use_fourier_features:
26 | self.fourier_features = FourierFeatures()
27 | self.embed_conditioning = nn.Sequential(
28 | nn.Linear(cfg.embedding_dim, cfg.embedding_dim * 4),
29 | nn.SiLU(),
30 | nn.Linear(cfg.embedding_dim * 4, cfg.embedding_dim * 4),
31 | nn.SiLU(),
32 | )
33 | total_input_ch = cfg.input_channels
34 | if cfg.use_fourier_features:
35 | total_input_ch *= 1 + self.fourier_features.num_features
36 | self.conv_in = nn.Conv2d(total_input_ch, cfg.embedding_dim, 3, padding=1)
37 |
38 | # Down path: n_blocks blocks with a resnet block and maybe attention.
39 | self.down_blocks = nn.ModuleList(
40 | UpDownBlock(
41 | resnet_block=ResnetBlock(**resnet_params),
42 | attention_block=AttentionBlock(**attention_params)
43 | if cfg.attention_everywhere
44 | else None,
45 | )
46 | for _ in range(cfg.n_blocks)
47 | )
48 |
49 | self.mid_resnet_block_1 = ResnetBlock(**resnet_params)
50 | self.mid_attn_block = AttentionBlock(**attention_params)
51 | self.mid_resnet_block_2 = ResnetBlock(**resnet_params)
52 |
53 | # Up path: n_blocks+1 blocks with a resnet block and maybe attention.
54 | resnet_params["ch_in"] *= 2 # double input channels due to skip connections
55 | self.up_blocks = nn.ModuleList(
56 | UpDownBlock(
57 | resnet_block=ResnetBlock(**resnet_params),
58 | attention_block=AttentionBlock(**attention_params)
59 | if cfg.attention_everywhere
60 | else None,
61 | )
62 | for _ in range(cfg.n_blocks + 1)
63 | )
64 |
65 | self.conv_out = nn.Sequential(
66 | nn.GroupNorm(num_groups=cfg.norm_groups, num_channels=cfg.embedding_dim),
67 | nn.SiLU(),
68 | zero_init(nn.Conv2d(cfg.embedding_dim, cfg.input_channels, 3, padding=1)),
69 | )
70 |
71 | def forward(self, z, g_t):
72 | # Get gamma to shape (B, ).
73 | g_t = g_t.expand(z.shape[0]) # assume shape () or (1,) or (B,)
74 | assert g_t.shape == (z.shape[0],)
75 | # Rescale to [0, 1], but only approximately since gamma0 & gamma1 are not fixed.
76 | t = (g_t - self.cfg.gamma_min) / (self.cfg.gamma_max - self.cfg.gamma_min)
77 | t_embedding = get_timestep_embedding(t, self.cfg.embedding_dim)
78 | # We will condition on time embedding.
79 | cond = self.embed_conditioning(t_embedding)
80 |
81 | h = self.maybe_concat_fourier(z)
82 | h = self.conv_in(h) # (B, embedding_dim, H, W)
83 | hs = []
84 | for down_block in self.down_blocks: # n_blocks times
85 | hs.append(h)
86 | h = down_block(h, cond)
87 | hs.append(h)
88 | h = self.mid_resnet_block_1(h, cond)
89 | h = self.mid_attn_block(h)
90 | h = self.mid_resnet_block_2(h, cond)
91 | for up_block in self.up_blocks: # n_blocks+1 times
92 | h = torch.cat([h, hs.pop()], dim=1)
93 | h = up_block(h, cond)
94 | prediction = self.conv_out(h)
95 | assert prediction.shape == z.shape, (prediction.shape, z.shape)
96 | return prediction + z
97 |
98 | def maybe_concat_fourier(self, z):
99 | if self.cfg.use_fourier_features:
100 | return torch.cat([z, self.fourier_features(z)], dim=1)
101 | return z
102 |
103 |
104 | class ResnetBlock(nn.Module):
105 | def __init__(
106 | self,
107 | ch_in,
108 | ch_out=None,
109 | condition_dim=None,
110 | dropout_prob=0.0,
111 | norm_groups=32,
112 | ):
113 | super().__init__()
114 | ch_out = ch_in if ch_out is None else ch_out
115 | self.ch_out = ch_out
116 | self.condition_dim = condition_dim
117 | self.net1 = nn.Sequential(
118 | nn.GroupNorm(num_groups=norm_groups, num_channels=ch_in),
119 | nn.SiLU(),
120 | nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1),
121 | )
122 | if condition_dim is not None:
123 | self.cond_proj = zero_init(nn.Linear(condition_dim, ch_out, bias=False))
124 | self.net2 = nn.Sequential(
125 | nn.GroupNorm(num_groups=norm_groups, num_channels=ch_out),
126 | nn.SiLU(),
127 | *([nn.Dropout(dropout_prob)] * (dropout_prob > 0.0)),
128 | zero_init(nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1)),
129 | )
130 | if ch_in != ch_out:
131 | self.skip_conv = nn.Conv2d(ch_in, ch_out, kernel_size=1)
132 |
133 | def forward(self, x, condition):
134 | h = self.net1(x)
135 | if condition is not None:
136 | assert condition.shape == (x.shape[0], self.condition_dim)
137 | condition = self.cond_proj(condition)
138 | condition = condition[:, :, None, None]
139 | h = h + condition
140 | h = self.net2(h)
141 | if x.shape[1] != self.ch_out:
142 | x = self.skip_conv(x)
143 | assert x.shape == h.shape
144 | return x + h
145 |
146 |
147 | def get_timestep_embedding(
148 | timesteps,
149 | embedding_dim: int,
150 | dtype=torch.float32,
151 | max_timescale=10_000,
152 | min_timescale=1,
153 | ):
154 | # Adapted from tensor2tensor and VDM codebase.
155 | assert timesteps.ndim == 1
156 | assert embedding_dim % 2 == 0
157 | timesteps *= 1000.0 # In DDPM the time step is in [0, 1000], here [0, 1]
158 | num_timescales = embedding_dim // 2
159 | inv_timescales = torch.logspace( # or exp(-linspace(log(min), log(max), n))
160 | -np.log10(min_timescale),
161 | -np.log10(max_timescale),
162 | num_timescales,
163 | device=timesteps.device,
164 | )
165 | emb = timesteps.to(dtype)[:, None] * inv_timescales[None, :] # (T, D/2)
166 | return torch.cat([emb.sin(), emb.cos()], dim=1) # (T, D)
167 |
168 |
169 | class FourierFeatures(nn.Module):
170 | def __init__(self, first=5.0, last=6.0, step=1.0):
171 | super().__init__()
172 | self.freqs_exponent = torch.arange(first, last + 1e-8, step)
173 |
174 | @property
175 | def num_features(self):
176 | return len(self.freqs_exponent) * 2
177 |
178 | def forward(self, x):
179 | assert len(x.shape) >= 2
180 |
181 | # Compute (2pi * 2^n) for n in freqs.
182 | freqs_exponent = self.freqs_exponent.to(dtype=x.dtype, device=x.device) # (F, )
183 | freqs = 2.0**freqs_exponent * 2 * pi # (F, )
184 | freqs = freqs.view(-1, *([1] * (x.dim() - 1))) # (F, 1, 1, ...)
185 |
186 | # Compute (2pi * 2^n * x) for n in freqs.
187 | features = freqs * x.unsqueeze(1) # (B, F, X1, X2, ...)
188 | features = features.flatten(1, 2) # (B, F * C, X1, X2, ...)
189 |
190 | # Output features are cos and sin of above. Shape (B, 2 * F * C, H, W).
191 | return torch.cat([features.sin(), features.cos()], dim=1)
192 |
193 |
194 | def attention_inner_heads(qkv, num_heads):
195 | """Computes attention with heads inside of qkv in the channel dimension.
196 |
197 | Args:
198 | qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where:
199 | H = number of heads,
200 | C = number of channels per head.
201 | num_heads: number of heads.
202 |
203 | Returns:
204 | Attention output of shape (B, H*C, T).
205 | """
206 |
207 | bs, width, length = qkv.shape
208 | ch = width // (3 * num_heads)
209 |
210 | # Split into (q, k, v) of shape (B, H*C, T).
211 | q, k, v = qkv.chunk(3, dim=1)
212 |
213 | # Rescale q and k. This makes them contiguous in memory.
214 | scale = ch ** (-1 / 4) # scale with 4th root = scaling output by sqrt
215 | q = q * scale
216 | k = k * scale
217 |
218 | # Reshape qkv to (B*H, C, T).
219 | new_shape = (bs * num_heads, ch, length)
220 | q = q.view(*new_shape)
221 | k = k.view(*new_shape)
222 | v = v.reshape(*new_shape)
223 |
224 | # Compute attention.
225 | weight = einsum("bct,bcs->bts", q, k) # (B*H, T, T)
226 | weight = softmax(weight.float(), dim=-1).to(weight.dtype) # (B*H, T, T)
227 | out = einsum("bts,bcs->bct", weight, v) # (B*H, C, T)
228 | return out.reshape(bs, num_heads * ch, length) # (B, H*C, T)
229 |
230 |
231 | class Attention(nn.Module):
232 | """Based on https://github.com/openai/guided-diffusion."""
233 |
234 | def __init__(self, n_heads):
235 | super().__init__()
236 | self.n_heads = n_heads
237 |
238 | def forward(self, qkv):
239 | assert qkv.dim() >= 3, qkv.dim()
240 | assert qkv.shape[1] % (3 * self.n_heads) == 0
241 | spatial_dims = qkv.shape[2:]
242 | qkv = qkv.view(*qkv.shape[:2], -1) # (B, 3*H*C, T)
243 | out = attention_inner_heads(qkv, self.n_heads) # (B, H*C, T)
244 | return out.view(*out.shape[:2], *spatial_dims)
245 |
246 |
247 | class AttentionBlock(nn.Module):
248 | """Self-attention residual block."""
249 |
250 | def __init__(self, n_heads, n_channels, norm_groups):
251 | super().__init__()
252 | assert n_channels % n_heads == 0
253 | self.layers = nn.Sequential(
254 | nn.GroupNorm(num_groups=norm_groups, num_channels=n_channels),
255 | nn.Conv2d(n_channels, 3 * n_channels, kernel_size=1), # (B, 3 * C, H, W)
256 | Attention(n_heads),
257 | zero_init(nn.Conv2d(n_channels, n_channels, kernel_size=1)),
258 | )
259 |
260 | def forward(self, x):
261 | return self.layers(x) + x
262 |
263 |
264 | class UpDownBlock(nn.Module):
265 | def __init__(self, resnet_block, attention_block=None):
266 | super().__init__()
267 | self.resnet_block = resnet_block
268 | self.attention_block = attention_block
269 |
270 | def forward(self, x, cond):
271 | x = self.resnet_block(x, cond)
272 | if self.attention_block is not None:
273 | x = self.attention_block(x)
274 | return x
275 |
--------------------------------------------------------------------------------