├── .gitignore
├── LICENSE
├── README.md
├── datasets.py
├── distributions.py
├── evaluate.py
├── fid
├── LICENSE_pytorch_fid
├── __init__.py
├── fid_score.py
└── inception.py
├── img
├── celebahq.png
└── model_diagram.png
├── lmdb_datasets.py
├── model.py
├── neural_ar_operations.py
├── neural_operations.py
├── requirements.txt
├── scripts
├── Dockerfile
├── convert_tfrecord_to_lmdb.py
├── create_celeba64_lmdb.py
├── create_ffhq_lmdb.py
└── precompute_fid_statistics.py
├── thirdparty
├── LICENSE_PyTorch
├── LICENSE_apache
├── LICENSE_torchvision
├── __init__.py
├── adamax.py
├── functions.py
├── inplaced_sync_batchnorm.py
├── lsun.py
└── swish.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | visualizations/
2 | .idea/
3 | .run/
4 | *.pyc
5 | *.pdf
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | NVIDIA Source Code License for NVAE
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means any person or entity that distributes its Work.
6 |
7 | “Software” means the original work of authorship made available under this License.
8 |
9 | “Work” means the Software and any additions to or derivative works of the Software that are made available under
10 | this License.
11 |
12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under
13 | U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include
14 | works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
15 |
16 | Works, including the Software, are “made available” under this License by including in or with the Work either
17 | (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License.
18 |
19 | 2. License Grant
20 |
21 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual,
22 | worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly
23 | display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
24 |
25 | 3. Limitations
26 |
27 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you
28 | include a complete copy of this License with your distribution, and (c) you retain without modification any
29 | copyright, patent, trademark, or attribution notices that are present in the Work.
30 |
31 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and
32 | distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use
33 | limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works
34 | that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution
35 | requirements in Section 3.1) will continue to apply to the Work itself.
36 |
37 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use
38 | non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative
39 | works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
40 |
41 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim,
42 | cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then
43 | your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately.
44 |
45 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos,
46 | or trademarks, except as necessary to reproduce the notices described in this License.
47 |
48 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the
49 | grant in Section 2.1) will terminate immediately.
50 |
51 | 4. Disclaimer of Warranty.
52 |
53 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
54 | WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU
55 | BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
56 |
57 | 5. Limitation of Liability.
58 |
59 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING
60 | NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
61 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR
62 | INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR
63 | DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN
64 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
65 |
66 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # The Official PyTorch Implementation of "NVAE: A Deep Hierarchical Variational Autoencoder" [(NeurIPS 2020 Spotlight Paper)](https://arxiv.org/abs/2007.03898)
2 |
3 |
7 |
8 |
9 |
10 | [NVAE](https://arxiv.org/abs/2007.03898) is a deep hierarchical variational autoencoder that enables training SOTA
11 | likelihood-based generative models on several image datasets.
12 |
13 |
14 |
15 |
16 |
17 | ## Requirements
18 | NVAE is built in Python 3.7 using PyTorch 1.6.0. Use the following command to install the requirements:
19 | ```
20 | pip install -r requirements.txt
21 | ```
22 |
23 | ## Set up file paths and data
24 | We have examined NVAE on several datasets. For large datasets, we store the data in LMDB datasets
25 | for I/O efficiency. Click below on each dataset to see how you can prepare your data. Below, `$DATA_DIR` indicates
26 | the path to a data directory that will contain all the datasets and `$CODE_DIR` refers to the code directory:
27 |
28 | MNIST and CIFAR-10
29 |
30 | These datasets will be downloaded automatically, when you run the main training for NVAE using `train.py`
31 | for the first time. You can use `--data=$DATA_DIR/mnist` or `--data=$DATA_DIR/cifar10`, so that the datasets
32 | are downloaded to the corresponding directories.
33 |
34 |
35 | CelebA 64
36 | Run the following commands to download the CelebA images and store them in an LMDB dataset:
37 |
38 | ```shell script
39 | cd $CODE_DIR/scripts
40 | python create_celeba64_lmdb.py --split train --img_path $DATA_DIR/celeba_org --lmdb_path $DATA_DIR/celeba64_lmdb
41 | python create_celeba64_lmdb.py --split valid --img_path $DATA_DIR/celeba_org --lmdb_path $DATA_DIR/celeba64_lmdb
42 | python create_celeba64_lmdb.py --split test --img_path $DATA_DIR/celeba_org --lmdb_path $DATA_DIR/celeba64_lmdb
43 | ```
44 | Above, the images will be downloaded to `$DATA_DIR/celeba_org` automatically and then then LMDB datasets are created
45 | at `$DATA_DIR/celeba64_lmdb`.
46 |
47 |
48 | ImageNet 32x32
49 |
50 | Run the following commands to download tfrecord files from [GLOW](https://github.com/openai/glow) and to convert them
51 | to LMDB datasets
52 | ```shell script
53 | mkdir -p $DATA_DIR/imagenet-oord
54 | cd $DATA_DIR/imagenet-oord
55 | wget https://storage.googleapis.com/glow-demo/data/imagenet-oord-tfr.tar
56 | tar -xvf imagenet-oord-tfr.tar
57 | cd $CODE_DIR/scripts
58 | python convert_tfrecord_to_lmdb.py --dataset=imagenet-oord_32 --tfr_path=$DATA_DIR/imagenet-oord/mnt/host/imagenet-oord-tfr --lmdb_path=$DATA_DIR/imagenet-oord/imagenet-oord-lmdb_32 --split=train
59 | python convert_tfrecord_to_lmdb.py --dataset=imagenet-oord_32 --tfr_path=$DATA_DIR/imagenet-oord/mnt/host/imagenet-oord-tfr --lmdb_path=$DATA_DIR/imagenet-oord/imagenet-oord-lmdb_32 --split=validation
60 | ```
61 |
62 |
63 | CelebA HQ 256
64 |
65 | Run the following commands to download tfrecord files from [GLOW](https://github.com/openai/glow) and to convert them
66 | to LMDB datasets
67 | ```shell script
68 | mkdir -p $DATA_DIR/celeba
69 | cd $DATA_DIR/celeba
70 | wget https://storage.googleapis.com/glow-demo/data/celeba-tfr.tar
71 | tar -xvf celeba-tfr.tar
72 | cd $CODE_DIR/scripts
73 | python convert_tfrecord_to_lmdb.py --dataset=celeba --tfr_path=$DATA_DIR/celeba/celeba-tfr --lmdb_path=$DATA_DIR/celeba/celeba-lmdb --split=train
74 | python convert_tfrecord_to_lmdb.py --dataset=celeba --tfr_path=$DATA_DIR/celeba/celeba-tfr --lmdb_path=$DATA_DIR/celeba/celeba-lmdb --split=validation
75 | ```
76 |
77 |
78 |
79 | FFHQ 256
80 |
81 | Visit [this Google drive location](https://drive.google.com/drive/folders/1WocxvZ4GEZ1DI8dOz30aSj2zT6pkATYS) and download
82 | `images1024x1024.zip`. Run the following commands to unzip the images and to store them in LMDB datasets:
83 | ```shell script
84 | mkdir -p $DATA_DIR/ffhq
85 | unzip images1024x1024.zip -d $DATA_DIR/ffhq/
86 | cd $CODE_DIR/scripts
87 | python create_ffhq_lmdb.py --ffhq_img_path=$DATA_DIR/ffhq/images1024x1024/ --ffhq_lmdb_path=$DATA_DIR/ffhq/ffhq-lmdb --split=train
88 | python create_ffhq_lmdb.py --ffhq_img_path=$DATA_DIR/ffhq/images1024x1024/ --ffhq_lmdb_path=$DATA_DIR/ffhq/ffhq-lmdb --split=validation
89 | ```
90 |
91 |
92 | LSUN
93 |
94 | We use LSUN datasets in our follow-up works. Visit [LSUN](https://www.yf.io/p/lsun) for
95 | instructions on how to download this dataset. Since the LSUN scene datasets come in the
96 | LMDB format, they are ready to be loaded using torchvision data loaders.
97 |
98 |
99 |
100 |
101 | ## Running the main NVAE training and evaluation scripts
102 | We use the following commands on each dataset for training NVAEs on each dataset for
103 | Table 1 in the [paper](https://arxiv.org/pdf/2007.03898.pdf). In all the datasets but MNIST
104 | normalizing flows are enabled. Check Table 6 in the paper for more information on training
105 | details. Note that for the multinode training (more than 8-GPU experiments), we use the `mpirun`
106 | command to run the training scripts on multiple nodes. Please adjust the commands below according to your setup.
107 | Below `IP_ADDR` is the IP address of the machine that will host the process with rank 0
108 | (see [here](https://pytorch.org/tutorials/intermediate/dist_tuto.html#initialization-methods)).
109 | `NODE_RANK` is the index of each node among all the nodes that are running the job.
110 |
111 | MNIST
112 |
113 | Two 16-GB V100 GPUs are used for training NVAE on dynamically binarized MNIST. Training takes about 21 hours.
114 |
115 | ```shell script
116 | export EXPR_ID=UNIQUE_EXPR_ID
117 | export DATA_DIR=PATH_TO_DATA_DIR
118 | export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR
119 | export CODE_DIR=PATH_TO_CODE_DIR
120 | cd $CODE_DIR
121 | python train.py --data $DATA_DIR/mnist --root $CHECKPOINT_DIR --save $EXPR_ID --dataset mnist --batch_size 200 \
122 | --epochs 400 --num_latent_scales 2 --num_groups_per_scale 10 --num_postprocess_cells 3 --num_preprocess_cells 3 \
123 | --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 --num_latent_per_group 20 --num_preprocess_blocks 2 \
124 | --num_postprocess_blocks 2 --weight_decay_norm 1e-2 --num_channels_enc 32 --num_channels_dec 32 --num_nf 0 \
125 | --ada_groups --num_process_per_node 2 --use_se --res_dist --fast_adamax
126 | ```
127 |
128 |
129 | CIFAR-10
130 |
131 | Eight 16-GB V100 GPUs are used for training NVAE on CIFAR-10. Training takes about 55 hours.
132 |
133 | ```shell script
134 | export EXPR_ID=UNIQUE_EXPR_ID
135 | export DATA_DIR=PATH_TO_DATA_DIR
136 | export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR
137 | export CODE_DIR=PATH_TO_CODE_DIR
138 | cd $CODE_DIR
139 | python train.py --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR --save $EXPR_ID --dataset cifar10 \
140 | --num_channels_enc 128 --num_channels_dec 128 --epochs 400 --num_postprocess_cells 2 --num_preprocess_cells 2 \
141 | --num_latent_scales 1 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
142 | --num_preprocess_blocks 1 --num_postprocess_blocks 1 --num_groups_per_scale 30 --batch_size 32 \
143 | --weight_decay_norm 1e-2 --num_nf 1 --num_process_per_node 8 --use_se --res_dist --fast_adamax
144 | ```
145 |
146 |
147 | CelebA 64
148 |
149 | Eight 16-GB V100 GPUs are used for training NVAE on CelebA 64. Training takes about 92 hours.
150 |
151 | ```shell script
152 | export EXPR_ID=UNIQUE_EXPR_ID
153 | export DATA_DIR=PATH_TO_DATA_DIR
154 | export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR
155 | export CODE_DIR=PATH_TO_CODE_DIR
156 | cd $CODE_DIR
157 | python train.py --data $DATA_DIR/celeba64_lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset celeba_64 \
158 | --num_channels_enc 64 --num_channels_dec 64 --epochs 90 --num_postprocess_cells 2 --num_preprocess_cells 2 \
159 | --num_latent_scales 3 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
160 | --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-1 --num_groups_per_scale 20 \
161 | --batch_size 16 --num_nf 1 --ada_groups --num_process_per_node 8 --use_se --res_dist --fast_adamax
162 | ```
163 |
164 |
165 | ImageNet 32x32
166 |
167 | 24 16-GB V100 GPUs are used for training NVAE on ImageNet 32x32. Training takes about 70 hours.
168 |
169 | ```shell script
170 | export EXPR_ID=UNIQUE_EXPR_ID
171 | export DATA_DIR=PATH_TO_DATA_DIR
172 | export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR
173 | export CODE_DIR=PATH_TO_CODE_DIR
174 | export IP_ADDR=IP_ADDRESS
175 | export NODE_RANK=NODE_RANK_BETWEEN_0_TO_2
176 | cd $CODE_DIR
177 | mpirun --allow-run-as-root -np 3 -npernode 1 bash -c \
178 | 'python train.py --data $DATA_DIR/imagenet-oord/imagenet-oord-lmdb_32 --root $CHECKPOINT_DIR --save $EXPR_ID --dataset imagenet_32 \
179 | --num_channels_enc 192 --num_channels_dec 192 --epochs 45 --num_postprocess_cells 2 --num_preprocess_cells 2 \
180 | --num_latent_scales 1 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
181 | --num_preprocess_blocks 1 --num_postprocess_blocks 1 --num_groups_per_scale 28 \
182 | --batch_size 24 --num_nf 1 --warmup_epochs 1 \
183 | --weight_decay_norm 1e-2 --weight_decay_norm_anneal --weight_decay_norm_init 1e0 \
184 | --num_process_per_node 8 --use_se --res_dist \
185 | --fast_adamax --node_rank $NODE_RANK --num_proc_node 3 --master_address $IP_ADDR '
186 | ```
187 |
188 |
189 | CelebA HQ 256
190 |
191 | 24 32-GB V100 GPUs are used for training NVAE on CelebA HQ 256. Training takes about 94 hours.
192 |
193 | ```shell script
194 | export EXPR_ID=UNIQUE_EXPR_ID
195 | export DATA_DIR=PATH_TO_DATA_DIR
196 | export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR
197 | export CODE_DIR=PATH_TO_CODE_DIR
198 | export IP_ADDR=IP_ADDRESS
199 | export NODE_RANK=NODE_RANK_BETWEEN_0_TO_2
200 | cd $CODE_DIR
201 | mpirun --allow-run-as-root -np 3 -npernode 1 bash -c \
202 | 'python train.py --data $DATA_DIR/celeba/celeba-lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset celeba_256 \
203 | --num_channels_enc 30 --num_channels_dec 30 --epochs 300 --num_postprocess_cells 2 --num_preprocess_cells 2 \
204 | --num_latent_scales 5 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
205 | --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-2 --num_groups_per_scale 16 \
206 | --batch_size 4 --num_nf 2 --ada_groups --min_groups_per_scale 4 \
207 | --weight_decay_norm_anneal --weight_decay_norm_init 1. --num_process_per_node 8 --use_se --res_dist \
208 | --fast_adamax --num_x_bits 5 --node_rank $NODE_RANK --num_proc_node 3 --master_address $IP_ADDR '
209 | ```
210 |
211 | In our early experiments, a smaller model with 24 channels instead of 30, could be trained on only 8 GPUs in
212 | the same time (with the batch size of 6). The smaller models obtain only 0.01 bpd higher
213 | negative log-likelihood.
214 |
215 |
216 | FFHQ 256
217 |
218 | 24 32-GB V100 GPUs are used for training NVAE on FFHQ 256. Training takes about 160 hours.
219 |
220 | ```shell script
221 | export EXPR_ID=UNIQUE_EXPR_ID
222 | export DATA_DIR=PATH_TO_DATA_DIR
223 | export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR
224 | export CODE_DIR=PATH_TO_CODE_DIR
225 | export IP_ADDR=IP_ADDRESS
226 | export NODE_RANK=NODE_RANK_BETWEEN_0_TO_2
227 | cd $CODE_DIR
228 | mpirun --allow-run-as-root -np 3 -npernode 1 bash -c \
229 | 'python train.py --data $DATA_DIR/ffhq/ffhq-lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset ffhq \
230 | --num_channels_enc 30 --num_channels_dec 30 --epochs 200 --num_postprocess_cells 2 --num_preprocess_cells 2 \
231 | --num_latent_scales 5 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
232 | --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-1 --num_groups_per_scale 16 \
233 | --batch_size 4 --num_nf 2 --ada_groups --min_groups_per_scale 4 \
234 | --weight_decay_norm_anneal --weight_decay_norm_init 1. --num_process_per_node 8 --use_se --res_dist \
235 | --fast_adamax --num_x_bits 5 --learning_rate 8e-3 --node_rank $NODE_RANK --num_proc_node 3 --master_address $IP_ADDR '
236 | ```
237 |
238 | In our early experiments, a smaller model with 24 channels instead of 30, could be trained on only 8 GPUs in
239 | the same time (with the batch size of 6). The smaller models obtain only 0.01 bpd higher
240 | negative log-likelihood.
241 |
242 |
243 | **If for any reason your training is stopped, use the exact same commend with the addition of `--cont_training`
244 | to continue training from the last saved checkpoint. If you observe NaN, continuing the training using this flag
245 | usually will not fix the NaN issue.**
246 |
247 | ## Known Issues
248 | Cannot build CelebA 64 or training gives NaN right at the beginning on this dataset
249 |
250 | Several users have reported issues building CelebA 64 or have encountered NaN at the beginning of training on this dataset.
251 | If you face similar issues on this dataset, you can download this dataset manually and build LMDBs using instructions
252 | on this issue https://github.com/NVlabs/NVAE/issues/2 .
253 |
254 |
255 | Getting NaN after a few epochs of training
256 |
257 | One of the main challenges in training very deep hierarchical VAEs is training instability that we discussed in the paper.
258 | We have verified that the settings in the commands above can be trained in a stable way. If you modify the settings
259 | above and you encounter NaN after a few epochs of training, you can use these tricks to stabilize your training:
260 | i) increase the spectral regularization coefficient, `--weight_decay_norm`. ii) Use exponential decay on
261 | `--weight_decay_norm` using `--weight_decay_norm_anneal` and `--weight_decay_norm_init`. iii) Decrease learning rate.
262 |
263 |
264 | Training freezes with no NaN
265 |
266 | In some very rare cases, we observed that training freezes after 2-3 days of training. We believe the root cause
267 | of this is because of a racing condition that is happening in one of the low-level libraries. If for any reason the training
268 | is stopped, kill your current run, and use the exact same commend with the addition of `--cont_training`
269 | to continue training from the last saved checkpoint.
270 |
271 |
272 | ## Monitoring the training progress
273 | While running any of the commands above, you can monitor the training progress using Tensorboard:
274 |
275 | Click here
276 |
277 | ```shell script
278 | tensorboard --logdir $CHECKPOINT_DIR/eval-$EXPR_ID/
279 | ```
280 | Above, `$CHECKPOINT_DIR` and `$EXPR_ID` are the same variables used for running the main training script.
281 |
282 |
283 |
284 | ## Post-training sampling, evaluation, and checkpoints
285 |
286 | Evaluating Log-Likelihood
287 |
288 | You can use the following command to load a trained model and evaluate it on the test datasets:
289 |
290 | ```shell script
291 | cd $CODE_DIR
292 | python evaluate.py --checkpoint $CHECKPOINT_DIR/eval-$EXPR_ID/checkpoint.pt --data $DATA_DIR/mnist --eval_mode=evaluate --num_iw_samples=1000
293 | ```
294 | Above, `--num_iw_samples` indicates the number of importance weighted samples used in evaluation.
295 | `$CHECKPOINT_DIR` and `$EXPR_ID` are the same variables used for running the main training script.
296 | Set `--data` to the same argument that was used when training NVAE (our example is for MNIST).
297 |
298 |
299 |
300 | Sampling
301 |
302 | You can also use the following command to generate samples from a trained model:
303 |
304 | ```shell script
305 | cd $CODE_DIR
306 | python evaluate.py --checkpoint $CHECKPOINT_DIR/eval-$EXPR_ID/checkpoint.pt --eval_mode=sample --temp=0.6 --readjust_bn
307 | ```
308 | where `--temp` sets the temperature used for sampling and `--readjust_bn` enables readjustment of the BN statistics
309 | as described in the paper. If you remove `--readjust_bn`, the sampling will proceed with BN layer in the eval mode
310 | (i.e., BN layers will use running mean and variances extracted during training).
311 |
312 |
313 |
314 | Computing FID
315 |
316 | You can compute the FID score using 50K samples. To do so, you will need to create
317 | a mean and covariance statistics file on the training data using a command like:
318 |
319 | ```shell script
320 | cd $CODE_DIR
321 | python scripts/precompute_fid_statistics.py --data $DATA_DIR/cifar10 --dataset cifar10 --fid_dir /tmp/fid-stats/
322 | ```
323 | The command above computes the references statistics on the CIFAR-10 dataset and stores them in the `--fid_dir` durectory.
324 | Given the reference statistics file, we can run the following command to compute the FID score:
325 |
326 | ```shell script
327 | cd $CODE_DIR
328 | python evaluate.py --checkpoint $CHECKPOINT_DIR/eval-$EXPR_ID/checkpoint.pt --data $DATA_DIR/cifar10 --eval_mode=evaluate_fid --fid_dir /tmp/fid-stats/ --temp=0.6 --readjust_bn
329 | ```
330 | where `--temp` sets the temperature used for sampling and `--readjust_bn` enables readjustment of the BN statistics
331 | as described in the paper. If you remove `--readjust_bn`, the sampling will proceed with BN layer in the eval mode
332 | (i.e., BN layers will use running mean and variances extracted during training).
333 | Above, `$CHECKPOINT_DIR` and `$EXPR_ID` are the same variables used for running the main training script.
334 | Set `--data` to the same argument that was used when training NVAE (our example is for MNIST).
335 |
336 |
337 |
338 | Checkpoints
339 |
340 | We provide checkpoints on MNIST, CIFAR-10, CelebA 64, CelebA HQ 256, FFHQ in
341 | [this Google drive directory](https://drive.google.com/drive/folders/1KVpw12AzdVjvbfEYM_6_3sxTy93wWkbe?usp=sharing).
342 | For CIFAR10, we provide two checkpoints as we observed that a multiscale NVAE provides better qualitative
343 | results than a single scale model on this dataset. The multiscale model is only slightly worse in terms
344 | of log-likelihood (0.01 bpd). We also observe that one of our early models on CelebA HQ 256 with 0.01 bpd
345 | worse likelihood generates much better images in low temperature on this dataset.
346 |
347 | You can use the commands above to evaluate or sample from these checkpoints.
348 |
349 |
350 |
351 | ## How to construct smaller NVAE models
352 | In the commands above, we are constructing big NVAE models that require several days of training
353 | in most cases. If you'd like to construct smaller NVAEs, you can use these tricks:
354 |
355 | * Reduce the network width: `--num_channels_enc` and `--num_channels_dec` are controlling the number
356 | of initial channels in the bottom-up and top-down networks respectively. Recall that we halve the
357 | number of channels with every spatial downsampling layer in the bottom-up network, and we double the number of
358 | channels with every upsampling layer in the top-down network. By reducing
359 | `--num_channels_enc` and `--num_channels_dec`, you can reduce the overall width of the networks.
360 |
361 | * Reduce the number of residual cells in the hierarchy: `--num_cell_per_cond_enc` and
362 | `--num_cell_per_cond_dec` control the number of residual cells used between every latent variable
363 | group in the bottom-up and top-down networks respectively. In most of our experiments, we are using
364 | two cells per group for both networks. You can reduce the number of residual cells to one to make the model
365 | smaller.
366 |
367 | * Reduce the number of epochs: You can reduce the training time by reducing `--epochs`.
368 |
369 | * Reduce the number of groups: You can make NVAE smaller by using a smaller number of latent variable groups.
370 | We use two schemes for setting the number of groups:
371 | 1. An equal number of groups: This is set by `--num_groups_per_scale` which indicates the number of groups
372 | in each scale of latent variables. Reduce this number to have a small NVAE.
373 |
374 | 2. An adaptive number of groups: This is enabled by `--ada_groups`. In this case, the highest
375 | resolution of latent variables will have `--num_groups_per_scale` groups and
376 | the smaller scales will get half the number of groups successively (see groups_per_scale in utils.py).
377 | We don't let the number of groups go below `--min_groups_per_scale`. You can reduce
378 | the total number of groups by reducing `--num_groups_per_scale` and `--min_groups_per_scale`
379 | when `--ada_groups` is enabled.
380 |
381 | ## Understanding the implementation
382 | If you are modifying the code, you can use the following figure to map the code to the paper.
383 |
384 |
385 |
386 |
387 |
388 |
389 | ## Traversing the latent space
390 | We can generate images by traversing in the latent space of NVAE. This sequence is generated using our model
391 | trained on CelebA HQ, by interpolating between samples generated with temperature 0.6.
392 | Some artifacts are due to color quantization in GIFs.
393 |
394 |
395 |
396 |
397 |
398 | ## License
399 | Please check the LICENSE file. NVAE may be used non-commercially, meaning for research or
400 | evaluation purposes only. For business inquiries, please contact
401 | [researchinquiries@nvidia.com](mailto:researchinquiries@nvidia.com).
402 |
403 | You should take into consideration that VAEs are trained to mimic the training data distribution, and, any
404 | bias introduced in data collection will make VAEs generate samples with a similar bias. Additional bias could be
405 | introduced during model design, training, or when VAEs are sampled using small temperatures. Bias correction in
406 | generative learning is an active area of research, and we recommend interested readers to check this area before
407 | building applications using NVAE.
408 |
409 | ## Bibtex:
410 | Please cite our paper, if you happen to use this codebase:
411 |
412 | ```
413 | @inproceedings{vahdat2020NVAE,
414 | title={{NVAE}: A Deep Hierarchical Variational Autoencoder},
415 | author={Vahdat, Arash and Kautz, Jan},
416 | booktitle={Neural Information Processing Systems (NeurIPS)},
417 | year={2020}
418 | }
419 | ```
420 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for NVAE. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | """Code for getting the data loaders."""
9 |
10 | import numpy as np
11 | from PIL import Image
12 | import torch
13 | import torchvision.datasets as dset
14 | import torchvision.transforms as transforms
15 | from torch.utils.data import Dataset
16 | from scipy.io import loadmat
17 | import os
18 | import urllib
19 | from lmdb_datasets import LMDBDataset
20 | from thirdparty.lsun import LSUN
21 |
22 |
23 | class StackedMNIST(dset.MNIST):
24 | def __init__(self, root, train=True, transform=None, target_transform=None,
25 | download=False):
26 | super(StackedMNIST, self).__init__(root=root, train=train, transform=transform,
27 | target_transform=target_transform, download=download)
28 |
29 | index1 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))])
30 | index2 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))])
31 | index3 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))])
32 | self.num_images = 2 * len(self.data)
33 |
34 | self.index = []
35 | for i in range(self.num_images):
36 | self.index.append((index1[i], index2[i], index3[i]))
37 |
38 | def __len__(self):
39 | return self.num_images
40 |
41 | def __getitem__(self, index):
42 | img = np.zeros((28, 28, 3), dtype=np.uint8)
43 | target = 0
44 | for i in range(3):
45 | img_, target_ = self.data[self.index[index][i]], int(self.targets[self.index[index][i]])
46 | img[:, :, i] = img_
47 | target += target_ * 10 ** (2 - i)
48 |
49 | img = Image.fromarray(img, mode="RGB")
50 |
51 | if self.transform is not None:
52 | img = self.transform(img)
53 |
54 | if self.target_transform is not None:
55 | target = self.target_transform(target)
56 |
57 | return img, target
58 |
59 |
60 |
61 | class Binarize(object):
62 | """ This class introduces a binarization transformation
63 | """
64 | def __call__(self, pic):
65 | return torch.Tensor(pic.size()).bernoulli_(pic)
66 |
67 | def __repr__(self):
68 | return self.__class__.__name__ + '()'
69 |
70 |
71 | class CropCelebA64(object):
72 | """ This class applies cropping for CelebA64. This is a simplified implementation of:
73 | https://github.com/andersbll/autoencoding_beyond_pixels/blob/master/dataset/celeba.py
74 | """
75 | def __call__(self, pic):
76 | new_pic = pic.crop((15, 40, 178 - 15, 218 - 30))
77 | return new_pic
78 |
79 | def __repr__(self):
80 | return self.__class__.__name__ + '()'
81 |
82 |
83 | def get_loaders(args):
84 | """Get data loaders for required dataset."""
85 | return get_loaders_eval(args.dataset, args)
86 |
87 | def download_omniglot(data_dir):
88 | filename = 'chardata.mat'
89 | if not os.path.exists(data_dir):
90 | os.mkdir(data_dir)
91 | url = 'https://raw.github.com/yburda/iwae/master/datasets/OMNIGLOT/chardata.mat'
92 |
93 | filepath = os.path.join(data_dir, filename)
94 | if not os.path.exists(filepath):
95 | filepath, _ = urllib.request.urlretrieve(url, filepath)
96 | print('Downloaded', filename)
97 |
98 | return
99 |
100 |
101 | def load_omniglot(data_dir):
102 | download_omniglot(data_dir)
103 |
104 | data_path = os.path.join(data_dir, 'chardata.mat')
105 |
106 | omni = loadmat(data_path)
107 | train_data = 255 * omni['data'].astype('float32').reshape((28, 28, -1)).transpose((2, 1, 0))
108 | test_data = 255 * omni['testdata'].astype('float32').reshape((28, 28, -1)).transpose((2, 1, 0))
109 |
110 | train_data = train_data.astype('uint8')
111 | test_data = test_data.astype('uint8')
112 |
113 | return train_data, test_data
114 |
115 |
116 | class OMNIGLOT(Dataset):
117 | def __init__(self, data, transform):
118 | self.data = data
119 | self.transform = transform
120 |
121 | def __getitem__(self, index):
122 | d = self.data[index]
123 | img = Image.fromarray(d)
124 | return self.transform(img), 0 # return zero as label.
125 |
126 | def __len__(self):
127 | return len(self.data)
128 |
129 | def get_loaders_eval(dataset, args):
130 | """Get train and valid loaders for cifar10/tiny imagenet."""
131 |
132 | if dataset == 'cifar10':
133 | num_classes = 10
134 | train_transform, valid_transform = _data_transforms_cifar10(args)
135 | train_data = dset.CIFAR10(
136 | root=args.data, train=True, download=True, transform=train_transform)
137 | valid_data = dset.CIFAR10(
138 | root=args.data, train=False, download=True, transform=valid_transform)
139 | elif dataset == 'mnist':
140 | num_classes = 10
141 | train_transform, valid_transform = _data_transforms_mnist(args)
142 | train_data = dset.MNIST(
143 | root=args.data, train=True, download=True, transform=train_transform)
144 | valid_data = dset.MNIST(
145 | root=args.data, train=False, download=True, transform=valid_transform)
146 | elif dataset == 'stacked_mnist':
147 | num_classes = 1000
148 | train_transform, valid_transform = _data_transforms_stacked_mnist(args)
149 | train_data = StackedMNIST(
150 | root=args.data, train=True, download=True, transform=train_transform)
151 | valid_data = StackedMNIST(
152 | root=args.data, train=False, download=True, transform=valid_transform)
153 | elif dataset == 'omniglot':
154 | num_classes = 0
155 | download_omniglot(args.data)
156 | train_transform, valid_transform = _data_transforms_mnist(args)
157 | train_data, valid_data = load_omniglot(args.data)
158 | train_data = OMNIGLOT(train_data, train_transform)
159 | valid_data = OMNIGLOT(valid_data, valid_transform)
160 | elif dataset.startswith('celeba'):
161 | if dataset == 'celeba_64':
162 | resize = 64
163 | num_classes = 40
164 | train_transform, valid_transform = _data_transforms_celeba64(resize)
165 | train_data = LMDBDataset(root=args.data, name='celeba64', train=True, transform=train_transform, is_encoded=True)
166 | valid_data = LMDBDataset(root=args.data, name='celeba64', train=False, transform=valid_transform, is_encoded=True)
167 | elif dataset in {'celeba_256'}:
168 | num_classes = 1
169 | resize = int(dataset.split('_')[1])
170 | train_transform, valid_transform = _data_transforms_generic(resize)
171 | train_data = LMDBDataset(root=args.data, name='celeba', train=True, transform=train_transform)
172 | valid_data = LMDBDataset(root=args.data, name='celeba', train=False, transform=valid_transform)
173 | else:
174 | raise NotImplementedError
175 | elif dataset.startswith('lsun'):
176 | if dataset.startswith('lsun_bedroom'):
177 | resize = int(dataset.split('_')[-1])
178 | num_classes = 1
179 | train_transform, valid_transform = _data_transforms_lsun(resize)
180 | train_data = LSUN(root=args.data, classes=['bedroom_train'], transform=train_transform)
181 | valid_data = LSUN(root=args.data, classes=['bedroom_val'], transform=valid_transform)
182 | elif dataset.startswith('lsun_church'):
183 | resize = int(dataset.split('_')[-1])
184 | num_classes = 1
185 | train_transform, valid_transform = _data_transforms_lsun(resize)
186 | train_data = LSUN(root=args.data, classes=['church_outdoor_train'], transform=train_transform)
187 | valid_data = LSUN(root=args.data, classes=['church_outdoor_val'], transform=valid_transform)
188 | elif dataset.startswith('lsun_tower'):
189 | resize = int(dataset.split('_')[-1])
190 | num_classes = 1
191 | train_transform, valid_transform = _data_transforms_lsun(resize)
192 | train_data = LSUN(root=args.data, classes=['tower_train'], transform=train_transform)
193 | valid_data = LSUN(root=args.data, classes=['tower_val'], transform=valid_transform)
194 | else:
195 | raise NotImplementedError
196 | elif dataset.startswith('imagenet'):
197 | num_classes = 1
198 | resize = int(dataset.split('_')[1])
199 | assert args.data.replace('/', '')[-3:] == dataset.replace('/', '')[-3:], 'the size should match'
200 | train_transform, valid_transform = _data_transforms_generic(resize)
201 | train_data = LMDBDataset(root=args.data, name='imagenet-oord', train=True, transform=train_transform)
202 | valid_data = LMDBDataset(root=args.data, name='imagenet-oord', train=False, transform=valid_transform)
203 | elif dataset.startswith('ffhq'):
204 | num_classes = 1
205 | resize = 256
206 | train_transform, valid_transform = _data_transforms_generic(resize)
207 | train_data = LMDBDataset(root=args.data, name='ffhq', train=True, transform=train_transform)
208 | valid_data = LMDBDataset(root=args.data, name='ffhq', train=False, transform=valid_transform)
209 | else:
210 | raise NotImplementedError
211 |
212 | train_sampler, valid_sampler = None, None
213 | if args.distributed:
214 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
215 | valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data)
216 |
217 | train_queue = torch.utils.data.DataLoader(
218 | train_data, batch_size=args.batch_size,
219 | shuffle=(train_sampler is None),
220 | sampler=train_sampler, pin_memory=True, num_workers=8, drop_last=True)
221 |
222 | valid_queue = torch.utils.data.DataLoader(
223 | valid_data, batch_size=args.batch_size,
224 | shuffle=(valid_sampler is None),
225 | sampler=valid_sampler, pin_memory=True, num_workers=1, drop_last=False)
226 |
227 | return train_queue, valid_queue, num_classes
228 |
229 |
230 | def _data_transforms_cifar10(args):
231 | """Get data transforms for cifar10."""
232 |
233 | train_transform = transforms.Compose([
234 | transforms.RandomHorizontalFlip(),
235 | transforms.ToTensor()
236 | ])
237 |
238 | valid_transform = transforms.Compose([
239 | transforms.ToTensor()
240 | ])
241 |
242 | return train_transform, valid_transform
243 |
244 |
245 | def _data_transforms_mnist(args):
246 | """Get data transforms for cifar10."""
247 | train_transform = transforms.Compose([
248 | transforms.Pad(padding=2),
249 | transforms.ToTensor(),
250 | Binarize(),
251 | ])
252 |
253 | valid_transform = transforms.Compose([
254 | transforms.Pad(padding=2),
255 | transforms.ToTensor(),
256 | Binarize(),
257 | ])
258 |
259 | return train_transform, valid_transform
260 |
261 |
262 | def _data_transforms_stacked_mnist(args):
263 | """Get data transforms for cifar10."""
264 | train_transform = transforms.Compose([
265 | transforms.Pad(padding=2),
266 | transforms.ToTensor()
267 | ])
268 |
269 | valid_transform = transforms.Compose([
270 | transforms.Pad(padding=2),
271 | transforms.ToTensor()
272 | ])
273 |
274 | return train_transform, valid_transform
275 |
276 |
277 | def _data_transforms_generic(size):
278 | train_transform = transforms.Compose([
279 | transforms.Resize(size),
280 | transforms.RandomHorizontalFlip(),
281 | transforms.ToTensor(),
282 | ])
283 |
284 | valid_transform = transforms.Compose([
285 | transforms.Resize(size),
286 | transforms.ToTensor(),
287 | ])
288 |
289 | return train_transform, valid_transform
290 |
291 |
292 | def _data_transforms_celeba64(size):
293 | train_transform = transforms.Compose([
294 | CropCelebA64(),
295 | transforms.Resize(size),
296 | transforms.RandomHorizontalFlip(),
297 | transforms.ToTensor(),
298 | ])
299 |
300 | valid_transform = transforms.Compose([
301 | CropCelebA64(),
302 | transforms.Resize(size),
303 | transforms.ToTensor(),
304 | ])
305 |
306 | return train_transform, valid_transform
307 |
308 |
309 | def _data_transforms_lsun(size):
310 | train_transform = transforms.Compose([
311 | transforms.Resize(size),
312 | transforms.RandomCrop(size),
313 | transforms.RandomHorizontalFlip(),
314 | transforms.ToTensor(),
315 | ])
316 |
317 | valid_transform = transforms.Compose([
318 | transforms.Resize(size),
319 | transforms.CenterCrop(size),
320 | transforms.ToTensor(),
321 | ])
322 |
323 | return train_transform, valid_transform
324 |
--------------------------------------------------------------------------------
/distributions.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for NVAE. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import numpy as np
12 |
13 | from utils import one_hot
14 |
15 | @torch.jit.script
16 | def soft_clamp5(x: torch.Tensor):
17 | return x.div(5.).tanh_().mul(5.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
18 |
19 |
20 | @torch.jit.script
21 | def sample_normal_jit(mu, sigma):
22 | eps = mu.mul(0).normal_()
23 | z = eps.mul_(sigma).add_(mu)
24 | return z, eps
25 |
26 |
27 | class Normal:
28 | def __init__(self, mu, log_sigma, temp=1.):
29 | self.mu = soft_clamp5(mu)
30 | log_sigma = soft_clamp5(log_sigma)
31 | self.sigma = torch.exp(log_sigma) + 1e-2 # we don't need this after soft clamp
32 | if temp != 1.:
33 | self.sigma *= temp
34 |
35 | def sample(self):
36 | return sample_normal_jit(self.mu, self.sigma)
37 |
38 | def sample_given_eps(self, eps):
39 | return eps * self.sigma + self.mu
40 |
41 | def log_p(self, samples):
42 | normalized_samples = (samples - self.mu) / self.sigma
43 | log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - torch.log(self.sigma)
44 | return log_p
45 |
46 | def kl(self, normal_dist):
47 | term1 = (self.mu - normal_dist.mu) / normal_dist.sigma
48 | term2 = self.sigma / normal_dist.sigma
49 |
50 | return 0.5 * (term1 * term1 + term2 * term2) - 0.5 - torch.log(term2)
51 |
52 |
53 | class NormalDecoder:
54 | def __init__(self, param, num_bits=8):
55 | B, C, H, W = param.size()
56 | self.num_c = C // 2
57 | mu = param[:, :self.num_c, :, :] # B, 3, H, W
58 | log_sigma = param[:, self.num_c:, :, :] # B, 3, H, W
59 | self.dist = Normal(mu, log_sigma)
60 |
61 | def log_prob(self, samples):
62 | assert torch.max(samples) <= 1.0 and torch.min(samples) >= 0.0
63 | # convert samples to be in [-1, 1]
64 | samples = 2 * samples - 1.0
65 |
66 | return self.dist.log_p(samples)
67 |
68 | def sample(self, t=1.):
69 | x, _ = self.dist.sample()
70 | x = torch.clamp(x, -1, 1.)
71 | x = x / 2. + 0.5
72 | return x
73 |
74 |
75 | class DiscLogistic:
76 | def __init__(self, param):
77 | B, C, H, W = param.size()
78 | self.num_c = C // 2
79 | self.means = param[:, :self.num_c, :, :] # B, 3, H, W
80 | self.log_scales = torch.clamp(param[:, self.num_c:, :, :], min=-8.0) # B, 3, H, W
81 |
82 | def log_prob(self, samples):
83 | assert torch.max(samples) <= 1.0 and torch.min(samples) >= 0.0
84 | # convert samples to be in [-1, 1]
85 | samples = 2 * samples - 1.0
86 |
87 | B, C, H, W = samples.size()
88 | assert C == self.num_c
89 |
90 | centered = samples - self.means # B, 3, H, W
91 | inv_stdv = torch.exp(- self.log_scales)
92 | plus_in = inv_stdv * (centered + 1. / 255.)
93 | cdf_plus = torch.sigmoid(plus_in)
94 | min_in = inv_stdv * (centered - 1. / 255.)
95 | cdf_min = torch.sigmoid(min_in)
96 | log_cdf_plus = plus_in - F.softplus(plus_in)
97 | log_one_minus_cdf_min = - F.softplus(min_in)
98 | cdf_delta = cdf_plus - cdf_min
99 | mid_in = inv_stdv * centered
100 | log_pdf_mid = mid_in - self.log_scales - 2. * F.softplus(mid_in)
101 |
102 | log_prob_mid_safe = torch.where(cdf_delta > 1e-5,
103 | torch.log(torch.clamp(cdf_delta, min=1e-10)),
104 | log_pdf_mid - np.log(127.5))
105 | # woow the original implementation uses samples > 0.999, this ignores the largest possible pixel value (255)
106 | # which is mapped to 0.9922
107 | log_probs = torch.where(samples < -0.999, log_cdf_plus, torch.where(samples > 0.99, log_one_minus_cdf_min,
108 | log_prob_mid_safe)) # B, 3, H, W
109 |
110 | return log_probs
111 |
112 | def sample(self):
113 | u = torch.Tensor(self.means.size()).uniform_(1e-5, 1. - 1e-5).cuda() # B, 3, H, W
114 | x = self.means + torch.exp(self.log_scales) * (torch.log(u) - torch.log(1. - u)) # B, 3, H, W
115 | x = torch.clamp(x, -1, 1.)
116 | x = x / 2. + 0.5
117 | return x
118 |
119 |
120 | class DiscMixLogistic:
121 | def __init__(self, param, num_mix=10, num_bits=8):
122 | B, C, H, W = param.size()
123 | self.num_mix = num_mix
124 | self.logit_probs = param[:, :num_mix, :, :] # B, M, H, W
125 | l = param[:, num_mix:, :, :].view(B, 3, 3 * num_mix, H, W) # B, 3, 3 * M, H, W
126 | self.means = l[:, :, :num_mix, :, :] # B, 3, M, H, W
127 | self.log_scales = torch.clamp(l[:, :, num_mix:2 * num_mix, :, :], min=-7.0) # B, 3, M, H, W
128 | self.coeffs = torch.tanh(l[:, :, 2 * num_mix:3 * num_mix, :, :]) # B, 3, M, H, W
129 | self.max_val = 2. ** num_bits - 1
130 |
131 | def log_prob(self, samples):
132 | assert torch.max(samples) <= 1.0 and torch.min(samples) >= 0.0
133 | # convert samples to be in [-1, 1]
134 | samples = 2 * samples - 1.0
135 |
136 | B, C, H, W = samples.size()
137 | assert C == 3, 'only RGB images are considered.'
138 |
139 | samples = samples.unsqueeze(4) # B, 3, H , W
140 | samples = samples.expand(-1, -1, -1, -1, self.num_mix).permute(0, 1, 4, 2, 3) # B, 3, M, H, W
141 | mean1 = self.means[:, 0, :, :, :] # B, M, H, W
142 | mean2 = self.means[:, 1, :, :, :] + \
143 | self.coeffs[:, 0, :, :, :] * samples[:, 0, :, :, :] # B, M, H, W
144 | mean3 = self.means[:, 2, :, :, :] + \
145 | self.coeffs[:, 1, :, :, :] * samples[:, 0, :, :, :] + \
146 | self.coeffs[:, 2, :, :, :] * samples[:, 1, :, :, :] # B, M, H, W
147 |
148 | mean1 = mean1.unsqueeze(1) # B, 1, M, H, W
149 | mean2 = mean2.unsqueeze(1) # B, 1, M, H, W
150 | mean3 = mean3.unsqueeze(1) # B, 1, M, H, W
151 | means = torch.cat([mean1, mean2, mean3], dim=1) # B, 3, M, H, W
152 | centered = samples - means # B, 3, M, H, W
153 |
154 | inv_stdv = torch.exp(- self.log_scales)
155 | plus_in = inv_stdv * (centered + 1. / self.max_val)
156 | cdf_plus = torch.sigmoid(plus_in)
157 | min_in = inv_stdv * (centered - 1. / self.max_val)
158 | cdf_min = torch.sigmoid(min_in)
159 | log_cdf_plus = plus_in - F.softplus(plus_in)
160 | log_one_minus_cdf_min = - F.softplus(min_in)
161 | cdf_delta = cdf_plus - cdf_min
162 | mid_in = inv_stdv * centered
163 | log_pdf_mid = mid_in - self.log_scales - 2. * F.softplus(mid_in)
164 |
165 | log_prob_mid_safe = torch.where(cdf_delta > 1e-5,
166 | torch.log(torch.clamp(cdf_delta, min=1e-10)),
167 | log_pdf_mid - np.log(self.max_val / 2))
168 | # the original implementation uses samples > 0.999, this ignores the largest possible pixel value (255)
169 | # which is mapped to 0.9922
170 | log_probs = torch.where(samples < -0.999, log_cdf_plus, torch.where(samples > 0.99, log_one_minus_cdf_min,
171 | log_prob_mid_safe)) # B, 3, M, H, W
172 |
173 | log_probs = torch.sum(log_probs, 1) + F.log_softmax(self.logit_probs, dim=1) # B, M, H, W
174 | return torch.logsumexp(log_probs, dim=1) # B, H, W
175 |
176 | def sample(self, t=1.):
177 | gumbel = -torch.log(- torch.log(torch.Tensor(self.logit_probs.size()).uniform_(1e-5, 1. - 1e-5).cuda())) # B, M, H, W
178 | sel = one_hot(torch.argmax(self.logit_probs / t + gumbel, 1), self.num_mix, dim=1) # B, M, H, W
179 | sel = sel.unsqueeze(1) # B, 1, M, H, W
180 |
181 | # select logistic parameters
182 | means = torch.sum(self.means * sel, dim=2) # B, 3, H, W
183 | log_scales = torch.sum(self.log_scales * sel, dim=2) # B, 3, H, W
184 | coeffs = torch.sum(self.coeffs * sel, dim=2) # B, 3, H, W
185 |
186 | # cells from logistic & clip to interval
187 | # we don't actually round to the nearest 8bit value when sampling
188 | u = torch.Tensor(means.size()).uniform_(1e-5, 1. - 1e-5).cuda() # B, 3, H, W
189 | x = means + torch.exp(log_scales) / t * (torch.log(u) - torch.log(1. - u)) # B, 3, H, W
190 |
191 | x0 = torch.clamp(x[:, 0, :, :], -1, 1.) # B, H, W
192 | x1 = torch.clamp(x[:, 1, :, :] + coeffs[:, 0, :, :] * x0, -1, 1) # B, H, W
193 | x2 = torch.clamp(x[:, 2, :, :] + coeffs[:, 1, :, :] * x0 + coeffs[:, 2, :, :] * x1, -1, 1) # B, H, W
194 |
195 | x0 = x0.unsqueeze(1)
196 | x1 = x1.unsqueeze(1)
197 | x2 = x2.unsqueeze(1)
198 |
199 | x = torch.cat([x0, x1, x2], 1)
200 | x = x / 2. + 0.5
201 | return x
202 |
203 | def mean(self):
204 | sel = torch.softmax(self.logit_probs, dim=1) # B, M, H, W
205 | sel = sel.unsqueeze(1) # B, 1, M, H, W
206 |
207 | # select logistic parameters
208 | means = torch.sum(self.means * sel, dim=2) # B, 3, H, W
209 | coeffs = torch.sum(self.coeffs * sel, dim=2) # B, 3, H, W
210 |
211 | # we don't sample from logistic components, because of the linear dependencies, we use mean
212 | x = means # B, 3, H, W
213 | x0 = torch.clamp(x[:, 0, :, :], -1, 1.) # B, H, W
214 | x1 = torch.clamp(x[:, 1, :, :] + coeffs[:, 0, :, :] * x0, -1, 1) # B, H, W
215 | x2 = torch.clamp(x[:, 2, :, :] + coeffs[:, 1, :, :] * x0 + coeffs[:, 2, :, :] * x1, -1, 1) # B, H, W
216 |
217 | x0 = x0.unsqueeze(1)
218 | x1 = x1.unsqueeze(1)
219 | x2 = x2.unsqueeze(1)
220 |
221 | x = torch.cat([x0, x1, x2], 1)
222 | x = x / 2. + 0.5
223 | return x
224 |
225 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for NVAE. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import argparse
9 | import torch
10 | import numpy as np
11 | import os
12 | import matplotlib.pyplot as plt
13 | from time import time
14 |
15 | from torch.multiprocessing import Process
16 | from torch.cuda.amp import autocast
17 |
18 | from model import AutoEncoder
19 | import utils
20 | import datasets
21 | from train import test, init_processes, test_vae_fid
22 |
23 |
24 | def set_bn(model, bn_eval_mode, num_samples=1, t=1.0, iter=100):
25 | if bn_eval_mode:
26 | model.eval()
27 | else:
28 | model.train()
29 | with autocast():
30 | for i in range(iter):
31 | if i % 10 == 0:
32 | print('setting BN statistics iter %d out of %d' % (i+1, iter))
33 | model.sample(num_samples, t)
34 | model.eval()
35 |
36 |
37 | def main(eval_args):
38 | # ensures that weight initializations are all the same
39 | logging = utils.Logger(eval_args.local_rank, eval_args.save)
40 |
41 | # load a checkpoint
42 | logging.info('loading the model at:')
43 | logging.info(eval_args.checkpoint)
44 | checkpoint = torch.load(eval_args.checkpoint, map_location='cpu')
45 | args = checkpoint['args']
46 |
47 | if not hasattr(args, 'ada_groups'):
48 | logging.info('old model, no ada groups was found.')
49 | args.ada_groups = False
50 |
51 | if not hasattr(args, 'min_groups_per_scale'):
52 | logging.info('old model, no min_groups_per_scale was found.')
53 | args.min_groups_per_scale = 1
54 |
55 | if not hasattr(args, 'num_mixture_dec'):
56 | logging.info('old model, no num_mixture_dec was found.')
57 | args.num_mixture_dec = 10
58 |
59 | if eval_args.batch_size > 0:
60 | args.batch_size = eval_args.batch_size
61 |
62 | logging.info('loaded the model at epoch %d', checkpoint['epoch'])
63 | arch_instance = utils.get_arch_cells(args.arch_instance)
64 | model = AutoEncoder(args, None, arch_instance)
65 | # Loading is not strict because of self.weight_normalized in Conv2D class in neural_operations. This variable
66 | # is only used for computing the spectral normalization and it is safe not to load it. Some of our earlier models
67 | # did not have this variable.
68 | model.load_state_dict(checkpoint['state_dict'], strict=False)
69 | model = model.cuda()
70 |
71 | logging.info('args = %s', args)
72 | logging.info('num conv layers: %d', len(model.all_conv_layers))
73 | logging.info('param size = %fM ', utils.count_parameters_in_M(model))
74 |
75 | if eval_args.eval_mode == 'evaluate':
76 | # load train valid queue
77 | args.data = eval_args.data
78 | train_queue, valid_queue, num_classes = datasets.get_loaders(args)
79 |
80 | if eval_args.eval_on_train:
81 | logging.info('Using the training data for eval.')
82 | valid_queue = train_queue
83 |
84 | # get number of bits
85 | num_output = utils.num_output(args.dataset)
86 | bpd_coeff = 1. / np.log(2.) / num_output
87 |
88 | valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=eval_args.num_iw_samples, args=args, logging=logging)
89 | logging.info('final valid nelbo %f', valid_nelbo)
90 | logging.info('final valid neg log p %f', valid_neg_log_p)
91 | logging.info('final valid nelbo in bpd %f', valid_nelbo * bpd_coeff)
92 | logging.info('final valid neg log p in bpd %f', valid_neg_log_p * bpd_coeff)
93 | elif eval_args.eval_mode == 'evaluate_fid':
94 | bn_eval_mode = not eval_args.readjust_bn
95 | set_bn(model, bn_eval_mode, num_samples=2, t=eval_args.temp, iter=500)
96 | args.fid_dir = eval_args.fid_dir
97 | args.num_process_per_node, args.num_proc_node = eval_args.world_size, 1 # evaluate only one 1 node
98 | fid = test_vae_fid(model, args, total_fid_samples=50000)
99 | logging.info('fid is %f' % fid)
100 | else:
101 | bn_eval_mode = not eval_args.readjust_bn
102 | total_samples = 50000 // eval_args.world_size # num images per gpu
103 | num_samples = 100 # sampling batch size
104 | num_iter = int(np.ceil(total_samples / num_samples)) # num iterations per gpu
105 |
106 | with torch.no_grad():
107 | n = int(np.floor(np.sqrt(num_samples)))
108 | set_bn(model, bn_eval_mode, num_samples=16, t=eval_args.temp, iter=500)
109 | for ind in range(num_iter): # sampling is repeated.
110 | torch.cuda.synchronize()
111 | start = time()
112 | with autocast():
113 | logits = model.sample(num_samples, eval_args.temp)
114 | output = model.decoder_output(logits)
115 | output_img = output.mean if isinstance(output, torch.distributions.bernoulli.Bernoulli) \
116 | else output.sample()
117 | torch.cuda.synchronize()
118 | end = time()
119 | logging.info('sampling time per batch: %0.3f sec', (end - start))
120 |
121 | visualize = False
122 | if visualize:
123 | output_tiled = utils.tile_image(output_img, n).cpu().numpy().transpose(1, 2, 0)
124 | output_tiled = np.asarray(output_tiled * 255, dtype=np.uint8)
125 | output_tiled = np.squeeze(output_tiled)
126 |
127 | plt.imshow(output_tiled)
128 | plt.show()
129 | else:
130 | file_path = os.path.join(eval_args.save, 'gpu_%d_samples_%d.npz' % (eval_args.local_rank, ind))
131 | np.savez_compressed(file_path, samples=output_img.cpu().numpy())
132 | logging.info('Saved at: {}'.format(file_path))
133 |
134 |
135 | if __name__ == '__main__':
136 | parser = argparse.ArgumentParser('encoder decoder examiner')
137 | # experimental results
138 | parser.add_argument('--checkpoint', type=str, default='/tmp/expr/checkpoint.pt',
139 | help='location of the checkpoint')
140 | parser.add_argument('--save', type=str, default='/tmp/expr',
141 | help='location of the checkpoint')
142 | parser.add_argument('--eval_mode', type=str, default='sample', choices=['sample', 'evaluate', 'evaluate_fid'],
143 | help='evaluation mode. you can choose between sample or evaluate.')
144 | parser.add_argument('--eval_on_train', action='store_true', default=False,
145 | help='Settings this to true will evaluate the model on training data.')
146 | parser.add_argument('--data', type=str, default='/tmp/data',
147 | help='location of the data corpus')
148 | parser.add_argument('--readjust_bn', action='store_true', default=False,
149 | help='adding this flag will enable readjusting BN statistics.')
150 | parser.add_argument('--temp', type=float, default=0.7,
151 | help='The temperature used for sampling.')
152 | parser.add_argument('--num_iw_samples', type=int, default=1000,
153 | help='The number of IW samples used in test_ll mode.')
154 | parser.add_argument('--fid_dir', type=str, default='/tmp/fid-stats',
155 | help='path to directory where fid related files are stored')
156 | parser.add_argument('--batch_size', type=int, default=0,
157 | help='Batch size used during evaluation. If set to zero, training batch size is used.')
158 | # DDP.
159 | parser.add_argument('--local_rank', type=int, default=0,
160 | help='rank of process')
161 | parser.add_argument('--world_size', type=int, default=1,
162 | help='number of gpus')
163 | parser.add_argument('--seed', type=int, default=1,
164 | help='seed used for initialization')
165 | parser.add_argument('--master_address', type=str, default='127.0.0.1',
166 | help='address for master')
167 |
168 | args = parser.parse_args()
169 | utils.create_exp_dir(args.save)
170 |
171 | size = args.world_size
172 |
173 | if size > 1:
174 | args.distributed = True
175 | processes = []
176 | for rank in range(size):
177 | args.local_rank = rank
178 | p = Process(target=init_processes, args=(rank, size, main, args))
179 | p.start()
180 | processes.append(p)
181 |
182 | for p in processes:
183 | p.join()
184 | else:
185 | # for debugging
186 | print('starting in debug mode')
187 | args.distributed = True
188 | init_processes(0, size, main, args)
--------------------------------------------------------------------------------
/fid/LICENSE_pytorch_fid:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/fid/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/NVAE/9fc1a288fb831c87d93a4e2663bc30ccf9225b29/fid/__init__.py
--------------------------------------------------------------------------------
/fid/fid_score.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the pytorch_fid library
5 | # which was released under the Apache License v2.0 License.
6 | #
7 | # Source:
8 | # https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/fid_score.py
9 | #
10 | # The license for the original version of this file can be
11 | # found in this directory (LICENSE_pytorch_fid). The modifications
12 | # to this file are subject to the same Apache License.
13 | # ---------------------------------------------------------------
14 |
15 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs
16 |
17 | The FID metric calculates the distance between two distributions of images.
18 | Typically, we have summary statistics (mean & covariance matrix) of one
19 | of these distributions, while the 2nd distribution is given by a GAN.
20 |
21 | When run as a stand-alone program, it compares the distribution of
22 | images that are stored as PNG/JPEG at a specified location with a
23 | distribution given by summary statistics (in pickle format).
24 |
25 | The FID is calculated by assuming that X_1 and X_2 are the activations of
26 | the pool_3 layer of the inception net for generated samples and real world
27 | samples respectively.
28 |
29 | See --help to see further details.
30 |
31 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
32 | of Tensorflow
33 |
34 | Copyright 2018 Institute of Bioinformatics, JKU Linz
35 |
36 | Licensed under the Apache License, Version 2.0 (the "License");
37 | you may not use this file except in compliance with the License.
38 | You may obtain a copy of the License at
39 |
40 | http://www.apache.org/licenses/LICENSE-2.0
41 |
42 | Unless required by applicable law or agreed to in writing, software
43 | distributed under the License is distributed on an "AS IS" BASIS,
44 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
45 | See the License for the specific language governing permissions and
46 | limitations under the License.
47 | """
48 | import os
49 | import pathlib
50 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
51 | from multiprocessing import cpu_count
52 |
53 | import numpy as np
54 | import torch
55 | from torch.utils.data import DataLoader
56 | import torchvision.transforms as TF
57 | from scipy import linalg
58 | from torch.nn.functional import adaptive_avg_pool2d
59 | from PIL import Image
60 |
61 | from fid.inception import InceptionV3
62 |
63 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
64 | parser.add_argument('--batch-size', type=int, default=50,
65 | help='Batch size to use')
66 | parser.add_argument('--device', type=str, default=None,
67 | help='Device to use. Like cuda, cuda:0 or cpu')
68 | parser.add_argument('--dims', type=int, default=2048,
69 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
70 | help=('Dimensionality of Inception features to use. '
71 | 'By default, uses pool3 features'))
72 | parser.add_argument('path', type=str, nargs=2,
73 | help=('Paths to the generated images or '
74 | 'to .npz statistic files'))
75 |
76 |
77 | class ImagesPathDataset(torch.utils.data.Dataset):
78 | def __init__(self, files, transforms=None):
79 | self.files = files
80 | self.transforms = transforms
81 |
82 | def __len__(self):
83 | return len(self.files)
84 |
85 | def __getitem__(self, i):
86 | path = self.files[i]
87 | img = Image.open(path).convert('RGB')
88 | if self.transforms is not None:
89 | img = self.transforms(img)
90 | return img
91 |
92 |
93 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', max_samples=None):
94 | """Calculates the activations of the pool_3 layer for all images.
95 |
96 | Params:
97 | -- files : List of image files paths or pytorch data loader
98 | -- model : Instance of inception model
99 | -- batch_size : Batch size of images for the model to process at once.
100 | Make sure that the number of samples is a multiple of
101 | the batch size, otherwise some samples are ignored. This
102 | behavior is retained to match the original FID score
103 | implementation.
104 | -- dims : Dimensionality of features returned by Inception
105 | -- device : Device to run calculations
106 | -- max_samples : Setting this value will stop activation when max_samples is reached
107 |
108 | Returns:
109 | -- A numpy array of dimension (num images, dims) that contains the
110 | activations of the given tensor when feeding inception with the
111 | query tensor.
112 | """
113 | model.eval()
114 |
115 | if isinstance(files, list):
116 | if batch_size > len(files):
117 | print(('Warning: batch size is bigger than the data size. '
118 | 'Setting batch size to data size'))
119 | batch_size = len(files)
120 |
121 | ds = ImagesPathDataset(files, transforms=TF.ToTensor())
122 | dl = DataLoader(ds, batch_size=batch_size,
123 | drop_last=False, num_workers=cpu_count())
124 | else:
125 | dl = files
126 |
127 | pred_arr = []
128 | total_processed = 0
129 |
130 | print('Starting to sample.')
131 | for batch in dl:
132 | # ignore labels
133 | if isinstance(batch, list):
134 | batch = batch[0]
135 |
136 | batch = batch.to(device)
137 | if batch.shape[1] == 1: # if image is gray scale
138 | batch = batch.repeat(1, 3, 1, 1)
139 |
140 | with torch.no_grad():
141 | pred = model(batch)[0]
142 |
143 | # If model output is not scalar, apply global spatial average pooling.
144 | # This happens if you choose a dimensionality not equal 2048.
145 | if pred.size(2) != 1 or pred.size(3) != 1:
146 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
147 |
148 | pred = pred.squeeze(3).squeeze(2).cpu().numpy()
149 | pred_arr.append(pred)
150 | total_processed += pred.shape[0]
151 | if max_samples is not None and total_processed > max_samples:
152 | print('Max Samples Reached.')
153 | break
154 |
155 | pred_arr = np.concatenate(pred_arr, axis=0)
156 | if max_samples is not None:
157 | pred_arr = pred_arr[:max_samples]
158 |
159 | return pred_arr
160 |
161 |
162 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
163 | """Numpy implementation of the Frechet Distance.
164 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
165 | and X_2 ~ N(mu_2, C_2) is
166 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
167 |
168 | Stable version by Dougal J. Sutherland.
169 |
170 | Params:
171 | -- mu1 : Numpy array containing the activations of a layer of the
172 | inception net (like returned by the function 'get_predictions')
173 | for generated samples.
174 | -- mu2 : The sample mean over activations, precalculated on an
175 | representative data set.
176 | -- sigma1: The covariance matrix over activations for generated samples.
177 | -- sigma2: The covariance matrix over activations, precalculated on an
178 | representative data set.
179 |
180 | Returns:
181 | -- : The Frechet Distance.
182 | """
183 |
184 | mu1 = np.atleast_1d(mu1)
185 | mu2 = np.atleast_1d(mu2)
186 |
187 | sigma1 = np.atleast_2d(sigma1)
188 | sigma2 = np.atleast_2d(sigma2)
189 |
190 | assert mu1.shape == mu2.shape, \
191 | 'Training and test mean vectors have different lengths'
192 | assert sigma1.shape == sigma2.shape, \
193 | 'Training and test covariances have different dimensions'
194 |
195 | diff = mu1 - mu2
196 |
197 | # Product might be almost singular
198 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
199 | if not np.isfinite(covmean).all():
200 | msg = ('fid calculation produces singular product; '
201 | 'adding %s to diagonal of cov estimates') % eps
202 | print(msg)
203 | offset = np.eye(sigma1.shape[0]) * eps
204 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
205 |
206 | # Numerical error might give slight imaginary component
207 | if np.iscomplexobj(covmean):
208 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
209 | m = np.max(np.abs(covmean.imag))
210 | raise ValueError('Imaginary component {}'.format(m))
211 | covmean = covmean.real
212 |
213 | tr_covmean = np.trace(covmean)
214 |
215 | return (diff.dot(diff) + np.trace(sigma1) +
216 | np.trace(sigma2) - 2 * tr_covmean)
217 |
218 |
219 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, device='cpu', max_samples=None):
220 | """Calculation of the statistics used by the FID.
221 | Params:
222 | -- files : List of image files paths or pytorch data loader
223 | -- model : Instance of inception model
224 | -- batch_size : The images numpy array is split into batches with
225 | batch size batch_size. A reasonable batch size
226 | depends on the hardware.
227 | -- dims : Dimensionality of features returned by Inception
228 | -- device : Device to run calculations
229 | -- max_samples : Setting this value will stop activation when max_samples is reached
230 |
231 | Returns:
232 | -- mu : The mean over samples of the activations of the pool_3 layer of
233 | the inception model.
234 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
235 | the inception model.
236 | """
237 | act = get_activations(files, model, batch_size, dims, device, max_samples)
238 | mu = np.mean(act, axis=0)
239 | sigma = np.cov(act, rowvar=False)
240 | return mu, sigma
241 |
242 |
243 | def _compute_statistics_of_path(path, model, batch_size, dims, device):
244 | if path.endswith('.npz'):
245 | f = np.load(path)
246 | m, s = f['mu'][:], f['sigma'][:]
247 | f.close()
248 | else:
249 | path = pathlib.Path(path)
250 | files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
251 | m, s = calculate_activation_statistics(files, model, batch_size,
252 | dims, device)
253 |
254 | return m, s
255 |
256 |
257 | def compute_statistics_of_generator(data_loader, model, batch_size, dims, device, max_samples=None):
258 | m, s = calculate_activation_statistics(data_loader, model, batch_size,
259 | dims, device, max_samples)
260 |
261 | return m, s
262 |
263 |
264 | def save_statistics(path, m, s):
265 | assert path.endswith('.npz')
266 | np.savez(path, mu=m, sigma=s)
267 |
268 |
269 | def load_statistics(path):
270 | assert path.endswith('.npz')
271 | f = np.load(path)
272 | m, s = f['mu'][:], f['sigma'][:]
273 | f.close()
274 | return m, s
275 |
276 |
277 | def calculate_fid_given_paths(paths, batch_size, device, dims):
278 | """Calculates the FID of two paths"""
279 | for p in paths:
280 | if not os.path.exists(p):
281 | raise RuntimeError('Invalid path: %s' % p)
282 |
283 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
284 |
285 | model = InceptionV3([block_idx]).to(device)
286 |
287 | m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
288 | dims, device)
289 | m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
290 | dims, device)
291 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
292 |
293 | return fid_value
294 |
295 |
296 | def main():
297 | args = parser.parse_args()
298 |
299 | if args.device is None:
300 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
301 | else:
302 | device = torch.device(args.device)
303 |
304 | fid_value = calculate_fid_given_paths(args.path,
305 | args.batch_size,
306 | device,
307 | args.dims)
308 | print('FID: ', fid_value)
309 |
310 |
311 | if __name__ == '__main__':
312 | main()
--------------------------------------------------------------------------------
/fid/inception.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the pytorch_fid library
5 | # which was released under the Apache License v2.0 License.
6 | #
7 | # Source:
8 | # https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py
9 | #
10 | # The license for the original version of this file can be
11 | # found in this directory (LICENSE_pytorch_fid). The modifications
12 | # to this file are subject to the same Apache License.
13 | # ---------------------------------------------------------------
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | import torchvision
19 |
20 | try:
21 | from torchvision.models.utils import load_state_dict_from_url
22 | except ImportError:
23 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
24 |
25 | # Inception weights ported to Pytorch from
26 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
27 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
28 |
29 |
30 | class InceptionV3(nn.Module):
31 | """Pretrained InceptionV3 network returning feature maps"""
32 |
33 | # Index of default block of inception to return,
34 | # corresponds to output of final average pooling
35 | DEFAULT_BLOCK_INDEX = 3
36 |
37 | # Maps feature dimensionality to their output blocks indices
38 | BLOCK_INDEX_BY_DIM = {
39 | 64: 0, # First max pooling features
40 | 192: 1, # Second max pooling featurs
41 | 768: 2, # Pre-aux classifier features
42 | 2048: 3 # Final average pooling features
43 | }
44 |
45 | def __init__(self,
46 | output_blocks=[DEFAULT_BLOCK_INDEX],
47 | resize_input=True,
48 | normalize_input=True,
49 | requires_grad=False,
50 | use_fid_inception=True,
51 | model_dir=None):
52 | """Build pretrained InceptionV3
53 |
54 | Parameters
55 | ----------
56 | output_blocks : list of int
57 | Indices of blocks to return features of. Possible values are:
58 | - 0: corresponds to output of first max pooling
59 | - 1: corresponds to output of second max pooling
60 | - 2: corresponds to output which is fed to aux classifier
61 | - 3: corresponds to output of final average pooling
62 | resize_input : bool
63 | If true, bilinearly resizes input to width and height 299 before
64 | feeding input to model. As the network without fully connected
65 | layers is fully convolutional, it should be able to handle inputs
66 | of arbitrary size, so resizing might not be strictly needed
67 | normalize_input : bool
68 | If true, scales the input from range (0, 1) to the range the
69 | pretrained Inception network expects, namely (-1, 1)
70 | requires_grad : bool
71 | If true, parameters of the model require gradients. Possibly useful
72 | for finetuning the network
73 | use_fid_inception : bool
74 | If true, uses the pretrained Inception model used in Tensorflow's
75 | FID implementation. If false, uses the pretrained Inception model
76 | available in torchvision. The FID Inception model has different
77 | weights and a slightly different structure from torchvision's
78 | Inception model. If you want to compute FID scores, you are
79 | strongly advised to set this parameter to true to get comparable
80 | results.
81 | model_dir: is used for storing pretrained checkpoints
82 | """
83 | super(InceptionV3, self).__init__()
84 |
85 | self.resize_input = resize_input
86 | self.normalize_input = normalize_input
87 | self.output_blocks = sorted(output_blocks)
88 | self.last_needed_block = max(output_blocks)
89 |
90 | assert self.last_needed_block <= 3, \
91 | 'Last possible output block index is 3'
92 |
93 | self.blocks = nn.ModuleList()
94 |
95 | if use_fid_inception:
96 | inception = fid_inception_v3(model_dir)
97 | else:
98 | inception = _inception_v3(pretrained=True)
99 |
100 | # Block 0: input to maxpool1
101 | block0 = [
102 | inception.Conv2d_1a_3x3,
103 | inception.Conv2d_2a_3x3,
104 | inception.Conv2d_2b_3x3,
105 | nn.MaxPool2d(kernel_size=3, stride=2)
106 | ]
107 | self.blocks.append(nn.Sequential(*block0))
108 |
109 | # Block 1: maxpool1 to maxpool2
110 | if self.last_needed_block >= 1:
111 | block1 = [
112 | inception.Conv2d_3b_1x1,
113 | inception.Conv2d_4a_3x3,
114 | nn.MaxPool2d(kernel_size=3, stride=2)
115 | ]
116 | self.blocks.append(nn.Sequential(*block1))
117 |
118 | # Block 2: maxpool2 to aux classifier
119 | if self.last_needed_block >= 2:
120 | block2 = [
121 | inception.Mixed_5b,
122 | inception.Mixed_5c,
123 | inception.Mixed_5d,
124 | inception.Mixed_6a,
125 | inception.Mixed_6b,
126 | inception.Mixed_6c,
127 | inception.Mixed_6d,
128 | inception.Mixed_6e,
129 | ]
130 | self.blocks.append(nn.Sequential(*block2))
131 |
132 | # Block 3: aux classifier to final avgpool
133 | if self.last_needed_block >= 3:
134 | block3 = [
135 | inception.Mixed_7a,
136 | inception.Mixed_7b,
137 | inception.Mixed_7c,
138 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
139 | ]
140 | self.blocks.append(nn.Sequential(*block3))
141 |
142 | for param in self.parameters():
143 | param.requires_grad = requires_grad
144 |
145 | def forward(self, inp):
146 | """Get Inception feature maps
147 |
148 | Parameters
149 | ----------
150 | inp : torch.autograd.Variable
151 | Input tensor of shape Bx3xHxW. Values are expected to be in
152 | range (0, 1)
153 |
154 | Returns
155 | -------
156 | List of torch.autograd.Variable, corresponding to the selected output
157 | block, sorted ascending by index
158 | """
159 | outp = []
160 | x = inp
161 |
162 | if self.resize_input:
163 | x = F.interpolate(x,
164 | size=(299, 299),
165 | mode='bilinear',
166 | align_corners=False)
167 |
168 | if self.normalize_input:
169 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
170 |
171 | for idx, block in enumerate(self.blocks):
172 | x = block(x)
173 | if idx in self.output_blocks:
174 | outp.append(x)
175 |
176 | if idx == self.last_needed_block:
177 | break
178 |
179 | return outp
180 |
181 |
182 | def _inception_v3(*args, **kwargs):
183 | """Wraps `torchvision.models.inception_v3`
184 |
185 | Skips default weight inititialization if supported by torchvision version.
186 | See https://github.com/mseitzer/pytorch-fid/issues/28.
187 | """
188 | try:
189 | version = tuple(map(int, torchvision.__version__.split('.')[:2]))
190 | except ValueError:
191 | # Just a caution against weird version strings
192 | version = (0,)
193 |
194 | if version >= (0, 6):
195 | kwargs['init_weights'] = False
196 |
197 | return torchvision.models.inception_v3(*args, **kwargs)
198 |
199 |
200 | def fid_inception_v3(model_dir=None):
201 | """Build pretrained Inception model for FID computation
202 |
203 | The Inception model for FID computation uses a different set of weights
204 | and has a slightly different structure than torchvision's Inception.
205 |
206 | This method first constructs torchvision's Inception and then patches the
207 | necessary parts that are different in the FID Inception model.
208 | """
209 | inception = _inception_v3(num_classes=1008,
210 | aux_logits=False,
211 | pretrained=False)
212 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
213 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
214 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
215 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
216 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
217 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
218 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
219 | inception.Mixed_7b = FIDInceptionE_1(1280)
220 | inception.Mixed_7c = FIDInceptionE_2(2048)
221 |
222 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, model_dir=model_dir, progress=True)
223 | inception.load_state_dict(state_dict)
224 | return inception
225 |
226 |
227 | class FIDInceptionA(torchvision.models.inception.InceptionA):
228 | """InceptionA block patched for FID computation"""
229 | def __init__(self, in_channels, pool_features):
230 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
231 |
232 | def forward(self, x):
233 | branch1x1 = self.branch1x1(x)
234 |
235 | branch5x5 = self.branch5x5_1(x)
236 | branch5x5 = self.branch5x5_2(branch5x5)
237 |
238 | branch3x3dbl = self.branch3x3dbl_1(x)
239 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
240 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
241 |
242 | # Patch: Tensorflow's average pool does not use the padded zero's in
243 | # its average calculation
244 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
245 | count_include_pad=False)
246 | branch_pool = self.branch_pool(branch_pool)
247 |
248 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
249 | return torch.cat(outputs, 1)
250 |
251 |
252 | class FIDInceptionC(torchvision.models.inception.InceptionC):
253 | """InceptionC block patched for FID computation"""
254 | def __init__(self, in_channels, channels_7x7):
255 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
256 |
257 | def forward(self, x):
258 | branch1x1 = self.branch1x1(x)
259 |
260 | branch7x7 = self.branch7x7_1(x)
261 | branch7x7 = self.branch7x7_2(branch7x7)
262 | branch7x7 = self.branch7x7_3(branch7x7)
263 |
264 | branch7x7dbl = self.branch7x7dbl_1(x)
265 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
266 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
267 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
268 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
269 |
270 | # Patch: Tensorflow's average pool does not use the padded zero's in
271 | # its average calculation
272 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
273 | count_include_pad=False)
274 | branch_pool = self.branch_pool(branch_pool)
275 |
276 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
277 | return torch.cat(outputs, 1)
278 |
279 |
280 | class FIDInceptionE_1(torchvision.models.inception.InceptionE):
281 | """First InceptionE block patched for FID computation"""
282 | def __init__(self, in_channels):
283 | super(FIDInceptionE_1, self).__init__(in_channels)
284 |
285 | def forward(self, x):
286 | branch1x1 = self.branch1x1(x)
287 |
288 | branch3x3 = self.branch3x3_1(x)
289 | branch3x3 = [
290 | self.branch3x3_2a(branch3x3),
291 | self.branch3x3_2b(branch3x3),
292 | ]
293 | branch3x3 = torch.cat(branch3x3, 1)
294 |
295 | branch3x3dbl = self.branch3x3dbl_1(x)
296 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
297 | branch3x3dbl = [
298 | self.branch3x3dbl_3a(branch3x3dbl),
299 | self.branch3x3dbl_3b(branch3x3dbl),
300 | ]
301 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
302 |
303 | # Patch: Tensorflow's average pool does not use the padded zero's in
304 | # its average calculation
305 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
306 | count_include_pad=False)
307 | branch_pool = self.branch_pool(branch_pool)
308 |
309 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
310 | return torch.cat(outputs, 1)
311 |
312 |
313 | class FIDInceptionE_2(torchvision.models.inception.InceptionE):
314 | """Second InceptionE block patched for FID computation"""
315 | def __init__(self, in_channels):
316 | super(FIDInceptionE_2, self).__init__(in_channels)
317 |
318 | def forward(self, x):
319 | branch1x1 = self.branch1x1(x)
320 |
321 | branch3x3 = self.branch3x3_1(x)
322 | branch3x3 = [
323 | self.branch3x3_2a(branch3x3),
324 | self.branch3x3_2b(branch3x3),
325 | ]
326 | branch3x3 = torch.cat(branch3x3, 1)
327 |
328 | branch3x3dbl = self.branch3x3dbl_1(x)
329 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
330 | branch3x3dbl = [
331 | self.branch3x3dbl_3a(branch3x3dbl),
332 | self.branch3x3dbl_3b(branch3x3dbl),
333 | ]
334 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
335 |
336 | # Patch: The FID Inception model uses max pooling instead of average
337 | # pooling. This is likely an error in this specific Inception
338 | # implementation, as other Inception models use average pooling here
339 | # (which matches the description in the paper).
340 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
341 | branch_pool = self.branch_pool(branch_pool)
342 |
343 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
344 | return torch.cat(outputs, 1)
--------------------------------------------------------------------------------
/img/celebahq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/NVAE/9fc1a288fb831c87d93a4e2663bc30ccf9225b29/img/celebahq.png
--------------------------------------------------------------------------------
/img/model_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/NVAE/9fc1a288fb831c87d93a4e2663bc30ccf9225b29/img/model_diagram.png
--------------------------------------------------------------------------------
/lmdb_datasets.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for NVAE. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import torch.utils.data as data
9 | import numpy as np
10 | import lmdb
11 | import os
12 | import io
13 | from PIL import Image
14 |
15 |
16 | def num_samples(dataset, train):
17 | if dataset == 'celeba':
18 | return 27000 if train else 3000
19 | elif dataset == 'celeba64':
20 | return 162770 if train else 19867
21 | elif dataset == 'imagenet-oord':
22 | return 1281147 if train else 50000
23 | elif dataset == 'ffhq':
24 | return 63000 if train else 7000
25 | else:
26 | raise NotImplementedError('dataset %s is unknown' % dataset)
27 |
28 |
29 | class LMDBDataset(data.Dataset):
30 | def __init__(self, root, name='', train=True, transform=None, is_encoded=False):
31 | self.train = train
32 | self.name = name
33 | self.transform = transform
34 | if self.train:
35 | lmdb_path = os.path.join(root, 'train.lmdb')
36 | else:
37 | lmdb_path = os.path.join(root, 'validation.lmdb')
38 | self.data_lmdb = lmdb.open(lmdb_path, readonly=True, max_readers=1,
39 | lock=False, readahead=False, meminit=False)
40 | self.is_encoded = is_encoded
41 |
42 | def __getitem__(self, index):
43 | target = [0]
44 | with self.data_lmdb.begin(write=False, buffers=True) as txn:
45 | data = txn.get(str(index).encode())
46 | if self.is_encoded:
47 | img = Image.open(io.BytesIO(data))
48 | img = img.convert('RGB')
49 | else:
50 | img = np.asarray(data, dtype=np.uint8)
51 | # assume data is RGB
52 | size = int(np.sqrt(len(img) / 3))
53 | img = np.reshape(img, (size, size, 3))
54 | img = Image.fromarray(img, mode='RGB')
55 |
56 | if self.transform is not None:
57 | img = self.transform(img)
58 |
59 | return img, target
60 |
61 | def __len__(self):
62 | return num_samples(self.name, self.train)
63 |
--------------------------------------------------------------------------------
/neural_ar_operations.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for NVAE. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.autograd import Variable
12 | import numpy as np
13 | from collections import OrderedDict
14 |
15 | from neural_operations import ConvBNSwish, normalize_weight_jit
16 |
17 | AROPS = OrderedDict([
18 | ('conv_3x3', lambda C, masked, zero_diag: ELUConv(C, C, 3, 1, 1, masked=masked, zero_diag=zero_diag))
19 | ])
20 |
21 |
22 | class Identity(nn.Module):
23 | def __init__(self, masked, zero_diag):
24 | super(Identity, self).__init__()
25 | if zero_diag:
26 | raise ValueError('Skip connection with zero diag is just a zero operation.')
27 |
28 | def forward(self, x):
29 | return x
30 |
31 |
32 | def channel_mask(c_in, g_in, c_out, zero_diag):
33 | assert c_in % c_out == 0 or c_out % c_in == 0, "%d - %d" % (c_in, c_out)
34 | assert g_in == 1 or g_in == c_in
35 |
36 | if g_in == 1:
37 | mask = np.ones([c_out, c_in], dtype=np.float32)
38 | if c_out >= c_in:
39 | ratio = c_out // c_in
40 | for i in range(c_in):
41 | mask[i * ratio:(i + 1) * ratio, i + 1:] = 0
42 | if zero_diag:
43 | mask[i * ratio:(i + 1) * ratio, i:i + 1] = 0
44 | else:
45 | ratio = c_in // c_out
46 | for i in range(c_out):
47 | mask[i:i + 1, (i + 1) * ratio:] = 0
48 | if zero_diag:
49 | mask[i:i + 1, i * ratio:(i + 1) * ratio:] = 0
50 | elif g_in == c_in:
51 | mask = np.ones([c_out, c_in // g_in], dtype=np.float32)
52 | if zero_diag:
53 | mask = 0. * mask
54 |
55 | return mask
56 |
57 |
58 | def create_conv_mask(kernel_size, c_in, g_in, c_out, zero_diag, mirror):
59 | m = (kernel_size - 1) // 2
60 | mask = np.ones([c_out, c_in // g_in, kernel_size, kernel_size], dtype=np.float32)
61 | mask[:, :, m:, :] = 0
62 | mask[:, :, m, :m] = 1
63 | mask[:, :, m, m] = channel_mask(c_in, g_in, c_out, zero_diag)
64 | if mirror:
65 | mask = np.copy(mask[:, :, ::-1, ::-1])
66 | return mask
67 |
68 |
69 | def norm(t, dim):
70 | return torch.sqrt(torch.sum(t * t, dim))
71 |
72 |
73 | class ARConv2d(nn.Conv2d):
74 | """Allows for weights as input."""
75 |
76 | def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False,
77 | masked=False, zero_diag=False, mirror=False):
78 | """
79 | Args:
80 | use_shared (bool): Use weights for this layer or not?
81 | """
82 | super(ARConv2d, self).__init__(C_in, C_out, kernel_size, stride, padding, dilation, groups, bias)
83 |
84 | self.masked = masked
85 | if self.masked:
86 | assert kernel_size % 2 == 1, 'kernel size should be an odd value.'
87 | self.mask = torch.from_numpy(create_conv_mask(kernel_size, C_in, groups, C_out, zero_diag, mirror)).cuda()
88 | init_mask = self.mask.cpu()
89 | else:
90 | self.mask = 1.0
91 | init_mask = 1.0
92 |
93 | # init weight normalizaition parameters
94 | init = torch.log(norm(self.weight * init_mask, dim=[1, 2, 3]).view(-1, 1, 1, 1) + 1e-2)
95 | self.log_weight_norm = nn.Parameter(init, requires_grad=True)
96 | self.weight_normalized = None
97 |
98 | def normalize_weight(self):
99 | weight = self.weight
100 | if self.masked:
101 | assert self.mask.size() == weight.size()
102 | weight = weight * self.mask
103 |
104 | # weight normalization
105 | weight = normalize_weight_jit(self.log_weight_norm, weight)
106 | return weight
107 |
108 | def forward(self, x):
109 | """
110 | Args:
111 | x (torch.Tensor): of size (B, C_in, H, W).
112 | params (ConvParam): containing `weight` and `bias` (optional) of conv operation.
113 | """
114 | self.weight_normalized = self.normalize_weight()
115 | bias = self.bias
116 | return F.conv2d(x, self.weight_normalized, bias, self.stride,
117 | self.padding, self.dilation, self.groups)
118 |
119 |
120 | class ELUConv(nn.Module):
121 | """ReLU + Conv2d + BN."""
122 |
123 | def __init__(self, C_in, C_out, kernel_size, padding=0, dilation=1, masked=True, zero_diag=True,
124 | weight_init_coeff=1.0, mirror=False):
125 | super(ELUConv, self).__init__()
126 | self.conv_0 = ARConv2d(C_in, C_out, kernel_size, stride=1, padding=padding, bias=True, dilation=dilation,
127 | masked=masked, zero_diag=zero_diag, mirror=mirror)
128 | # change the initialized log weight norm
129 | self.conv_0.log_weight_norm.data += np.log(weight_init_coeff)
130 |
131 | def forward(self, x):
132 | """
133 | Args:
134 | x (torch.Tensor): of size (B, C_in, H, W)
135 | """
136 | out = F.elu(x)
137 | out = self.conv_0(out)
138 | return out
139 |
140 |
141 | class ARInvertedResidual(nn.Module):
142 | def __init__(self, inz, inf, ex=6, dil=1, k=5, mirror=False):
143 | super(ARInvertedResidual, self).__init__()
144 | hidden_dim = int(round(inz * ex))
145 | padding = dil * (k - 1) // 2
146 | layers = []
147 | layers.extend([ARConv2d(inz, hidden_dim, kernel_size=3, padding=1, masked=True, mirror=mirror, zero_diag=True),
148 | nn.ELU(inplace=True)])
149 | layers.extend([ARConv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=k, padding=padding, dilation=dil,
150 | masked=True, mirror=mirror, zero_diag=False),
151 | nn.ELU(inplace=True)])
152 | self.convz = nn.Sequential(*layers)
153 | self.hidden_dim = hidden_dim
154 |
155 | def forward(self, z, ftr):
156 | z = self.convz(z)
157 | return z
158 |
159 |
160 | class MixLogCDFParam(nn.Module):
161 | def __init__(self, num_z, num_mix, num_ftr, mirror):
162 | super(MixLogCDFParam, self).__init__()
163 |
164 | num_out = num_z * (3 * num_mix + 3)
165 | self.conv = ELUConv(num_ftr, num_out, kernel_size=1, padding=0, masked=True, zero_diag=False,
166 | weight_init_coeff=0.1, mirror=mirror)
167 | self.num_z = num_z
168 | self.num_mix = num_mix
169 |
170 | def forward(self, ftr):
171 | out = self.conv(ftr)
172 | b, c, h, w = out.size()
173 | out = out.view(b, self.num_z, c // self.num_z, h, w)
174 | m = self.num_mix
175 | logit_pi, mu, log_s, log_a, b, _ = torch.split(out, [m, m, m, 1, 1, 1], dim=2) # the last one is dummy
176 | return logit_pi, mu, log_s, log_a, b
177 |
178 |
179 | def mix_log_cdf_flow(z1, logit_pi, mu, log_s, log_a, b):
180 | # z b, n, 1, h, w
181 | # logit_pi b, n, k, h, w
182 | # mu b, n, k, h, w
183 | # log_s b, n, k, h, w
184 | # log_a b, n, 1, h, w
185 | # b b, n, 1, h, w
186 |
187 | log_s = torch.clamp(log_s, min=-7)
188 |
189 | z = z1.unsqueeze(dim=2)
190 | log_pi = torch.log_softmax(logit_pi, dim=2) # normalize log_pi
191 | u = - (z - mu) * torch.exp(-log_s)
192 | softplus_u = F.softplus(u)
193 | log_mix_cdf = log_pi - softplus_u
194 | log_one_minus_mix_cdf = log_mix_cdf + u
195 | log_mix_cdf = torch.logsumexp(log_mix_cdf, dim=2)
196 | log_one_minus_mix_cdf = torch.logsumexp(log_one_minus_mix_cdf, dim=2)
197 |
198 | log_a = log_a.squeeze_(dim=2)
199 | b = b.squeeze_(dim=2)
200 | new_z = torch.exp(log_a) * (log_mix_cdf - log_one_minus_mix_cdf) + b
201 |
202 | # compute log determinant Jac
203 | log_mix_pdf = torch.logsumexp(log_pi + u - log_s - 2 * softplus_u, dim=2)
204 | log_det = log_a - log_mix_cdf - log_one_minus_mix_cdf + log_mix_pdf
205 |
206 | return new_z, log_det
207 |
--------------------------------------------------------------------------------
/neural_operations.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for NVAE. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from thirdparty.swish import Swish as SwishFN
13 | from thirdparty.inplaced_sync_batchnorm import SyncBatchNormSwish
14 |
15 | from utils import average_tensor
16 | from collections import OrderedDict
17 |
18 | BN_EPS = 1e-5
19 | SYNC_BN = True
20 |
21 | OPS = OrderedDict([
22 | ('res_elu', lambda Cin, Cout, stride: ELUConv(Cin, Cout, 3, stride, 1)),
23 | ('res_bnelu', lambda Cin, Cout, stride: BNELUConv(Cin, Cout, 3, stride, 1)),
24 | ('res_bnswish', lambda Cin, Cout, stride: BNSwishConv(Cin, Cout, 3, stride, 1)),
25 | ('res_bnswish5', lambda Cin, Cout, stride: BNSwishConv(Cin, Cout, 3, stride, 2, 2)),
26 | ('mconv_e6k5g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=6, dil=1, k=5, g=0)),
27 | ('mconv_e3k5g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=3, dil=1, k=5, g=0)),
28 | ('mconv_e3k5g8', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=3, dil=1, k=5, g=8)),
29 | ('mconv_e6k11g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=6, dil=1, k=11, g=0)),
30 | ])
31 |
32 |
33 | def get_skip_connection(C, stride, affine, channel_mult):
34 | if stride == 1:
35 | return Identity()
36 | elif stride == 2:
37 | return FactorizedReduce(C, int(channel_mult * C))
38 | elif stride == -1:
39 | return nn.Sequential(UpSample(), Conv2D(C, int(C / channel_mult), kernel_size=1))
40 |
41 |
42 | def norm(t, dim):
43 | return torch.sqrt(torch.sum(t * t, dim))
44 |
45 |
46 | def logit(t):
47 | return torch.log(t) - torch.log(1 - t)
48 |
49 |
50 | def act(t):
51 | # The following implementation has lower memory.
52 | return SwishFN.apply(t)
53 |
54 |
55 | class Swish(nn.Module):
56 | def __init__(self):
57 | super(Swish, self).__init__()
58 |
59 | def forward(self, x):
60 | return act(x)
61 |
62 | @torch.jit.script
63 | def normalize_weight_jit(log_weight_norm, weight):
64 | n = torch.exp(log_weight_norm)
65 | wn = torch.sqrt(torch.sum(weight * weight, dim=[1, 2, 3])) # norm(w)
66 | weight = n * weight / (wn.view(-1, 1, 1, 1) + 1e-5)
67 | return weight
68 |
69 |
70 | class Conv2D(nn.Conv2d):
71 | """Allows for weights as input."""
72 |
73 | def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, data_init=False,
74 | weight_norm=True):
75 | """
76 | Args:
77 | use_shared (bool): Use weights for this layer or not?
78 | """
79 | super(Conv2D, self).__init__(C_in, C_out, kernel_size, stride, padding, dilation, groups, bias)
80 |
81 | self.log_weight_norm = None
82 | if weight_norm:
83 | init = norm(self.weight, dim=[1, 2, 3]).view(-1, 1, 1, 1)
84 | self.log_weight_norm = nn.Parameter(torch.log(init + 1e-2), requires_grad=True)
85 |
86 | self.data_init = data_init
87 | self.init_done = False
88 | self.weight_normalized = self.normalize_weight()
89 |
90 | def forward(self, x):
91 | """
92 | Args:
93 | x (torch.Tensor): of size (B, C_in, H, W).
94 | params (ConvParam): containing `weight` and `bias` (optional) of conv operation.
95 | """
96 | # do data based initialization
97 | if self.data_init and not self.init_done:
98 | with torch.no_grad():
99 | weight = self.weight / (norm(self.weight, dim=[1, 2, 3]).view(-1, 1, 1, 1) + 1e-5)
100 | bias = None
101 | out = F.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
102 | mn = torch.mean(out, dim=[0, 2, 3])
103 | st = 5 * torch.std(out, dim=[0, 2, 3])
104 |
105 | # get mn and st from other GPUs
106 | average_tensor(mn, is_distributed=True)
107 | average_tensor(st, is_distributed=True)
108 |
109 | if self.bias is not None:
110 | self.bias.data = - mn / (st + 1e-5)
111 | self.log_weight_norm.data = -torch.log((st.view(-1, 1, 1, 1) + 1e-5))
112 | self.init_done = True
113 |
114 | self.weight_normalized = self.normalize_weight()
115 |
116 | bias = self.bias
117 | return F.conv2d(x, self.weight_normalized, bias, self.stride,
118 | self.padding, self.dilation, self.groups)
119 |
120 | def normalize_weight(self):
121 | """ applies weight normalization """
122 | if self.log_weight_norm is not None:
123 | weight = normalize_weight_jit(self.log_weight_norm, self.weight)
124 | else:
125 | weight = self.weight
126 |
127 | return weight
128 |
129 |
130 | class Identity(nn.Module):
131 | def __init__(self):
132 | super(Identity, self).__init__()
133 |
134 | def forward(self, x):
135 | return x
136 |
137 |
138 | class SyncBatchNorm(nn.Module):
139 | def __init__(self, *args, **kwargs):
140 | super(SyncBatchNorm, self).__init__()
141 | self.bn = nn.SyncBatchNorm(*args, **kwargs)
142 |
143 | def forward(self, x):
144 | # Sync BN only works with distributed data parallel with 1 GPU per process. I don't use DDP, so I need to let
145 | # Sync BN to know that I have 1 gpu per process.
146 | self.bn.ddp_gpu_size = 1
147 | return self.bn(x)
148 |
149 |
150 | # quick switch between multi-gpu, single-gpu batch norm
151 | def get_batchnorm(*args, **kwargs):
152 | if SYNC_BN:
153 | return SyncBatchNorm(*args, **kwargs)
154 | else:
155 | return nn.BatchNorm2d(*args, **kwargs)
156 |
157 |
158 | class ELUConv(nn.Module):
159 | def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1):
160 | super(ELUConv, self).__init__()
161 | self.upsample = stride == -1
162 | stride = abs(stride)
163 | self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation,
164 | data_init=True)
165 |
166 | def forward(self, x):
167 | out = F.elu(x)
168 | if self.upsample:
169 | out = F.interpolate(out, scale_factor=2, mode='nearest')
170 | out = self.conv_0(out)
171 | return out
172 |
173 |
174 | class BNELUConv(nn.Module):
175 | def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1):
176 | super(BNELUConv, self).__init__()
177 | self.upsample = stride == -1
178 | stride = abs(stride)
179 | self.bn = get_batchnorm(C_in, eps=BN_EPS, momentum=0.05)
180 | self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation)
181 |
182 | def forward(self, x):
183 | x = self.bn(x)
184 | out = F.elu(x)
185 | if self.upsample:
186 | out = F.interpolate(out, scale_factor=2, mode='nearest')
187 | out = self.conv_0(out)
188 | return out
189 |
190 |
191 | class BNSwishConv(nn.Module):
192 | """ReLU + Conv2d + BN."""
193 |
194 | def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1):
195 | super(BNSwishConv, self).__init__()
196 | self.upsample = stride == -1
197 | stride = abs(stride)
198 | self.bn_act = SyncBatchNormSwish(C_in, eps=BN_EPS, momentum=0.05)
199 | self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation)
200 |
201 | def forward(self, x):
202 | """
203 | Args:
204 | x (torch.Tensor): of size (B, C_in, H, W)
205 | """
206 | out = self.bn_act(x)
207 | if self.upsample:
208 | out = F.interpolate(out, scale_factor=2, mode='nearest')
209 | out = self.conv_0(out)
210 | return out
211 |
212 |
213 |
214 | class FactorizedReduce(nn.Module):
215 | def __init__(self, C_in, C_out):
216 | super(FactorizedReduce, self).__init__()
217 | assert C_out % 2 == 0
218 | self.conv_1 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True)
219 | self.conv_2 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True)
220 | self.conv_3 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True)
221 | self.conv_4 = Conv2D(C_in, C_out - 3 * (C_out // 4), 1, stride=2, padding=0, bias=True)
222 |
223 | def forward(self, x):
224 | out = act(x)
225 | conv1 = self.conv_1(out)
226 | conv2 = self.conv_2(out[:, :, 1:, 1:])
227 | conv3 = self.conv_3(out[:, :, :, 1:])
228 | conv4 = self.conv_4(out[:, :, 1:, :])
229 | out = torch.cat([conv1, conv2, conv3, conv4], dim=1)
230 | return out
231 |
232 |
233 | class UpSample(nn.Module):
234 | def __init__(self):
235 | super(UpSample, self).__init__()
236 | pass
237 |
238 | def forward(self, x):
239 | return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
240 |
241 |
242 | class EncCombinerCell(nn.Module):
243 | def __init__(self, Cin1, Cin2, Cout, cell_type):
244 | super(EncCombinerCell, self).__init__()
245 | self.cell_type = cell_type
246 | # Cin = Cin1 + Cin2
247 | self.conv = Conv2D(Cin2, Cout, kernel_size=1, stride=1, padding=0, bias=True)
248 |
249 | def forward(self, x1, x2):
250 | x2 = self.conv(x2)
251 | out = x1 + x2
252 | return out
253 |
254 |
255 | # original combiner
256 | class DecCombinerCell(nn.Module):
257 | def __init__(self, Cin1, Cin2, Cout, cell_type):
258 | super(DecCombinerCell, self).__init__()
259 | self.cell_type = cell_type
260 | self.conv = Conv2D(Cin1 + Cin2, Cout, kernel_size=1, stride=1, padding=0, bias=True)
261 |
262 | def forward(self, x1, x2):
263 | out = torch.cat([x1, x2], dim=1)
264 | out = self.conv(out)
265 | return out
266 |
267 |
268 | class ConvBNSwish(nn.Module):
269 | def __init__(self, Cin, Cout, k=3, stride=1, groups=1, dilation=1):
270 | padding = dilation * (k - 1) // 2
271 | super(ConvBNSwish, self).__init__()
272 |
273 | self.conv = nn.Sequential(
274 | Conv2D(Cin, Cout, k, stride, padding, groups=groups, bias=False, dilation=dilation, weight_norm=False),
275 | SyncBatchNormSwish(Cout, eps=BN_EPS, momentum=0.05) # drop in replacement for BN + Swish
276 | )
277 |
278 | def forward(self, x):
279 | return self.conv(x)
280 |
281 |
282 | class SE(nn.Module):
283 | def __init__(self, Cin, Cout):
284 | super(SE, self).__init__()
285 | num_hidden = max(Cout // 16, 4)
286 | self.se = nn.Sequential(nn.Linear(Cin, num_hidden), nn.ReLU(inplace=True),
287 | nn.Linear(num_hidden, Cout), nn.Sigmoid())
288 |
289 | def forward(self, x):
290 | se = torch.mean(x, dim=[2, 3])
291 | se = se.view(se.size(0), -1)
292 | se = self.se(se)
293 | se = se.view(se.size(0), -1, 1, 1)
294 | return x * se
295 |
296 |
297 | class InvertedResidual(nn.Module):
298 | def __init__(self, Cin, Cout, stride, ex, dil, k, g):
299 | super(InvertedResidual, self).__init__()
300 | self.stride = stride
301 | assert stride in [1, 2, -1]
302 |
303 | hidden_dim = int(round(Cin * ex))
304 | self.use_res_connect = self.stride == 1 and Cin == Cout
305 | self.upsample = self.stride == -1
306 | self.stride = abs(self.stride)
307 | groups = hidden_dim if g == 0 else g
308 |
309 | layers0 = [nn.UpsamplingNearest2d(scale_factor=2)] if self.upsample else []
310 | layers = [get_batchnorm(Cin, eps=BN_EPS, momentum=0.05),
311 | ConvBNSwish(Cin, hidden_dim, k=1),
312 | ConvBNSwish(hidden_dim, hidden_dim, stride=self.stride, groups=groups, k=k, dilation=dil),
313 | Conv2D(hidden_dim, Cout, 1, 1, 0, bias=False, weight_norm=False),
314 | get_batchnorm(Cout, momentum=0.05)]
315 |
316 | layers0.extend(layers)
317 | self.conv = nn.Sequential(*layers0)
318 |
319 | def forward(self, x):
320 | return self.conv(x)
321 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.6.0
2 | torchvision==0.7.0
3 | pillow
4 | matplotlib
5 | tensorboard
6 | tensorboardX
7 | lmdb
8 | tfrecord
--------------------------------------------------------------------------------
/scripts/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:20.07-py3
2 |
3 | RUN pip install pillow
4 | RUN pip install cython
5 | RUN pip install matplotlib
6 | RUN pip install tensorboard
7 | RUN pip install tensorboardX
8 | RUN pip install tfrecord
9 | RUN apt update
10 | RUN apt install screen -y
11 |
--------------------------------------------------------------------------------
/scripts/convert_tfrecord_to_lmdb.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for NVAE. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 | import argparse
8 | import torch
9 | import lmdb
10 | import os
11 |
12 | from tfrecord.torch.dataset import TFRecordDataset
13 |
14 |
15 | def main(dataset, split, tfr_path, lmdb_path):
16 | assert split in {'train', 'validation'}
17 |
18 | # create target directory
19 | if not os.path.exists(lmdb_path):
20 | os.makedirs(lmdb_path, exist_ok=True)
21 | if dataset == 'celeba' and split in {'train', 'validation'}:
22 | num_shards = {'train': 120, 'validation': 40}[split]
23 | lmdb_path = os.path.join(lmdb_path, '%s.lmdb' % split)
24 | tfrecord_path_template = os.path.join(tfr_path, '%s/%s-r08-s-%04d-of-%04d.tfrecords')
25 | elif dataset == 'imagenet-oord_32':
26 | num_shards = {'train': 2000, 'validation': 80}[split]
27 | # imagenet_oord_lmdb_path += '_32'
28 | lmdb_path = os.path.join(lmdb_path, '%s.lmdb' % split)
29 | tfrecord_path_template = os.path.join(tfr_path, '%s/%s-r05-s-%04d-of-%04d.tfrecords')
30 | elif dataset == 'imagenet-oord_64':
31 | num_shards = {'train': 2000, 'validation': 80}[split]
32 | # imagenet_oord_lmdb_path += '_64'
33 | lmdb_path = os.path.join(lmdb_path, '%s.lmdb' % split)
34 | tfrecord_path_template = os.path.join(tfr_path, '%s/%s-r06-s-%04d-of-%04d.tfrecords')
35 | else:
36 | raise NotImplementedError
37 |
38 | # create lmdb
39 | env = lmdb.open(lmdb_path, map_size=1e12)
40 | count = 0
41 | with env.begin(write=True) as txn:
42 | for tf_ind in range(num_shards):
43 | # read tf_record
44 | tfrecord_path = tfrecord_path_template % (split, split, tf_ind, num_shards)
45 | index_path = None
46 | description = {'shape': 'int', 'data': 'byte', 'label': 'int'}
47 | dataset = TFRecordDataset(tfrecord_path, index_path, description)
48 | loader = torch.utils.data.DataLoader(dataset, batch_size=1)
49 |
50 | # put the data in lmdb
51 | for data in loader:
52 | im = data['data'][0].cpu().numpy()
53 | txn.put(str(count).encode(), im)
54 | count += 1
55 | if count % 100 == 0:
56 | print(count)
57 |
58 | print('added %d items to the LMDB dataset.' % count)
59 |
60 |
61 | if __name__ == '__main__':
62 | parser = argparse.ArgumentParser('LMDB creator using TFRecords from GLOW.')
63 | # experimental results
64 | parser.add_argument('--dataset', type=str, default='imagenet-oord_32',
65 | help='dataset name', choices=['imagenet-oord_32', 'imagenet-oord_32', 'celeba'])
66 | parser.add_argument('--tfr_path', type=str, default='/data1/datasets/imagenet-oord/mnt/host/imagenet-oord-tfr',
67 | help='location of TFRecords')
68 | parser.add_argument('--lmdb_path', type=str, default='/data1/datasets/imagenet-oord/imagenet-oord-lmdb_32',
69 | help='target location for storing lmdb files')
70 | parser.add_argument('--split', type=str, default='train',
71 | help='training or validation split', choices=['train', 'validation'])
72 | args = parser.parse_args()
73 | main(args.dataset, args.split, args.tfr_path, args.lmdb_path)
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/scripts/create_celeba64_lmdb.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for NVAE. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import argparse
9 | import lmdb
10 | import os
11 | import torchvision.datasets as dset
12 |
13 |
14 | def main(split, img_path, lmdb_path):
15 | assert split in {"train", "valid", "test"}
16 | # create target directory
17 | if not os.path.exists(lmdb_path):
18 | os.makedirs(lmdb_path, exist_ok=True)
19 |
20 | lmdb_split = {'train': 'train', 'valid': 'validation', 'test': 'test'}[split]
21 | lmdb_path = os.path.join(lmdb_path, '%s.lmdb' % lmdb_split)
22 |
23 | # if you don't have this will download the data
24 | data = dset.celeba.CelebA(root=img_path, split=split, target_type='attr', transform=None, download=True)
25 | print(len('total data'))
26 |
27 | # create lmdb
28 | env = lmdb.open(lmdb_path, map_size=1e12)
29 | with env.begin(write=True) as txn:
30 | for i in range(len(data)):
31 | file_path = os.path.join(data.root, data.base_folder, "img_align_celeba", data.filename[i])
32 | attr = data.attr[i, :]
33 | with open(file_path, 'rb') as f:
34 | file_data = f.read()
35 |
36 | txn.put(str(i).encode(), file_data)
37 | print(i)
38 |
39 |
40 | if __name__ == '__main__':
41 | parser = argparse.ArgumentParser('CelebA 64 LMDB creator.')
42 | # experimental results
43 | parser.add_argument('--img_path', type=str, default='/data1/datasets/celeba_org/',
44 | help='location of images for CelebA dataset')
45 | parser.add_argument('--lmdb_path', type=str, default='/data1/datasets/celeba_org/celeba64_lmdb',
46 | help='target location for storing lmdb files')
47 | parser.add_argument('--split', type=str, default='train',
48 | help='training or validation split', choices=["train", "valid", "test"])
49 | args = parser.parse_args()
50 | main(args.split, args.img_path, args.lmdb_path)
51 |
52 |
--------------------------------------------------------------------------------
/scripts/create_ffhq_lmdb.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for NVAE. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import argparse
9 | import torch
10 | import numpy as np
11 | import lmdb
12 | import os
13 |
14 | from PIL import Image
15 |
16 |
17 | def main(split, ffhq_img_path, ffhq_lmdb_path):
18 | assert split in {'train', 'validation'}
19 | num_images = 70000
20 | num_train = 63000
21 |
22 | # create target directory
23 | if not os.path.exists(ffhq_lmdb_path):
24 | os.makedirs(ffhq_lmdb_path, exist_ok=True)
25 |
26 | ind_path = os.path.join(ffhq_lmdb_path, 'train_test_ind.pt')
27 | if os.path.exists(ind_path):
28 | ind_dat = torch.load(ind_path)
29 | train_ind = ind_dat['train']
30 | test_ind = ind_dat['test']
31 | else:
32 | rand = np.random.permutation(num_images)
33 | train_ind = rand[:num_train]
34 | test_ind = rand[num_train:]
35 | torch.save({'train': train_ind, 'test': test_ind}, ind_path)
36 |
37 | file_ind = train_ind if split == 'train' else test_ind
38 | lmdb_path = os.path.join(ffhq_lmdb_path, '%s.lmdb' % split)
39 |
40 | # create lmdb
41 | env = lmdb.open(lmdb_path, map_size=1e12)
42 | count = 0
43 | with env.begin(write=True) as txn:
44 | for i in file_ind:
45 | img_path = os.path.join(ffhq_img_path, '%05d.png' % i)
46 | im = Image.open(img_path)
47 | im = im.resize(size=(256, 256), resample=Image.BILINEAR)
48 | im = np.array(im.getdata(), dtype=np.uint8).reshape(im.size[1], im.size[0], 3)
49 |
50 | txn.put(str(count).encode(), im)
51 | count += 1
52 | if count % 100 == 0:
53 | print(count)
54 |
55 | print('added %d items to the LMDB dataset.' % count)
56 |
57 |
58 | if __name__ == '__main__':
59 | parser = argparse.ArgumentParser('FFHQ LMDB creator. Download images1024x1024.zip from here and unzip it \n'
60 | 'https://drive.google.com/drive/folders/1WocxvZ4GEZ1DI8dOz30aSj2zT6pkATYS')
61 | # experimental results
62 | parser.add_argument('--ffhq_img_path', type=str, default='/data1/datasets/ffhq/images1024x1024',
63 | help='location of images from FFHQ')
64 | parser.add_argument('--ffhq_lmdb_path', type=str, default='/data1/datasets/ffhq/ffhq-lmdb',
65 | help='target location for storing lmdb files')
66 | parser.add_argument('--split', type=str, default='train',
67 | help='training or validation split', choices=['train', 'validation'])
68 | args = parser.parse_args()
69 |
70 | main(args.split, args.ffhq_img_path, args.ffhq_lmdb_path)
71 |
72 |
--------------------------------------------------------------------------------
/scripts/precompute_fid_statistics.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for NVAE. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 | import os
8 | import argparse
9 | from fid.fid_score import compute_statistics_of_generator, save_statistics
10 | from datasets import get_loaders_eval
11 | from fid.inception import InceptionV3
12 | from itertools import chain
13 |
14 |
15 | def main(args):
16 | device = 'cuda'
17 | dims = 2048
18 | # for binary datasets including MNIST and OMNIGLOT, we don't apply binarization for FID computation
19 | train_queue, valid_queue, _ = get_loaders_eval(args.dataset, args)
20 | print('len train queue', len(train_queue), 'len val queue', len(valid_queue), 'batch size', args.batch_size)
21 | if args.dataset in {'celeba_256', 'omniglot'}:
22 | train_queue = chain(train_queue, valid_queue)
23 |
24 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
25 | model = InceptionV3([block_idx], model_dir=args.fid_dir).to(device)
26 | m, s = compute_statistics_of_generator(train_queue, model, args.batch_size, dims, device, args.max_samples)
27 | file_path = os.path.join(args.fid_dir, args.dataset + '.npz')
28 | print('saving fid stats at %s' % file_path)
29 | save_statistics(file_path, m, s)
30 |
31 |
32 | if __name__ == '__main__':
33 | # python precompute_fid_statistics.py --dataset cifar10
34 | parser = argparse.ArgumentParser('')
35 | parser.add_argument('--dataset', type=str, default='cifar10',
36 | choices=['cifar10', 'celeba_64', 'celeba_256', 'omniglot', 'mnist',
37 | 'imagenet_32', 'ffhq', 'lsun_bedroom_128', 'lsun_church_256'],
38 | help='which dataset to use')
39 | parser.add_argument('--data', type=str, default='/tmp/nvae-diff/data',
40 | help='location of the data corpus')
41 | parser.add_argument('--batch_size', type=int, default=64,
42 | help='batch size per GPU')
43 | parser.add_argument('--max_samples', type=int, default=50000,
44 | help='batch size per GPU')
45 | parser.add_argument('--fid_dir', type=str, default='/tmp/fid-stats',
46 | help='A dir to store fid related files')
47 |
48 | args = parser.parse_args()
49 | args.distributed = False
50 |
51 | main(args)
--------------------------------------------------------------------------------
/thirdparty/LICENSE_PyTorch:
--------------------------------------------------------------------------------
1 | From PyTorch:
2 |
3 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
4 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
5 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
6 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
7 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
8 | Copyright (c) 2011-2013 NYU (Clement Farabet)
9 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
10 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
11 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
12 |
13 | From Caffe2:
14 |
15 | Copyright (c) 2016-present, Facebook Inc. All rights reserved.
16 |
17 | All contributions by Facebook:
18 | Copyright (c) 2016 Facebook Inc.
19 |
20 | All contributions by Google:
21 | Copyright (c) 2015 Google Inc.
22 | All rights reserved.
23 |
24 | All contributions by Yangqing Jia:
25 | Copyright (c) 2015 Yangqing Jia
26 | All rights reserved.
27 |
28 | All contributions from Caffe:
29 | Copyright(c) 2013, 2014, 2015, the respective contributors
30 | All rights reserved.
31 |
32 | All other contributions:
33 | Copyright(c) 2015, 2016 the respective contributors
34 | All rights reserved.
35 |
36 | Caffe2 uses a copyright model similar to Caffe: each contributor holds
37 | copyright over their contributions to Caffe2. The project versioning records
38 | all such contribution and copyright details. If a contributor wants to further
39 | mark their specific copyright on a particular contribution, they should
40 | indicate their copyright solely in the commit message of the change when it is
41 | committed.
42 |
43 | All rights reserved.
44 |
45 | Redistribution and use in source and binary forms, with or without
46 | modification, are permitted provided that the following conditions are met:
47 |
48 | 1. Redistributions of source code must retain the above copyright
49 | notice, this list of conditions and the following disclaimer.
50 |
51 | 2. Redistributions in binary form must reproduce the above copyright
52 | notice, this list of conditions and the following disclaimer in the
53 | documentation and/or other materials provided with the distribution.
54 |
55 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
56 | and IDIAP Research Institute nor the names of its contributors may be
57 | used to endorse or promote products derived from this software without
58 | specific prior written permission.
59 |
60 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
61 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
62 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
63 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
64 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
65 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
66 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
67 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
68 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
69 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
70 | POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/thirdparty/LICENSE_apache:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/thirdparty/LICENSE_torchvision:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) Soumith Chintala 2016,
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/thirdparty/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/NVAE/9fc1a288fb831c87d93a4e2663bc30ccf9225b29/thirdparty/__init__.py
--------------------------------------------------------------------------------
/thirdparty/adamax.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the PyTorch library.
5 | #
6 | # Source:
7 | # https://github.com/pytorch/pytorch/blob/6e2bb1c05442010aff90b413e21fce99f0393727/torch/optim/adamax.py
8 | #
9 | # The license for the original version of this file can be
10 | # found in this directory (LICENSE_PyTorch). The modifications
11 | # to this file are subject to the NVIDIA Source Code License for
12 | # NVAE located at the root directory.
13 | # ---------------------------------------------------------------
14 |
15 | import torch
16 | from torch.optim import Optimizer
17 |
18 | @torch.jit.script
19 | def fusion1(exp_avg :torch.Tensor, grad :torch.Tensor, beta1: float):
20 | return exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
21 |
22 |
23 | class Adamax(Optimizer):
24 | """Implements Adamax algorithm (a variant of Adam based on infinity norm).
25 |
26 | It has been proposed in `Adam: A Method for Stochastic Optimization`__.
27 |
28 | Arguments:
29 | params (iterable): iterable of parameters to optimize or dicts defining
30 | parameter groups
31 | lr (float, optional): learning rate (default: 2e-3)
32 | betas (Tuple[float, float], optional): coefficients used for computing
33 | running averages of gradient and its square
34 | eps (float, optional): term added to the denominator to improve
35 | numerical stability (default: 1e-8)
36 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
37 |
38 | __ https://arxiv.org/abs/1412.6980
39 | """
40 |
41 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
42 | weight_decay=0):
43 | if not 0.0 <= lr:
44 | raise ValueError("Invalid learning rate: {}".format(lr))
45 | if not 0.0 <= eps:
46 | raise ValueError("Invalid epsilon value: {}".format(eps))
47 | if not 0.0 <= betas[0] < 1.0:
48 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
49 | if not 0.0 <= betas[1] < 1.0:
50 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
51 | if not 0.0 <= weight_decay:
52 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
53 |
54 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
55 | super(Adamax, self).__init__(params, defaults)
56 |
57 | def step(self, closure=None):
58 | """Performs a single optimization step.
59 |
60 | Arguments:
61 | closure (callable, optional): A closure that reevaluates the model
62 | and returns the loss.
63 | """
64 | loss = None
65 | if closure is not None:
66 | loss = closure()
67 |
68 | params, grads, exp_avg, exp_inf = {},{},{},{}
69 |
70 | for group in self.param_groups:
71 | for i, p in enumerate(group['params']):
72 | if p.grad is None:
73 | continue
74 | grad = p.grad.data
75 | if grad.is_sparse:
76 | raise RuntimeError('Adamax does not support sparse gradients')
77 | state = self.state[p]
78 |
79 | # State initialization
80 | if len(state) == 0:
81 | state['step'] = 0
82 | # state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
83 | # state['exp_inf'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
84 | state['exp_avg'] = torch.zeros_like(p.data)
85 | state['exp_inf'] = torch.zeros_like(p.data)
86 |
87 | state['step'] += 1
88 |
89 | if p.shape not in params:
90 | params[p.shape] = {'idx': 0, 'data': []}
91 | grads[p.shape] = []
92 | exp_avg[p.shape] = []
93 | exp_inf[p.shape] = []
94 |
95 | params[p.shape]['data'].append(p.data)
96 | grads[p.shape].append(grad)
97 | exp_avg[p.shape].append(state['exp_avg'])
98 | exp_inf[p.shape].append(state['exp_inf'])
99 |
100 | for i in params:
101 | params[i]['data'] = torch.stack(params[i]['data'], dim=0)
102 | grads[i] = torch.stack(grads[i], dim=0)
103 | exp_avg[i] = torch.stack(exp_avg[i], dim=0)
104 | exp_inf[i] = torch.stack(exp_inf[i], dim=0)
105 |
106 | for group in self.param_groups:
107 | beta1, beta2 = group['betas']
108 | eps = group['eps']
109 | bias_correction = 1 - beta1 ** self.state[group['params'][0]]['step']
110 | clr = group['lr'] / bias_correction
111 |
112 | for i in params:
113 | if group['weight_decay'] != 0:
114 | grads[i] = grads[i].add_(params[i]['data'], alpha=group['weight_decay'])
115 | # Update biased first moment estimate.
116 | exp_avg[i].mul_(beta1).add_(grads[i], alpha=1 - beta1)
117 | # Update the exponentially weighted infinity norm.
118 | torch.max(exp_inf[i].mul_(beta2), grads[i].abs_().add_(eps), out=exp_inf[i])
119 | params[i]['data'].addcdiv_(exp_avg[i], exp_inf[i], value=-clr)
120 |
121 | for group in self.param_groups:
122 | for p in group['params']:
123 | idx = params[p.shape]['idx']
124 | p.data = params[p.shape]['data'][idx, :]
125 | self.state[p]['exp_avg'] = exp_avg[p.shape][idx, :]
126 | self.state[p]['exp_inf'] = exp_inf[p.shape][idx, :]
127 | params[p.shape]['idx'] += 1
128 |
129 | return loss
130 |
--------------------------------------------------------------------------------
/thirdparty/functions.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the PyTorch library.
5 | #
6 | # Source:
7 | # https://github.com/pytorch/pytorch/blob/2a54533c64c409b626b6c209ed78258f67aec194/torch/nn/modules/_functions.py
8 | #
9 | # The license for the original version of this file can be
10 | # found in this directory (LICENSE_PyTorch). The modifications
11 | # to this file are subject to the NVIDIA Source Code License for
12 | # NVAE located at the root directory.
13 | # ---------------------------------------------------------------
14 |
15 | import torch
16 | from torch.autograd.function import Function
17 | import torch.distributed as dist
18 |
19 |
20 | class SyncBatchNorm(Function):
21 |
22 | @staticmethod
23 | def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
24 | input = input.contiguous()
25 |
26 | count = torch.empty(1,
27 | dtype=running_mean.dtype,
28 | device=input.device).fill_(input.numel() // input.size(1))
29 |
30 | # calculate mean/invstd for input.
31 | mean, invstd = torch.batch_norm_stats(input, eps)
32 |
33 | num_channels = input.shape[1]
34 | # C, C, 1 -> (2C + 1)
35 | combined = torch.cat([mean, invstd, count], dim=0)
36 | # world_size * (2C + 1)
37 | combined_list = [
38 | torch.empty_like(combined) for k in range(world_size)
39 | ]
40 | # Use allgather instead of allreduce since I don't trust in-place operations ..
41 | dist.all_gather(combined_list, combined, async_op=False)
42 | combined = torch.stack(combined_list, dim=0)
43 | # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
44 | mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
45 |
46 | size = count_all.view(-1).long().sum()
47 | if size == 1:
48 | raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
49 |
50 | # calculate global mean & invstd
51 | mean, invstd = torch.batch_norm_gather_stats_with_counts(
52 | input,
53 | mean_all,
54 | invstd_all,
55 | running_mean,
56 | running_var,
57 | momentum,
58 | eps,
59 | count_all.view(-1)
60 | )
61 |
62 | self.save_for_backward(input, weight, mean, invstd, bias, count_all)
63 | self.process_group = process_group
64 |
65 | # apply element-wise normalization
66 | out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
67 |
68 | # av: apply swish
69 | assert eps == 1e-5, "I assumed below that eps is 1e-5"
70 | out = out * torch.sigmoid(out)
71 | # av: end
72 |
73 | return out
74 |
75 | @staticmethod
76 | def backward(self, grad_output):
77 | grad_output = grad_output.contiguous()
78 | saved_input, weight, mean, invstd, bias, count_tensor = self.saved_tensors
79 |
80 | # av: re-compute batch normalized out
81 | eps = 1e-5
82 | out = torch.batch_norm_elemt(saved_input, weight, bias, mean, invstd, eps)
83 | sigmoid_out = torch.sigmoid(out)
84 | grad_output *= (sigmoid_out * (1 + out * (1 - sigmoid_out)))
85 | # av: end
86 |
87 | grad_input = grad_weight = grad_bias = None
88 | process_group = self.process_group
89 |
90 | # calculate local stats as well as grad_weight / grad_bias
91 | sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
92 | grad_output,
93 | saved_input,
94 | mean,
95 | invstd,
96 | weight,
97 | self.needs_input_grad[0],
98 | self.needs_input_grad[1],
99 | self.needs_input_grad[2]
100 | )
101 |
102 | if self.needs_input_grad[0]:
103 | # synchronizing stats used to calculate input gradient.
104 | # TODO: move div_ into batch_norm_backward_elemt kernel
105 | num_channels = sum_dy.shape[0]
106 | combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
107 | torch.distributed.all_reduce(
108 | combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
109 | sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
110 |
111 | divisor = count_tensor.sum()
112 | mean_dy = sum_dy / divisor
113 | mean_dy_xmu = sum_dy_xmu / divisor
114 | # backward pass for gradient calculation
115 | grad_input = torch.batch_norm_backward_elemt(
116 | grad_output,
117 | saved_input,
118 | mean,
119 | invstd,
120 | weight,
121 | mean_dy,
122 | mean_dy_xmu
123 | )
124 |
125 | # synchronizing of grad_weight / grad_bias is not needed as distributed
126 | # training would handle all reduce.
127 | if weight is None or not self.needs_input_grad[1]:
128 | grad_weight = None
129 |
130 | if weight is None or not self.needs_input_grad[2]:
131 | grad_bias = None
132 |
133 | return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
134 |
--------------------------------------------------------------------------------
/thirdparty/inplaced_sync_batchnorm.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the PyTorch library.
5 | #
6 | # Source:
7 | # https://github.com/pytorch/pytorch/blob/881c1adfcd916b6cd5de91bc343eb86aff88cc80/torch/nn/modules/batchnorm.py
8 | #
9 | # The license for the original version of this file can be
10 | # found in this directory (LICENSE_PyTorch). The modifications
11 | # to this file are subject to the NVIDIA Source Code License for
12 | # NVAE located at the root directory.
13 | # ---------------------------------------------------------------
14 |
15 | from __future__ import division
16 |
17 | import torch
18 | from torch.nn.modules.batchnorm import _BatchNorm
19 | import torch.nn.functional as F
20 |
21 | from .functions import SyncBatchNorm as sync_batch_norm
22 | from .swish import Swish as swish
23 |
24 |
25 | class SyncBatchNormSwish(_BatchNorm):
26 | r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
27 | with additional channel dimension) as described in the paper
28 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
29 |
30 | .. math::
31 |
32 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
33 |
34 | The mean and standard-deviation are calculated per-dimension over all
35 | mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
36 | are learnable parameter vectors of size `C` (where `C` is the input size).
37 | By default, the elements of :math:`\gamma` are sampled from
38 | :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
39 |
40 | Also by default, during training this layer keeps running estimates of its
41 | computed mean and variance, which are then used for normalization during
42 | evaluation. The running estimates are kept with a default :attr:`momentum`
43 | of 0.1.
44 |
45 | If :attr:`track_running_stats` is set to ``False``, this layer then does not
46 | keep running estimates, and batch statistics are instead used during
47 | evaluation time as well.
48 |
49 | .. note::
50 | This :attr:`momentum` argument is different from one used in optimizer
51 | classes and the conventional notion of momentum. Mathematically, the
52 | update rule for running statistics here is
53 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
54 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
55 | new observed value.
56 |
57 | Because the Batch Normalization is done over the `C` dimension, computing statistics
58 | on `(N, +)` slices, it's common terminology to call this Volumetric Batch Normalization
59 | or Spatio-temporal Batch Normalization.
60 |
61 | Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use
62 | torch.nn.SyncBatchNorm.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping
63 | Network with DDP.
64 |
65 | Args:
66 | num_features: :math:`C` from an expected input of size
67 | :math:`(N, C, +)`
68 | eps: a value added to the denominator for numerical stability.
69 | Default: 1e-5
70 | momentum: the value used for the running_mean and running_var
71 | computation. Can be set to ``None`` for cumulative moving average
72 | (i.e. simple average). Default: 0.1
73 | affine: a boolean value that when set to ``True``, this module has
74 | learnable affine parameters. Default: ``True``
75 | track_running_stats: a boolean value that when set to ``True``, this
76 | module tracks the running mean and variance, and when set to ``False``,
77 | this module does not track such statistics and always uses batch
78 | statistics in both training and eval modes. Default: ``True``
79 | process_group: synchronization of stats happen within each process group
80 | individually. Default behavior is synchronization across the whole
81 | world
82 |
83 | Shape:
84 | - Input: :math:`(N, C, +)`
85 | - Output: :math:`(N, C, +)` (same shape as input)
86 |
87 | Examples::
88 |
89 | >>> # With Learnable Parameters
90 | >>> m = nn.SyncBatchNorm(100)
91 | >>> # creating process group (optional)
92 | >>> # process_ids is a list of int identifying rank ids.
93 | >>> process_group = torch.distributed.new_group(process_ids)
94 | >>> # Without Learnable Parameters
95 | >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
96 | >>> input = torch.randn(20, 100, 35, 45, 10)
97 | >>> output = m(input)
98 |
99 | >>> # network is nn.BatchNorm layer
100 | >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
101 | >>> # only single gpu per process is currently supported
102 | >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
103 | >>> sync_bn_network,
104 | >>> device_ids=[args.local_rank],
105 | >>> output_device=args.local_rank)
106 |
107 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
108 | https://arxiv.org/abs/1502.03167
109 | """
110 |
111 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
112 | track_running_stats=True, process_group=None):
113 | super(SyncBatchNormSwish, self).__init__(num_features, eps, momentum, affine, track_running_stats)
114 | self.process_group = process_group
115 | # gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used
116 | # under supported condition (single GPU per process)
117 | self.ddp_gpu_size = None
118 |
119 | def _check_input_dim(self, input):
120 | if input.dim() < 2:
121 | raise ValueError('expected at least 2D input (got {}D input)'
122 | .format(input.dim()))
123 |
124 | def _specify_ddp_gpu_num(self, gpu_size):
125 | if gpu_size > 1:
126 | raise ValueError('SyncBatchNorm is only supported for DDP with single GPU per process')
127 | self.ddp_gpu_size = gpu_size
128 |
129 | def forward(self, input):
130 | # currently only GPU input is supported
131 | if not input.is_cuda:
132 | raise ValueError('SyncBatchNorm expected input tensor to be on GPU')
133 |
134 | self._check_input_dim(input)
135 |
136 | # exponential_average_factor is set to self.momentum
137 | # (when it is available) only so that it gets updated
138 | # in ONNX graph when this node is exported to ONNX.
139 | if self.momentum is None:
140 | exponential_average_factor = 0.0
141 | else:
142 | exponential_average_factor = self.momentum
143 |
144 | if self.training and self.track_running_stats:
145 | self.num_batches_tracked = self.num_batches_tracked + 1
146 | if self.momentum is None: # use cumulative moving average
147 | exponential_average_factor = 1.0 / self.num_batches_tracked.item()
148 | else: # use exponential moving average
149 | exponential_average_factor = self.momentum
150 |
151 | need_sync = self.training or not self.track_running_stats
152 | if need_sync:
153 | process_group = torch.distributed.group.WORLD
154 | if self.process_group:
155 | process_group = self.process_group
156 | world_size = torch.distributed.get_world_size(process_group)
157 | need_sync = world_size > 1
158 |
159 | # fallback to framework BN when synchronization is not necessary
160 | if not need_sync:
161 | out = F.batch_norm(
162 | input, self.running_mean, self.running_var, self.weight, self.bias,
163 | self.training or not self.track_running_stats,
164 | exponential_average_factor, self.eps)
165 | return swish.apply(out)
166 | else:
167 | # av: I only use it in this setting.
168 | if not self.ddp_gpu_size and False:
169 | raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel')
170 |
171 | return sync_batch_norm.apply(
172 | input, self.weight, self.bias, self.running_mean, self.running_var,
173 | self.eps, exponential_average_factor, process_group, world_size)
--------------------------------------------------------------------------------
/thirdparty/lsun.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the torchvision library
5 | # which was released under the BSD 3-Clause License.
6 | #
7 | # Source:
8 | # https://github.com/pytorch/vision/blob/ea6b879e90459006e71a164dc76b7e2cc3bff9d9/torchvision/datasets/lsun.py
9 | #
10 | # The license for the original version of this file can be
11 | # found in this directory (LICENSE_torchvision). The modifications
12 | # to this file are subject to the same BSD 3-Clause License.
13 | # ---------------------------------------------------------------
14 |
15 | from torchvision.datasets.vision import VisionDataset
16 | from PIL import Image
17 | import os
18 | import os.path
19 | import io
20 | import string
21 | from collections.abc import Iterable
22 | import pickle
23 | from torchvision.datasets.utils import verify_str_arg, iterable_to_str
24 |
25 |
26 | class LSUNClass(VisionDataset):
27 | def __init__(self, root, transform=None, target_transform=None):
28 | import lmdb
29 | super(LSUNClass, self).__init__(root, transform=transform,
30 | target_transform=target_transform)
31 |
32 | self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False,
33 | readahead=False, meminit=False)
34 | with self.env.begin(write=False) as txn:
35 | self.length = txn.stat()['entries']
36 | # cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters)
37 | # av begin
38 | # We only modified the location of cache_file.
39 | cache_file = os.path.join(self.root, '_cache_')
40 | # av end
41 | if os.path.isfile(cache_file):
42 | self.keys = pickle.load(open(cache_file, "rb"))
43 | else:
44 | with self.env.begin(write=False) as txn:
45 | self.keys = [key for key, _ in txn.cursor()]
46 | pickle.dump(self.keys, open(cache_file, "wb"))
47 |
48 | def __getitem__(self, index):
49 | img, target = None, None
50 | env = self.env
51 | with env.begin(write=False) as txn:
52 | imgbuf = txn.get(self.keys[index])
53 |
54 | buf = io.BytesIO()
55 | buf.write(imgbuf)
56 | buf.seek(0)
57 | img = Image.open(buf).convert('RGB')
58 |
59 | if self.transform is not None:
60 | img = self.transform(img)
61 |
62 | if self.target_transform is not None:
63 | target = self.target_transform(target)
64 |
65 | return img, target
66 |
67 | def __len__(self):
68 | return self.length
69 |
70 |
71 | class LSUN(VisionDataset):
72 | """
73 | `LSUN `_ dataset.
74 |
75 | Args:
76 | root (string): Root directory for the database files.
77 | classes (string or list): One of {'train', 'val', 'test'} or a list of
78 | categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
79 | transform (callable, optional): A function/transform that takes in an PIL image
80 | and returns a transformed version. E.g, ``transforms.RandomCrop``
81 | target_transform (callable, optional): A function/transform that takes in the
82 | target and transforms it.
83 | """
84 |
85 | def __init__(self, root, classes='train', transform=None, target_transform=None):
86 | super(LSUN, self).__init__(root, transform=transform,
87 | target_transform=target_transform)
88 | self.classes = self._verify_classes(classes)
89 |
90 | # for each class, create an LSUNClassDataset
91 | self.dbs = []
92 | for c in self.classes:
93 | self.dbs.append(LSUNClass(
94 | root=root + '/' + c + '_lmdb',
95 | transform=transform))
96 |
97 | self.indices = []
98 | count = 0
99 | for db in self.dbs:
100 | count += len(db)
101 | self.indices.append(count)
102 |
103 | self.length = count
104 |
105 | def _verify_classes(self, classes):
106 | categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
107 | 'conference_room', 'dining_room', 'kitchen',
108 | 'living_room', 'restaurant', 'tower']
109 | dset_opts = ['train', 'val', 'test']
110 |
111 | try:
112 | verify_str_arg(classes, "classes", dset_opts)
113 | if classes == 'test':
114 | classes = [classes]
115 | else:
116 | classes = [c + '_' + classes for c in categories]
117 | except ValueError:
118 | if not isinstance(classes, Iterable):
119 | msg = ("Expected type str or Iterable for argument classes, "
120 | "but got type {}.")
121 | raise ValueError(msg.format(type(classes)))
122 |
123 | classes = list(classes)
124 | msg_fmtstr = ("Expected type str for elements in argument classes, "
125 | "but got type {}.")
126 | for c in classes:
127 | verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c)))
128 | c_short = c.split('_')
129 | category, dset_opt = '_'.join(c_short[:-1]), c_short[-1]
130 |
131 | msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
132 | msg = msg_fmtstr.format(category, "LSUN class",
133 | iterable_to_str(categories))
134 | verify_str_arg(category, valid_values=categories, custom_msg=msg)
135 |
136 | msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
137 | verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
138 |
139 | return classes
140 |
141 | def __getitem__(self, index):
142 | """
143 | Args:
144 | index (int): Index
145 |
146 | Returns:
147 | tuple: Tuple (image, target) where target is the index of the target category.
148 | """
149 | target = 0
150 | sub = 0
151 | for ind in self.indices:
152 | if index < ind:
153 | break
154 | target += 1
155 | sub = ind
156 |
157 | db = self.dbs[target]
158 | index = index - sub
159 |
160 | if self.target_transform is not None:
161 | target = self.target_transform(target)
162 |
163 | img, _ = db[index]
164 | return img, target
165 |
166 | def __len__(self):
167 | return self.length
168 |
169 | def extra_repr(self):
170 | return "Classes: {classes}".format(**self.__dict__)
171 |
--------------------------------------------------------------------------------
/thirdparty/swish.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the following repo
5 | # (released under the Apache License 2.0).
6 | #
7 | # Source:
8 | # https://github.com/ceshine/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py
9 | #
10 | # The license for the original version of this file can be
11 | # found in this directory (LICENSE_apache).
12 | # ---------------------------------------------------------------
13 |
14 | import torch
15 |
16 |
17 | class Swish(torch.autograd.Function):
18 | @staticmethod
19 | def forward(ctx, i):
20 | result = i * torch.sigmoid(i)
21 | ctx.save_for_backward(i)
22 | return result
23 |
24 | @staticmethod
25 | def backward(ctx, grad_output):
26 | i = ctx.saved_variables[0]
27 | sigmoid_i = torch.sigmoid(i)
28 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for NVAE. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import argparse
9 | import torch
10 | import torch.nn as nn
11 | import numpy as np
12 | import os
13 |
14 | import torch.distributed as dist
15 | from torch.multiprocessing import Process
16 | from torch.cuda.amp import autocast, GradScaler
17 |
18 | from model import AutoEncoder
19 | from thirdparty.adamax import Adamax
20 | import utils
21 | import datasets
22 |
23 | from fid.fid_score import compute_statistics_of_generator, load_statistics, calculate_frechet_distance
24 | from fid.inception import InceptionV3
25 |
26 |
27 | def main(args):
28 | # ensures that weight initializations are all the same
29 | torch.manual_seed(args.seed)
30 | np.random.seed(args.seed)
31 | torch.cuda.manual_seed(args.seed)
32 | torch.cuda.manual_seed_all(args.seed)
33 |
34 | logging = utils.Logger(args.global_rank, args.save)
35 | writer = utils.Writer(args.global_rank, args.save)
36 |
37 | # Get data loaders.
38 | train_queue, valid_queue, num_classes = datasets.get_loaders(args)
39 | args.num_total_iter = len(train_queue) * args.epochs
40 | warmup_iters = len(train_queue) * args.warmup_epochs
41 | swa_start = len(train_queue) * (args.epochs - 1)
42 |
43 | arch_instance = utils.get_arch_cells(args.arch_instance)
44 |
45 | model = AutoEncoder(args, writer, arch_instance)
46 | model = model.cuda()
47 |
48 | logging.info('args = %s', args)
49 | logging.info('param size = %fM ', utils.count_parameters_in_M(model))
50 | logging.info('groups per scale: %s, total_groups: %d', model.groups_per_scale, sum(model.groups_per_scale))
51 |
52 | if args.fast_adamax:
53 | # Fast adamax has the same functionality as torch.optim.Adamax, except it is faster.
54 | cnn_optimizer = Adamax(model.parameters(), args.learning_rate,
55 | weight_decay=args.weight_decay, eps=1e-3)
56 | else:
57 | cnn_optimizer = torch.optim.Adamax(model.parameters(), args.learning_rate,
58 | weight_decay=args.weight_decay, eps=1e-3)
59 |
60 | cnn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
61 | cnn_optimizer, float(args.epochs - args.warmup_epochs - 1), eta_min=args.learning_rate_min)
62 | grad_scalar = GradScaler(2**10)
63 |
64 | num_output = utils.num_output(args.dataset)
65 | bpd_coeff = 1. / np.log(2.) / num_output
66 |
67 | # if load
68 | checkpoint_file = os.path.join(args.save, 'checkpoint.pt')
69 | if args.cont_training:
70 | logging.info('loading the model.')
71 | checkpoint = torch.load(checkpoint_file, map_location='cpu')
72 | init_epoch = checkpoint['epoch']
73 | model.load_state_dict(checkpoint['state_dict'])
74 | model = model.cuda()
75 | cnn_optimizer.load_state_dict(checkpoint['optimizer'])
76 | grad_scalar.load_state_dict(checkpoint['grad_scalar'])
77 | cnn_scheduler.load_state_dict(checkpoint['scheduler'])
78 | global_step = checkpoint['global_step']
79 | else:
80 | global_step, init_epoch = 0, 0
81 |
82 | for epoch in range(init_epoch, args.epochs):
83 | # update lrs.
84 | if args.distributed:
85 | train_queue.sampler.set_epoch(global_step + args.seed)
86 | valid_queue.sampler.set_epoch(0)
87 |
88 | if epoch > args.warmup_epochs:
89 | cnn_scheduler.step()
90 |
91 | # Logging.
92 | logging.info('epoch %d', epoch)
93 |
94 | # Training.
95 | train_nelbo, global_step = train(train_queue, model, cnn_optimizer, grad_scalar, global_step, warmup_iters, writer, logging)
96 | logging.info('train_nelbo %f', train_nelbo)
97 | writer.add_scalar('train/nelbo', train_nelbo, global_step)
98 |
99 | model.eval()
100 | # generate samples less frequently
101 | eval_freq = 1 if args.epochs <= 50 else 20
102 | if epoch % eval_freq == 0 or epoch == (args.epochs - 1):
103 | with torch.no_grad():
104 | num_samples = 16
105 | n = int(np.floor(np.sqrt(num_samples)))
106 | for t in [0.7, 0.8, 0.9, 1.0]:
107 | logits = model.sample(num_samples, t)
108 | output = model.decoder_output(logits)
109 | output_img = output.mean if isinstance(output, torch.distributions.bernoulli.Bernoulli) else output.sample(t)
110 | output_tiled = utils.tile_image(output_img, n)
111 | writer.add_image('generated_%0.1f' % t, output_tiled, global_step)
112 |
113 | valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=10, args=args, logging=logging)
114 | logging.info('valid_nelbo %f', valid_nelbo)
115 | logging.info('valid neg log p %f', valid_neg_log_p)
116 | logging.info('valid bpd elbo %f', valid_nelbo * bpd_coeff)
117 | logging.info('valid bpd log p %f', valid_neg_log_p * bpd_coeff)
118 | writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch)
119 | writer.add_scalar('val/nelbo', valid_nelbo, epoch)
120 | writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch)
121 | writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch)
122 |
123 | save_freq = int(np.ceil(args.epochs / 100))
124 | if epoch % save_freq == 0 or epoch == (args.epochs - 1):
125 | if args.global_rank == 0:
126 | logging.info('saving the model.')
127 | torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict(),
128 | 'optimizer': cnn_optimizer.state_dict(), 'global_step': global_step,
129 | 'args': args, 'arch_instance': arch_instance, 'scheduler': cnn_scheduler.state_dict(),
130 | 'grad_scalar': grad_scalar.state_dict()}, checkpoint_file)
131 |
132 | # Final validation
133 | valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=1000, args=args, logging=logging)
134 | logging.info('final valid nelbo %f', valid_nelbo)
135 | logging.info('final valid neg log p %f', valid_neg_log_p)
136 | writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch + 1)
137 | writer.add_scalar('val/nelbo', valid_nelbo, epoch + 1)
138 | writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch + 1)
139 | writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch + 1)
140 | writer.close()
141 |
142 |
143 | def train(train_queue, model, cnn_optimizer, grad_scalar, global_step, warmup_iters, writer, logging):
144 | alpha_i = utils.kl_balancer_coeff(num_scales=model.num_latent_scales,
145 | groups_per_scale=model.groups_per_scale, fun='square')
146 | nelbo = utils.AvgrageMeter()
147 | model.train()
148 | for step, x in enumerate(train_queue):
149 | x = x[0] if len(x) > 1 else x
150 | x = x.cuda()
151 |
152 | # change bit length
153 | x = utils.pre_process(x, args.num_x_bits)
154 |
155 | # warm-up lr
156 | if global_step < warmup_iters:
157 | lr = args.learning_rate * float(global_step) / warmup_iters
158 | for param_group in cnn_optimizer.param_groups:
159 | param_group['lr'] = lr
160 |
161 | # sync parameters, it may not be necessary
162 | if step % 100 == 0:
163 | utils.average_params(model.parameters(), args.distributed)
164 |
165 | cnn_optimizer.zero_grad()
166 | with autocast():
167 | logits, log_q, log_p, kl_all, kl_diag = model(x)
168 |
169 | output = model.decoder_output(logits)
170 | kl_coeff = utils.kl_coeff(global_step, args.kl_anneal_portion * args.num_total_iter,
171 | args.kl_const_portion * args.num_total_iter, args.kl_const_coeff)
172 |
173 | recon_loss = utils.reconstruction_loss(output, x, crop=model.crop_output)
174 | balanced_kl, kl_coeffs, kl_vals = utils.kl_balancer(kl_all, kl_coeff, kl_balance=True, alpha_i=alpha_i)
175 |
176 | nelbo_batch = recon_loss + balanced_kl
177 | loss = torch.mean(nelbo_batch)
178 | norm_loss = model.spectral_norm_parallel()
179 | bn_loss = model.batchnorm_loss()
180 | # get spectral regularization coefficient (lambda)
181 | if args.weight_decay_norm_anneal:
182 | assert args.weight_decay_norm_init > 0 and args.weight_decay_norm > 0, 'init and final wdn should be positive.'
183 | wdn_coeff = (1. - kl_coeff) * np.log(args.weight_decay_norm_init) + kl_coeff * np.log(args.weight_decay_norm)
184 | wdn_coeff = np.exp(wdn_coeff)
185 | else:
186 | wdn_coeff = args.weight_decay_norm
187 |
188 | loss += norm_loss * wdn_coeff + bn_loss * wdn_coeff
189 |
190 | grad_scalar.scale(loss).backward()
191 | utils.average_gradients(model.parameters(), args.distributed)
192 | grad_scalar.step(cnn_optimizer)
193 | grad_scalar.update()
194 | nelbo.update(loss.data, 1)
195 |
196 | if (global_step + 1) % 100 == 0:
197 | if (global_step + 1) % 1000 == 0: # reduced frequency
198 | n = int(np.floor(np.sqrt(x.size(0))))
199 | x_img = x[:n*n]
200 | output_img = output.mean if isinstance(output, torch.distributions.bernoulli.Bernoulli) else output.sample()
201 | output_img = output_img[:n*n]
202 | x_tiled = utils.tile_image(x_img, n)
203 | output_tiled = utils.tile_image(output_img, n)
204 | in_out_tiled = torch.cat((x_tiled, output_tiled), dim=2)
205 | writer.add_image('reconstruction', in_out_tiled, global_step)
206 |
207 | # norm
208 | writer.add_scalar('train/norm_loss', norm_loss, global_step)
209 | writer.add_scalar('train/bn_loss', bn_loss, global_step)
210 | writer.add_scalar('train/norm_coeff', wdn_coeff, global_step)
211 |
212 | utils.average_tensor(nelbo.avg, args.distributed)
213 | logging.info('train %d %f', global_step, nelbo.avg)
214 | writer.add_scalar('train/nelbo_avg', nelbo.avg, global_step)
215 | writer.add_scalar('train/lr', cnn_optimizer.state_dict()[
216 | 'param_groups'][0]['lr'], global_step)
217 | writer.add_scalar('train/nelbo_iter', loss, global_step)
218 | writer.add_scalar('train/kl_iter', torch.mean(sum(kl_all)), global_step)
219 | writer.add_scalar('train/recon_iter', torch.mean(utils.reconstruction_loss(output, x, crop=model.crop_output)), global_step)
220 | writer.add_scalar('kl_coeff/coeff', kl_coeff, global_step)
221 | total_active = 0
222 | for i, kl_diag_i in enumerate(kl_diag):
223 | utils.average_tensor(kl_diag_i, args.distributed)
224 | num_active = torch.sum(kl_diag_i > 0.1).detach()
225 | total_active += num_active
226 |
227 | # kl_ceoff
228 | writer.add_scalar('kl/active_%d' % i, num_active, global_step)
229 | writer.add_scalar('kl_coeff/layer_%d' % i, kl_coeffs[i], global_step)
230 | writer.add_scalar('kl_vals/layer_%d' % i, kl_vals[i], global_step)
231 | writer.add_scalar('kl/total_active', total_active, global_step)
232 |
233 | global_step += 1
234 |
235 | utils.average_tensor(nelbo.avg, args.distributed)
236 | return nelbo.avg, global_step
237 |
238 |
239 | def test(valid_queue, model, num_samples, args, logging):
240 | if args.distributed:
241 | dist.barrier()
242 | nelbo_avg = utils.AvgrageMeter()
243 | neg_log_p_avg = utils.AvgrageMeter()
244 | model.eval()
245 | for step, x in enumerate(valid_queue):
246 | x = x[0] if len(x) > 1 else x
247 | x = x.cuda()
248 |
249 | # change bit length
250 | x = utils.pre_process(x, args.num_x_bits)
251 |
252 | with torch.no_grad():
253 | nelbo, log_iw = [], []
254 | for k in range(num_samples):
255 | logits, log_q, log_p, kl_all, _ = model(x)
256 | output = model.decoder_output(logits)
257 | recon_loss = utils.reconstruction_loss(output, x, crop=model.crop_output)
258 | balanced_kl, _, _ = utils.kl_balancer(kl_all, kl_balance=False)
259 | nelbo_batch = recon_loss + balanced_kl
260 | nelbo.append(nelbo_batch)
261 | log_iw.append(utils.log_iw(output, x, log_q, log_p, crop=model.crop_output))
262 |
263 | nelbo = torch.mean(torch.stack(nelbo, dim=1))
264 | log_p = torch.mean(torch.logsumexp(torch.stack(log_iw, dim=1), dim=1) - np.log(num_samples))
265 |
266 | nelbo_avg.update(nelbo.data, x.size(0))
267 | neg_log_p_avg.update(- log_p.data, x.size(0))
268 |
269 | utils.average_tensor(nelbo_avg.avg, args.distributed)
270 | utils.average_tensor(neg_log_p_avg.avg, args.distributed)
271 | if args.distributed:
272 | # block to sync
273 | dist.barrier()
274 | logging.info('val, step: %d, NELBO: %f, neg Log p %f', step, nelbo_avg.avg, neg_log_p_avg.avg)
275 | return neg_log_p_avg.avg, nelbo_avg.avg
276 |
277 |
278 | def create_generator_vae(model, batch_size, num_total_samples):
279 | num_iters = int(np.ceil(num_total_samples / batch_size))
280 | for i in range(num_iters):
281 | with torch.no_grad():
282 | logits = model.sample(batch_size, 1.0)
283 | output = model.decoder_output(logits)
284 | output_img = output.mean if isinstance(output, torch.distributions.bernoulli.Bernoulli) else output.mean()
285 | yield output_img.float()
286 |
287 |
288 | def test_vae_fid(model, args, total_fid_samples):
289 | dims = 2048
290 | device = 'cuda'
291 | num_gpus = args.num_process_per_node * args.num_proc_node
292 | num_sample_per_gpu = int(np.ceil(total_fid_samples / num_gpus))
293 |
294 | g = create_generator_vae(model, args.batch_size, num_sample_per_gpu)
295 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
296 | model = InceptionV3([block_idx], model_dir=args.fid_dir).to(device)
297 | m, s = compute_statistics_of_generator(g, model, args.batch_size, dims, device, max_samples=num_sample_per_gpu)
298 |
299 | # share m and s
300 | m = torch.from_numpy(m).cuda()
301 | s = torch.from_numpy(s).cuda()
302 | # take average across gpus
303 | utils.average_tensor(m, args.distributed)
304 | utils.average_tensor(s, args.distributed)
305 |
306 | # convert m, s
307 | m = m.cpu().numpy()
308 | s = s.cpu().numpy()
309 |
310 | # load precomputed m, s
311 | path = os.path.join(args.fid_dir, args.dataset + '.npz')
312 | m0, s0 = load_statistics(path)
313 |
314 | fid = calculate_frechet_distance(m0, s0, m, s)
315 | return fid
316 |
317 |
318 | def init_processes(rank, size, fn, args):
319 | """ Initialize the distributed environment. """
320 | os.environ['MASTER_ADDR'] = args.master_address
321 | os.environ['MASTER_PORT'] = '6020'
322 | torch.cuda.set_device(args.local_rank)
323 | dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=size)
324 | fn(args)
325 | cleanup()
326 |
327 |
328 | def cleanup():
329 | dist.destroy_process_group()
330 |
331 |
332 | if __name__ == '__main__':
333 | parser = argparse.ArgumentParser('encoder decoder examiner')
334 | # experimental results
335 | parser.add_argument('--root', type=str, default='/tmp/nasvae/expr',
336 | help='location of the results')
337 | parser.add_argument('--save', type=str, default='exp',
338 | help='id used for storing intermediate results')
339 | # data
340 | parser.add_argument('--dataset', type=str, default='mnist',
341 | choices=['cifar10', 'mnist', 'omniglot', 'celeba_64', 'celeba_256',
342 | 'imagenet_32', 'ffhq', 'lsun_bedroom_128', 'stacked_mnist',
343 | 'lsun_church_128', 'lsun_church_64'],
344 | help='which dataset to use')
345 | parser.add_argument('--data', type=str, default='/tmp/nasvae/data',
346 | help='location of the data corpus')
347 | # optimization
348 | parser.add_argument('--batch_size', type=int, default=200,
349 | help='batch size per GPU')
350 | parser.add_argument('--learning_rate', type=float, default=1e-2,
351 | help='init learning rate')
352 | parser.add_argument('--learning_rate_min', type=float, default=1e-4,
353 | help='min learning rate')
354 | parser.add_argument('--weight_decay', type=float, default=3e-4,
355 | help='weight decay')
356 | parser.add_argument('--weight_decay_norm', type=float, default=0.,
357 | help='The lambda parameter for spectral regularization.')
358 | parser.add_argument('--weight_decay_norm_init', type=float, default=10.,
359 | help='The initial lambda parameter')
360 | parser.add_argument('--weight_decay_norm_anneal', action='store_true', default=False,
361 | help='This flag enables annealing the lambda coefficient from '
362 | '--weight_decay_norm_init to --weight_decay_norm.')
363 | parser.add_argument('--epochs', type=int, default=200,
364 | help='num of training epochs')
365 | parser.add_argument('--warmup_epochs', type=int, default=5,
366 | help='num of training epochs in which lr is warmed up')
367 | parser.add_argument('--fast_adamax', action='store_true', default=False,
368 | help='This flag enables using our optimized adamax.')
369 | parser.add_argument('--arch_instance', type=str, default='res_mbconv',
370 | help='path to the architecture instance')
371 | # KL annealing
372 | parser.add_argument('--kl_anneal_portion', type=float, default=0.3,
373 | help='The portions epochs that KL is annealed')
374 | parser.add_argument('--kl_const_portion', type=float, default=0.0001,
375 | help='The portions epochs that KL is constant at kl_const_coeff')
376 | parser.add_argument('--kl_const_coeff', type=float, default=0.0001,
377 | help='The constant value used for min KL coeff')
378 | # Flow params
379 | parser.add_argument('--num_nf', type=int, default=0,
380 | help='The number of normalizing flow cells per groups. Set this to zero to disable flows.')
381 | parser.add_argument('--num_x_bits', type=int, default=8,
382 | help='The number of bits used for representing data for colored images.')
383 | # latent variables
384 | parser.add_argument('--num_latent_scales', type=int, default=1,
385 | help='the number of latent scales')
386 | parser.add_argument('--num_groups_per_scale', type=int, default=10,
387 | help='number of groups of latent variables per scale')
388 | parser.add_argument('--num_latent_per_group', type=int, default=20,
389 | help='number of channels in latent variables per group')
390 | parser.add_argument('--ada_groups', action='store_true', default=False,
391 | help='Settings this to true will set different number of groups per scale.')
392 | parser.add_argument('--min_groups_per_scale', type=int, default=1,
393 | help='the minimum number of groups per scale.')
394 | # encoder parameters
395 | parser.add_argument('--num_channels_enc', type=int, default=32,
396 | help='number of channels in encoder')
397 | parser.add_argument('--num_preprocess_blocks', type=int, default=2,
398 | help='number of preprocessing blocks')
399 | parser.add_argument('--num_preprocess_cells', type=int, default=3,
400 | help='number of cells per block')
401 | parser.add_argument('--num_cell_per_cond_enc', type=int, default=1,
402 | help='number of cell for each conditional in encoder')
403 | # decoder parameters
404 | parser.add_argument('--num_channels_dec', type=int, default=32,
405 | help='number of channels in decoder')
406 | parser.add_argument('--num_postprocess_blocks', type=int, default=2,
407 | help='number of postprocessing blocks')
408 | parser.add_argument('--num_postprocess_cells', type=int, default=3,
409 | help='number of cells per block')
410 | parser.add_argument('--num_cell_per_cond_dec', type=int, default=1,
411 | help='number of cell for each conditional in decoder')
412 | parser.add_argument('--num_mixture_dec', type=int, default=10,
413 | help='number of mixture components in decoder. set to 1 for Normal decoder.')
414 | # NAS
415 | parser.add_argument('--use_se', action='store_true', default=False,
416 | help='This flag enables squeeze and excitation.')
417 | parser.add_argument('--res_dist', action='store_true', default=False,
418 | help='This flag enables squeeze and excitation.')
419 | parser.add_argument('--cont_training', action='store_true', default=False,
420 | help='This flag enables training from an existing checkpoint.')
421 | # DDP.
422 | parser.add_argument('--num_proc_node', type=int, default=1,
423 | help='The number of nodes in multi node env.')
424 | parser.add_argument('--node_rank', type=int, default=0,
425 | help='The index of node.')
426 | parser.add_argument('--local_rank', type=int, default=0,
427 | help='rank of process in the node')
428 | parser.add_argument('--global_rank', type=int, default=0,
429 | help='rank of process among all the processes')
430 | parser.add_argument('--num_process_per_node', type=int, default=1,
431 | help='number of gpus')
432 | parser.add_argument('--master_address', type=str, default='127.0.0.1',
433 | help='address for master')
434 | parser.add_argument('--seed', type=int, default=1,
435 | help='seed used for initialization')
436 | args = parser.parse_args()
437 | args.save = args.root + '/eval-' + args.save
438 | utils.create_exp_dir(args.save)
439 |
440 | size = args.num_process_per_node
441 |
442 | if size > 1:
443 | args.distributed = True
444 | processes = []
445 | for rank in range(size):
446 | args.local_rank = rank
447 | global_rank = rank + args.node_rank * args.num_process_per_node
448 | global_size = args.num_proc_node * args.num_process_per_node
449 | args.global_rank = global_rank
450 | print('Node rank %d, local proc %d, global proc %d' % (args.node_rank, rank, global_rank))
451 | p = Process(target=init_processes, args=(global_rank, global_size, main, args))
452 | p.start()
453 | processes.append(p)
454 |
455 | for p in processes:
456 | p.join()
457 | else:
458 | # for debugging
459 | print('starting in debug mode')
460 | args.distributed = True
461 | init_processes(0, size, main, args)
462 |
463 |
464 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for NVAE. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import logging
9 | import os
10 | import shutil
11 | import time
12 | from datetime import timedelta
13 | import sys
14 |
15 | import torch
16 | import torch.nn as nn
17 | import numpy as np
18 | import torch.distributed as dist
19 |
20 | import torch.nn.functional as F
21 | from tensorboardX import SummaryWriter
22 |
23 |
24 | class AvgrageMeter(object):
25 |
26 | def __init__(self):
27 | self.reset()
28 |
29 | def reset(self):
30 | self.avg = 0
31 | self.sum = 0
32 | self.cnt = 0
33 |
34 | def update(self, val, n=1):
35 | self.sum += val * n
36 | self.cnt += n
37 | self.avg = self.sum / self.cnt
38 |
39 |
40 | class ExpMovingAvgrageMeter(object):
41 |
42 | def __init__(self, momentum=0.9):
43 | self.momentum = momentum
44 | self.reset()
45 |
46 | def reset(self):
47 | self.avg = 0
48 |
49 | def update(self, val):
50 | self.avg = (1. - self.momentum) * self.avg + self.momentum * val
51 |
52 |
53 | class DummyDDP(nn.Module):
54 | def __init__(self, model):
55 | super(DummyDDP, self).__init__()
56 | self.module = model
57 |
58 | def forward(self, *input, **kwargs):
59 | return self.module(*input, **kwargs)
60 |
61 |
62 | def count_parameters_in_M(model):
63 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
64 |
65 |
66 | def save_checkpoint(state, is_best, save):
67 | filename = os.path.join(save, 'checkpoint.pth.tar')
68 | torch.save(state, filename)
69 | if is_best:
70 | best_filename = os.path.join(save, 'model_best.pth.tar')
71 | shutil.copyfile(filename, best_filename)
72 |
73 |
74 | def save(model, model_path):
75 | torch.save(model.state_dict(), model_path)
76 |
77 |
78 | def load(model, model_path):
79 | model.load_state_dict(torch.load(model_path))
80 |
81 |
82 | def create_exp_dir(path, scripts_to_save=None):
83 | if not os.path.exists(path):
84 | os.makedirs(path, exist_ok=True)
85 | print('Experiment dir : {}'.format(path))
86 |
87 | if scripts_to_save is not None:
88 | if not os.path.exists(os.path.join(path, 'scripts')):
89 | os.mkdir(os.path.join(path, 'scripts'))
90 | for script in scripts_to_save:
91 | dst_file = os.path.join(path, 'scripts', os.path.basename(script))
92 | shutil.copyfile(script, dst_file)
93 |
94 |
95 | class Logger(object):
96 | def __init__(self, rank, save):
97 | # other libraries may set logging before arriving at this line.
98 | # by reloading logging, we can get rid of previous configs set by other libraries.
99 | from importlib import reload
100 | reload(logging)
101 | self.rank = rank
102 | if self.rank == 0:
103 | log_format = '%(asctime)s %(message)s'
104 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
105 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
106 | fh = logging.FileHandler(os.path.join(save, 'log.txt'))
107 | fh.setFormatter(logging.Formatter(log_format))
108 | logging.getLogger().addHandler(fh)
109 | self.start_time = time.time()
110 |
111 | def info(self, string, *args):
112 | if self.rank == 0:
113 | elapsed_time = time.time() - self.start_time
114 | elapsed_time = time.strftime(
115 | '(Elapsed: %H:%M:%S) ', time.gmtime(elapsed_time))
116 | if isinstance(string, str):
117 | string = elapsed_time + string
118 | else:
119 | logging.info(elapsed_time)
120 | logging.info(string, *args)
121 |
122 |
123 | class Writer(object):
124 | def __init__(self, rank, save):
125 | self.rank = rank
126 | if self.rank == 0:
127 | self.writer = SummaryWriter(log_dir=save, flush_secs=20)
128 |
129 | def add_scalar(self, *args, **kwargs):
130 | if self.rank == 0:
131 | self.writer.add_scalar(*args, **kwargs)
132 |
133 | def add_figure(self, *args, **kwargs):
134 | if self.rank == 0:
135 | self.writer.add_figure(*args, **kwargs)
136 |
137 | def add_image(self, *args, **kwargs):
138 | if self.rank == 0:
139 | self.writer.add_image(*args, **kwargs)
140 |
141 | def add_histogram(self, *args, **kwargs):
142 | if self.rank == 0:
143 | self.writer.add_histogram(*args, **kwargs)
144 |
145 | def add_histogram_if(self, write, *args, **kwargs):
146 | if write and False: # Used for debugging.
147 | self.add_histogram(*args, **kwargs)
148 |
149 | def close(self, *args, **kwargs):
150 | if self.rank == 0:
151 | self.writer.close()
152 |
153 |
154 | def reduce_tensor(tensor, world_size):
155 | rt = tensor.clone()
156 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
157 | rt /= world_size
158 | return rt
159 |
160 |
161 | def get_stride_for_cell_type(cell_type):
162 | if cell_type.startswith('normal') or cell_type.startswith('combiner'):
163 | stride = 1
164 | elif cell_type.startswith('down'):
165 | stride = 2
166 | elif cell_type.startswith('up'):
167 | stride = -1
168 | else:
169 | raise NotImplementedError(cell_type)
170 |
171 | return stride
172 |
173 |
174 | def get_cout(cin, stride):
175 | if stride == 1:
176 | cout = cin
177 | elif stride == -1:
178 | cout = cin // 2
179 | elif stride == 2:
180 | cout = 2 * cin
181 |
182 | return cout
183 |
184 |
185 | def kl_balancer_coeff(num_scales, groups_per_scale, fun):
186 | if fun == 'equal':
187 | coeff = torch.cat([torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda()
188 | elif fun == 'linear':
189 | coeff = torch.cat([(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda()
190 | elif fun == 'sqrt':
191 | coeff = torch.cat([np.sqrt(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda()
192 | elif fun == 'square':
193 | coeff = torch.cat([np.square(2 ** i) / groups_per_scale[num_scales - i - 1] * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda()
194 | else:
195 | raise NotImplementedError
196 | # convert min to 1.
197 | coeff /= torch.min(coeff)
198 | return coeff
199 |
200 |
201 | def kl_per_group(kl_all):
202 | kl_vals = torch.mean(kl_all, dim=0)
203 | kl_coeff_i = torch.abs(kl_all)
204 | kl_coeff_i = torch.mean(kl_coeff_i, dim=0, keepdim=True) + 0.01
205 |
206 | return kl_coeff_i, kl_vals
207 |
208 |
209 | def kl_balancer(kl_all, kl_coeff=1.0, kl_balance=False, alpha_i=None):
210 | if kl_balance and kl_coeff < 1.0:
211 | alpha_i = alpha_i.unsqueeze(0)
212 |
213 | kl_all = torch.stack(kl_all, dim=1)
214 | kl_coeff_i, kl_vals = kl_per_group(kl_all)
215 | total_kl = torch.sum(kl_coeff_i)
216 |
217 | kl_coeff_i = kl_coeff_i / alpha_i * total_kl
218 | kl_coeff_i = kl_coeff_i / torch.mean(kl_coeff_i, dim=1, keepdim=True)
219 | kl = torch.sum(kl_all * kl_coeff_i.detach(), dim=1)
220 |
221 | # for reporting
222 | kl_coeffs = kl_coeff_i.squeeze(0)
223 | else:
224 | kl_all = torch.stack(kl_all, dim=1)
225 | kl_vals = torch.mean(kl_all, dim=0)
226 | kl = torch.sum(kl_all, dim=1)
227 | kl_coeffs = torch.ones(size=(len(kl_vals),))
228 |
229 | return kl_coeff * kl, kl_coeffs, kl_vals
230 |
231 |
232 | def kl_coeff(step, total_step, constant_step, min_kl_coeff):
233 | return max(min((step - constant_step) / total_step, 1.0), min_kl_coeff)
234 |
235 |
236 | def log_iw(decoder, x, log_q, log_p, crop=False):
237 | recon = reconstruction_loss(decoder, x, crop)
238 | return - recon - log_q + log_p
239 |
240 |
241 | def reconstruction_loss(decoder, x, crop=False):
242 | from distributions import Normal, DiscMixLogistic
243 |
244 | recon = decoder.log_prob(x)
245 | if crop:
246 | recon = recon[:, :, 2:30, 2:30]
247 |
248 | if isinstance(decoder, DiscMixLogistic):
249 | return - torch.sum(recon, dim=[1, 2]) # summation over RGB is done.
250 | else:
251 | return - torch.sum(recon, dim=[1, 2, 3])
252 |
253 |
254 | def tile_image(batch_image, n):
255 | assert n * n == batch_image.size(0)
256 | channels, height, width = batch_image.size(1), batch_image.size(2), batch_image.size(3)
257 | batch_image = batch_image.view(n, n, channels, height, width)
258 | batch_image = batch_image.permute(2, 0, 3, 1, 4) # n, height, n, width, c
259 | batch_image = batch_image.contiguous().view(channels, n * height, n * width)
260 | return batch_image
261 |
262 |
263 | def average_gradients(params, is_distributed):
264 | """ Gradient averaging. """
265 | if is_distributed:
266 | size = float(dist.get_world_size())
267 | for param in params:
268 | if param.requires_grad:
269 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
270 | param.grad.data /= size
271 |
272 |
273 | def average_params(params, is_distributed):
274 | """ parameter averaging. """
275 | if is_distributed:
276 | size = float(dist.get_world_size())
277 | for param in params:
278 | dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
279 | param.data /= size
280 |
281 |
282 | def average_tensor(t, is_distributed):
283 | if is_distributed:
284 | size = float(dist.get_world_size())
285 | dist.all_reduce(t.data, op=dist.ReduceOp.SUM)
286 | t.data /= size
287 |
288 |
289 | def one_hot(indices, depth, dim):
290 | indices = indices.unsqueeze(dim)
291 | size = list(indices.size())
292 | size[dim] = depth
293 | y_onehot = torch.zeros(size).cuda()
294 | y_onehot.zero_()
295 | y_onehot.scatter_(dim, indices, 1)
296 |
297 | return y_onehot
298 |
299 |
300 | def num_output(dataset):
301 | if dataset in {'mnist', 'omniglot'}:
302 | return 28 * 28
303 | elif dataset == 'cifar10':
304 | return 3 * 32 * 32
305 | elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'):
306 | size = int(dataset.split('_')[-1])
307 | return 3 * size * size
308 | elif dataset == 'ffhq':
309 | return 3 * 256 * 256
310 | else:
311 | raise NotImplementedError
312 |
313 |
314 | def get_input_size(dataset):
315 | if dataset in {'mnist', 'omniglot'}:
316 | return 32
317 | elif dataset == 'cifar10':
318 | return 32
319 | elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'):
320 | size = int(dataset.split('_')[-1])
321 | return size
322 | elif dataset == 'ffhq':
323 | return 256
324 | else:
325 | raise NotImplementedError
326 |
327 |
328 | def pre_process(x, num_bits):
329 | if num_bits != 8:
330 | x = torch.floor(x * 255 / 2 ** (8 - num_bits))
331 | x /= (2 ** num_bits - 1)
332 | return x
333 |
334 |
335 | def get_arch_cells(arch_type):
336 | if arch_type == 'res_elu':
337 | arch_cells = dict()
338 | arch_cells['normal_enc'] = ['res_elu', 'res_elu']
339 | arch_cells['down_enc'] = ['res_elu', 'res_elu']
340 | arch_cells['normal_dec'] = ['res_elu', 'res_elu']
341 | arch_cells['up_dec'] = ['res_elu', 'res_elu']
342 | arch_cells['normal_pre'] = ['res_elu', 'res_elu']
343 | arch_cells['down_pre'] = ['res_elu', 'res_elu']
344 | arch_cells['normal_post'] = ['res_elu', 'res_elu']
345 | arch_cells['up_post'] = ['res_elu', 'res_elu']
346 | arch_cells['ar_nn'] = ['']
347 | elif arch_type == 'res_bnelu':
348 | arch_cells = dict()
349 | arch_cells['normal_enc'] = ['res_bnelu', 'res_bnelu']
350 | arch_cells['down_enc'] = ['res_bnelu', 'res_bnelu']
351 | arch_cells['normal_dec'] = ['res_bnelu', 'res_bnelu']
352 | arch_cells['up_dec'] = ['res_bnelu', 'res_bnelu']
353 | arch_cells['normal_pre'] = ['res_bnelu', 'res_bnelu']
354 | arch_cells['down_pre'] = ['res_bnelu', 'res_bnelu']
355 | arch_cells['normal_post'] = ['res_bnelu', 'res_bnelu']
356 | arch_cells['up_post'] = ['res_bnelu', 'res_bnelu']
357 | arch_cells['ar_nn'] = ['']
358 | elif arch_type == 'res_bnswish':
359 | arch_cells = dict()
360 | arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish']
361 | arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish']
362 | arch_cells['normal_dec'] = ['res_bnswish', 'res_bnswish']
363 | arch_cells['up_dec'] = ['res_bnswish', 'res_bnswish']
364 | arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish']
365 | arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish']
366 | arch_cells['normal_post'] = ['res_bnswish', 'res_bnswish']
367 | arch_cells['up_post'] = ['res_bnswish', 'res_bnswish']
368 | arch_cells['ar_nn'] = ['']
369 | elif arch_type == 'mbconv_sep':
370 | arch_cells = dict()
371 | arch_cells['normal_enc'] = ['mconv_e6k5g0']
372 | arch_cells['down_enc'] = ['mconv_e6k5g0']
373 | arch_cells['normal_dec'] = ['mconv_e6k5g0']
374 | arch_cells['up_dec'] = ['mconv_e6k5g0']
375 | arch_cells['normal_pre'] = ['mconv_e3k5g0']
376 | arch_cells['down_pre'] = ['mconv_e3k5g0']
377 | arch_cells['normal_post'] = ['mconv_e3k5g0']
378 | arch_cells['up_post'] = ['mconv_e3k5g0']
379 | arch_cells['ar_nn'] = ['']
380 | elif arch_type == 'mbconv_sep11':
381 | arch_cells = dict()
382 | arch_cells['normal_enc'] = ['mconv_e6k11g0']
383 | arch_cells['down_enc'] = ['mconv_e6k11g0']
384 | arch_cells['normal_dec'] = ['mconv_e6k11g0']
385 | arch_cells['up_dec'] = ['mconv_e6k11g0']
386 | arch_cells['normal_pre'] = ['mconv_e3k5g0']
387 | arch_cells['down_pre'] = ['mconv_e3k5g0']
388 | arch_cells['normal_post'] = ['mconv_e3k5g0']
389 | arch_cells['up_post'] = ['mconv_e3k5g0']
390 | arch_cells['ar_nn'] = ['']
391 | elif arch_type == 'res_mbconv':
392 | arch_cells = dict()
393 | arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish']
394 | arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish']
395 | arch_cells['normal_dec'] = ['mconv_e6k5g0']
396 | arch_cells['up_dec'] = ['mconv_e6k5g0']
397 | arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish']
398 | arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish']
399 | arch_cells['normal_post'] = ['mconv_e3k5g0']
400 | arch_cells['up_post'] = ['mconv_e3k5g0']
401 | arch_cells['ar_nn'] = ['']
402 | elif arch_type == 'res53_mbconv':
403 | arch_cells = dict()
404 | arch_cells['normal_enc'] = ['res_bnswish5', 'res_bnswish']
405 | arch_cells['down_enc'] = ['res_bnswish5', 'res_bnswish']
406 | arch_cells['normal_dec'] = ['mconv_e6k5g0']
407 | arch_cells['up_dec'] = ['mconv_e6k5g0']
408 | arch_cells['normal_pre'] = ['res_bnswish5', 'res_bnswish']
409 | arch_cells['down_pre'] = ['res_bnswish5', 'res_bnswish']
410 | arch_cells['normal_post'] = ['mconv_e3k5g0']
411 | arch_cells['up_post'] = ['mconv_e3k5g0']
412 | arch_cells['ar_nn'] = ['']
413 | elif arch_type == 'res35_mbconv':
414 | arch_cells = dict()
415 | arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish5']
416 | arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish5']
417 | arch_cells['normal_dec'] = ['mconv_e6k5g0']
418 | arch_cells['up_dec'] = ['mconv_e6k5g0']
419 | arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish5']
420 | arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish5']
421 | arch_cells['normal_post'] = ['mconv_e3k5g0']
422 | arch_cells['up_post'] = ['mconv_e3k5g0']
423 | arch_cells['ar_nn'] = ['']
424 | elif arch_type == 'res55_mbconv':
425 | arch_cells = dict()
426 | arch_cells['normal_enc'] = ['res_bnswish5', 'res_bnswish5']
427 | arch_cells['down_enc'] = ['res_bnswish5', 'res_bnswish5']
428 | arch_cells['normal_dec'] = ['mconv_e6k5g0']
429 | arch_cells['up_dec'] = ['mconv_e6k5g0']
430 | arch_cells['normal_pre'] = ['res_bnswish5', 'res_bnswish5']
431 | arch_cells['down_pre'] = ['res_bnswish5', 'res_bnswish5']
432 | arch_cells['normal_post'] = ['mconv_e3k5g0']
433 | arch_cells['up_post'] = ['mconv_e3k5g0']
434 | arch_cells['ar_nn'] = ['']
435 | elif arch_type == 'res_mbconv9':
436 | arch_cells = dict()
437 | arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish']
438 | arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish']
439 | arch_cells['normal_dec'] = ['mconv_e6k9g0']
440 | arch_cells['up_dec'] = ['mconv_e6k9g0']
441 | arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish']
442 | arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish']
443 | arch_cells['normal_post'] = ['mconv_e3k9g0']
444 | arch_cells['up_post'] = ['mconv_e3k9g0']
445 | arch_cells['ar_nn'] = ['']
446 | elif arch_type == 'mbconv_res':
447 | arch_cells = dict()
448 | arch_cells['normal_enc'] = ['mconv_e6k5g0']
449 | arch_cells['down_enc'] = ['mconv_e6k5g0']
450 | arch_cells['normal_dec'] = ['res_bnswish', 'res_bnswish']
451 | arch_cells['up_dec'] = ['res_bnswish', 'res_bnswish']
452 | arch_cells['normal_pre'] = ['mconv_e3k5g0']
453 | arch_cells['down_pre'] = ['mconv_e3k5g0']
454 | arch_cells['normal_post'] = ['res_bnswish', 'res_bnswish']
455 | arch_cells['up_post'] = ['res_bnswish', 'res_bnswish']
456 | arch_cells['ar_nn'] = ['']
457 | elif arch_type == 'mbconv_den':
458 | arch_cells = dict()
459 | arch_cells['normal_enc'] = ['mconv_e6k5g0']
460 | arch_cells['down_enc'] = ['mconv_e6k5g0']
461 | arch_cells['normal_dec'] = ['mconv_e6k5g0']
462 | arch_cells['up_dec'] = ['mconv_e6k5g0']
463 | arch_cells['normal_pre'] = ['mconv_e3k5g8']
464 | arch_cells['down_pre'] = ['mconv_e3k5g8']
465 | arch_cells['normal_post'] = ['mconv_e3k5g8']
466 | arch_cells['up_post'] = ['mconv_e3k5g8']
467 | arch_cells['ar_nn'] = ['']
468 | else:
469 | raise NotImplementedError
470 |
471 | return arch_cells
472 |
473 |
474 | def groups_per_scale(num_scales, num_groups_per_scale, is_adaptive, divider=2, minimum_groups=1):
475 | g = []
476 | n = num_groups_per_scale
477 | for s in range(num_scales):
478 | assert n >= 1
479 | g.append(n)
480 | if is_adaptive:
481 | n = n // divider
482 | n = max(minimum_groups, n)
483 | return g
484 |
--------------------------------------------------------------------------------