├── .gitignore ├── LICENSE ├── README.md ├── STAT.md ├── assets └── GET.png ├── data ├── cond_generate.py ├── dataset.sh └── generate.py ├── datasets.py ├── environment.yml ├── eval.py ├── eval.sh ├── losses.py ├── models ├── __init__.py ├── get.py └── vit.py ├── run.sh ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # results 132 | eval-results/ 133 | results/ 134 | 135 | # test scripts 136 | eval_test.sh 137 | run_test.sh 138 | 139 | # TODO list 140 | TODO.md 141 | 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 CMU Locus Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Equilibrium Transformer 2 | 3 | This is the official repo for the paper [*One-Step Diffusion Distillation via Deep Equilibrium Models*](), 4 | by [Zhengyang Geng](https://gsunshine.github.io/)\*, [Ashwini Pokle](https://ashwinipokle.github.io/)\*, and [J. Zico Kolter](http://zicokolter.com/). 5 | 6 |
7 | 8 | ## Environment 9 | 10 | ## Dataset 11 | 12 | First, download the datasets **EDM-Uncond-CIFAR** and **EDM-Cond-CIFAR** from [this link](https://drive.google.com/drive/folders/1dlFiS5ahwu7xne6fUNVNELaG12-AKn4I?usp=sharing). 13 | Set up the `--data_path` in `run.sh` to the dir where you store the datasets, like `--data_path DATA_DIR/EDM-Uncond-CIFAR-1M`. 14 | 15 | In addition, download the precomputed dataset statistics from [this link](https://drive.google.com/drive/folders/1UBdzl6GtNMwNQ5U-4ESlIer43tNjiGJC). 16 | Set up the `--stat_path` in `run.sh` and `eval.sh` using your download dir plus stat name. 17 | 18 | ## Training 19 | 20 | To train a GET, run this command: 21 | 22 | ```bash 23 | bash run.sh N_GPU DDP_PORT --model MODEL_NAME --name EXP_NAME 24 | ``` 25 | 26 | `N_GPU` is the number of GPU used for training. 27 | `DDP_PORT` is the port number for syncing gradient during distributed training. 28 | `MODEL_NAME` is the model's name. 29 | See all available models using `python train.py -h`. 30 | The training log, checkpoints, and sampled images will be saved to `./results` using your `EXP_NAME`. 31 | 32 | For example, this command train a GET-S/2 (of patch size 2) on 4 GPUs. 33 | 34 | ```bash 35 | bash run.sh 4 12345 --model GET-S/2 --name test-GET 36 | ``` 37 | 38 | To train a ViT, run this command: 39 | 40 | ```bash 41 | bash run.sh N_GPU DDP_PORT --model ViT-B/2 --name EXP_NAME 42 | ``` 43 | 44 | For training **conditional** models, add the `--cond` command. 45 | 46 | For the **O(1)-memory** training, add the `--mem` command. 47 | 48 | ## Evaluation 49 | 50 | Download pretrained models from [this link](https://drive.google.com/drive/u/1/folders/1g998S6moSQhybD9poDJHmXP85QF3zz4g). 51 | 52 | To load a checkpoint for evaluation, run this command 53 | 54 | ```bash 55 | bash run.sh N_GPU DDP_PORT --model MODEL_NAME --resume CKPT_PATH --name EXP_NAME 56 | ``` 57 | 58 | The evaluation log and sampled images will be saved to `./eval-results` plus your `EXP_NAME`. 59 | 60 | For evaluating conditional models, add the `--cond` command. Here is an example. 61 | 62 | ```bash 63 | bash run.sh 4 12345 --model GET-B/2 --cond --resume CKPT_DIR/GET-B-cond-2M-data-bs256.pth 64 | ``` 65 | 66 | ## Generative Performance 67 | 68 | You can see [the generative performance here](STAT.md). The discussion there might be interesting. 69 | 70 | ## Data Generation 71 | 72 | First, clone the EDM repo. Then, copy the files under `/data` to the `/edm` directory. 73 | 74 | Set up the `DATA_PATH` in `dataset.sh` for storing the synthetic dataset. 75 | Run the following command to generate both conditional and unconditional training sets. 76 | 77 | ```bash 78 | bash dataset.sh 79 | ``` 80 | 81 | If you want to generate more data pairs, adjust the range of `--seeds=0-MAX_SAMPLES`. 82 | 83 | ## Bibtex 84 | 85 | If you find our work helpful to your research, please consider citing this paper. :) 86 | 87 | ```bib 88 | @inproceedings{ 89 | geng2023onestep, 90 | title={One-Step Diffusion Distillation via Deep Equilibrium Models}, 91 | author={Zhengyang Geng and Ashwini Pokle and J Zico Kolter}, 92 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 93 | year={2023} 94 | } 95 | ``` 96 | 97 | ## Contact 98 | 99 | Feel free to contact us if you have additional questions! 100 | Please drop an email to zhengyanggeng@gmail.com (or [Twitter](https://twitter.com/ZhengyangGeng)) 101 | or apokle@andrew.cmu.edu. 102 | 103 | ## Acknowledgment 104 | 105 | This project is built upon [TorchDEQ](https://github.com/locuslab/torchdeq), 106 | [DiT](https://arxiv.org/abs/2212.09748), 107 | and [timm](https://github.com/huggingface/pytorch-image-models). 108 | Thanks for the awesome projects! 109 | -------------------------------------------------------------------------------- /STAT.md: -------------------------------------------------------------------------------- 1 | 2 | # Model Performance 3 | 4 | Here are the generative performances under different training settings and model configs. 5 | 6 | Generative performance of unconditional models. 7 | 8 | | Model Name | Type | Params | FID | IS | Training Data | BS | Iters | 9 | | :--------- | :-- | :----: | :-: | :-: | :----------: | :-: | :---: | 10 | | [GET-T](https://drive.google.com/file/d/1rDw5A34ZnTajQZLSb_7fGkUfQU6viwq8/view?usp=sharing) | Uncond | 8.6M | 15.23 | 8.40 | 1M | 128 | 800k | 11 | | [GET-M](https://drive.google.com/file/d/1bAcRl0dWDxzIkm3sZBzBMABw6y78mVcQ/view?usp=sharing) | Uncond | 19.2M | 10.81 | 8.77 | 1M | 128 | 800k | 12 | | [GET-S](https://drive.google.com/file/d/1rN2rD7WUDaJaL3uRU5eX14wKQAg7Zy8z/view?usp=sharing) | Uncond | 37.2M | 7.99 | 9.05 | 1M | 128 | 800k | 13 | | [GET-B](https://drive.google.com/file/d/1k7qMLfqxctFNldsUSuapLP96oIllwZ1H/view?usp=sharing) | Uncond | 62.2M | 7.39 | 9.17 | 1M | 128 | 800k | 14 | | [GET-B+](https://drive.google.com/file/d/1jUE1lqs0qsbqbLyl9nmROcxrETDUXx25/view?usp=sharing) | Uncond | 83.5M | 7.21 | 9.07 | 1M | 128 | 800k | 15 | 16 | Generative performance of class-conditional models. 17 | 18 | | Model Name | Type | Params | FID | IS | Training Data | BS | Iters | 19 | | :--------- | :-- | :----: | :-: | :-: | :----------: | :-: | :---: | 20 | | [GET-B](https://drive.google.com/file/d/1BPPtWpoXVexgozaKAiZRx0N5egweHrFH/view?usp=sharing) | Cond | 62.2M | 6.23 | 9.42 | 1M | 256 | 800k | 21 | | [GET-B](https://drive.google.com/file/d/1DH8cN70OucFRoWsXJvK4vgyIrctAcoFN/view?usp=sharing) | Cond | 62.2M | 5.66 | 9.63 | 2M | 256 | 1.2M | 22 | 23 | There is a clear scaling for Generative Equilibrium Transformers. Mostly, FID has a **log linear** relation w.r.t. the input money (=Training FLOPs/Time/Data/Params) when fixing other dimensions. Scaling up the training data, better supervision from the perceptual loss/teacher model, more training FLOPs/larger batch size/longer training schedule can lead to better results, as demonstrated. 24 | 25 | Ideally, when scaling up training data, the model size and training flops need to adjust accordingly to achieve the best training efficiency. Nonetheless, restricted by our computing resources, we cannot derive the *exact* scaling law for compute-optimal models as shown in [Chinchilla](https://arxiv.org/abs/2203.15556). Despite the compute restriction, our observation still shows that GET's scaling law suggests **much more compact compute-optimal models** than ViTs, which is ideal for memory-bounded deployment, like today's LLM. 26 | 27 | In particular, this work shows that **memorizing sufficient "regular" data pairs can lead to a good generative model**, no matter for GET or ViT. The differences can be data efficiency and training efficiency (we assume there are much better training strategies though). 28 | 29 | Here, the term "regular" means the latent-image pairs are easy to learn, e.g., sampled from a pretrained model. Please note that the randomly paired latent-code, for example, shuffling the latent-image pairs in the training set, cannot be memorized by a model at a constant learning rate (we use a fixed learning rate to train models over massive pairs in the above experiments), as the loss curve is non-decreasing. This implies that **the learnability of pairing can be a strong measurement of pairing quality**. 30 | 31 | -------------------------------------------------------------------------------- /assets/GET.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/get/53377d3047b8ab677ef4e84252b15b5919fe70dd/assets/GET.png -------------------------------------------------------------------------------- /data/cond_generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Generate random images using the techniques described in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import os 12 | import re 13 | import click 14 | import tqdm 15 | import pickle 16 | import numpy as np 17 | import torch 18 | import PIL.Image 19 | import dnnlib 20 | from torch_utils import distributed as dist 21 | 22 | #---------------------------------------------------------------------------- 23 | # Proposed EDM sampler (Algorithm 2). 24 | 25 | def edm_sampler( 26 | net, latents, class_labels=None, randn_like=torch.randn_like, 27 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 28 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 29 | ): 30 | # Adjust noise levels based on what's supported by the network. 31 | sigma_min = max(sigma_min, net.sigma_min) 32 | sigma_max = min(sigma_max, net.sigma_max) 33 | 34 | # Time step discretization. 35 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 36 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 37 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 38 | 39 | # Main sampling loop. 40 | x_next = latents.to(torch.float64) * t_steps[0] 41 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 42 | x_cur = x_next 43 | 44 | # Increase noise temporarily. 45 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 46 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 47 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) 48 | 49 | # Euler step. 50 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64) 51 | d_cur = (x_hat - denoised) / t_hat 52 | x_next = x_hat + (t_next - t_hat) * d_cur 53 | 54 | # Apply 2nd order correction. 55 | if i < num_steps - 1: 56 | denoised = net(x_next, t_next, class_labels).to(torch.float64) 57 | d_prime = (x_next - denoised) / t_next 58 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 59 | 60 | return x_next 61 | 62 | #---------------------------------------------------------------------------- 63 | # Generalized ablation sampler, representing the superset of all sampling 64 | # methods discussed in the paper. 65 | 66 | def ablation_sampler( 67 | net, latents, class_labels=None, randn_like=torch.randn_like, 68 | num_steps=18, sigma_min=None, sigma_max=None, rho=7, 69 | solver='heun', discretization='edm', schedule='linear', scaling='none', 70 | epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1, 71 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 72 | ): 73 | assert solver in ['euler', 'heun'] 74 | assert discretization in ['vp', 've', 'iddpm', 'edm'] 75 | assert schedule in ['vp', 've', 'linear'] 76 | assert scaling in ['vp', 'none'] 77 | 78 | # Helper functions for VP & VE noise level schedules. 79 | vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 80 | vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) 81 | vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d 82 | ve_sigma = lambda t: t.sqrt() 83 | ve_sigma_deriv = lambda t: 0.5 / t.sqrt() 84 | ve_sigma_inv = lambda sigma: sigma ** 2 85 | 86 | # Select default noise level range based on the specified time step discretization. 87 | if sigma_min is None: 88 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s) 89 | sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization] 90 | if sigma_max is None: 91 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1) 92 | sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization] 93 | 94 | # Adjust noise levels based on what's supported by the network. 95 | sigma_min = max(sigma_min, net.sigma_min) 96 | sigma_max = min(sigma_max, net.sigma_max) 97 | 98 | # Compute corresponding betas for VP. 99 | vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1) 100 | vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d 101 | 102 | # Define time steps in terms of noise level. 103 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 104 | if discretization == 'vp': 105 | orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) 106 | sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) 107 | elif discretization == 've': 108 | orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1))) 109 | sigma_steps = ve_sigma(orig_t_steps) 110 | elif discretization == 'iddpm': 111 | u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) 112 | alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 113 | for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 114 | u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() 115 | u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] 116 | sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] 117 | else: 118 | assert discretization == 'edm' 119 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 120 | 121 | # Define noise level schedule. 122 | if schedule == 'vp': 123 | sigma = vp_sigma(vp_beta_d, vp_beta_min) 124 | sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) 125 | sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) 126 | elif schedule == 've': 127 | sigma = ve_sigma 128 | sigma_deriv = ve_sigma_deriv 129 | sigma_inv = ve_sigma_inv 130 | else: 131 | assert schedule == 'linear' 132 | sigma = lambda t: t 133 | sigma_deriv = lambda t: 1 134 | sigma_inv = lambda sigma: sigma 135 | 136 | # Define scaling schedule. 137 | if scaling == 'vp': 138 | s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() 139 | s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) 140 | else: 141 | assert scaling == 'none' 142 | s = lambda t: 1 143 | s_deriv = lambda t: 0 144 | 145 | # Compute final time steps based on the corresponding noise levels. 146 | t_steps = sigma_inv(net.round_sigma(sigma_steps)) 147 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 148 | 149 | # Main sampling loop. 150 | t_next = t_steps[0] 151 | x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) 152 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 153 | x_cur = x_next 154 | 155 | # Increase noise temporarily. 156 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 157 | t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) 158 | x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur) 159 | 160 | # Euler step. 161 | h = t_next - t_hat 162 | denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64) 163 | d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised 164 | x_prime = x_hat + alpha * h * d_cur 165 | t_prime = t_hat + alpha * h 166 | 167 | # Apply 2nd order correction. 168 | if solver == 'euler' or i == num_steps - 1: 169 | x_next = x_hat + h * d_cur 170 | else: 171 | assert solver == 'heun' 172 | denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64) 173 | d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised 174 | x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) 175 | 176 | return x_next 177 | 178 | #---------------------------------------------------------------------------- 179 | # Wrapper for torch.Generator that allows specifying a different random seed 180 | # for each sample in a minibatch. 181 | 182 | class StackedRandomGenerator: 183 | def __init__(self, device, seeds): 184 | super().__init__() 185 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] 186 | 187 | def randn(self, size, **kwargs): 188 | assert size[0] == len(self.generators) 189 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) 190 | 191 | def randn_like(self, input): 192 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) 193 | 194 | def randint(self, *args, size, **kwargs): 195 | assert size[0] == len(self.generators) 196 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) 197 | 198 | #---------------------------------------------------------------------------- 199 | # Parse a comma separated list of numbers or ranges and return a list of ints. 200 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 201 | 202 | def parse_int_list(s): 203 | if isinstance(s, list): return s 204 | ranges = [] 205 | range_re = re.compile(r'^(\d+)-(\d+)$') 206 | for p in s.split(','): 207 | m = range_re.match(p) 208 | if m: 209 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 210 | else: 211 | ranges.append(int(p)) 212 | return ranges 213 | 214 | #---------------------------------------------------------------------------- 215 | 216 | @click.command() 217 | @click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True) 218 | @click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True) 219 | @click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True) 220 | @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True) 221 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None) 222 | @click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 223 | 224 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True) 225 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) 226 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) 227 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True) 228 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 229 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 230 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True) 231 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True) 232 | 233 | @click.option('--solver', help='Ablate ODE solver', metavar='euler|heun', type=click.Choice(['euler', 'heun'])) 234 | @click.option('--disc', 'discretization', help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm'])) 235 | @click.option('--schedule', help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear', type=click.Choice(['vp', 've', 'linear'])) 236 | @click.option('--scaling', help='Ablate signal scaling s(t)', metavar='vp|none', type=click.Choice(['vp', 'none'])) 237 | 238 | def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=torch.device('cuda'), **sampler_kwargs): 239 | """Generate random images using the techniques described in the paper 240 | "Elucidating the Design Space of Diffusion-Based Generative Models". 241 | 242 | Examples: 243 | 244 | \b 245 | # Generate 64 images and save them as out/*.png 246 | python generate.py --outdir=out --seeds=0-63 --batch=64 \\ 247 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 248 | 249 | \b 250 | # Generate 1024 images using 2 GPUs 251 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \\ 252 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 253 | """ 254 | dist.init() 255 | num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() 256 | all_batches = torch.as_tensor(seeds).tensor_split(num_batches) 257 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] 258 | 259 | # Rank 0 goes first. 260 | if dist.get_rank() != 0: 261 | torch.distributed.barrier() 262 | 263 | # Load network. 264 | dist.print0(f'Loading network from "{network_pkl}"...') 265 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f: 266 | net = pickle.load(f)['ema'].to(device) 267 | 268 | # Other ranks follow. 269 | if dist.get_rank() == 0: 270 | torch.distributed.barrier() 271 | 272 | latent_list = [] 273 | label_list = [] 274 | image_list = [] 275 | 276 | # Loop over batches. 277 | dist.print0(f'Generating {len(seeds)} images to "{outdir}"...') 278 | for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)): 279 | torch.distributed.barrier() 280 | batch_size = len(batch_seeds) 281 | if batch_size == 0: 282 | continue 283 | 284 | # Pick latents and labels. 285 | rnd = StackedRandomGenerator(device, batch_seeds) 286 | latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device) 287 | class_labels = None 288 | if net.label_dim: 289 | class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)] 290 | if class_idx is not None: 291 | class_labels[:, :] = 0 292 | class_labels[:, class_idx] = 1 293 | 294 | # Generate images. 295 | sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None} 296 | have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling']) 297 | sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler 298 | images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs) 299 | 300 | # Save images. 301 | latent_list.append(latents.float().cpu()) 302 | label_list.append(class_labels.float().cpu()) 303 | image_list.append(images.float().cpu()) 304 | 305 | if len(image_list) % 50 == 0: 306 | save_dir = outdir 307 | os.makedirs(save_dir, exist_ok=True) 308 | 309 | latent_dir = os.path.join(save_dir, 'latent') 310 | os.makedirs(latent_dir, exist_ok=True) 311 | 312 | image_dir = os.path.join(save_dir, 'image') 313 | os.makedirs(image_dir, exist_ok=True) 314 | 315 | latent_path = os.path.join(latent_dir, f'{batch_seeds[-1]:09d}.pkl') 316 | image_path = os.path.join(image_dir, f'{batch_seeds[-1]:09d}.pkl') 317 | 318 | torch.save([torch.cat(latent_list, dim=0), torch.cat(label_list, dim=0)], latent_path) 319 | torch.save(torch.cat(image_list, dim=0), image_path) 320 | 321 | latent_list = [] 322 | label_list = [] 323 | image_list = [] 324 | 325 | if len(image_list) > 0: 326 | save_dir = outdir 327 | os.makedirs(save_dir, exist_ok=True) 328 | 329 | latent_dir = os.path.join(save_dir, 'latent') 330 | os.makedirs(latent_dir, exist_ok=True) 331 | 332 | image_dir = os.path.join(save_dir, 'image') 333 | os.makedirs(image_dir, exist_ok=True) 334 | 335 | latent_path = os.path.join(latent_dir, f'{batch_seeds[-1]:09d}.pkl') 336 | image_path = os.path.join(image_dir, f'{batch_seeds[-1]:09d}.pkl') 337 | 338 | torch.save([torch.cat(latent_list, dim=0), torch.cat(label_list, dim=0)], latent_path) 339 | torch.save(torch.cat(image_list, dim=0), image_path) 340 | 341 | # Done. 342 | torch.distributed.barrier() 343 | dist.print0('Done.') 344 | 345 | #---------------------------------------------------------------------------- 346 | 347 | if __name__ == "__main__": 348 | main() 349 | 350 | #---------------------------------------------------------------------------- 351 | -------------------------------------------------------------------------------- /data/dataset.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=4 generate.py \ 2 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl \ 3 | --outdir=YOUR_UNCOND_DATA_PATH \ 4 | --seeds=0-999999 \ 5 | --batch 250 6 | 7 | torchrun --standalone --nproc_per_node=4 cond_generate.py \ 8 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl \ 9 | --outdir=YOUR_COND_DATA_PATH \ 10 | --seeds=0-999999 \ 11 | --batch 250 12 | 13 | -------------------------------------------------------------------------------- /data/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Generate random images using the techniques described in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import os 12 | import re 13 | import click 14 | import tqdm 15 | import pickle 16 | import numpy as np 17 | import torch 18 | import PIL.Image 19 | import dnnlib 20 | from torch_utils import distributed as dist 21 | 22 | #---------------------------------------------------------------------------- 23 | # Proposed EDM sampler (Algorithm 2). 24 | 25 | def edm_sampler( 26 | net, latents, class_labels=None, randn_like=torch.randn_like, 27 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 28 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 29 | ): 30 | # Adjust noise levels based on what's supported by the network. 31 | sigma_min = max(sigma_min, net.sigma_min) 32 | sigma_max = min(sigma_max, net.sigma_max) 33 | 34 | # Time step discretization. 35 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 36 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 37 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 38 | 39 | # Main sampling loop. 40 | x_next = latents.to(torch.float64) * t_steps[0] 41 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 42 | x_cur = x_next 43 | 44 | # Increase noise temporarily. 45 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 46 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 47 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) 48 | 49 | # Euler step. 50 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64) 51 | d_cur = (x_hat - denoised) / t_hat 52 | x_next = x_hat + (t_next - t_hat) * d_cur 53 | 54 | # Apply 2nd order correction. 55 | if i < num_steps - 1: 56 | denoised = net(x_next, t_next, class_labels).to(torch.float64) 57 | d_prime = (x_next - denoised) / t_next 58 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 59 | 60 | return x_next 61 | 62 | #---------------------------------------------------------------------------- 63 | # Generalized ablation sampler, representing the superset of all sampling 64 | # methods discussed in the paper. 65 | 66 | def ablation_sampler( 67 | net, latents, class_labels=None, randn_like=torch.randn_like, 68 | num_steps=18, sigma_min=None, sigma_max=None, rho=7, 69 | solver='heun', discretization='edm', schedule='linear', scaling='none', 70 | epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1, 71 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 72 | ): 73 | assert solver in ['euler', 'heun'] 74 | assert discretization in ['vp', 've', 'iddpm', 'edm'] 75 | assert schedule in ['vp', 've', 'linear'] 76 | assert scaling in ['vp', 'none'] 77 | 78 | # Helper functions for VP & VE noise level schedules. 79 | vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 80 | vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) 81 | vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d 82 | ve_sigma = lambda t: t.sqrt() 83 | ve_sigma_deriv = lambda t: 0.5 / t.sqrt() 84 | ve_sigma_inv = lambda sigma: sigma ** 2 85 | 86 | # Select default noise level range based on the specified time step discretization. 87 | if sigma_min is None: 88 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s) 89 | sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization] 90 | if sigma_max is None: 91 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1) 92 | sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization] 93 | 94 | # Adjust noise levels based on what's supported by the network. 95 | sigma_min = max(sigma_min, net.sigma_min) 96 | sigma_max = min(sigma_max, net.sigma_max) 97 | 98 | # Compute corresponding betas for VP. 99 | vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1) 100 | vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d 101 | 102 | # Define time steps in terms of noise level. 103 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 104 | if discretization == 'vp': 105 | orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) 106 | sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) 107 | elif discretization == 've': 108 | orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1))) 109 | sigma_steps = ve_sigma(orig_t_steps) 110 | elif discretization == 'iddpm': 111 | u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) 112 | alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 113 | for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 114 | u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() 115 | u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] 116 | sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] 117 | else: 118 | assert discretization == 'edm' 119 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 120 | 121 | # Define noise level schedule. 122 | if schedule == 'vp': 123 | sigma = vp_sigma(vp_beta_d, vp_beta_min) 124 | sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) 125 | sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) 126 | elif schedule == 've': 127 | sigma = ve_sigma 128 | sigma_deriv = ve_sigma_deriv 129 | sigma_inv = ve_sigma_inv 130 | else: 131 | assert schedule == 'linear' 132 | sigma = lambda t: t 133 | sigma_deriv = lambda t: 1 134 | sigma_inv = lambda sigma: sigma 135 | 136 | # Define scaling schedule. 137 | if scaling == 'vp': 138 | s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() 139 | s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) 140 | else: 141 | assert scaling == 'none' 142 | s = lambda t: 1 143 | s_deriv = lambda t: 0 144 | 145 | # Compute final time steps based on the corresponding noise levels. 146 | t_steps = sigma_inv(net.round_sigma(sigma_steps)) 147 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 148 | 149 | # Main sampling loop. 150 | t_next = t_steps[0] 151 | x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) 152 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 153 | x_cur = x_next 154 | 155 | # Increase noise temporarily. 156 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 157 | t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) 158 | x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur) 159 | 160 | # Euler step. 161 | h = t_next - t_hat 162 | denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64) 163 | d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised 164 | x_prime = x_hat + alpha * h * d_cur 165 | t_prime = t_hat + alpha * h 166 | 167 | # Apply 2nd order correction. 168 | if solver == 'euler' or i == num_steps - 1: 169 | x_next = x_hat + h * d_cur 170 | else: 171 | assert solver == 'heun' 172 | denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64) 173 | d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised 174 | x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) 175 | 176 | return x_next 177 | 178 | #---------------------------------------------------------------------------- 179 | # Wrapper for torch.Generator that allows specifying a different random seed 180 | # for each sample in a minibatch. 181 | 182 | class StackedRandomGenerator: 183 | def __init__(self, device, seeds): 184 | super().__init__() 185 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] 186 | 187 | def randn(self, size, **kwargs): 188 | assert size[0] == len(self.generators) 189 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) 190 | 191 | def randn_like(self, input): 192 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) 193 | 194 | def randint(self, *args, size, **kwargs): 195 | assert size[0] == len(self.generators) 196 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) 197 | 198 | #---------------------------------------------------------------------------- 199 | # Parse a comma separated list of numbers or ranges and return a list of ints. 200 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 201 | 202 | def parse_int_list(s): 203 | if isinstance(s, list): return s 204 | ranges = [] 205 | range_re = re.compile(r'^(\d+)-(\d+)$') 206 | for p in s.split(','): 207 | m = range_re.match(p) 208 | if m: 209 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 210 | else: 211 | ranges.append(int(p)) 212 | return ranges 213 | 214 | #---------------------------------------------------------------------------- 215 | 216 | @click.command() 217 | @click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True) 218 | @click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True) 219 | @click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True) 220 | @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True) 221 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None) 222 | @click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 223 | 224 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True) 225 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) 226 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) 227 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True) 228 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 229 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 230 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True) 231 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True) 232 | 233 | @click.option('--solver', help='Ablate ODE solver', metavar='euler|heun', type=click.Choice(['euler', 'heun'])) 234 | @click.option('--disc', 'discretization', help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm'])) 235 | @click.option('--schedule', help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear', type=click.Choice(['vp', 've', 'linear'])) 236 | @click.option('--scaling', help='Ablate signal scaling s(t)', metavar='vp|none', type=click.Choice(['vp', 'none'])) 237 | 238 | def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=torch.device('cuda'), **sampler_kwargs): 239 | """Generate random images using the techniques described in the paper 240 | "Elucidating the Design Space of Diffusion-Based Generative Models". 241 | 242 | Examples: 243 | 244 | \b 245 | # Generate 64 images and save them as out/*.png 246 | python generate.py --outdir=out --seeds=0-63 --batch=64 \\ 247 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 248 | 249 | \b 250 | # Generate 1024 images using 2 GPUs 251 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \\ 252 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 253 | """ 254 | dist.init() 255 | num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() 256 | all_batches = torch.as_tensor(seeds).tensor_split(num_batches) 257 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] 258 | 259 | # Rank 0 goes first. 260 | if dist.get_rank() != 0: 261 | torch.distributed.barrier() 262 | 263 | # Load network. 264 | dist.print0(f'Loading network from "{network_pkl}"...') 265 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f: 266 | net = pickle.load(f)['ema'].to(device) 267 | 268 | # Other ranks follow. 269 | if dist.get_rank() == 0: 270 | torch.distributed.barrier() 271 | 272 | latent_list = [] 273 | image_list = [] 274 | 275 | # Loop over batches. 276 | dist.print0(f'Generating {len(seeds)} images to "{outdir}"...') 277 | for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)): 278 | torch.distributed.barrier() 279 | batch_size = len(batch_seeds) 280 | if batch_size == 0: 281 | continue 282 | 283 | # Pick latents and labels. 284 | rnd = StackedRandomGenerator(device, batch_seeds) 285 | latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device) 286 | class_labels = None 287 | if net.label_dim: 288 | class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)] 289 | if class_idx is not None: 290 | class_labels[:, :] = 0 291 | class_labels[:, class_idx] = 1 292 | 293 | # Generate images. 294 | sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None} 295 | have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling']) 296 | sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler 297 | images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs) 298 | 299 | # Save images. 300 | latent_list.append(latents.cpu()) 301 | image_list.append(images.cpu()) 302 | 303 | if len(image_list) % 50 == 0: 304 | save_dir = outdir 305 | os.makedirs(save_dir, exist_ok=True) 306 | 307 | image_dir = os.path.join(save_dir, 'image') 308 | os.makedirs(image_dir, exist_ok=True) 309 | 310 | latent_dir = os.path.join(save_dir, 'latent') 311 | os.makedirs(latent_dir, exist_ok=True) 312 | 313 | image_path = os.path.join(image_dir, f'{batch_seeds[-1]:09d}.pkl') 314 | latent_path = os.path.join(latent_dir, f'{batch_seeds[-1]:09d}.pkl') 315 | 316 | torch.save(torch.cat(image_list, dim=0), image_path) 317 | torch.save(torch.cat(latent_list, dim=0), latent_path) 318 | 319 | image_list = [] 320 | latent_list = [] 321 | 322 | if len(image_list) > 0: 323 | save_dir = outdir 324 | os.makedirs(save_dir, exist_ok=True) 325 | 326 | image_dir = os.path.join(save_dir, 'image') 327 | os.makedirs(image_dir, exist_ok=True) 328 | 329 | latent_dir = os.path.join(save_dir, 'latent') 330 | os.makedirs(latent_dir, exist_ok=True) 331 | 332 | image_path = os.path.join(image_dir, f'{batch_seeds[-1]:09d}.pkl') 333 | latent_path = os.path.join(latent_dir, f'{batch_seeds[-1]:09d}.pkl') 334 | 335 | torch.save(torch.cat(image_list, dim=0), image_path) 336 | torch.save(torch.cat(latent_list, dim=0), latent_path) 337 | 338 | # Done. 339 | torch.distributed.barrier() 340 | dist.print0('Done.') 341 | 342 | #---------------------------------------------------------------------------- 343 | 344 | if __name__ == "__main__": 345 | main() 346 | 347 | #---------------------------------------------------------------------------- 348 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def find_files(base_dir, world_size=1, rank=0): 6 | path_list = os.listdir(base_dir) 7 | 8 | sort_key = lambda f_name: int(f_name.split('.')[0]) 9 | path_list.sort(key=sort_key) 10 | 11 | assert len(path_list) % world_size == 0 12 | 13 | for i, f in enumerate(path_list): 14 | if i % world_size == rank: 15 | f_path = os.path.join(base_dir, f) 16 | yield f_path 17 | 18 | 19 | class PairedDataset(torch.utils.data.Dataset): 20 | def __init__(self, 21 | data_dir, 22 | world_size=1, 23 | rank=0, 24 | ): 25 | super().__init__() 26 | 27 | self.world_size = world_size 28 | self.data_dir = data_dir 29 | 30 | latent_dir = os.path.join(data_dir, 'latent') 31 | image_dir = os.path.join(data_dir, 'image') 32 | 33 | Z = list() 34 | for f in find_files(latent_dir, world_size, rank): 35 | z = torch.load(f) 36 | Z.append(z) 37 | Z = torch.cat(Z, dim=0) 38 | 39 | X = list() 40 | for f in find_files(image_dir, world_size, rank): 41 | x = torch.load(f) 42 | X.append(x) 43 | X = torch.cat(X, dim=0) 44 | 45 | assert len(Z) == len(X) 46 | 47 | self.Z = Z 48 | self.X = X 49 | 50 | def __len__(self): 51 | return len(self.X) * self.world_size 52 | 53 | def __getitem__(self, idx): 54 | idx = idx // self.world_size 55 | z, x = self.Z[idx], self.X[idx] 56 | 57 | return z, x 58 | 59 | 60 | class PairedCondDataset(torch.utils.data.Dataset): 61 | def __init__(self, 62 | data_dir, 63 | world_size=1, 64 | rank=0, 65 | ): 66 | super().__init__() 67 | 68 | self.world_size = world_size 69 | self.data_dir = data_dir 70 | 71 | latent_dir = os.path.join(data_dir, 'latent') 72 | image_dir = os.path.join(data_dir, 'image') 73 | 74 | Z = list() 75 | C = list() 76 | for f in find_files(latent_dir, world_size, rank): 77 | z, c = torch.load(f) 78 | Z.append(z) 79 | C.append(c) 80 | Z = torch.cat(Z, dim=0) 81 | C = torch.cat(C, dim=0) 82 | 83 | X = list() 84 | for f in find_files(image_dir, world_size, rank): 85 | x = torch.load(f) 86 | X.append(x) 87 | X = torch.cat(X, dim=0) 88 | 89 | assert len(Z) == len(X) 90 | 91 | self.Z = Z 92 | self.C = C 93 | self.X = X 94 | 95 | def __len__(self): 96 | return len(self.X) * self.world_size 97 | 98 | def __getitem__(self, idx): 99 | idx = idx // self.world_size 100 | z, c, x = self.Z[idx], self.C[idx], self.X[idx] 101 | 102 | return z, x, c 103 | 104 | 105 | class EpochPairedDataset(torch.utils.data.Dataset): 106 | def __init__(self, 107 | data_dir, 108 | total=10, 109 | epoch=0, 110 | world_size=1, 111 | rank=0, 112 | ): 113 | super().__init__() 114 | 115 | self.world_size = world_size 116 | 117 | data_dir = data_dir + f'-{epoch%total}' 118 | self.data_dir = data_dir 119 | 120 | latent_dir = os.path.join(data_dir, 'latent') 121 | image_dir = os.path.join(data_dir, 'image') 122 | 123 | Z = list() 124 | for f in find_files(latent_dir, world_size, rank): 125 | z = torch.load(f) 126 | Z.append(z) 127 | Z = torch.cat(Z, dim=0) 128 | 129 | X = list() 130 | for f in find_files(image_dir, world_size, rank): 131 | x = torch.load(f) 132 | X.append(x) 133 | X = torch.cat(X, dim=0) 134 | 135 | assert len(Z) == len(X) 136 | 137 | self.Z = Z 138 | self.X = X 139 | 140 | def __len__(self): 141 | return len(self.X) * self.world_size 142 | 143 | def __getitem__(self, idx): 144 | idx = idx // self.world_size 145 | z, x = self.Z[idx], self.X[idx] 146 | 147 | return z, x 148 | 149 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: GET 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pytorch >= 1.11 8 | - torchvision 9 | - cudatoolkit >= 11.3.1 10 | - pip: 11 | - torchdeq 12 | - timm 13 | - diffusers 14 | - pytorch-gan-metrics 15 | - piq 16 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import time 5 | 6 | import torch 7 | # the first flag below was False when we tested this script but True makes A100 training a lot faster: 8 | torch.backends.cuda.matmul.allow_tf32 = True 9 | torch.backends.cudnn.allow_tf32 = True 10 | 11 | import torch.distributed as dist 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | from torch.utils.data import DataLoader 14 | 15 | from utils import ( 16 | create_logger, requires_grad, 17 | sample_image, sample_fid, compute_fid_is 18 | ) 19 | from models import model_dict 20 | 21 | from torchdeq import add_deq_args 22 | 23 | 24 | def main(args): 25 | ''' 26 | Model evaluation. 27 | ''' 28 | # Setup DDP 29 | dist.init_process_group('nccl') 30 | rank = dist.get_rank() 31 | device = rank % torch.cuda.device_count() 32 | seed = args.global_seed * dist.get_world_size() + rank 33 | torch.manual_seed(seed) 34 | torch.cuda.set_device(device) 35 | print(f'Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.') 36 | 37 | # Setup an experiment folder 38 | if rank == 0: 39 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) 40 | resume_dir = re.split('/|\.', args.resume) 41 | folder_name = f'eval-{resume_dir[-4]}-{resume_dir[-2]}-{args.name}' 42 | experiment_dir = f'{args.results_dir}/{folder_name}' # Create an experiment folder 43 | os.makedirs(experiment_dir, exist_ok=True) 44 | 45 | logger = create_logger(experiment_dir) 46 | logger.info(f'Experiment directory created at {experiment_dir}') 47 | else: 48 | logger = create_logger() 49 | 50 | # Create model 51 | model = model_dict[args.model]( 52 | args=args, 53 | input_size=args.input_size, 54 | num_classes=args.num_classes, 55 | cond=args.cond 56 | ) 57 | ema = model_dict[args.model]( 58 | args=args, 59 | input_size=args.input_size, 60 | num_classes=args.num_classes, 61 | cond=args.cond 62 | ).to(device) 63 | requires_grad(ema, False) 64 | 65 | # Setup DDP 66 | model = DDP(model.to(device), device_ids=[rank]) 67 | logger.info(f'Model Parameters: {sum(p.numel() for p in model.parameters()):,}') 68 | 69 | model.eval() 70 | ema.eval() 71 | 72 | # Resume from the given checkpoint 73 | if args.resume: 74 | ckpt = torch.load(args.resume, map_location=torch.device('cpu')) 75 | model.module.load_state_dict(ckpt['model']) 76 | ema.load_state_dict(ckpt['ema']) 77 | logger.info(f'Resume from {args.resume}..') 78 | 79 | # Sample images 80 | if rank == 0: 81 | image_path = f'{experiment_dir}/samples.png' 82 | sample_image(args, ema, device, image_path, cond=args.cond) 83 | logger.info(f'Saved samples to {image_path}') 84 | dist.barrier() 85 | 86 | # Compute FID and IS 87 | start_time = time.time() 88 | images = sample_fid(args, ema, device, rank, cond=args.cond) 89 | end_time = time.time() 90 | logger.info(f'Time for sampling 50k images {end_time-start_time:.2f}s.') 91 | 92 | # DDP sync for FID evaluation 93 | all_images = [torch.zeros_like(images) for _ in range(dist.get_world_size())] 94 | dist.gather(images, all_images if rank == 0 else None, dst=0) 95 | if rank == 0: 96 | FID, IS = compute_fid_is(args, all_images, rank) 97 | logger.info(f'FID {FID:0.2f}, IS {IS:0.2f}.') 98 | 99 | dist.barrier() 100 | dist.destroy_process_group() 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument('--results_dir', type=str, default='eval-results') 106 | parser.add_argument('--name', type=str, default='debug') 107 | 108 | parser.add_argument('--model', type=str, choices=list(model_dict.keys()), default='GET-S/2') 109 | parser.add_argument('--input_size', type=int, default=32) 110 | 111 | parser.add_argument('--cond', action='store_true', help='Run conditional model.') 112 | parser.add_argument('--num_classes', type=int, default=10) 113 | 114 | parser.add_argument('--global_seed', type=int, default=42) 115 | parser.add_argument('--num_workers', type=int, default=4) 116 | 117 | parser.add_argument('--mem', action='store_true', help='Enable O1 memory.') 118 | 119 | parser.add_argument('--eval_batch_size', type=int, default=128) 120 | parser.add_argument('--eval_samples', type=int, default=50000) 121 | parser.add_argument('--stat_path', type=str, default='YOUR_STAT_PATH/cifar10.test.npz') 122 | 123 | parser.add_argument('--resume', help="restore checkpoint for training") 124 | 125 | # Add for DEQs 126 | add_deq_args(parser) 127 | 128 | args = parser.parse_args() 129 | main(args) 130 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | torchrun --nnodes=1 --nproc_per_node=$1 --rdzv_endpoint=localhost:$2 \ 2 | eval.py \ 3 | --eval_f_max_iter 6 \ 4 | --norm_type none \ 5 | --stat_path YOUR_STAT_PATH \ 6 | ${@:3} 7 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from piq import LPIPS, DISTS 6 | 7 | 8 | class L1Loss(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x_pred, x): 13 | return (x_pred - x).abs().mean() 14 | 15 | 16 | class L2Loss(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def forward(self, x_pred, x): 21 | return ((x_pred - x) ** 2).mean() 22 | 23 | 24 | class LPIPSLoss(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | 28 | self.loss = LPIPS() 29 | 30 | def forward(self, x_pred, x): 31 | x_pred = F.interpolate(x_pred, size=224, mode="bilinear") 32 | x = F.interpolate(x.float(), size=224, mode="bilinear") 33 | 34 | x_pred = (x_pred + 1) / 2 35 | x = (x + 1) / 2 36 | 37 | return self.loss(x_pred, x) 38 | 39 | 40 | class DISTSLoss(nn.Module): 41 | def __init__(self): 42 | super().__init__() 43 | 44 | self.loss = DISTS() 45 | 46 | def forward(self, x_pred, x): 47 | x_pred = F.interpolate(x_pred, size=224, mode="bilinear") 48 | x = F.interpolate(x.float(), size=224, mode="bilinear") 49 | 50 | x_pred = (x_pred + 1) / 2 51 | x = (x + 1) / 2 52 | 53 | return self.loss(x_pred, x) 54 | 55 | 56 | loss_dict = { 57 | 'l1': L1Loss, 58 | 'l2': L2Loss, 59 | 'lpips': LPIPSLoss, 60 | 'dists': DISTSLoss 61 | } 62 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .get import GET_models 3 | from .vit import ViT_models 4 | 5 | 6 | model_dict = {} 7 | model_dict.update(GET_models) 8 | model_dict.update(ViT_models) 9 | 10 | 11 | -------------------------------------------------------------------------------- /models/get.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # References: 3 | # DiT: https://github.com/facebookresearch/DiT 4 | # MAE: https://github.com/facebookresearch/mae 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import numpy as np 12 | import math 13 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 14 | 15 | from torchdeq import get_deq 16 | from torchdeq.norm import apply_norm, reset_norm 17 | from torchdeq.utils import mem_gc 18 | 19 | 20 | # Postional Embedding 21 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 22 | 23 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 24 | """ 25 | grid_size: int of the grid height and width 26 | return: 27 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 28 | """ 29 | grid_h = np.arange(grid_size, dtype=np.float32) 30 | grid_w = np.arange(grid_size, dtype=np.float32) 31 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 32 | grid = np.stack(grid, axis=0) 33 | 34 | grid = grid.reshape([2, 1, grid_size, grid_size]) 35 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 36 | if cls_token and extra_tokens > 0: 37 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 38 | return pos_embed 39 | 40 | 41 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 42 | assert embed_dim % 2 == 0 43 | 44 | # use half of dimensions to encode grid_h 45 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 46 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 47 | 48 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 49 | return emb 50 | 51 | 52 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 53 | """ 54 | embed_dim: output dimension for each position 55 | pos: a list of positions to be encoded: size (M,) 56 | out: (M, D) 57 | """ 58 | assert embed_dim % 2 == 0 59 | omega = np.arange(embed_dim // 2, dtype=np.float64) 60 | omega /= embed_dim / 2. 61 | omega = 1. / 10000**omega # (D/2,) 62 | 63 | pos = pos.reshape(-1) # (M,) 64 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 65 | 66 | emb_sin = np.sin(out) # (M, D/2) 67 | emb_cos = np.cos(out) # (M, D/2) 68 | 69 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 70 | return emb 71 | 72 | 73 | class ClassEmbedding(nn.Module): 74 | def __init__(self, num_classes, hidden_size): 75 | super().__init__() 76 | self.embedding_table = nn.Embedding(num_classes, hidden_size) 77 | 78 | def forward(self, labels): 79 | return self.embedding_table(labels) 80 | 81 | 82 | class AttnInterface(nn.Module): 83 | def __init__( 84 | self, 85 | dim, 86 | num_heads=8, 87 | qkv_bias=False, 88 | qk_norm=False, 89 | attn_drop=0., 90 | proj_drop=0., 91 | norm_layer=nn.LayerNorm, 92 | cond=False 93 | ): 94 | super().__init__() 95 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 96 | self.num_heads = num_heads 97 | self.head_dim = dim // num_heads 98 | self.scale = self.head_dim ** -0.5 99 | self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME 100 | 101 | self.cond = cond 102 | 103 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 104 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 105 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 106 | self.attn_drop = nn.Dropout(attn_drop) 107 | self.proj = nn.Linear(dim, dim) 108 | self.proj_drop = nn.Dropout(proj_drop) 109 | 110 | def forward(self, x, c, u=None): 111 | B, N, C = x.shape 112 | qkv = self.qkv(x) 113 | 114 | # Injection 115 | if self.cond: 116 | qkv = qkv + c 117 | if u is not None: 118 | qkv = qkv + u 119 | 120 | qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 121 | 122 | q, k, v = qkv.unbind(0) 123 | q, k = self.q_norm(q), self.k_norm(k) 124 | 125 | if self.fast_attn: 126 | x = F.scaled_dot_product_attention( 127 | q, k, v, 128 | dropout_p=self.attn_drop.p, 129 | ) 130 | else: 131 | q = q * self.scale 132 | attn = q @ k.transpose(-2, -1) 133 | attn = attn.softmax(dim=-1) 134 | attn = self.attn_drop(attn) 135 | x = attn @ v 136 | 137 | x = x.transpose(1, 2).reshape(B, N, C) 138 | x = self.proj(x) 139 | x = self.proj_drop(x) 140 | return x 141 | 142 | 143 | class GETBlock(nn.Module): 144 | """ 145 | A GET block with additive attention injection. 146 | """ 147 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, cond=False, **block_kwargs): 148 | super().__init__() 149 | # Attention 150 | self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6) 151 | self.attn = AttnInterface(hidden_size, num_heads=num_heads, qkv_bias=True, cond=cond, **block_kwargs) 152 | 153 | # MLP 154 | self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6) 155 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 156 | 157 | act = lambda: nn.GELU() 158 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=act, drop=0) 159 | 160 | def forward(self, x, c, u=None): 161 | x = x + self.attn(self.norm1(x), c, u) 162 | x = x + self.mlp(self.norm2(x)) 163 | return x 164 | 165 | 166 | class FinalLayer(nn.Module): 167 | """ 168 | The final projection layer. 169 | """ 170 | def __init__(self, hidden_size, patch_size, out_channels, cond=False): 171 | super().__init__() 172 | self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6) 173 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 174 | 175 | def forward(self, x): 176 | x = self.norm_final(x) 177 | x = self.linear(x) 178 | return x 179 | 180 | 181 | class GET(nn.Module): 182 | """ 183 | Diffusion model with a Transformer backbone. 184 | """ 185 | def __init__( 186 | self, 187 | args, 188 | input_size=32, 189 | patch_size=2, 190 | in_channels=3, 191 | hidden_size=1152, 192 | depth=28, 193 | deq_depth=3, 194 | num_heads=16, 195 | mlp_ratio=4.0, 196 | deq_mlp_ratio=16.0, 197 | num_classes=10, 198 | cond=False 199 | ): 200 | super().__init__() 201 | self.in_channels = in_channels 202 | self.out_channels = in_channels 203 | self.patch_size = patch_size 204 | self.num_heads = num_heads 205 | self.deq_depth = deq_depth 206 | 207 | self.cond = cond 208 | 209 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 210 | 211 | if self.cond: 212 | self.y_embedder = ClassEmbedding(num_classes, 3*hidden_size) 213 | 214 | num_patches = self.x_embedder.num_patches 215 | # Will use fixed sin-cos embedding: 216 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 217 | 218 | self.blocks = nn.ModuleList([ 219 | GETBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, cond=cond) for _ in range(depth) 220 | ]) 221 | 222 | # injection 223 | self.qkv_inj = nn.Linear(hidden_size, hidden_size*3*deq_depth, bias=False) 224 | 225 | # DEQ blocks 226 | self.deq_blocks = nn.ModuleList([ 227 | GETBlock(hidden_size, num_heads, mlp_ratio=deq_mlp_ratio, cond=cond) for _ in range(deq_depth) 228 | ]) 229 | 230 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, cond=cond) 231 | self.initialize_weights() 232 | 233 | self.mem = args.mem 234 | self.deq = get_deq(args) 235 | apply_norm(self.deq_blocks, args=args) 236 | 237 | def initialize_weights(self): 238 | # Initialize transformer layers: 239 | def _basic_init(module): 240 | if isinstance(module, nn.Linear): 241 | torch.nn.init.xavier_uniform_(module.weight) 242 | if module.bias is not None: 243 | nn.init.constant_(module.bias, 0) 244 | if isinstance(module, nn.LayerNorm): 245 | if module.weight is not None: 246 | nn.init.constant_(module.weight, 1) 247 | nn.init.constant_(module.bias, 0) 248 | self.apply(_basic_init) 249 | 250 | # Initialize (and freeze) pos_embed by sin-cos embedding: 251 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 252 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 253 | 254 | # Initialize patch embedding: 255 | w = self.x_embedder.proj.weight.data 256 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 257 | nn.init.constant_(self.x_embedder.proj.bias, 0) 258 | 259 | # Initialize class embedding table: 260 | if self.cond: 261 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 262 | 263 | nn.init.constant_(self.final_layer.linear.weight, 0) 264 | nn.init.constant_(self.final_layer.linear.bias, 0) 265 | 266 | def unpatchify(self, x): 267 | """ 268 | x: (B, N, P ** 2 * C) 269 | imgs: (B, H, W, C) 270 | """ 271 | c = self.out_channels 272 | p = self.x_embedder.patch_size[0] 273 | h = w = int(x.shape[1] ** 0.5) 274 | assert h * w == x.shape[1] 275 | 276 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 277 | x = torch.einsum('nhwpqc->nchpwq', x) 278 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 279 | return imgs 280 | 281 | def decode(self, z): 282 | x = self.final_layer(z) # (B, N, P ** 2 * C_out) 283 | x = self.unpatchify(x) # (B, C_out, H, W) 284 | return x 285 | 286 | def forward(self, x, y=None): 287 | """ 288 | Forward pass of GET. 289 | x: (B, C, H, W) tensor of spatial inputs (images or latent representations of images) 290 | y: (B,) tensor of class labels 291 | """ 292 | reset_norm(self) 293 | 294 | x = self.x_embedder(x) + self.pos_embed # (B, N, D), where N = H * W / P ** 2 295 | B, N, C = x.shape 296 | 297 | c = None 298 | if self.cond: 299 | c = self.y_embedder(y).view(B, 1, 3*C) 300 | 301 | # Injection T 302 | for block in self.blocks: 303 | x = block(x, c) # (B, N, D) 304 | 305 | u = self.qkv_inj(x) 306 | u_list = u.chunk(self.deq_depth, dim=-1) 307 | 308 | def func(z): 309 | for block, u in zip(self.deq_blocks, u_list): 310 | if self.mem: 311 | z = mem_gc(block, (z, c, u)) 312 | else: 313 | z = block(z, c, u) 314 | return z 315 | 316 | # Equilibrium T 317 | z = torch.randn_like(x) 318 | z_out, info = self.deq(func, z) 319 | 320 | if self.training: 321 | # For fixed point correction 322 | return [self.decode(z) for z in z_out] 323 | else: 324 | return self.decode(z_out[-1]) 325 | 326 | 327 | ################################################################################# 328 | # Current GET Configs # 329 | ################################################################################# 330 | 331 | 332 | def GET_T_2_L6_L3_H6(args, **kwargs): 333 | return GET(args, depth=6, hidden_size=256, patch_size=2, num_heads=4, deq_mlp_ratio=6, **kwargs) 334 | 335 | def GET_M_2_L6_L3_H6(args, **kwargs): 336 | return GET(args, depth=6, hidden_size=384, patch_size=2, num_heads=6, deq_mlp_ratio=6, **kwargs) 337 | 338 | def GET_S_2_L6_L3_H8(args, **kwargs): 339 | return GET(args, depth=6, hidden_size=512, patch_size=2, num_heads=8, deq_mlp_ratio=8, **kwargs) 340 | 341 | def GET_B_2_L1_L3_H12(args, **kwargs): 342 | return GET(args, depth=1, hidden_size=768, patch_size=2, num_heads=12, deq_mlp_ratio=12, **kwargs) 343 | 344 | def GET_B_2_L6_L3_H8(args, **kwargs): 345 | return GET(args, depth=6, hidden_size=768, patch_size=2, num_heads=12, deq_mlp_ratio=8, **kwargs) 346 | 347 | 348 | GET_models = { 349 | 'GET-T/2': GET_T_2_L6_L3_H6, 350 | 'GET-M/2': GET_M_2_L6_L3_H6, 351 | 'GET-S/2': GET_S_2_L6_L3_H8, 352 | 'GET-B/2': GET_B_2_L1_L3_H12, 353 | 'GET-B/2+': GET_B_2_L6_L3_H8, 354 | } 355 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # References: 3 | # DiT: https://github.com/facebookresearch/DiT 4 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import numpy as np 12 | import math 13 | from timm.models.vision_transformer import PatchEmbed, Mlp 14 | 15 | # Postional Embedding 16 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 17 | 18 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 19 | """ 20 | grid_size: int of the grid height and width 21 | return: 22 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 23 | """ 24 | grid_h = np.arange(grid_size, dtype=np.float32) 25 | grid_w = np.arange(grid_size, dtype=np.float32) 26 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 27 | grid = np.stack(grid, axis=0) 28 | 29 | grid = grid.reshape([2, 1, grid_size, grid_size]) 30 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 31 | if cls_token and extra_tokens > 0: 32 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 33 | return pos_embed 34 | 35 | 36 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 37 | assert embed_dim % 2 == 0 38 | 39 | # use half of dimensions to encode grid_h 40 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 41 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 42 | 43 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 44 | return emb 45 | 46 | 47 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 48 | """ 49 | embed_dim: output dimension for each position 50 | pos: a list of positions to be encoded: size (M,) 51 | out: (M, D) 52 | """ 53 | assert embed_dim % 2 == 0 54 | omega = np.arange(embed_dim // 2, dtype=np.float64) 55 | omega /= embed_dim / 2. 56 | omega = 1. / 10000**omega # (D/2,) 57 | 58 | pos = pos.reshape(-1) # (M,) 59 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 60 | 61 | emb_sin = np.sin(out) # (M, D/2) 62 | emb_cos = np.cos(out) # (M, D/2) 63 | 64 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 65 | return emb 66 | 67 | 68 | class ClassEmbedding(nn.Module): 69 | def __init__(self, num_classes, hidden_size): 70 | super().__init__() 71 | self.embedding_table = nn.Embedding(num_classes, hidden_size) 72 | 73 | def forward(self, labels): 74 | return self.embedding_table(labels) 75 | 76 | 77 | class DEQAttention(nn.Module): 78 | def __init__( 79 | self, 80 | dim, 81 | num_heads=8, 82 | qkv_bias=False, 83 | qk_norm=False, 84 | attn_drop=0., 85 | proj_drop=0., 86 | norm_layer=nn.LayerNorm, 87 | cond=False 88 | ): 89 | super().__init__() 90 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 91 | self.num_heads = num_heads 92 | self.head_dim = dim // num_heads 93 | self.scale = self.head_dim ** -0.5 94 | self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME 95 | 96 | self.cond = cond 97 | 98 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 99 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 100 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 101 | self.attn_drop = nn.Dropout(attn_drop) 102 | self.proj = nn.Linear(dim, dim) 103 | self.proj_drop = nn.Dropout(proj_drop) 104 | 105 | def forward(self, x, c=None): 106 | B, N, C = x.shape 107 | qkv = self.qkv(x) 108 | 109 | # Injection 110 | if self.cond: 111 | qkv = qkv + c 112 | 113 | qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 114 | 115 | q, k, v = qkv.unbind(0) 116 | q, k = self.q_norm(q), self.k_norm(k) 117 | 118 | if self.fast_attn: 119 | x = F.scaled_dot_product_attention( 120 | q, k, v, 121 | dropout_p=self.attn_drop.p, 122 | ) 123 | else: 124 | q = q * self.scale 125 | attn = q @ k.transpose(-2, -1) 126 | attn = attn.softmax(dim=-1) 127 | attn = self.attn_drop(attn) 128 | x = attn @ v 129 | 130 | x = x.transpose(1, 2).reshape(B, N, C) 131 | x = self.proj(x) 132 | x = self.proj_drop(x) 133 | return x 134 | 135 | 136 | class ViTBlock(nn.Module): 137 | """ 138 | A standard ViT block. 139 | """ 140 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, cond=False, **block_kwargs): 141 | super().__init__() 142 | self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6) 143 | self.attn = DEQAttention(hidden_size, num_heads=num_heads, qkv_bias=True, cond=cond, **block_kwargs) 144 | self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6) 145 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 146 | 147 | # For Pytorch 1.13 148 | act = lambda: nn.GELU() 149 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=act, drop=0) 150 | 151 | self.cond = cond 152 | 153 | def forward(self, x, c): 154 | if self.cond: 155 | x = x + self.attn(self.norm1(x), c) 156 | x = x + self.mlp(self.norm2(x)) 157 | else: 158 | x = x + self.attn(self.norm1(x)) 159 | x = x + self.mlp(self.norm2(x)) 160 | 161 | return x 162 | 163 | 164 | class FinalLayer(nn.Module): 165 | def __init__(self, hidden_size, patch_size, out_channels, cond=False): 166 | super().__init__() 167 | self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6) 168 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 169 | 170 | def forward(self, x): 171 | x = self.norm_final(x) 172 | x = self.linear(x) 173 | return x 174 | 175 | 176 | class ViT(nn.Module): 177 | """ 178 | Learning fast image generation using ViT. 179 | """ 180 | def __init__( 181 | self, 182 | args, 183 | input_size=32, 184 | patch_size=2, 185 | in_channels=3, 186 | hidden_size=1152, 187 | depth=28, 188 | num_heads=16, 189 | mlp_ratio=4.0, 190 | num_classes=10, 191 | cond=False 192 | ): 193 | super().__init__() 194 | self.in_channels = in_channels 195 | self.out_channels = in_channels 196 | self.patch_size = patch_size 197 | self.num_heads = num_heads 198 | 199 | self.cond = cond 200 | 201 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 202 | 203 | if self.cond: 204 | self.y_embedder = ClassEmbedding(num_classes, 3*hidden_size) 205 | 206 | num_patches = self.x_embedder.num_patches 207 | # Will use fixed sin-cos embedding: 208 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 209 | 210 | self.blocks = nn.ModuleList([ 211 | ViTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, cond=cond) for _ in range(depth) 212 | ]) 213 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, cond=cond) 214 | self.initialize_weights() 215 | 216 | def initialize_weights(self): 217 | # Initialize transformer layers: 218 | def _basic_init(module): 219 | if isinstance(module, nn.Linear): 220 | torch.nn.init.xavier_uniform_(module.weight) 221 | if module.bias is not None: 222 | nn.init.constant_(module.bias, 0) 223 | if isinstance(module, nn.LayerNorm): 224 | if hasattr(module, 'weight') and module.weight is not None: 225 | nn.init.constant_(module.weight, 1) 226 | nn.init.constant_(module.bias, 0) 227 | self.apply(_basic_init) 228 | 229 | # Initialize (and freeze) pos_embed by sin-cos embedding: 230 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 231 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 232 | 233 | # Initialize patch_embed like nn.Linear: 234 | w = self.x_embedder.proj.weight.data 235 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 236 | nn.init.constant_(self.x_embedder.proj.bias, 0) 237 | 238 | # Initialize label embedding: 239 | if self.cond: 240 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 241 | 242 | nn.init.constant_(self.final_layer.linear.weight, 0) 243 | nn.init.constant_(self.final_layer.linear.bias, 0) 244 | 245 | def unpatchify(self, x): 246 | """ 247 | x: (B, N, P ** 2 * C) 248 | imgs: (B, H, W, C) 249 | """ 250 | c = self.out_channels 251 | p = self.x_embedder.patch_size[0] 252 | h = w = int(x.shape[1] ** 0.5) 253 | assert h * w == x.shape[1] 254 | 255 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 256 | x = torch.einsum('nhwpqc->nchpwq', x) 257 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 258 | return imgs 259 | 260 | def forward(self, x, y=None): 261 | """ 262 | Forward pass of ViT. 263 | x: (B, C, H, W) tensor of spatial inputs (images or latent representations of images) 264 | y: (B,) tensor of class labels 265 | """ 266 | x = self.x_embedder(x) + self.pos_embed # (B, N, D), where N = H * W / P ** 2 267 | 268 | c = None 269 | if self.cond: 270 | c = self.y_embedder(y).view(B, 1, 3*C) 271 | 272 | for block in self.blocks: 273 | x = block(x, c) # (B, N, D) 274 | x = self.final_layer(x) # (B, N, P ** 2 * C_out) 275 | 276 | x = self.unpatchify(x) # (B, C_out, H, W) 277 | return x 278 | 279 | 280 | 281 | ################################################################################# 282 | # ViT Configs # 283 | ################################################################################# 284 | 285 | def ViT_XL_2(args, **kwargs): 286 | return ViT(args, depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 287 | 288 | def ViT_XL_4(args, **kwargs): 289 | return ViT(args, depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 290 | 291 | def ViT_XL_8(args, **kwargs): 292 | return ViT(args, depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 293 | 294 | def ViT_L_2(args, **kwargs): 295 | return ViT(args, depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 296 | 297 | def ViT_L_4(args, **kwargs): 298 | return ViT(args, depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 299 | 300 | def ViT_L_8(args, **kwargs): 301 | return ViT(args, depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 302 | 303 | def ViT_B_2(args, **kwargs): 304 | return ViT(args, depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 305 | 306 | def ViT_B_4(args, **kwargs): 307 | return ViT(args, depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) 308 | 309 | def ViT_B_8(args, **kwargs): 310 | return ViT(args, depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) 311 | 312 | def ViT_S_2(args, **kwargs): 313 | return ViT(args, depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 314 | 315 | def ViT_S_4(args, **kwargs): 316 | return ViT(args, depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 317 | 318 | def ViT_S_8(args, **kwargs): 319 | return ViT(args, depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 320 | 321 | 322 | ViT_models = { 323 | 'ViT-XL/2': ViT_XL_2, 'ViT-XL/4': ViT_XL_4, 'ViT-XL/8': ViT_XL_8, 324 | 'ViT-L/2': ViT_L_2, 'ViT-L/4': ViT_L_4, 'ViT-L/8': ViT_L_8, 325 | 'ViT-B/2': ViT_B_2, 'ViT-B/4': ViT_B_4, 'ViT-B/8': ViT_B_8, 326 | 'ViT-S/2': ViT_S_2, 'ViT-S/4': ViT_S_4, 'ViT-S/8': ViT_S_8, 327 | } 328 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | torchrun --nnodes=1 --nproc_per_node=$1 --rdzv_endpoint=localhost:$2 \ 2 | train.py \ 3 | --epochs 1000 \ 4 | --global_batch_size 128 \ 5 | --grad 6 \ 6 | --sup_gap 1 \ 7 | --f_max_iter 0 \ 8 | --eval_f_max_iter 6 \ 9 | --norm_type none \ 10 | --data_path YOUR_DATA_PATH \ 11 | --stat_path YOUR_STAT_PATH \ 12 | ${@:3} 13 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | from glob import glob 7 | 8 | import torch 9 | # the first flag below was False when we tested this script but True makes A100 training a lot faster: 10 | torch.backends.cuda.matmul.allow_tf32 = True 11 | torch.backends.cudnn.allow_tf32 = True 12 | 13 | import torch.distributed as dist 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data.distributed import DistributedSampler 17 | 18 | from torchprofile import profile_macs 19 | 20 | from utils import ( 21 | create_logger, save_ckpt, 22 | update_ema, requires_grad, 23 | sample_image, sample_fid, compute_fid_is 24 | ) 25 | from models import model_dict 26 | from losses import loss_dict 27 | from datasets import PairedDataset, PairedCondDataset 28 | 29 | # For future ImageNet training & sampling 30 | # from diffusers.models import AutoencoderKL 31 | 32 | from torchdeq import add_deq_args 33 | from torchdeq.loss import fp_correction 34 | 35 | 36 | def main(args): 37 | ''' 38 | Model training. 39 | ''' 40 | # Setup DDP 41 | dist.init_process_group('nccl') 42 | world_size = dist.get_world_size() 43 | rank = dist.get_rank() 44 | assert args.global_batch_size % world_size == 0, f'Batch size must be divisible by world size.' 45 | device = rank % torch.cuda.device_count() 46 | seed = args.global_seed * world_size + rank 47 | torch.manual_seed(seed) 48 | torch.cuda.set_device(device) 49 | print(f'Starting rank={rank}, seed={seed}, world_size={world_size}.') 50 | 51 | # Setup an experiment folder 52 | if rank == 0: 53 | os.makedirs(args.results_dir, exist_ok=True) 54 | experiment_index = len(glob(f'{args.results_dir}/*')) 55 | model_string_name = args.model.replace('/', '-') 56 | experiment_dir = f'{args.results_dir}/{experiment_index:03d}-{model_string_name}-{args.name}' 57 | 58 | checkpoint_dir = f'{experiment_dir}/checkpoints' 59 | os.makedirs(checkpoint_dir, exist_ok=True) 60 | 61 | sample_dir = f'{experiment_dir}/samples' 62 | os.makedirs(sample_dir, exist_ok=True) 63 | 64 | logger = create_logger(experiment_dir) 65 | logger.info(f'Experiment directory created at {experiment_dir}') 66 | else: 67 | logger = create_logger() 68 | 69 | # Create model 70 | model = model_dict[args.model]( 71 | args=args, 72 | input_size=args.input_size, 73 | num_classes=args.num_classes, 74 | cond=args.cond 75 | ) 76 | ema = model_dict[args.model]( 77 | args=args, 78 | input_size=args.input_size, 79 | num_classes=args.num_classes, 80 | cond=args.cond 81 | ).to(device) 82 | requires_grad(ema, False) 83 | 84 | # Setup DDP 85 | model = DDP(model.to(device), device_ids=[rank]) 86 | logger.info(f'Model Parameters: {sum(p.numel() for p in model.parameters()):,}') 87 | 88 | # Test FLOPs 89 | if rank == 0: 90 | test_case = torch.randn(1, 3, args.input_size, args.input_size).to(device) 91 | if args.cond: 92 | test_c = torch.randint(0, 10, (1,1)).to(device) 93 | macs = profile_macs(model, (test_case, test_c)) 94 | del test_case, test_c 95 | else: 96 | macs = profile_macs(model, test_case) 97 | del test_case 98 | logger.info(f'Model MACs: {macs:,}') 99 | dist.barrier() 100 | 101 | # For future ImageNet training 102 | # vae = AutoencoderKL.from_pretrained(f'stabilityai/sd-vae-ft-{args.vae}').to(device) 103 | 104 | # Setup optimizer 105 | opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0) 106 | 107 | # Setup data 108 | if args.cond: 109 | dataset = PairedCondDataset(args.data_path, world_size=world_size, rank=rank) 110 | else: 111 | dataset = PairedDataset(args.data_path, world_size=world_size, rank=rank) 112 | sampler = DistributedSampler( 113 | dataset, 114 | num_replicas=world_size, 115 | rank=rank, 116 | shuffle=True, 117 | seed=args.global_seed 118 | ) 119 | loader = DataLoader( 120 | dataset, 121 | batch_size=int(args.global_batch_size // world_size), 122 | shuffle=False, 123 | sampler=sampler, 124 | num_workers=args.num_workers, 125 | pin_memory=True, 126 | drop_last=True 127 | ) 128 | logger.info(f'Dataset contains {len(dataset):,} images ({dataset.data_dir})') 129 | 130 | # Prepare models for training 131 | update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights 132 | model.train() 133 | ema.eval() # EMA model should always be in eval mode 134 | 135 | # Loss fn 136 | loss_fn = loss_dict[args.loss]().to(device) 137 | 138 | # Variables for monitoring/logging purposes 139 | train_steps = 0 140 | log_steps = 0 141 | running_loss = 0 142 | total_steps = args.epochs * (len(dataset) / args.global_batch_size) 143 | 144 | # Resume from the prev checkpoint 145 | if args.resume: 146 | ckpt = torch.load(args.resume, map_location=torch.device('cpu')) 147 | model.module.load_state_dict(ckpt['model']) 148 | ema.load_state_dict(ckpt['ema']) 149 | opt.load_state_dict(ckpt['opt']) 150 | train_steps = max(args.resume_iter, 0) 151 | 152 | logger.info(f'Resume from {args.resume}..') 153 | 154 | start_time = time.time() 155 | logger.info(f'Training for {args.epochs} epochs...') 156 | for epoch in range(args.epochs): 157 | sampler.set_epoch(epoch) 158 | logger.info(f'Beginning epoch {epoch}...') 159 | 160 | for data in loader: 161 | # Unpack data 162 | if args.cond: 163 | z, x, c = data 164 | z, x, c = z.to(device), x.to(device), c.to(device).max(dim=1)[1] 165 | else: 166 | z, x = data 167 | z, x, c = z.to(device), x.to(device), None 168 | 169 | # Loss & Grad 170 | x_pred = model(z, c) 171 | loss, loss_list = fp_correction(loss_fn, (x_pred, x), return_loss_values=True) 172 | opt.zero_grad() 173 | loss.backward() 174 | 175 | # LR Warmup 176 | if train_steps < args.warmup_iter: 177 | curr_lr = args.lr * (train_steps+1) / args.warmup_iter 178 | opt.param_groups[0]['lr'] = curr_lr 179 | 180 | opt.step() 181 | update_ema(ema, model.module, decay=args.ema_decay) 182 | 183 | running_loss += loss_list[-1] 184 | log_steps += 1 185 | train_steps += 1 186 | 187 | # Log training progress 188 | if train_steps % args.log_every == 0: 189 | # Measure training speed 190 | torch.cuda.synchronize() 191 | end_time = time.time() 192 | steps_per_sec = log_steps / (end_time - start_time) 193 | 194 | # Reduce loss history over all processes 195 | avg_loss = torch.tensor(running_loss / log_steps, device=device) 196 | dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) 197 | avg_loss = avg_loss.item() / world_size 198 | logger.info(f'(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}') 199 | 200 | # Reset monitoring variables 201 | running_loss = 0 202 | log_steps = 0 203 | start_time = time.time() 204 | 205 | # Save checkpoint 206 | if train_steps % args.ckpt_every == 0 and train_steps > 0: 207 | if rank == 0: 208 | checkpoint_path = f'{checkpoint_dir}/{train_steps:07d}.pth' 209 | save_ckpt(args, model, ema, opt, checkpoint_path) 210 | logger.info(f'Saved checkpoint to {checkpoint_path}') 211 | dist.barrier() 212 | 213 | # Save the latest checkpoint 214 | if train_steps % args.save_latest_every == 0 and train_steps > 0: 215 | if rank == 0: 216 | checkpoint_path = f'{checkpoint_dir}/latest.pth' 217 | save_ckpt(args, model, ema, opt, checkpoint_path) 218 | logger.info(f'Saved latest checkpoint to {checkpoint_path}') 219 | dist.barrier() 220 | 221 | # Sample images 222 | if train_steps % args.sample_every == 0 and train_steps > 0: 223 | if rank == 0: 224 | image_path = f'{sample_dir}/{train_steps}.png' 225 | sample_image(args, ema, device, image_path, cond=args.cond) 226 | logger.info(f'Saved samples to {image_path}') 227 | dist.barrier() 228 | 229 | # Compute FID and IS 230 | if train_steps % args.eval_every == 0 and train_steps > 0: 231 | images = sample_fid(args, ema, device, rank, cond=args.cond) 232 | 233 | # In case you want to sample from the online model 234 | # images = sample_fid(args, model.module, device, rank, cond=args.cond, set_grad=True) 235 | 236 | # DDP sync 237 | all_images = [torch.zeros_like(images) for _ in range(world_size)] 238 | dist.gather(images, all_images if rank == 0 else None, dst=0) 239 | if rank == 0: 240 | FID, IS = compute_fid_is(args, all_images, rank) 241 | logger.info(f'FID {FID:0.2f}, IS {IS:0.2f} at iters {train_steps}.') 242 | 243 | del images, all_images 244 | dist.barrier() 245 | 246 | # Check training schedule 247 | if train_steps > total_steps: 248 | break 249 | 250 | if rank == 0: 251 | checkpoint_path = f'{checkpoint_dir}/final.pth' 252 | save_ckpt(args, model, ema, opt, checkpoint_path) 253 | logger.info(f'Saved final checkpoint to {checkpoint_path}') 254 | dist.barrier() 255 | 256 | # Finish training 257 | dist.destroy_process_group() 258 | 259 | 260 | if __name__ == '__main__': 261 | parser = argparse.ArgumentParser() 262 | parser.add_argument('--data_path', type=str, required=True) 263 | parser.add_argument('--name', type=str, default='debug') 264 | parser.add_argument('--results_dir', type=str, default='results') 265 | 266 | parser.add_argument('--model', type=str, choices=list(model_dict.keys()), default='GET-S/2') 267 | parser.add_argument('--input_size', type=int, default=32) 268 | 269 | parser.add_argument('--cond', action='store_true', help='Run conditional model.') 270 | parser.add_argument('--num_classes', type=int, default=10) 271 | 272 | parser.add_argument('--loss', type=str, choices=['l1', 'l2', 'lpips', 'dists'], default='l1') 273 | parser.add_argument('--vae', type=str, choices=['ema', 'mse'], default='ema') 274 | 275 | parser.add_argument('--lr', type=float, default=1e-4) 276 | parser.add_argument('--warmup_iter', type=int, default=0, help="warmup for the given iterations") 277 | parser.add_argument('--ema_decay', type=float, default=0.9999) 278 | 279 | parser.add_argument('--epochs', type=int, default=1000) 280 | parser.add_argument('--global_batch_size', type=int, default=256) 281 | parser.add_argument('--global_seed', type=int, default=42) 282 | parser.add_argument('--num_workers', type=int, default=4) 283 | 284 | parser.add_argument('--mem', action='store_true', help='Enable O(1) memory.') 285 | 286 | parser.add_argument('--log_every', type=int, default=100) 287 | parser.add_argument('--ckpt_every', type=int, default=50000) 288 | parser.add_argument('--save_latest_every', type=int, default=10000) 289 | parser.add_argument('--sample_every', type=int, default=10000) 290 | 291 | parser.add_argument('--eval_every', type=int, default=50000) 292 | parser.add_argument('--eval_samples', type=int, default=50000) 293 | parser.add_argument('--eval_batch_size', type=int, default=128) 294 | parser.add_argument('--stat_path', type=str, default='YOUR_STAT_PATH/cifar10.test.npz') 295 | 296 | parser.add_argument('--resume', help="restore checkpoint for training") 297 | parser.add_argument('--resume_iter', type=int, default=-1, help="resume from the given iterations") 298 | 299 | # Add for DEQs 300 | add_deq_args(parser) 301 | 302 | args = parser.parse_args() 303 | main(args) 304 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import logging 4 | import torch 5 | 6 | import torch.distributed as dist 7 | 8 | from PIL import Image 9 | from pytorch_gan_metrics import get_inception_score_and_fid 10 | 11 | 12 | def create_logger(logging_dir=None): 13 | """ 14 | Create a logger that writes to a log file and stdout. 15 | """ 16 | if dist.get_rank() == 0: # real logger 17 | logging.basicConfig( 18 | level=logging.INFO, 19 | format='[\033[34m%(asctime)s\033[0m] %(message)s', 20 | datefmt='%Y-%m-%d %H:%M:%S', 21 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 22 | ) 23 | logger = logging.getLogger(__name__) 24 | else: # dummy logger (does nothing) 25 | logger = logging.getLogger(__name__) 26 | logger.addHandler(logging.NullHandler()) 27 | return logger 28 | 29 | 30 | @torch.no_grad() 31 | def update_ema(ema_model, model, decay=0.9999): 32 | ''' 33 | Step the EMA model towards the current model. 34 | ''' 35 | ema_params = OrderedDict(ema_model.named_parameters()) 36 | model_params = OrderedDict(model.named_parameters()) 37 | 38 | for name, param in model_params.items(): 39 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 40 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 41 | 42 | 43 | def requires_grad(model, flag=True): 44 | ''' 45 | Set requires_grad flag for all parameters in a model. 46 | ''' 47 | for p in model.parameters(): 48 | p.requires_grad = flag 49 | 50 | 51 | def save_ckpt(args, model, ema, opt, checkpoint_path): 52 | ''' 53 | Save a checkpoint containing the online model, EMA, and optimizer states. 54 | ''' 55 | checkpoint = { 56 | 'args': args, 57 | 'model': model.module.state_dict(), 58 | 'ema': ema.state_dict(), 59 | 'opt': opt.state_dict(), 60 | } 61 | torch.save(checkpoint, checkpoint_path) 62 | 63 | 64 | def sample_image(args, model, device, image_path, set_train=False, cond=False): 65 | ''' 66 | sample a batch of images for visualization. 67 | set set_train to true if you are using the online model for sampling. 68 | ''' 69 | model.eval() 70 | 71 | n_row = 16 72 | size = args.input_size 73 | 74 | z = torch.randn(n_row*n_row, 3, size, size).to(device) 75 | c = torch.randint(0, args.num_classes, (n_row*n_row,)).to(device) if cond else None 76 | with torch.no_grad(): 77 | x = model(z, c) 78 | 79 | x = x.view(n_row, n_row, 3, size, size) 80 | x = (x * 127.5 + 128).clip(0, 255).to(torch.uint8) 81 | images = x.permute(0, 3, 1, 4, 2).reshape(n_row*size, n_row*size, 3).cpu().numpy() 82 | 83 | Image.fromarray(images, 'RGB').save(image_path) 84 | del images, x, z, c 85 | torch.cuda.empty_cache() 86 | 87 | if set_train: 88 | model.train() 89 | 90 | 91 | def num_to_groups(num, divisor): 92 | ''' 93 | Compute number of samples in each batch to evenly divide the total eval samples. 94 | ''' 95 | groups = num // divisor 96 | remainder = num % divisor 97 | arr = [divisor] * groups 98 | if remainder > 0: 99 | arr.append(remainder) 100 | return arr 101 | 102 | 103 | def sample_fid(args, model, device, rank, set_train=False, cond=False): 104 | ''' 105 | Sample args.eval_samples images in parallel for FID and IS calculation. Default 50k images. 106 | Set set_train to True if you are using the online model for sampling. 107 | ''' 108 | # Setup batches for each node 109 | assert args.eval_samples % dist.get_world_size() == 0 110 | samples_per_node = args.eval_samples // dist.get_world_size() 111 | batches = num_to_groups(samples_per_node, args.eval_batch_size) 112 | 113 | # Dist EMA/online evaluation 114 | # No need to use the DDP wrapper here 115 | # As we do not need grad sycn (by DDP) 116 | model.eval() 117 | model = model.to(device) 118 | 119 | n_cls = args.num_classes 120 | size = args.input_size 121 | 122 | images = [] 123 | with torch.no_grad(): 124 | for n in batches: 125 | z = torch.randn(n, 3, size, size).to(device) 126 | c = torch.randint(0, n_cls, (n,)).to(device) if cond else None 127 | x = model(z, c) 128 | images.append(x) 129 | images = torch.cat(images, dim=0) 130 | 131 | torch.cuda.empty_cache() 132 | if set_train: 133 | model.train() 134 | 135 | return images 136 | 137 | 138 | def compute_fid_is(args, all_images, rank): 139 | ''' 140 | Compute FID and IS using provided images. 141 | ''' 142 | # Post-process to images. 143 | all_images = torch.cat(all_images, dim=0) 144 | all_images = (all_images * 127.5 + 128).clip(0, 255).to(torch.uint8).float().div(255).cpu() 145 | 146 | # Compute FID & IS 147 | (IS, IS_std), FID = get_inception_score_and_fid(all_images, args.stat_path) 148 | torch.cuda.empty_cache() 149 | 150 | return FID, IS 151 | 152 | --------------------------------------------------------------------------------