├── .gitignore
├── README.md
├── aconfs
└── 1_node_1_gpu_ddp.yaml
├── adaptive_controller.py
├── demos
└── fpdm_inference.ipynb
├── diffusion
├── __init__.py
├── diffusion_utils.py
├── gaussian_diffusion.py
├── respace.py
└── timestep_sampler.py
├── environment.yml
├── models.py
├── sample.py
├── train.py
├── utils
└── imagenet-labels.json
└── visuals
└── splash-figure-v1.png
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # ONLY FOR ANONYMOUS CODE
3 | CODE_OF_CONDUCT.md
4 | CONTRIBUTING.md
5 | LICENSE.txt
6 | convert-to-diffusers/
7 | download.py
8 | env-luke.yaml
9 | explore/
10 | extract_features.py
11 | figures/
12 | performance/
13 | run_DiT.ipynb
14 | sample_ddp.py
15 | scripts/
16 | train_options/
17 |
18 |
19 | # Custom
20 | .vscode
21 | wandb
22 | outputs
23 | features
24 | tmp*
25 | slurm-logs
26 | slurm_logs
27 | results
28 | archive
29 | debug
30 | samples/*
31 | !samples/imagenet-labels.json
32 | samples-examples
33 |
34 | # Byte-compiled / optimized / DLL files
35 | __pycache__/
36 | *.py[cod]
37 | *$py.class
38 | .github
39 |
40 | # C extensions
41 | *.so
42 |
43 | # Distribution / packaging
44 | .Python
45 | build/
46 | develop-eggs/
47 | dist/
48 | downloads/
49 | eggs/
50 | .eggs/
51 | lib/
52 | lib64/
53 | parts/
54 | sdist/
55 | var/
56 | wheels/
57 | *.egg-info/
58 | .installed.cfg
59 | *.egg
60 | MANIFEST
61 |
62 | # Lightning /research
63 | test_tube_exp/
64 | tests/tests_tt_dir/
65 | tests/save_dir
66 | default/
67 | data/
68 | test_tube_logs/
69 | test_tube_data/
70 | datasets/
71 | model_weights/
72 | tests/save_dir
73 | tests/tests_tt_dir/
74 | processed/
75 | raw/
76 |
77 | # PyInstaller
78 | # Usually these files are written by a python script from a template
79 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
80 | *.manifest
81 | *.spec
82 |
83 | # Installer logs
84 | pip-log.txt
85 | pip-delete-this-directory.txt
86 |
87 | # Unit test / coverage reports
88 | htmlcov/
89 | .tox/
90 | .coverage
91 | .coverage.*
92 | .cache
93 | nosetests.xml
94 | coverage.xml
95 | *.cover
96 | .hypothesis/
97 | .pytest_cache/
98 |
99 | # Translations
100 | *.mo
101 | *.pot
102 |
103 | # Django stuff:
104 | *.log
105 | local_settings.py
106 | db.sqlite3
107 |
108 | # Flask stuff:
109 | instance/
110 | .webassets-cache
111 |
112 | # Scrapy stuff:
113 | .scrapy
114 |
115 | # Sphinx documentation
116 | docs/_build/
117 |
118 | # PyBuilder
119 | target/
120 |
121 | # Jupyter Notebook
122 | .ipynb_checkpoints
123 |
124 | # pyenv
125 | .python-version
126 |
127 | # celery beat schedule file
128 | celerybeat-schedule
129 |
130 | # SageMath parsed files
131 | *.sage.py
132 |
133 | # Environments
134 | .env
135 | .venv
136 | env/
137 | venv/
138 | ENV/
139 | env.bak/
140 | venv.bak/
141 |
142 | # Spyder project settings
143 | .spyderproject
144 | .spyproject
145 |
146 | # Rope project settings
147 | .ropeproject
148 |
149 | # mkdocs documentation
150 | /site
151 |
152 | # mypy
153 | .mypy_cache/
154 |
155 | # IDEs
156 | .idea
157 | .vscode
158 |
159 | # seed project
160 | lightning_logs/
161 | MNIST
162 | .DS_Store
163 | *.code-workspace
164 | vis.zip
165 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | [![Contributors][contributors-shield]][contributors-url]
7 | [![Forks][forks-shield]][forks-url]
8 | [![Stargazers][stars-shield]][stars-url]
9 | [![Issues][issues-shield]][issues-url]
10 |
11 | ### Fixed Point Diffusion Models
12 |
13 | [Project Page](https://lukemelas.github.io/fixed-point-diffusion-models/) · [Paper](https://arxiv.org/abs/2401.08741)
14 |
15 |
16 |
17 |
18 |
19 | 
20 |
21 | ### Table of Contents
22 | - [Abstract](#abstract)
23 | - [Setup & Installation](#setup)
24 | - [Model](#model)
25 | - [Training](#training)
26 | - [Sampling](#sampling)
27 | - [Contribution](#contribution)
28 | - [Acknowledgements](#acknowledgements)
29 |
30 | ### Roadmap
31 |
32 | - [x] Code and paper release 🎉🎉
33 | - [x] Jupyter notebook example
34 | - [ ] Pretrained model release _(coming soon)_
35 | - [ ] Code walkthrough and tutorial
36 |
37 | ### Abstract
38 |
39 | We introduce the Fixed Point Diffusion Model (FPDM), a novel approach to image generation that integrates the concept of fixed point solving into the framework of diffusion-based generative modeling. Our approach embeds an implicit fixed point solving layer into the denoising network of a diffusion model, transforming the diffusion process into a sequence of closely-related fixed point problems. Combined with a new stochastic training method, this approach significantly reduces model size, reduces memory usage, and accelerates training. Moreover, it enables the development of two new techniques to improve sampling efficiency: reallocating computation across timesteps and reusing fixed point solutions between timesteps. We conduct extensive experiments with state-of-the-art models on ImageNet, FFHQ, CelebA-HQ, and LSUN-Church, demonstrating substantial improvements in performance and efficiency. Compared to the state-of-the-art DiT model, FPDM contains 87% fewer parameters, consumes 60% less memory during training, and improves image generation quality in situations where sampling computation or time is limited.
40 |
41 | ### Setup
42 |
43 | We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
44 | to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.
45 |
46 | ```bash
47 | conda env create -f environment.yml
48 | conda activate DiT
49 | ```
50 |
51 | ### Model
52 |
53 | Our model definition, including all fixed point functionality, is included in `models.py`.
54 |
55 | ### Training
56 |
57 | Example training scripts:
58 | ```bash
59 | # Standard model
60 | accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py
61 |
62 | # Fixed Point Diffusion Model
63 | accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py --fixed_point True --deq_pre_depth 1 --deq_post_depth 1
64 |
65 | # With v-prediction and zero-SNR
66 | accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py --output_subdir v_pred_exp --predict_v True --use_zero_terminal_snr True --fixed_point True --deq_pre_depth 1 --deq_post_depth 1
67 |
68 | # With v-prediction and zero-SNR, with 4 pre- and post-layers
69 | accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py --output_subdir v_pred_exp --predict_v True --use_zero_terminal_snr True --fixed_point True --deq_pre_depth 4 --deq_post_depth 4
70 | ```
71 |
72 | ### Sampling
73 |
74 | Example sampling scripts:
75 | ```bash
76 | # Sample
77 | python sample.py --ckpt {checkpoint-path-from-above} --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --num_sampling_steps 20
78 |
79 | # Sample with fewer iterations per timestep and more timesteps
80 | python sample.py --ckpt {checkpoint-path-from-above} --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --fixed_point_iters 12 --num_sampling_steps 40 --fixed_point_reuse_solution True
81 | ```
82 |
83 | ### Contribution
84 |
85 | Pull requests are welcome!
86 |
87 | ### Acknowledgements
88 |
89 | * The strong baseline from DiT:
90 | ```
91 | @article{Peebles2022DiT,
92 | title={Scalable Diffusion Models with Transformers},
93 | author={William Peebles and Saining Xie},
94 | year={2022},
95 | journal={arXiv preprint arXiv:2212.09748},
96 | }
97 | ```
98 |
99 | * The fast-DiT code from [chuanyangjin](https://github.com/chuanyangjin/fast-DiT):
100 |
101 | * All the great work from the [CMU Locus Lab](https://github.com/locuslab) on Deep Equilibrium Models, which started with:
102 | ```
103 | @inproceedings{bai2019deep,
104 | author = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
105 | title = {Deep Equilibrium Models},
106 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
107 | year = {2019},
108 | }
109 | ```
110 |
111 | * L.M.K. thanks the Rhodes Trust for their scholarship support.
112 |
113 |
114 |
115 |
116 | [contributors-shield]: https://img.shields.io/github/contributors/lukemelas/fixed-point-diffusion-models.svg?style=for-the-badge
117 | [contributors-url]: https://github.com/lukemelas/fixed-point-diffusion-models/graphs/contributors
118 | [forks-shield]: https://img.shields.io/github/forks/lukemelas/fixed-point-diffusion-models.svg?style=for-the-badge
119 | [forks-url]: https://github.com/lukemelas/fixed-point-diffusion-models/network/members
120 | [stars-shield]: https://img.shields.io/github/stars/lukemelas/fixed-point-diffusion-models.svg?style=for-the-badge
121 | [stars-url]: https://github.com/lukemelas/fixed-point-diffusion-models/stargazers
122 | [issues-shield]: https://img.shields.io/github/issues/lukemelas/fixed-point-diffusion-models.svg?style=for-the-badge
123 | [issues-url]: https://github.com/lukemelas/fixed-point-diffusion-models/issues
124 | [license-shield]: https://img.shields.io/github/license/lukemelas/fixed-point-diffusion-models.svg?style=for-the-badge
125 | [license-url]: https://github.com/lukemelas/fixed-point-diffusion-models/blob/master/LICENSE.txt
126 |
--------------------------------------------------------------------------------
/aconfs/1_node_1_gpu_ddp.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | distributed_type: MULTI_GPU
3 | downcast_bf16: 'no'
4 | gpu_ids: all
5 | machine_rank: 0
6 | main_training_function: main
7 | mixed_precision: bf16
8 | num_machines: 1
9 | num_processes: 1
10 | main_process_port: 17337
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/adaptive_controller.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # calculate the norm (modified) of the delta
4 | def batched_distance(dt):
5 | dist = torch.norm(dt, p = 2, dim = (1, 2))
6 | dist = torch.mean(dist)
7 | return dist
8 |
9 | class LinearController():
10 | def __init__(self, budget, tot_steps, ratio = [0.5, 1.5], type = "increasing") -> None:
11 | if type == "increasing":
12 | self.ratio_list = torch.linspace(ratio[0], ratio[1], tot_steps)
13 | elif type == "decreasing":
14 | self.ratio_list = torch.linspace(ratio[1], ratio[0], tot_steps)
15 | elif type == "fixed":
16 | self.ratio_list = torch.ones(tot_steps)
17 |
18 | assert len(self.ratio_list) == tot_steps
19 | self.budget_list_float = 1 + (budget - tot_steps) * self.ratio_list / torch.sum(self.ratio_list)
20 |
21 | rounding_threshold = len(self.budget_list_float) // 2
22 | while True:
23 | self.budget_list = torch.zeros_like(self.budget_list_float)
24 | for i in range(len(self.budget_list_float)):
25 | # in the first half, round up
26 | # in the second half, round down
27 | # this ensures the sum of the budget is roughly equal to the total budget
28 | if i < rounding_threshold:
29 | self.budget_list[i] = torch.ceil(self.budget_list_float[i])
30 | else:
31 | self.budget_list[i] = torch.floor(self.budget_list_float[i])
32 | if torch.sum(self.budget_list) <= budget:
33 | break
34 | rounding_threshold -= 1
35 | print(f"fixed the rounding issue! in ", type)
36 |
37 | assert torch.sum(self.budget_list) <= budget
38 |
39 | self.pointer = len(self.budget_list) - 1
40 | self.threshold = None
41 | self.cost = None
42 | self.lowerbound = None
43 | self.upperbound = None
44 |
45 | def init_image(self):
46 | self.pointer = len(self.budget_list) - 1
47 | print(f"in Fix controller, budget list: {self.budget_list}")
48 | def end_image(self):
49 | pass
50 | def update(self):
51 | return True
52 | def get(self):
53 | ret = self.budget_list[self.pointer]
54 | self.pointer -= 1
55 | return int(ret)
56 |
57 |
58 | # define a class to analyis threshold
59 | class FixedController():
60 | def __init__(self, threshold) -> None:
61 | self.threshold = threshold
62 | self.cost = 0
63 |
64 | def get(self):
65 | return self.threshold
66 |
67 | def add_cost(self, cost):
68 | self.cost += cost
69 |
70 |
71 |
72 | # class ThresholdController():
73 | # def __init__(self, budget, ratio_delta = 0) -> None:
74 | # self.budget = budget
75 | # self.threshold = 100.0
76 | # self.last_threshold = -1
77 | # self.lowerbound = 0.9 * budget
78 | # self.upperbound = 1.0 * budget
79 | # self.success_list = []
80 | # self.all_threshold_list = []
81 | # self.try_count = []
82 | # self.costs = []
83 |
84 | # self.delta_init_ratio = 0.03
85 | # self.delta = self.threshold * self.delta_init_ratio
86 | # self.count = 0
87 | # self.thresholds = []
88 | # self.cost = 0
89 | # self.pivot = False
90 |
91 | # self.max_count = 20 # the max number of threshold update on each batch.
92 |
93 | # # retrieve the number
94 | # self.upperbound_ratio = 1.0 + ratio_delta
95 | # self.lowerbound_ratio = 1.0 - ratio_delta
96 |
97 | # def init_image(self):
98 | # if len(self.all_threshold_list) != 0:
99 | # self.threshold = self.success_list[len(self.success_list) // 2]
100 |
101 | # self.last_threshold = -1
102 | # self.delta = self.threshold * self.delta_init_ratio
103 | # self.count = 0
104 | # self.cost = 0
105 | # self.pivot = False
106 | # self.thresholds = []
107 | # return self.threshold
108 |
109 | # def end_image(self):
110 | # self.success_list.append(self.threshold)
111 | # self.try_count.append(self.count)
112 | # self.all_threshold_list.append(self.thresholds)
113 | # self.success_list.sort()
114 |
115 | # def get(self):
116 | # return self.threshold
117 |
118 | # def add_cost(self, cost):
119 | # self.cost += cost
120 |
121 | # def update(self):
122 | # # after finishing an image
123 | # self.thresholds.append(self.threshold)
124 | # # print(f"in update: thres={self.threshold}, last thres={self.last_threshold}")
125 | # if self.cost > self.upperbound:
126 | # if self.last_threshold != -1 and self.threshold < self.last_threshold: # decrease, increase
127 | # self.pivot = True
128 | # if self.cost < self.lowerbound:
129 | # if self.last_threshold != -1 and self.threshold > self.last_threshold: # increase, decrease
130 | # self.pivot = True
131 |
132 | # if self.pivot:
133 | # self.delta /= 2
134 | # else:
135 | # self.delta *= 2
136 |
137 |
138 | # self.last_threshold = self.threshold
139 | # if self.cost > self.upperbound:
140 | # self.threshold = self.threshold + self.delta
141 | # if self.cost < self.lowerbound:
142 | # self.threshold = self.threshold - self.delta
143 |
144 | # passed = (self.lowerbound <= self.cost) and (self.cost <= self.upperbound)
145 | # # TODO: special case
146 | # if self.count > self.max_count and self.cost <= self.lowerbound:
147 | # passed = True
148 | # print(f"passed={passed}")
149 |
150 | # if passed:
151 | # self.costs.append(self.cost)
152 |
153 | # self.cost = 0
154 | # self.count += 1
155 |
156 | # return passed
157 |
158 | # def output(self):
159 | # """
160 | # success_list: list of threshold, (B)
161 | # all_threshold_list: list of list of threshold, (B, *)
162 | # try_count: list of try count, (B)
163 | # """
164 | # return self.success_list, self.all_threshold_list, self.try_count, self.costs
165 |
166 |
167 |
168 |
169 |
--------------------------------------------------------------------------------
/demos/fpdm_inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Fixed Point Diffusion Models (FPDM)\n",
8 | "This notebook shows how to run the image sampling with FPDM."
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "metadata": {},
14 | "source": [
15 | "## Set Up\n",
16 | "We provide an environment.yml file that can be used to create a Conda environment. See how to install all required packages in `README.md`."
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "# Standard library imports\n",
26 | "import json\n",
27 | "import math\n",
28 | "import random\n",
29 | "import sys\n",
30 | "from contextlib import nullcontext\n",
31 | "from pathlib import Path\n",
32 | "from typing import Optional\n",
33 | "\n",
34 | "# Third-party imports\n",
35 | "import torch\n",
36 | "import torch.nn as nn\n",
37 | "from PIL import Image\n",
38 | "from torch import Tensor\n",
39 | "from torch.utils.checkpoint import checkpoint\n",
40 | "from accelerate import Accelerator\n",
41 | "from accelerate.utils import set_seed\n",
42 | "from diffusers.models import AutoencoderKL\n",
43 | "from jaxtyping import Float, Shaped\n",
44 | "from tap import Tap\n",
45 | "from tqdm import trange\n",
46 | "from timm.models.vision_transformer import PatchEmbed, Attention, Mlp\n",
47 | "\n",
48 | "# Local module imports\n",
49 | "sys.path.append(\"..\")\n",
50 | "from diffusion import create_diffusion\n",
51 | "from download import find_model\n",
52 | "from models import DiT_models\n"
53 | ]
54 | },
55 | {
56 | "cell_type": "markdown",
57 | "metadata": {},
58 | "source": [
59 | "## Hyperparameters"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "class Args(Tap):\n",
69 | " \"\"\"\n",
70 | " A class to define and store hyperparameters and configurations for the demo.\n",
71 | " \"\"\"\n",
72 | "\n",
73 | " # File and directory paths.\n",
74 | " output_dir: str = 'demo_samples'\n",
75 | "\n",
76 | " # Dataset configuration.\n",
77 | " dataset_name: str = \"imagenet256\"\n",
78 | "\n",
79 | " # Model specific parameters.\n",
80 | " model: str = \"DiT-XL/2\"\n",
81 | " vae: str = \"mse\"\n",
82 | " num_classes: int = 1000\n",
83 | " image_size: int = 256\n",
84 | " predict_v: bool = False\n",
85 | " use_zero_terminal_snr: bool = False\n",
86 | " unsupervised: bool = False\n",
87 | " dino_supervised: bool = False\n",
88 | " dino_supervised_dim: int = 768\n",
89 | " flow: bool = False\n",
90 | " debug: bool = False\n",
91 | "\n",
92 | " # Fixed Point settings.\n",
93 | " fixed_point: bool = False\n",
94 | " fixed_point_pre_depth: int = 2\n",
95 | " fixed_point_post_depth: int = 2\n",
96 | " fixed_point_iters: Optional[int] = None\n",
97 | " fixed_point_pre_post_timestep_conditioning: bool = False\n",
98 | " fixed_point_reuse_solution: bool = False\n",
99 | "\n",
100 | " # Sampling configuration.\n",
101 | " ddim: bool = False\n",
102 | " cfg_scale: float = 4.0\n",
103 | " num_sampling_steps: int = 250\n",
104 | " batch_size: int = 4\n",
105 | " ckpt: str = '/work/xingjian/diff-deq-inference/pretrained/DiT-XL-2/checkpoints/0500000.pt' # replace it with the Path to your checkpoint.\n",
106 | " global_seed: int = 0\n",
107 | "\n",
108 | " # Parallelization settings.\n",
109 | " sample_index_start: int = 0\n",
110 | " sample_index_end: Optional[int] = 32\n",
111 | "\n",
112 | " def process_args(self):\n",
113 | " \"\"\"\n",
114 | " Method for additional argument processing and validation.\n",
115 | " \"\"\"\n",
116 | " # Debug mode configuration.\n",
117 | " if self.debug:\n",
118 | " self.log_with = 'tensorboard'\n",
119 | " self.name = 'debug'\n",
120 | "\n",
121 | " # Set default values and validate image size.\n",
122 | " self.fixed_point_iters = self.fixed_point_iters or (28 - self.fixed_point_pre_depth - self.fixed_point_post_depth)\n",
123 | " assert self.image_size % 8 == 0, \"Image size must be divisible by 8 (for the VAE encoder).\"\n",
124 | " self.latent_size = self.H_lat = self.W_lat = self.image_size // 8\n",
125 | "\n",
126 | " # Additional checks and validations.\n",
127 | " if self.cfg_scale < 1.0:\n",
128 | " raise ValueError(\"In almost all cases, cfg_scale should be >= 1.0\")\n",
129 | " \n",
130 | " if self.unsupervised:\n",
131 | " assert self.cfg_scale == 1.0\n",
132 | " self.num_classes = 1\n",
133 | " elif self.dino_supervised:\n",
134 | " raise NotImplementedError()\n",
135 | " \n",
136 | " if not Path(self.ckpt).is_file():\n",
137 | " raise ValueError(self.ckpt)\n",
138 | " \n",
139 | " # Creating the output directory.\n",
140 | " output_parent = Path(self.output_dir) / Path(self.ckpt).parent.parent.name\n",
141 | " if self.debug:\n",
142 | " output_dirname = 'debug'\n",
143 | " else:\n",
144 | " output_dirname = f'num_sampling_steps-{self.num_sampling_steps}--cfg_scale-{self.cfg_scale}'\n",
145 | " if self.fixed_point:\n",
146 | " output_dirname += f'--fixed_point_iters-{self.fixed_point_iters}--fixed_point_reuse_solution-{self.fixed_point_reuse_solution}--fixed_point_pptc-{self.fixed_point_pre_post_timestep_conditioning}'\n",
147 | " if self.ddim:\n",
148 | " output_dirname += f'--ddim'\n",
149 | " self.output_dir = str(output_parent / output_dirname)\n",
150 | " Path(self.output_dir).mkdir(exist_ok=True, parents=True)\n"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "metadata": {},
157 | "outputs": [],
158 | "source": [
159 | "class Args(Tap):\n",
160 | "\n",
161 | " # Paths\n",
162 | " output_dir: str = 'samples'\n",
163 | "\n",
164 | " # Dataset\n",
165 | " dataset_name: str = \"imagenet256\"\n",
166 | "\n",
167 | " # Model\n",
168 | " model: str = \"DiT-XL/2\"\n",
169 | " vae: str = \"mse\"\n",
170 | " num_classes: int = 1000\n",
171 | " image_size: int = 256\n",
172 | " predict_v: bool = False\n",
173 | " use_zero_terminal_snr: bool = False\n",
174 | " unsupervised: bool = False\n",
175 | " dino_supervised: bool = False\n",
176 | " dino_supervised_dim: int = 768\n",
177 | " flow: bool = False\n",
178 | " debug: bool = False\n",
179 | "\n",
180 | " # Fixed Point settings\n",
181 | " fixed_point: bool = False\n",
182 | " fixed_point_pre_depth: int = 2\n",
183 | " fixed_point_post_depth: int = 2\n",
184 | " fixed_point_iters: Optional[int] = None\n",
185 | " fixed_point_pre_post_timestep_conditioning: bool = False\n",
186 | " fixed_point_reuse_solution: bool = False\n",
187 | "\n",
188 | " # Sampling\n",
189 | " ddim: bool = False\n",
190 | " cfg_scale: float = 4.0\n",
191 | " num_sampling_steps: int = 250\n",
192 | " batch_size: int = 4\n",
193 | " ckpt: str = '/work/xingjian/diff-deq-inference/pretrained/DiT-XL-2/checkpoints/0500000.pt' # replace with path to checkpoint\n",
194 | " global_seed: int = 0\n",
195 | " \n",
196 | " # Parallelization\n",
197 | " sample_index_start: int = 0\n",
198 | " sample_index_end: Optional[int] = 32\n",
199 | "\n",
200 | " def process_args(self) -> None:\n",
201 | " \"\"\"Additional argument processing\"\"\"\n",
202 | " if self.debug:\n",
203 | " self.log_with = 'tensorboard'\n",
204 | " self.name = 'debug'\n",
205 | "\n",
206 | " # Defaults\n",
207 | " self.fixed_point_iters = self.fixed_point_iters or (28 - self.fixed_point_pre_depth - self.fixed_point_post_depth)\n",
208 | " assert self.image_size % 8 == 0, \"Image size must be divisible by 8 (for the VAE encoder).\"\n",
209 | " self.latent_size = self.H_lat = self.W_lat = self.image_size // 8\n",
210 | " # Checks\n",
211 | " if self.cfg_scale < 1.0:\n",
212 | " raise ValueError(\"In almost all cases, cfg_scale should be >= 1.0\")\n",
213 | " if self.unsupervised:\n",
214 | " assert self.cfg_scale == 1.0\n",
215 | " self.num_classes = 1\n",
216 | " elif self.dino_supervised:\n",
217 | " raise NotImplementedError()\n",
218 | " if not Path(self.ckpt).is_file():\n",
219 | " raise ValueError(self.ckpt)\n",
220 | "\n",
221 | " # Create output directory\n",
222 | " output_parent = Path(self.output_dir) / Path(self.ckpt).parent.parent.name\n",
223 | " if self.debug:\n",
224 | " output_dirname = 'debug'\n",
225 | " else:\n",
226 | " output_dirname = f'num_sampling_steps-{self.num_sampling_steps}--cfg_scale-{self.cfg_scale}'\n",
227 | " if self.fixed_point:\n",
228 | " output_dirname += f'--fixed_point_iters-{self.fixed_point_iters}--fixed_point_reuse_solution-{self.fixed_point_reuse_solution}--fixed_point_pptc-{self.fixed_point_pre_post_timestep_conditioning}'\n",
229 | " if self.ddim:\n",
230 | " output_dirname += f'--ddim'\n",
231 | " self.output_dir = str(output_parent / output_dirname)\n",
232 | " Path(self.output_dir).mkdir(exist_ok=True, parents=True)"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "metadata": {},
239 | "outputs": [],
240 | "source": [
241 | "args = Args()\n",
242 | "args.process_args()"
243 | ]
244 | },
245 | {
246 | "cell_type": "markdown",
247 | "metadata": {},
248 | "source": []
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "metadata": {},
253 | "source": [
254 | "## Network architecture\n",
255 | "We modify the original DiT class to support fixed point blocks."
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": null,
261 | "metadata": {},
262 | "outputs": [],
263 | "source": [
264 | "\n",
265 | "from models import TimestepEmbedder, LabelEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed\n",
266 | "class DiT(nn.Module):\n",
267 | " \"\"\"\n",
268 | " Diffusion model with a Transformer backbone. It includes methods for the forward pass\n",
269 | " and initialization of weights. The model can operate in both standard and fixed-point modes.\n",
270 | " \"\"\"\n",
271 | " def __init__(\n",
272 | " self,\n",
273 | " input_size=32,\n",
274 | " patch_size=2,\n",
275 | " in_channels=4,\n",
276 | " hidden_size=1152,\n",
277 | " depth=28,\n",
278 | " num_heads=16,\n",
279 | " mlp_ratio=4.0,\n",
280 | " class_dropout_prob=0.1,\n",
281 | " num_classes=1000,\n",
282 | " learn_sigma=True,\n",
283 | " use_cfg_embedding: bool = True,\n",
284 | " use_gradient_checkpointing: bool = True,\n",
285 | " is_label_continuous: bool = False,\n",
286 | "\n",
287 | " # below are Fixed Point-specific arguments.\n",
288 | " fixed_point: bool = False,\n",
289 | "\n",
290 | " # size\n",
291 | " fixed_point_pre_depth: int = 1, \n",
292 | " fixed_point_post_depth: int = 1, \n",
293 | "\n",
294 | " # iteration counts\n",
295 | " fixed_point_no_grad_min_iters: int = 0, \n",
296 | " fixed_point_no_grad_max_iters: int = 0,\n",
297 | " fixed_point_with_grad_min_iters: int = 28, \n",
298 | " fixed_point_with_grad_max_iters: int = 28,\n",
299 | "\n",
300 | " # solution recycle\n",
301 | " fixed_point_reuse_solution = False,\n",
302 | " \n",
303 | " # pre_post_timestep_conditioning\n",
304 | " fixed_point_pre_post_timestep_conditioning: bool = True,\n",
305 | " ):\n",
306 | " super().__init__()\n",
307 | " self.learn_sigma = learn_sigma\n",
308 | " self.in_channels = in_channels\n",
309 | " self.out_channels = in_channels * 2 if learn_sigma else in_channels\n",
310 | " self.patch_size = patch_size\n",
311 | " self.num_heads = num_heads\n",
312 | " self.use_gradient_checkpointing = use_gradient_checkpointing\n",
313 | "\n",
314 | " self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)\n",
315 | " self.t_embedder = TimestepEmbedder(hidden_size)\n",
316 | " self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, \n",
317 | " use_cfg_embedding=use_cfg_embedding, continuous=is_label_continuous)\n",
318 | " num_patches = self.x_embedder.num_patches\n",
319 | " \n",
320 | " # Will use fixed sin-cos embedding:\n",
321 | " # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)\n",
322 | " self.register_buffer('pos_embed', torch.zeros(1, num_patches, hidden_size))\n",
323 | "\n",
324 | " # New: Fixed Point\n",
325 | " self.fixed_point = fixed_point\n",
326 | " if self.fixed_point:\n",
327 | " self.fixed_point_no_grad_min_iters = fixed_point_no_grad_min_iters\n",
328 | " self.fixed_point_no_grad_max_iters = fixed_point_no_grad_max_iters\n",
329 | " self.fixed_point_with_grad_min_iters = fixed_point_with_grad_min_iters\n",
330 | " self.fixed_point_with_grad_max_iters = fixed_point_with_grad_max_iters\n",
331 | " self.fixed_point_pre_post_timestep_conditioning = fixed_point_pre_post_timestep_conditioning\n",
332 | " self.blocks_pre = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(fixed_point_pre_depth)])\n",
333 | " self.block_pre_projection = nn.Linear(hidden_size, hidden_size)\n",
334 | " self.block_fixed_point_projection_fc1 = nn.Linear(2 * hidden_size, 2 * hidden_size)\n",
335 | " self.block_fixed_point_projection_act = nn.GELU(approximate=\"tanh\")\n",
336 | " self.block_fixed_point_projection_fc2 = nn.Linear(2 * hidden_size, hidden_size)\n",
337 | " self.block_fixed_point = DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)\n",
338 | " self.blocks_post = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(fixed_point_post_depth)])\n",
339 | " self.blocks = [*self.blocks_pre, self.block_fixed_point, *self.blocks_post]\n",
340 | " self.fixed_point_reuse_solution = fixed_point_reuse_solution\n",
341 | " self.last_solution = None\n",
342 | " else:\n",
343 | " self.blocks = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])\n",
344 | " self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)\n",
345 | " self.initialize_weights()\n",
346 | "\n",
347 | " def initialize_weights(self):\n",
348 | " # Initialize transformer layers:\n",
349 | " def _basic_init(module):\n",
350 | " if isinstance(module, nn.Linear):\n",
351 | " torch.nn.init.xavier_uniform_(module.weight)\n",
352 | " if module.bias is not None:\n",
353 | " nn.init.constant_(module.bias, 0)\n",
354 | " self.apply(_basic_init)\n",
355 | "\n",
356 | " # Initialize (and freeze) pos_embed by sin-cos embedding:\n",
357 | " pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))\n",
358 | " self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))\n",
359 | "\n",
360 | " # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):\n",
361 | " w = self.x_embedder.proj.weight.data\n",
362 | " nn.init.xavier_uniform_(w.view([w.shape[0], -1]))\n",
363 | " nn.init.constant_(self.x_embedder.proj.bias, 0)\n",
364 | "\n",
365 | " # Initialize label embedding table:\n",
366 | " if self.y_embedder.continuous:\n",
367 | " nn.init.normal_(self.y_embedder.embedding_projection.weight, std=0.02)\n",
368 | " nn.init.constant_(self.y_embedder.embedding_projection.bias, 0)\n",
369 | " else:\n",
370 | " nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)\n",
371 | "\n",
372 | " # Initialize timestep embedding MLP:\n",
373 | " nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)\n",
374 | " nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)\n",
375 | "\n",
376 | " # Zero-out adaLN modulation layers in DiT blocks:\n",
377 | " for block in self.blocks:\n",
378 | " nn.init.constant_(block.adaLN_modulation[-1].weight, 0)\n",
379 | " nn.init.constant_(block.adaLN_modulation[-1].bias, 0)\n",
380 | "\n",
381 | " # Zero-out output layers:\n",
382 | " nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)\n",
383 | " nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)\n",
384 | " nn.init.constant_(self.final_layer.linear.weight, 0)\n",
385 | " nn.init.constant_(self.final_layer.linear.bias, 0)\n",
386 | "\n",
387 | " def unpatchify(self, x):\n",
388 | " \"\"\"\n",
389 | " Reshapes the patches back to image format.\n",
390 | " x: (N, T, patch_size**2 * C)\n",
391 | " imgs: (N, H, W, C)\n",
392 | " \"\"\"\n",
393 | " c = self.out_channels\n",
394 | " p = self.x_embedder.patch_size[0]\n",
395 | " h = w = int(x.shape[1] ** 0.5)\n",
396 | " assert h * w == x.shape[1]\n",
397 | "\n",
398 | " x = x.reshape(shape=(x.shape[0], h, w, p, p, c))\n",
399 | " x = torch.einsum('nhwpqc->nchpwq', x)\n",
400 | " imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))\n",
401 | " return imgs\n",
402 | " \n",
403 | " def ckpt_wrapper(self, module):\n",
404 | " \"\"\"\n",
405 | " Wrapper function for gradient checkpointing.\n",
406 | " \"\"\"\n",
407 | " def ckpt_forward(*inputs):\n",
408 | " outputs = module(*inputs)\n",
409 | " return outputs\n",
410 | " return ckpt_forward\n",
411 | "\n",
412 | " def _forward_dit(self, x, t, y):\n",
413 | " x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2\n",
414 | " t = self.t_embedder(t) # (N, D)\n",
415 | " y = self.y_embedder(y, self.training) # (N, D))\n",
416 | " c = t + y # (N, D)\n",
417 | " for block in self.blocks:\n",
418 | " x = checkpoint(self.ckpt_wrapper(block), x, c) if self.use_gradient_checkpointing else block(x, c) # (N, T, D)\n",
419 | " x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)\n",
420 | " x = self.unpatchify(x) # (N, out_channels, H, W)\n",
421 | " return x\n",
422 | " \n",
423 | " def _forward_fixed_point_blocks(\n",
424 | " self, x: Float[Tensor, \"b t d\"], x_input_injection: Float[Tensor, \"b t d\"], c: Float[Tensor, \"b d\"], num_iterations: int\n",
425 | " ) -> Float[Tensor, \"b t d\"]:\n",
426 | " for _ in range(num_iterations):\n",
427 | " x = torch.cat((x, x_input_injection), dim=-1) # (N, T, D * 2)\n",
428 | " x = self.block_fixed_point_projection_fc1(x) # (N, T, D * 2)\n",
429 | " x = self.block_fixed_point_projection_act(x) # (N, T, D * 2)\n",
430 | " x = self.block_fixed_point_projection_fc2(x) # (N, T, D)\n",
431 | " x = self.block_fixed_point(x, c) # (N, T, D)\n",
432 | " return x\n",
433 | " \n",
434 | " def _check_inputs(self, x: Float[Tensor, \"b c h w\"], t: Shaped[Tensor, \"b\"], y: Shaped[Tensor, \"b\"]) -> None:\n",
435 | " if self.fixed_point_reuse_solution:\n",
436 | " if not torch.all(t[0] == t).item():\n",
437 | " raise ValueError(t)\n",
438 | "\n",
439 | " def _forward_fixed_point(self, x: Float[Tensor, \"b c h w\"], t: Shaped[Tensor, \"b\"], y: Shaped[Tensor, \"b\"]) -> Float[Tensor, \"b c h w\"]:\n",
440 | " self._check_inputs(x, t, y)\n",
441 | " x: Float[Tensor, \"b t d\"] = self.x_embedder(x) + self.pos_embed\n",
442 | " t_emb: Float[Tensor, \"b d\"] = self.t_embedder(t)\n",
443 | " y: Float[Tensor, \"b d\"] = self.y_embedder(y, self.training)\n",
444 | " c: Float[Tensor, \"b d\"] = t_emb + y\n",
445 | " c_pre_post_fixed_point: Float[Tensor, \"b d\"] = (t_emb + y) if self.fixed_point_pre_post_timestep_conditioning else y\n",
446 | " \n",
447 | " # Pre-Fixed Point\n",
448 | " # Note: If using DDP with find_unused_parameters=True, checkpoint causes issues. For more \n",
449 | " # information, see https://github.com/allenai/longformer/issues/63#issuecomment-648861503\n",
450 | " for block in self.blocks_pre:\n",
451 | " x: Float[Tensor, \"b t d\"] = checkpoint(self.ckpt_wrapper(block), x, c_pre_post_fixed_point) if self.use_gradient_checkpointing else block(x, c_pre_post_fixed_point)\n",
452 | " condition = x.clone()\n",
453 | "\n",
454 | " # Whether to reuse the previous solution at the next iteration\n",
455 | " init_solution = self.last_solution if (self.fixed_point_reuse_solution and self.last_solution is not None) else x.clone()\n",
456 | "\n",
457 | " # Fixed Point (we have condition and init_solution)\n",
458 | " x_input_injection = self.block_pre_projection(condition)\n",
459 | "\n",
460 | " # NOTE: This section of code should have no_grad, but cannot due to a DDP bug. See\n",
461 | " # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594\n",
462 | " # for more information\n",
463 | " with nullcontext(): # we use x.detach() in place of torch.no_grad due to DDP issue\n",
464 | " num_iterations_no_grad = random.randint(self.fixed_point_no_grad_min_iters, self.fixed_point_no_grad_max_iters)\n",
465 | " x = self._forward_fixed_point_blocks(x=init_solution.detach(), x_input_injection=x_input_injection.detach(), c=c, num_iterations=num_iterations_no_grad)\n",
466 | " x = x.detach() # no grad\n",
467 | " num_iterations_with_grad = random.randint(self.fixed_point_with_grad_min_iters, self.fixed_point_with_grad_max_iters)\n",
468 | " x = self._forward_fixed_point_blocks(x=x, x_input_injection=x_input_injection, c=c, num_iterations=num_iterations_with_grad)\n",
469 | "\n",
470 | " # Save solution for reuse at next step\n",
471 | " if self.fixed_point_reuse_solution:\n",
472 | " self.last_solution = x.clone()\n",
473 | " \n",
474 | " # Post-Fixed Point\n",
475 | " for block in self.blocks_post:\n",
476 | " x = checkpoint(self.ckpt_wrapper(block), x, c_pre_post_fixed_point) if self.use_gradient_checkpointing else block(x, c_pre_post_fixed_point)\n",
477 | " \n",
478 | " # Output\n",
479 | " x: Float[Tensor, \"b t p2c\"] = self.final_layer(x, c_pre_post_fixed_point) # p2c = patch_size ** 2 * out_channels)\n",
480 | " x: Float[Tensor, \"b c h w\"] = self.unpatchify(x)\n",
481 | " return x\n",
482 | " \n",
483 | " def reset(self):\n",
484 | " self.last_solution = None\n",
485 | " \n",
486 | " def forward(self, x, t, y):\n",
487 | " \"\"\"\n",
488 | " General forward pass method which handles both standard and fixed point modes.\n",
489 | " x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)\n",
490 | " t: (N,) tensor of diffusion timesteps\n",
491 | " y: (N,) tensor of class labels\n",
492 | " \"\"\"\n",
493 | " if self.fixed_point:\n",
494 | " return self._forward_fixed_point(x, t, y)\n",
495 | " else:\n",
496 | " return self._forward_dit(x, t, y)\n",
497 | "\n",
498 | " def forward_with_cfg(self, x, t, y, cfg_scale):\n",
499 | " \"\"\"\n",
500 | " Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.\n",
501 | " \"\"\"\n",
502 | " # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb\n",
503 | " half = x[: len(x) // 2]\n",
504 | " combined = torch.cat([half, half], dim=0)\n",
505 | " model_out = self.forward(combined, t, y)\n",
506 | " # For exact reproducibility reasons, we apply classifier-free guidance on only\n",
507 | " # three channels by default. The standard approach to cfg applies it to all channels.\n",
508 | " # This can be done by uncommenting the following line and commenting-out the line following that.\n",
509 | " # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]\n",
510 | " eps, rest = model_out[:, :3], model_out[:, 3:]\n",
511 | " cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)\n",
512 | " half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)\n",
513 | " eps = torch.cat([half_eps, half_eps], dim=0)\n",
514 | " return torch.cat([eps, rest], dim=1)"
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "execution_count": null,
520 | "metadata": {},
521 | "outputs": [],
522 | "source": [
523 | "# Initialize the Accelerator for GPU/CPU acceleration and create the DiT model.\n",
524 | "accelerator = Accelerator()\n",
525 | "model = DiT_models[args.model](\n",
526 | " input_size=args.latent_size,\n",
527 | " num_classes=(args.dino_supervised_dim if args.dino_supervised else args.num_classes),\n",
528 | " is_label_continuous=args.dino_supervised,\n",
529 | " class_dropout_prob=0,\n",
530 | " learn_sigma=(not args.flow), # TODO: Implement learned variance for flow-based models\n",
531 | " use_gradient_checkpointing=False,\n",
532 | " fixed_point=args.fixed_point,\n",
533 | " fixed_point_pre_depth=args.fixed_point_pre_depth,\n",
534 | " fixed_point_post_depth=args.fixed_point_post_depth,\n",
535 | " fixed_point_no_grad_min_iters=0, \n",
536 | " fixed_point_no_grad_max_iters=0,\n",
537 | " fixed_point_with_grad_min_iters=args.fixed_point_iters, \n",
538 | " fixed_point_with_grad_max_iters=args.fixed_point_iters,\n",
539 | " fixed_point_reuse_solution=args.fixed_point_reuse_solution,\n",
540 | " fixed_point_pre_post_timestep_conditioning=args.fixed_point_pre_post_timestep_conditioning,\n",
541 | " ).to(accelerator.device)"
542 | ]
543 | },
544 | {
545 | "cell_type": "markdown",
546 | "metadata": {},
547 | "source": [
548 | "## Load Model"
549 | ]
550 | },
551 | {
552 | "cell_type": "code",
553 | "execution_count": null,
554 | "metadata": {},
555 | "outputs": [],
556 | "source": [
557 | "# Load the pre-trained model checkpoint.\n",
558 | "state_dict = find_model(args.ckpt)\n",
559 | "model.load_state_dict(state_dict)\n",
560 | "model.eval() \n",
561 | "\n",
562 | "# Initialize the diffusion process with specified parameters.\n",
563 | "diffusion = create_diffusion(\n",
564 | " str(args.num_sampling_steps), \n",
565 | " use_flow=args.flow,\n",
566 | " predict_v=args.predict_v,\n",
567 | " use_zero_terminal_snr=args.use_zero_terminal_snr,\n",
568 | ")\n",
569 | "\n",
570 | "# Load the VAE model and evaluate it.\n",
571 | "vae = AutoencoderKL.from_pretrained(f\"stabilityai/sd-vae-ft-{args.vae}\").to(accelerator.device).eval()\n"
572 | ]
573 | },
574 | {
575 | "cell_type": "markdown",
576 | "metadata": {},
577 | "source": [
578 | "## Inference"
579 | ]
580 | },
581 | {
582 | "cell_type": "code",
583 | "execution_count": null,
584 | "metadata": {},
585 | "outputs": [],
586 | "source": [
587 | "# create generator, class labels, and latents\n",
588 | "N_images = 32\n",
589 | "generator = torch.Generator(device=accelerator.device)\n",
590 | "generator.manual_seed(args.global_seed)\n",
591 | "class_labels = torch.randint(0, args.num_classes, size=(N_images,), device=accelerator.device)\n",
592 | "generator.manual_seed(args.global_seed)\n",
593 | "latents = torch.randn(N_images, model.in_channels, args.H_lat, args.W_lat, device=accelerator.device, generator=generator)\n",
594 | "class_labels = class_labels[args.sample_index_start:args.sample_index_end]\n",
595 | "latents = latents[args.sample_index_start:args.sample_index_end]\n",
596 | "indices = list(range(args.sample_index_start, args.sample_index_end))\n",
597 | "print(f'Using pseudorandom class labels and latents (start={args.sample_index_start} and end={args.sample_index_end})')\n",
598 | "\n",
599 | " # Create output path\n",
600 | "output_dir = Path(args.output_dir)\n",
601 | "# if cfg is used\n",
602 | "using_cfg = args.cfg_scale > 1.0"
603 | ]
604 | },
605 | {
606 | "cell_type": "code",
607 | "execution_count": null,
608 | "metadata": {},
609 | "outputs": [],
610 | "source": [
611 | "# Load class labels for helpful filenames\n",
612 | "if args.dataset_name == 'imagenet256':\n",
613 | " with open(\"../utils/imagenet-labels.json\", \"r\") as f:\n",
614 | " label_names: list[str] = json.load(f)\n",
615 | " label_names = [l.lower().replace(' ', '-').replace('\\'', '') for l in label_names]\n",
616 | "elif args.unsupervised:\n",
617 | " assert args.cfg_scale == 1.0\n",
618 | " label_names = [\"unlabeled\"]\n",
619 | "else:\n",
620 | " raise NotImplementedError()"
621 | ]
622 | },
623 | {
624 | "cell_type": "code",
625 | "execution_count": null,
626 | "metadata": {},
627 | "outputs": [],
628 | "source": [
629 | "with torch.inference_mode():\n",
630 | " # Sample loop\n",
631 | " num_batches = math.ceil(len(class_labels) / args.batch_size)\n",
632 | " for batch_idx in trange(num_batches, disable=(not accelerator.is_main_process)):\n",
633 | "\n",
634 | " # Get pre-sampled inputs\n",
635 | " z = latents[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size]\n",
636 | " y = class_labels[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size]\n",
637 | " idxs = indices[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size]\n",
638 | " output_paths = [output_dir / f'{idx:05d}--{y_i:03d}--{label_names[y_i]}.png' for y_i, idx in zip(y.tolist(), idxs)]\n",
639 | "\n",
640 | " # Skip files that already exist\n",
641 | " if all(output_path.is_file() for output_path in output_paths):\n",
642 | " print(f'Files already exist (batch {batch_idx}). Skipping.')\n",
643 | " continue\n",
644 | "\n",
645 | " # Setup classifier-free guidance\n",
646 | " if using_cfg:\n",
647 | " y_null = torch.tensor([1000] * args.batch_size, device=accelerator.device)\n",
648 | " y = torch.cat([y, y_null], 0)\n",
649 | " z = torch.cat([z, z], 0)\n",
650 | " model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)\n",
651 | " sample_fn = model.forward_with_cfg\n",
652 | " else:\n",
653 | " model_kwargs = dict(y=y)\n",
654 | " sample_fn = model.forward\n",
655 | "\n",
656 | " # Sample latent images\n",
657 | " sample_kwargs = dict(model=sample_fn, shape=z.shape, noise=z, clip_denoised=False, model_kwargs=model_kwargs, \n",
658 | " progress=False, device=accelerator.device)\n",
659 | " if args.ddim:\n",
660 | " samples = diffusion.ddim_sample_loop(**sample_kwargs)\n",
661 | " else:\n",
662 | " samples = diffusion.p_sample_loop(**sample_kwargs)\n",
663 | "\n",
664 | " if using_cfg:\n",
665 | " samples, _ = samples.chunk(2, dim=0)\n",
666 | " \n",
667 | " # Reset model (resets the initial solution to None)\n",
668 | " model.reset()\n",
669 | "\n",
670 | " # Decode latents\n",
671 | " samples = vae.decode(samples / vae.config.scaling_factor).sample\n",
672 | " samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to(\"cpu\", dtype=torch.uint8).numpy()\n",
673 | "\n",
674 | " # Save samples to disk as individual .png files\n",
675 | " for sample, output_path in zip(samples, output_paths):\n",
676 | " Image.fromarray(sample).save(output_path)"
677 | ]
678 | },
679 | {
680 | "cell_type": "markdown",
681 | "metadata": {},
682 | "source": [
683 | "## Visualization"
684 | ]
685 | },
686 | {
687 | "cell_type": "code",
688 | "execution_count": null,
689 | "metadata": {},
690 | "outputs": [],
691 | "source": [
692 | "import matplotlib.pyplot as plt\n",
693 | "from PIL import Image\n",
694 | "from pathlib import Path\n",
695 | "import json\n",
696 | "\n",
697 | "# get all files in output_dir\n",
698 | "output_dir = Path(args.output_dir)\n",
699 | "files = list(output_dir.glob('*.png'))\n",
700 | "\n",
701 | "# visualize four random samples\n",
702 | "fig, axs = plt.subplots(2, 2, figsize=(10, 10))\n",
703 | "for ax in axs.flatten():\n",
704 | " img = Image.open(files[random.randint(0, len(files) - 1)])\n",
705 | " ax.imshow(img)\n",
706 | " ax.axis('off')\n",
707 | "plt.tight_layout()\n",
708 | "plt.show()"
709 | ]
710 | }
711 | ],
712 | "metadata": {
713 | "kernelspec": {
714 | "display_name": "renaissance3",
715 | "language": "python",
716 | "name": "python3"
717 | },
718 | "language_info": {
719 | "codemirror_mode": {
720 | "name": "ipython",
721 | "version": 3
722 | },
723 | "file_extension": ".py",
724 | "mimetype": "text/x-python",
725 | "name": "python",
726 | "nbconvert_exporter": "python",
727 | "pygments_lexer": "ipython3",
728 | "version": "3.10.13"
729 | }
730 | },
731 | "nbformat": 4,
732 | "nbformat_minor": 2
733 | }
734 |
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_flow=False,
14 | use_kl=False,
15 | sigma_small=False,
16 | predict_xstart=False,
17 | predict_v=False,
18 | learn_sigma=True,
19 | rescale_learned_sigmas=False,
20 | use_zero_terminal_snr=False,
21 | diffusion_steps=1000
22 | ):
23 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps, )
24 | if use_zero_terminal_snr:
25 | assert predict_v, "please use v-prediction if using zero terminal snr"
26 | betas = gd.enforce_zero_terminal_snr(betas).numpy()
27 | print('Rescaled betas to enforce zero terminal snr')
28 | if use_kl:
29 | loss_type = gd.LossType.RESCALED_KL
30 | elif use_flow:
31 | loss_type = gd.LossType.FLOW
32 | # TODO: Implement learned variance for flow-based models
33 | learn_sigma = False
34 | elif rescale_learned_sigmas:
35 | loss_type = gd.LossType.RESCALED_MSE
36 | else:
37 | loss_type = gd.LossType.MSE
38 | if timestep_respacing is None or timestep_respacing == "":
39 | timestep_respacing = [diffusion_steps]
40 | return SpacedDiffusion(
41 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
42 | betas=betas,
43 | model_mean_type=(
44 | gd.ModelMeanType.V if predict_v else
45 | gd.ModelMeanType.START_X if predict_xstart else
46 | gd.ModelMeanType.EPSILON
47 | ),
48 | model_var_type=(
49 | (
50 | gd.ModelVarType.FIXED_LARGE
51 | if not sigma_small
52 | else gd.ModelVarType.FIXED_SMALL
53 | )
54 | if not learn_sigma
55 | else gd.ModelVarType.LEARNED_RANGE
56 | ),
57 | loss_type=loss_type
58 | # rescale_timesteps=rescale_timesteps,
59 | )
60 |
--------------------------------------------------------------------------------
/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/diffusion/gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 |
7 | import math
8 |
9 | import numpy as np
10 | import torch as th
11 | import enum
12 |
13 | from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14 |
15 |
16 | def mean_flat(tensor):
17 | """
18 | Take the mean over all non-batch dimensions.
19 | """
20 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
21 |
22 |
23 | class ModelMeanType(enum.Enum):
24 | """
25 | Which type of output the model predicts.
26 | """
27 |
28 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29 | START_X = enum.auto() # the model predicts x_0
30 | EPSILON = enum.auto() # the model predicts epsilon
31 | V = enum.auto() # the model predicts v == sqrt(alpha) * eps - sqrt(1-alpha) * x
32 |
33 |
34 | class ModelVarType(enum.Enum):
35 | """
36 | What is used as the model's output variance.
37 | The LEARNED_RANGE option has been added to allow the model to predict
38 | values between FIXED_SMALL and FIXED_LARGE, making its job easier.
39 | """
40 |
41 | LEARNED = enum.auto()
42 | FIXED_SMALL = enum.auto()
43 | FIXED_LARGE = enum.auto()
44 | LEARNED_RANGE = enum.auto()
45 |
46 |
47 | class LossType(enum.Enum):
48 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
49 | RESCALED_MSE = (
50 | enum.auto()
51 | ) # use raw MSE loss (with RESCALED_KL when learning variances)
52 | KL = enum.auto() # use the variational lower-bound
53 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
54 | FLOW = enum.auto() # new: flow-based loss
55 |
56 | def is_vb(self):
57 | return self == LossType.KL or self == LossType.RESCALED_KL
58 |
59 |
60 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
61 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
62 | warmup_time = int(num_diffusion_timesteps * warmup_frac)
63 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
64 | return betas
65 |
66 |
67 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
68 | """
69 | This is the deprecated API for creating beta schedules.
70 | See get_named_beta_schedule() for the new library of schedules.
71 | """
72 | if beta_schedule == "quad":
73 | betas = (
74 | np.linspace(
75 | beta_start ** 0.5,
76 | beta_end ** 0.5,
77 | num_diffusion_timesteps,
78 | dtype=np.float64,
79 | )
80 | ** 2
81 | )
82 | elif beta_schedule == "linear":
83 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
84 | elif beta_schedule == "warmup10":
85 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
86 | elif beta_schedule == "warmup50":
87 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
88 | elif beta_schedule == "const":
89 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
90 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
91 | betas = 1.0 / np.linspace(
92 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
93 | )
94 | else:
95 | raise NotImplementedError(beta_schedule)
96 | assert betas.shape == (num_diffusion_timesteps,)
97 | return betas
98 |
99 |
100 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
101 | """
102 | Get a pre-defined beta schedule for the given name.
103 | The beta schedule library consists of beta schedules which remain similar
104 | in the limit of num_diffusion_timesteps.
105 | Beta schedules may be added, but should not be removed or changed once
106 | they are committed to maintain backwards compatibility.
107 | """
108 | if schedule_name == "linear":
109 | # Linear schedule from Ho et al, extended to work for any number of
110 | # diffusion steps.
111 | scale = 1000 / num_diffusion_timesteps
112 | betas = get_beta_schedule(
113 | "linear",
114 | beta_start=scale * 0.0001,
115 | beta_end=scale * 0.02,
116 | num_diffusion_timesteps=num_diffusion_timesteps,
117 | )
118 | elif schedule_name == "squaredcos_cap_v2":
119 | betas = betas_for_alpha_bar(
120 | num_diffusion_timesteps,
121 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
122 | )
123 | else:
124 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
125 | return betas
126 |
127 | from typing import Union
128 | def enforce_zero_terminal_snr(betas: Union[th.Tensor, np.ndarray]) -> th.Tensor:
129 | # def enforce_zero_terminal_snr(betas: th.Tensor | np.ndarray) -> th.Tensor:
130 | betas = betas if th.is_tensor(betas) else th.from_numpy(betas)
131 |
132 | # Convert betas to alphas_bar_sqrt
133 | alphas = 1 - betas
134 | alphas_bar = alphas.cumprod(0)
135 | alphas_bar_sqrt = alphas_bar.sqrt()
136 |
137 | # Store old values.
138 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
139 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
140 | # Shift so last timestep is zero.
141 | alphas_bar_sqrt -= alphas_bar_sqrt_T
142 | # Scale so first timestep is back to old value.
143 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
144 |
145 | # Convert alphas_bar_sqrt to betas
146 | alphas_bar = alphas_bar_sqrt ** 2
147 | alphas = alphas_bar[1:] / alphas_bar[:-1]
148 | alphas = th.cat([alphas_bar[0:1], alphas])
149 | betas = 1 - alphas
150 | return betas
151 |
152 |
153 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
154 | """
155 | Create a beta schedule that discretizes the given alpha_t_bar function,
156 | which defines the cumulative product of (1-beta) over time from t = [0,1].
157 | :param num_diffusion_timesteps: the number of betas to produce.
158 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
159 | produces the cumulative product of (1-beta) up to that
160 | part of the diffusion process.
161 | :param max_beta: the maximum beta to use; use values lower than 1 to
162 | prevent singularities.
163 | """
164 | betas = []
165 | for i in range(num_diffusion_timesteps):
166 | t1 = i / num_diffusion_timesteps
167 | t2 = (i + 1) / num_diffusion_timesteps
168 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
169 | return np.array(betas)
170 |
171 |
172 | class GaussianDiffusion:
173 | """
174 | Utilities for training and sampling diffusion models.
175 | Original ported from this codebase:
176 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
177 | :param betas: a 1-D numpy array of betas for each diffusion timestep,
178 | starting at T and going to 1.
179 | """
180 |
181 | def __init__(
182 | self,
183 | *,
184 | betas,
185 | model_mean_type,
186 | model_var_type,
187 | loss_type,
188 | weight_schedule: str = "p2"
189 | ):
190 |
191 | self.model_mean_type = model_mean_type
192 | self.model_var_type = model_var_type
193 | self.loss_type = loss_type
194 |
195 | # Use float64 for accuracy.
196 | betas = np.array(betas, dtype=np.float64)
197 | self.betas = betas
198 | assert len(betas.shape) == 1, "betas must be 1-D"
199 | assert (betas > 0).all() and (betas <= 1).all()
200 |
201 | self.num_timesteps = int(betas.shape[0])
202 |
203 | alphas = 1.0 - betas
204 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
205 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
206 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
207 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
208 |
209 | # calculations for diffusion q(x_t | x_{t-1}) and others
210 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
211 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
212 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
213 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
214 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
215 |
216 | # calculations for posterior q(x_{t-1} | x_t, x_0)
217 | self.posterior_variance = (
218 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
219 | )
220 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
221 | self.posterior_log_variance_clipped = np.log(
222 | np.append(self.posterior_variance[1], self.posterior_variance[1:])
223 | ) if len(self.posterior_variance) > 1 else np.array([])
224 |
225 | self.posterior_mean_coef1 = (
226 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
227 | )
228 | self.posterior_mean_coef2 = (
229 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
230 | )
231 |
232 | self.snrs = self.alphas_cumprod / self.sqrt_one_minus_alphas_cumprod ** 2
233 |
234 | self.weight_schedule = weight_schedule
235 |
236 | def q_mean_variance(self, x_start, t):
237 | """
238 | Get the distribution q(x_t | x_0).
239 | :param x_start: the [N x C x ...] tensor of noiseless inputs.
240 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
241 | :return: A tuple (mean, variance, log_variance), all of x_start's shape.
242 | """
243 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
244 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
245 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
246 | return mean, variance, log_variance
247 |
248 | def q_sample(self, x_start, t, noise=None):
249 | """
250 | Diffuse the data for a given number of diffusion steps.
251 | In other words, sample from q(x_t | x_0).
252 | :param x_start: the initial data batch.
253 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
254 | :param noise: if specified, the split-out normal noise.
255 | :return: A noisy version of x_start.
256 | """
257 | if noise is None:
258 | noise = th.randn_like(x_start)
259 | assert noise.shape == x_start.shape
260 | if self.loss_type == LossType.FLOW:
261 | t_float = t[:, None, None, None] / self.num_timesteps
262 | # return x_start * t_float + (1 - t_float) * noise # old models
263 | return x_start * (1 - t_float) + t_float * noise
264 | else:
265 | return (
266 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
267 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
268 | )
269 |
270 | def q_posterior_mean_variance(self, x_start, x_t, t):
271 | """
272 | Compute the mean and variance of the diffusion posterior:
273 | q(x_{t-1} | x_t, x_0)
274 | """
275 | assert x_start.shape == x_t.shape
276 | posterior_mean = (
277 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
278 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
279 | )
280 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
281 | posterior_log_variance_clipped = _extract_into_tensor(
282 | self.posterior_log_variance_clipped, t, x_t.shape
283 | )
284 | assert (
285 | posterior_mean.shape[0]
286 | == posterior_variance.shape[0]
287 | == posterior_log_variance_clipped.shape[0]
288 | == x_start.shape[0]
289 | )
290 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
291 |
292 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
293 | """
294 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
295 | the initial x, x_0.
296 | :param model: the model, which takes a signal and a batch of timesteps
297 | as input.
298 | :param x: the [N x C x ...] tensor at time t.
299 | :param t: a 1-D Tensor of timesteps.
300 | :param clip_denoised: if True, clip the denoised signal into [-1, 1].
301 | :param denoised_fn: if not None, a function which applies to the
302 | x_start prediction before it is used to sample. Applies before
303 | clip_denoised.
304 | :param model_kwargs: if not None, a dict of extra keyword arguments to
305 | pass to the model. This can be used for conditioning.
306 | :return: a dict with the following keys:
307 | - 'mean': the model mean output.
308 | - 'variance': the model variance output.
309 | - 'log_variance': the log of 'variance'.
310 | - 'pred_xstart': the prediction for x_0.
311 | """
312 | if model_kwargs is None:
313 | model_kwargs = {}
314 |
315 | B, C = x.shape[:2]
316 | assert t.shape == (B,), f'{t.shape = } and {B = }'
317 | model_output = model(x, t, **model_kwargs)
318 | if isinstance(model_output, tuple):
319 | model_output, extra = model_output
320 | else:
321 | extra = None
322 |
323 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
324 | assert self.loss_type != LossType.FLOW # TODO: Implement learned variance for flow-based models
325 | assert model_output.shape == (B, C * 2, *x.shape[2:])
326 | model_output, model_var_values = th.split(model_output, C, dim=1)
327 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
328 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
329 | # The model_var_values is [-1, 1] for [min_var, max_var].
330 | frac = (model_var_values + 1) / 2
331 | model_log_variance = frac * max_log + (1 - frac) * min_log
332 | model_variance = th.exp(model_log_variance)
333 | else:
334 | model_variance, model_log_variance = {
335 | # for fixedlarge, we set the initial (log-)variance like so
336 | # to get a better decoder log likelihood.
337 | ModelVarType.FIXED_LARGE: (
338 | np.append(self.posterior_variance[1], self.betas[1:]),
339 | np.log(np.append(self.posterior_variance[1], self.betas[1:])),
340 | ),
341 | ModelVarType.FIXED_SMALL: (
342 | self.posterior_variance,
343 | self.posterior_log_variance_clipped,
344 | ),
345 | }[self.model_var_type]
346 | model_variance = _extract_into_tensor(model_variance, t, x.shape)
347 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
348 |
349 | def process_xstart(x):
350 | if denoised_fn is not None:
351 | x = denoised_fn(x)
352 | if clip_denoised:
353 | return x.clamp(-1, 1)
354 | return x
355 |
356 | if self.loss_type == LossType.FLOW:
357 | velocity = model_output
358 | model_mean = x + velocity / self.num_timesteps # this is the simpler version, without clipping
359 | t_float = t[:, None, None, None] / self.num_timesteps # the input t is (T, 0]
360 | pred_xstart = process_xstart(x + velocity * t_float)
361 | # new_model_mean = x + (pred_xstart - x) / t_float / self.num_timesteps
362 | else:
363 | if self.model_mean_type == ModelMeanType.PREVIOUS_X:
364 | pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output))
365 | model_mean = model_output
366 | elif self.model_mean_type == ModelMeanType.START_X:
367 | pred_xstart = process_xstart(model_output)
368 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
369 | elif self.model_mean_type == ModelMeanType.V:
370 | pred_xstart = process_xstart(self._predict_xstart_from_v(x_t=x, t=t, v=model_output))
371 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
372 | elif self.model_mean_type == ModelMeanType.EPSILON:
373 | pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
374 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
375 | else:
376 | raise ValueError(self.model_mean_type)
377 |
378 |
379 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
380 | return {
381 | "mean": model_mean,
382 | "variance": model_variance,
383 | "model_output": model_output,
384 | "log_variance": model_log_variance,
385 | "pred_xstart": pred_xstart,
386 | "extra": extra,
387 | }
388 |
389 | def _predict_xstart_from_eps(self, x_t, t, eps):
390 | assert x_t.shape == eps.shape
391 | return (
392 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
393 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
394 | )
395 |
396 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
397 | return (
398 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
399 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
400 |
401 | def _predict_xstart_from_xprev(self, x_t, t, xprev):
402 | # (xprev - coef2*x_t) / coef1
403 | assert x_t.shape == xprev.shape
404 | return (_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev -
405 | _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t)
406 |
407 | def _predict_xstart_from_v(self, x_t, t, v):
408 | # x0_pred = sqrt(alpha) * x_t - sqrt(1-alpha) * v_pred
409 | assert x_t.shape == v.shape
410 | return (_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
411 | _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v)
412 |
413 | def _predict_eps_from_v(self, x_t, t, v):
414 | # https://github.com/huggingface/diffusers/blob/6a89a6c93ae38927097f5181030e3ceb7de7f43d/src/diffusers/schedulers/scheduling_ddim.py#L416-L429
415 | # eps = (alpha_prod_t**0.5) * model_output_v_pred + (beta_prod_t**0.5) * sample
416 | assert x_t.shape == v.shape
417 | return (_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
418 | _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t)
419 |
420 |
421 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
422 | """
423 | Compute the mean for the previous step, given a function cond_fn that
424 | computes the gradient of a conditional log probability with respect to
425 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
426 | condition on y.
427 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
428 | """
429 | gradient = cond_fn(x, t, **model_kwargs)
430 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
431 | return new_mean
432 |
433 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
434 | """
435 | Compute what the p_mean_variance output would have been, should the
436 | model's score function be conditioned by cond_fn.
437 | See condition_mean() for details on cond_fn.
438 | Unlike condition_mean(), this instead uses the conditioning strategy
439 | from Song et al (2020).
440 | """
441 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
442 |
443 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
444 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
445 |
446 | out = p_mean_var.copy()
447 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
448 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
449 | return out
450 |
451 | def p_sample(
452 | self,
453 | model,
454 | x,
455 | t,
456 | clip_denoised=True,
457 | denoised_fn=None,
458 | cond_fn=None,
459 | model_kwargs=None,
460 | ):
461 | """
462 | Sample x_{t-1} from the model at the given timestep.
463 | :param model: the model to sample from.
464 | :param x: the current tensor at x_{t-1}.
465 | :param t: the value of t, starting at 0 for the first diffusion step.
466 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
467 | :param denoised_fn: if not None, a function which applies to the
468 | x_start prediction before it is used to sample.
469 | :param cond_fn: if not None, this is a gradient function that acts
470 | similarly to the model.
471 | :param model_kwargs: if not None, a dict of extra keyword arguments to
472 | pass to the model. This can be used for conditioning.
473 | :return: a dict containing the following keys:
474 | - 'sample': a random sample from the model.
475 | - 'pred_xstart': a prediction of x_0.
476 | """
477 | out = self.p_mean_variance(
478 | model,
479 | x,
480 | t,
481 | clip_denoised=clip_denoised,
482 | denoised_fn=denoised_fn,
483 | model_kwargs=model_kwargs,
484 | )
485 | # print(f'Running ddpm with {t = }')
486 | noise = th.randn_like(x)
487 | nonzero_mask = (
488 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
489 | ) # no noise when t == 0
490 | if cond_fn is not None:
491 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
492 | if self.loss_type == LossType.FLOW:
493 | sample = out["mean"]
494 | else:
495 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
496 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
497 |
498 | def p_sample_loop(
499 | self,
500 | model,
501 | shape,
502 | noise=None,
503 | clip_denoised=True,
504 | denoised_fn=None,
505 | cond_fn=None,
506 | model_kwargs=None,
507 | device=None,
508 | progress=False,
509 | return_all=False,
510 | ):
511 | """
512 | Generate samples from the model.
513 | :param model: the model module.
514 | :param shape: the shape of the samples, (N, C, H, W).
515 | :param noise: if specified, the noise from the encoder to sample.
516 | Should be of the same shape as `shape`.
517 | :param clip_denoised: if True, clip x_start predictions to [-1, 1].
518 | :param denoised_fn: if not None, a function which applies to the
519 | x_start prediction before it is used to sample.
520 | :param cond_fn: if not None, this is a gradient function that acts
521 | similarly to the model.
522 | :param model_kwargs: if not None, a dict of extra keyword arguments to
523 | pass to the model. This can be used for conditioning.
524 | :param device: if specified, the device to create the samples on.
525 | If not specified, use a model parameter's device.
526 | :param progress: if True, show a tqdm progress bar.
527 | :return: a non-differentiable batch of samples.
528 | """
529 | final = None
530 | all_samples = []
531 | for sample in self.p_sample_loop_progressive(
532 | model,
533 | shape,
534 | noise=noise,
535 | clip_denoised=clip_denoised,
536 | denoised_fn=denoised_fn,
537 | cond_fn=cond_fn,
538 | model_kwargs=model_kwargs,
539 | device=device,
540 | progress=progress,
541 | ):
542 | final = sample
543 | all_samples.append(sample)
544 | return (final["sample"], all_samples) if return_all else final["sample"]
545 |
546 | def p_sample_loop_progressive(
547 | self,
548 | model,
549 | shape,
550 | noise=None,
551 | clip_denoised=True,
552 | denoised_fn=None,
553 | cond_fn=None,
554 | model_kwargs=None,
555 | device=None,
556 | progress=False,
557 | ):
558 | """
559 | Generate samples from the model and yield intermediate samples from
560 | each timestep of diffusion.
561 | Arguments are the same as p_sample_loop().
562 | Returns a generator over dicts, where each dict is the return value of
563 | p_sample().
564 | """
565 | if device is None:
566 | device = next(model.parameters()).device
567 | assert isinstance(shape, (tuple, list))
568 | if noise is not None:
569 | img = noise
570 | else:
571 | img = th.randn(*shape, device=device)
572 | indices = list(range(self.num_timesteps))[::-1]
573 |
574 | if progress:
575 | # Lazy import so that we don't depend on tqdm.
576 | from tqdm.auto import tqdm
577 |
578 | indices = tqdm(indices)
579 |
580 | for i in indices:
581 | t = th.tensor([i] * shape[0], device=device)
582 | with th.no_grad():
583 | out = self.p_sample(
584 | model,
585 | img,
586 | t,
587 | clip_denoised=clip_denoised,
588 | denoised_fn=denoised_fn,
589 | cond_fn=cond_fn,
590 | model_kwargs=model_kwargs,
591 | )
592 | yield out
593 | img = out["sample"]
594 |
595 | def ddim_sample(
596 | self,
597 | model,
598 | x,
599 | t,
600 | clip_denoised=True,
601 | denoised_fn=None,
602 | cond_fn=None,
603 | model_kwargs=None,
604 | eta=0.0,
605 | ):
606 | """
607 | Sample x_{t-1} from the model using DDIM.
608 | Same usage as p_sample().
609 | """
610 | # print(f'Running ddim with {t = }')
611 | out = self.p_mean_variance(
612 | model,
613 | x,
614 | t,
615 | clip_denoised=clip_denoised,
616 | denoised_fn=denoised_fn,
617 | model_kwargs=model_kwargs,
618 | )
619 | if cond_fn is not None:
620 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
621 |
622 | # Usually our model outputs epsilon, but we re-derive it
623 | # in case we used x_start or x_prev prediction.
624 |
625 | # NOTE: For v-pred, these two ways of computing eps should be the same, except that the latter
626 | # way (1) does not work at the first step of zero-snr when alphas = 1, and (2) also has some additional
627 | # post-processing applied to the sample (like clipping, as applied by the `process_xstart`
628 | # function above). For this reason, we will use the former when alphas is 1 and the latter otherwise.
629 | if self.model_mean_type == ModelMeanType.V and _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, (1,)).isinf().any().item():
630 | if cond_fn is not None:
631 | raise NotImplementedError()
632 | eps = self._predict_eps_from_v(x, t, out["model_output"])
633 | # eps_check = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
634 | # print(eps - eps_check) # <-- can use this to check that they are the same, for timesteps where alpha != 1
635 | else:
636 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
637 |
638 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
639 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
640 | sigma = (
641 | eta
642 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
643 | * th.sqrt(1 - alpha_bar / alpha_bar_prev)
644 | )
645 | # Equation 12.
646 | noise = th.randn_like(x)
647 | mean_pred = (
648 | out["pred_xstart"] * th.sqrt(alpha_bar_prev)
649 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
650 | )
651 | nonzero_mask = (
652 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
653 | ) # no noise when t == 0
654 | sample = mean_pred + nonzero_mask * sigma * noise
655 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
656 |
657 | def ddim_reverse_sample(
658 | self,
659 | model,
660 | x,
661 | t,
662 | clip_denoised=True,
663 | denoised_fn=None,
664 | cond_fn=None,
665 | model_kwargs=None,
666 | eta=0.0,
667 | ):
668 | """
669 | Sample x_{t+1} from the model using DDIM reverse ODE.
670 | """
671 | assert eta == 0.0, "Reverse ODE only for deterministic path"
672 | out = self.p_mean_variance(
673 | model,
674 | x,
675 | t,
676 | clip_denoised=clip_denoised,
677 | denoised_fn=denoised_fn,
678 | model_kwargs=model_kwargs,
679 | )
680 | if cond_fn is not None:
681 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
682 | # Usually our model outputs epsilon, but we re-derive it
683 | # in case we used x_start or x_prev prediction.
684 | eps = (
685 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
686 | - out["pred_xstart"]
687 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
688 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
689 |
690 | # Equation 12. reversed
691 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
692 |
693 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
694 |
695 | def ddim_sample_loop(
696 | self,
697 | model,
698 | shape,
699 | noise=None,
700 | clip_denoised=True,
701 | denoised_fn=None,
702 | cond_fn=None,
703 | model_kwargs=None,
704 | device=None,
705 | progress=False,
706 | eta=0.0,
707 | ):
708 | """
709 | Generate samples from the model using DDIM.
710 | Same usage as p_sample_loop().
711 | """
712 | final = None
713 | for sample in self.ddim_sample_loop_progressive(
714 | model,
715 | shape,
716 | noise=noise,
717 | clip_denoised=clip_denoised,
718 | denoised_fn=denoised_fn,
719 | cond_fn=cond_fn,
720 | model_kwargs=model_kwargs,
721 | device=device,
722 | progress=progress,
723 | eta=eta,
724 | ):
725 | final = sample
726 | return final["sample"]
727 |
728 | def ddim_sample_loop_progressive(
729 | self,
730 | model,
731 | shape,
732 | noise=None,
733 | clip_denoised=True,
734 | denoised_fn=None,
735 | cond_fn=None,
736 | model_kwargs=None,
737 | device=None,
738 | progress=False,
739 | eta=0.0,
740 | ):
741 | """
742 | Use DDIM to sample from the model and yield intermediate samples from
743 | each timestep of DDIM.
744 | Same usage as p_sample_loop_progressive().
745 | """
746 | if device is None:
747 | device = next(model.parameters()).device
748 | assert isinstance(shape, (tuple, list))
749 | if noise is not None:
750 | img = noise
751 | else:
752 | img = th.randn(*shape, device=device)
753 | indices = list(range(self.num_timesteps))[::-1]
754 |
755 | if progress:
756 | # Lazy import so that we don't depend on tqdm.
757 | from tqdm.auto import tqdm
758 |
759 | indices = tqdm(indices)
760 |
761 | for i in indices:
762 | t = th.tensor([i] * shape[0], device=device)
763 | with th.no_grad():
764 | out = self.ddim_sample(
765 | model,
766 | img,
767 | t,
768 | clip_denoised=clip_denoised,
769 | denoised_fn=denoised_fn,
770 | cond_fn=cond_fn,
771 | model_kwargs=model_kwargs,
772 | eta=eta,
773 | )
774 | yield out
775 | img = out["sample"]
776 |
777 | def _vb_terms_bpd(
778 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
779 | ):
780 | """
781 | Get a term for the variational lower-bound.
782 | The resulting units are bits (rather than nats, as one might expect).
783 | This allows for comparison to other papers.
784 | :return: a dict with the following keys:
785 | - 'output': a shape [N] tensor of NLLs or KLs.
786 | - 'pred_xstart': the x_0 predictions.
787 | """
788 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
789 | x_start=x_start, x_t=x_t, t=t
790 | )
791 | out = self.p_mean_variance(
792 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
793 | )
794 | kl = normal_kl(
795 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
796 | )
797 | kl = mean_flat(kl) / np.log(2.0)
798 |
799 | decoder_nll = -discretized_gaussian_log_likelihood(
800 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
801 | )
802 | assert decoder_nll.shape == x_start.shape
803 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
804 |
805 | # At the first timestep return the decoder NLL,
806 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
807 | output = th.where((t == 0), decoder_nll, kl)
808 | return {"output": output, "pred_xstart": out["pred_xstart"]}
809 |
810 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
811 | """
812 | Compute training losses for a single timestep.
813 | :param model: the model to evaluate loss on.
814 | :param x_start: the [N x C x ...] tensor of inputs.
815 | :param t: a batch of timestep indices.
816 | :param model_kwargs: if not None, a dict of extra keyword arguments to
817 | pass to the model. This can be used for conditioning.
818 | :param noise: if specified, the specific Gaussian noise to try to remove.
819 | :return: a dict with the key "loss" containing a tensor of shape [N].
820 | Some mean or variance settings may also have other keys.
821 | """
822 | if model_kwargs is None:
823 | model_kwargs = {}
824 | if noise is None:
825 | noise = th.randn_like(x_start)
826 | x_t = self.q_sample(x_start, t, noise=noise)
827 |
828 | terms = {}
829 |
830 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
831 | terms["loss"] = self._vb_terms_bpd(
832 | model=model,
833 | x_start=x_start,
834 | x_t=x_t,
835 | t=t,
836 | clip_denoised=False,
837 | model_kwargs=model_kwargs,
838 | )["output"]
839 | if self.loss_type == LossType.RESCALED_KL:
840 | terms["loss"] *= self.num_timesteps
841 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
842 | model_output = model(x_t, t, **model_kwargs)
843 |
844 | if self.model_var_type in [
845 | ModelVarType.LEARNED,
846 | ModelVarType.LEARNED_RANGE,
847 | ]:
848 | B, C = x_t.shape[:2]
849 | assert model_output.shape == (B, C * 2, *x_t.shape[2:])
850 | model_output, model_var_values = th.split(model_output, C, dim=1)
851 | # Learn the variance using the variational bound, but don't let
852 | # it affect our mean prediction.
853 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
854 | terms["vb"] = self._vb_terms_bpd(
855 | model=lambda *args, r=frozen_out: r,
856 | x_start=x_start,
857 | x_t=x_t,
858 | t=t,
859 | clip_denoised=False,
860 | )["output"]
861 | if self.loss_type == LossType.RESCALED_MSE:
862 | # Divide by 1000 for equivalence with initial implementation.
863 | # Without a factor of 1/1000, the VB term hurts the MSE term.
864 | terms["vb"] *= self.num_timesteps / 1000.0
865 |
866 | if self.model_mean_type == ModelMeanType.V:
867 | # from https://github.com/ericl122333/PatchDiffusion-Pytorch/blob/main/patch_diffusion/gaussian_diffusion.py#L841
868 | # TODO: consider refactoring everything to this format
869 | # Get predicted x_0
870 | if self.model_mean_type == ModelMeanType.START_X:
871 | model_pred_x0 = model_output
872 | elif self.model_mean_type == ModelMeanType.EPSILON:
873 | model_pred_x0 = self._predict_xstart_from_eps(x_t, t, model_output)
874 | elif self.model_mean_type == ModelMeanType.V:
875 | model_pred_x0 = self._predict_xstart_from_v(x_t, t, model_output)
876 | elif self.model_mean_type == ModelMeanType.PREVIOUS_X:
877 | model_pred_x0 = self._predict_xstart_from_xprev(x_t, t, model_output)
878 | else:
879 | raise NotImplementedError()
880 | # Get loss weights depending on timestep
881 | if self.weight_schedule == "sqrt_snr":
882 | weights = np.sqrt(self.snrs)
883 | elif self.weight_schedule == "p2": # from https://arxiv.org/abs/2204.00227
884 | weights = self.snrs * self.sqrt_one_minus_alphas_cumprod
885 | elif self.weight_schedule == "snr":
886 | weights = self.snrs
887 | elif self.weight_schedule == "snr+1":
888 | weights = self.snrs + 1
889 | elif self.weight_schedule == "truncated_snr":
890 | weights = np.maximum(self.snrs, 1.0)
891 | else:
892 | raise NotImplementedError()
893 | weights = _extract_into_tensor(weights, t, x_start.shape)
894 | # Compute loss
895 | assert model_output.shape == x_start.shape
896 | terms["mse"] = mean_flat(weights * (x_start - model_pred_x0)**2)
897 | else:
898 | target = {
899 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
900 | ModelMeanType.START_X: x_start,
901 | ModelMeanType.EPSILON: noise,
902 | }[self.model_mean_type]
903 | assert model_output.shape == target.shape == x_start.shape
904 | terms["mse"] = mean_flat((target - model_output) ** 2)
905 | if "vb" in terms:
906 | terms["loss"] = terms["mse"] + terms["vb"]
907 | else:
908 | terms["loss"] = terms["mse"]
909 | elif self.loss_type == LossType.FLOW:
910 | model_output = model(x_t, t, **model_kwargs)
911 |
912 | # TODO: Implement learned variance for flow-based models
913 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
914 | raise NotImplementedError(f'{self.model_var_type = } is not implemented for flow yet')
915 |
916 | # Loss
917 | target = x_start - noise
918 | terms["flow"] = mean_flat((model_output - target) ** 2)
919 | terms["loss"] = terms["flow"]
920 |
921 | else:
922 | raise NotImplementedError(self.loss_type)
923 |
924 | return terms
925 |
926 | def _prior_bpd(self, x_start):
927 | """
928 | Get the prior KL term for the variational lower-bound, measured in
929 | bits-per-dim.
930 | This term can't be optimized, as it only depends on the encoder.
931 | :param x_start: the [N x C x ...] tensor of inputs.
932 | :return: a batch of [N] KL values (in bits), one per batch element.
933 | """
934 | batch_size = x_start.shape[0]
935 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
936 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
937 | kl_prior = normal_kl(
938 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
939 | )
940 | return mean_flat(kl_prior) / np.log(2.0)
941 |
942 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
943 | """
944 | Compute the entire variational lower-bound, measured in bits-per-dim,
945 | as well as other related quantities.
946 | :param model: the model to evaluate loss on.
947 | :param x_start: the [N x C x ...] tensor of inputs.
948 | :param clip_denoised: if True, clip denoised samples.
949 | :param model_kwargs: if not None, a dict of extra keyword arguments to
950 | pass to the model. This can be used for conditioning.
951 | :return: a dict containing the following keys:
952 | - total_bpd: the total variational lower-bound, per batch element.
953 | - prior_bpd: the prior term in the lower-bound.
954 | - vb: an [N x T] tensor of terms in the lower-bound.
955 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
956 | - mse: an [N x T] tensor of epsilon MSEs for each timestep.
957 | """
958 | device = x_start.device
959 | batch_size = x_start.shape[0]
960 |
961 | vb = []
962 | xstart_mse = []
963 | mse = []
964 | for t in list(range(self.num_timesteps))[::-1]:
965 | t_batch = th.tensor([t] * batch_size, device=device)
966 | noise = th.randn_like(x_start)
967 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
968 | # Calculate VLB term at the current timestep
969 | with th.no_grad():
970 | out = self._vb_terms_bpd(
971 | model,
972 | x_start=x_start,
973 | x_t=x_t,
974 | t=t_batch,
975 | clip_denoised=clip_denoised,
976 | model_kwargs=model_kwargs,
977 | )
978 | vb.append(out["output"])
979 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
980 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
981 | mse.append(mean_flat((eps - noise) ** 2))
982 |
983 | vb = th.stack(vb, dim=1)
984 | xstart_mse = th.stack(xstart_mse, dim=1)
985 | mse = th.stack(mse, dim=1)
986 |
987 | prior_bpd = self._prior_bpd(x_start)
988 | total_bpd = vb.sum(dim=1) + prior_bpd
989 | return {
990 | "total_bpd": total_bpd,
991 | "prior_bpd": prior_bpd,
992 | "vb": vb,
993 | "xstart_mse": xstart_mse,
994 | "mse": mse,
995 | }
996 |
997 | def p_sample_loop_ode(
998 | self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None,
999 | device=None, progress=False, return_all=False,
1000 | ):
1001 | """
1002 | Generate samples from the model and yield intermediate samples from
1003 | each timestep of diffusion.
1004 | Arguments are the same as p_sample_loop().
1005 | Returns a generator over dicts, where each dict is the return value of
1006 | p_sample().
1007 | """
1008 | assert isinstance(shape, (tuple, list))
1009 | device = device or next(model.parameters()).device
1010 | img = noise if noise is not None else th.randn(*shape, device=device)
1011 |
1012 | intermediates = []
1013 |
1014 | def ode_func(t: float, x: th.FloatTensor):
1015 | nonlocal intermediates
1016 | print(f'In ode_func, {t = }')
1017 | print(f'In ode_func, {x = }')
1018 | t_input = th.round((1 - t) * self.num_timesteps).expand(shape[0]).to(device=device, dtype=th.long) # first map [0, 1] to [T, 0]
1019 | t_input = th.clip(t_input - 1, min=0, max=None)
1020 | # t_input = th.tensor([t_input] * shape[0], device=device) # convert to tensor
1021 | out = self.p_sample(
1022 | model, x, t_input, clip_denoised=clip_denoised,
1023 | denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs,
1024 | )
1025 | x_new = out["sample"]
1026 | intermediates.append({
1027 | 't': t, 't_input': th.round((1 - t) * self.num_timesteps), 'sample': x_new,
1028 | 'pred_xstart': x_new, # <-- this should be computed separately but that's fine
1029 | })
1030 | return x_new
1031 |
1032 | from torchdiffeq import odeint
1033 |
1034 | odeint_timescale = 1.0
1035 | odeint_kwargs = {
1036 | "atol": 1e-5,
1037 | "rtol": 1e-5,
1038 | "options": {
1039 | "step_t": [1.0 + 1e-7],
1040 | }
1041 | }
1042 |
1043 | with th.no_grad():
1044 | t = th.tensor([0, 1.0], device=device)
1045 | z, log_det = odeint(
1046 | ode_func,
1047 | img,
1048 | t * odeint_timescale,
1049 | **odeint_kwargs,
1050 | )
1051 |
1052 | final_sample = z
1053 | return final_sample, intermediates
1054 |
1055 |
1056 | def _extract_into_tensor(arr, timesteps, broadcast_shape):
1057 | """
1058 | Extract values from a 1-D numpy array for a batch of indices.
1059 | :param arr: the 1-D numpy array.
1060 | :param timesteps: a tensor of indices into the array to extract.
1061 | :param broadcast_shape: a larger shape of K dimensions with the batch
1062 | dimension equal to the length of timesteps.
1063 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1064 | """
1065 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1066 | while len(res.shape) < len(broadcast_shape):
1067 | res = res[..., None]
1068 | return res + th.zeros(broadcast_shape, device=timesteps.device)
--------------------------------------------------------------------------------
/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 |
117 | class _WrappedModel:
118 | def __init__(self, model, timestep_map, original_num_steps):
119 | self.model = model
120 | self.timestep_map = timestep_map
121 | # self.rescale_timesteps = rescale_timesteps
122 | self.original_num_steps = original_num_steps
123 |
124 | def __call__(self, x, ts, **kwargs):
125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126 | new_ts = map_tensor[ts]
127 | # if self.rescale_timesteps:
128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129 | # print(f'Inside _WrappedModel: {ts = } and {new_ts = }')
130 | return self.model(x, new_ts, **kwargs)
131 |
--------------------------------------------------------------------------------
/diffusion/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: DiT
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python >= 3.8
7 | - pytorch >= 1.13
8 | - torchvision
9 | - pytorch-cuda=11.7
10 | - pip:
11 | - timm
12 | - diffusers
13 | - accelerate
14 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 |
12 | import math
13 | import random
14 | from typing import Optional
15 | from contextlib import nullcontext
16 |
17 | import torch
18 | import torch.nn as nn
19 | import numpy as np
20 | from jaxtyping import Float, Shaped
21 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
22 | from torch import Tensor
23 | from torch.utils.checkpoint import checkpoint
24 |
25 | def modulate(x, shift, scale):
26 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
27 |
28 |
29 | #################################################################################
30 | # Embedding Layers for Timesteps and Class Labels #
31 | #################################################################################
32 |
33 | class TimestepEmbedder(nn.Module):
34 | """
35 | Embeds scalar timesteps into vector representations.
36 | """
37 | def __init__(self, hidden_size, frequency_embedding_size=256):
38 | super().__init__()
39 | self.mlp = nn.Sequential(
40 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
41 | nn.SiLU(),
42 | nn.Linear(hidden_size, hidden_size, bias=True),
43 | )
44 | self.frequency_embedding_size = frequency_embedding_size
45 |
46 | @staticmethod
47 | def timestep_embedding(t, dim, max_period=10000):
48 | """
49 | Create sinusoidal timestep embeddings.
50 | :param t: a 1-D Tensor of N indices, one per batch element.
51 | These may be fractional.
52 | :param dim: the dimension of the output.
53 | :param max_period: controls the minimum frequency of the embeddings.
54 | :return: an (N, D) Tensor of positional embeddings.
55 | """
56 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
57 | half = dim // 2
58 | freqs = torch.exp(
59 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
60 | ).to(device=t.device)
61 | args = t[:, None].float() * freqs[None]
62 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
63 | if dim % 2:
64 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
65 | return embedding
66 |
67 | def forward(self, t):
68 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
69 | t_emb = self.mlp(t_freq)
70 | return t_emb
71 |
72 |
73 | class LabelEmbedder(nn.Module):
74 | """
75 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
76 | """
77 | def __init__(self, num_classes, hidden_size, dropout_prob, use_cfg_embedding: bool = True, continuous: bool = False):
78 | super().__init__()
79 | self.continuous = continuous
80 | self.num_classes = num_classes
81 | self.dropout_prob = dropout_prob
82 | if self.continuous:
83 | self.embedding_projection = nn.Linear(num_classes, hidden_size)
84 | else:
85 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
86 |
87 | def token_drop(self, labels, force_drop_ids=None):
88 | """
89 | Drops labels to enable classifier-free guidance.
90 | """
91 | if force_drop_ids is None:
92 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
93 | else:
94 | drop_ids = force_drop_ids == 1
95 | if self.continuous:
96 | labels = labels * (1 - drop_ids[:, None].to(labels))
97 | else:
98 | labels = torch.where(drop_ids, self.num_classes, labels)
99 | return labels
100 |
101 | def forward(self, labels, train, force_drop_ids=None):
102 | use_dropout = self.dropout_prob > 0
103 | if (train and use_dropout) or (force_drop_ids is not None):
104 | labels = self.token_drop(labels, force_drop_ids)
105 | embeddings = self.embedding_projection(labels) if self.continuous else self.embedding_table(labels)
106 | return embeddings
107 |
108 |
109 |
110 | #################################################################################
111 | # Core DiT Model #
112 | #################################################################################
113 |
114 | class DiTBlock(nn.Module):
115 | """
116 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
117 | """
118 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
119 | super().__init__()
120 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
121 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
122 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
123 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
124 | approx_gelu = lambda: nn.GELU(approximate="tanh")
125 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
126 | self.adaLN_modulation = nn.Sequential(
127 | nn.SiLU(),
128 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
129 | )
130 |
131 | def forward(self, x, c):
132 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
133 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
134 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
135 | return x
136 |
137 |
138 | class FinalLayer(nn.Module):
139 | """
140 | The final layer of DiT.
141 | """
142 | def __init__(self, hidden_size, patch_size, out_channels):
143 | super().__init__()
144 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
145 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
146 | self.adaLN_modulation = nn.Sequential(
147 | nn.SiLU(),
148 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
149 | )
150 |
151 | def forward(self, x, c):
152 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
153 | x = modulate(self.norm_final(x), shift, scale)
154 | x = self.linear(x)
155 | return x
156 |
157 |
158 | class DiT(nn.Module):
159 | """
160 | Diffusion model with a Transformer backbone.
161 | """
162 | def __init__(
163 | self,
164 | input_size=32,
165 | patch_size=2,
166 | in_channels=4,
167 | hidden_size=1152,
168 | depth=28,
169 | num_heads=16,
170 | mlp_ratio=4.0,
171 | class_dropout_prob=0.1,
172 | num_classes=1000,
173 | learn_sigma=True,
174 | use_cfg_embedding: bool = True,
175 | use_gradient_checkpointing: bool = True,
176 | is_label_continuous: bool = False,
177 |
178 | # below are Fixed Point-specific arguments.
179 | fixed_point: bool = False,
180 |
181 | # size
182 | fixed_point_pre_depth: int = 1,
183 | fixed_point_post_depth: int = 1,
184 |
185 | # iteration counts
186 | fixed_point_no_grad_min_iters: int = 0,
187 | fixed_point_no_grad_max_iters: int = 0,
188 | fixed_point_with_grad_min_iters: int = 28,
189 | fixed_point_with_grad_max_iters: int = 28,
190 |
191 | # solution recycle
192 | fixed_point_reuse_solution = False,
193 |
194 | # pre_post_timestep_conditioning
195 | fixed_point_pre_post_timestep_conditioning: bool = True,
196 |
197 | # adaptively distributing iterations among timesteps. Currently we only support linear distribution.
198 | adaptive: bool = False,
199 | iteration_controller = None
200 | ):
201 | super().__init__()
202 | self.learn_sigma = learn_sigma
203 | self.in_channels = in_channels
204 | self.out_channels = in_channels * 2 if learn_sigma else in_channels
205 | self.patch_size = patch_size
206 | self.num_heads = num_heads
207 | self.use_gradient_checkpointing = use_gradient_checkpointing
208 |
209 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
210 | self.t_embedder = TimestepEmbedder(hidden_size)
211 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob,
212 | use_cfg_embedding=use_cfg_embedding, continuous=is_label_continuous)
213 | num_patches = self.x_embedder.num_patches
214 |
215 | # Will use fixed sin-cos embedding:
216 | # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
217 | self.register_buffer('pos_embed', torch.zeros(1, num_patches, hidden_size))
218 |
219 | # New: Fixed Point
220 | self.fixed_point = fixed_point
221 | if self.fixed_point:
222 | self.fixed_point_no_grad_min_iters = fixed_point_no_grad_min_iters
223 | self.fixed_point_no_grad_max_iters = fixed_point_no_grad_max_iters
224 | self.fixed_point_with_grad_min_iters = fixed_point_with_grad_min_iters
225 | self.fixed_point_with_grad_max_iters = fixed_point_with_grad_max_iters
226 | self.fixed_point_pre_post_timestep_conditioning = fixed_point_pre_post_timestep_conditioning
227 | self.blocks_pre = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(fixed_point_pre_depth)])
228 | self.block_pre_projection = nn.Linear(hidden_size, hidden_size)
229 | self.block_fixed_point_projection_fc1 = nn.Linear(2 * hidden_size, 2 * hidden_size)
230 | self.block_fixed_point_projection_act = nn.GELU(approximate="tanh")
231 | self.block_fixed_point_projection_fc2 = nn.Linear(2 * hidden_size, hidden_size)
232 | self.block_fixed_point = DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
233 | self.blocks_post = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(fixed_point_post_depth)])
234 | self.blocks = [*self.blocks_pre, self.block_fixed_point, *self.blocks_post]
235 | self.fixed_point_reuse_solution = fixed_point_reuse_solution
236 | self.last_solution = None
237 |
238 | self.adaptive = adaptive
239 | self.iteration_controller = iteration_controller
240 | else:
241 | self.blocks = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
242 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
243 | self.initialize_weights()
244 |
245 | def initialize_weights(self):
246 | # Initialize transformer layers:
247 | def _basic_init(module):
248 | if isinstance(module, nn.Linear):
249 | torch.nn.init.xavier_uniform_(module.weight)
250 | if module.bias is not None:
251 | nn.init.constant_(module.bias, 0)
252 | self.apply(_basic_init)
253 |
254 | # Initialize (and freeze) pos_embed by sin-cos embedding:
255 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
256 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
257 |
258 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
259 | w = self.x_embedder.proj.weight.data
260 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
261 | nn.init.constant_(self.x_embedder.proj.bias, 0)
262 |
263 | # Initialize label embedding table:
264 | if self.y_embedder.continuous:
265 | nn.init.normal_(self.y_embedder.embedding_projection.weight, std=0.02)
266 | nn.init.constant_(self.y_embedder.embedding_projection.bias, 0)
267 | else:
268 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
269 |
270 | # Initialize timestep embedding MLP:
271 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
272 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
273 |
274 | # Zero-out adaLN modulation layers in DiT blocks:
275 | for block in self.blocks:
276 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
277 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
278 |
279 | # Zero-out output layers:
280 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
281 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
282 | nn.init.constant_(self.final_layer.linear.weight, 0)
283 | nn.init.constant_(self.final_layer.linear.bias, 0)
284 |
285 | def unpatchify(self, x):
286 | """
287 | x: (N, T, patch_size**2 * C)
288 | imgs: (N, H, W, C)
289 | """
290 | c = self.out_channels
291 | p = self.x_embedder.patch_size[0]
292 | h = w = int(x.shape[1] ** 0.5)
293 | assert h * w == x.shape[1]
294 |
295 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
296 | x = torch.einsum('nhwpqc->nchpwq', x)
297 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
298 | return imgs
299 |
300 | def ckpt_wrapper(self, module):
301 | def ckpt_forward(*inputs):
302 | outputs = module(*inputs)
303 | return outputs
304 | return ckpt_forward
305 |
306 | def _forward_dit(self, x, t, y):
307 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
308 | t = self.t_embedder(t) # (N, D)
309 | y = self.y_embedder(y, self.training) # (N, D))
310 | c = t + y # (N, D)
311 | for block in self.blocks:
312 | x = checkpoint(self.ckpt_wrapper(block), x, c) if self.use_gradient_checkpointing else block(x, c) # (N, T, D)
313 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
314 | x = self.unpatchify(x) # (N, out_channels, H, W)
315 | return x
316 |
317 | def _forward_fixed_point_blocks(
318 | self, x: Float[Tensor, "b t d"], x_input_injection: Float[Tensor, "b t d"], c: Float[Tensor, "b d"], num_iterations: int
319 | ) -> Float[Tensor, "b t d"]:
320 | def forward_pass(x):
321 | x = torch.cat((x, x_input_injection), dim=-1) # (N, T, D * 2)
322 | x = self.block_fixed_point_projection_fc1(x) # (N, T, D * 2)
323 | x = self.block_fixed_point_projection_act(x) # (N, T, D * 2)
324 | x = self.block_fixed_point_projection_fc2(x) # (N, T, D)
325 | x = self.block_fixed_point(x, c) # (N, T, D)
326 | return x
327 |
328 | if self.adaptive:
329 | num_iterations = self.iteration_controller.get()
330 |
331 | for _ in range(num_iterations):
332 | x = forward_pass(x)
333 | return x
334 |
335 | def _check_inputs(self, x: Float[Tensor, "b c h w"], t: Shaped[Tensor, "b"], y: Shaped[Tensor, "b"]) -> None:
336 | if self.fixed_point_reuse_solution:
337 | if not torch.all(t[0] == t).item():
338 | raise ValueError(t)
339 |
340 | def _forward_fixed_point(self, x: Float[Tensor, "b c h w"], t: Shaped[Tensor, "b"], y: Shaped[Tensor, "b"]) -> Float[Tensor, "b c h w"]:
341 | self._check_inputs(x, t, y)
342 | x: Float[Tensor, "b t d"] = self.x_embedder(x) + self.pos_embed
343 | t_emb: Float[Tensor, "b d"] = self.t_embedder(t)
344 | y: Float[Tensor, "b d"] = self.y_embedder(y, self.training)
345 | c: Float[Tensor, "b d"] = t_emb + y
346 | c_pre_post_fixed_point: Float[Tensor, "b d"] = (t_emb + y) if self.fixed_point_pre_post_timestep_conditioning else y
347 |
348 | # Pre-Fixed Point
349 | # Note: If using DDP with find_unused_parameters=True, checkpoint causes issues. For more
350 | # information, see https://github.com/allenai/longformer/issues/63#issuecomment-648861503
351 | for block in self.blocks_pre:
352 | x: Float[Tensor, "b t d"] = checkpoint(self.ckpt_wrapper(block), x, c_pre_post_fixed_point) if self.use_gradient_checkpointing else block(x, c_pre_post_fixed_point)
353 | condition = x.clone()
354 |
355 | # Whether to reuse the previous solution at the next iteration
356 | init_solution = self.last_solution if (self.fixed_point_reuse_solution and self.last_solution is not None) else x.clone()
357 |
358 | # Fixed Point (we have condition and init_solution)
359 | x_input_injection = self.block_pre_projection(condition)
360 |
361 | # NOTE: This section of code should have no_grad, but cannot due to a DDP bug. See
362 | # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
363 | # for more information
364 | with nullcontext(): # we use x.detach() in place of torch.no_grad due to DDP issue
365 | num_iterations_no_grad = random.randint(self.fixed_point_no_grad_min_iters, self.fixed_point_no_grad_max_iters)
366 | x = self._forward_fixed_point_blocks(x=init_solution.detach(), x_input_injection=x_input_injection.detach(), c=c, num_iterations=num_iterations_no_grad)
367 | x = x.detach() # no grad
368 | num_iterations_with_grad = random.randint(self.fixed_point_with_grad_min_iters, self.fixed_point_with_grad_max_iters)
369 | x = self._forward_fixed_point_blocks(x=x, x_input_injection=x_input_injection, c=c, num_iterations=num_iterations_with_grad)
370 |
371 | # Save solution for reuse at next step
372 | if self.fixed_point_reuse_solution:
373 | self.last_solution = x.clone()
374 |
375 | # Post-Fixed Point
376 | for block in self.blocks_post:
377 | x = checkpoint(self.ckpt_wrapper(block), x, c_pre_post_fixed_point) if self.use_gradient_checkpointing else block(x, c_pre_post_fixed_point)
378 |
379 | # Output
380 | x: Float[Tensor, "b t p2c"] = self.final_layer(x, c_pre_post_fixed_point) # p2c = patch_size ** 2 * out_channels)
381 | x: Float[Tensor, "b c h w"] = self.unpatchify(x)
382 | return x
383 |
384 | def reset(self):
385 | self.last_solution = None
386 |
387 | def forward(self, x, t, y):
388 | """
389 | Forward pass of DiT.
390 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
391 | t: (N,) tensor of diffusion timesteps
392 | y: (N,) tensor of class labels
393 | """
394 | if self.fixed_point:
395 | return self._forward_fixed_point(x, t, y)
396 | else:
397 | return self._forward_dit(x, t, y)
398 |
399 | def forward_with_cfg(self, x, t, y, cfg_scale):
400 | """
401 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
402 | """
403 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
404 | half = x[: len(x) // 2]
405 | combined = torch.cat([half, half], dim=0)
406 | model_out = self.forward(combined, t, y)
407 | # For exact reproducibility reasons, we apply classifier-free guidance on only
408 | # three channels by default. The standard approach to cfg applies it to all channels.
409 | # This can be done by uncommenting the following line and commenting-out the line following that.
410 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
411 | eps, rest = model_out[:, :3], model_out[:, 3:]
412 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
413 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
414 | eps = torch.cat([half_eps, half_eps], dim=0)
415 | return torch.cat([eps, rest], dim=1)
416 |
417 |
418 | #################################################################################
419 | # Sine/Cosine Positional Embedding Functions #
420 | #################################################################################
421 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
422 |
423 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
424 | """
425 | grid_size: int of the grid height and width
426 | return:
427 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
428 | """
429 | grid_h = np.arange(grid_size, dtype=np.float32)
430 | grid_w = np.arange(grid_size, dtype=np.float32)
431 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
432 | grid = np.stack(grid, axis=0)
433 |
434 | grid = grid.reshape([2, 1, grid_size, grid_size])
435 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
436 | if cls_token and extra_tokens > 0:
437 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
438 | return pos_embed
439 |
440 |
441 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
442 | assert embed_dim % 2 == 0
443 |
444 | # use half of dimensions to encode grid_h
445 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
446 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
447 |
448 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
449 | return emb
450 |
451 |
452 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
453 | """
454 | embed_dim: output dimension for each position
455 | pos: a list of positions to be encoded: size (M,)
456 | out: (M, D)
457 | """
458 | assert embed_dim % 2 == 0
459 | omega = np.arange(embed_dim // 2, dtype=np.float64)
460 | omega /= embed_dim / 2.
461 | omega = 1. / 10000**omega # (D/2,)
462 |
463 | pos = pos.reshape(-1) # (M,)
464 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
465 |
466 | emb_sin = np.sin(out) # (M, D/2)
467 | emb_cos = np.cos(out) # (M, D/2)
468 |
469 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
470 | return emb
471 |
472 |
473 | #################################################################################
474 | # DiT Configs #
475 | #################################################################################
476 |
477 | def DiT_XL_2(**kwargs):
478 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
479 |
480 | def DiT_XL_4(**kwargs):
481 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
482 |
483 | def DiT_XL_8(**kwargs):
484 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
485 |
486 | def DiT_L_2(**kwargs):
487 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
488 |
489 | def DiT_L_4(**kwargs):
490 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
491 |
492 | def DiT_L_8(**kwargs):
493 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
494 |
495 | def DiT_B_2(**kwargs):
496 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
497 |
498 | def DiT_B_4(**kwargs):
499 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
500 |
501 | def DiT_B_8(**kwargs):
502 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
503 |
504 | def DiT_S_2(**kwargs):
505 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
506 |
507 | def DiT_S_4(**kwargs):
508 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
509 |
510 | def DiT_S_8(**kwargs):
511 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
512 |
513 |
514 | DiT_models = {
515 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
516 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
517 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
518 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
519 | }
520 |
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | from pathlib import Path
4 | from typing import Optional
5 |
6 | import torch
7 | from accelerate import Accelerator
8 | from accelerate.utils import set_seed
9 | from diffusers.models import AutoencoderKL
10 | from PIL import Image
11 | from tap import Tap
12 | from tqdm import trange
13 |
14 |
15 | from diffusion import create_diffusion
16 | from download import find_model
17 | from models import DiT_models
18 | from adaptive_controller import LinearController
19 |
20 |
21 | class Args(Tap):
22 |
23 | # Paths
24 | output_dir: str = 'samples'
25 |
26 | # Dataset
27 | dataset_name: str = "imagenet256"
28 |
29 | # Model
30 | model: str = "DiT-XL/2"
31 | vae: str = "mse"
32 | num_classes: int = 1000
33 | image_size: int = 256
34 | predict_v: bool = False
35 | use_zero_terminal_snr: bool = False
36 | unsupervised: bool = False
37 | dino_supervised: bool = False
38 | dino_supervised_dim: int = 768
39 | flow: bool = False
40 | debug: bool = False
41 |
42 | # Fixed Point settings
43 | fixed_point: bool = False
44 | fixed_point_pre_depth: int = 2
45 | fixed_point_post_depth: int = 2
46 | fixed_point_iters: Optional[int] = None
47 | fixed_point_pre_post_timestep_conditioning: bool = False
48 | fixed_point_reuse_solution: bool = False
49 |
50 | # Sampling
51 | ddim: bool = False
52 | cfg_scale: float = 4.0
53 | num_sampling_steps: int = 250
54 | batch_size: int = 32
55 | ckpt: str = '...'
56 | global_seed: int = 0
57 |
58 | # Parallelization
59 | sample_index_start: int = 0
60 | sample_index_end: Optional[int] = 50_000
61 |
62 | # Adaptive
63 | adaptive: bool = False
64 | adaptive_type: str = "increasing" # currently only support increasing, fixed, and decreasing
65 |
66 | def process_args(self) -> None:
67 | """Additional argument processing"""
68 | if self.debug:
69 | self.log_with = 'tensorboard'
70 | self.name = 'debug'
71 |
72 | # Defaults
73 | self.fixed_point_iters = self.fixed_point_iters or (28 - self.fixed_point_pre_depth - self.fixed_point_post_depth)
74 |
75 | # Checks
76 | if self.cfg_scale < 1.0:
77 | raise ValueError("In almost all cases, cfg_scale should be >= 1.0")
78 | if self.unsupervised:
79 | assert self.cfg_scale == 1.0
80 | self.num_classes = 1
81 | elif self.dino_supervised:
82 | raise NotImplementedError()
83 | if not Path(self.ckpt).is_file():
84 | raise ValueError(self.ckpt)
85 |
86 | # Create output directory
87 | output_parent = Path(self.output_dir) / Path(self.ckpt).parent.parent.name
88 | if self.debug:
89 | output_dirname = 'debug'
90 | else:
91 | output_dirname = f'num_sampling_steps-{self.num_sampling_steps}--cfg_scale-{self.cfg_scale}'
92 | if self.fixed_point:
93 | output_dirname += f'--fixed_point_iters-{self.fixed_point_iters}--fixed_point_reuse_solution-{self.fixed_point_reuse_solution}--fixed_point_pptc-{self.fixed_point_pre_post_timestep_conditioning}'
94 | if self.ddim:
95 | output_dirname += f'--ddim'
96 | self.output_dir = str(output_parent / output_dirname)
97 | Path(self.output_dir).mkdir(exist_ok=True, parents=True)
98 |
99 | if self.adaptive:
100 | self.budget = self.num_sampling_steps * self.deq_iters
101 | if self.adaptive_type == "increasing" or self.adaptive_type == "decreasing" or self.adaptive_type == "fixed":
102 | self.iteration_controller = LinearController(self.budget, self.num_sampling_steps, type = self.adaptive_type)
103 | else:
104 | raise NotImplementedError()
105 |
106 | def main(args: Args):
107 |
108 | # Setup accelerator, logging, randomness
109 | accelerator = Accelerator()
110 | set_seed(args.global_seed + args.sample_index_start)
111 |
112 | # Load model
113 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
114 | latent_size = H_lat = W_lat = args.image_size // 8
115 | # print(f"!!! Pre model init, fixed_point_reuse_solution: {args.fixed_point_reuse_solution}")
116 | model = DiT_models[args.model](
117 | input_size=latent_size,
118 | num_classes=(args.dino_supervised_dim if args.dino_supervised else args.num_classes),
119 | is_label_continuous=args.dino_supervised,
120 | class_dropout_prob=0,
121 | learn_sigma=(not args.flow), # TODO: Implement learned variance for flow-based models
122 | use_gradient_checkpointing=False,
123 | fixed_point=args.fixed_point,
124 | fixed_point_pre_depth=args.fixed_point_pre_depth,
125 | fixed_point_post_depth=args.fixed_point_post_depth,
126 | fixed_point_no_grad_min_iters=0,
127 | fixed_point_no_grad_max_iters=0,
128 | fixed_point_with_grad_min_iters=args.fixed_point_iters,
129 | fixed_point_with_grad_max_iters=args.fixed_point_iters,
130 | fixed_point_reuse_solution=args.fixed_point_reuse_solution,
131 | fixed_point_pre_post_timestep_conditioning=args.fixed_point_pre_post_timestep_conditioning,
132 | adaptive=args.adaptive,
133 | iteration_controller=args.iteration_controller,
134 | ).to(accelerator.device)
135 | print(f'Loaded model with params: {sum(p.numel() for p in model.parameters()):_}')
136 |
137 | # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py
138 | state_dict = find_model(args.ckpt)
139 | model.load_state_dict(state_dict)
140 | model.eval()
141 | diffusion = create_diffusion(
142 | str(args.num_sampling_steps),
143 | use_flow=args.flow,
144 | predict_v=args.predict_v,
145 | use_zero_terminal_snr=args.use_zero_terminal_snr,
146 | )
147 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(accelerator.device).eval()
148 | using_cfg = args.cfg_scale > 1.0
149 |
150 | # Generate pseudorandom class labels and noises. Note that these are generated using the global
151 | # seed, which is shared between all processes. As a result, all processes will generate the same
152 | # list of class labels and noises. Then, we take a subset of these based on the `process_index`
153 | # of the current process.
154 | N = 50_000 # this assumes we will never sample more than 50K samples, which I think is reasonable
155 | generator = torch.Generator(device=accelerator.device)
156 | generator.manual_seed(args.global_seed)
157 | class_labels = torch.randint(0, args.num_classes, size=(N,), device=accelerator.device)
158 | generator.manual_seed(args.global_seed)
159 | latents = torch.randn(N, model.in_channels, H_lat, W_lat, device=accelerator.device, generator=generator)
160 | class_labels = class_labels[args.sample_index_start:args.sample_index_end]
161 | latents = latents[args.sample_index_start:args.sample_index_end]
162 | indices = list(range(args.sample_index_start, args.sample_index_end))
163 | print(f'Using pseudorandom class labels and latents (start={args.sample_index_start} and end={args.sample_index_end})')
164 |
165 | # Create output path
166 | output_dir = Path(args.output_dir)
167 | args.save(output_dir / 'args.json')
168 | print(f'Saving samples to {output_dir.resolve()}')
169 |
170 | # Load class labels for helpful filenames
171 | if args.dataset_name == 'imagenet256':
172 | with open("utils/imagenet-labels.json", "r") as f:
173 | label_names: list[str] = json.load(f)
174 | label_names = [l.lower().replace(' ', '-').replace('\'', '') for l in label_names]
175 | elif args.unsupervised:
176 | assert args.cfg_scale == 1.0
177 | label_names = ["unlabeled"]
178 | else:
179 | raise NotImplementedError()
180 |
181 | # Disable gradient
182 | with torch.inference_mode():
183 |
184 | # Sample loop
185 | num_batches = math.ceil(len(class_labels) / args.batch_size)
186 | for batch_idx in trange(num_batches, disable=(not accelerator.is_main_process)):
187 |
188 | # Get pre-sampled inputs
189 | z = latents[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size]
190 | y = class_labels[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size]
191 | idxs = indices[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size]
192 | output_paths = [output_dir / f'{idx:05d}--{y_i:03d}--{label_names[y_i]}.png' for y_i, idx in zip(y.tolist(), idxs)]
193 |
194 | # Skip files that already exist
195 | if all(output_path.is_file() for output_path in output_paths):
196 | print(f'Files already exist (batch {batch_idx}). Skipping.')
197 | continue
198 |
199 | # Setup classifier-free guidance
200 | if using_cfg:
201 | y_null = torch.tensor([1000] * args.batch_size, device=accelerator.device)
202 | y = torch.cat([y, y_null], 0)
203 | z = torch.cat([z, z], 0)
204 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
205 | sample_fn = model.forward_with_cfg
206 | else:
207 | model_kwargs = dict(y=y)
208 | sample_fn = model.forward
209 |
210 | if args.adaptive:
211 | model.threshold_controller.init_image()
212 |
213 | # Sample latent images
214 | sample_kwargs = dict(model=sample_fn, shape=z.shape, noise=z, clip_denoised=False, model_kwargs=model_kwargs,
215 | progress=False, device=accelerator.device)
216 | if args.ddim:
217 | samples = diffusion.ddim_sample_loop(**sample_kwargs)
218 | else:
219 | samples = diffusion.p_sample_loop(**sample_kwargs)
220 |
221 |
222 |
223 | if using_cfg:
224 | samples, _ = samples.chunk(2, dim=0)
225 |
226 |
227 | # Reset model (resets the initial solution to None)
228 | model.reset()
229 |
230 | # Decode latents
231 | samples = vae.decode(samples / vae.config.scaling_factor).sample
232 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
233 |
234 | # Save samples to disk as individual .png files
235 | for sample, output_path in zip(samples, output_paths):
236 | Image.fromarray(sample).save(output_path)
237 |
238 |
239 | if __name__ == "__main__":
240 | args = Args(explicit_bool=True).parse_args()
241 | main(args)
242 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from collections import OrderedDict
4 | from copy import deepcopy
5 | from glob import glob
6 | from pathlib import Path
7 | from time import time
8 | from typing import Callable, Optional
9 |
10 | import numpy as np
11 | import torch
12 | from accelerate import Accelerator, DistributedDataParallelKwargs
13 | from accelerate.utils import ProjectConfiguration
14 | from PIL import Image
15 | from tap import Tap
16 | from torch.utils.data import DataLoader, Dataset
17 | from tqdm import tqdm
18 |
19 | from diffusion import create_diffusion
20 | from models import DiT_models
21 |
22 | try:
23 | from streaming import StreamingDataset
24 | except ImportError:
25 | StreamingDataset = Dataset
26 |
27 |
28 | class Args(Tap):
29 |
30 | # Paths
31 | feature_path: Optional[str] = None
32 | dataset_name: str = "imagenet256"
33 | name: Optional[str] = None
34 | output_dir: str = "results"
35 | output_subdir: str = "runs"
36 |
37 | # Model
38 | model: str = "DiT-XL/2"
39 | num_classes: int = 1000
40 | image_size: int = 256
41 | predict_v: bool = False
42 | use_zero_terminal_snr: bool = False
43 | unsupervised: bool = False
44 | dino_supervised: bool = False
45 | dino_supervised_dim: int = 768
46 | flow: bool = False
47 | fixed_point: bool = False
48 | fixed_point_pre_depth: int = 1
49 | fixed_point_post_depth: int = 1
50 | fixed_point_no_grad_min_iters: int = 0
51 | fixed_point_no_grad_max_iters: int = 10
52 | fixed_point_with_grad_min_iters: int = 1
53 | fixed_point_with_grad_max_iters: int = 12
54 | fixed_point_pre_post_timestep_conditioning: bool = False
55 |
56 | # Training
57 | epochs: int = 1400
58 | global_batch_size: int = 512
59 | global_seed: int = 0
60 | num_workers: int = 4
61 | log_every: int = 100
62 | ckpt_every: int = 100_000
63 | lr: float = 1e-4
64 | log_with: str = "wandb"
65 | resume: Optional[str] = None
66 | # use_streaming_dataset: bool = False
67 | compile: bool = False
68 | debug: bool = False
69 |
70 | def process_args(self) -> None:
71 | """Additional argument processing"""
72 | if self.debug:
73 | self.log_with = 'tensorboard'
74 | self.output_subdir = 'debug'
75 |
76 | # Auto-generated name
77 | if self.name is None:
78 | experiment_index = len(glob(os.path.join(self.output_dir, self.output_subdir, "*")))
79 | model_string_name = self.model.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
80 | model_string_name += '--flow' if self.flow else '--diff'
81 | model_string_name += '--unsupervised' if self.unsupervised else '--dino_supervised' if self.dino_supervised else ''
82 | model_string_name += f'--{self.lr:f}'
83 | model_string_name += f'--v' if self.predict_v else ''
84 | model_string_name += f'--zero_snr' if self.use_zero_terminal_snr else ''
85 | ngmin, ngmax = self.fixed_point_no_grad_min_iters, self.fixed_point_no_grad_max_iters
86 | gmin, gmax = self.fixed_point_with_grad_min_iters, self.fixed_point_with_grad_max_iters
87 | model_string_name += (f'--fixed_point-pre_depth-{self.fixed_point_pre_depth}-post_depth-{self.fixed_point_post_depth}' +
88 | f'-no_grad_iters-{ngmin:02d}-{ngmax:02d}-with_grad_iters-{gmin:02d}-{gmax:02d}' +
89 | f'-pre_post_time_cond_{self.fixed_point_pre_post_timestep_conditioning}'
90 | if self.fixed_point else '--dit')
91 | self.name = f'{experiment_index:03d}-{model_string_name}'
92 |
93 | # Copy data to scratch
94 | if self.feature_path is None:
95 | assert os.getenv('SLURM_JOBID', None) is not None
96 | os.environ['TMPDIR'] = TMPDIR = os.path.join('/opt/dlami/nvme/slurm_tmpdir', os.getenv('SLURM_JOBID'))
97 | self.feature_path = os.path.join('/opt/dlami/nvme/slurm_tmpdir', os.getenv('SLURM_JOBID'))
98 | features_dir = f"{self.feature_path}/{self.dataset_name}_features"
99 | labels_dir = f"{self.feature_path}/{self.dataset_name}_{'dino_vitb8' if self.dino_supervised else 'labels'}"
100 | assert Path(features_dir).is_dir() == Path(labels_dir).is_dir()
101 | if Path(features_dir).is_dir():
102 | print(f'Features already exist in {TMPDIR}')
103 | else:
104 | start = time()
105 | print(f'Copying features to {TMPDIR}')
106 | copy_cmd_1 = f'cp ./features/{self.dataset_name}_npy.tar {TMPDIR}'
107 | copy_cmd_2 = f'tar xf {os.path.join(TMPDIR, self.dataset_name)}_npy.tar -C {TMPDIR}'
108 | print(copy_cmd_1)
109 | os.system(copy_cmd_1)
110 | print(copy_cmd_2)
111 | os.system(copy_cmd_2)
112 | print(f'Finished copying features to {TMPDIR} in {time() - start:.2f}s')
113 |
114 | # Create output directory
115 | self.output_dir = os.path.join(self.output_dir, self.output_subdir, self.name)
116 | Path(self.output_dir).mkdir(exist_ok=True, parents=True)
117 |
118 | #################################################################################
119 | # Training Helper Functions #
120 | #################################################################################
121 |
122 | @torch.no_grad()
123 | def update_ema(ema_model, model, decay=0.9999):
124 | """
125 | Step the EMA model towards the current model.
126 | """
127 | ema_params = OrderedDict(ema_model.named_parameters())
128 | model_params = OrderedDict(model.named_parameters())
129 |
130 | for name, param in model_params.items():
131 | name = name.replace("module.", "")
132 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
133 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
134 |
135 |
136 | def requires_grad(model, flag=True):
137 | """
138 | Set requires_grad flag for all parameters in a model.
139 | """
140 | for p in model.parameters():
141 | p.requires_grad = flag
142 |
143 |
144 | def create_logger(logging_dir):
145 | """
146 | Create a logger that writes to a log file and stdout.
147 | """
148 | logging.basicConfig(
149 | level=logging.INFO,
150 | format='[\033[34m%(asctime)s\033[0m] %(message)s',
151 | datefmt='%Y-%m-%d %H:%M:%S',
152 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
153 | )
154 | logger = logging.getLogger(__name__)
155 | return logger
156 |
157 |
158 | def center_crop_arr(pil_image, image_size):
159 | """
160 | Center cropping implementation from ADM.
161 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
162 | """
163 | while min(*pil_image.size) >= 2 * image_size:
164 | pil_image = pil_image.resize(
165 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
166 | )
167 |
168 | scale = image_size / min(*pil_image.size)
169 | pil_image = pil_image.resize(
170 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
171 | )
172 |
173 | arr = np.array(pil_image)
174 | crop_y = (arr.shape[0] - image_size) // 2
175 | crop_x = (arr.shape[1] - image_size) // 2
176 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
177 |
178 |
179 | class CustomDataset(Dataset):
180 | def __init__(self, features_dir, labels_dir):
181 | self.features_dir = features_dir
182 | self.labels_dir = labels_dir
183 |
184 | self.features_files = os.listdir(features_dir)
185 | self.labels_files = os.listdir(labels_dir)
186 |
187 | def __len__(self):
188 | assert len(self.features_files) == len(self.labels_files), \
189 | "Number of feature files and label files should be same"
190 | return len(self.features_files)
191 |
192 | def __getitem__(self, idx):
193 | feature_file = self.features_files[idx]
194 | label_file = self.labels_files[idx]
195 |
196 | features = np.load(os.path.join(self.features_dir, feature_file))
197 | labels = np.load(os.path.join(self.labels_dir, label_file))
198 | return torch.from_numpy(features), torch.from_numpy(labels)
199 |
200 |
201 | class CustomStreamingDataset(StreamingDataset):
202 | def __init__(
203 | self,
204 | local: str,
205 | remote: Optional[str] = None,
206 | shuffle: bool = False,
207 | batch_size: int = 1,
208 | transform: Optional[Callable] = None,
209 | ):
210 | remote = local if remote is None else remote
211 | super().__init__(remote=remote, local=local, shuffle=shuffle, batch_size=batch_size)
212 | self.transform = transform
213 |
214 | def __getitem__(self, idx):
215 | item = super().__getitem__(idx)
216 | feats = item['features'].squeeze(0)
217 | label = item['class']
218 | if self.transform is not None:
219 | feats = self.transform(feats)
220 | return feats, label
221 |
222 |
223 | #################################################################################
224 | # Training Loop #
225 | #################################################################################
226 |
227 | def main(args: Args):
228 | """
229 | Trains a new DiT model.
230 | """
231 | assert torch.cuda.is_available(), "Training currently requires at least one GPU."
232 |
233 | # Setup an experiment folder:
234 | checkpoint_dir = f"{args.output_dir}/checkpoints" # Stores saved model checkpoints
235 | os.makedirs(checkpoint_dir, exist_ok=True)
236 | args.save(f"{args.output_dir}/args.json")
237 |
238 | # Setup accelerator:
239 | find_unused_parameters = False # args.fixed_point and args.fixed_point_b_solver != 'backprop'
240 | print(f'Using {find_unused_parameters = }')
241 | accelerator_ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)
242 | accelerator = Accelerator(
243 | log_with=args.log_with,
244 | project_config=ProjectConfiguration(project_dir=args.output_dir),
245 | kwargs_handlers=[accelerator_ddp_kwargs],
246 | dynamo_backend=("inductor" if args.compile else None),
247 | )
248 | device = accelerator.device
249 |
250 | # Create trackers
251 | if accelerator.is_main_process:
252 | accelerator.init_trackers("dit", config=args.as_dict(), init_kwargs={"wandb": {"name": args.name}})
253 | if args.log_with == 'wandb':
254 | accelerator.get_tracker("wandb", unwrap=True).log_code()
255 | print(args)
256 |
257 | # Create model
258 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
259 | latent_size = args.image_size // 8
260 | model = DiT_models[args.model](
261 | input_size=latent_size,
262 | num_classes=(1 if args.unsupervised else (args.dino_supervised_dim if args.dino_supervised else args.num_classes)),
263 | is_label_continuous=args.dino_supervised,
264 | class_dropout_prob=(0.0 if args.unsupervised else 0.1),
265 | learn_sigma=(not args.flow), # TODO: Implement learned variance for flow-based models
266 | use_gradient_checkpointing=(not args.compile and not find_unused_parameters),
267 | fixed_point=args.fixed_point,
268 | fixed_point_pre_depth=args.fixed_point_pre_depth,
269 | fixed_point_post_depth=args.fixed_point_post_depth,
270 | fixed_point_no_grad_min_iters=args.fixed_point_no_grad_min_iters,
271 | fixed_point_no_grad_max_iters=args.fixed_point_no_grad_max_iters,
272 | fixed_point_with_grad_min_iters=args.fixed_point_with_grad_min_iters,
273 | fixed_point_with_grad_max_iters=args.fixed_point_with_grad_max_iters,
274 | fixed_point_pre_post_timestep_conditioning=args.fixed_point_pre_post_timestep_conditioning,
275 | ).to(device)
276 | print(f'Loaded model with params: {sum(p.numel() for p in model.parameters()):_}')
277 |
278 | # Note that parameter initialization is done within the DiT constructor
279 | ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
280 | requires_grad(ema, False)
281 | diffusion = create_diffusion(
282 | timestep_respacing="",
283 | use_flow=args.flow,
284 | predict_v=args.predict_v,
285 | use_zero_terminal_snr=args.use_zero_terminal_snr,
286 | ) # default: 1000 steps, linear noise schedule
287 | # # Note: the VAE is not used because we assume all images are already preprocessed
288 | # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
289 | print(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
290 |
291 | # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
292 | opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0)
293 |
294 | # Setup data:
295 | batch_size = int(args.global_batch_size // accelerator.num_processes)
296 | # if args.use_streaming_dataset:
297 | # data_dir = f"{args.feature_path}/{args.dataset_name}_streaming"
298 | # dataset = CustomStreamingDataset(data_dir, shuffle=True, batch_size=batch_size)
299 | # load_kwargs = dict()
300 | # else:
301 | features_dir = f"{args.feature_path}/{args.dataset_name}_features"
302 | labels_dir = f"{args.feature_path}/{args.dataset_name}_{'dino_vitb8' if args.dino_supervised else 'labels'}"
303 | dataset = CustomDataset(features_dir, labels_dir)
304 | load_kwargs = dict(shuffle=True, pin_memory=True, drop_last=True)
305 | loader = DataLoader(
306 | dataset, batch_size=batch_size, num_workers=args.num_workers, **load_kwargs
307 | )
308 | print(f"Dataset contains {len(dataset):,} images ({args.feature_path})")
309 |
310 | # Load from checkpoint
311 | train_steps = 0
312 | if args.resume is not None:
313 | checkpoint: dict = torch.load(args.resume, map_location='cpu')
314 | model.load_state_dict(checkpoint['model'])
315 | ema.load_state_dict(checkpoint['ema'])
316 | opt.load_state_dict(checkpoint['opt'])
317 | train_steps = checkpoint['train_steps'] if 'train_steps' in checkpoint else int(Path(args.resume).stem)
318 | print(f'Resuming from checkpoint: {args.resume}')
319 |
320 | # Prepare models for training:
321 | update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
322 | model.train() # important! This enables embedding dropout for classifier-free guidance
323 | ema.eval() # EMA model should always be in eval mode
324 | model, opt, loader = accelerator.prepare(model, opt, loader)
325 |
326 | # Train
327 | log_steps = 0
328 | running_loss = 0
329 | steps_per_sec = 0
330 | start_time = time()
331 | progress_bar = tqdm()
332 | print(f"Training for {args.epochs} epochs...")
333 | for epoch in range(args.epochs):
334 | if accelerator.is_main_process:
335 | print(f"Beginning epoch {epoch}...")
336 | for x, y in loader:
337 | x = x.to(device)
338 | y = y.to(device)
339 | x = x.squeeze(dim=1)
340 | y = y.squeeze(dim=-1)
341 | if args.unsupervised: # replace class labels with zeros
342 | y = torch.zeros_like(y)
343 | t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
344 | model_kwargs = dict(y=y)
345 | loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
346 | loss = loss_dict["loss"].mean()
347 | loss_float = loss.item()
348 | opt.zero_grad()
349 | accelerator.backward(loss)
350 | if train_steps < 5 and accelerator.is_main_process: # debug
351 | print(f'[Step {train_steps}] Params total: {sum(p.numel() for p in model.parameters()):_}')
352 | print(f'[Step {train_steps}] Params req. grad: {sum(p.numel() for p in model.parameters() if p.requires_grad):_}')
353 | print(f'[Step {train_steps}] Params with grad: {sum(p.numel() for p in model.parameters() if p.requires_grad and p.grad is not None):_}')
354 | opt.step()
355 | update_ema(ema, model)
356 |
357 | # Log every step
358 | if train_steps % 5 == 0 and accelerator.is_main_process:
359 | accelerator.log({
360 | "train/step": train_steps, "train/loss": loss_float,
361 | }, step=train_steps)
362 | progress_bar.set_description_str((f"train/step: {train_steps}, train/steps_per_sec: {steps_per_sec:.2f}, train/loss: {loss_float:.4f}"))
363 |
364 | # Print periodically
365 | running_loss += loss_float
366 | log_steps += 1
367 | train_steps += 1
368 | progress_bar.update()
369 | if train_steps % args.log_every == 0:
370 | # Measure training speed:
371 | torch.cuda.synchronize()
372 | end_time = time()
373 | steps_per_sec = log_steps / (end_time - start_time)
374 | # Reduce loss history over all processes:
375 | avg_loss = torch.tensor(running_loss / log_steps, device=device)
376 | avg_loss = avg_loss.item() / accelerator.num_processes
377 | print(f"\n(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
378 | accelerator.log({"train/steps_per_sec": steps_per_sec}, step=train_steps) # also log steps per second
379 | # Reset monitoring variables:
380 | running_loss = 0
381 | log_steps = 0
382 | start_time = time()
383 |
384 | # Save DiT checkpoint:
385 | if train_steps > 0 and accelerator.is_main_process and (train_steps % 5000 == 0 or train_steps % args.ckpt_every == 0):
386 | checkpoint = {
387 | "model": model.module.state_dict(),
388 | "ema": ema.state_dict(),
389 | "opt": opt.state_dict(),
390 | "args": args.as_dict(),
391 | "train_steps": train_steps,
392 | }
393 | if train_steps % 5000 == 0:
394 | checkpoint_path = f"{checkpoint_dir}/latest.pt"
395 | torch.save(checkpoint, checkpoint_path)
396 | if train_steps % args.ckpt_every == 0:
397 | checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
398 | torch.save(checkpoint, checkpoint_path)
399 | print(f"Saved checkpoint to {checkpoint_path}")
400 |
401 | model.eval() # important! This disables randomized embedding dropout
402 | # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
403 |
404 | accelerator.end_training()
405 | print("Done!")
406 |
407 |
408 | if __name__ == "__main__":
409 | torch.backends.cuda.matmul.allow_tf32 = True
410 | torch.backends.cudnn.allow_tf32 = True
411 | args = Args(explicit_bool=True).parse_args()
412 | main(args)
--------------------------------------------------------------------------------
/utils/imagenet-labels.json:
--------------------------------------------------------------------------------
1 | ["tench",
2 | "goldfish",
3 | "great white shark",
4 | "tiger shark",
5 | "hammerhead shark",
6 | "electric ray",
7 | "stingray",
8 | "cock",
9 | "hen",
10 | "ostrich",
11 | "brambling",
12 | "goldfinch",
13 | "house finch",
14 | "junco",
15 | "indigo bunting",
16 | "American robin",
17 | "bulbul",
18 | "jay",
19 | "magpie",
20 | "chickadee",
21 | "American dipper",
22 | "kite",
23 | "bald eagle",
24 | "vulture",
25 | "great grey owl",
26 | "fire salamander",
27 | "smooth newt",
28 | "newt",
29 | "spotted salamander",
30 | "axolotl",
31 | "American bullfrog",
32 | "tree frog",
33 | "tailed frog",
34 | "loggerhead sea turtle",
35 | "leatherback sea turtle",
36 | "mud turtle",
37 | "terrapin",
38 | "box turtle",
39 | "banded gecko",
40 | "green iguana",
41 | "Carolina anole",
42 | "desert grassland whiptail lizard",
43 | "agama",
44 | "frilled-necked lizard",
45 | "alligator lizard",
46 | "Gila monster",
47 | "European green lizard",
48 | "chameleon",
49 | "Komodo dragon",
50 | "Nile crocodile",
51 | "American alligator",
52 | "triceratops",
53 | "worm snake",
54 | "ring-necked snake",
55 | "eastern hog-nosed snake",
56 | "smooth green snake",
57 | "kingsnake",
58 | "garter snake",
59 | "water snake",
60 | "vine snake",
61 | "night snake",
62 | "boa constrictor",
63 | "African rock python",
64 | "Indian cobra",
65 | "green mamba",
66 | "sea snake",
67 | "Saharan horned viper",
68 | "eastern diamondback rattlesnake",
69 | "sidewinder",
70 | "trilobite",
71 | "harvestman",
72 | "scorpion",
73 | "yellow garden spider",
74 | "barn spider",
75 | "European garden spider",
76 | "southern black widow",
77 | "tarantula",
78 | "wolf spider",
79 | "tick",
80 | "centipede",
81 | "black grouse",
82 | "ptarmigan",
83 | "ruffed grouse",
84 | "prairie grouse",
85 | "peacock",
86 | "quail",
87 | "partridge",
88 | "grey parrot",
89 | "macaw",
90 | "sulphur-crested cockatoo",
91 | "lorikeet",
92 | "coucal",
93 | "bee eater",
94 | "hornbill",
95 | "hummingbird",
96 | "jacamar",
97 | "toucan",
98 | "duck",
99 | "red-breasted merganser",
100 | "goose",
101 | "black swan",
102 | "tusker",
103 | "echidna",
104 | "platypus",
105 | "wallaby",
106 | "koala",
107 | "wombat",
108 | "jellyfish",
109 | "sea anemone",
110 | "brain coral",
111 | "flatworm",
112 | "nematode",
113 | "conch",
114 | "snail",
115 | "slug",
116 | "sea slug",
117 | "chiton",
118 | "chambered nautilus",
119 | "Dungeness crab",
120 | "rock crab",
121 | "fiddler crab",
122 | "red king crab",
123 | "American lobster",
124 | "spiny lobster",
125 | "crayfish",
126 | "hermit crab",
127 | "isopod",
128 | "white stork",
129 | "black stork",
130 | "spoonbill",
131 | "flamingo",
132 | "little blue heron",
133 | "great egret",
134 | "bittern",
135 | "crane (bird)",
136 | "limpkin",
137 | "common gallinule",
138 | "American coot",
139 | "bustard",
140 | "ruddy turnstone",
141 | "dunlin",
142 | "common redshank",
143 | "dowitcher",
144 | "oystercatcher",
145 | "pelican",
146 | "king penguin",
147 | "albatross",
148 | "grey whale",
149 | "killer whale",
150 | "dugong",
151 | "sea lion",
152 | "Chihuahua",
153 | "Japanese Chin",
154 | "Maltese",
155 | "Pekingese",
156 | "Shih Tzu",
157 | "King Charles Spaniel",
158 | "Papillon",
159 | "toy terrier",
160 | "Rhodesian Ridgeback",
161 | "Afghan Hound",
162 | "Basset Hound",
163 | "Beagle",
164 | "Bloodhound",
165 | "Bluetick Coonhound",
166 | "Black and Tan Coonhound",
167 | "Treeing Walker Coonhound",
168 | "English foxhound",
169 | "Redbone Coonhound",
170 | "borzoi",
171 | "Irish Wolfhound",
172 | "Italian Greyhound",
173 | "Whippet",
174 | "Ibizan Hound",
175 | "Norwegian Elkhound",
176 | "Otterhound",
177 | "Saluki",
178 | "Scottish Deerhound",
179 | "Weimaraner",
180 | "Staffordshire Bull Terrier",
181 | "American Staffordshire Terrier",
182 | "Bedlington Terrier",
183 | "Border Terrier",
184 | "Kerry Blue Terrier",
185 | "Irish Terrier",
186 | "Norfolk Terrier",
187 | "Norwich Terrier",
188 | "Yorkshire Terrier",
189 | "Wire Fox Terrier",
190 | "Lakeland Terrier",
191 | "Sealyham Terrier",
192 | "Airedale Terrier",
193 | "Cairn Terrier",
194 | "Australian Terrier",
195 | "Dandie Dinmont Terrier",
196 | "Boston Terrier",
197 | "Miniature Schnauzer",
198 | "Giant Schnauzer",
199 | "Standard Schnauzer",
200 | "Scottish Terrier",
201 | "Tibetan Terrier",
202 | "Australian Silky Terrier",
203 | "Soft-coated Wheaten Terrier",
204 | "West Highland White Terrier",
205 | "Lhasa Apso",
206 | "Flat-Coated Retriever",
207 | "Curly-coated Retriever",
208 | "Golden Retriever",
209 | "Labrador Retriever",
210 | "Chesapeake Bay Retriever",
211 | "German Shorthaired Pointer",
212 | "Vizsla",
213 | "English Setter",
214 | "Irish Setter",
215 | "Gordon Setter",
216 | "Brittany",
217 | "Clumber Spaniel",
218 | "English Springer Spaniel",
219 | "Welsh Springer Spaniel",
220 | "Cocker Spaniels",
221 | "Sussex Spaniel",
222 | "Irish Water Spaniel",
223 | "Kuvasz",
224 | "Schipperke",
225 | "Groenendael",
226 | "Malinois",
227 | "Briard",
228 | "Australian Kelpie",
229 | "Komondor",
230 | "Old English Sheepdog",
231 | "Shetland Sheepdog",
232 | "collie",
233 | "Border Collie",
234 | "Bouvier des Flandres",
235 | "Rottweiler",
236 | "German Shepherd Dog",
237 | "Dobermann",
238 | "Miniature Pinscher",
239 | "Greater Swiss Mountain Dog",
240 | "Bernese Mountain Dog",
241 | "Appenzeller Sennenhund",
242 | "Entlebucher Sennenhund",
243 | "Boxer",
244 | "Bullmastiff",
245 | "Tibetan Mastiff",
246 | "French Bulldog",
247 | "Great Dane",
248 | "St. Bernard",
249 | "husky",
250 | "Alaskan Malamute",
251 | "Siberian Husky",
252 | "Dalmatian",
253 | "Affenpinscher",
254 | "Basenji",
255 | "pug",
256 | "Leonberger",
257 | "Newfoundland",
258 | "Pyrenean Mountain Dog",
259 | "Samoyed",
260 | "Pomeranian",
261 | "Chow Chow",
262 | "Keeshond",
263 | "Griffon Bruxellois",
264 | "Pembroke Welsh Corgi",
265 | "Cardigan Welsh Corgi",
266 | "Toy Poodle",
267 | "Miniature Poodle",
268 | "Standard Poodle",
269 | "Mexican hairless dog",
270 | "grey wolf",
271 | "Alaskan tundra wolf",
272 | "red wolf",
273 | "coyote",
274 | "dingo",
275 | "dhole",
276 | "African wild dog",
277 | "hyena",
278 | "red fox",
279 | "kit fox",
280 | "Arctic fox",
281 | "grey fox",
282 | "tabby cat",
283 | "tiger cat",
284 | "Persian cat",
285 | "Siamese cat",
286 | "Egyptian Mau",
287 | "cougar",
288 | "lynx",
289 | "leopard",
290 | "snow leopard",
291 | "jaguar",
292 | "lion",
293 | "tiger",
294 | "cheetah",
295 | "brown bear",
296 | "American black bear",
297 | "polar bear",
298 | "sloth bear",
299 | "mongoose",
300 | "meerkat",
301 | "tiger beetle",
302 | "ladybug",
303 | "ground beetle",
304 | "longhorn beetle",
305 | "leaf beetle",
306 | "dung beetle",
307 | "rhinoceros beetle",
308 | "weevil",
309 | "fly",
310 | "bee",
311 | "ant",
312 | "grasshopper",
313 | "cricket",
314 | "stick insect",
315 | "cockroach",
316 | "mantis",
317 | "cicada",
318 | "leafhopper",
319 | "lacewing",
320 | "dragonfly",
321 | "damselfly",
322 | "red admiral",
323 | "ringlet",
324 | "monarch butterfly",
325 | "small white",
326 | "sulphur butterfly",
327 | "gossamer-winged butterfly",
328 | "starfish",
329 | "sea urchin",
330 | "sea cucumber",
331 | "cottontail rabbit",
332 | "hare",
333 | "Angora rabbit",
334 | "hamster",
335 | "porcupine",
336 | "fox squirrel",
337 | "marmot",
338 | "beaver",
339 | "guinea pig",
340 | "common sorrel",
341 | "zebra",
342 | "pig",
343 | "wild boar",
344 | "warthog",
345 | "hippopotamus",
346 | "ox",
347 | "water buffalo",
348 | "bison",
349 | "ram",
350 | "bighorn sheep",
351 | "Alpine ibex",
352 | "hartebeest",
353 | "impala",
354 | "gazelle",
355 | "dromedary",
356 | "llama",
357 | "weasel",
358 | "mink",
359 | "European polecat",
360 | "black-footed ferret",
361 | "otter",
362 | "skunk",
363 | "badger",
364 | "armadillo",
365 | "three-toed sloth",
366 | "orangutan",
367 | "gorilla",
368 | "chimpanzee",
369 | "gibbon",
370 | "siamang",
371 | "guenon",
372 | "patas monkey",
373 | "baboon",
374 | "macaque",
375 | "langur",
376 | "black-and-white colobus",
377 | "proboscis monkey",
378 | "marmoset",
379 | "white-headed capuchin",
380 | "howler monkey",
381 | "titi",
382 | "Geoffroy's spider monkey",
383 | "common squirrel monkey",
384 | "ring-tailed lemur",
385 | "indri",
386 | "Asian elephant",
387 | "African bush elephant",
388 | "red panda",
389 | "giant panda",
390 | "snoek",
391 | "eel",
392 | "coho salmon",
393 | "rock beauty",
394 | "clownfish",
395 | "sturgeon",
396 | "garfish",
397 | "lionfish",
398 | "pufferfish",
399 | "abacus",
400 | "abaya",
401 | "academic gown",
402 | "accordion",
403 | "acoustic guitar",
404 | "aircraft carrier",
405 | "airliner",
406 | "airship",
407 | "altar",
408 | "ambulance",
409 | "amphibious vehicle",
410 | "analog clock",
411 | "apiary",
412 | "apron",
413 | "waste container",
414 | "assault rifle",
415 | "backpack",
416 | "bakery",
417 | "balance beam",
418 | "balloon",
419 | "ballpoint pen",
420 | "Band-Aid",
421 | "banjo",
422 | "baluster",
423 | "barbell",
424 | "barber chair",
425 | "barbershop",
426 | "barn",
427 | "barometer",
428 | "barrel",
429 | "wheelbarrow",
430 | "baseball",
431 | "basketball",
432 | "bassinet",
433 | "bassoon",
434 | "swimming cap",
435 | "bath towel",
436 | "bathtub",
437 | "station wagon",
438 | "lighthouse",
439 | "beaker",
440 | "military cap",
441 | "beer bottle",
442 | "beer glass",
443 | "bell-cot",
444 | "bib",
445 | "tandem bicycle",
446 | "bikini",
447 | "ring binder",
448 | "binoculars",
449 | "birdhouse",
450 | "boathouse",
451 | "bobsleigh",
452 | "bolo tie",
453 | "poke bonnet",
454 | "bookcase",
455 | "bookstore",
456 | "bottle cap",
457 | "bow",
458 | "bow tie",
459 | "brass",
460 | "bra",
461 | "breakwater",
462 | "breastplate",
463 | "broom",
464 | "bucket",
465 | "buckle",
466 | "bulletproof vest",
467 | "high-speed train",
468 | "butcher shop",
469 | "taxicab",
470 | "cauldron",
471 | "candle",
472 | "cannon",
473 | "canoe",
474 | "can opener",
475 | "cardigan",
476 | "car mirror",
477 | "carousel",
478 | "tool kit",
479 | "carton",
480 | "car wheel",
481 | "automated teller machine",
482 | "cassette",
483 | "cassette player",
484 | "castle",
485 | "catamaran",
486 | "CD player",
487 | "cello",
488 | "mobile phone",
489 | "chain",
490 | "chain-link fence",
491 | "chain mail",
492 | "chainsaw",
493 | "chest",
494 | "chiffonier",
495 | "chime",
496 | "china cabinet",
497 | "Christmas stocking",
498 | "church",
499 | "movie theater",
500 | "cleaver",
501 | "cliff dwelling",
502 | "cloak",
503 | "clogs",
504 | "cocktail shaker",
505 | "coffee mug",
506 | "coffeemaker",
507 | "coil",
508 | "combination lock",
509 | "computer keyboard",
510 | "confectionery store",
511 | "container ship",
512 | "convertible",
513 | "corkscrew",
514 | "cornet",
515 | "cowboy boot",
516 | "cowboy hat",
517 | "cradle",
518 | "crane (machine)",
519 | "crash helmet",
520 | "crate",
521 | "infant bed",
522 | "Crock Pot",
523 | "croquet ball",
524 | "crutch",
525 | "cuirass",
526 | "dam",
527 | "desk",
528 | "desktop computer",
529 | "rotary dial telephone",
530 | "diaper",
531 | "digital clock",
532 | "digital watch",
533 | "dining table",
534 | "dishcloth",
535 | "dishwasher",
536 | "disc brake",
537 | "dock",
538 | "dog sled",
539 | "dome",
540 | "doormat",
541 | "drilling rig",
542 | "drum",
543 | "drumstick",
544 | "dumbbell",
545 | "Dutch oven",
546 | "electric fan",
547 | "electric guitar",
548 | "electric locomotive",
549 | "entertainment center",
550 | "envelope",
551 | "espresso machine",
552 | "face powder",
553 | "feather boa",
554 | "filing cabinet",
555 | "fireboat",
556 | "fire engine",
557 | "fire screen sheet",
558 | "flagpole",
559 | "flute",
560 | "folding chair",
561 | "football helmet",
562 | "forklift",
563 | "fountain",
564 | "fountain pen",
565 | "four-poster bed",
566 | "freight car",
567 | "French horn",
568 | "frying pan",
569 | "fur coat",
570 | "garbage truck",
571 | "gas mask",
572 | "gas pump",
573 | "goblet",
574 | "go-kart",
575 | "golf ball",
576 | "golf cart",
577 | "gondola",
578 | "gong",
579 | "gown",
580 | "grand piano",
581 | "greenhouse",
582 | "grille",
583 | "grocery store",
584 | "guillotine",
585 | "barrette",
586 | "hair spray",
587 | "half-track",
588 | "hammer",
589 | "hamper",
590 | "hair dryer",
591 | "hand-held computer",
592 | "handkerchief",
593 | "hard disk drive",
594 | "harmonica",
595 | "harp",
596 | "harvester",
597 | "hatchet",
598 | "holster",
599 | "home theater",
600 | "honeycomb",
601 | "hook",
602 | "hoop skirt",
603 | "horizontal bar",
604 | "horse-drawn vehicle",
605 | "hourglass",
606 | "iPod",
607 | "clothes iron",
608 | "jack-o'-lantern",
609 | "jeans",
610 | "jeep",
611 | "T-shirt",
612 | "jigsaw puzzle",
613 | "pulled rickshaw",
614 | "joystick",
615 | "kimono",
616 | "knee pad",
617 | "knot",
618 | "lab coat",
619 | "ladle",
620 | "lampshade",
621 | "laptop computer",
622 | "lawn mower",
623 | "lens cap",
624 | "paper knife",
625 | "library",
626 | "lifeboat",
627 | "lighter",
628 | "limousine",
629 | "ocean liner",
630 | "lipstick",
631 | "slip-on shoe",
632 | "lotion",
633 | "speaker",
634 | "loupe",
635 | "sawmill",
636 | "magnetic compass",
637 | "mail bag",
638 | "mailbox",
639 | "tights",
640 | "tank suit",
641 | "manhole cover",
642 | "maraca",
643 | "marimba",
644 | "mask",
645 | "match",
646 | "maypole",
647 | "maze",
648 | "measuring cup",
649 | "medicine chest",
650 | "megalith",
651 | "microphone",
652 | "microwave oven",
653 | "military uniform",
654 | "milk can",
655 | "minibus",
656 | "miniskirt",
657 | "minivan",
658 | "missile",
659 | "mitten",
660 | "mixing bowl",
661 | "mobile home",
662 | "Model T",
663 | "modem",
664 | "monastery",
665 | "monitor",
666 | "moped",
667 | "mortar",
668 | "square academic cap",
669 | "mosque",
670 | "mosquito net",
671 | "scooter",
672 | "mountain bike",
673 | "tent",
674 | "computer mouse",
675 | "mousetrap",
676 | "moving van",
677 | "muzzle",
678 | "nail",
679 | "neck brace",
680 | "necklace",
681 | "nipple",
682 | "notebook computer",
683 | "obelisk",
684 | "oboe",
685 | "ocarina",
686 | "odometer",
687 | "oil filter",
688 | "organ",
689 | "oscilloscope",
690 | "overskirt",
691 | "bullock cart",
692 | "oxygen mask",
693 | "packet",
694 | "paddle",
695 | "paddle wheel",
696 | "padlock",
697 | "paintbrush",
698 | "pajamas",
699 | "palace",
700 | "pan flute",
701 | "paper towel",
702 | "parachute",
703 | "parallel bars",
704 | "park bench",
705 | "parking meter",
706 | "passenger car",
707 | "patio",
708 | "payphone",
709 | "pedestal",
710 | "pencil case",
711 | "pencil sharpener",
712 | "perfume",
713 | "Petri dish",
714 | "photocopier",
715 | "plectrum",
716 | "Pickelhaube",
717 | "picket fence",
718 | "pickup truck",
719 | "pier",
720 | "piggy bank",
721 | "pill bottle",
722 | "pillow",
723 | "ping-pong ball",
724 | "pinwheel",
725 | "pirate ship",
726 | "pitcher",
727 | "hand plane",
728 | "planetarium",
729 | "plastic bag",
730 | "plate rack",
731 | "plow",
732 | "plunger",
733 | "Polaroid camera",
734 | "pole",
735 | "police van",
736 | "poncho",
737 | "billiard table",
738 | "soda bottle",
739 | "pot",
740 | "potter's wheel",
741 | "power drill",
742 | "prayer rug",
743 | "printer",
744 | "prison",
745 | "projectile",
746 | "projector",
747 | "hockey puck",
748 | "punching bag",
749 | "purse",
750 | "quill",
751 | "quilt",
752 | "race car",
753 | "racket",
754 | "radiator",
755 | "radio",
756 | "radio telescope",
757 | "rain barrel",
758 | "recreational vehicle",
759 | "reel",
760 | "reflex camera",
761 | "refrigerator",
762 | "remote control",
763 | "restaurant",
764 | "revolver",
765 | "rifle",
766 | "rocking chair",
767 | "rotisserie",
768 | "eraser",
769 | "rugby ball",
770 | "ruler",
771 | "running shoe",
772 | "safe",
773 | "safety pin",
774 | "salt shaker",
775 | "sandal",
776 | "sarong",
777 | "saxophone",
778 | "scabbard",
779 | "weighing scale",
780 | "school bus",
781 | "schooner",
782 | "scoreboard",
783 | "CRT screen",
784 | "screw",
785 | "screwdriver",
786 | "seat belt",
787 | "sewing machine",
788 | "shield",
789 | "shoe store",
790 | "shoji",
791 | "shopping basket",
792 | "shopping cart",
793 | "shovel",
794 | "shower cap",
795 | "shower curtain",
796 | "ski",
797 | "ski mask",
798 | "sleeping bag",
799 | "slide rule",
800 | "sliding door",
801 | "slot machine",
802 | "snorkel",
803 | "snowmobile",
804 | "snowplow",
805 | "soap dispenser",
806 | "soccer ball",
807 | "sock",
808 | "solar thermal collector",
809 | "sombrero",
810 | "soup bowl",
811 | "space bar",
812 | "space heater",
813 | "space shuttle",
814 | "spatula",
815 | "motorboat",
816 | "spider web",
817 | "spindle",
818 | "sports car",
819 | "spotlight",
820 | "stage",
821 | "steam locomotive",
822 | "through arch bridge",
823 | "steel drum",
824 | "stethoscope",
825 | "scarf",
826 | "stone wall",
827 | "stopwatch",
828 | "stove",
829 | "strainer",
830 | "tram",
831 | "stretcher",
832 | "couch",
833 | "stupa",
834 | "submarine",
835 | "suit",
836 | "sundial",
837 | "sunglass",
838 | "sunglasses",
839 | "sunscreen",
840 | "suspension bridge",
841 | "mop",
842 | "sweatshirt",
843 | "swimsuit",
844 | "swing",
845 | "switch",
846 | "syringe",
847 | "table lamp",
848 | "tank",
849 | "tape player",
850 | "teapot",
851 | "teddy bear",
852 | "television",
853 | "tennis ball",
854 | "thatched roof",
855 | "front curtain",
856 | "thimble",
857 | "threshing machine",
858 | "throne",
859 | "tile roof",
860 | "toaster",
861 | "tobacco shop",
862 | "toilet seat",
863 | "torch",
864 | "totem pole",
865 | "tow truck",
866 | "toy store",
867 | "tractor",
868 | "semi-trailer truck",
869 | "tray",
870 | "trench coat",
871 | "tricycle",
872 | "trimaran",
873 | "tripod",
874 | "triumphal arch",
875 | "trolleybus",
876 | "trombone",
877 | "tub",
878 | "turnstile",
879 | "typewriter keyboard",
880 | "umbrella",
881 | "unicycle",
882 | "upright piano",
883 | "vacuum cleaner",
884 | "vase",
885 | "vault",
886 | "velvet",
887 | "vending machine",
888 | "vestment",
889 | "viaduct",
890 | "violin",
891 | "volleyball",
892 | "waffle iron",
893 | "wall clock",
894 | "wallet",
895 | "wardrobe",
896 | "military aircraft",
897 | "sink",
898 | "washing machine",
899 | "water bottle",
900 | "water jug",
901 | "water tower",
902 | "whiskey jug",
903 | "whistle",
904 | "wig",
905 | "window screen",
906 | "window shade",
907 | "Windsor tie",
908 | "wine bottle",
909 | "wing",
910 | "wok",
911 | "wooden spoon",
912 | "wool",
913 | "split-rail fence",
914 | "shipwreck",
915 | "yawl",
916 | "yurt",
917 | "website",
918 | "comic book",
919 | "crossword",
920 | "traffic sign",
921 | "traffic light",
922 | "dust jacket",
923 | "menu",
924 | "plate",
925 | "guacamole",
926 | "consomme",
927 | "hot pot",
928 | "trifle",
929 | "ice cream",
930 | "ice pop",
931 | "baguette",
932 | "bagel",
933 | "pretzel",
934 | "cheeseburger",
935 | "hot dog",
936 | "mashed potato",
937 | "cabbage",
938 | "broccoli",
939 | "cauliflower",
940 | "zucchini",
941 | "spaghetti squash",
942 | "acorn squash",
943 | "butternut squash",
944 | "cucumber",
945 | "artichoke",
946 | "bell pepper",
947 | "cardoon",
948 | "mushroom",
949 | "Granny Smith",
950 | "strawberry",
951 | "orange",
952 | "lemon",
953 | "fig",
954 | "pineapple",
955 | "banana",
956 | "jackfruit",
957 | "custard apple",
958 | "pomegranate",
959 | "hay",
960 | "carbonara",
961 | "chocolate syrup",
962 | "dough",
963 | "meatloaf",
964 | "pizza",
965 | "pot pie",
966 | "burrito",
967 | "red wine",
968 | "espresso",
969 | "cup",
970 | "eggnog",
971 | "alp",
972 | "bubble",
973 | "cliff",
974 | "coral reef",
975 | "geyser",
976 | "lakeshore",
977 | "promontory",
978 | "shoal",
979 | "seashore",
980 | "valley",
981 | "volcano",
982 | "baseball player",
983 | "bridegroom",
984 | "scuba diver",
985 | "rapeseed",
986 | "daisy",
987 | "yellow lady's slipper",
988 | "corn",
989 | "acorn",
990 | "rose hip",
991 | "horse chestnut seed",
992 | "coral fungus",
993 | "agaric",
994 | "gyromitra",
995 | "stinkhorn mushroom",
996 | "earth star",
997 | "hen-of-the-woods",
998 | "bolete",
999 | "ear",
1000 | "toilet paper"]
1001 |
--------------------------------------------------------------------------------
/visuals/splash-figure-v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukemelas/fixed-point-diffusion-models/519e1286ba27c34e177e05962c5d9e66edce31e6/visuals/splash-figure-v1.png
--------------------------------------------------------------------------------