├── .gitignore
├── LICENSE
├── README.md
├── STAT.md
├── assets
└── GET.png
├── data
├── cond_generate.py
├── dataset.sh
└── generate.py
├── datasets.py
├── environment.yml
├── eval.py
├── eval.sh
├── losses.py
├── models
├── __init__.py
├── get.py
└── vit.py
├── run.sh
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # results
132 | eval-results/
133 | results/
134 |
135 | # test scripts
136 | eval_test.sh
137 | run_test.sh
138 |
139 | # TODO list
140 | TODO.md
141 |
142 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 CMU Locus Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Generative Equilibrium Transformer
2 |
3 | This is the official repo for the paper [*One-Step Diffusion Distillation via Deep Equilibrium Models*](),
4 | by [Zhengyang Geng](https://gsunshine.github.io/)\*, [Ashwini Pokle](https://ashwinipokle.github.io/)\*, and [J. Zico Kolter](http://zicokolter.com/).
5 |
6 |

7 |
8 | ## Environment
9 |
10 | ## Dataset
11 |
12 | First, download the datasets **EDM-Uncond-CIFAR** and **EDM-Cond-CIFAR** from [this link](https://drive.google.com/drive/folders/1dlFiS5ahwu7xne6fUNVNELaG12-AKn4I?usp=sharing).
13 | Set up the `--data_path` in `run.sh` to the dir where you store the datasets, like `--data_path DATA_DIR/EDM-Uncond-CIFAR-1M`.
14 |
15 | In addition, download the precomputed dataset statistics from [this link](https://drive.google.com/drive/folders/1UBdzl6GtNMwNQ5U-4ESlIer43tNjiGJC).
16 | Set up the `--stat_path` in `run.sh` and `eval.sh` using your download dir plus stat name.
17 |
18 | ## Training
19 |
20 | To train a GET, run this command:
21 |
22 | ```bash
23 | bash run.sh N_GPU DDP_PORT --model MODEL_NAME --name EXP_NAME
24 | ```
25 |
26 | `N_GPU` is the number of GPU used for training.
27 | `DDP_PORT` is the port number for syncing gradient during distributed training.
28 | `MODEL_NAME` is the model's name.
29 | See all available models using `python train.py -h`.
30 | The training log, checkpoints, and sampled images will be saved to `./results` using your `EXP_NAME`.
31 |
32 | For example, this command train a GET-S/2 (of patch size 2) on 4 GPUs.
33 |
34 | ```bash
35 | bash run.sh 4 12345 --model GET-S/2 --name test-GET
36 | ```
37 |
38 | To train a ViT, run this command:
39 |
40 | ```bash
41 | bash run.sh N_GPU DDP_PORT --model ViT-B/2 --name EXP_NAME
42 | ```
43 |
44 | For training **conditional** models, add the `--cond` command.
45 |
46 | For the **O(1)-memory** training, add the `--mem` command.
47 |
48 | ## Evaluation
49 |
50 | Download pretrained models from [this link](https://drive.google.com/drive/u/1/folders/1g998S6moSQhybD9poDJHmXP85QF3zz4g).
51 |
52 | To load a checkpoint for evaluation, run this command
53 |
54 | ```bash
55 | bash run.sh N_GPU DDP_PORT --model MODEL_NAME --resume CKPT_PATH --name EXP_NAME
56 | ```
57 |
58 | The evaluation log and sampled images will be saved to `./eval-results` plus your `EXP_NAME`.
59 |
60 | For evaluating conditional models, add the `--cond` command. Here is an example.
61 |
62 | ```bash
63 | bash run.sh 4 12345 --model GET-B/2 --cond --resume CKPT_DIR/GET-B-cond-2M-data-bs256.pth
64 | ```
65 |
66 | ## Generative Performance
67 |
68 | You can see [the generative performance here](STAT.md). The discussion there might be interesting.
69 |
70 | ## Data Generation
71 |
72 | First, clone the EDM repo. Then, copy the files under `/data` to the `/edm` directory.
73 |
74 | Set up the `DATA_PATH` in `dataset.sh` for storing the synthetic dataset.
75 | Run the following command to generate both conditional and unconditional training sets.
76 |
77 | ```bash
78 | bash dataset.sh
79 | ```
80 |
81 | If you want to generate more data pairs, adjust the range of `--seeds=0-MAX_SAMPLES`.
82 |
83 | ## Bibtex
84 |
85 | If you find our work helpful to your research, please consider citing this paper. :)
86 |
87 | ```bib
88 | @inproceedings{
89 | geng2023onestep,
90 | title={One-Step Diffusion Distillation via Deep Equilibrium Models},
91 | author={Zhengyang Geng and Ashwini Pokle and J Zico Kolter},
92 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
93 | year={2023}
94 | }
95 | ```
96 |
97 | ## Contact
98 |
99 | Feel free to contact us if you have additional questions!
100 | Please drop an email to zhengyanggeng@gmail.com (or [Twitter](https://twitter.com/ZhengyangGeng))
101 | or apokle@andrew.cmu.edu.
102 |
103 | ## Acknowledgment
104 |
105 | This project is built upon [TorchDEQ](https://github.com/locuslab/torchdeq),
106 | [DiT](https://arxiv.org/abs/2212.09748),
107 | and [timm](https://github.com/huggingface/pytorch-image-models).
108 | Thanks for the awesome projects!
109 |
--------------------------------------------------------------------------------
/STAT.md:
--------------------------------------------------------------------------------
1 |
2 | # Model Performance
3 |
4 | Here are the generative performances under different training settings and model configs.
5 |
6 | Generative performance of unconditional models.
7 |
8 | | Model Name | Type | Params | FID | IS | Training Data | BS | Iters |
9 | | :--------- | :-- | :----: | :-: | :-: | :----------: | :-: | :---: |
10 | | [GET-T](https://drive.google.com/file/d/1rDw5A34ZnTajQZLSb_7fGkUfQU6viwq8/view?usp=sharing) | Uncond | 8.6M | 15.23 | 8.40 | 1M | 128 | 800k |
11 | | [GET-M](https://drive.google.com/file/d/1bAcRl0dWDxzIkm3sZBzBMABw6y78mVcQ/view?usp=sharing) | Uncond | 19.2M | 10.81 | 8.77 | 1M | 128 | 800k |
12 | | [GET-S](https://drive.google.com/file/d/1rN2rD7WUDaJaL3uRU5eX14wKQAg7Zy8z/view?usp=sharing) | Uncond | 37.2M | 7.99 | 9.05 | 1M | 128 | 800k |
13 | | [GET-B](https://drive.google.com/file/d/1k7qMLfqxctFNldsUSuapLP96oIllwZ1H/view?usp=sharing) | Uncond | 62.2M | 7.39 | 9.17 | 1M | 128 | 800k |
14 | | [GET-B+](https://drive.google.com/file/d/1jUE1lqs0qsbqbLyl9nmROcxrETDUXx25/view?usp=sharing) | Uncond | 83.5M | 7.21 | 9.07 | 1M | 128 | 800k |
15 |
16 | Generative performance of class-conditional models.
17 |
18 | | Model Name | Type | Params | FID | IS | Training Data | BS | Iters |
19 | | :--------- | :-- | :----: | :-: | :-: | :----------: | :-: | :---: |
20 | | [GET-B](https://drive.google.com/file/d/1BPPtWpoXVexgozaKAiZRx0N5egweHrFH/view?usp=sharing) | Cond | 62.2M | 6.23 | 9.42 | 1M | 256 | 800k |
21 | | [GET-B](https://drive.google.com/file/d/1DH8cN70OucFRoWsXJvK4vgyIrctAcoFN/view?usp=sharing) | Cond | 62.2M | 5.66 | 9.63 | 2M | 256 | 1.2M |
22 |
23 | There is a clear scaling for Generative Equilibrium Transformers. Mostly, FID has a **log linear** relation w.r.t. the input money (=Training FLOPs/Time/Data/Params) when fixing other dimensions. Scaling up the training data, better supervision from the perceptual loss/teacher model, more training FLOPs/larger batch size/longer training schedule can lead to better results, as demonstrated.
24 |
25 | Ideally, when scaling up training data, the model size and training flops need to adjust accordingly to achieve the best training efficiency. Nonetheless, restricted by our computing resources, we cannot derive the *exact* scaling law for compute-optimal models as shown in [Chinchilla](https://arxiv.org/abs/2203.15556). Despite the compute restriction, our observation still shows that GET's scaling law suggests **much more compact compute-optimal models** than ViTs, which is ideal for memory-bounded deployment, like today's LLM.
26 |
27 | In particular, this work shows that **memorizing sufficient "regular" data pairs can lead to a good generative model**, no matter for GET or ViT. The differences can be data efficiency and training efficiency (we assume there are much better training strategies though).
28 |
29 | Here, the term "regular" means the latent-image pairs are easy to learn, e.g., sampled from a pretrained model. Please note that the randomly paired latent-code, for example, shuffling the latent-image pairs in the training set, cannot be memorized by a model at a constant learning rate (we use a fixed learning rate to train models over massive pairs in the above experiments), as the loss curve is non-decreasing. This implies that **the learnability of pairing can be a strong measurement of pairing quality**.
30 |
31 |
--------------------------------------------------------------------------------
/assets/GET.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/get/53377d3047b8ab677ef4e84252b15b5919fe70dd/assets/GET.png
--------------------------------------------------------------------------------
/data/cond_generate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Generate random images using the techniques described in the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import os
12 | import re
13 | import click
14 | import tqdm
15 | import pickle
16 | import numpy as np
17 | import torch
18 | import PIL.Image
19 | import dnnlib
20 | from torch_utils import distributed as dist
21 |
22 | #----------------------------------------------------------------------------
23 | # Proposed EDM sampler (Algorithm 2).
24 |
25 | def edm_sampler(
26 | net, latents, class_labels=None, randn_like=torch.randn_like,
27 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
28 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
29 | ):
30 | # Adjust noise levels based on what's supported by the network.
31 | sigma_min = max(sigma_min, net.sigma_min)
32 | sigma_max = min(sigma_max, net.sigma_max)
33 |
34 | # Time step discretization.
35 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
36 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
37 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
38 |
39 | # Main sampling loop.
40 | x_next = latents.to(torch.float64) * t_steps[0]
41 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
42 | x_cur = x_next
43 |
44 | # Increase noise temporarily.
45 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
46 | t_hat = net.round_sigma(t_cur + gamma * t_cur)
47 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
48 |
49 | # Euler step.
50 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
51 | d_cur = (x_hat - denoised) / t_hat
52 | x_next = x_hat + (t_next - t_hat) * d_cur
53 |
54 | # Apply 2nd order correction.
55 | if i < num_steps - 1:
56 | denoised = net(x_next, t_next, class_labels).to(torch.float64)
57 | d_prime = (x_next - denoised) / t_next
58 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
59 |
60 | return x_next
61 |
62 | #----------------------------------------------------------------------------
63 | # Generalized ablation sampler, representing the superset of all sampling
64 | # methods discussed in the paper.
65 |
66 | def ablation_sampler(
67 | net, latents, class_labels=None, randn_like=torch.randn_like,
68 | num_steps=18, sigma_min=None, sigma_max=None, rho=7,
69 | solver='heun', discretization='edm', schedule='linear', scaling='none',
70 | epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
71 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
72 | ):
73 | assert solver in ['euler', 'heun']
74 | assert discretization in ['vp', 've', 'iddpm', 'edm']
75 | assert schedule in ['vp', 've', 'linear']
76 | assert scaling in ['vp', 'none']
77 |
78 | # Helper functions for VP & VE noise level schedules.
79 | vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
80 | vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
81 | vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
82 | ve_sigma = lambda t: t.sqrt()
83 | ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
84 | ve_sigma_inv = lambda sigma: sigma ** 2
85 |
86 | # Select default noise level range based on the specified time step discretization.
87 | if sigma_min is None:
88 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s)
89 | sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
90 | if sigma_max is None:
91 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1)
92 | sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
93 |
94 | # Adjust noise levels based on what's supported by the network.
95 | sigma_min = max(sigma_min, net.sigma_min)
96 | sigma_max = min(sigma_max, net.sigma_max)
97 |
98 | # Compute corresponding betas for VP.
99 | vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
100 | vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
101 |
102 | # Define time steps in terms of noise level.
103 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
104 | if discretization == 'vp':
105 | orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
106 | sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
107 | elif discretization == 've':
108 | orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
109 | sigma_steps = ve_sigma(orig_t_steps)
110 | elif discretization == 'iddpm':
111 | u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
112 | alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
113 | for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
114 | u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
115 | u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
116 | sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
117 | else:
118 | assert discretization == 'edm'
119 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
120 |
121 | # Define noise level schedule.
122 | if schedule == 'vp':
123 | sigma = vp_sigma(vp_beta_d, vp_beta_min)
124 | sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
125 | sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
126 | elif schedule == 've':
127 | sigma = ve_sigma
128 | sigma_deriv = ve_sigma_deriv
129 | sigma_inv = ve_sigma_inv
130 | else:
131 | assert schedule == 'linear'
132 | sigma = lambda t: t
133 | sigma_deriv = lambda t: 1
134 | sigma_inv = lambda sigma: sigma
135 |
136 | # Define scaling schedule.
137 | if scaling == 'vp':
138 | s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
139 | s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
140 | else:
141 | assert scaling == 'none'
142 | s = lambda t: 1
143 | s_deriv = lambda t: 0
144 |
145 | # Compute final time steps based on the corresponding noise levels.
146 | t_steps = sigma_inv(net.round_sigma(sigma_steps))
147 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
148 |
149 | # Main sampling loop.
150 | t_next = t_steps[0]
151 | x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
152 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
153 | x_cur = x_next
154 |
155 | # Increase noise temporarily.
156 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
157 | t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
158 | x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)
159 |
160 | # Euler step.
161 | h = t_next - t_hat
162 | denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64)
163 | d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
164 | x_prime = x_hat + alpha * h * d_cur
165 | t_prime = t_hat + alpha * h
166 |
167 | # Apply 2nd order correction.
168 | if solver == 'euler' or i == num_steps - 1:
169 | x_next = x_hat + h * d_cur
170 | else:
171 | assert solver == 'heun'
172 | denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64)
173 | d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
174 | x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
175 |
176 | return x_next
177 |
178 | #----------------------------------------------------------------------------
179 | # Wrapper for torch.Generator that allows specifying a different random seed
180 | # for each sample in a minibatch.
181 |
182 | class StackedRandomGenerator:
183 | def __init__(self, device, seeds):
184 | super().__init__()
185 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
186 |
187 | def randn(self, size, **kwargs):
188 | assert size[0] == len(self.generators)
189 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
190 |
191 | def randn_like(self, input):
192 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
193 |
194 | def randint(self, *args, size, **kwargs):
195 | assert size[0] == len(self.generators)
196 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
197 |
198 | #----------------------------------------------------------------------------
199 | # Parse a comma separated list of numbers or ranges and return a list of ints.
200 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
201 |
202 | def parse_int_list(s):
203 | if isinstance(s, list): return s
204 | ranges = []
205 | range_re = re.compile(r'^(\d+)-(\d+)$')
206 | for p in s.split(','):
207 | m = range_re.match(p)
208 | if m:
209 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
210 | else:
211 | ranges.append(int(p))
212 | return ranges
213 |
214 | #----------------------------------------------------------------------------
215 |
216 | @click.command()
217 | @click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True)
218 | @click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True)
219 | @click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True)
220 | @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True)
221 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None)
222 | @click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
223 |
224 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True)
225 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
226 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
227 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
228 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
229 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
230 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True)
231 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True)
232 |
233 | @click.option('--solver', help='Ablate ODE solver', metavar='euler|heun', type=click.Choice(['euler', 'heun']))
234 | @click.option('--disc', 'discretization', help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm']))
235 | @click.option('--schedule', help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear', type=click.Choice(['vp', 've', 'linear']))
236 | @click.option('--scaling', help='Ablate signal scaling s(t)', metavar='vp|none', type=click.Choice(['vp', 'none']))
237 |
238 | def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=torch.device('cuda'), **sampler_kwargs):
239 | """Generate random images using the techniques described in the paper
240 | "Elucidating the Design Space of Diffusion-Based Generative Models".
241 |
242 | Examples:
243 |
244 | \b
245 | # Generate 64 images and save them as out/*.png
246 | python generate.py --outdir=out --seeds=0-63 --batch=64 \\
247 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
248 |
249 | \b
250 | # Generate 1024 images using 2 GPUs
251 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \\
252 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
253 | """
254 | dist.init()
255 | num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
256 | all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
257 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
258 |
259 | # Rank 0 goes first.
260 | if dist.get_rank() != 0:
261 | torch.distributed.barrier()
262 |
263 | # Load network.
264 | dist.print0(f'Loading network from "{network_pkl}"...')
265 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
266 | net = pickle.load(f)['ema'].to(device)
267 |
268 | # Other ranks follow.
269 | if dist.get_rank() == 0:
270 | torch.distributed.barrier()
271 |
272 | latent_list = []
273 | label_list = []
274 | image_list = []
275 |
276 | # Loop over batches.
277 | dist.print0(f'Generating {len(seeds)} images to "{outdir}"...')
278 | for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)):
279 | torch.distributed.barrier()
280 | batch_size = len(batch_seeds)
281 | if batch_size == 0:
282 | continue
283 |
284 | # Pick latents and labels.
285 | rnd = StackedRandomGenerator(device, batch_seeds)
286 | latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
287 | class_labels = None
288 | if net.label_dim:
289 | class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)]
290 | if class_idx is not None:
291 | class_labels[:, :] = 0
292 | class_labels[:, class_idx] = 1
293 |
294 | # Generate images.
295 | sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
296 | have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling'])
297 | sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler
298 | images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs)
299 |
300 | # Save images.
301 | latent_list.append(latents.float().cpu())
302 | label_list.append(class_labels.float().cpu())
303 | image_list.append(images.float().cpu())
304 |
305 | if len(image_list) % 50 == 0:
306 | save_dir = outdir
307 | os.makedirs(save_dir, exist_ok=True)
308 |
309 | latent_dir = os.path.join(save_dir, 'latent')
310 | os.makedirs(latent_dir, exist_ok=True)
311 |
312 | image_dir = os.path.join(save_dir, 'image')
313 | os.makedirs(image_dir, exist_ok=True)
314 |
315 | latent_path = os.path.join(latent_dir, f'{batch_seeds[-1]:09d}.pkl')
316 | image_path = os.path.join(image_dir, f'{batch_seeds[-1]:09d}.pkl')
317 |
318 | torch.save([torch.cat(latent_list, dim=0), torch.cat(label_list, dim=0)], latent_path)
319 | torch.save(torch.cat(image_list, dim=0), image_path)
320 |
321 | latent_list = []
322 | label_list = []
323 | image_list = []
324 |
325 | if len(image_list) > 0:
326 | save_dir = outdir
327 | os.makedirs(save_dir, exist_ok=True)
328 |
329 | latent_dir = os.path.join(save_dir, 'latent')
330 | os.makedirs(latent_dir, exist_ok=True)
331 |
332 | image_dir = os.path.join(save_dir, 'image')
333 | os.makedirs(image_dir, exist_ok=True)
334 |
335 | latent_path = os.path.join(latent_dir, f'{batch_seeds[-1]:09d}.pkl')
336 | image_path = os.path.join(image_dir, f'{batch_seeds[-1]:09d}.pkl')
337 |
338 | torch.save([torch.cat(latent_list, dim=0), torch.cat(label_list, dim=0)], latent_path)
339 | torch.save(torch.cat(image_list, dim=0), image_path)
340 |
341 | # Done.
342 | torch.distributed.barrier()
343 | dist.print0('Done.')
344 |
345 | #----------------------------------------------------------------------------
346 |
347 | if __name__ == "__main__":
348 | main()
349 |
350 | #----------------------------------------------------------------------------
351 |
--------------------------------------------------------------------------------
/data/dataset.sh:
--------------------------------------------------------------------------------
1 | torchrun --standalone --nproc_per_node=4 generate.py \
2 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl \
3 | --outdir=YOUR_UNCOND_DATA_PATH \
4 | --seeds=0-999999 \
5 | --batch 250
6 |
7 | torchrun --standalone --nproc_per_node=4 cond_generate.py \
8 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl \
9 | --outdir=YOUR_COND_DATA_PATH \
10 | --seeds=0-999999 \
11 | --batch 250
12 |
13 |
--------------------------------------------------------------------------------
/data/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Generate random images using the techniques described in the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import os
12 | import re
13 | import click
14 | import tqdm
15 | import pickle
16 | import numpy as np
17 | import torch
18 | import PIL.Image
19 | import dnnlib
20 | from torch_utils import distributed as dist
21 |
22 | #----------------------------------------------------------------------------
23 | # Proposed EDM sampler (Algorithm 2).
24 |
25 | def edm_sampler(
26 | net, latents, class_labels=None, randn_like=torch.randn_like,
27 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
28 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
29 | ):
30 | # Adjust noise levels based on what's supported by the network.
31 | sigma_min = max(sigma_min, net.sigma_min)
32 | sigma_max = min(sigma_max, net.sigma_max)
33 |
34 | # Time step discretization.
35 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
36 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
37 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
38 |
39 | # Main sampling loop.
40 | x_next = latents.to(torch.float64) * t_steps[0]
41 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
42 | x_cur = x_next
43 |
44 | # Increase noise temporarily.
45 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
46 | t_hat = net.round_sigma(t_cur + gamma * t_cur)
47 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
48 |
49 | # Euler step.
50 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
51 | d_cur = (x_hat - denoised) / t_hat
52 | x_next = x_hat + (t_next - t_hat) * d_cur
53 |
54 | # Apply 2nd order correction.
55 | if i < num_steps - 1:
56 | denoised = net(x_next, t_next, class_labels).to(torch.float64)
57 | d_prime = (x_next - denoised) / t_next
58 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
59 |
60 | return x_next
61 |
62 | #----------------------------------------------------------------------------
63 | # Generalized ablation sampler, representing the superset of all sampling
64 | # methods discussed in the paper.
65 |
66 | def ablation_sampler(
67 | net, latents, class_labels=None, randn_like=torch.randn_like,
68 | num_steps=18, sigma_min=None, sigma_max=None, rho=7,
69 | solver='heun', discretization='edm', schedule='linear', scaling='none',
70 | epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
71 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
72 | ):
73 | assert solver in ['euler', 'heun']
74 | assert discretization in ['vp', 've', 'iddpm', 'edm']
75 | assert schedule in ['vp', 've', 'linear']
76 | assert scaling in ['vp', 'none']
77 |
78 | # Helper functions for VP & VE noise level schedules.
79 | vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
80 | vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
81 | vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
82 | ve_sigma = lambda t: t.sqrt()
83 | ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
84 | ve_sigma_inv = lambda sigma: sigma ** 2
85 |
86 | # Select default noise level range based on the specified time step discretization.
87 | if sigma_min is None:
88 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s)
89 | sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
90 | if sigma_max is None:
91 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1)
92 | sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
93 |
94 | # Adjust noise levels based on what's supported by the network.
95 | sigma_min = max(sigma_min, net.sigma_min)
96 | sigma_max = min(sigma_max, net.sigma_max)
97 |
98 | # Compute corresponding betas for VP.
99 | vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
100 | vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
101 |
102 | # Define time steps in terms of noise level.
103 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
104 | if discretization == 'vp':
105 | orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
106 | sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
107 | elif discretization == 've':
108 | orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
109 | sigma_steps = ve_sigma(orig_t_steps)
110 | elif discretization == 'iddpm':
111 | u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
112 | alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
113 | for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
114 | u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
115 | u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
116 | sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
117 | else:
118 | assert discretization == 'edm'
119 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
120 |
121 | # Define noise level schedule.
122 | if schedule == 'vp':
123 | sigma = vp_sigma(vp_beta_d, vp_beta_min)
124 | sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
125 | sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
126 | elif schedule == 've':
127 | sigma = ve_sigma
128 | sigma_deriv = ve_sigma_deriv
129 | sigma_inv = ve_sigma_inv
130 | else:
131 | assert schedule == 'linear'
132 | sigma = lambda t: t
133 | sigma_deriv = lambda t: 1
134 | sigma_inv = lambda sigma: sigma
135 |
136 | # Define scaling schedule.
137 | if scaling == 'vp':
138 | s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
139 | s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
140 | else:
141 | assert scaling == 'none'
142 | s = lambda t: 1
143 | s_deriv = lambda t: 0
144 |
145 | # Compute final time steps based on the corresponding noise levels.
146 | t_steps = sigma_inv(net.round_sigma(sigma_steps))
147 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
148 |
149 | # Main sampling loop.
150 | t_next = t_steps[0]
151 | x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
152 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
153 | x_cur = x_next
154 |
155 | # Increase noise temporarily.
156 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
157 | t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
158 | x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)
159 |
160 | # Euler step.
161 | h = t_next - t_hat
162 | denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64)
163 | d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
164 | x_prime = x_hat + alpha * h * d_cur
165 | t_prime = t_hat + alpha * h
166 |
167 | # Apply 2nd order correction.
168 | if solver == 'euler' or i == num_steps - 1:
169 | x_next = x_hat + h * d_cur
170 | else:
171 | assert solver == 'heun'
172 | denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64)
173 | d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
174 | x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
175 |
176 | return x_next
177 |
178 | #----------------------------------------------------------------------------
179 | # Wrapper for torch.Generator that allows specifying a different random seed
180 | # for each sample in a minibatch.
181 |
182 | class StackedRandomGenerator:
183 | def __init__(self, device, seeds):
184 | super().__init__()
185 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
186 |
187 | def randn(self, size, **kwargs):
188 | assert size[0] == len(self.generators)
189 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
190 |
191 | def randn_like(self, input):
192 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
193 |
194 | def randint(self, *args, size, **kwargs):
195 | assert size[0] == len(self.generators)
196 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
197 |
198 | #----------------------------------------------------------------------------
199 | # Parse a comma separated list of numbers or ranges and return a list of ints.
200 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
201 |
202 | def parse_int_list(s):
203 | if isinstance(s, list): return s
204 | ranges = []
205 | range_re = re.compile(r'^(\d+)-(\d+)$')
206 | for p in s.split(','):
207 | m = range_re.match(p)
208 | if m:
209 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
210 | else:
211 | ranges.append(int(p))
212 | return ranges
213 |
214 | #----------------------------------------------------------------------------
215 |
216 | @click.command()
217 | @click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True)
218 | @click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True)
219 | @click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True)
220 | @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True)
221 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None)
222 | @click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
223 |
224 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True)
225 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
226 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
227 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
228 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
229 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
230 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True)
231 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True)
232 |
233 | @click.option('--solver', help='Ablate ODE solver', metavar='euler|heun', type=click.Choice(['euler', 'heun']))
234 | @click.option('--disc', 'discretization', help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm']))
235 | @click.option('--schedule', help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear', type=click.Choice(['vp', 've', 'linear']))
236 | @click.option('--scaling', help='Ablate signal scaling s(t)', metavar='vp|none', type=click.Choice(['vp', 'none']))
237 |
238 | def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=torch.device('cuda'), **sampler_kwargs):
239 | """Generate random images using the techniques described in the paper
240 | "Elucidating the Design Space of Diffusion-Based Generative Models".
241 |
242 | Examples:
243 |
244 | \b
245 | # Generate 64 images and save them as out/*.png
246 | python generate.py --outdir=out --seeds=0-63 --batch=64 \\
247 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
248 |
249 | \b
250 | # Generate 1024 images using 2 GPUs
251 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \\
252 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
253 | """
254 | dist.init()
255 | num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
256 | all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
257 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
258 |
259 | # Rank 0 goes first.
260 | if dist.get_rank() != 0:
261 | torch.distributed.barrier()
262 |
263 | # Load network.
264 | dist.print0(f'Loading network from "{network_pkl}"...')
265 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
266 | net = pickle.load(f)['ema'].to(device)
267 |
268 | # Other ranks follow.
269 | if dist.get_rank() == 0:
270 | torch.distributed.barrier()
271 |
272 | latent_list = []
273 | image_list = []
274 |
275 | # Loop over batches.
276 | dist.print0(f'Generating {len(seeds)} images to "{outdir}"...')
277 | for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)):
278 | torch.distributed.barrier()
279 | batch_size = len(batch_seeds)
280 | if batch_size == 0:
281 | continue
282 |
283 | # Pick latents and labels.
284 | rnd = StackedRandomGenerator(device, batch_seeds)
285 | latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
286 | class_labels = None
287 | if net.label_dim:
288 | class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)]
289 | if class_idx is not None:
290 | class_labels[:, :] = 0
291 | class_labels[:, class_idx] = 1
292 |
293 | # Generate images.
294 | sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
295 | have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling'])
296 | sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler
297 | images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs)
298 |
299 | # Save images.
300 | latent_list.append(latents.cpu())
301 | image_list.append(images.cpu())
302 |
303 | if len(image_list) % 50 == 0:
304 | save_dir = outdir
305 | os.makedirs(save_dir, exist_ok=True)
306 |
307 | image_dir = os.path.join(save_dir, 'image')
308 | os.makedirs(image_dir, exist_ok=True)
309 |
310 | latent_dir = os.path.join(save_dir, 'latent')
311 | os.makedirs(latent_dir, exist_ok=True)
312 |
313 | image_path = os.path.join(image_dir, f'{batch_seeds[-1]:09d}.pkl')
314 | latent_path = os.path.join(latent_dir, f'{batch_seeds[-1]:09d}.pkl')
315 |
316 | torch.save(torch.cat(image_list, dim=0), image_path)
317 | torch.save(torch.cat(latent_list, dim=0), latent_path)
318 |
319 | image_list = []
320 | latent_list = []
321 |
322 | if len(image_list) > 0:
323 | save_dir = outdir
324 | os.makedirs(save_dir, exist_ok=True)
325 |
326 | image_dir = os.path.join(save_dir, 'image')
327 | os.makedirs(image_dir, exist_ok=True)
328 |
329 | latent_dir = os.path.join(save_dir, 'latent')
330 | os.makedirs(latent_dir, exist_ok=True)
331 |
332 | image_path = os.path.join(image_dir, f'{batch_seeds[-1]:09d}.pkl')
333 | latent_path = os.path.join(latent_dir, f'{batch_seeds[-1]:09d}.pkl')
334 |
335 | torch.save(torch.cat(image_list, dim=0), image_path)
336 | torch.save(torch.cat(latent_list, dim=0), latent_path)
337 |
338 | # Done.
339 | torch.distributed.barrier()
340 | dist.print0('Done.')
341 |
342 | #----------------------------------------------------------------------------
343 |
344 | if __name__ == "__main__":
345 | main()
346 |
347 | #----------------------------------------------------------------------------
348 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 |
5 | def find_files(base_dir, world_size=1, rank=0):
6 | path_list = os.listdir(base_dir)
7 |
8 | sort_key = lambda f_name: int(f_name.split('.')[0])
9 | path_list.sort(key=sort_key)
10 |
11 | assert len(path_list) % world_size == 0
12 |
13 | for i, f in enumerate(path_list):
14 | if i % world_size == rank:
15 | f_path = os.path.join(base_dir, f)
16 | yield f_path
17 |
18 |
19 | class PairedDataset(torch.utils.data.Dataset):
20 | def __init__(self,
21 | data_dir,
22 | world_size=1,
23 | rank=0,
24 | ):
25 | super().__init__()
26 |
27 | self.world_size = world_size
28 | self.data_dir = data_dir
29 |
30 | latent_dir = os.path.join(data_dir, 'latent')
31 | image_dir = os.path.join(data_dir, 'image')
32 |
33 | Z = list()
34 | for f in find_files(latent_dir, world_size, rank):
35 | z = torch.load(f)
36 | Z.append(z)
37 | Z = torch.cat(Z, dim=0)
38 |
39 | X = list()
40 | for f in find_files(image_dir, world_size, rank):
41 | x = torch.load(f)
42 | X.append(x)
43 | X = torch.cat(X, dim=0)
44 |
45 | assert len(Z) == len(X)
46 |
47 | self.Z = Z
48 | self.X = X
49 |
50 | def __len__(self):
51 | return len(self.X) * self.world_size
52 |
53 | def __getitem__(self, idx):
54 | idx = idx // self.world_size
55 | z, x = self.Z[idx], self.X[idx]
56 |
57 | return z, x
58 |
59 |
60 | class PairedCondDataset(torch.utils.data.Dataset):
61 | def __init__(self,
62 | data_dir,
63 | world_size=1,
64 | rank=0,
65 | ):
66 | super().__init__()
67 |
68 | self.world_size = world_size
69 | self.data_dir = data_dir
70 |
71 | latent_dir = os.path.join(data_dir, 'latent')
72 | image_dir = os.path.join(data_dir, 'image')
73 |
74 | Z = list()
75 | C = list()
76 | for f in find_files(latent_dir, world_size, rank):
77 | z, c = torch.load(f)
78 | Z.append(z)
79 | C.append(c)
80 | Z = torch.cat(Z, dim=0)
81 | C = torch.cat(C, dim=0)
82 |
83 | X = list()
84 | for f in find_files(image_dir, world_size, rank):
85 | x = torch.load(f)
86 | X.append(x)
87 | X = torch.cat(X, dim=0)
88 |
89 | assert len(Z) == len(X)
90 |
91 | self.Z = Z
92 | self.C = C
93 | self.X = X
94 |
95 | def __len__(self):
96 | return len(self.X) * self.world_size
97 |
98 | def __getitem__(self, idx):
99 | idx = idx // self.world_size
100 | z, c, x = self.Z[idx], self.C[idx], self.X[idx]
101 |
102 | return z, x, c
103 |
104 |
105 | class EpochPairedDataset(torch.utils.data.Dataset):
106 | def __init__(self,
107 | data_dir,
108 | total=10,
109 | epoch=0,
110 | world_size=1,
111 | rank=0,
112 | ):
113 | super().__init__()
114 |
115 | self.world_size = world_size
116 |
117 | data_dir = data_dir + f'-{epoch%total}'
118 | self.data_dir = data_dir
119 |
120 | latent_dir = os.path.join(data_dir, 'latent')
121 | image_dir = os.path.join(data_dir, 'image')
122 |
123 | Z = list()
124 | for f in find_files(latent_dir, world_size, rank):
125 | z = torch.load(f)
126 | Z.append(z)
127 | Z = torch.cat(Z, dim=0)
128 |
129 | X = list()
130 | for f in find_files(image_dir, world_size, rank):
131 | x = torch.load(f)
132 | X.append(x)
133 | X = torch.cat(X, dim=0)
134 |
135 | assert len(Z) == len(X)
136 |
137 | self.Z = Z
138 | self.X = X
139 |
140 | def __len__(self):
141 | return len(self.X) * self.world_size
142 |
143 | def __getitem__(self, idx):
144 | idx = idx // self.world_size
145 | z, x = self.Z[idx], self.X[idx]
146 |
147 | return z, x
148 |
149 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: GET
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python >= 3.8
7 | - pytorch >= 1.11
8 | - torchvision
9 | - cudatoolkit >= 11.3.1
10 | - pip:
11 | - torchdeq
12 | - timm
13 | - diffusers
14 | - pytorch-gan-metrics
15 | - piq
16 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import re
4 | import time
5 |
6 | import torch
7 | # the first flag below was False when we tested this script but True makes A100 training a lot faster:
8 | torch.backends.cuda.matmul.allow_tf32 = True
9 | torch.backends.cudnn.allow_tf32 = True
10 |
11 | import torch.distributed as dist
12 | from torch.nn.parallel import DistributedDataParallel as DDP
13 | from torch.utils.data import DataLoader
14 |
15 | from utils import (
16 | create_logger, requires_grad,
17 | sample_image, sample_fid, compute_fid_is
18 | )
19 | from models import model_dict
20 |
21 | from torchdeq import add_deq_args
22 |
23 |
24 | def main(args):
25 | '''
26 | Model evaluation.
27 | '''
28 | # Setup DDP
29 | dist.init_process_group('nccl')
30 | rank = dist.get_rank()
31 | device = rank % torch.cuda.device_count()
32 | seed = args.global_seed * dist.get_world_size() + rank
33 | torch.manual_seed(seed)
34 | torch.cuda.set_device(device)
35 | print(f'Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.')
36 |
37 | # Setup an experiment folder
38 | if rank == 0:
39 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
40 | resume_dir = re.split('/|\.', args.resume)
41 | folder_name = f'eval-{resume_dir[-4]}-{resume_dir[-2]}-{args.name}'
42 | experiment_dir = f'{args.results_dir}/{folder_name}' # Create an experiment folder
43 | os.makedirs(experiment_dir, exist_ok=True)
44 |
45 | logger = create_logger(experiment_dir)
46 | logger.info(f'Experiment directory created at {experiment_dir}')
47 | else:
48 | logger = create_logger()
49 |
50 | # Create model
51 | model = model_dict[args.model](
52 | args=args,
53 | input_size=args.input_size,
54 | num_classes=args.num_classes,
55 | cond=args.cond
56 | )
57 | ema = model_dict[args.model](
58 | args=args,
59 | input_size=args.input_size,
60 | num_classes=args.num_classes,
61 | cond=args.cond
62 | ).to(device)
63 | requires_grad(ema, False)
64 |
65 | # Setup DDP
66 | model = DDP(model.to(device), device_ids=[rank])
67 | logger.info(f'Model Parameters: {sum(p.numel() for p in model.parameters()):,}')
68 |
69 | model.eval()
70 | ema.eval()
71 |
72 | # Resume from the given checkpoint
73 | if args.resume:
74 | ckpt = torch.load(args.resume, map_location=torch.device('cpu'))
75 | model.module.load_state_dict(ckpt['model'])
76 | ema.load_state_dict(ckpt['ema'])
77 | logger.info(f'Resume from {args.resume}..')
78 |
79 | # Sample images
80 | if rank == 0:
81 | image_path = f'{experiment_dir}/samples.png'
82 | sample_image(args, ema, device, image_path, cond=args.cond)
83 | logger.info(f'Saved samples to {image_path}')
84 | dist.barrier()
85 |
86 | # Compute FID and IS
87 | start_time = time.time()
88 | images = sample_fid(args, ema, device, rank, cond=args.cond)
89 | end_time = time.time()
90 | logger.info(f'Time for sampling 50k images {end_time-start_time:.2f}s.')
91 |
92 | # DDP sync for FID evaluation
93 | all_images = [torch.zeros_like(images) for _ in range(dist.get_world_size())]
94 | dist.gather(images, all_images if rank == 0 else None, dst=0)
95 | if rank == 0:
96 | FID, IS = compute_fid_is(args, all_images, rank)
97 | logger.info(f'FID {FID:0.2f}, IS {IS:0.2f}.')
98 |
99 | dist.barrier()
100 | dist.destroy_process_group()
101 |
102 |
103 | if __name__ == '__main__':
104 | parser = argparse.ArgumentParser()
105 | parser.add_argument('--results_dir', type=str, default='eval-results')
106 | parser.add_argument('--name', type=str, default='debug')
107 |
108 | parser.add_argument('--model', type=str, choices=list(model_dict.keys()), default='GET-S/2')
109 | parser.add_argument('--input_size', type=int, default=32)
110 |
111 | parser.add_argument('--cond', action='store_true', help='Run conditional model.')
112 | parser.add_argument('--num_classes', type=int, default=10)
113 |
114 | parser.add_argument('--global_seed', type=int, default=42)
115 | parser.add_argument('--num_workers', type=int, default=4)
116 |
117 | parser.add_argument('--mem', action='store_true', help='Enable O1 memory.')
118 |
119 | parser.add_argument('--eval_batch_size', type=int, default=128)
120 | parser.add_argument('--eval_samples', type=int, default=50000)
121 | parser.add_argument('--stat_path', type=str, default='YOUR_STAT_PATH/cifar10.test.npz')
122 |
123 | parser.add_argument('--resume', help="restore checkpoint for training")
124 |
125 | # Add for DEQs
126 | add_deq_args(parser)
127 |
128 | args = parser.parse_args()
129 | main(args)
130 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | torchrun --nnodes=1 --nproc_per_node=$1 --rdzv_endpoint=localhost:$2 \
2 | eval.py \
3 | --eval_f_max_iter 6 \
4 | --norm_type none \
5 | --stat_path YOUR_STAT_PATH \
6 | ${@:3}
7 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from piq import LPIPS, DISTS
6 |
7 |
8 | class L1Loss(nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def forward(self, x_pred, x):
13 | return (x_pred - x).abs().mean()
14 |
15 |
16 | class L2Loss(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 |
20 | def forward(self, x_pred, x):
21 | return ((x_pred - x) ** 2).mean()
22 |
23 |
24 | class LPIPSLoss(nn.Module):
25 | def __init__(self):
26 | super().__init__()
27 |
28 | self.loss = LPIPS()
29 |
30 | def forward(self, x_pred, x):
31 | x_pred = F.interpolate(x_pred, size=224, mode="bilinear")
32 | x = F.interpolate(x.float(), size=224, mode="bilinear")
33 |
34 | x_pred = (x_pred + 1) / 2
35 | x = (x + 1) / 2
36 |
37 | return self.loss(x_pred, x)
38 |
39 |
40 | class DISTSLoss(nn.Module):
41 | def __init__(self):
42 | super().__init__()
43 |
44 | self.loss = DISTS()
45 |
46 | def forward(self, x_pred, x):
47 | x_pred = F.interpolate(x_pred, size=224, mode="bilinear")
48 | x = F.interpolate(x.float(), size=224, mode="bilinear")
49 |
50 | x_pred = (x_pred + 1) / 2
51 | x = (x + 1) / 2
52 |
53 | return self.loss(x_pred, x)
54 |
55 |
56 | loss_dict = {
57 | 'l1': L1Loss,
58 | 'l2': L2Loss,
59 | 'lpips': LPIPSLoss,
60 | 'dists': DISTSLoss
61 | }
62 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .get import GET_models
3 | from .vit import ViT_models
4 |
5 |
6 | model_dict = {}
7 | model_dict.update(GET_models)
8 | model_dict.update(ViT_models)
9 |
10 |
11 |
--------------------------------------------------------------------------------
/models/get.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # References:
3 | # DiT: https://github.com/facebookresearch/DiT
4 | # MAE: https://github.com/facebookresearch/mae
5 | # --------------------------------------------------------
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | import numpy as np
12 | import math
13 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
14 |
15 | from torchdeq import get_deq
16 | from torchdeq.norm import apply_norm, reset_norm
17 | from torchdeq.utils import mem_gc
18 |
19 |
20 | # Postional Embedding
21 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
22 |
23 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
24 | """
25 | grid_size: int of the grid height and width
26 | return:
27 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
28 | """
29 | grid_h = np.arange(grid_size, dtype=np.float32)
30 | grid_w = np.arange(grid_size, dtype=np.float32)
31 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
32 | grid = np.stack(grid, axis=0)
33 |
34 | grid = grid.reshape([2, 1, grid_size, grid_size])
35 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
36 | if cls_token and extra_tokens > 0:
37 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
38 | return pos_embed
39 |
40 |
41 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
42 | assert embed_dim % 2 == 0
43 |
44 | # use half of dimensions to encode grid_h
45 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
46 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
47 |
48 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
49 | return emb
50 |
51 |
52 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
53 | """
54 | embed_dim: output dimension for each position
55 | pos: a list of positions to be encoded: size (M,)
56 | out: (M, D)
57 | """
58 | assert embed_dim % 2 == 0
59 | omega = np.arange(embed_dim // 2, dtype=np.float64)
60 | omega /= embed_dim / 2.
61 | omega = 1. / 10000**omega # (D/2,)
62 |
63 | pos = pos.reshape(-1) # (M,)
64 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
65 |
66 | emb_sin = np.sin(out) # (M, D/2)
67 | emb_cos = np.cos(out) # (M, D/2)
68 |
69 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
70 | return emb
71 |
72 |
73 | class ClassEmbedding(nn.Module):
74 | def __init__(self, num_classes, hidden_size):
75 | super().__init__()
76 | self.embedding_table = nn.Embedding(num_classes, hidden_size)
77 |
78 | def forward(self, labels):
79 | return self.embedding_table(labels)
80 |
81 |
82 | class AttnInterface(nn.Module):
83 | def __init__(
84 | self,
85 | dim,
86 | num_heads=8,
87 | qkv_bias=False,
88 | qk_norm=False,
89 | attn_drop=0.,
90 | proj_drop=0.,
91 | norm_layer=nn.LayerNorm,
92 | cond=False
93 | ):
94 | super().__init__()
95 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
96 | self.num_heads = num_heads
97 | self.head_dim = dim // num_heads
98 | self.scale = self.head_dim ** -0.5
99 | self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
100 |
101 | self.cond = cond
102 |
103 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
104 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
105 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
106 | self.attn_drop = nn.Dropout(attn_drop)
107 | self.proj = nn.Linear(dim, dim)
108 | self.proj_drop = nn.Dropout(proj_drop)
109 |
110 | def forward(self, x, c, u=None):
111 | B, N, C = x.shape
112 | qkv = self.qkv(x)
113 |
114 | # Injection
115 | if self.cond:
116 | qkv = qkv + c
117 | if u is not None:
118 | qkv = qkv + u
119 |
120 | qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
121 |
122 | q, k, v = qkv.unbind(0)
123 | q, k = self.q_norm(q), self.k_norm(k)
124 |
125 | if self.fast_attn:
126 | x = F.scaled_dot_product_attention(
127 | q, k, v,
128 | dropout_p=self.attn_drop.p,
129 | )
130 | else:
131 | q = q * self.scale
132 | attn = q @ k.transpose(-2, -1)
133 | attn = attn.softmax(dim=-1)
134 | attn = self.attn_drop(attn)
135 | x = attn @ v
136 |
137 | x = x.transpose(1, 2).reshape(B, N, C)
138 | x = self.proj(x)
139 | x = self.proj_drop(x)
140 | return x
141 |
142 |
143 | class GETBlock(nn.Module):
144 | """
145 | A GET block with additive attention injection.
146 | """
147 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, cond=False, **block_kwargs):
148 | super().__init__()
149 | # Attention
150 | self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6)
151 | self.attn = AttnInterface(hidden_size, num_heads=num_heads, qkv_bias=True, cond=cond, **block_kwargs)
152 |
153 | # MLP
154 | self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6)
155 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
156 |
157 | act = lambda: nn.GELU()
158 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=act, drop=0)
159 |
160 | def forward(self, x, c, u=None):
161 | x = x + self.attn(self.norm1(x), c, u)
162 | x = x + self.mlp(self.norm2(x))
163 | return x
164 |
165 |
166 | class FinalLayer(nn.Module):
167 | """
168 | The final projection layer.
169 | """
170 | def __init__(self, hidden_size, patch_size, out_channels, cond=False):
171 | super().__init__()
172 | self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6)
173 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
174 |
175 | def forward(self, x):
176 | x = self.norm_final(x)
177 | x = self.linear(x)
178 | return x
179 |
180 |
181 | class GET(nn.Module):
182 | """
183 | Diffusion model with a Transformer backbone.
184 | """
185 | def __init__(
186 | self,
187 | args,
188 | input_size=32,
189 | patch_size=2,
190 | in_channels=3,
191 | hidden_size=1152,
192 | depth=28,
193 | deq_depth=3,
194 | num_heads=16,
195 | mlp_ratio=4.0,
196 | deq_mlp_ratio=16.0,
197 | num_classes=10,
198 | cond=False
199 | ):
200 | super().__init__()
201 | self.in_channels = in_channels
202 | self.out_channels = in_channels
203 | self.patch_size = patch_size
204 | self.num_heads = num_heads
205 | self.deq_depth = deq_depth
206 |
207 | self.cond = cond
208 |
209 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
210 |
211 | if self.cond:
212 | self.y_embedder = ClassEmbedding(num_classes, 3*hidden_size)
213 |
214 | num_patches = self.x_embedder.num_patches
215 | # Will use fixed sin-cos embedding:
216 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
217 |
218 | self.blocks = nn.ModuleList([
219 | GETBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, cond=cond) for _ in range(depth)
220 | ])
221 |
222 | # injection
223 | self.qkv_inj = nn.Linear(hidden_size, hidden_size*3*deq_depth, bias=False)
224 |
225 | # DEQ blocks
226 | self.deq_blocks = nn.ModuleList([
227 | GETBlock(hidden_size, num_heads, mlp_ratio=deq_mlp_ratio, cond=cond) for _ in range(deq_depth)
228 | ])
229 |
230 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, cond=cond)
231 | self.initialize_weights()
232 |
233 | self.mem = args.mem
234 | self.deq = get_deq(args)
235 | apply_norm(self.deq_blocks, args=args)
236 |
237 | def initialize_weights(self):
238 | # Initialize transformer layers:
239 | def _basic_init(module):
240 | if isinstance(module, nn.Linear):
241 | torch.nn.init.xavier_uniform_(module.weight)
242 | if module.bias is not None:
243 | nn.init.constant_(module.bias, 0)
244 | if isinstance(module, nn.LayerNorm):
245 | if module.weight is not None:
246 | nn.init.constant_(module.weight, 1)
247 | nn.init.constant_(module.bias, 0)
248 | self.apply(_basic_init)
249 |
250 | # Initialize (and freeze) pos_embed by sin-cos embedding:
251 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
252 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
253 |
254 | # Initialize patch embedding:
255 | w = self.x_embedder.proj.weight.data
256 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
257 | nn.init.constant_(self.x_embedder.proj.bias, 0)
258 |
259 | # Initialize class embedding table:
260 | if self.cond:
261 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
262 |
263 | nn.init.constant_(self.final_layer.linear.weight, 0)
264 | nn.init.constant_(self.final_layer.linear.bias, 0)
265 |
266 | def unpatchify(self, x):
267 | """
268 | x: (B, N, P ** 2 * C)
269 | imgs: (B, H, W, C)
270 | """
271 | c = self.out_channels
272 | p = self.x_embedder.patch_size[0]
273 | h = w = int(x.shape[1] ** 0.5)
274 | assert h * w == x.shape[1]
275 |
276 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
277 | x = torch.einsum('nhwpqc->nchpwq', x)
278 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
279 | return imgs
280 |
281 | def decode(self, z):
282 | x = self.final_layer(z) # (B, N, P ** 2 * C_out)
283 | x = self.unpatchify(x) # (B, C_out, H, W)
284 | return x
285 |
286 | def forward(self, x, y=None):
287 | """
288 | Forward pass of GET.
289 | x: (B, C, H, W) tensor of spatial inputs (images or latent representations of images)
290 | y: (B,) tensor of class labels
291 | """
292 | reset_norm(self)
293 |
294 | x = self.x_embedder(x) + self.pos_embed # (B, N, D), where N = H * W / P ** 2
295 | B, N, C = x.shape
296 |
297 | c = None
298 | if self.cond:
299 | c = self.y_embedder(y).view(B, 1, 3*C)
300 |
301 | # Injection T
302 | for block in self.blocks:
303 | x = block(x, c) # (B, N, D)
304 |
305 | u = self.qkv_inj(x)
306 | u_list = u.chunk(self.deq_depth, dim=-1)
307 |
308 | def func(z):
309 | for block, u in zip(self.deq_blocks, u_list):
310 | if self.mem:
311 | z = mem_gc(block, (z, c, u))
312 | else:
313 | z = block(z, c, u)
314 | return z
315 |
316 | # Equilibrium T
317 | z = torch.randn_like(x)
318 | z_out, info = self.deq(func, z)
319 |
320 | if self.training:
321 | # For fixed point correction
322 | return [self.decode(z) for z in z_out]
323 | else:
324 | return self.decode(z_out[-1])
325 |
326 |
327 | #################################################################################
328 | # Current GET Configs #
329 | #################################################################################
330 |
331 |
332 | def GET_T_2_L6_L3_H6(args, **kwargs):
333 | return GET(args, depth=6, hidden_size=256, patch_size=2, num_heads=4, deq_mlp_ratio=6, **kwargs)
334 |
335 | def GET_M_2_L6_L3_H6(args, **kwargs):
336 | return GET(args, depth=6, hidden_size=384, patch_size=2, num_heads=6, deq_mlp_ratio=6, **kwargs)
337 |
338 | def GET_S_2_L6_L3_H8(args, **kwargs):
339 | return GET(args, depth=6, hidden_size=512, patch_size=2, num_heads=8, deq_mlp_ratio=8, **kwargs)
340 |
341 | def GET_B_2_L1_L3_H12(args, **kwargs):
342 | return GET(args, depth=1, hidden_size=768, patch_size=2, num_heads=12, deq_mlp_ratio=12, **kwargs)
343 |
344 | def GET_B_2_L6_L3_H8(args, **kwargs):
345 | return GET(args, depth=6, hidden_size=768, patch_size=2, num_heads=12, deq_mlp_ratio=8, **kwargs)
346 |
347 |
348 | GET_models = {
349 | 'GET-T/2': GET_T_2_L6_L3_H6,
350 | 'GET-M/2': GET_M_2_L6_L3_H6,
351 | 'GET-S/2': GET_S_2_L6_L3_H8,
352 | 'GET-B/2': GET_B_2_L1_L3_H12,
353 | 'GET-B/2+': GET_B_2_L6_L3_H8,
354 | }
355 |
--------------------------------------------------------------------------------
/models/vit.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # References:
3 | # DiT: https://github.com/facebookresearch/DiT
4 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
5 | # --------------------------------------------------------
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | import numpy as np
12 | import math
13 | from timm.models.vision_transformer import PatchEmbed, Mlp
14 |
15 | # Postional Embedding
16 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
17 |
18 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
19 | """
20 | grid_size: int of the grid height and width
21 | return:
22 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
23 | """
24 | grid_h = np.arange(grid_size, dtype=np.float32)
25 | grid_w = np.arange(grid_size, dtype=np.float32)
26 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
27 | grid = np.stack(grid, axis=0)
28 |
29 | grid = grid.reshape([2, 1, grid_size, grid_size])
30 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
31 | if cls_token and extra_tokens > 0:
32 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
33 | return pos_embed
34 |
35 |
36 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
37 | assert embed_dim % 2 == 0
38 |
39 | # use half of dimensions to encode grid_h
40 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
41 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
42 |
43 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
44 | return emb
45 |
46 |
47 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
48 | """
49 | embed_dim: output dimension for each position
50 | pos: a list of positions to be encoded: size (M,)
51 | out: (M, D)
52 | """
53 | assert embed_dim % 2 == 0
54 | omega = np.arange(embed_dim // 2, dtype=np.float64)
55 | omega /= embed_dim / 2.
56 | omega = 1. / 10000**omega # (D/2,)
57 |
58 | pos = pos.reshape(-1) # (M,)
59 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
60 |
61 | emb_sin = np.sin(out) # (M, D/2)
62 | emb_cos = np.cos(out) # (M, D/2)
63 |
64 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
65 | return emb
66 |
67 |
68 | class ClassEmbedding(nn.Module):
69 | def __init__(self, num_classes, hidden_size):
70 | super().__init__()
71 | self.embedding_table = nn.Embedding(num_classes, hidden_size)
72 |
73 | def forward(self, labels):
74 | return self.embedding_table(labels)
75 |
76 |
77 | class DEQAttention(nn.Module):
78 | def __init__(
79 | self,
80 | dim,
81 | num_heads=8,
82 | qkv_bias=False,
83 | qk_norm=False,
84 | attn_drop=0.,
85 | proj_drop=0.,
86 | norm_layer=nn.LayerNorm,
87 | cond=False
88 | ):
89 | super().__init__()
90 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
91 | self.num_heads = num_heads
92 | self.head_dim = dim // num_heads
93 | self.scale = self.head_dim ** -0.5
94 | self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
95 |
96 | self.cond = cond
97 |
98 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
99 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
100 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
101 | self.attn_drop = nn.Dropout(attn_drop)
102 | self.proj = nn.Linear(dim, dim)
103 | self.proj_drop = nn.Dropout(proj_drop)
104 |
105 | def forward(self, x, c=None):
106 | B, N, C = x.shape
107 | qkv = self.qkv(x)
108 |
109 | # Injection
110 | if self.cond:
111 | qkv = qkv + c
112 |
113 | qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
114 |
115 | q, k, v = qkv.unbind(0)
116 | q, k = self.q_norm(q), self.k_norm(k)
117 |
118 | if self.fast_attn:
119 | x = F.scaled_dot_product_attention(
120 | q, k, v,
121 | dropout_p=self.attn_drop.p,
122 | )
123 | else:
124 | q = q * self.scale
125 | attn = q @ k.transpose(-2, -1)
126 | attn = attn.softmax(dim=-1)
127 | attn = self.attn_drop(attn)
128 | x = attn @ v
129 |
130 | x = x.transpose(1, 2).reshape(B, N, C)
131 | x = self.proj(x)
132 | x = self.proj_drop(x)
133 | return x
134 |
135 |
136 | class ViTBlock(nn.Module):
137 | """
138 | A standard ViT block.
139 | """
140 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, cond=False, **block_kwargs):
141 | super().__init__()
142 | self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6)
143 | self.attn = DEQAttention(hidden_size, num_heads=num_heads, qkv_bias=True, cond=cond, **block_kwargs)
144 | self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6)
145 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
146 |
147 | # For Pytorch 1.13
148 | act = lambda: nn.GELU()
149 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=act, drop=0)
150 |
151 | self.cond = cond
152 |
153 | def forward(self, x, c):
154 | if self.cond:
155 | x = x + self.attn(self.norm1(x), c)
156 | x = x + self.mlp(self.norm2(x))
157 | else:
158 | x = x + self.attn(self.norm1(x))
159 | x = x + self.mlp(self.norm2(x))
160 |
161 | return x
162 |
163 |
164 | class FinalLayer(nn.Module):
165 | def __init__(self, hidden_size, patch_size, out_channels, cond=False):
166 | super().__init__()
167 | self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6)
168 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
169 |
170 | def forward(self, x):
171 | x = self.norm_final(x)
172 | x = self.linear(x)
173 | return x
174 |
175 |
176 | class ViT(nn.Module):
177 | """
178 | Learning fast image generation using ViT.
179 | """
180 | def __init__(
181 | self,
182 | args,
183 | input_size=32,
184 | patch_size=2,
185 | in_channels=3,
186 | hidden_size=1152,
187 | depth=28,
188 | num_heads=16,
189 | mlp_ratio=4.0,
190 | num_classes=10,
191 | cond=False
192 | ):
193 | super().__init__()
194 | self.in_channels = in_channels
195 | self.out_channels = in_channels
196 | self.patch_size = patch_size
197 | self.num_heads = num_heads
198 |
199 | self.cond = cond
200 |
201 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
202 |
203 | if self.cond:
204 | self.y_embedder = ClassEmbedding(num_classes, 3*hidden_size)
205 |
206 | num_patches = self.x_embedder.num_patches
207 | # Will use fixed sin-cos embedding:
208 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
209 |
210 | self.blocks = nn.ModuleList([
211 | ViTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, cond=cond) for _ in range(depth)
212 | ])
213 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, cond=cond)
214 | self.initialize_weights()
215 |
216 | def initialize_weights(self):
217 | # Initialize transformer layers:
218 | def _basic_init(module):
219 | if isinstance(module, nn.Linear):
220 | torch.nn.init.xavier_uniform_(module.weight)
221 | if module.bias is not None:
222 | nn.init.constant_(module.bias, 0)
223 | if isinstance(module, nn.LayerNorm):
224 | if hasattr(module, 'weight') and module.weight is not None:
225 | nn.init.constant_(module.weight, 1)
226 | nn.init.constant_(module.bias, 0)
227 | self.apply(_basic_init)
228 |
229 | # Initialize (and freeze) pos_embed by sin-cos embedding:
230 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
231 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
232 |
233 | # Initialize patch_embed like nn.Linear:
234 | w = self.x_embedder.proj.weight.data
235 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
236 | nn.init.constant_(self.x_embedder.proj.bias, 0)
237 |
238 | # Initialize label embedding:
239 | if self.cond:
240 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
241 |
242 | nn.init.constant_(self.final_layer.linear.weight, 0)
243 | nn.init.constant_(self.final_layer.linear.bias, 0)
244 |
245 | def unpatchify(self, x):
246 | """
247 | x: (B, N, P ** 2 * C)
248 | imgs: (B, H, W, C)
249 | """
250 | c = self.out_channels
251 | p = self.x_embedder.patch_size[0]
252 | h = w = int(x.shape[1] ** 0.5)
253 | assert h * w == x.shape[1]
254 |
255 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
256 | x = torch.einsum('nhwpqc->nchpwq', x)
257 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
258 | return imgs
259 |
260 | def forward(self, x, y=None):
261 | """
262 | Forward pass of ViT.
263 | x: (B, C, H, W) tensor of spatial inputs (images or latent representations of images)
264 | y: (B,) tensor of class labels
265 | """
266 | x = self.x_embedder(x) + self.pos_embed # (B, N, D), where N = H * W / P ** 2
267 |
268 | c = None
269 | if self.cond:
270 | c = self.y_embedder(y).view(B, 1, 3*C)
271 |
272 | for block in self.blocks:
273 | x = block(x, c) # (B, N, D)
274 | x = self.final_layer(x) # (B, N, P ** 2 * C_out)
275 |
276 | x = self.unpatchify(x) # (B, C_out, H, W)
277 | return x
278 |
279 |
280 |
281 | #################################################################################
282 | # ViT Configs #
283 | #################################################################################
284 |
285 | def ViT_XL_2(args, **kwargs):
286 | return ViT(args, depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
287 |
288 | def ViT_XL_4(args, **kwargs):
289 | return ViT(args, depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
290 |
291 | def ViT_XL_8(args, **kwargs):
292 | return ViT(args, depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
293 |
294 | def ViT_L_2(args, **kwargs):
295 | return ViT(args, depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
296 |
297 | def ViT_L_4(args, **kwargs):
298 | return ViT(args, depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
299 |
300 | def ViT_L_8(args, **kwargs):
301 | return ViT(args, depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
302 |
303 | def ViT_B_2(args, **kwargs):
304 | return ViT(args, depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
305 |
306 | def ViT_B_4(args, **kwargs):
307 | return ViT(args, depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
308 |
309 | def ViT_B_8(args, **kwargs):
310 | return ViT(args, depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
311 |
312 | def ViT_S_2(args, **kwargs):
313 | return ViT(args, depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
314 |
315 | def ViT_S_4(args, **kwargs):
316 | return ViT(args, depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
317 |
318 | def ViT_S_8(args, **kwargs):
319 | return ViT(args, depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
320 |
321 |
322 | ViT_models = {
323 | 'ViT-XL/2': ViT_XL_2, 'ViT-XL/4': ViT_XL_4, 'ViT-XL/8': ViT_XL_8,
324 | 'ViT-L/2': ViT_L_2, 'ViT-L/4': ViT_L_4, 'ViT-L/8': ViT_L_8,
325 | 'ViT-B/2': ViT_B_2, 'ViT-B/4': ViT_B_4, 'ViT-B/8': ViT_B_8,
326 | 'ViT-S/2': ViT_S_2, 'ViT-S/4': ViT_S_4, 'ViT-S/8': ViT_S_8,
327 | }
328 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | torchrun --nnodes=1 --nproc_per_node=$1 --rdzv_endpoint=localhost:$2 \
2 | train.py \
3 | --epochs 1000 \
4 | --global_batch_size 128 \
5 | --grad 6 \
6 | --sup_gap 1 \
7 | --f_max_iter 0 \
8 | --eval_f_max_iter 6 \
9 | --norm_type none \
10 | --data_path YOUR_DATA_PATH \
11 | --stat_path YOUR_STAT_PATH \
12 | ${@:3}
13 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 |
5 | import numpy as np
6 | from glob import glob
7 |
8 | import torch
9 | # the first flag below was False when we tested this script but True makes A100 training a lot faster:
10 | torch.backends.cuda.matmul.allow_tf32 = True
11 | torch.backends.cudnn.allow_tf32 = True
12 |
13 | import torch.distributed as dist
14 | from torch.nn.parallel import DistributedDataParallel as DDP
15 | from torch.utils.data import DataLoader
16 | from torch.utils.data.distributed import DistributedSampler
17 |
18 | from torchprofile import profile_macs
19 |
20 | from utils import (
21 | create_logger, save_ckpt,
22 | update_ema, requires_grad,
23 | sample_image, sample_fid, compute_fid_is
24 | )
25 | from models import model_dict
26 | from losses import loss_dict
27 | from datasets import PairedDataset, PairedCondDataset
28 |
29 | # For future ImageNet training & sampling
30 | # from diffusers.models import AutoencoderKL
31 |
32 | from torchdeq import add_deq_args
33 | from torchdeq.loss import fp_correction
34 |
35 |
36 | def main(args):
37 | '''
38 | Model training.
39 | '''
40 | # Setup DDP
41 | dist.init_process_group('nccl')
42 | world_size = dist.get_world_size()
43 | rank = dist.get_rank()
44 | assert args.global_batch_size % world_size == 0, f'Batch size must be divisible by world size.'
45 | device = rank % torch.cuda.device_count()
46 | seed = args.global_seed * world_size + rank
47 | torch.manual_seed(seed)
48 | torch.cuda.set_device(device)
49 | print(f'Starting rank={rank}, seed={seed}, world_size={world_size}.')
50 |
51 | # Setup an experiment folder
52 | if rank == 0:
53 | os.makedirs(args.results_dir, exist_ok=True)
54 | experiment_index = len(glob(f'{args.results_dir}/*'))
55 | model_string_name = args.model.replace('/', '-')
56 | experiment_dir = f'{args.results_dir}/{experiment_index:03d}-{model_string_name}-{args.name}'
57 |
58 | checkpoint_dir = f'{experiment_dir}/checkpoints'
59 | os.makedirs(checkpoint_dir, exist_ok=True)
60 |
61 | sample_dir = f'{experiment_dir}/samples'
62 | os.makedirs(sample_dir, exist_ok=True)
63 |
64 | logger = create_logger(experiment_dir)
65 | logger.info(f'Experiment directory created at {experiment_dir}')
66 | else:
67 | logger = create_logger()
68 |
69 | # Create model
70 | model = model_dict[args.model](
71 | args=args,
72 | input_size=args.input_size,
73 | num_classes=args.num_classes,
74 | cond=args.cond
75 | )
76 | ema = model_dict[args.model](
77 | args=args,
78 | input_size=args.input_size,
79 | num_classes=args.num_classes,
80 | cond=args.cond
81 | ).to(device)
82 | requires_grad(ema, False)
83 |
84 | # Setup DDP
85 | model = DDP(model.to(device), device_ids=[rank])
86 | logger.info(f'Model Parameters: {sum(p.numel() for p in model.parameters()):,}')
87 |
88 | # Test FLOPs
89 | if rank == 0:
90 | test_case = torch.randn(1, 3, args.input_size, args.input_size).to(device)
91 | if args.cond:
92 | test_c = torch.randint(0, 10, (1,1)).to(device)
93 | macs = profile_macs(model, (test_case, test_c))
94 | del test_case, test_c
95 | else:
96 | macs = profile_macs(model, test_case)
97 | del test_case
98 | logger.info(f'Model MACs: {macs:,}')
99 | dist.barrier()
100 |
101 | # For future ImageNet training
102 | # vae = AutoencoderKL.from_pretrained(f'stabilityai/sd-vae-ft-{args.vae}').to(device)
103 |
104 | # Setup optimizer
105 | opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0)
106 |
107 | # Setup data
108 | if args.cond:
109 | dataset = PairedCondDataset(args.data_path, world_size=world_size, rank=rank)
110 | else:
111 | dataset = PairedDataset(args.data_path, world_size=world_size, rank=rank)
112 | sampler = DistributedSampler(
113 | dataset,
114 | num_replicas=world_size,
115 | rank=rank,
116 | shuffle=True,
117 | seed=args.global_seed
118 | )
119 | loader = DataLoader(
120 | dataset,
121 | batch_size=int(args.global_batch_size // world_size),
122 | shuffle=False,
123 | sampler=sampler,
124 | num_workers=args.num_workers,
125 | pin_memory=True,
126 | drop_last=True
127 | )
128 | logger.info(f'Dataset contains {len(dataset):,} images ({dataset.data_dir})')
129 |
130 | # Prepare models for training
131 | update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
132 | model.train()
133 | ema.eval() # EMA model should always be in eval mode
134 |
135 | # Loss fn
136 | loss_fn = loss_dict[args.loss]().to(device)
137 |
138 | # Variables for monitoring/logging purposes
139 | train_steps = 0
140 | log_steps = 0
141 | running_loss = 0
142 | total_steps = args.epochs * (len(dataset) / args.global_batch_size)
143 |
144 | # Resume from the prev checkpoint
145 | if args.resume:
146 | ckpt = torch.load(args.resume, map_location=torch.device('cpu'))
147 | model.module.load_state_dict(ckpt['model'])
148 | ema.load_state_dict(ckpt['ema'])
149 | opt.load_state_dict(ckpt['opt'])
150 | train_steps = max(args.resume_iter, 0)
151 |
152 | logger.info(f'Resume from {args.resume}..')
153 |
154 | start_time = time.time()
155 | logger.info(f'Training for {args.epochs} epochs...')
156 | for epoch in range(args.epochs):
157 | sampler.set_epoch(epoch)
158 | logger.info(f'Beginning epoch {epoch}...')
159 |
160 | for data in loader:
161 | # Unpack data
162 | if args.cond:
163 | z, x, c = data
164 | z, x, c = z.to(device), x.to(device), c.to(device).max(dim=1)[1]
165 | else:
166 | z, x = data
167 | z, x, c = z.to(device), x.to(device), None
168 |
169 | # Loss & Grad
170 | x_pred = model(z, c)
171 | loss, loss_list = fp_correction(loss_fn, (x_pred, x), return_loss_values=True)
172 | opt.zero_grad()
173 | loss.backward()
174 |
175 | # LR Warmup
176 | if train_steps < args.warmup_iter:
177 | curr_lr = args.lr * (train_steps+1) / args.warmup_iter
178 | opt.param_groups[0]['lr'] = curr_lr
179 |
180 | opt.step()
181 | update_ema(ema, model.module, decay=args.ema_decay)
182 |
183 | running_loss += loss_list[-1]
184 | log_steps += 1
185 | train_steps += 1
186 |
187 | # Log training progress
188 | if train_steps % args.log_every == 0:
189 | # Measure training speed
190 | torch.cuda.synchronize()
191 | end_time = time.time()
192 | steps_per_sec = log_steps / (end_time - start_time)
193 |
194 | # Reduce loss history over all processes
195 | avg_loss = torch.tensor(running_loss / log_steps, device=device)
196 | dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
197 | avg_loss = avg_loss.item() / world_size
198 | logger.info(f'(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}')
199 |
200 | # Reset monitoring variables
201 | running_loss = 0
202 | log_steps = 0
203 | start_time = time.time()
204 |
205 | # Save checkpoint
206 | if train_steps % args.ckpt_every == 0 and train_steps > 0:
207 | if rank == 0:
208 | checkpoint_path = f'{checkpoint_dir}/{train_steps:07d}.pth'
209 | save_ckpt(args, model, ema, opt, checkpoint_path)
210 | logger.info(f'Saved checkpoint to {checkpoint_path}')
211 | dist.barrier()
212 |
213 | # Save the latest checkpoint
214 | if train_steps % args.save_latest_every == 0 and train_steps > 0:
215 | if rank == 0:
216 | checkpoint_path = f'{checkpoint_dir}/latest.pth'
217 | save_ckpt(args, model, ema, opt, checkpoint_path)
218 | logger.info(f'Saved latest checkpoint to {checkpoint_path}')
219 | dist.barrier()
220 |
221 | # Sample images
222 | if train_steps % args.sample_every == 0 and train_steps > 0:
223 | if rank == 0:
224 | image_path = f'{sample_dir}/{train_steps}.png'
225 | sample_image(args, ema, device, image_path, cond=args.cond)
226 | logger.info(f'Saved samples to {image_path}')
227 | dist.barrier()
228 |
229 | # Compute FID and IS
230 | if train_steps % args.eval_every == 0 and train_steps > 0:
231 | images = sample_fid(args, ema, device, rank, cond=args.cond)
232 |
233 | # In case you want to sample from the online model
234 | # images = sample_fid(args, model.module, device, rank, cond=args.cond, set_grad=True)
235 |
236 | # DDP sync
237 | all_images = [torch.zeros_like(images) for _ in range(world_size)]
238 | dist.gather(images, all_images if rank == 0 else None, dst=0)
239 | if rank == 0:
240 | FID, IS = compute_fid_is(args, all_images, rank)
241 | logger.info(f'FID {FID:0.2f}, IS {IS:0.2f} at iters {train_steps}.')
242 |
243 | del images, all_images
244 | dist.barrier()
245 |
246 | # Check training schedule
247 | if train_steps > total_steps:
248 | break
249 |
250 | if rank == 0:
251 | checkpoint_path = f'{checkpoint_dir}/final.pth'
252 | save_ckpt(args, model, ema, opt, checkpoint_path)
253 | logger.info(f'Saved final checkpoint to {checkpoint_path}')
254 | dist.barrier()
255 |
256 | # Finish training
257 | dist.destroy_process_group()
258 |
259 |
260 | if __name__ == '__main__':
261 | parser = argparse.ArgumentParser()
262 | parser.add_argument('--data_path', type=str, required=True)
263 | parser.add_argument('--name', type=str, default='debug')
264 | parser.add_argument('--results_dir', type=str, default='results')
265 |
266 | parser.add_argument('--model', type=str, choices=list(model_dict.keys()), default='GET-S/2')
267 | parser.add_argument('--input_size', type=int, default=32)
268 |
269 | parser.add_argument('--cond', action='store_true', help='Run conditional model.')
270 | parser.add_argument('--num_classes', type=int, default=10)
271 |
272 | parser.add_argument('--loss', type=str, choices=['l1', 'l2', 'lpips', 'dists'], default='l1')
273 | parser.add_argument('--vae', type=str, choices=['ema', 'mse'], default='ema')
274 |
275 | parser.add_argument('--lr', type=float, default=1e-4)
276 | parser.add_argument('--warmup_iter', type=int, default=0, help="warmup for the given iterations")
277 | parser.add_argument('--ema_decay', type=float, default=0.9999)
278 |
279 | parser.add_argument('--epochs', type=int, default=1000)
280 | parser.add_argument('--global_batch_size', type=int, default=256)
281 | parser.add_argument('--global_seed', type=int, default=42)
282 | parser.add_argument('--num_workers', type=int, default=4)
283 |
284 | parser.add_argument('--mem', action='store_true', help='Enable O(1) memory.')
285 |
286 | parser.add_argument('--log_every', type=int, default=100)
287 | parser.add_argument('--ckpt_every', type=int, default=50000)
288 | parser.add_argument('--save_latest_every', type=int, default=10000)
289 | parser.add_argument('--sample_every', type=int, default=10000)
290 |
291 | parser.add_argument('--eval_every', type=int, default=50000)
292 | parser.add_argument('--eval_samples', type=int, default=50000)
293 | parser.add_argument('--eval_batch_size', type=int, default=128)
294 | parser.add_argument('--stat_path', type=str, default='YOUR_STAT_PATH/cifar10.test.npz')
295 |
296 | parser.add_argument('--resume', help="restore checkpoint for training")
297 | parser.add_argument('--resume_iter', type=int, default=-1, help="resume from the given iterations")
298 |
299 | # Add for DEQs
300 | add_deq_args(parser)
301 |
302 | args = parser.parse_args()
303 | main(args)
304 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import logging
4 | import torch
5 |
6 | import torch.distributed as dist
7 |
8 | from PIL import Image
9 | from pytorch_gan_metrics import get_inception_score_and_fid
10 |
11 |
12 | def create_logger(logging_dir=None):
13 | """
14 | Create a logger that writes to a log file and stdout.
15 | """
16 | if dist.get_rank() == 0: # real logger
17 | logging.basicConfig(
18 | level=logging.INFO,
19 | format='[\033[34m%(asctime)s\033[0m] %(message)s',
20 | datefmt='%Y-%m-%d %H:%M:%S',
21 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
22 | )
23 | logger = logging.getLogger(__name__)
24 | else: # dummy logger (does nothing)
25 | logger = logging.getLogger(__name__)
26 | logger.addHandler(logging.NullHandler())
27 | return logger
28 |
29 |
30 | @torch.no_grad()
31 | def update_ema(ema_model, model, decay=0.9999):
32 | '''
33 | Step the EMA model towards the current model.
34 | '''
35 | ema_params = OrderedDict(ema_model.named_parameters())
36 | model_params = OrderedDict(model.named_parameters())
37 |
38 | for name, param in model_params.items():
39 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
40 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
41 |
42 |
43 | def requires_grad(model, flag=True):
44 | '''
45 | Set requires_grad flag for all parameters in a model.
46 | '''
47 | for p in model.parameters():
48 | p.requires_grad = flag
49 |
50 |
51 | def save_ckpt(args, model, ema, opt, checkpoint_path):
52 | '''
53 | Save a checkpoint containing the online model, EMA, and optimizer states.
54 | '''
55 | checkpoint = {
56 | 'args': args,
57 | 'model': model.module.state_dict(),
58 | 'ema': ema.state_dict(),
59 | 'opt': opt.state_dict(),
60 | }
61 | torch.save(checkpoint, checkpoint_path)
62 |
63 |
64 | def sample_image(args, model, device, image_path, set_train=False, cond=False):
65 | '''
66 | sample a batch of images for visualization.
67 | set set_train to true if you are using the online model for sampling.
68 | '''
69 | model.eval()
70 |
71 | n_row = 16
72 | size = args.input_size
73 |
74 | z = torch.randn(n_row*n_row, 3, size, size).to(device)
75 | c = torch.randint(0, args.num_classes, (n_row*n_row,)).to(device) if cond else None
76 | with torch.no_grad():
77 | x = model(z, c)
78 |
79 | x = x.view(n_row, n_row, 3, size, size)
80 | x = (x * 127.5 + 128).clip(0, 255).to(torch.uint8)
81 | images = x.permute(0, 3, 1, 4, 2).reshape(n_row*size, n_row*size, 3).cpu().numpy()
82 |
83 | Image.fromarray(images, 'RGB').save(image_path)
84 | del images, x, z, c
85 | torch.cuda.empty_cache()
86 |
87 | if set_train:
88 | model.train()
89 |
90 |
91 | def num_to_groups(num, divisor):
92 | '''
93 | Compute number of samples in each batch to evenly divide the total eval samples.
94 | '''
95 | groups = num // divisor
96 | remainder = num % divisor
97 | arr = [divisor] * groups
98 | if remainder > 0:
99 | arr.append(remainder)
100 | return arr
101 |
102 |
103 | def sample_fid(args, model, device, rank, set_train=False, cond=False):
104 | '''
105 | Sample args.eval_samples images in parallel for FID and IS calculation. Default 50k images.
106 | Set set_train to True if you are using the online model for sampling.
107 | '''
108 | # Setup batches for each node
109 | assert args.eval_samples % dist.get_world_size() == 0
110 | samples_per_node = args.eval_samples // dist.get_world_size()
111 | batches = num_to_groups(samples_per_node, args.eval_batch_size)
112 |
113 | # Dist EMA/online evaluation
114 | # No need to use the DDP wrapper here
115 | # As we do not need grad sycn (by DDP)
116 | model.eval()
117 | model = model.to(device)
118 |
119 | n_cls = args.num_classes
120 | size = args.input_size
121 |
122 | images = []
123 | with torch.no_grad():
124 | for n in batches:
125 | z = torch.randn(n, 3, size, size).to(device)
126 | c = torch.randint(0, n_cls, (n,)).to(device) if cond else None
127 | x = model(z, c)
128 | images.append(x)
129 | images = torch.cat(images, dim=0)
130 |
131 | torch.cuda.empty_cache()
132 | if set_train:
133 | model.train()
134 |
135 | return images
136 |
137 |
138 | def compute_fid_is(args, all_images, rank):
139 | '''
140 | Compute FID and IS using provided images.
141 | '''
142 | # Post-process to images.
143 | all_images = torch.cat(all_images, dim=0)
144 | all_images = (all_images * 127.5 + 128).clip(0, 255).to(torch.uint8).float().div(255).cpu()
145 |
146 | # Compute FID & IS
147 | (IS, IS_std), FID = get_inception_score_and_fid(all_images, args.stat_path)
148 | torch.cuda.empty_cache()
149 |
150 | return FID, IS
151 |
152 |
--------------------------------------------------------------------------------