├── .gitignore ├── README.md ├── aconfs └── 1_node_1_gpu_ddp.yaml ├── adaptive_controller.py ├── demos └── fpdm_inference.ipynb ├── diffusion ├── __init__.py ├── diffusion_utils.py ├── gaussian_diffusion.py ├── respace.py └── timestep_sampler.py ├── environment.yml ├── models.py ├── sample.py ├── train.py ├── utils └── imagenet-labels.json └── visuals └── splash-figure-v1.png /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # ONLY FOR ANONYMOUS CODE 3 | CODE_OF_CONDUCT.md 4 | CONTRIBUTING.md 5 | LICENSE.txt 6 | convert-to-diffusers/ 7 | download.py 8 | env-luke.yaml 9 | explore/ 10 | extract_features.py 11 | figures/ 12 | performance/ 13 | run_DiT.ipynb 14 | sample_ddp.py 15 | scripts/ 16 | train_options/ 17 | 18 | 19 | # Custom 20 | .vscode 21 | wandb 22 | outputs 23 | features 24 | tmp* 25 | slurm-logs 26 | slurm_logs 27 | results 28 | archive 29 | debug 30 | samples/* 31 | !samples/imagenet-labels.json 32 | samples-examples 33 | 34 | # Byte-compiled / optimized / DLL files 35 | __pycache__/ 36 | *.py[cod] 37 | *$py.class 38 | .github 39 | 40 | # C extensions 41 | *.so 42 | 43 | # Distribution / packaging 44 | .Python 45 | build/ 46 | develop-eggs/ 47 | dist/ 48 | downloads/ 49 | eggs/ 50 | .eggs/ 51 | lib/ 52 | lib64/ 53 | parts/ 54 | sdist/ 55 | var/ 56 | wheels/ 57 | *.egg-info/ 58 | .installed.cfg 59 | *.egg 60 | MANIFEST 61 | 62 | # Lightning /research 63 | test_tube_exp/ 64 | tests/tests_tt_dir/ 65 | tests/save_dir 66 | default/ 67 | data/ 68 | test_tube_logs/ 69 | test_tube_data/ 70 | datasets/ 71 | model_weights/ 72 | tests/save_dir 73 | tests/tests_tt_dir/ 74 | processed/ 75 | raw/ 76 | 77 | # PyInstaller 78 | # Usually these files are written by a python script from a template 79 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 80 | *.manifest 81 | *.spec 82 | 83 | # Installer logs 84 | pip-log.txt 85 | pip-delete-this-directory.txt 86 | 87 | # Unit test / coverage reports 88 | htmlcov/ 89 | .tox/ 90 | .coverage 91 | .coverage.* 92 | .cache 93 | nosetests.xml 94 | coverage.xml 95 | *.cover 96 | .hypothesis/ 97 | .pytest_cache/ 98 | 99 | # Translations 100 | *.mo 101 | *.pot 102 | 103 | # Django stuff: 104 | *.log 105 | local_settings.py 106 | db.sqlite3 107 | 108 | # Flask stuff: 109 | instance/ 110 | .webassets-cache 111 | 112 | # Scrapy stuff: 113 | .scrapy 114 | 115 | # Sphinx documentation 116 | docs/_build/ 117 | 118 | # PyBuilder 119 | target/ 120 | 121 | # Jupyter Notebook 122 | .ipynb_checkpoints 123 | 124 | # pyenv 125 | .python-version 126 | 127 | # celery beat schedule file 128 | celerybeat-schedule 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | 155 | # IDEs 156 | .idea 157 | .vscode 158 | 159 | # seed project 160 | lightning_logs/ 161 | MNIST 162 | .DS_Store 163 | *.code-workspace 164 | vis.zip 165 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 | 6 | [![Contributors][contributors-shield]][contributors-url] 7 | [![Forks][forks-shield]][forks-url] 8 | [![Stargazers][stars-shield]][stars-url] 9 | [![Issues][issues-shield]][issues-url] 10 | 11 | ### Fixed Point Diffusion Models 12 | 13 | [Project Page](https://lukemelas.github.io/fixed-point-diffusion-models/) · [Paper](https://arxiv.org/abs/2401.08741) 14 | 15 |
16 | 17 |
18 | 19 | ![DiT samples](visuals/splash-figure-v1.png) 20 | 21 | ### Table of Contents 22 | - [Abstract](#abstract) 23 | - [Setup & Installation](#setup) 24 | - [Model](#model) 25 | - [Training](#training) 26 | - [Sampling](#sampling) 27 | - [Contribution](#contribution) 28 | - [Acknowledgements](#acknowledgements) 29 | 30 | ### Roadmap 31 | 32 | - [x] Code and paper release 🎉🎉 33 | - [x] Jupyter notebook example 34 | - [ ] Pretrained model release _(coming soon)_ 35 | - [ ] Code walkthrough and tutorial 36 | 37 | ### Abstract 38 | 39 | We introduce the Fixed Point Diffusion Model (FPDM), a novel approach to image generation that integrates the concept of fixed point solving into the framework of diffusion-based generative modeling. Our approach embeds an implicit fixed point solving layer into the denoising network of a diffusion model, transforming the diffusion process into a sequence of closely-related fixed point problems. Combined with a new stochastic training method, this approach significantly reduces model size, reduces memory usage, and accelerates training. Moreover, it enables the development of two new techniques to improve sampling efficiency: reallocating computation across timesteps and reusing fixed point solutions between timesteps. We conduct extensive experiments with state-of-the-art models on ImageNet, FFHQ, CelebA-HQ, and LSUN-Church, demonstrating substantial improvements in performance and efficiency. Compared to the state-of-the-art DiT model, FPDM contains 87% fewer parameters, consumes 60% less memory during training, and improves image generation quality in situations where sampling computation or time is limited. 40 | 41 | ### Setup 42 | 43 | We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want 44 | to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file. 45 | 46 | ```bash 47 | conda env create -f environment.yml 48 | conda activate DiT 49 | ``` 50 | 51 | ### Model 52 | 53 | Our model definition, including all fixed point functionality, is included in `models.py`. 54 | 55 | ### Training 56 | 57 | Example training scripts: 58 | ```bash 59 | # Standard model 60 | accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py 61 | 62 | # Fixed Point Diffusion Model 63 | accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py --fixed_point True --deq_pre_depth 1 --deq_post_depth 1 64 | 65 | # With v-prediction and zero-SNR 66 | accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py --output_subdir v_pred_exp --predict_v True --use_zero_terminal_snr True --fixed_point True --deq_pre_depth 1 --deq_post_depth 1 67 | 68 | # With v-prediction and zero-SNR, with 4 pre- and post-layers 69 | accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py --output_subdir v_pred_exp --predict_v True --use_zero_terminal_snr True --fixed_point True --deq_pre_depth 4 --deq_post_depth 4 70 | ``` 71 | 72 | ### Sampling 73 | 74 | Example sampling scripts: 75 | ```bash 76 | # Sample 77 | python sample.py --ckpt {checkpoint-path-from-above} --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --num_sampling_steps 20 78 | 79 | # Sample with fewer iterations per timestep and more timesteps 80 | python sample.py --ckpt {checkpoint-path-from-above} --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --fixed_point_iters 12 --num_sampling_steps 40 --fixed_point_reuse_solution True 81 | ``` 82 | 83 | ### Contribution 84 | 85 | Pull requests are welcome! 86 | 87 | ### Acknowledgements 88 | 89 | * The strong baseline from DiT: 90 | ``` 91 | @article{Peebles2022DiT, 92 | title={Scalable Diffusion Models with Transformers}, 93 | author={William Peebles and Saining Xie}, 94 | year={2022}, 95 | journal={arXiv preprint arXiv:2212.09748}, 96 | } 97 | ``` 98 | 99 | * The fast-DiT code from [chuanyangjin](https://github.com/chuanyangjin/fast-DiT): 100 | 101 | * All the great work from the [CMU Locus Lab](https://github.com/locuslab) on Deep Equilibrium Models, which started with: 102 | ``` 103 | @inproceedings{bai2019deep, 104 | author = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun}, 105 | title = {Deep Equilibrium Models}, 106 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 107 | year = {2019}, 108 | } 109 | ``` 110 | 111 | * L.M.K. thanks the Rhodes Trust for their scholarship support. 112 | 113 | 114 | 115 | 116 | [contributors-shield]: https://img.shields.io/github/contributors/lukemelas/fixed-point-diffusion-models.svg?style=for-the-badge 117 | [contributors-url]: https://github.com/lukemelas/fixed-point-diffusion-models/graphs/contributors 118 | [forks-shield]: https://img.shields.io/github/forks/lukemelas/fixed-point-diffusion-models.svg?style=for-the-badge 119 | [forks-url]: https://github.com/lukemelas/fixed-point-diffusion-models/network/members 120 | [stars-shield]: https://img.shields.io/github/stars/lukemelas/fixed-point-diffusion-models.svg?style=for-the-badge 121 | [stars-url]: https://github.com/lukemelas/fixed-point-diffusion-models/stargazers 122 | [issues-shield]: https://img.shields.io/github/issues/lukemelas/fixed-point-diffusion-models.svg?style=for-the-badge 123 | [issues-url]: https://github.com/lukemelas/fixed-point-diffusion-models/issues 124 | [license-shield]: https://img.shields.io/github/license/lukemelas/fixed-point-diffusion-models.svg?style=for-the-badge 125 | [license-url]: https://github.com/lukemelas/fixed-point-diffusion-models/blob/master/LICENSE.txt 126 | -------------------------------------------------------------------------------- /aconfs/1_node_1_gpu_ddp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | downcast_bf16: 'no' 4 | gpu_ids: all 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: bf16 8 | num_machines: 1 9 | num_processes: 1 10 | main_process_port: 17337 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /adaptive_controller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # calculate the norm (modified) of the delta 4 | def batched_distance(dt): 5 | dist = torch.norm(dt, p = 2, dim = (1, 2)) 6 | dist = torch.mean(dist) 7 | return dist 8 | 9 | class LinearController(): 10 | def __init__(self, budget, tot_steps, ratio = [0.5, 1.5], type = "increasing") -> None: 11 | if type == "increasing": 12 | self.ratio_list = torch.linspace(ratio[0], ratio[1], tot_steps) 13 | elif type == "decreasing": 14 | self.ratio_list = torch.linspace(ratio[1], ratio[0], tot_steps) 15 | elif type == "fixed": 16 | self.ratio_list = torch.ones(tot_steps) 17 | 18 | assert len(self.ratio_list) == tot_steps 19 | self.budget_list_float = 1 + (budget - tot_steps) * self.ratio_list / torch.sum(self.ratio_list) 20 | 21 | rounding_threshold = len(self.budget_list_float) // 2 22 | while True: 23 | self.budget_list = torch.zeros_like(self.budget_list_float) 24 | for i in range(len(self.budget_list_float)): 25 | # in the first half, round up 26 | # in the second half, round down 27 | # this ensures the sum of the budget is roughly equal to the total budget 28 | if i < rounding_threshold: 29 | self.budget_list[i] = torch.ceil(self.budget_list_float[i]) 30 | else: 31 | self.budget_list[i] = torch.floor(self.budget_list_float[i]) 32 | if torch.sum(self.budget_list) <= budget: 33 | break 34 | rounding_threshold -= 1 35 | print(f"fixed the rounding issue! in ", type) 36 | 37 | assert torch.sum(self.budget_list) <= budget 38 | 39 | self.pointer = len(self.budget_list) - 1 40 | self.threshold = None 41 | self.cost = None 42 | self.lowerbound = None 43 | self.upperbound = None 44 | 45 | def init_image(self): 46 | self.pointer = len(self.budget_list) - 1 47 | print(f"in Fix controller, budget list: {self.budget_list}") 48 | def end_image(self): 49 | pass 50 | def update(self): 51 | return True 52 | def get(self): 53 | ret = self.budget_list[self.pointer] 54 | self.pointer -= 1 55 | return int(ret) 56 | 57 | 58 | # define a class to analyis threshold 59 | class FixedController(): 60 | def __init__(self, threshold) -> None: 61 | self.threshold = threshold 62 | self.cost = 0 63 | 64 | def get(self): 65 | return self.threshold 66 | 67 | def add_cost(self, cost): 68 | self.cost += cost 69 | 70 | 71 | 72 | # class ThresholdController(): 73 | # def __init__(self, budget, ratio_delta = 0) -> None: 74 | # self.budget = budget 75 | # self.threshold = 100.0 76 | # self.last_threshold = -1 77 | # self.lowerbound = 0.9 * budget 78 | # self.upperbound = 1.0 * budget 79 | # self.success_list = [] 80 | # self.all_threshold_list = [] 81 | # self.try_count = [] 82 | # self.costs = [] 83 | 84 | # self.delta_init_ratio = 0.03 85 | # self.delta = self.threshold * self.delta_init_ratio 86 | # self.count = 0 87 | # self.thresholds = [] 88 | # self.cost = 0 89 | # self.pivot = False 90 | 91 | # self.max_count = 20 # the max number of threshold update on each batch. 92 | 93 | # # retrieve the number 94 | # self.upperbound_ratio = 1.0 + ratio_delta 95 | # self.lowerbound_ratio = 1.0 - ratio_delta 96 | 97 | # def init_image(self): 98 | # if len(self.all_threshold_list) != 0: 99 | # self.threshold = self.success_list[len(self.success_list) // 2] 100 | 101 | # self.last_threshold = -1 102 | # self.delta = self.threshold * self.delta_init_ratio 103 | # self.count = 0 104 | # self.cost = 0 105 | # self.pivot = False 106 | # self.thresholds = [] 107 | # return self.threshold 108 | 109 | # def end_image(self): 110 | # self.success_list.append(self.threshold) 111 | # self.try_count.append(self.count) 112 | # self.all_threshold_list.append(self.thresholds) 113 | # self.success_list.sort() 114 | 115 | # def get(self): 116 | # return self.threshold 117 | 118 | # def add_cost(self, cost): 119 | # self.cost += cost 120 | 121 | # def update(self): 122 | # # after finishing an image 123 | # self.thresholds.append(self.threshold) 124 | # # print(f"in update: thres={self.threshold}, last thres={self.last_threshold}") 125 | # if self.cost > self.upperbound: 126 | # if self.last_threshold != -1 and self.threshold < self.last_threshold: # decrease, increase 127 | # self.pivot = True 128 | # if self.cost < self.lowerbound: 129 | # if self.last_threshold != -1 and self.threshold > self.last_threshold: # increase, decrease 130 | # self.pivot = True 131 | 132 | # if self.pivot: 133 | # self.delta /= 2 134 | # else: 135 | # self.delta *= 2 136 | 137 | 138 | # self.last_threshold = self.threshold 139 | # if self.cost > self.upperbound: 140 | # self.threshold = self.threshold + self.delta 141 | # if self.cost < self.lowerbound: 142 | # self.threshold = self.threshold - self.delta 143 | 144 | # passed = (self.lowerbound <= self.cost) and (self.cost <= self.upperbound) 145 | # # TODO: special case 146 | # if self.count > self.max_count and self.cost <= self.lowerbound: 147 | # passed = True 148 | # print(f"passed={passed}") 149 | 150 | # if passed: 151 | # self.costs.append(self.cost) 152 | 153 | # self.cost = 0 154 | # self.count += 1 155 | 156 | # return passed 157 | 158 | # def output(self): 159 | # """ 160 | # success_list: list of threshold, (B) 161 | # all_threshold_list: list of list of threshold, (B, *) 162 | # try_count: list of try count, (B) 163 | # """ 164 | # return self.success_list, self.all_threshold_list, self.try_count, self.costs 165 | 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /demos/fpdm_inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Fixed Point Diffusion Models (FPDM)\n", 8 | "This notebook shows how to run the image sampling with FPDM." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "## Set Up\n", 16 | "We provide an environment.yml file that can be used to create a Conda environment. See how to install all required packages in `README.md`." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "# Standard library imports\n", 26 | "import json\n", 27 | "import math\n", 28 | "import random\n", 29 | "import sys\n", 30 | "from contextlib import nullcontext\n", 31 | "from pathlib import Path\n", 32 | "from typing import Optional\n", 33 | "\n", 34 | "# Third-party imports\n", 35 | "import torch\n", 36 | "import torch.nn as nn\n", 37 | "from PIL import Image\n", 38 | "from torch import Tensor\n", 39 | "from torch.utils.checkpoint import checkpoint\n", 40 | "from accelerate import Accelerator\n", 41 | "from accelerate.utils import set_seed\n", 42 | "from diffusers.models import AutoencoderKL\n", 43 | "from jaxtyping import Float, Shaped\n", 44 | "from tap import Tap\n", 45 | "from tqdm import trange\n", 46 | "from timm.models.vision_transformer import PatchEmbed, Attention, Mlp\n", 47 | "\n", 48 | "# Local module imports\n", 49 | "sys.path.append(\"..\")\n", 50 | "from diffusion import create_diffusion\n", 51 | "from download import find_model\n", 52 | "from models import DiT_models\n" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "## Hyperparameters" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "class Args(Tap):\n", 69 | " \"\"\"\n", 70 | " A class to define and store hyperparameters and configurations for the demo.\n", 71 | " \"\"\"\n", 72 | "\n", 73 | " # File and directory paths.\n", 74 | " output_dir: str = 'demo_samples'\n", 75 | "\n", 76 | " # Dataset configuration.\n", 77 | " dataset_name: str = \"imagenet256\"\n", 78 | "\n", 79 | " # Model specific parameters.\n", 80 | " model: str = \"DiT-XL/2\"\n", 81 | " vae: str = \"mse\"\n", 82 | " num_classes: int = 1000\n", 83 | " image_size: int = 256\n", 84 | " predict_v: bool = False\n", 85 | " use_zero_terminal_snr: bool = False\n", 86 | " unsupervised: bool = False\n", 87 | " dino_supervised: bool = False\n", 88 | " dino_supervised_dim: int = 768\n", 89 | " flow: bool = False\n", 90 | " debug: bool = False\n", 91 | "\n", 92 | " # Fixed Point settings.\n", 93 | " fixed_point: bool = False\n", 94 | " fixed_point_pre_depth: int = 2\n", 95 | " fixed_point_post_depth: int = 2\n", 96 | " fixed_point_iters: Optional[int] = None\n", 97 | " fixed_point_pre_post_timestep_conditioning: bool = False\n", 98 | " fixed_point_reuse_solution: bool = False\n", 99 | "\n", 100 | " # Sampling configuration.\n", 101 | " ddim: bool = False\n", 102 | " cfg_scale: float = 4.0\n", 103 | " num_sampling_steps: int = 250\n", 104 | " batch_size: int = 4\n", 105 | " ckpt: str = '/work/xingjian/diff-deq-inference/pretrained/DiT-XL-2/checkpoints/0500000.pt' # replace it with the Path to your checkpoint.\n", 106 | " global_seed: int = 0\n", 107 | "\n", 108 | " # Parallelization settings.\n", 109 | " sample_index_start: int = 0\n", 110 | " sample_index_end: Optional[int] = 32\n", 111 | "\n", 112 | " def process_args(self):\n", 113 | " \"\"\"\n", 114 | " Method for additional argument processing and validation.\n", 115 | " \"\"\"\n", 116 | " # Debug mode configuration.\n", 117 | " if self.debug:\n", 118 | " self.log_with = 'tensorboard'\n", 119 | " self.name = 'debug'\n", 120 | "\n", 121 | " # Set default values and validate image size.\n", 122 | " self.fixed_point_iters = self.fixed_point_iters or (28 - self.fixed_point_pre_depth - self.fixed_point_post_depth)\n", 123 | " assert self.image_size % 8 == 0, \"Image size must be divisible by 8 (for the VAE encoder).\"\n", 124 | " self.latent_size = self.H_lat = self.W_lat = self.image_size // 8\n", 125 | "\n", 126 | " # Additional checks and validations.\n", 127 | " if self.cfg_scale < 1.0:\n", 128 | " raise ValueError(\"In almost all cases, cfg_scale should be >= 1.0\")\n", 129 | " \n", 130 | " if self.unsupervised:\n", 131 | " assert self.cfg_scale == 1.0\n", 132 | " self.num_classes = 1\n", 133 | " elif self.dino_supervised:\n", 134 | " raise NotImplementedError()\n", 135 | " \n", 136 | " if not Path(self.ckpt).is_file():\n", 137 | " raise ValueError(self.ckpt)\n", 138 | " \n", 139 | " # Creating the output directory.\n", 140 | " output_parent = Path(self.output_dir) / Path(self.ckpt).parent.parent.name\n", 141 | " if self.debug:\n", 142 | " output_dirname = 'debug'\n", 143 | " else:\n", 144 | " output_dirname = f'num_sampling_steps-{self.num_sampling_steps}--cfg_scale-{self.cfg_scale}'\n", 145 | " if self.fixed_point:\n", 146 | " output_dirname += f'--fixed_point_iters-{self.fixed_point_iters}--fixed_point_reuse_solution-{self.fixed_point_reuse_solution}--fixed_point_pptc-{self.fixed_point_pre_post_timestep_conditioning}'\n", 147 | " if self.ddim:\n", 148 | " output_dirname += f'--ddim'\n", 149 | " self.output_dir = str(output_parent / output_dirname)\n", 150 | " Path(self.output_dir).mkdir(exist_ok=True, parents=True)\n" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "class Args(Tap):\n", 160 | "\n", 161 | " # Paths\n", 162 | " output_dir: str = 'samples'\n", 163 | "\n", 164 | " # Dataset\n", 165 | " dataset_name: str = \"imagenet256\"\n", 166 | "\n", 167 | " # Model\n", 168 | " model: str = \"DiT-XL/2\"\n", 169 | " vae: str = \"mse\"\n", 170 | " num_classes: int = 1000\n", 171 | " image_size: int = 256\n", 172 | " predict_v: bool = False\n", 173 | " use_zero_terminal_snr: bool = False\n", 174 | " unsupervised: bool = False\n", 175 | " dino_supervised: bool = False\n", 176 | " dino_supervised_dim: int = 768\n", 177 | " flow: bool = False\n", 178 | " debug: bool = False\n", 179 | "\n", 180 | " # Fixed Point settings\n", 181 | " fixed_point: bool = False\n", 182 | " fixed_point_pre_depth: int = 2\n", 183 | " fixed_point_post_depth: int = 2\n", 184 | " fixed_point_iters: Optional[int] = None\n", 185 | " fixed_point_pre_post_timestep_conditioning: bool = False\n", 186 | " fixed_point_reuse_solution: bool = False\n", 187 | "\n", 188 | " # Sampling\n", 189 | " ddim: bool = False\n", 190 | " cfg_scale: float = 4.0\n", 191 | " num_sampling_steps: int = 250\n", 192 | " batch_size: int = 4\n", 193 | " ckpt: str = '/work/xingjian/diff-deq-inference/pretrained/DiT-XL-2/checkpoints/0500000.pt' # replace with path to checkpoint\n", 194 | " global_seed: int = 0\n", 195 | " \n", 196 | " # Parallelization\n", 197 | " sample_index_start: int = 0\n", 198 | " sample_index_end: Optional[int] = 32\n", 199 | "\n", 200 | " def process_args(self) -> None:\n", 201 | " \"\"\"Additional argument processing\"\"\"\n", 202 | " if self.debug:\n", 203 | " self.log_with = 'tensorboard'\n", 204 | " self.name = 'debug'\n", 205 | "\n", 206 | " # Defaults\n", 207 | " self.fixed_point_iters = self.fixed_point_iters or (28 - self.fixed_point_pre_depth - self.fixed_point_post_depth)\n", 208 | " assert self.image_size % 8 == 0, \"Image size must be divisible by 8 (for the VAE encoder).\"\n", 209 | " self.latent_size = self.H_lat = self.W_lat = self.image_size // 8\n", 210 | " # Checks\n", 211 | " if self.cfg_scale < 1.0:\n", 212 | " raise ValueError(\"In almost all cases, cfg_scale should be >= 1.0\")\n", 213 | " if self.unsupervised:\n", 214 | " assert self.cfg_scale == 1.0\n", 215 | " self.num_classes = 1\n", 216 | " elif self.dino_supervised:\n", 217 | " raise NotImplementedError()\n", 218 | " if not Path(self.ckpt).is_file():\n", 219 | " raise ValueError(self.ckpt)\n", 220 | "\n", 221 | " # Create output directory\n", 222 | " output_parent = Path(self.output_dir) / Path(self.ckpt).parent.parent.name\n", 223 | " if self.debug:\n", 224 | " output_dirname = 'debug'\n", 225 | " else:\n", 226 | " output_dirname = f'num_sampling_steps-{self.num_sampling_steps}--cfg_scale-{self.cfg_scale}'\n", 227 | " if self.fixed_point:\n", 228 | " output_dirname += f'--fixed_point_iters-{self.fixed_point_iters}--fixed_point_reuse_solution-{self.fixed_point_reuse_solution}--fixed_point_pptc-{self.fixed_point_pre_post_timestep_conditioning}'\n", 229 | " if self.ddim:\n", 230 | " output_dirname += f'--ddim'\n", 231 | " self.output_dir = str(output_parent / output_dirname)\n", 232 | " Path(self.output_dir).mkdir(exist_ok=True, parents=True)" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "args = Args()\n", 242 | "args.process_args()" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "## Network architecture\n", 255 | "We modify the original DiT class to support fixed point blocks." 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "\n", 265 | "from models import TimestepEmbedder, LabelEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed\n", 266 | "class DiT(nn.Module):\n", 267 | " \"\"\"\n", 268 | " Diffusion model with a Transformer backbone. It includes methods for the forward pass\n", 269 | " and initialization of weights. The model can operate in both standard and fixed-point modes.\n", 270 | " \"\"\"\n", 271 | " def __init__(\n", 272 | " self,\n", 273 | " input_size=32,\n", 274 | " patch_size=2,\n", 275 | " in_channels=4,\n", 276 | " hidden_size=1152,\n", 277 | " depth=28,\n", 278 | " num_heads=16,\n", 279 | " mlp_ratio=4.0,\n", 280 | " class_dropout_prob=0.1,\n", 281 | " num_classes=1000,\n", 282 | " learn_sigma=True,\n", 283 | " use_cfg_embedding: bool = True,\n", 284 | " use_gradient_checkpointing: bool = True,\n", 285 | " is_label_continuous: bool = False,\n", 286 | "\n", 287 | " # below are Fixed Point-specific arguments.\n", 288 | " fixed_point: bool = False,\n", 289 | "\n", 290 | " # size\n", 291 | " fixed_point_pre_depth: int = 1, \n", 292 | " fixed_point_post_depth: int = 1, \n", 293 | "\n", 294 | " # iteration counts\n", 295 | " fixed_point_no_grad_min_iters: int = 0, \n", 296 | " fixed_point_no_grad_max_iters: int = 0,\n", 297 | " fixed_point_with_grad_min_iters: int = 28, \n", 298 | " fixed_point_with_grad_max_iters: int = 28,\n", 299 | "\n", 300 | " # solution recycle\n", 301 | " fixed_point_reuse_solution = False,\n", 302 | " \n", 303 | " # pre_post_timestep_conditioning\n", 304 | " fixed_point_pre_post_timestep_conditioning: bool = True,\n", 305 | " ):\n", 306 | " super().__init__()\n", 307 | " self.learn_sigma = learn_sigma\n", 308 | " self.in_channels = in_channels\n", 309 | " self.out_channels = in_channels * 2 if learn_sigma else in_channels\n", 310 | " self.patch_size = patch_size\n", 311 | " self.num_heads = num_heads\n", 312 | " self.use_gradient_checkpointing = use_gradient_checkpointing\n", 313 | "\n", 314 | " self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)\n", 315 | " self.t_embedder = TimestepEmbedder(hidden_size)\n", 316 | " self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, \n", 317 | " use_cfg_embedding=use_cfg_embedding, continuous=is_label_continuous)\n", 318 | " num_patches = self.x_embedder.num_patches\n", 319 | " \n", 320 | " # Will use fixed sin-cos embedding:\n", 321 | " # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)\n", 322 | " self.register_buffer('pos_embed', torch.zeros(1, num_patches, hidden_size))\n", 323 | "\n", 324 | " # New: Fixed Point\n", 325 | " self.fixed_point = fixed_point\n", 326 | " if self.fixed_point:\n", 327 | " self.fixed_point_no_grad_min_iters = fixed_point_no_grad_min_iters\n", 328 | " self.fixed_point_no_grad_max_iters = fixed_point_no_grad_max_iters\n", 329 | " self.fixed_point_with_grad_min_iters = fixed_point_with_grad_min_iters\n", 330 | " self.fixed_point_with_grad_max_iters = fixed_point_with_grad_max_iters\n", 331 | " self.fixed_point_pre_post_timestep_conditioning = fixed_point_pre_post_timestep_conditioning\n", 332 | " self.blocks_pre = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(fixed_point_pre_depth)])\n", 333 | " self.block_pre_projection = nn.Linear(hidden_size, hidden_size)\n", 334 | " self.block_fixed_point_projection_fc1 = nn.Linear(2 * hidden_size, 2 * hidden_size)\n", 335 | " self.block_fixed_point_projection_act = nn.GELU(approximate=\"tanh\")\n", 336 | " self.block_fixed_point_projection_fc2 = nn.Linear(2 * hidden_size, hidden_size)\n", 337 | " self.block_fixed_point = DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)\n", 338 | " self.blocks_post = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(fixed_point_post_depth)])\n", 339 | " self.blocks = [*self.blocks_pre, self.block_fixed_point, *self.blocks_post]\n", 340 | " self.fixed_point_reuse_solution = fixed_point_reuse_solution\n", 341 | " self.last_solution = None\n", 342 | " else:\n", 343 | " self.blocks = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])\n", 344 | " self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)\n", 345 | " self.initialize_weights()\n", 346 | "\n", 347 | " def initialize_weights(self):\n", 348 | " # Initialize transformer layers:\n", 349 | " def _basic_init(module):\n", 350 | " if isinstance(module, nn.Linear):\n", 351 | " torch.nn.init.xavier_uniform_(module.weight)\n", 352 | " if module.bias is not None:\n", 353 | " nn.init.constant_(module.bias, 0)\n", 354 | " self.apply(_basic_init)\n", 355 | "\n", 356 | " # Initialize (and freeze) pos_embed by sin-cos embedding:\n", 357 | " pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))\n", 358 | " self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))\n", 359 | "\n", 360 | " # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):\n", 361 | " w = self.x_embedder.proj.weight.data\n", 362 | " nn.init.xavier_uniform_(w.view([w.shape[0], -1]))\n", 363 | " nn.init.constant_(self.x_embedder.proj.bias, 0)\n", 364 | "\n", 365 | " # Initialize label embedding table:\n", 366 | " if self.y_embedder.continuous:\n", 367 | " nn.init.normal_(self.y_embedder.embedding_projection.weight, std=0.02)\n", 368 | " nn.init.constant_(self.y_embedder.embedding_projection.bias, 0)\n", 369 | " else:\n", 370 | " nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)\n", 371 | "\n", 372 | " # Initialize timestep embedding MLP:\n", 373 | " nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)\n", 374 | " nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)\n", 375 | "\n", 376 | " # Zero-out adaLN modulation layers in DiT blocks:\n", 377 | " for block in self.blocks:\n", 378 | " nn.init.constant_(block.adaLN_modulation[-1].weight, 0)\n", 379 | " nn.init.constant_(block.adaLN_modulation[-1].bias, 0)\n", 380 | "\n", 381 | " # Zero-out output layers:\n", 382 | " nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)\n", 383 | " nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)\n", 384 | " nn.init.constant_(self.final_layer.linear.weight, 0)\n", 385 | " nn.init.constant_(self.final_layer.linear.bias, 0)\n", 386 | "\n", 387 | " def unpatchify(self, x):\n", 388 | " \"\"\"\n", 389 | " Reshapes the patches back to image format.\n", 390 | " x: (N, T, patch_size**2 * C)\n", 391 | " imgs: (N, H, W, C)\n", 392 | " \"\"\"\n", 393 | " c = self.out_channels\n", 394 | " p = self.x_embedder.patch_size[0]\n", 395 | " h = w = int(x.shape[1] ** 0.5)\n", 396 | " assert h * w == x.shape[1]\n", 397 | "\n", 398 | " x = x.reshape(shape=(x.shape[0], h, w, p, p, c))\n", 399 | " x = torch.einsum('nhwpqc->nchpwq', x)\n", 400 | " imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))\n", 401 | " return imgs\n", 402 | " \n", 403 | " def ckpt_wrapper(self, module):\n", 404 | " \"\"\"\n", 405 | " Wrapper function for gradient checkpointing.\n", 406 | " \"\"\"\n", 407 | " def ckpt_forward(*inputs):\n", 408 | " outputs = module(*inputs)\n", 409 | " return outputs\n", 410 | " return ckpt_forward\n", 411 | "\n", 412 | " def _forward_dit(self, x, t, y):\n", 413 | " x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2\n", 414 | " t = self.t_embedder(t) # (N, D)\n", 415 | " y = self.y_embedder(y, self.training) # (N, D))\n", 416 | " c = t + y # (N, D)\n", 417 | " for block in self.blocks:\n", 418 | " x = checkpoint(self.ckpt_wrapper(block), x, c) if self.use_gradient_checkpointing else block(x, c) # (N, T, D)\n", 419 | " x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)\n", 420 | " x = self.unpatchify(x) # (N, out_channels, H, W)\n", 421 | " return x\n", 422 | " \n", 423 | " def _forward_fixed_point_blocks(\n", 424 | " self, x: Float[Tensor, \"b t d\"], x_input_injection: Float[Tensor, \"b t d\"], c: Float[Tensor, \"b d\"], num_iterations: int\n", 425 | " ) -> Float[Tensor, \"b t d\"]:\n", 426 | " for _ in range(num_iterations):\n", 427 | " x = torch.cat((x, x_input_injection), dim=-1) # (N, T, D * 2)\n", 428 | " x = self.block_fixed_point_projection_fc1(x) # (N, T, D * 2)\n", 429 | " x = self.block_fixed_point_projection_act(x) # (N, T, D * 2)\n", 430 | " x = self.block_fixed_point_projection_fc2(x) # (N, T, D)\n", 431 | " x = self.block_fixed_point(x, c) # (N, T, D)\n", 432 | " return x\n", 433 | " \n", 434 | " def _check_inputs(self, x: Float[Tensor, \"b c h w\"], t: Shaped[Tensor, \"b\"], y: Shaped[Tensor, \"b\"]) -> None:\n", 435 | " if self.fixed_point_reuse_solution:\n", 436 | " if not torch.all(t[0] == t).item():\n", 437 | " raise ValueError(t)\n", 438 | "\n", 439 | " def _forward_fixed_point(self, x: Float[Tensor, \"b c h w\"], t: Shaped[Tensor, \"b\"], y: Shaped[Tensor, \"b\"]) -> Float[Tensor, \"b c h w\"]:\n", 440 | " self._check_inputs(x, t, y)\n", 441 | " x: Float[Tensor, \"b t d\"] = self.x_embedder(x) + self.pos_embed\n", 442 | " t_emb: Float[Tensor, \"b d\"] = self.t_embedder(t)\n", 443 | " y: Float[Tensor, \"b d\"] = self.y_embedder(y, self.training)\n", 444 | " c: Float[Tensor, \"b d\"] = t_emb + y\n", 445 | " c_pre_post_fixed_point: Float[Tensor, \"b d\"] = (t_emb + y) if self.fixed_point_pre_post_timestep_conditioning else y\n", 446 | " \n", 447 | " # Pre-Fixed Point\n", 448 | " # Note: If using DDP with find_unused_parameters=True, checkpoint causes issues. For more \n", 449 | " # information, see https://github.com/allenai/longformer/issues/63#issuecomment-648861503\n", 450 | " for block in self.blocks_pre:\n", 451 | " x: Float[Tensor, \"b t d\"] = checkpoint(self.ckpt_wrapper(block), x, c_pre_post_fixed_point) if self.use_gradient_checkpointing else block(x, c_pre_post_fixed_point)\n", 452 | " condition = x.clone()\n", 453 | "\n", 454 | " # Whether to reuse the previous solution at the next iteration\n", 455 | " init_solution = self.last_solution if (self.fixed_point_reuse_solution and self.last_solution is not None) else x.clone()\n", 456 | "\n", 457 | " # Fixed Point (we have condition and init_solution)\n", 458 | " x_input_injection = self.block_pre_projection(condition)\n", 459 | "\n", 460 | " # NOTE: This section of code should have no_grad, but cannot due to a DDP bug. See\n", 461 | " # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594\n", 462 | " # for more information\n", 463 | " with nullcontext(): # we use x.detach() in place of torch.no_grad due to DDP issue\n", 464 | " num_iterations_no_grad = random.randint(self.fixed_point_no_grad_min_iters, self.fixed_point_no_grad_max_iters)\n", 465 | " x = self._forward_fixed_point_blocks(x=init_solution.detach(), x_input_injection=x_input_injection.detach(), c=c, num_iterations=num_iterations_no_grad)\n", 466 | " x = x.detach() # no grad\n", 467 | " num_iterations_with_grad = random.randint(self.fixed_point_with_grad_min_iters, self.fixed_point_with_grad_max_iters)\n", 468 | " x = self._forward_fixed_point_blocks(x=x, x_input_injection=x_input_injection, c=c, num_iterations=num_iterations_with_grad)\n", 469 | "\n", 470 | " # Save solution for reuse at next step\n", 471 | " if self.fixed_point_reuse_solution:\n", 472 | " self.last_solution = x.clone()\n", 473 | " \n", 474 | " # Post-Fixed Point\n", 475 | " for block in self.blocks_post:\n", 476 | " x = checkpoint(self.ckpt_wrapper(block), x, c_pre_post_fixed_point) if self.use_gradient_checkpointing else block(x, c_pre_post_fixed_point)\n", 477 | " \n", 478 | " # Output\n", 479 | " x: Float[Tensor, \"b t p2c\"] = self.final_layer(x, c_pre_post_fixed_point) # p2c = patch_size ** 2 * out_channels)\n", 480 | " x: Float[Tensor, \"b c h w\"] = self.unpatchify(x)\n", 481 | " return x\n", 482 | " \n", 483 | " def reset(self):\n", 484 | " self.last_solution = None\n", 485 | " \n", 486 | " def forward(self, x, t, y):\n", 487 | " \"\"\"\n", 488 | " General forward pass method which handles both standard and fixed point modes.\n", 489 | " x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)\n", 490 | " t: (N,) tensor of diffusion timesteps\n", 491 | " y: (N,) tensor of class labels\n", 492 | " \"\"\"\n", 493 | " if self.fixed_point:\n", 494 | " return self._forward_fixed_point(x, t, y)\n", 495 | " else:\n", 496 | " return self._forward_dit(x, t, y)\n", 497 | "\n", 498 | " def forward_with_cfg(self, x, t, y, cfg_scale):\n", 499 | " \"\"\"\n", 500 | " Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.\n", 501 | " \"\"\"\n", 502 | " # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb\n", 503 | " half = x[: len(x) // 2]\n", 504 | " combined = torch.cat([half, half], dim=0)\n", 505 | " model_out = self.forward(combined, t, y)\n", 506 | " # For exact reproducibility reasons, we apply classifier-free guidance on only\n", 507 | " # three channels by default. The standard approach to cfg applies it to all channels.\n", 508 | " # This can be done by uncommenting the following line and commenting-out the line following that.\n", 509 | " # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]\n", 510 | " eps, rest = model_out[:, :3], model_out[:, 3:]\n", 511 | " cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)\n", 512 | " half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)\n", 513 | " eps = torch.cat([half_eps, half_eps], dim=0)\n", 514 | " return torch.cat([eps, rest], dim=1)" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": null, 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [ 523 | "# Initialize the Accelerator for GPU/CPU acceleration and create the DiT model.\n", 524 | "accelerator = Accelerator()\n", 525 | "model = DiT_models[args.model](\n", 526 | " input_size=args.latent_size,\n", 527 | " num_classes=(args.dino_supervised_dim if args.dino_supervised else args.num_classes),\n", 528 | " is_label_continuous=args.dino_supervised,\n", 529 | " class_dropout_prob=0,\n", 530 | " learn_sigma=(not args.flow), # TODO: Implement learned variance for flow-based models\n", 531 | " use_gradient_checkpointing=False,\n", 532 | " fixed_point=args.fixed_point,\n", 533 | " fixed_point_pre_depth=args.fixed_point_pre_depth,\n", 534 | " fixed_point_post_depth=args.fixed_point_post_depth,\n", 535 | " fixed_point_no_grad_min_iters=0, \n", 536 | " fixed_point_no_grad_max_iters=0,\n", 537 | " fixed_point_with_grad_min_iters=args.fixed_point_iters, \n", 538 | " fixed_point_with_grad_max_iters=args.fixed_point_iters,\n", 539 | " fixed_point_reuse_solution=args.fixed_point_reuse_solution,\n", 540 | " fixed_point_pre_post_timestep_conditioning=args.fixed_point_pre_post_timestep_conditioning,\n", 541 | " ).to(accelerator.device)" 542 | ] 543 | }, 544 | { 545 | "cell_type": "markdown", 546 | "metadata": {}, 547 | "source": [ 548 | "## Load Model" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": null, 554 | "metadata": {}, 555 | "outputs": [], 556 | "source": [ 557 | "# Load the pre-trained model checkpoint.\n", 558 | "state_dict = find_model(args.ckpt)\n", 559 | "model.load_state_dict(state_dict)\n", 560 | "model.eval() \n", 561 | "\n", 562 | "# Initialize the diffusion process with specified parameters.\n", 563 | "diffusion = create_diffusion(\n", 564 | " str(args.num_sampling_steps), \n", 565 | " use_flow=args.flow,\n", 566 | " predict_v=args.predict_v,\n", 567 | " use_zero_terminal_snr=args.use_zero_terminal_snr,\n", 568 | ")\n", 569 | "\n", 570 | "# Load the VAE model and evaluate it.\n", 571 | "vae = AutoencoderKL.from_pretrained(f\"stabilityai/sd-vae-ft-{args.vae}\").to(accelerator.device).eval()\n" 572 | ] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "metadata": {}, 577 | "source": [ 578 | "## Inference" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": null, 584 | "metadata": {}, 585 | "outputs": [], 586 | "source": [ 587 | "# create generator, class labels, and latents\n", 588 | "N_images = 32\n", 589 | "generator = torch.Generator(device=accelerator.device)\n", 590 | "generator.manual_seed(args.global_seed)\n", 591 | "class_labels = torch.randint(0, args.num_classes, size=(N_images,), device=accelerator.device)\n", 592 | "generator.manual_seed(args.global_seed)\n", 593 | "latents = torch.randn(N_images, model.in_channels, args.H_lat, args.W_lat, device=accelerator.device, generator=generator)\n", 594 | "class_labels = class_labels[args.sample_index_start:args.sample_index_end]\n", 595 | "latents = latents[args.sample_index_start:args.sample_index_end]\n", 596 | "indices = list(range(args.sample_index_start, args.sample_index_end))\n", 597 | "print(f'Using pseudorandom class labels and latents (start={args.sample_index_start} and end={args.sample_index_end})')\n", 598 | "\n", 599 | " # Create output path\n", 600 | "output_dir = Path(args.output_dir)\n", 601 | "# if cfg is used\n", 602 | "using_cfg = args.cfg_scale > 1.0" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": null, 608 | "metadata": {}, 609 | "outputs": [], 610 | "source": [ 611 | "# Load class labels for helpful filenames\n", 612 | "if args.dataset_name == 'imagenet256':\n", 613 | " with open(\"../utils/imagenet-labels.json\", \"r\") as f:\n", 614 | " label_names: list[str] = json.load(f)\n", 615 | " label_names = [l.lower().replace(' ', '-').replace('\\'', '') for l in label_names]\n", 616 | "elif args.unsupervised:\n", 617 | " assert args.cfg_scale == 1.0\n", 618 | " label_names = [\"unlabeled\"]\n", 619 | "else:\n", 620 | " raise NotImplementedError()" 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": null, 626 | "metadata": {}, 627 | "outputs": [], 628 | "source": [ 629 | "with torch.inference_mode():\n", 630 | " # Sample loop\n", 631 | " num_batches = math.ceil(len(class_labels) / args.batch_size)\n", 632 | " for batch_idx in trange(num_batches, disable=(not accelerator.is_main_process)):\n", 633 | "\n", 634 | " # Get pre-sampled inputs\n", 635 | " z = latents[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size]\n", 636 | " y = class_labels[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size]\n", 637 | " idxs = indices[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size]\n", 638 | " output_paths = [output_dir / f'{idx:05d}--{y_i:03d}--{label_names[y_i]}.png' for y_i, idx in zip(y.tolist(), idxs)]\n", 639 | "\n", 640 | " # Skip files that already exist\n", 641 | " if all(output_path.is_file() for output_path in output_paths):\n", 642 | " print(f'Files already exist (batch {batch_idx}). Skipping.')\n", 643 | " continue\n", 644 | "\n", 645 | " # Setup classifier-free guidance\n", 646 | " if using_cfg:\n", 647 | " y_null = torch.tensor([1000] * args.batch_size, device=accelerator.device)\n", 648 | " y = torch.cat([y, y_null], 0)\n", 649 | " z = torch.cat([z, z], 0)\n", 650 | " model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)\n", 651 | " sample_fn = model.forward_with_cfg\n", 652 | " else:\n", 653 | " model_kwargs = dict(y=y)\n", 654 | " sample_fn = model.forward\n", 655 | "\n", 656 | " # Sample latent images\n", 657 | " sample_kwargs = dict(model=sample_fn, shape=z.shape, noise=z, clip_denoised=False, model_kwargs=model_kwargs, \n", 658 | " progress=False, device=accelerator.device)\n", 659 | " if args.ddim:\n", 660 | " samples = diffusion.ddim_sample_loop(**sample_kwargs)\n", 661 | " else:\n", 662 | " samples = diffusion.p_sample_loop(**sample_kwargs)\n", 663 | "\n", 664 | " if using_cfg:\n", 665 | " samples, _ = samples.chunk(2, dim=0)\n", 666 | " \n", 667 | " # Reset model (resets the initial solution to None)\n", 668 | " model.reset()\n", 669 | "\n", 670 | " # Decode latents\n", 671 | " samples = vae.decode(samples / vae.config.scaling_factor).sample\n", 672 | " samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to(\"cpu\", dtype=torch.uint8).numpy()\n", 673 | "\n", 674 | " # Save samples to disk as individual .png files\n", 675 | " for sample, output_path in zip(samples, output_paths):\n", 676 | " Image.fromarray(sample).save(output_path)" 677 | ] 678 | }, 679 | { 680 | "cell_type": "markdown", 681 | "metadata": {}, 682 | "source": [ 683 | "## Visualization" 684 | ] 685 | }, 686 | { 687 | "cell_type": "code", 688 | "execution_count": null, 689 | "metadata": {}, 690 | "outputs": [], 691 | "source": [ 692 | "import matplotlib.pyplot as plt\n", 693 | "from PIL import Image\n", 694 | "from pathlib import Path\n", 695 | "import json\n", 696 | "\n", 697 | "# get all files in output_dir\n", 698 | "output_dir = Path(args.output_dir)\n", 699 | "files = list(output_dir.glob('*.png'))\n", 700 | "\n", 701 | "# visualize four random samples\n", 702 | "fig, axs = plt.subplots(2, 2, figsize=(10, 10))\n", 703 | "for ax in axs.flatten():\n", 704 | " img = Image.open(files[random.randint(0, len(files) - 1)])\n", 705 | " ax.imshow(img)\n", 706 | " ax.axis('off')\n", 707 | "plt.tight_layout()\n", 708 | "plt.show()" 709 | ] 710 | } 711 | ], 712 | "metadata": { 713 | "kernelspec": { 714 | "display_name": "renaissance3", 715 | "language": "python", 716 | "name": "python3" 717 | }, 718 | "language_info": { 719 | "codemirror_mode": { 720 | "name": "ipython", 721 | "version": 3 722 | }, 723 | "file_extension": ".py", 724 | "mimetype": "text/x-python", 725 | "name": "python", 726 | "nbconvert_exporter": "python", 727 | "pygments_lexer": "ipython3", 728 | "version": "3.10.13" 729 | } 730 | }, 731 | "nbformat": 4, 732 | "nbformat_minor": 2 733 | } 734 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_flow=False, 14 | use_kl=False, 15 | sigma_small=False, 16 | predict_xstart=False, 17 | predict_v=False, 18 | learn_sigma=True, 19 | rescale_learned_sigmas=False, 20 | use_zero_terminal_snr=False, 21 | diffusion_steps=1000 22 | ): 23 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps, ) 24 | if use_zero_terminal_snr: 25 | assert predict_v, "please use v-prediction if using zero terminal snr" 26 | betas = gd.enforce_zero_terminal_snr(betas).numpy() 27 | print('Rescaled betas to enforce zero terminal snr') 28 | if use_kl: 29 | loss_type = gd.LossType.RESCALED_KL 30 | elif use_flow: 31 | loss_type = gd.LossType.FLOW 32 | # TODO: Implement learned variance for flow-based models 33 | learn_sigma = False 34 | elif rescale_learned_sigmas: 35 | loss_type = gd.LossType.RESCALED_MSE 36 | else: 37 | loss_type = gd.LossType.MSE 38 | if timestep_respacing is None or timestep_respacing == "": 39 | timestep_respacing = [diffusion_steps] 40 | return SpacedDiffusion( 41 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 42 | betas=betas, 43 | model_mean_type=( 44 | gd.ModelMeanType.V if predict_v else 45 | gd.ModelMeanType.START_X if predict_xstart else 46 | gd.ModelMeanType.EPSILON 47 | ), 48 | model_var_type=( 49 | ( 50 | gd.ModelVarType.FIXED_LARGE 51 | if not sigma_small 52 | else gd.ModelVarType.FIXED_SMALL 53 | ) 54 | if not learn_sigma 55 | else gd.ModelVarType.LEARNED_RANGE 56 | ), 57 | loss_type=loss_type 58 | # rescale_timesteps=rescale_timesteps, 59 | ) 60 | -------------------------------------------------------------------------------- /diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | 7 | import math 8 | 9 | import numpy as np 10 | import torch as th 11 | import enum 12 | 13 | from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl 14 | 15 | 16 | def mean_flat(tensor): 17 | """ 18 | Take the mean over all non-batch dimensions. 19 | """ 20 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 21 | 22 | 23 | class ModelMeanType(enum.Enum): 24 | """ 25 | Which type of output the model predicts. 26 | """ 27 | 28 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 29 | START_X = enum.auto() # the model predicts x_0 30 | EPSILON = enum.auto() # the model predicts epsilon 31 | V = enum.auto() # the model predicts v == sqrt(alpha) * eps - sqrt(1-alpha) * x 32 | 33 | 34 | class ModelVarType(enum.Enum): 35 | """ 36 | What is used as the model's output variance. 37 | The LEARNED_RANGE option has been added to allow the model to predict 38 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 39 | """ 40 | 41 | LEARNED = enum.auto() 42 | FIXED_SMALL = enum.auto() 43 | FIXED_LARGE = enum.auto() 44 | LEARNED_RANGE = enum.auto() 45 | 46 | 47 | class LossType(enum.Enum): 48 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 49 | RESCALED_MSE = ( 50 | enum.auto() 51 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 52 | KL = enum.auto() # use the variational lower-bound 53 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 54 | FLOW = enum.auto() # new: flow-based loss 55 | 56 | def is_vb(self): 57 | return self == LossType.KL or self == LossType.RESCALED_KL 58 | 59 | 60 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 61 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 62 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 63 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) 64 | return betas 65 | 66 | 67 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 68 | """ 69 | This is the deprecated API for creating beta schedules. 70 | See get_named_beta_schedule() for the new library of schedules. 71 | """ 72 | if beta_schedule == "quad": 73 | betas = ( 74 | np.linspace( 75 | beta_start ** 0.5, 76 | beta_end ** 0.5, 77 | num_diffusion_timesteps, 78 | dtype=np.float64, 79 | ) 80 | ** 2 81 | ) 82 | elif beta_schedule == "linear": 83 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 84 | elif beta_schedule == "warmup10": 85 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) 86 | elif beta_schedule == "warmup50": 87 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) 88 | elif beta_schedule == "const": 89 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 90 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 91 | betas = 1.0 / np.linspace( 92 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 93 | ) 94 | else: 95 | raise NotImplementedError(beta_schedule) 96 | assert betas.shape == (num_diffusion_timesteps,) 97 | return betas 98 | 99 | 100 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 101 | """ 102 | Get a pre-defined beta schedule for the given name. 103 | The beta schedule library consists of beta schedules which remain similar 104 | in the limit of num_diffusion_timesteps. 105 | Beta schedules may be added, but should not be removed or changed once 106 | they are committed to maintain backwards compatibility. 107 | """ 108 | if schedule_name == "linear": 109 | # Linear schedule from Ho et al, extended to work for any number of 110 | # diffusion steps. 111 | scale = 1000 / num_diffusion_timesteps 112 | betas = get_beta_schedule( 113 | "linear", 114 | beta_start=scale * 0.0001, 115 | beta_end=scale * 0.02, 116 | num_diffusion_timesteps=num_diffusion_timesteps, 117 | ) 118 | elif schedule_name == "squaredcos_cap_v2": 119 | betas = betas_for_alpha_bar( 120 | num_diffusion_timesteps, 121 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 122 | ) 123 | else: 124 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 125 | return betas 126 | 127 | from typing import Union 128 | def enforce_zero_terminal_snr(betas: Union[th.Tensor, np.ndarray]) -> th.Tensor: 129 | # def enforce_zero_terminal_snr(betas: th.Tensor | np.ndarray) -> th.Tensor: 130 | betas = betas if th.is_tensor(betas) else th.from_numpy(betas) 131 | 132 | # Convert betas to alphas_bar_sqrt 133 | alphas = 1 - betas 134 | alphas_bar = alphas.cumprod(0) 135 | alphas_bar_sqrt = alphas_bar.sqrt() 136 | 137 | # Store old values. 138 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 139 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() 140 | # Shift so last timestep is zero. 141 | alphas_bar_sqrt -= alphas_bar_sqrt_T 142 | # Scale so first timestep is back to old value. 143 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 144 | 145 | # Convert alphas_bar_sqrt to betas 146 | alphas_bar = alphas_bar_sqrt ** 2 147 | alphas = alphas_bar[1:] / alphas_bar[:-1] 148 | alphas = th.cat([alphas_bar[0:1], alphas]) 149 | betas = 1 - alphas 150 | return betas 151 | 152 | 153 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 154 | """ 155 | Create a beta schedule that discretizes the given alpha_t_bar function, 156 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 157 | :param num_diffusion_timesteps: the number of betas to produce. 158 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 159 | produces the cumulative product of (1-beta) up to that 160 | part of the diffusion process. 161 | :param max_beta: the maximum beta to use; use values lower than 1 to 162 | prevent singularities. 163 | """ 164 | betas = [] 165 | for i in range(num_diffusion_timesteps): 166 | t1 = i / num_diffusion_timesteps 167 | t2 = (i + 1) / num_diffusion_timesteps 168 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 169 | return np.array(betas) 170 | 171 | 172 | class GaussianDiffusion: 173 | """ 174 | Utilities for training and sampling diffusion models. 175 | Original ported from this codebase: 176 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 177 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 178 | starting at T and going to 1. 179 | """ 180 | 181 | def __init__( 182 | self, 183 | *, 184 | betas, 185 | model_mean_type, 186 | model_var_type, 187 | loss_type, 188 | weight_schedule: str = "p2" 189 | ): 190 | 191 | self.model_mean_type = model_mean_type 192 | self.model_var_type = model_var_type 193 | self.loss_type = loss_type 194 | 195 | # Use float64 for accuracy. 196 | betas = np.array(betas, dtype=np.float64) 197 | self.betas = betas 198 | assert len(betas.shape) == 1, "betas must be 1-D" 199 | assert (betas > 0).all() and (betas <= 1).all() 200 | 201 | self.num_timesteps = int(betas.shape[0]) 202 | 203 | alphas = 1.0 - betas 204 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 205 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 206 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 207 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 208 | 209 | # calculations for diffusion q(x_t | x_{t-1}) and others 210 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 211 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 212 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 213 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 214 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 215 | 216 | # calculations for posterior q(x_{t-1} | x_t, x_0) 217 | self.posterior_variance = ( 218 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 219 | ) 220 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 221 | self.posterior_log_variance_clipped = np.log( 222 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 223 | ) if len(self.posterior_variance) > 1 else np.array([]) 224 | 225 | self.posterior_mean_coef1 = ( 226 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 227 | ) 228 | self.posterior_mean_coef2 = ( 229 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) 230 | ) 231 | 232 | self.snrs = self.alphas_cumprod / self.sqrt_one_minus_alphas_cumprod ** 2 233 | 234 | self.weight_schedule = weight_schedule 235 | 236 | def q_mean_variance(self, x_start, t): 237 | """ 238 | Get the distribution q(x_t | x_0). 239 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 240 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 241 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 242 | """ 243 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 244 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 245 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) 246 | return mean, variance, log_variance 247 | 248 | def q_sample(self, x_start, t, noise=None): 249 | """ 250 | Diffuse the data for a given number of diffusion steps. 251 | In other words, sample from q(x_t | x_0). 252 | :param x_start: the initial data batch. 253 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 254 | :param noise: if specified, the split-out normal noise. 255 | :return: A noisy version of x_start. 256 | """ 257 | if noise is None: 258 | noise = th.randn_like(x_start) 259 | assert noise.shape == x_start.shape 260 | if self.loss_type == LossType.FLOW: 261 | t_float = t[:, None, None, None] / self.num_timesteps 262 | # return x_start * t_float + (1 - t_float) * noise # old models 263 | return x_start * (1 - t_float) + t_float * noise 264 | else: 265 | return ( 266 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 267 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 268 | ) 269 | 270 | def q_posterior_mean_variance(self, x_start, x_t, t): 271 | """ 272 | Compute the mean and variance of the diffusion posterior: 273 | q(x_{t-1} | x_t, x_0) 274 | """ 275 | assert x_start.shape == x_t.shape 276 | posterior_mean = ( 277 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 278 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 279 | ) 280 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 281 | posterior_log_variance_clipped = _extract_into_tensor( 282 | self.posterior_log_variance_clipped, t, x_t.shape 283 | ) 284 | assert ( 285 | posterior_mean.shape[0] 286 | == posterior_variance.shape[0] 287 | == posterior_log_variance_clipped.shape[0] 288 | == x_start.shape[0] 289 | ) 290 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 291 | 292 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): 293 | """ 294 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 295 | the initial x, x_0. 296 | :param model: the model, which takes a signal and a batch of timesteps 297 | as input. 298 | :param x: the [N x C x ...] tensor at time t. 299 | :param t: a 1-D Tensor of timesteps. 300 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 301 | :param denoised_fn: if not None, a function which applies to the 302 | x_start prediction before it is used to sample. Applies before 303 | clip_denoised. 304 | :param model_kwargs: if not None, a dict of extra keyword arguments to 305 | pass to the model. This can be used for conditioning. 306 | :return: a dict with the following keys: 307 | - 'mean': the model mean output. 308 | - 'variance': the model variance output. 309 | - 'log_variance': the log of 'variance'. 310 | - 'pred_xstart': the prediction for x_0. 311 | """ 312 | if model_kwargs is None: 313 | model_kwargs = {} 314 | 315 | B, C = x.shape[:2] 316 | assert t.shape == (B,), f'{t.shape = } and {B = }' 317 | model_output = model(x, t, **model_kwargs) 318 | if isinstance(model_output, tuple): 319 | model_output, extra = model_output 320 | else: 321 | extra = None 322 | 323 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 324 | assert self.loss_type != LossType.FLOW # TODO: Implement learned variance for flow-based models 325 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 326 | model_output, model_var_values = th.split(model_output, C, dim=1) 327 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) 328 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 329 | # The model_var_values is [-1, 1] for [min_var, max_var]. 330 | frac = (model_var_values + 1) / 2 331 | model_log_variance = frac * max_log + (1 - frac) * min_log 332 | model_variance = th.exp(model_log_variance) 333 | else: 334 | model_variance, model_log_variance = { 335 | # for fixedlarge, we set the initial (log-)variance like so 336 | # to get a better decoder log likelihood. 337 | ModelVarType.FIXED_LARGE: ( 338 | np.append(self.posterior_variance[1], self.betas[1:]), 339 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 340 | ), 341 | ModelVarType.FIXED_SMALL: ( 342 | self.posterior_variance, 343 | self.posterior_log_variance_clipped, 344 | ), 345 | }[self.model_var_type] 346 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 347 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 348 | 349 | def process_xstart(x): 350 | if denoised_fn is not None: 351 | x = denoised_fn(x) 352 | if clip_denoised: 353 | return x.clamp(-1, 1) 354 | return x 355 | 356 | if self.loss_type == LossType.FLOW: 357 | velocity = model_output 358 | model_mean = x + velocity / self.num_timesteps # this is the simpler version, without clipping 359 | t_float = t[:, None, None, None] / self.num_timesteps # the input t is (T, 0] 360 | pred_xstart = process_xstart(x + velocity * t_float) 361 | # new_model_mean = x + (pred_xstart - x) / t_float / self.num_timesteps 362 | else: 363 | if self.model_mean_type == ModelMeanType.PREVIOUS_X: 364 | pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)) 365 | model_mean = model_output 366 | elif self.model_mean_type == ModelMeanType.START_X: 367 | pred_xstart = process_xstart(model_output) 368 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 369 | elif self.model_mean_type == ModelMeanType.V: 370 | pred_xstart = process_xstart(self._predict_xstart_from_v(x_t=x, t=t, v=model_output)) 371 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 372 | elif self.model_mean_type == ModelMeanType.EPSILON: 373 | pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) 374 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 375 | else: 376 | raise ValueError(self.model_mean_type) 377 | 378 | 379 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 380 | return { 381 | "mean": model_mean, 382 | "variance": model_variance, 383 | "model_output": model_output, 384 | "log_variance": model_log_variance, 385 | "pred_xstart": pred_xstart, 386 | "extra": extra, 387 | } 388 | 389 | def _predict_xstart_from_eps(self, x_t, t, eps): 390 | assert x_t.shape == eps.shape 391 | return ( 392 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 393 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 394 | ) 395 | 396 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 397 | return ( 398 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart 399 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 400 | 401 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 402 | # (xprev - coef2*x_t) / coef1 403 | assert x_t.shape == xprev.shape 404 | return (_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - 405 | _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t) 406 | 407 | def _predict_xstart_from_v(self, x_t, t, v): 408 | # x0_pred = sqrt(alpha) * x_t - sqrt(1-alpha) * v_pred 409 | assert x_t.shape == v.shape 410 | return (_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - 411 | _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v) 412 | 413 | def _predict_eps_from_v(self, x_t, t, v): 414 | # https://github.com/huggingface/diffusers/blob/6a89a6c93ae38927097f5181030e3ceb7de7f43d/src/diffusers/schedulers/scheduling_ddim.py#L416-L429 415 | # eps = (alpha_prod_t**0.5) * model_output_v_pred + (beta_prod_t**0.5) * sample 416 | assert x_t.shape == v.shape 417 | return (_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + 418 | _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t) 419 | 420 | 421 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 422 | """ 423 | Compute the mean for the previous step, given a function cond_fn that 424 | computes the gradient of a conditional log probability with respect to 425 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 426 | condition on y. 427 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 428 | """ 429 | gradient = cond_fn(x, t, **model_kwargs) 430 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 431 | return new_mean 432 | 433 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 434 | """ 435 | Compute what the p_mean_variance output would have been, should the 436 | model's score function be conditioned by cond_fn. 437 | See condition_mean() for details on cond_fn. 438 | Unlike condition_mean(), this instead uses the conditioning strategy 439 | from Song et al (2020). 440 | """ 441 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 442 | 443 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 444 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) 445 | 446 | out = p_mean_var.copy() 447 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 448 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) 449 | return out 450 | 451 | def p_sample( 452 | self, 453 | model, 454 | x, 455 | t, 456 | clip_denoised=True, 457 | denoised_fn=None, 458 | cond_fn=None, 459 | model_kwargs=None, 460 | ): 461 | """ 462 | Sample x_{t-1} from the model at the given timestep. 463 | :param model: the model to sample from. 464 | :param x: the current tensor at x_{t-1}. 465 | :param t: the value of t, starting at 0 for the first diffusion step. 466 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 467 | :param denoised_fn: if not None, a function which applies to the 468 | x_start prediction before it is used to sample. 469 | :param cond_fn: if not None, this is a gradient function that acts 470 | similarly to the model. 471 | :param model_kwargs: if not None, a dict of extra keyword arguments to 472 | pass to the model. This can be used for conditioning. 473 | :return: a dict containing the following keys: 474 | - 'sample': a random sample from the model. 475 | - 'pred_xstart': a prediction of x_0. 476 | """ 477 | out = self.p_mean_variance( 478 | model, 479 | x, 480 | t, 481 | clip_denoised=clip_denoised, 482 | denoised_fn=denoised_fn, 483 | model_kwargs=model_kwargs, 484 | ) 485 | # print(f'Running ddpm with {t = }') 486 | noise = th.randn_like(x) 487 | nonzero_mask = ( 488 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 489 | ) # no noise when t == 0 490 | if cond_fn is not None: 491 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) 492 | if self.loss_type == LossType.FLOW: 493 | sample = out["mean"] 494 | else: 495 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 496 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 497 | 498 | def p_sample_loop( 499 | self, 500 | model, 501 | shape, 502 | noise=None, 503 | clip_denoised=True, 504 | denoised_fn=None, 505 | cond_fn=None, 506 | model_kwargs=None, 507 | device=None, 508 | progress=False, 509 | return_all=False, 510 | ): 511 | """ 512 | Generate samples from the model. 513 | :param model: the model module. 514 | :param shape: the shape of the samples, (N, C, H, W). 515 | :param noise: if specified, the noise from the encoder to sample. 516 | Should be of the same shape as `shape`. 517 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 518 | :param denoised_fn: if not None, a function which applies to the 519 | x_start prediction before it is used to sample. 520 | :param cond_fn: if not None, this is a gradient function that acts 521 | similarly to the model. 522 | :param model_kwargs: if not None, a dict of extra keyword arguments to 523 | pass to the model. This can be used for conditioning. 524 | :param device: if specified, the device to create the samples on. 525 | If not specified, use a model parameter's device. 526 | :param progress: if True, show a tqdm progress bar. 527 | :return: a non-differentiable batch of samples. 528 | """ 529 | final = None 530 | all_samples = [] 531 | for sample in self.p_sample_loop_progressive( 532 | model, 533 | shape, 534 | noise=noise, 535 | clip_denoised=clip_denoised, 536 | denoised_fn=denoised_fn, 537 | cond_fn=cond_fn, 538 | model_kwargs=model_kwargs, 539 | device=device, 540 | progress=progress, 541 | ): 542 | final = sample 543 | all_samples.append(sample) 544 | return (final["sample"], all_samples) if return_all else final["sample"] 545 | 546 | def p_sample_loop_progressive( 547 | self, 548 | model, 549 | shape, 550 | noise=None, 551 | clip_denoised=True, 552 | denoised_fn=None, 553 | cond_fn=None, 554 | model_kwargs=None, 555 | device=None, 556 | progress=False, 557 | ): 558 | """ 559 | Generate samples from the model and yield intermediate samples from 560 | each timestep of diffusion. 561 | Arguments are the same as p_sample_loop(). 562 | Returns a generator over dicts, where each dict is the return value of 563 | p_sample(). 564 | """ 565 | if device is None: 566 | device = next(model.parameters()).device 567 | assert isinstance(shape, (tuple, list)) 568 | if noise is not None: 569 | img = noise 570 | else: 571 | img = th.randn(*shape, device=device) 572 | indices = list(range(self.num_timesteps))[::-1] 573 | 574 | if progress: 575 | # Lazy import so that we don't depend on tqdm. 576 | from tqdm.auto import tqdm 577 | 578 | indices = tqdm(indices) 579 | 580 | for i in indices: 581 | t = th.tensor([i] * shape[0], device=device) 582 | with th.no_grad(): 583 | out = self.p_sample( 584 | model, 585 | img, 586 | t, 587 | clip_denoised=clip_denoised, 588 | denoised_fn=denoised_fn, 589 | cond_fn=cond_fn, 590 | model_kwargs=model_kwargs, 591 | ) 592 | yield out 593 | img = out["sample"] 594 | 595 | def ddim_sample( 596 | self, 597 | model, 598 | x, 599 | t, 600 | clip_denoised=True, 601 | denoised_fn=None, 602 | cond_fn=None, 603 | model_kwargs=None, 604 | eta=0.0, 605 | ): 606 | """ 607 | Sample x_{t-1} from the model using DDIM. 608 | Same usage as p_sample(). 609 | """ 610 | # print(f'Running ddim with {t = }') 611 | out = self.p_mean_variance( 612 | model, 613 | x, 614 | t, 615 | clip_denoised=clip_denoised, 616 | denoised_fn=denoised_fn, 617 | model_kwargs=model_kwargs, 618 | ) 619 | if cond_fn is not None: 620 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 621 | 622 | # Usually our model outputs epsilon, but we re-derive it 623 | # in case we used x_start or x_prev prediction. 624 | 625 | # NOTE: For v-pred, these two ways of computing eps should be the same, except that the latter 626 | # way (1) does not work at the first step of zero-snr when alphas = 1, and (2) also has some additional 627 | # post-processing applied to the sample (like clipping, as applied by the `process_xstart` 628 | # function above). For this reason, we will use the former when alphas is 1 and the latter otherwise. 629 | if self.model_mean_type == ModelMeanType.V and _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, (1,)).isinf().any().item(): 630 | if cond_fn is not None: 631 | raise NotImplementedError() 632 | eps = self._predict_eps_from_v(x, t, out["model_output"]) 633 | # eps_check = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 634 | # print(eps - eps_check) # <-- can use this to check that they are the same, for timesteps where alpha != 1 635 | else: 636 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 637 | 638 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 639 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 640 | sigma = ( 641 | eta 642 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 643 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 644 | ) 645 | # Equation 12. 646 | noise = th.randn_like(x) 647 | mean_pred = ( 648 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 649 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 650 | ) 651 | nonzero_mask = ( 652 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 653 | ) # no noise when t == 0 654 | sample = mean_pred + nonzero_mask * sigma * noise 655 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 656 | 657 | def ddim_reverse_sample( 658 | self, 659 | model, 660 | x, 661 | t, 662 | clip_denoised=True, 663 | denoised_fn=None, 664 | cond_fn=None, 665 | model_kwargs=None, 666 | eta=0.0, 667 | ): 668 | """ 669 | Sample x_{t+1} from the model using DDIM reverse ODE. 670 | """ 671 | assert eta == 0.0, "Reverse ODE only for deterministic path" 672 | out = self.p_mean_variance( 673 | model, 674 | x, 675 | t, 676 | clip_denoised=clip_denoised, 677 | denoised_fn=denoised_fn, 678 | model_kwargs=model_kwargs, 679 | ) 680 | if cond_fn is not None: 681 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 682 | # Usually our model outputs epsilon, but we re-derive it 683 | # in case we used x_start or x_prev prediction. 684 | eps = ( 685 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 686 | - out["pred_xstart"] 687 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 688 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 689 | 690 | # Equation 12. reversed 691 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps 692 | 693 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 694 | 695 | def ddim_sample_loop( 696 | self, 697 | model, 698 | shape, 699 | noise=None, 700 | clip_denoised=True, 701 | denoised_fn=None, 702 | cond_fn=None, 703 | model_kwargs=None, 704 | device=None, 705 | progress=False, 706 | eta=0.0, 707 | ): 708 | """ 709 | Generate samples from the model using DDIM. 710 | Same usage as p_sample_loop(). 711 | """ 712 | final = None 713 | for sample in self.ddim_sample_loop_progressive( 714 | model, 715 | shape, 716 | noise=noise, 717 | clip_denoised=clip_denoised, 718 | denoised_fn=denoised_fn, 719 | cond_fn=cond_fn, 720 | model_kwargs=model_kwargs, 721 | device=device, 722 | progress=progress, 723 | eta=eta, 724 | ): 725 | final = sample 726 | return final["sample"] 727 | 728 | def ddim_sample_loop_progressive( 729 | self, 730 | model, 731 | shape, 732 | noise=None, 733 | clip_denoised=True, 734 | denoised_fn=None, 735 | cond_fn=None, 736 | model_kwargs=None, 737 | device=None, 738 | progress=False, 739 | eta=0.0, 740 | ): 741 | """ 742 | Use DDIM to sample from the model and yield intermediate samples from 743 | each timestep of DDIM. 744 | Same usage as p_sample_loop_progressive(). 745 | """ 746 | if device is None: 747 | device = next(model.parameters()).device 748 | assert isinstance(shape, (tuple, list)) 749 | if noise is not None: 750 | img = noise 751 | else: 752 | img = th.randn(*shape, device=device) 753 | indices = list(range(self.num_timesteps))[::-1] 754 | 755 | if progress: 756 | # Lazy import so that we don't depend on tqdm. 757 | from tqdm.auto import tqdm 758 | 759 | indices = tqdm(indices) 760 | 761 | for i in indices: 762 | t = th.tensor([i] * shape[0], device=device) 763 | with th.no_grad(): 764 | out = self.ddim_sample( 765 | model, 766 | img, 767 | t, 768 | clip_denoised=clip_denoised, 769 | denoised_fn=denoised_fn, 770 | cond_fn=cond_fn, 771 | model_kwargs=model_kwargs, 772 | eta=eta, 773 | ) 774 | yield out 775 | img = out["sample"] 776 | 777 | def _vb_terms_bpd( 778 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 779 | ): 780 | """ 781 | Get a term for the variational lower-bound. 782 | The resulting units are bits (rather than nats, as one might expect). 783 | This allows for comparison to other papers. 784 | :return: a dict with the following keys: 785 | - 'output': a shape [N] tensor of NLLs or KLs. 786 | - 'pred_xstart': the x_0 predictions. 787 | """ 788 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 789 | x_start=x_start, x_t=x_t, t=t 790 | ) 791 | out = self.p_mean_variance( 792 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 793 | ) 794 | kl = normal_kl( 795 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 796 | ) 797 | kl = mean_flat(kl) / np.log(2.0) 798 | 799 | decoder_nll = -discretized_gaussian_log_likelihood( 800 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 801 | ) 802 | assert decoder_nll.shape == x_start.shape 803 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 804 | 805 | # At the first timestep return the decoder NLL, 806 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 807 | output = th.where((t == 0), decoder_nll, kl) 808 | return {"output": output, "pred_xstart": out["pred_xstart"]} 809 | 810 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): 811 | """ 812 | Compute training losses for a single timestep. 813 | :param model: the model to evaluate loss on. 814 | :param x_start: the [N x C x ...] tensor of inputs. 815 | :param t: a batch of timestep indices. 816 | :param model_kwargs: if not None, a dict of extra keyword arguments to 817 | pass to the model. This can be used for conditioning. 818 | :param noise: if specified, the specific Gaussian noise to try to remove. 819 | :return: a dict with the key "loss" containing a tensor of shape [N]. 820 | Some mean or variance settings may also have other keys. 821 | """ 822 | if model_kwargs is None: 823 | model_kwargs = {} 824 | if noise is None: 825 | noise = th.randn_like(x_start) 826 | x_t = self.q_sample(x_start, t, noise=noise) 827 | 828 | terms = {} 829 | 830 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 831 | terms["loss"] = self._vb_terms_bpd( 832 | model=model, 833 | x_start=x_start, 834 | x_t=x_t, 835 | t=t, 836 | clip_denoised=False, 837 | model_kwargs=model_kwargs, 838 | )["output"] 839 | if self.loss_type == LossType.RESCALED_KL: 840 | terms["loss"] *= self.num_timesteps 841 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 842 | model_output = model(x_t, t, **model_kwargs) 843 | 844 | if self.model_var_type in [ 845 | ModelVarType.LEARNED, 846 | ModelVarType.LEARNED_RANGE, 847 | ]: 848 | B, C = x_t.shape[:2] 849 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 850 | model_output, model_var_values = th.split(model_output, C, dim=1) 851 | # Learn the variance using the variational bound, but don't let 852 | # it affect our mean prediction. 853 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 854 | terms["vb"] = self._vb_terms_bpd( 855 | model=lambda *args, r=frozen_out: r, 856 | x_start=x_start, 857 | x_t=x_t, 858 | t=t, 859 | clip_denoised=False, 860 | )["output"] 861 | if self.loss_type == LossType.RESCALED_MSE: 862 | # Divide by 1000 for equivalence with initial implementation. 863 | # Without a factor of 1/1000, the VB term hurts the MSE term. 864 | terms["vb"] *= self.num_timesteps / 1000.0 865 | 866 | if self.model_mean_type == ModelMeanType.V: 867 | # from https://github.com/ericl122333/PatchDiffusion-Pytorch/blob/main/patch_diffusion/gaussian_diffusion.py#L841 868 | # TODO: consider refactoring everything to this format 869 | # Get predicted x_0 870 | if self.model_mean_type == ModelMeanType.START_X: 871 | model_pred_x0 = model_output 872 | elif self.model_mean_type == ModelMeanType.EPSILON: 873 | model_pred_x0 = self._predict_xstart_from_eps(x_t, t, model_output) 874 | elif self.model_mean_type == ModelMeanType.V: 875 | model_pred_x0 = self._predict_xstart_from_v(x_t, t, model_output) 876 | elif self.model_mean_type == ModelMeanType.PREVIOUS_X: 877 | model_pred_x0 = self._predict_xstart_from_xprev(x_t, t, model_output) 878 | else: 879 | raise NotImplementedError() 880 | # Get loss weights depending on timestep 881 | if self.weight_schedule == "sqrt_snr": 882 | weights = np.sqrt(self.snrs) 883 | elif self.weight_schedule == "p2": # from https://arxiv.org/abs/2204.00227 884 | weights = self.snrs * self.sqrt_one_minus_alphas_cumprod 885 | elif self.weight_schedule == "snr": 886 | weights = self.snrs 887 | elif self.weight_schedule == "snr+1": 888 | weights = self.snrs + 1 889 | elif self.weight_schedule == "truncated_snr": 890 | weights = np.maximum(self.snrs, 1.0) 891 | else: 892 | raise NotImplementedError() 893 | weights = _extract_into_tensor(weights, t, x_start.shape) 894 | # Compute loss 895 | assert model_output.shape == x_start.shape 896 | terms["mse"] = mean_flat(weights * (x_start - model_pred_x0)**2) 897 | else: 898 | target = { 899 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], 900 | ModelMeanType.START_X: x_start, 901 | ModelMeanType.EPSILON: noise, 902 | }[self.model_mean_type] 903 | assert model_output.shape == target.shape == x_start.shape 904 | terms["mse"] = mean_flat((target - model_output) ** 2) 905 | if "vb" in terms: 906 | terms["loss"] = terms["mse"] + terms["vb"] 907 | else: 908 | terms["loss"] = terms["mse"] 909 | elif self.loss_type == LossType.FLOW: 910 | model_output = model(x_t, t, **model_kwargs) 911 | 912 | # TODO: Implement learned variance for flow-based models 913 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 914 | raise NotImplementedError(f'{self.model_var_type = } is not implemented for flow yet') 915 | 916 | # Loss 917 | target = x_start - noise 918 | terms["flow"] = mean_flat((model_output - target) ** 2) 919 | terms["loss"] = terms["flow"] 920 | 921 | else: 922 | raise NotImplementedError(self.loss_type) 923 | 924 | return terms 925 | 926 | def _prior_bpd(self, x_start): 927 | """ 928 | Get the prior KL term for the variational lower-bound, measured in 929 | bits-per-dim. 930 | This term can't be optimized, as it only depends on the encoder. 931 | :param x_start: the [N x C x ...] tensor of inputs. 932 | :return: a batch of [N] KL values (in bits), one per batch element. 933 | """ 934 | batch_size = x_start.shape[0] 935 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 936 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 937 | kl_prior = normal_kl( 938 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 939 | ) 940 | return mean_flat(kl_prior) / np.log(2.0) 941 | 942 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 943 | """ 944 | Compute the entire variational lower-bound, measured in bits-per-dim, 945 | as well as other related quantities. 946 | :param model: the model to evaluate loss on. 947 | :param x_start: the [N x C x ...] tensor of inputs. 948 | :param clip_denoised: if True, clip denoised samples. 949 | :param model_kwargs: if not None, a dict of extra keyword arguments to 950 | pass to the model. This can be used for conditioning. 951 | :return: a dict containing the following keys: 952 | - total_bpd: the total variational lower-bound, per batch element. 953 | - prior_bpd: the prior term in the lower-bound. 954 | - vb: an [N x T] tensor of terms in the lower-bound. 955 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 956 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 957 | """ 958 | device = x_start.device 959 | batch_size = x_start.shape[0] 960 | 961 | vb = [] 962 | xstart_mse = [] 963 | mse = [] 964 | for t in list(range(self.num_timesteps))[::-1]: 965 | t_batch = th.tensor([t] * batch_size, device=device) 966 | noise = th.randn_like(x_start) 967 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 968 | # Calculate VLB term at the current timestep 969 | with th.no_grad(): 970 | out = self._vb_terms_bpd( 971 | model, 972 | x_start=x_start, 973 | x_t=x_t, 974 | t=t_batch, 975 | clip_denoised=clip_denoised, 976 | model_kwargs=model_kwargs, 977 | ) 978 | vb.append(out["output"]) 979 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 980 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 981 | mse.append(mean_flat((eps - noise) ** 2)) 982 | 983 | vb = th.stack(vb, dim=1) 984 | xstart_mse = th.stack(xstart_mse, dim=1) 985 | mse = th.stack(mse, dim=1) 986 | 987 | prior_bpd = self._prior_bpd(x_start) 988 | total_bpd = vb.sum(dim=1) + prior_bpd 989 | return { 990 | "total_bpd": total_bpd, 991 | "prior_bpd": prior_bpd, 992 | "vb": vb, 993 | "xstart_mse": xstart_mse, 994 | "mse": mse, 995 | } 996 | 997 | def p_sample_loop_ode( 998 | self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, 999 | device=None, progress=False, return_all=False, 1000 | ): 1001 | """ 1002 | Generate samples from the model and yield intermediate samples from 1003 | each timestep of diffusion. 1004 | Arguments are the same as p_sample_loop(). 1005 | Returns a generator over dicts, where each dict is the return value of 1006 | p_sample(). 1007 | """ 1008 | assert isinstance(shape, (tuple, list)) 1009 | device = device or next(model.parameters()).device 1010 | img = noise if noise is not None else th.randn(*shape, device=device) 1011 | 1012 | intermediates = [] 1013 | 1014 | def ode_func(t: float, x: th.FloatTensor): 1015 | nonlocal intermediates 1016 | print(f'In ode_func, {t = }') 1017 | print(f'In ode_func, {x = }') 1018 | t_input = th.round((1 - t) * self.num_timesteps).expand(shape[0]).to(device=device, dtype=th.long) # first map [0, 1] to [T, 0] 1019 | t_input = th.clip(t_input - 1, min=0, max=None) 1020 | # t_input = th.tensor([t_input] * shape[0], device=device) # convert to tensor 1021 | out = self.p_sample( 1022 | model, x, t_input, clip_denoised=clip_denoised, 1023 | denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, 1024 | ) 1025 | x_new = out["sample"] 1026 | intermediates.append({ 1027 | 't': t, 't_input': th.round((1 - t) * self.num_timesteps), 'sample': x_new, 1028 | 'pred_xstart': x_new, # <-- this should be computed separately but that's fine 1029 | }) 1030 | return x_new 1031 | 1032 | from torchdiffeq import odeint 1033 | 1034 | odeint_timescale = 1.0 1035 | odeint_kwargs = { 1036 | "atol": 1e-5, 1037 | "rtol": 1e-5, 1038 | "options": { 1039 | "step_t": [1.0 + 1e-7], 1040 | } 1041 | } 1042 | 1043 | with th.no_grad(): 1044 | t = th.tensor([0, 1.0], device=device) 1045 | z, log_det = odeint( 1046 | ode_func, 1047 | img, 1048 | t * odeint_timescale, 1049 | **odeint_kwargs, 1050 | ) 1051 | 1052 | final_sample = z 1053 | return final_sample, intermediates 1054 | 1055 | 1056 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 1057 | """ 1058 | Extract values from a 1-D numpy array for a batch of indices. 1059 | :param arr: the 1-D numpy array. 1060 | :param timesteps: a tensor of indices into the array to extract. 1061 | :param broadcast_shape: a larger shape of K dimensions with the batch 1062 | dimension equal to the length of timesteps. 1063 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 1064 | """ 1065 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 1066 | while len(res.shape) < len(broadcast_shape): 1067 | res = res[..., None] 1068 | return res + th.zeros(broadcast_shape, device=timesteps.device) -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | # print(f'Inside _WrappedModel: {ts = } and {new_ts = }') 130 | return self.model(x, new_ts, **kwargs) 131 | -------------------------------------------------------------------------------- /diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: DiT 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pytorch >= 1.13 8 | - torchvision 9 | - pytorch-cuda=11.7 10 | - pip: 11 | - timm 12 | - diffusers 13 | - accelerate 14 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import random 14 | from typing import Optional 15 | from contextlib import nullcontext 16 | 17 | import torch 18 | import torch.nn as nn 19 | import numpy as np 20 | from jaxtyping import Float, Shaped 21 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 22 | from torch import Tensor 23 | from torch.utils.checkpoint import checkpoint 24 | 25 | def modulate(x, shift, scale): 26 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 27 | 28 | 29 | ################################################################################# 30 | # Embedding Layers for Timesteps and Class Labels # 31 | ################################################################################# 32 | 33 | class TimestepEmbedder(nn.Module): 34 | """ 35 | Embeds scalar timesteps into vector representations. 36 | """ 37 | def __init__(self, hidden_size, frequency_embedding_size=256): 38 | super().__init__() 39 | self.mlp = nn.Sequential( 40 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 41 | nn.SiLU(), 42 | nn.Linear(hidden_size, hidden_size, bias=True), 43 | ) 44 | self.frequency_embedding_size = frequency_embedding_size 45 | 46 | @staticmethod 47 | def timestep_embedding(t, dim, max_period=10000): 48 | """ 49 | Create sinusoidal timestep embeddings. 50 | :param t: a 1-D Tensor of N indices, one per batch element. 51 | These may be fractional. 52 | :param dim: the dimension of the output. 53 | :param max_period: controls the minimum frequency of the embeddings. 54 | :return: an (N, D) Tensor of positional embeddings. 55 | """ 56 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 57 | half = dim // 2 58 | freqs = torch.exp( 59 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 60 | ).to(device=t.device) 61 | args = t[:, None].float() * freqs[None] 62 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 63 | if dim % 2: 64 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 65 | return embedding 66 | 67 | def forward(self, t): 68 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 69 | t_emb = self.mlp(t_freq) 70 | return t_emb 71 | 72 | 73 | class LabelEmbedder(nn.Module): 74 | """ 75 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 76 | """ 77 | def __init__(self, num_classes, hidden_size, dropout_prob, use_cfg_embedding: bool = True, continuous: bool = False): 78 | super().__init__() 79 | self.continuous = continuous 80 | self.num_classes = num_classes 81 | self.dropout_prob = dropout_prob 82 | if self.continuous: 83 | self.embedding_projection = nn.Linear(num_classes, hidden_size) 84 | else: 85 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 86 | 87 | def token_drop(self, labels, force_drop_ids=None): 88 | """ 89 | Drops labels to enable classifier-free guidance. 90 | """ 91 | if force_drop_ids is None: 92 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 93 | else: 94 | drop_ids = force_drop_ids == 1 95 | if self.continuous: 96 | labels = labels * (1 - drop_ids[:, None].to(labels)) 97 | else: 98 | labels = torch.where(drop_ids, self.num_classes, labels) 99 | return labels 100 | 101 | def forward(self, labels, train, force_drop_ids=None): 102 | use_dropout = self.dropout_prob > 0 103 | if (train and use_dropout) or (force_drop_ids is not None): 104 | labels = self.token_drop(labels, force_drop_ids) 105 | embeddings = self.embedding_projection(labels) if self.continuous else self.embedding_table(labels) 106 | return embeddings 107 | 108 | 109 | 110 | ################################################################################# 111 | # Core DiT Model # 112 | ################################################################################# 113 | 114 | class DiTBlock(nn.Module): 115 | """ 116 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 117 | """ 118 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 119 | super().__init__() 120 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 121 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 122 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 123 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 124 | approx_gelu = lambda: nn.GELU(approximate="tanh") 125 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 126 | self.adaLN_modulation = nn.Sequential( 127 | nn.SiLU(), 128 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 129 | ) 130 | 131 | def forward(self, x, c): 132 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 133 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 134 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 135 | return x 136 | 137 | 138 | class FinalLayer(nn.Module): 139 | """ 140 | The final layer of DiT. 141 | """ 142 | def __init__(self, hidden_size, patch_size, out_channels): 143 | super().__init__() 144 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 145 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 146 | self.adaLN_modulation = nn.Sequential( 147 | nn.SiLU(), 148 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 149 | ) 150 | 151 | def forward(self, x, c): 152 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 153 | x = modulate(self.norm_final(x), shift, scale) 154 | x = self.linear(x) 155 | return x 156 | 157 | 158 | class DiT(nn.Module): 159 | """ 160 | Diffusion model with a Transformer backbone. 161 | """ 162 | def __init__( 163 | self, 164 | input_size=32, 165 | patch_size=2, 166 | in_channels=4, 167 | hidden_size=1152, 168 | depth=28, 169 | num_heads=16, 170 | mlp_ratio=4.0, 171 | class_dropout_prob=0.1, 172 | num_classes=1000, 173 | learn_sigma=True, 174 | use_cfg_embedding: bool = True, 175 | use_gradient_checkpointing: bool = True, 176 | is_label_continuous: bool = False, 177 | 178 | # below are Fixed Point-specific arguments. 179 | fixed_point: bool = False, 180 | 181 | # size 182 | fixed_point_pre_depth: int = 1, 183 | fixed_point_post_depth: int = 1, 184 | 185 | # iteration counts 186 | fixed_point_no_grad_min_iters: int = 0, 187 | fixed_point_no_grad_max_iters: int = 0, 188 | fixed_point_with_grad_min_iters: int = 28, 189 | fixed_point_with_grad_max_iters: int = 28, 190 | 191 | # solution recycle 192 | fixed_point_reuse_solution = False, 193 | 194 | # pre_post_timestep_conditioning 195 | fixed_point_pre_post_timestep_conditioning: bool = True, 196 | 197 | # adaptively distributing iterations among timesteps. Currently we only support linear distribution. 198 | adaptive: bool = False, 199 | iteration_controller = None 200 | ): 201 | super().__init__() 202 | self.learn_sigma = learn_sigma 203 | self.in_channels = in_channels 204 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 205 | self.patch_size = patch_size 206 | self.num_heads = num_heads 207 | self.use_gradient_checkpointing = use_gradient_checkpointing 208 | 209 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 210 | self.t_embedder = TimestepEmbedder(hidden_size) 211 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, 212 | use_cfg_embedding=use_cfg_embedding, continuous=is_label_continuous) 213 | num_patches = self.x_embedder.num_patches 214 | 215 | # Will use fixed sin-cos embedding: 216 | # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 217 | self.register_buffer('pos_embed', torch.zeros(1, num_patches, hidden_size)) 218 | 219 | # New: Fixed Point 220 | self.fixed_point = fixed_point 221 | if self.fixed_point: 222 | self.fixed_point_no_grad_min_iters = fixed_point_no_grad_min_iters 223 | self.fixed_point_no_grad_max_iters = fixed_point_no_grad_max_iters 224 | self.fixed_point_with_grad_min_iters = fixed_point_with_grad_min_iters 225 | self.fixed_point_with_grad_max_iters = fixed_point_with_grad_max_iters 226 | self.fixed_point_pre_post_timestep_conditioning = fixed_point_pre_post_timestep_conditioning 227 | self.blocks_pre = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(fixed_point_pre_depth)]) 228 | self.block_pre_projection = nn.Linear(hidden_size, hidden_size) 229 | self.block_fixed_point_projection_fc1 = nn.Linear(2 * hidden_size, 2 * hidden_size) 230 | self.block_fixed_point_projection_act = nn.GELU(approximate="tanh") 231 | self.block_fixed_point_projection_fc2 = nn.Linear(2 * hidden_size, hidden_size) 232 | self.block_fixed_point = DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) 233 | self.blocks_post = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(fixed_point_post_depth)]) 234 | self.blocks = [*self.blocks_pre, self.block_fixed_point, *self.blocks_post] 235 | self.fixed_point_reuse_solution = fixed_point_reuse_solution 236 | self.last_solution = None 237 | 238 | self.adaptive = adaptive 239 | self.iteration_controller = iteration_controller 240 | else: 241 | self.blocks = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]) 242 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 243 | self.initialize_weights() 244 | 245 | def initialize_weights(self): 246 | # Initialize transformer layers: 247 | def _basic_init(module): 248 | if isinstance(module, nn.Linear): 249 | torch.nn.init.xavier_uniform_(module.weight) 250 | if module.bias is not None: 251 | nn.init.constant_(module.bias, 0) 252 | self.apply(_basic_init) 253 | 254 | # Initialize (and freeze) pos_embed by sin-cos embedding: 255 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 256 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 257 | 258 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 259 | w = self.x_embedder.proj.weight.data 260 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 261 | nn.init.constant_(self.x_embedder.proj.bias, 0) 262 | 263 | # Initialize label embedding table: 264 | if self.y_embedder.continuous: 265 | nn.init.normal_(self.y_embedder.embedding_projection.weight, std=0.02) 266 | nn.init.constant_(self.y_embedder.embedding_projection.bias, 0) 267 | else: 268 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 269 | 270 | # Initialize timestep embedding MLP: 271 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 272 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 273 | 274 | # Zero-out adaLN modulation layers in DiT blocks: 275 | for block in self.blocks: 276 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 277 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 278 | 279 | # Zero-out output layers: 280 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 281 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 282 | nn.init.constant_(self.final_layer.linear.weight, 0) 283 | nn.init.constant_(self.final_layer.linear.bias, 0) 284 | 285 | def unpatchify(self, x): 286 | """ 287 | x: (N, T, patch_size**2 * C) 288 | imgs: (N, H, W, C) 289 | """ 290 | c = self.out_channels 291 | p = self.x_embedder.patch_size[0] 292 | h = w = int(x.shape[1] ** 0.5) 293 | assert h * w == x.shape[1] 294 | 295 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 296 | x = torch.einsum('nhwpqc->nchpwq', x) 297 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 298 | return imgs 299 | 300 | def ckpt_wrapper(self, module): 301 | def ckpt_forward(*inputs): 302 | outputs = module(*inputs) 303 | return outputs 304 | return ckpt_forward 305 | 306 | def _forward_dit(self, x, t, y): 307 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 308 | t = self.t_embedder(t) # (N, D) 309 | y = self.y_embedder(y, self.training) # (N, D)) 310 | c = t + y # (N, D) 311 | for block in self.blocks: 312 | x = checkpoint(self.ckpt_wrapper(block), x, c) if self.use_gradient_checkpointing else block(x, c) # (N, T, D) 313 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 314 | x = self.unpatchify(x) # (N, out_channels, H, W) 315 | return x 316 | 317 | def _forward_fixed_point_blocks( 318 | self, x: Float[Tensor, "b t d"], x_input_injection: Float[Tensor, "b t d"], c: Float[Tensor, "b d"], num_iterations: int 319 | ) -> Float[Tensor, "b t d"]: 320 | def forward_pass(x): 321 | x = torch.cat((x, x_input_injection), dim=-1) # (N, T, D * 2) 322 | x = self.block_fixed_point_projection_fc1(x) # (N, T, D * 2) 323 | x = self.block_fixed_point_projection_act(x) # (N, T, D * 2) 324 | x = self.block_fixed_point_projection_fc2(x) # (N, T, D) 325 | x = self.block_fixed_point(x, c) # (N, T, D) 326 | return x 327 | 328 | if self.adaptive: 329 | num_iterations = self.iteration_controller.get() 330 | 331 | for _ in range(num_iterations): 332 | x = forward_pass(x) 333 | return x 334 | 335 | def _check_inputs(self, x: Float[Tensor, "b c h w"], t: Shaped[Tensor, "b"], y: Shaped[Tensor, "b"]) -> None: 336 | if self.fixed_point_reuse_solution: 337 | if not torch.all(t[0] == t).item(): 338 | raise ValueError(t) 339 | 340 | def _forward_fixed_point(self, x: Float[Tensor, "b c h w"], t: Shaped[Tensor, "b"], y: Shaped[Tensor, "b"]) -> Float[Tensor, "b c h w"]: 341 | self._check_inputs(x, t, y) 342 | x: Float[Tensor, "b t d"] = self.x_embedder(x) + self.pos_embed 343 | t_emb: Float[Tensor, "b d"] = self.t_embedder(t) 344 | y: Float[Tensor, "b d"] = self.y_embedder(y, self.training) 345 | c: Float[Tensor, "b d"] = t_emb + y 346 | c_pre_post_fixed_point: Float[Tensor, "b d"] = (t_emb + y) if self.fixed_point_pre_post_timestep_conditioning else y 347 | 348 | # Pre-Fixed Point 349 | # Note: If using DDP with find_unused_parameters=True, checkpoint causes issues. For more 350 | # information, see https://github.com/allenai/longformer/issues/63#issuecomment-648861503 351 | for block in self.blocks_pre: 352 | x: Float[Tensor, "b t d"] = checkpoint(self.ckpt_wrapper(block), x, c_pre_post_fixed_point) if self.use_gradient_checkpointing else block(x, c_pre_post_fixed_point) 353 | condition = x.clone() 354 | 355 | # Whether to reuse the previous solution at the next iteration 356 | init_solution = self.last_solution if (self.fixed_point_reuse_solution and self.last_solution is not None) else x.clone() 357 | 358 | # Fixed Point (we have condition and init_solution) 359 | x_input_injection = self.block_pre_projection(condition) 360 | 361 | # NOTE: This section of code should have no_grad, but cannot due to a DDP bug. See 362 | # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594 363 | # for more information 364 | with nullcontext(): # we use x.detach() in place of torch.no_grad due to DDP issue 365 | num_iterations_no_grad = random.randint(self.fixed_point_no_grad_min_iters, self.fixed_point_no_grad_max_iters) 366 | x = self._forward_fixed_point_blocks(x=init_solution.detach(), x_input_injection=x_input_injection.detach(), c=c, num_iterations=num_iterations_no_grad) 367 | x = x.detach() # no grad 368 | num_iterations_with_grad = random.randint(self.fixed_point_with_grad_min_iters, self.fixed_point_with_grad_max_iters) 369 | x = self._forward_fixed_point_blocks(x=x, x_input_injection=x_input_injection, c=c, num_iterations=num_iterations_with_grad) 370 | 371 | # Save solution for reuse at next step 372 | if self.fixed_point_reuse_solution: 373 | self.last_solution = x.clone() 374 | 375 | # Post-Fixed Point 376 | for block in self.blocks_post: 377 | x = checkpoint(self.ckpt_wrapper(block), x, c_pre_post_fixed_point) if self.use_gradient_checkpointing else block(x, c_pre_post_fixed_point) 378 | 379 | # Output 380 | x: Float[Tensor, "b t p2c"] = self.final_layer(x, c_pre_post_fixed_point) # p2c = patch_size ** 2 * out_channels) 381 | x: Float[Tensor, "b c h w"] = self.unpatchify(x) 382 | return x 383 | 384 | def reset(self): 385 | self.last_solution = None 386 | 387 | def forward(self, x, t, y): 388 | """ 389 | Forward pass of DiT. 390 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 391 | t: (N,) tensor of diffusion timesteps 392 | y: (N,) tensor of class labels 393 | """ 394 | if self.fixed_point: 395 | return self._forward_fixed_point(x, t, y) 396 | else: 397 | return self._forward_dit(x, t, y) 398 | 399 | def forward_with_cfg(self, x, t, y, cfg_scale): 400 | """ 401 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. 402 | """ 403 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 404 | half = x[: len(x) // 2] 405 | combined = torch.cat([half, half], dim=0) 406 | model_out = self.forward(combined, t, y) 407 | # For exact reproducibility reasons, we apply classifier-free guidance on only 408 | # three channels by default. The standard approach to cfg applies it to all channels. 409 | # This can be done by uncommenting the following line and commenting-out the line following that. 410 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 411 | eps, rest = model_out[:, :3], model_out[:, 3:] 412 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 413 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 414 | eps = torch.cat([half_eps, half_eps], dim=0) 415 | return torch.cat([eps, rest], dim=1) 416 | 417 | 418 | ################################################################################# 419 | # Sine/Cosine Positional Embedding Functions # 420 | ################################################################################# 421 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 422 | 423 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 424 | """ 425 | grid_size: int of the grid height and width 426 | return: 427 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 428 | """ 429 | grid_h = np.arange(grid_size, dtype=np.float32) 430 | grid_w = np.arange(grid_size, dtype=np.float32) 431 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 432 | grid = np.stack(grid, axis=0) 433 | 434 | grid = grid.reshape([2, 1, grid_size, grid_size]) 435 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 436 | if cls_token and extra_tokens > 0: 437 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 438 | return pos_embed 439 | 440 | 441 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 442 | assert embed_dim % 2 == 0 443 | 444 | # use half of dimensions to encode grid_h 445 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 446 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 447 | 448 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 449 | return emb 450 | 451 | 452 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 453 | """ 454 | embed_dim: output dimension for each position 455 | pos: a list of positions to be encoded: size (M,) 456 | out: (M, D) 457 | """ 458 | assert embed_dim % 2 == 0 459 | omega = np.arange(embed_dim // 2, dtype=np.float64) 460 | omega /= embed_dim / 2. 461 | omega = 1. / 10000**omega # (D/2,) 462 | 463 | pos = pos.reshape(-1) # (M,) 464 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 465 | 466 | emb_sin = np.sin(out) # (M, D/2) 467 | emb_cos = np.cos(out) # (M, D/2) 468 | 469 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 470 | return emb 471 | 472 | 473 | ################################################################################# 474 | # DiT Configs # 475 | ################################################################################# 476 | 477 | def DiT_XL_2(**kwargs): 478 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 479 | 480 | def DiT_XL_4(**kwargs): 481 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 482 | 483 | def DiT_XL_8(**kwargs): 484 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 485 | 486 | def DiT_L_2(**kwargs): 487 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 488 | 489 | def DiT_L_4(**kwargs): 490 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 491 | 492 | def DiT_L_8(**kwargs): 493 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 494 | 495 | def DiT_B_2(**kwargs): 496 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 497 | 498 | def DiT_B_4(**kwargs): 499 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) 500 | 501 | def DiT_B_8(**kwargs): 502 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) 503 | 504 | def DiT_S_2(**kwargs): 505 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 506 | 507 | def DiT_S_4(**kwargs): 508 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 509 | 510 | def DiT_S_8(**kwargs): 511 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 512 | 513 | 514 | DiT_models = { 515 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, 516 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 517 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 518 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, 519 | } 520 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import torch 7 | from accelerate import Accelerator 8 | from accelerate.utils import set_seed 9 | from diffusers.models import AutoencoderKL 10 | from PIL import Image 11 | from tap import Tap 12 | from tqdm import trange 13 | 14 | 15 | from diffusion import create_diffusion 16 | from download import find_model 17 | from models import DiT_models 18 | from adaptive_controller import LinearController 19 | 20 | 21 | class Args(Tap): 22 | 23 | # Paths 24 | output_dir: str = 'samples' 25 | 26 | # Dataset 27 | dataset_name: str = "imagenet256" 28 | 29 | # Model 30 | model: str = "DiT-XL/2" 31 | vae: str = "mse" 32 | num_classes: int = 1000 33 | image_size: int = 256 34 | predict_v: bool = False 35 | use_zero_terminal_snr: bool = False 36 | unsupervised: bool = False 37 | dino_supervised: bool = False 38 | dino_supervised_dim: int = 768 39 | flow: bool = False 40 | debug: bool = False 41 | 42 | # Fixed Point settings 43 | fixed_point: bool = False 44 | fixed_point_pre_depth: int = 2 45 | fixed_point_post_depth: int = 2 46 | fixed_point_iters: Optional[int] = None 47 | fixed_point_pre_post_timestep_conditioning: bool = False 48 | fixed_point_reuse_solution: bool = False 49 | 50 | # Sampling 51 | ddim: bool = False 52 | cfg_scale: float = 4.0 53 | num_sampling_steps: int = 250 54 | batch_size: int = 32 55 | ckpt: str = '...' 56 | global_seed: int = 0 57 | 58 | # Parallelization 59 | sample_index_start: int = 0 60 | sample_index_end: Optional[int] = 50_000 61 | 62 | # Adaptive 63 | adaptive: bool = False 64 | adaptive_type: str = "increasing" # currently only support increasing, fixed, and decreasing 65 | 66 | def process_args(self) -> None: 67 | """Additional argument processing""" 68 | if self.debug: 69 | self.log_with = 'tensorboard' 70 | self.name = 'debug' 71 | 72 | # Defaults 73 | self.fixed_point_iters = self.fixed_point_iters or (28 - self.fixed_point_pre_depth - self.fixed_point_post_depth) 74 | 75 | # Checks 76 | if self.cfg_scale < 1.0: 77 | raise ValueError("In almost all cases, cfg_scale should be >= 1.0") 78 | if self.unsupervised: 79 | assert self.cfg_scale == 1.0 80 | self.num_classes = 1 81 | elif self.dino_supervised: 82 | raise NotImplementedError() 83 | if not Path(self.ckpt).is_file(): 84 | raise ValueError(self.ckpt) 85 | 86 | # Create output directory 87 | output_parent = Path(self.output_dir) / Path(self.ckpt).parent.parent.name 88 | if self.debug: 89 | output_dirname = 'debug' 90 | else: 91 | output_dirname = f'num_sampling_steps-{self.num_sampling_steps}--cfg_scale-{self.cfg_scale}' 92 | if self.fixed_point: 93 | output_dirname += f'--fixed_point_iters-{self.fixed_point_iters}--fixed_point_reuse_solution-{self.fixed_point_reuse_solution}--fixed_point_pptc-{self.fixed_point_pre_post_timestep_conditioning}' 94 | if self.ddim: 95 | output_dirname += f'--ddim' 96 | self.output_dir = str(output_parent / output_dirname) 97 | Path(self.output_dir).mkdir(exist_ok=True, parents=True) 98 | 99 | if self.adaptive: 100 | self.budget = self.num_sampling_steps * self.deq_iters 101 | if self.adaptive_type == "increasing" or self.adaptive_type == "decreasing" or self.adaptive_type == "fixed": 102 | self.iteration_controller = LinearController(self.budget, self.num_sampling_steps, type = self.adaptive_type) 103 | else: 104 | raise NotImplementedError() 105 | 106 | def main(args: Args): 107 | 108 | # Setup accelerator, logging, randomness 109 | accelerator = Accelerator() 110 | set_seed(args.global_seed + args.sample_index_start) 111 | 112 | # Load model 113 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." 114 | latent_size = H_lat = W_lat = args.image_size // 8 115 | # print(f"!!! Pre model init, fixed_point_reuse_solution: {args.fixed_point_reuse_solution}") 116 | model = DiT_models[args.model]( 117 | input_size=latent_size, 118 | num_classes=(args.dino_supervised_dim if args.dino_supervised else args.num_classes), 119 | is_label_continuous=args.dino_supervised, 120 | class_dropout_prob=0, 121 | learn_sigma=(not args.flow), # TODO: Implement learned variance for flow-based models 122 | use_gradient_checkpointing=False, 123 | fixed_point=args.fixed_point, 124 | fixed_point_pre_depth=args.fixed_point_pre_depth, 125 | fixed_point_post_depth=args.fixed_point_post_depth, 126 | fixed_point_no_grad_min_iters=0, 127 | fixed_point_no_grad_max_iters=0, 128 | fixed_point_with_grad_min_iters=args.fixed_point_iters, 129 | fixed_point_with_grad_max_iters=args.fixed_point_iters, 130 | fixed_point_reuse_solution=args.fixed_point_reuse_solution, 131 | fixed_point_pre_post_timestep_conditioning=args.fixed_point_pre_post_timestep_conditioning, 132 | adaptive=args.adaptive, 133 | iteration_controller=args.iteration_controller, 134 | ).to(accelerator.device) 135 | print(f'Loaded model with params: {sum(p.numel() for p in model.parameters()):_}') 136 | 137 | # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py 138 | state_dict = find_model(args.ckpt) 139 | model.load_state_dict(state_dict) 140 | model.eval() 141 | diffusion = create_diffusion( 142 | str(args.num_sampling_steps), 143 | use_flow=args.flow, 144 | predict_v=args.predict_v, 145 | use_zero_terminal_snr=args.use_zero_terminal_snr, 146 | ) 147 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(accelerator.device).eval() 148 | using_cfg = args.cfg_scale > 1.0 149 | 150 | # Generate pseudorandom class labels and noises. Note that these are generated using the global 151 | # seed, which is shared between all processes. As a result, all processes will generate the same 152 | # list of class labels and noises. Then, we take a subset of these based on the `process_index` 153 | # of the current process. 154 | N = 50_000 # this assumes we will never sample more than 50K samples, which I think is reasonable 155 | generator = torch.Generator(device=accelerator.device) 156 | generator.manual_seed(args.global_seed) 157 | class_labels = torch.randint(0, args.num_classes, size=(N,), device=accelerator.device) 158 | generator.manual_seed(args.global_seed) 159 | latents = torch.randn(N, model.in_channels, H_lat, W_lat, device=accelerator.device, generator=generator) 160 | class_labels = class_labels[args.sample_index_start:args.sample_index_end] 161 | latents = latents[args.sample_index_start:args.sample_index_end] 162 | indices = list(range(args.sample_index_start, args.sample_index_end)) 163 | print(f'Using pseudorandom class labels and latents (start={args.sample_index_start} and end={args.sample_index_end})') 164 | 165 | # Create output path 166 | output_dir = Path(args.output_dir) 167 | args.save(output_dir / 'args.json') 168 | print(f'Saving samples to {output_dir.resolve()}') 169 | 170 | # Load class labels for helpful filenames 171 | if args.dataset_name == 'imagenet256': 172 | with open("utils/imagenet-labels.json", "r") as f: 173 | label_names: list[str] = json.load(f) 174 | label_names = [l.lower().replace(' ', '-').replace('\'', '') for l in label_names] 175 | elif args.unsupervised: 176 | assert args.cfg_scale == 1.0 177 | label_names = ["unlabeled"] 178 | else: 179 | raise NotImplementedError() 180 | 181 | # Disable gradient 182 | with torch.inference_mode(): 183 | 184 | # Sample loop 185 | num_batches = math.ceil(len(class_labels) / args.batch_size) 186 | for batch_idx in trange(num_batches, disable=(not accelerator.is_main_process)): 187 | 188 | # Get pre-sampled inputs 189 | z = latents[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size] 190 | y = class_labels[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size] 191 | idxs = indices[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size] 192 | output_paths = [output_dir / f'{idx:05d}--{y_i:03d}--{label_names[y_i]}.png' for y_i, idx in zip(y.tolist(), idxs)] 193 | 194 | # Skip files that already exist 195 | if all(output_path.is_file() for output_path in output_paths): 196 | print(f'Files already exist (batch {batch_idx}). Skipping.') 197 | continue 198 | 199 | # Setup classifier-free guidance 200 | if using_cfg: 201 | y_null = torch.tensor([1000] * args.batch_size, device=accelerator.device) 202 | y = torch.cat([y, y_null], 0) 203 | z = torch.cat([z, z], 0) 204 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) 205 | sample_fn = model.forward_with_cfg 206 | else: 207 | model_kwargs = dict(y=y) 208 | sample_fn = model.forward 209 | 210 | if args.adaptive: 211 | model.threshold_controller.init_image() 212 | 213 | # Sample latent images 214 | sample_kwargs = dict(model=sample_fn, shape=z.shape, noise=z, clip_denoised=False, model_kwargs=model_kwargs, 215 | progress=False, device=accelerator.device) 216 | if args.ddim: 217 | samples = diffusion.ddim_sample_loop(**sample_kwargs) 218 | else: 219 | samples = diffusion.p_sample_loop(**sample_kwargs) 220 | 221 | 222 | 223 | if using_cfg: 224 | samples, _ = samples.chunk(2, dim=0) 225 | 226 | 227 | # Reset model (resets the initial solution to None) 228 | model.reset() 229 | 230 | # Decode latents 231 | samples = vae.decode(samples / vae.config.scaling_factor).sample 232 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 233 | 234 | # Save samples to disk as individual .png files 235 | for sample, output_path in zip(samples, output_paths): 236 | Image.fromarray(sample).save(output_path) 237 | 238 | 239 | if __name__ == "__main__": 240 | args = Args(explicit_bool=True).parse_args() 241 | main(args) 242 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from collections import OrderedDict 4 | from copy import deepcopy 5 | from glob import glob 6 | from pathlib import Path 7 | from time import time 8 | from typing import Callable, Optional 9 | 10 | import numpy as np 11 | import torch 12 | from accelerate import Accelerator, DistributedDataParallelKwargs 13 | from accelerate.utils import ProjectConfiguration 14 | from PIL import Image 15 | from tap import Tap 16 | from torch.utils.data import DataLoader, Dataset 17 | from tqdm import tqdm 18 | 19 | from diffusion import create_diffusion 20 | from models import DiT_models 21 | 22 | try: 23 | from streaming import StreamingDataset 24 | except ImportError: 25 | StreamingDataset = Dataset 26 | 27 | 28 | class Args(Tap): 29 | 30 | # Paths 31 | feature_path: Optional[str] = None 32 | dataset_name: str = "imagenet256" 33 | name: Optional[str] = None 34 | output_dir: str = "results" 35 | output_subdir: str = "runs" 36 | 37 | # Model 38 | model: str = "DiT-XL/2" 39 | num_classes: int = 1000 40 | image_size: int = 256 41 | predict_v: bool = False 42 | use_zero_terminal_snr: bool = False 43 | unsupervised: bool = False 44 | dino_supervised: bool = False 45 | dino_supervised_dim: int = 768 46 | flow: bool = False 47 | fixed_point: bool = False 48 | fixed_point_pre_depth: int = 1 49 | fixed_point_post_depth: int = 1 50 | fixed_point_no_grad_min_iters: int = 0 51 | fixed_point_no_grad_max_iters: int = 10 52 | fixed_point_with_grad_min_iters: int = 1 53 | fixed_point_with_grad_max_iters: int = 12 54 | fixed_point_pre_post_timestep_conditioning: bool = False 55 | 56 | # Training 57 | epochs: int = 1400 58 | global_batch_size: int = 512 59 | global_seed: int = 0 60 | num_workers: int = 4 61 | log_every: int = 100 62 | ckpt_every: int = 100_000 63 | lr: float = 1e-4 64 | log_with: str = "wandb" 65 | resume: Optional[str] = None 66 | # use_streaming_dataset: bool = False 67 | compile: bool = False 68 | debug: bool = False 69 | 70 | def process_args(self) -> None: 71 | """Additional argument processing""" 72 | if self.debug: 73 | self.log_with = 'tensorboard' 74 | self.output_subdir = 'debug' 75 | 76 | # Auto-generated name 77 | if self.name is None: 78 | experiment_index = len(glob(os.path.join(self.output_dir, self.output_subdir, "*"))) 79 | model_string_name = self.model.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders) 80 | model_string_name += '--flow' if self.flow else '--diff' 81 | model_string_name += '--unsupervised' if self.unsupervised else '--dino_supervised' if self.dino_supervised else '' 82 | model_string_name += f'--{self.lr:f}' 83 | model_string_name += f'--v' if self.predict_v else '' 84 | model_string_name += f'--zero_snr' if self.use_zero_terminal_snr else '' 85 | ngmin, ngmax = self.fixed_point_no_grad_min_iters, self.fixed_point_no_grad_max_iters 86 | gmin, gmax = self.fixed_point_with_grad_min_iters, self.fixed_point_with_grad_max_iters 87 | model_string_name += (f'--fixed_point-pre_depth-{self.fixed_point_pre_depth}-post_depth-{self.fixed_point_post_depth}' + 88 | f'-no_grad_iters-{ngmin:02d}-{ngmax:02d}-with_grad_iters-{gmin:02d}-{gmax:02d}' + 89 | f'-pre_post_time_cond_{self.fixed_point_pre_post_timestep_conditioning}' 90 | if self.fixed_point else '--dit') 91 | self.name = f'{experiment_index:03d}-{model_string_name}' 92 | 93 | # Copy data to scratch 94 | if self.feature_path is None: 95 | assert os.getenv('SLURM_JOBID', None) is not None 96 | os.environ['TMPDIR'] = TMPDIR = os.path.join('/opt/dlami/nvme/slurm_tmpdir', os.getenv('SLURM_JOBID')) 97 | self.feature_path = os.path.join('/opt/dlami/nvme/slurm_tmpdir', os.getenv('SLURM_JOBID')) 98 | features_dir = f"{self.feature_path}/{self.dataset_name}_features" 99 | labels_dir = f"{self.feature_path}/{self.dataset_name}_{'dino_vitb8' if self.dino_supervised else 'labels'}" 100 | assert Path(features_dir).is_dir() == Path(labels_dir).is_dir() 101 | if Path(features_dir).is_dir(): 102 | print(f'Features already exist in {TMPDIR}') 103 | else: 104 | start = time() 105 | print(f'Copying features to {TMPDIR}') 106 | copy_cmd_1 = f'cp ./features/{self.dataset_name}_npy.tar {TMPDIR}' 107 | copy_cmd_2 = f'tar xf {os.path.join(TMPDIR, self.dataset_name)}_npy.tar -C {TMPDIR}' 108 | print(copy_cmd_1) 109 | os.system(copy_cmd_1) 110 | print(copy_cmd_2) 111 | os.system(copy_cmd_2) 112 | print(f'Finished copying features to {TMPDIR} in {time() - start:.2f}s') 113 | 114 | # Create output directory 115 | self.output_dir = os.path.join(self.output_dir, self.output_subdir, self.name) 116 | Path(self.output_dir).mkdir(exist_ok=True, parents=True) 117 | 118 | ################################################################################# 119 | # Training Helper Functions # 120 | ################################################################################# 121 | 122 | @torch.no_grad() 123 | def update_ema(ema_model, model, decay=0.9999): 124 | """ 125 | Step the EMA model towards the current model. 126 | """ 127 | ema_params = OrderedDict(ema_model.named_parameters()) 128 | model_params = OrderedDict(model.named_parameters()) 129 | 130 | for name, param in model_params.items(): 131 | name = name.replace("module.", "") 132 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 133 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 134 | 135 | 136 | def requires_grad(model, flag=True): 137 | """ 138 | Set requires_grad flag for all parameters in a model. 139 | """ 140 | for p in model.parameters(): 141 | p.requires_grad = flag 142 | 143 | 144 | def create_logger(logging_dir): 145 | """ 146 | Create a logger that writes to a log file and stdout. 147 | """ 148 | logging.basicConfig( 149 | level=logging.INFO, 150 | format='[\033[34m%(asctime)s\033[0m] %(message)s', 151 | datefmt='%Y-%m-%d %H:%M:%S', 152 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 153 | ) 154 | logger = logging.getLogger(__name__) 155 | return logger 156 | 157 | 158 | def center_crop_arr(pil_image, image_size): 159 | """ 160 | Center cropping implementation from ADM. 161 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 162 | """ 163 | while min(*pil_image.size) >= 2 * image_size: 164 | pil_image = pil_image.resize( 165 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 166 | ) 167 | 168 | scale = image_size / min(*pil_image.size) 169 | pil_image = pil_image.resize( 170 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 171 | ) 172 | 173 | arr = np.array(pil_image) 174 | crop_y = (arr.shape[0] - image_size) // 2 175 | crop_x = (arr.shape[1] - image_size) // 2 176 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 177 | 178 | 179 | class CustomDataset(Dataset): 180 | def __init__(self, features_dir, labels_dir): 181 | self.features_dir = features_dir 182 | self.labels_dir = labels_dir 183 | 184 | self.features_files = os.listdir(features_dir) 185 | self.labels_files = os.listdir(labels_dir) 186 | 187 | def __len__(self): 188 | assert len(self.features_files) == len(self.labels_files), \ 189 | "Number of feature files and label files should be same" 190 | return len(self.features_files) 191 | 192 | def __getitem__(self, idx): 193 | feature_file = self.features_files[idx] 194 | label_file = self.labels_files[idx] 195 | 196 | features = np.load(os.path.join(self.features_dir, feature_file)) 197 | labels = np.load(os.path.join(self.labels_dir, label_file)) 198 | return torch.from_numpy(features), torch.from_numpy(labels) 199 | 200 | 201 | class CustomStreamingDataset(StreamingDataset): 202 | def __init__( 203 | self, 204 | local: str, 205 | remote: Optional[str] = None, 206 | shuffle: bool = False, 207 | batch_size: int = 1, 208 | transform: Optional[Callable] = None, 209 | ): 210 | remote = local if remote is None else remote 211 | super().__init__(remote=remote, local=local, shuffle=shuffle, batch_size=batch_size) 212 | self.transform = transform 213 | 214 | def __getitem__(self, idx): 215 | item = super().__getitem__(idx) 216 | feats = item['features'].squeeze(0) 217 | label = item['class'] 218 | if self.transform is not None: 219 | feats = self.transform(feats) 220 | return feats, label 221 | 222 | 223 | ################################################################################# 224 | # Training Loop # 225 | ################################################################################# 226 | 227 | def main(args: Args): 228 | """ 229 | Trains a new DiT model. 230 | """ 231 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." 232 | 233 | # Setup an experiment folder: 234 | checkpoint_dir = f"{args.output_dir}/checkpoints" # Stores saved model checkpoints 235 | os.makedirs(checkpoint_dir, exist_ok=True) 236 | args.save(f"{args.output_dir}/args.json") 237 | 238 | # Setup accelerator: 239 | find_unused_parameters = False # args.fixed_point and args.fixed_point_b_solver != 'backprop' 240 | print(f'Using {find_unused_parameters = }') 241 | accelerator_ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters) 242 | accelerator = Accelerator( 243 | log_with=args.log_with, 244 | project_config=ProjectConfiguration(project_dir=args.output_dir), 245 | kwargs_handlers=[accelerator_ddp_kwargs], 246 | dynamo_backend=("inductor" if args.compile else None), 247 | ) 248 | device = accelerator.device 249 | 250 | # Create trackers 251 | if accelerator.is_main_process: 252 | accelerator.init_trackers("dit", config=args.as_dict(), init_kwargs={"wandb": {"name": args.name}}) 253 | if args.log_with == 'wandb': 254 | accelerator.get_tracker("wandb", unwrap=True).log_code() 255 | print(args) 256 | 257 | # Create model 258 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." 259 | latent_size = args.image_size // 8 260 | model = DiT_models[args.model]( 261 | input_size=latent_size, 262 | num_classes=(1 if args.unsupervised else (args.dino_supervised_dim if args.dino_supervised else args.num_classes)), 263 | is_label_continuous=args.dino_supervised, 264 | class_dropout_prob=(0.0 if args.unsupervised else 0.1), 265 | learn_sigma=(not args.flow), # TODO: Implement learned variance for flow-based models 266 | use_gradient_checkpointing=(not args.compile and not find_unused_parameters), 267 | fixed_point=args.fixed_point, 268 | fixed_point_pre_depth=args.fixed_point_pre_depth, 269 | fixed_point_post_depth=args.fixed_point_post_depth, 270 | fixed_point_no_grad_min_iters=args.fixed_point_no_grad_min_iters, 271 | fixed_point_no_grad_max_iters=args.fixed_point_no_grad_max_iters, 272 | fixed_point_with_grad_min_iters=args.fixed_point_with_grad_min_iters, 273 | fixed_point_with_grad_max_iters=args.fixed_point_with_grad_max_iters, 274 | fixed_point_pre_post_timestep_conditioning=args.fixed_point_pre_post_timestep_conditioning, 275 | ).to(device) 276 | print(f'Loaded model with params: {sum(p.numel() for p in model.parameters()):_}') 277 | 278 | # Note that parameter initialization is done within the DiT constructor 279 | ema = deepcopy(model).to(device) # Create an EMA of the model for use after training 280 | requires_grad(ema, False) 281 | diffusion = create_diffusion( 282 | timestep_respacing="", 283 | use_flow=args.flow, 284 | predict_v=args.predict_v, 285 | use_zero_terminal_snr=args.use_zero_terminal_snr, 286 | ) # default: 1000 steps, linear noise schedule 287 | # # Note: the VAE is not used because we assume all images are already preprocessed 288 | # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) 289 | print(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}") 290 | 291 | # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): 292 | opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0) 293 | 294 | # Setup data: 295 | batch_size = int(args.global_batch_size // accelerator.num_processes) 296 | # if args.use_streaming_dataset: 297 | # data_dir = f"{args.feature_path}/{args.dataset_name}_streaming" 298 | # dataset = CustomStreamingDataset(data_dir, shuffle=True, batch_size=batch_size) 299 | # load_kwargs = dict() 300 | # else: 301 | features_dir = f"{args.feature_path}/{args.dataset_name}_features" 302 | labels_dir = f"{args.feature_path}/{args.dataset_name}_{'dino_vitb8' if args.dino_supervised else 'labels'}" 303 | dataset = CustomDataset(features_dir, labels_dir) 304 | load_kwargs = dict(shuffle=True, pin_memory=True, drop_last=True) 305 | loader = DataLoader( 306 | dataset, batch_size=batch_size, num_workers=args.num_workers, **load_kwargs 307 | ) 308 | print(f"Dataset contains {len(dataset):,} images ({args.feature_path})") 309 | 310 | # Load from checkpoint 311 | train_steps = 0 312 | if args.resume is not None: 313 | checkpoint: dict = torch.load(args.resume, map_location='cpu') 314 | model.load_state_dict(checkpoint['model']) 315 | ema.load_state_dict(checkpoint['ema']) 316 | opt.load_state_dict(checkpoint['opt']) 317 | train_steps = checkpoint['train_steps'] if 'train_steps' in checkpoint else int(Path(args.resume).stem) 318 | print(f'Resuming from checkpoint: {args.resume}') 319 | 320 | # Prepare models for training: 321 | update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights 322 | model.train() # important! This enables embedding dropout for classifier-free guidance 323 | ema.eval() # EMA model should always be in eval mode 324 | model, opt, loader = accelerator.prepare(model, opt, loader) 325 | 326 | # Train 327 | log_steps = 0 328 | running_loss = 0 329 | steps_per_sec = 0 330 | start_time = time() 331 | progress_bar = tqdm() 332 | print(f"Training for {args.epochs} epochs...") 333 | for epoch in range(args.epochs): 334 | if accelerator.is_main_process: 335 | print(f"Beginning epoch {epoch}...") 336 | for x, y in loader: 337 | x = x.to(device) 338 | y = y.to(device) 339 | x = x.squeeze(dim=1) 340 | y = y.squeeze(dim=-1) 341 | if args.unsupervised: # replace class labels with zeros 342 | y = torch.zeros_like(y) 343 | t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device) 344 | model_kwargs = dict(y=y) 345 | loss_dict = diffusion.training_losses(model, x, t, model_kwargs) 346 | loss = loss_dict["loss"].mean() 347 | loss_float = loss.item() 348 | opt.zero_grad() 349 | accelerator.backward(loss) 350 | if train_steps < 5 and accelerator.is_main_process: # debug 351 | print(f'[Step {train_steps}] Params total: {sum(p.numel() for p in model.parameters()):_}') 352 | print(f'[Step {train_steps}] Params req. grad: {sum(p.numel() for p in model.parameters() if p.requires_grad):_}') 353 | print(f'[Step {train_steps}] Params with grad: {sum(p.numel() for p in model.parameters() if p.requires_grad and p.grad is not None):_}') 354 | opt.step() 355 | update_ema(ema, model) 356 | 357 | # Log every step 358 | if train_steps % 5 == 0 and accelerator.is_main_process: 359 | accelerator.log({ 360 | "train/step": train_steps, "train/loss": loss_float, 361 | }, step=train_steps) 362 | progress_bar.set_description_str((f"train/step: {train_steps}, train/steps_per_sec: {steps_per_sec:.2f}, train/loss: {loss_float:.4f}")) 363 | 364 | # Print periodically 365 | running_loss += loss_float 366 | log_steps += 1 367 | train_steps += 1 368 | progress_bar.update() 369 | if train_steps % args.log_every == 0: 370 | # Measure training speed: 371 | torch.cuda.synchronize() 372 | end_time = time() 373 | steps_per_sec = log_steps / (end_time - start_time) 374 | # Reduce loss history over all processes: 375 | avg_loss = torch.tensor(running_loss / log_steps, device=device) 376 | avg_loss = avg_loss.item() / accelerator.num_processes 377 | print(f"\n(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") 378 | accelerator.log({"train/steps_per_sec": steps_per_sec}, step=train_steps) # also log steps per second 379 | # Reset monitoring variables: 380 | running_loss = 0 381 | log_steps = 0 382 | start_time = time() 383 | 384 | # Save DiT checkpoint: 385 | if train_steps > 0 and accelerator.is_main_process and (train_steps % 5000 == 0 or train_steps % args.ckpt_every == 0): 386 | checkpoint = { 387 | "model": model.module.state_dict(), 388 | "ema": ema.state_dict(), 389 | "opt": opt.state_dict(), 390 | "args": args.as_dict(), 391 | "train_steps": train_steps, 392 | } 393 | if train_steps % 5000 == 0: 394 | checkpoint_path = f"{checkpoint_dir}/latest.pt" 395 | torch.save(checkpoint, checkpoint_path) 396 | if train_steps % args.ckpt_every == 0: 397 | checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" 398 | torch.save(checkpoint, checkpoint_path) 399 | print(f"Saved checkpoint to {checkpoint_path}") 400 | 401 | model.eval() # important! This disables randomized embedding dropout 402 | # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... 403 | 404 | accelerator.end_training() 405 | print("Done!") 406 | 407 | 408 | if __name__ == "__main__": 409 | torch.backends.cuda.matmul.allow_tf32 = True 410 | torch.backends.cudnn.allow_tf32 = True 411 | args = Args(explicit_bool=True).parse_args() 412 | main(args) -------------------------------------------------------------------------------- /utils/imagenet-labels.json: -------------------------------------------------------------------------------- 1 | ["tench", 2 | "goldfish", 3 | "great white shark", 4 | "tiger shark", 5 | "hammerhead shark", 6 | "electric ray", 7 | "stingray", 8 | "cock", 9 | "hen", 10 | "ostrich", 11 | "brambling", 12 | "goldfinch", 13 | "house finch", 14 | "junco", 15 | "indigo bunting", 16 | "American robin", 17 | "bulbul", 18 | "jay", 19 | "magpie", 20 | "chickadee", 21 | "American dipper", 22 | "kite", 23 | "bald eagle", 24 | "vulture", 25 | "great grey owl", 26 | "fire salamander", 27 | "smooth newt", 28 | "newt", 29 | "spotted salamander", 30 | "axolotl", 31 | "American bullfrog", 32 | "tree frog", 33 | "tailed frog", 34 | "loggerhead sea turtle", 35 | "leatherback sea turtle", 36 | "mud turtle", 37 | "terrapin", 38 | "box turtle", 39 | "banded gecko", 40 | "green iguana", 41 | "Carolina anole", 42 | "desert grassland whiptail lizard", 43 | "agama", 44 | "frilled-necked lizard", 45 | "alligator lizard", 46 | "Gila monster", 47 | "European green lizard", 48 | "chameleon", 49 | "Komodo dragon", 50 | "Nile crocodile", 51 | "American alligator", 52 | "triceratops", 53 | "worm snake", 54 | "ring-necked snake", 55 | "eastern hog-nosed snake", 56 | "smooth green snake", 57 | "kingsnake", 58 | "garter snake", 59 | "water snake", 60 | "vine snake", 61 | "night snake", 62 | "boa constrictor", 63 | "African rock python", 64 | "Indian cobra", 65 | "green mamba", 66 | "sea snake", 67 | "Saharan horned viper", 68 | "eastern diamondback rattlesnake", 69 | "sidewinder", 70 | "trilobite", 71 | "harvestman", 72 | "scorpion", 73 | "yellow garden spider", 74 | "barn spider", 75 | "European garden spider", 76 | "southern black widow", 77 | "tarantula", 78 | "wolf spider", 79 | "tick", 80 | "centipede", 81 | "black grouse", 82 | "ptarmigan", 83 | "ruffed grouse", 84 | "prairie grouse", 85 | "peacock", 86 | "quail", 87 | "partridge", 88 | "grey parrot", 89 | "macaw", 90 | "sulphur-crested cockatoo", 91 | "lorikeet", 92 | "coucal", 93 | "bee eater", 94 | "hornbill", 95 | "hummingbird", 96 | "jacamar", 97 | "toucan", 98 | "duck", 99 | "red-breasted merganser", 100 | "goose", 101 | "black swan", 102 | "tusker", 103 | "echidna", 104 | "platypus", 105 | "wallaby", 106 | "koala", 107 | "wombat", 108 | "jellyfish", 109 | "sea anemone", 110 | "brain coral", 111 | "flatworm", 112 | "nematode", 113 | "conch", 114 | "snail", 115 | "slug", 116 | "sea slug", 117 | "chiton", 118 | "chambered nautilus", 119 | "Dungeness crab", 120 | "rock crab", 121 | "fiddler crab", 122 | "red king crab", 123 | "American lobster", 124 | "spiny lobster", 125 | "crayfish", 126 | "hermit crab", 127 | "isopod", 128 | "white stork", 129 | "black stork", 130 | "spoonbill", 131 | "flamingo", 132 | "little blue heron", 133 | "great egret", 134 | "bittern", 135 | "crane (bird)", 136 | "limpkin", 137 | "common gallinule", 138 | "American coot", 139 | "bustard", 140 | "ruddy turnstone", 141 | "dunlin", 142 | "common redshank", 143 | "dowitcher", 144 | "oystercatcher", 145 | "pelican", 146 | "king penguin", 147 | "albatross", 148 | "grey whale", 149 | "killer whale", 150 | "dugong", 151 | "sea lion", 152 | "Chihuahua", 153 | "Japanese Chin", 154 | "Maltese", 155 | "Pekingese", 156 | "Shih Tzu", 157 | "King Charles Spaniel", 158 | "Papillon", 159 | "toy terrier", 160 | "Rhodesian Ridgeback", 161 | "Afghan Hound", 162 | "Basset Hound", 163 | "Beagle", 164 | "Bloodhound", 165 | "Bluetick Coonhound", 166 | "Black and Tan Coonhound", 167 | "Treeing Walker Coonhound", 168 | "English foxhound", 169 | "Redbone Coonhound", 170 | "borzoi", 171 | "Irish Wolfhound", 172 | "Italian Greyhound", 173 | "Whippet", 174 | "Ibizan Hound", 175 | "Norwegian Elkhound", 176 | "Otterhound", 177 | "Saluki", 178 | "Scottish Deerhound", 179 | "Weimaraner", 180 | "Staffordshire Bull Terrier", 181 | "American Staffordshire Terrier", 182 | "Bedlington Terrier", 183 | "Border Terrier", 184 | "Kerry Blue Terrier", 185 | "Irish Terrier", 186 | "Norfolk Terrier", 187 | "Norwich Terrier", 188 | "Yorkshire Terrier", 189 | "Wire Fox Terrier", 190 | "Lakeland Terrier", 191 | "Sealyham Terrier", 192 | "Airedale Terrier", 193 | "Cairn Terrier", 194 | "Australian Terrier", 195 | "Dandie Dinmont Terrier", 196 | "Boston Terrier", 197 | "Miniature Schnauzer", 198 | "Giant Schnauzer", 199 | "Standard Schnauzer", 200 | "Scottish Terrier", 201 | "Tibetan Terrier", 202 | "Australian Silky Terrier", 203 | "Soft-coated Wheaten Terrier", 204 | "West Highland White Terrier", 205 | "Lhasa Apso", 206 | "Flat-Coated Retriever", 207 | "Curly-coated Retriever", 208 | "Golden Retriever", 209 | "Labrador Retriever", 210 | "Chesapeake Bay Retriever", 211 | "German Shorthaired Pointer", 212 | "Vizsla", 213 | "English Setter", 214 | "Irish Setter", 215 | "Gordon Setter", 216 | "Brittany", 217 | "Clumber Spaniel", 218 | "English Springer Spaniel", 219 | "Welsh Springer Spaniel", 220 | "Cocker Spaniels", 221 | "Sussex Spaniel", 222 | "Irish Water Spaniel", 223 | "Kuvasz", 224 | "Schipperke", 225 | "Groenendael", 226 | "Malinois", 227 | "Briard", 228 | "Australian Kelpie", 229 | "Komondor", 230 | "Old English Sheepdog", 231 | "Shetland Sheepdog", 232 | "collie", 233 | "Border Collie", 234 | "Bouvier des Flandres", 235 | "Rottweiler", 236 | "German Shepherd Dog", 237 | "Dobermann", 238 | "Miniature Pinscher", 239 | "Greater Swiss Mountain Dog", 240 | "Bernese Mountain Dog", 241 | "Appenzeller Sennenhund", 242 | "Entlebucher Sennenhund", 243 | "Boxer", 244 | "Bullmastiff", 245 | "Tibetan Mastiff", 246 | "French Bulldog", 247 | "Great Dane", 248 | "St. Bernard", 249 | "husky", 250 | "Alaskan Malamute", 251 | "Siberian Husky", 252 | "Dalmatian", 253 | "Affenpinscher", 254 | "Basenji", 255 | "pug", 256 | "Leonberger", 257 | "Newfoundland", 258 | "Pyrenean Mountain Dog", 259 | "Samoyed", 260 | "Pomeranian", 261 | "Chow Chow", 262 | "Keeshond", 263 | "Griffon Bruxellois", 264 | "Pembroke Welsh Corgi", 265 | "Cardigan Welsh Corgi", 266 | "Toy Poodle", 267 | "Miniature Poodle", 268 | "Standard Poodle", 269 | "Mexican hairless dog", 270 | "grey wolf", 271 | "Alaskan tundra wolf", 272 | "red wolf", 273 | "coyote", 274 | "dingo", 275 | "dhole", 276 | "African wild dog", 277 | "hyena", 278 | "red fox", 279 | "kit fox", 280 | "Arctic fox", 281 | "grey fox", 282 | "tabby cat", 283 | "tiger cat", 284 | "Persian cat", 285 | "Siamese cat", 286 | "Egyptian Mau", 287 | "cougar", 288 | "lynx", 289 | "leopard", 290 | "snow leopard", 291 | "jaguar", 292 | "lion", 293 | "tiger", 294 | "cheetah", 295 | "brown bear", 296 | "American black bear", 297 | "polar bear", 298 | "sloth bear", 299 | "mongoose", 300 | "meerkat", 301 | "tiger beetle", 302 | "ladybug", 303 | "ground beetle", 304 | "longhorn beetle", 305 | "leaf beetle", 306 | "dung beetle", 307 | "rhinoceros beetle", 308 | "weevil", 309 | "fly", 310 | "bee", 311 | "ant", 312 | "grasshopper", 313 | "cricket", 314 | "stick insect", 315 | "cockroach", 316 | "mantis", 317 | "cicada", 318 | "leafhopper", 319 | "lacewing", 320 | "dragonfly", 321 | "damselfly", 322 | "red admiral", 323 | "ringlet", 324 | "monarch butterfly", 325 | "small white", 326 | "sulphur butterfly", 327 | "gossamer-winged butterfly", 328 | "starfish", 329 | "sea urchin", 330 | "sea cucumber", 331 | "cottontail rabbit", 332 | "hare", 333 | "Angora rabbit", 334 | "hamster", 335 | "porcupine", 336 | "fox squirrel", 337 | "marmot", 338 | "beaver", 339 | "guinea pig", 340 | "common sorrel", 341 | "zebra", 342 | "pig", 343 | "wild boar", 344 | "warthog", 345 | "hippopotamus", 346 | "ox", 347 | "water buffalo", 348 | "bison", 349 | "ram", 350 | "bighorn sheep", 351 | "Alpine ibex", 352 | "hartebeest", 353 | "impala", 354 | "gazelle", 355 | "dromedary", 356 | "llama", 357 | "weasel", 358 | "mink", 359 | "European polecat", 360 | "black-footed ferret", 361 | "otter", 362 | "skunk", 363 | "badger", 364 | "armadillo", 365 | "three-toed sloth", 366 | "orangutan", 367 | "gorilla", 368 | "chimpanzee", 369 | "gibbon", 370 | "siamang", 371 | "guenon", 372 | "patas monkey", 373 | "baboon", 374 | "macaque", 375 | "langur", 376 | "black-and-white colobus", 377 | "proboscis monkey", 378 | "marmoset", 379 | "white-headed capuchin", 380 | "howler monkey", 381 | "titi", 382 | "Geoffroy's spider monkey", 383 | "common squirrel monkey", 384 | "ring-tailed lemur", 385 | "indri", 386 | "Asian elephant", 387 | "African bush elephant", 388 | "red panda", 389 | "giant panda", 390 | "snoek", 391 | "eel", 392 | "coho salmon", 393 | "rock beauty", 394 | "clownfish", 395 | "sturgeon", 396 | "garfish", 397 | "lionfish", 398 | "pufferfish", 399 | "abacus", 400 | "abaya", 401 | "academic gown", 402 | "accordion", 403 | "acoustic guitar", 404 | "aircraft carrier", 405 | "airliner", 406 | "airship", 407 | "altar", 408 | "ambulance", 409 | "amphibious vehicle", 410 | "analog clock", 411 | "apiary", 412 | "apron", 413 | "waste container", 414 | "assault rifle", 415 | "backpack", 416 | "bakery", 417 | "balance beam", 418 | "balloon", 419 | "ballpoint pen", 420 | "Band-Aid", 421 | "banjo", 422 | "baluster", 423 | "barbell", 424 | "barber chair", 425 | "barbershop", 426 | "barn", 427 | "barometer", 428 | "barrel", 429 | "wheelbarrow", 430 | "baseball", 431 | "basketball", 432 | "bassinet", 433 | "bassoon", 434 | "swimming cap", 435 | "bath towel", 436 | "bathtub", 437 | "station wagon", 438 | "lighthouse", 439 | "beaker", 440 | "military cap", 441 | "beer bottle", 442 | "beer glass", 443 | "bell-cot", 444 | "bib", 445 | "tandem bicycle", 446 | "bikini", 447 | "ring binder", 448 | "binoculars", 449 | "birdhouse", 450 | "boathouse", 451 | "bobsleigh", 452 | "bolo tie", 453 | "poke bonnet", 454 | "bookcase", 455 | "bookstore", 456 | "bottle cap", 457 | "bow", 458 | "bow tie", 459 | "brass", 460 | "bra", 461 | "breakwater", 462 | "breastplate", 463 | "broom", 464 | "bucket", 465 | "buckle", 466 | "bulletproof vest", 467 | "high-speed train", 468 | "butcher shop", 469 | "taxicab", 470 | "cauldron", 471 | "candle", 472 | "cannon", 473 | "canoe", 474 | "can opener", 475 | "cardigan", 476 | "car mirror", 477 | "carousel", 478 | "tool kit", 479 | "carton", 480 | "car wheel", 481 | "automated teller machine", 482 | "cassette", 483 | "cassette player", 484 | "castle", 485 | "catamaran", 486 | "CD player", 487 | "cello", 488 | "mobile phone", 489 | "chain", 490 | "chain-link fence", 491 | "chain mail", 492 | "chainsaw", 493 | "chest", 494 | "chiffonier", 495 | "chime", 496 | "china cabinet", 497 | "Christmas stocking", 498 | "church", 499 | "movie theater", 500 | "cleaver", 501 | "cliff dwelling", 502 | "cloak", 503 | "clogs", 504 | "cocktail shaker", 505 | "coffee mug", 506 | "coffeemaker", 507 | "coil", 508 | "combination lock", 509 | "computer keyboard", 510 | "confectionery store", 511 | "container ship", 512 | "convertible", 513 | "corkscrew", 514 | "cornet", 515 | "cowboy boot", 516 | "cowboy hat", 517 | "cradle", 518 | "crane (machine)", 519 | "crash helmet", 520 | "crate", 521 | "infant bed", 522 | "Crock Pot", 523 | "croquet ball", 524 | "crutch", 525 | "cuirass", 526 | "dam", 527 | "desk", 528 | "desktop computer", 529 | "rotary dial telephone", 530 | "diaper", 531 | "digital clock", 532 | "digital watch", 533 | "dining table", 534 | "dishcloth", 535 | "dishwasher", 536 | "disc brake", 537 | "dock", 538 | "dog sled", 539 | "dome", 540 | "doormat", 541 | "drilling rig", 542 | "drum", 543 | "drumstick", 544 | "dumbbell", 545 | "Dutch oven", 546 | "electric fan", 547 | "electric guitar", 548 | "electric locomotive", 549 | "entertainment center", 550 | "envelope", 551 | "espresso machine", 552 | "face powder", 553 | "feather boa", 554 | "filing cabinet", 555 | "fireboat", 556 | "fire engine", 557 | "fire screen sheet", 558 | "flagpole", 559 | "flute", 560 | "folding chair", 561 | "football helmet", 562 | "forklift", 563 | "fountain", 564 | "fountain pen", 565 | "four-poster bed", 566 | "freight car", 567 | "French horn", 568 | "frying pan", 569 | "fur coat", 570 | "garbage truck", 571 | "gas mask", 572 | "gas pump", 573 | "goblet", 574 | "go-kart", 575 | "golf ball", 576 | "golf cart", 577 | "gondola", 578 | "gong", 579 | "gown", 580 | "grand piano", 581 | "greenhouse", 582 | "grille", 583 | "grocery store", 584 | "guillotine", 585 | "barrette", 586 | "hair spray", 587 | "half-track", 588 | "hammer", 589 | "hamper", 590 | "hair dryer", 591 | "hand-held computer", 592 | "handkerchief", 593 | "hard disk drive", 594 | "harmonica", 595 | "harp", 596 | "harvester", 597 | "hatchet", 598 | "holster", 599 | "home theater", 600 | "honeycomb", 601 | "hook", 602 | "hoop skirt", 603 | "horizontal bar", 604 | "horse-drawn vehicle", 605 | "hourglass", 606 | "iPod", 607 | "clothes iron", 608 | "jack-o'-lantern", 609 | "jeans", 610 | "jeep", 611 | "T-shirt", 612 | "jigsaw puzzle", 613 | "pulled rickshaw", 614 | "joystick", 615 | "kimono", 616 | "knee pad", 617 | "knot", 618 | "lab coat", 619 | "ladle", 620 | "lampshade", 621 | "laptop computer", 622 | "lawn mower", 623 | "lens cap", 624 | "paper knife", 625 | "library", 626 | "lifeboat", 627 | "lighter", 628 | "limousine", 629 | "ocean liner", 630 | "lipstick", 631 | "slip-on shoe", 632 | "lotion", 633 | "speaker", 634 | "loupe", 635 | "sawmill", 636 | "magnetic compass", 637 | "mail bag", 638 | "mailbox", 639 | "tights", 640 | "tank suit", 641 | "manhole cover", 642 | "maraca", 643 | "marimba", 644 | "mask", 645 | "match", 646 | "maypole", 647 | "maze", 648 | "measuring cup", 649 | "medicine chest", 650 | "megalith", 651 | "microphone", 652 | "microwave oven", 653 | "military uniform", 654 | "milk can", 655 | "minibus", 656 | "miniskirt", 657 | "minivan", 658 | "missile", 659 | "mitten", 660 | "mixing bowl", 661 | "mobile home", 662 | "Model T", 663 | "modem", 664 | "monastery", 665 | "monitor", 666 | "moped", 667 | "mortar", 668 | "square academic cap", 669 | "mosque", 670 | "mosquito net", 671 | "scooter", 672 | "mountain bike", 673 | "tent", 674 | "computer mouse", 675 | "mousetrap", 676 | "moving van", 677 | "muzzle", 678 | "nail", 679 | "neck brace", 680 | "necklace", 681 | "nipple", 682 | "notebook computer", 683 | "obelisk", 684 | "oboe", 685 | "ocarina", 686 | "odometer", 687 | "oil filter", 688 | "organ", 689 | "oscilloscope", 690 | "overskirt", 691 | "bullock cart", 692 | "oxygen mask", 693 | "packet", 694 | "paddle", 695 | "paddle wheel", 696 | "padlock", 697 | "paintbrush", 698 | "pajamas", 699 | "palace", 700 | "pan flute", 701 | "paper towel", 702 | "parachute", 703 | "parallel bars", 704 | "park bench", 705 | "parking meter", 706 | "passenger car", 707 | "patio", 708 | "payphone", 709 | "pedestal", 710 | "pencil case", 711 | "pencil sharpener", 712 | "perfume", 713 | "Petri dish", 714 | "photocopier", 715 | "plectrum", 716 | "Pickelhaube", 717 | "picket fence", 718 | "pickup truck", 719 | "pier", 720 | "piggy bank", 721 | "pill bottle", 722 | "pillow", 723 | "ping-pong ball", 724 | "pinwheel", 725 | "pirate ship", 726 | "pitcher", 727 | "hand plane", 728 | "planetarium", 729 | "plastic bag", 730 | "plate rack", 731 | "plow", 732 | "plunger", 733 | "Polaroid camera", 734 | "pole", 735 | "police van", 736 | "poncho", 737 | "billiard table", 738 | "soda bottle", 739 | "pot", 740 | "potter's wheel", 741 | "power drill", 742 | "prayer rug", 743 | "printer", 744 | "prison", 745 | "projectile", 746 | "projector", 747 | "hockey puck", 748 | "punching bag", 749 | "purse", 750 | "quill", 751 | "quilt", 752 | "race car", 753 | "racket", 754 | "radiator", 755 | "radio", 756 | "radio telescope", 757 | "rain barrel", 758 | "recreational vehicle", 759 | "reel", 760 | "reflex camera", 761 | "refrigerator", 762 | "remote control", 763 | "restaurant", 764 | "revolver", 765 | "rifle", 766 | "rocking chair", 767 | "rotisserie", 768 | "eraser", 769 | "rugby ball", 770 | "ruler", 771 | "running shoe", 772 | "safe", 773 | "safety pin", 774 | "salt shaker", 775 | "sandal", 776 | "sarong", 777 | "saxophone", 778 | "scabbard", 779 | "weighing scale", 780 | "school bus", 781 | "schooner", 782 | "scoreboard", 783 | "CRT screen", 784 | "screw", 785 | "screwdriver", 786 | "seat belt", 787 | "sewing machine", 788 | "shield", 789 | "shoe store", 790 | "shoji", 791 | "shopping basket", 792 | "shopping cart", 793 | "shovel", 794 | "shower cap", 795 | "shower curtain", 796 | "ski", 797 | "ski mask", 798 | "sleeping bag", 799 | "slide rule", 800 | "sliding door", 801 | "slot machine", 802 | "snorkel", 803 | "snowmobile", 804 | "snowplow", 805 | "soap dispenser", 806 | "soccer ball", 807 | "sock", 808 | "solar thermal collector", 809 | "sombrero", 810 | "soup bowl", 811 | "space bar", 812 | "space heater", 813 | "space shuttle", 814 | "spatula", 815 | "motorboat", 816 | "spider web", 817 | "spindle", 818 | "sports car", 819 | "spotlight", 820 | "stage", 821 | "steam locomotive", 822 | "through arch bridge", 823 | "steel drum", 824 | "stethoscope", 825 | "scarf", 826 | "stone wall", 827 | "stopwatch", 828 | "stove", 829 | "strainer", 830 | "tram", 831 | "stretcher", 832 | "couch", 833 | "stupa", 834 | "submarine", 835 | "suit", 836 | "sundial", 837 | "sunglass", 838 | "sunglasses", 839 | "sunscreen", 840 | "suspension bridge", 841 | "mop", 842 | "sweatshirt", 843 | "swimsuit", 844 | "swing", 845 | "switch", 846 | "syringe", 847 | "table lamp", 848 | "tank", 849 | "tape player", 850 | "teapot", 851 | "teddy bear", 852 | "television", 853 | "tennis ball", 854 | "thatched roof", 855 | "front curtain", 856 | "thimble", 857 | "threshing machine", 858 | "throne", 859 | "tile roof", 860 | "toaster", 861 | "tobacco shop", 862 | "toilet seat", 863 | "torch", 864 | "totem pole", 865 | "tow truck", 866 | "toy store", 867 | "tractor", 868 | "semi-trailer truck", 869 | "tray", 870 | "trench coat", 871 | "tricycle", 872 | "trimaran", 873 | "tripod", 874 | "triumphal arch", 875 | "trolleybus", 876 | "trombone", 877 | "tub", 878 | "turnstile", 879 | "typewriter keyboard", 880 | "umbrella", 881 | "unicycle", 882 | "upright piano", 883 | "vacuum cleaner", 884 | "vase", 885 | "vault", 886 | "velvet", 887 | "vending machine", 888 | "vestment", 889 | "viaduct", 890 | "violin", 891 | "volleyball", 892 | "waffle iron", 893 | "wall clock", 894 | "wallet", 895 | "wardrobe", 896 | "military aircraft", 897 | "sink", 898 | "washing machine", 899 | "water bottle", 900 | "water jug", 901 | "water tower", 902 | "whiskey jug", 903 | "whistle", 904 | "wig", 905 | "window screen", 906 | "window shade", 907 | "Windsor tie", 908 | "wine bottle", 909 | "wing", 910 | "wok", 911 | "wooden spoon", 912 | "wool", 913 | "split-rail fence", 914 | "shipwreck", 915 | "yawl", 916 | "yurt", 917 | "website", 918 | "comic book", 919 | "crossword", 920 | "traffic sign", 921 | "traffic light", 922 | "dust jacket", 923 | "menu", 924 | "plate", 925 | "guacamole", 926 | "consomme", 927 | "hot pot", 928 | "trifle", 929 | "ice cream", 930 | "ice pop", 931 | "baguette", 932 | "bagel", 933 | "pretzel", 934 | "cheeseburger", 935 | "hot dog", 936 | "mashed potato", 937 | "cabbage", 938 | "broccoli", 939 | "cauliflower", 940 | "zucchini", 941 | "spaghetti squash", 942 | "acorn squash", 943 | "butternut squash", 944 | "cucumber", 945 | "artichoke", 946 | "bell pepper", 947 | "cardoon", 948 | "mushroom", 949 | "Granny Smith", 950 | "strawberry", 951 | "orange", 952 | "lemon", 953 | "fig", 954 | "pineapple", 955 | "banana", 956 | "jackfruit", 957 | "custard apple", 958 | "pomegranate", 959 | "hay", 960 | "carbonara", 961 | "chocolate syrup", 962 | "dough", 963 | "meatloaf", 964 | "pizza", 965 | "pot pie", 966 | "burrito", 967 | "red wine", 968 | "espresso", 969 | "cup", 970 | "eggnog", 971 | "alp", 972 | "bubble", 973 | "cliff", 974 | "coral reef", 975 | "geyser", 976 | "lakeshore", 977 | "promontory", 978 | "shoal", 979 | "seashore", 980 | "valley", 981 | "volcano", 982 | "baseball player", 983 | "bridegroom", 984 | "scuba diver", 985 | "rapeseed", 986 | "daisy", 987 | "yellow lady's slipper", 988 | "corn", 989 | "acorn", 990 | "rose hip", 991 | "horse chestnut seed", 992 | "coral fungus", 993 | "agaric", 994 | "gyromitra", 995 | "stinkhorn mushroom", 996 | "earth star", 997 | "hen-of-the-woods", 998 | "bolete", 999 | "ear", 1000 | "toilet paper"] 1001 | -------------------------------------------------------------------------------- /visuals/splash-figure-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/fixed-point-diffusion-models/519e1286ba27c34e177e05962c5d9e66edce31e6/visuals/splash-figure-v1.png --------------------------------------------------------------------------------