├── .gitignore ├── LICENSE ├── README.md ├── datasets.py ├── distributions.py ├── evaluate.py ├── fid ├── LICENSE_pytorch_fid ├── __init__.py ├── fid_score.py └── inception.py ├── img ├── celebahq.png └── model_diagram.png ├── lmdb_datasets.py ├── model.py ├── neural_ar_operations.py ├── neural_operations.py ├── requirements.txt ├── scripts ├── Dockerfile ├── convert_tfrecord_to_lmdb.py ├── create_celeba64_lmdb.py ├── create_ffhq_lmdb.py └── precompute_fid_statistics.py ├── thirdparty ├── LICENSE_PyTorch ├── LICENSE_apache ├── LICENSE_torchvision ├── __init__.py ├── adamax.py ├── functions.py ├── inplaced_sync_batchnorm.py ├── lsun.py └── swish.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | visualizations/ 2 | .idea/ 3 | .run/ 4 | *.pyc 5 | *.pdf 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | NVIDIA Source Code License for NVAE 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | 7 | “Software” means the original work of authorship made available under this License. 8 | 9 | “Work” means the Software and any additions to or derivative works of the Software that are made available under 10 | this License. 11 | 12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under 13 | U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include 14 | works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 15 | 16 | Works, including the Software, are “made available” under this License by including in or with the Work either 17 | (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 18 | 19 | 2. License Grant 20 | 21 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, 22 | worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly 23 | display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 24 | 25 | 3. Limitations 26 | 27 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you 28 | include a complete copy of this License with your distribution, and (c) you retain without modification any 29 | copyright, patent, trademark, or attribution notices that are present in the Work. 30 | 31 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and 32 | distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use 33 | limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works 34 | that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution 35 | requirements in Section 3.1) will continue to apply to the Work itself. 36 | 37 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use 38 | non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative 39 | works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 40 | 41 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, 42 | cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then 43 | your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. 44 | 45 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, 46 | or trademarks, except as necessary to reproduce the notices described in this License. 47 | 48 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the 49 | grant in Section 2.1) will terminate immediately. 50 | 51 | 4. Disclaimer of Warranty. 52 | 53 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING 54 | WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU 55 | BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 56 | 57 | 5. Limitation of Liability. 58 | 59 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING 60 | NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 61 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR 62 | INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR 63 | DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN 64 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 65 | 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Official PyTorch Implementation of "NVAE: A Deep Hierarchical Variational Autoencoder" [(NeurIPS 2020 Spotlight Paper)](https://arxiv.org/abs/2007.03898) 2 | 3 |
4 | Arash Vahdat·   5 | Jan Kautz 6 |
7 |
8 |
9 | 10 | [NVAE](https://arxiv.org/abs/2007.03898) is a deep hierarchical variational autoencoder that enables training SOTA 11 | likelihood-based generative models on several image datasets. 12 | 13 |

14 | 15 |

16 | 17 | ## Requirements 18 | NVAE is built in Python 3.7 using PyTorch 1.6.0. Use the following command to install the requirements: 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ## Set up file paths and data 24 | We have examined NVAE on several datasets. For large datasets, we store the data in LMDB datasets 25 | for I/O efficiency. Click below on each dataset to see how you can prepare your data. Below, `$DATA_DIR` indicates 26 | the path to a data directory that will contain all the datasets and `$CODE_DIR` refers to the code directory: 27 | 28 |
MNIST and CIFAR-10 29 | 30 | These datasets will be downloaded automatically, when you run the main training for NVAE using `train.py` 31 | for the first time. You can use `--data=$DATA_DIR/mnist` or `--data=$DATA_DIR/cifar10`, so that the datasets 32 | are downloaded to the corresponding directories. 33 |
34 | 35 |
CelebA 64 36 | Run the following commands to download the CelebA images and store them in an LMDB dataset: 37 | 38 | ```shell script 39 | cd $CODE_DIR/scripts 40 | python create_celeba64_lmdb.py --split train --img_path $DATA_DIR/celeba_org --lmdb_path $DATA_DIR/celeba64_lmdb 41 | python create_celeba64_lmdb.py --split valid --img_path $DATA_DIR/celeba_org --lmdb_path $DATA_DIR/celeba64_lmdb 42 | python create_celeba64_lmdb.py --split test --img_path $DATA_DIR/celeba_org --lmdb_path $DATA_DIR/celeba64_lmdb 43 | ``` 44 | Above, the images will be downloaded to `$DATA_DIR/celeba_org` automatically and then then LMDB datasets are created 45 | at `$DATA_DIR/celeba64_lmdb`. 46 |
47 | 48 |
ImageNet 32x32 49 | 50 | Run the following commands to download tfrecord files from [GLOW](https://github.com/openai/glow) and to convert them 51 | to LMDB datasets 52 | ```shell script 53 | mkdir -p $DATA_DIR/imagenet-oord 54 | cd $DATA_DIR/imagenet-oord 55 | wget https://storage.googleapis.com/glow-demo/data/imagenet-oord-tfr.tar 56 | tar -xvf imagenet-oord-tfr.tar 57 | cd $CODE_DIR/scripts 58 | python convert_tfrecord_to_lmdb.py --dataset=imagenet-oord_32 --tfr_path=$DATA_DIR/imagenet-oord/mnt/host/imagenet-oord-tfr --lmdb_path=$DATA_DIR/imagenet-oord/imagenet-oord-lmdb_32 --split=train 59 | python convert_tfrecord_to_lmdb.py --dataset=imagenet-oord_32 --tfr_path=$DATA_DIR/imagenet-oord/mnt/host/imagenet-oord-tfr --lmdb_path=$DATA_DIR/imagenet-oord/imagenet-oord-lmdb_32 --split=validation 60 | ``` 61 |
62 | 63 |
CelebA HQ 256 64 | 65 | Run the following commands to download tfrecord files from [GLOW](https://github.com/openai/glow) and to convert them 66 | to LMDB datasets 67 | ```shell script 68 | mkdir -p $DATA_DIR/celeba 69 | cd $DATA_DIR/celeba 70 | wget https://storage.googleapis.com/glow-demo/data/celeba-tfr.tar 71 | tar -xvf celeba-tfr.tar 72 | cd $CODE_DIR/scripts 73 | python convert_tfrecord_to_lmdb.py --dataset=celeba --tfr_path=$DATA_DIR/celeba/celeba-tfr --lmdb_path=$DATA_DIR/celeba/celeba-lmdb --split=train 74 | python convert_tfrecord_to_lmdb.py --dataset=celeba --tfr_path=$DATA_DIR/celeba/celeba-tfr --lmdb_path=$DATA_DIR/celeba/celeba-lmdb --split=validation 75 | ``` 76 |
77 | 78 | 79 |
FFHQ 256 80 | 81 | Visit [this Google drive location](https://drive.google.com/drive/folders/1WocxvZ4GEZ1DI8dOz30aSj2zT6pkATYS) and download 82 | `images1024x1024.zip`. Run the following commands to unzip the images and to store them in LMDB datasets: 83 | ```shell script 84 | mkdir -p $DATA_DIR/ffhq 85 | unzip images1024x1024.zip -d $DATA_DIR/ffhq/ 86 | cd $CODE_DIR/scripts 87 | python create_ffhq_lmdb.py --ffhq_img_path=$DATA_DIR/ffhq/images1024x1024/ --ffhq_lmdb_path=$DATA_DIR/ffhq/ffhq-lmdb --split=train 88 | python create_ffhq_lmdb.py --ffhq_img_path=$DATA_DIR/ffhq/images1024x1024/ --ffhq_lmdb_path=$DATA_DIR/ffhq/ffhq-lmdb --split=validation 89 | ``` 90 |
91 | 92 |
LSUN 93 | 94 | We use LSUN datasets in our follow-up works. Visit [LSUN](https://www.yf.io/p/lsun) for 95 | instructions on how to download this dataset. Since the LSUN scene datasets come in the 96 | LMDB format, they are ready to be loaded using torchvision data loaders. 97 | 98 |
99 | 100 | 101 | ## Running the main NVAE training and evaluation scripts 102 | We use the following commands on each dataset for training NVAEs on each dataset for 103 | Table 1 in the [paper](https://arxiv.org/pdf/2007.03898.pdf). In all the datasets but MNIST 104 | normalizing flows are enabled. Check Table 6 in the paper for more information on training 105 | details. Note that for the multinode training (more than 8-GPU experiments), we use the `mpirun` 106 | command to run the training scripts on multiple nodes. Please adjust the commands below according to your setup. 107 | Below `IP_ADDR` is the IP address of the machine that will host the process with rank 0 108 | (see [here](https://pytorch.org/tutorials/intermediate/dist_tuto.html#initialization-methods)). 109 | `NODE_RANK` is the index of each node among all the nodes that are running the job. 110 | 111 |
MNIST 112 | 113 | Two 16-GB V100 GPUs are used for training NVAE on dynamically binarized MNIST. Training takes about 21 hours. 114 | 115 | ```shell script 116 | export EXPR_ID=UNIQUE_EXPR_ID 117 | export DATA_DIR=PATH_TO_DATA_DIR 118 | export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR 119 | export CODE_DIR=PATH_TO_CODE_DIR 120 | cd $CODE_DIR 121 | python train.py --data $DATA_DIR/mnist --root $CHECKPOINT_DIR --save $EXPR_ID --dataset mnist --batch_size 200 \ 122 | --epochs 400 --num_latent_scales 2 --num_groups_per_scale 10 --num_postprocess_cells 3 --num_preprocess_cells 3 \ 123 | --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 --num_latent_per_group 20 --num_preprocess_blocks 2 \ 124 | --num_postprocess_blocks 2 --weight_decay_norm 1e-2 --num_channels_enc 32 --num_channels_dec 32 --num_nf 0 \ 125 | --ada_groups --num_process_per_node 2 --use_se --res_dist --fast_adamax 126 | ``` 127 |
128 | 129 |
CIFAR-10 130 | 131 | Eight 16-GB V100 GPUs are used for training NVAE on CIFAR-10. Training takes about 55 hours. 132 | 133 | ```shell script 134 | export EXPR_ID=UNIQUE_EXPR_ID 135 | export DATA_DIR=PATH_TO_DATA_DIR 136 | export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR 137 | export CODE_DIR=PATH_TO_CODE_DIR 138 | cd $CODE_DIR 139 | python train.py --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR --save $EXPR_ID --dataset cifar10 \ 140 | --num_channels_enc 128 --num_channels_dec 128 --epochs 400 --num_postprocess_cells 2 --num_preprocess_cells 2 \ 141 | --num_latent_scales 1 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \ 142 | --num_preprocess_blocks 1 --num_postprocess_blocks 1 --num_groups_per_scale 30 --batch_size 32 \ 143 | --weight_decay_norm 1e-2 --num_nf 1 --num_process_per_node 8 --use_se --res_dist --fast_adamax 144 | ``` 145 |
146 | 147 |
CelebA 64 148 | 149 | Eight 16-GB V100 GPUs are used for training NVAE on CelebA 64. Training takes about 92 hours. 150 | 151 | ```shell script 152 | export EXPR_ID=UNIQUE_EXPR_ID 153 | export DATA_DIR=PATH_TO_DATA_DIR 154 | export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR 155 | export CODE_DIR=PATH_TO_CODE_DIR 156 | cd $CODE_DIR 157 | python train.py --data $DATA_DIR/celeba64_lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset celeba_64 \ 158 | --num_channels_enc 64 --num_channels_dec 64 --epochs 90 --num_postprocess_cells 2 --num_preprocess_cells 2 \ 159 | --num_latent_scales 3 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \ 160 | --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-1 --num_groups_per_scale 20 \ 161 | --batch_size 16 --num_nf 1 --ada_groups --num_process_per_node 8 --use_se --res_dist --fast_adamax 162 | ``` 163 |
164 | 165 |
ImageNet 32x32 166 | 167 | 24 16-GB V100 GPUs are used for training NVAE on ImageNet 32x32. Training takes about 70 hours. 168 | 169 | ```shell script 170 | export EXPR_ID=UNIQUE_EXPR_ID 171 | export DATA_DIR=PATH_TO_DATA_DIR 172 | export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR 173 | export CODE_DIR=PATH_TO_CODE_DIR 174 | export IP_ADDR=IP_ADDRESS 175 | export NODE_RANK=NODE_RANK_BETWEEN_0_TO_2 176 | cd $CODE_DIR 177 | mpirun --allow-run-as-root -np 3 -npernode 1 bash -c \ 178 | 'python train.py --data $DATA_DIR/imagenet-oord/imagenet-oord-lmdb_32 --root $CHECKPOINT_DIR --save $EXPR_ID --dataset imagenet_32 \ 179 | --num_channels_enc 192 --num_channels_dec 192 --epochs 45 --num_postprocess_cells 2 --num_preprocess_cells 2 \ 180 | --num_latent_scales 1 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \ 181 | --num_preprocess_blocks 1 --num_postprocess_blocks 1 --num_groups_per_scale 28 \ 182 | --batch_size 24 --num_nf 1 --warmup_epochs 1 \ 183 | --weight_decay_norm 1e-2 --weight_decay_norm_anneal --weight_decay_norm_init 1e0 \ 184 | --num_process_per_node 8 --use_se --res_dist \ 185 | --fast_adamax --node_rank $NODE_RANK --num_proc_node 3 --master_address $IP_ADDR ' 186 | ``` 187 |
188 | 189 |
CelebA HQ 256 190 | 191 | 24 32-GB V100 GPUs are used for training NVAE on CelebA HQ 256. Training takes about 94 hours. 192 | 193 | ```shell script 194 | export EXPR_ID=UNIQUE_EXPR_ID 195 | export DATA_DIR=PATH_TO_DATA_DIR 196 | export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR 197 | export CODE_DIR=PATH_TO_CODE_DIR 198 | export IP_ADDR=IP_ADDRESS 199 | export NODE_RANK=NODE_RANK_BETWEEN_0_TO_2 200 | cd $CODE_DIR 201 | mpirun --allow-run-as-root -np 3 -npernode 1 bash -c \ 202 | 'python train.py --data $DATA_DIR/celeba/celeba-lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset celeba_256 \ 203 | --num_channels_enc 30 --num_channels_dec 30 --epochs 300 --num_postprocess_cells 2 --num_preprocess_cells 2 \ 204 | --num_latent_scales 5 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \ 205 | --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-2 --num_groups_per_scale 16 \ 206 | --batch_size 4 --num_nf 2 --ada_groups --min_groups_per_scale 4 \ 207 | --weight_decay_norm_anneal --weight_decay_norm_init 1. --num_process_per_node 8 --use_se --res_dist \ 208 | --fast_adamax --num_x_bits 5 --node_rank $NODE_RANK --num_proc_node 3 --master_address $IP_ADDR ' 209 | ``` 210 | 211 | In our early experiments, a smaller model with 24 channels instead of 30, could be trained on only 8 GPUs in 212 | the same time (with the batch size of 6). The smaller models obtain only 0.01 bpd higher 213 | negative log-likelihood. 214 |
215 | 216 |
FFHQ 256 217 | 218 | 24 32-GB V100 GPUs are used for training NVAE on FFHQ 256. Training takes about 160 hours. 219 | 220 | ```shell script 221 | export EXPR_ID=UNIQUE_EXPR_ID 222 | export DATA_DIR=PATH_TO_DATA_DIR 223 | export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR 224 | export CODE_DIR=PATH_TO_CODE_DIR 225 | export IP_ADDR=IP_ADDRESS 226 | export NODE_RANK=NODE_RANK_BETWEEN_0_TO_2 227 | cd $CODE_DIR 228 | mpirun --allow-run-as-root -np 3 -npernode 1 bash -c \ 229 | 'python train.py --data $DATA_DIR/ffhq/ffhq-lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset ffhq \ 230 | --num_channels_enc 30 --num_channels_dec 30 --epochs 200 --num_postprocess_cells 2 --num_preprocess_cells 2 \ 231 | --num_latent_scales 5 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \ 232 | --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-1 --num_groups_per_scale 16 \ 233 | --batch_size 4 --num_nf 2 --ada_groups --min_groups_per_scale 4 \ 234 | --weight_decay_norm_anneal --weight_decay_norm_init 1. --num_process_per_node 8 --use_se --res_dist \ 235 | --fast_adamax --num_x_bits 5 --learning_rate 8e-3 --node_rank $NODE_RANK --num_proc_node 3 --master_address $IP_ADDR ' 236 | ``` 237 | 238 | In our early experiments, a smaller model with 24 channels instead of 30, could be trained on only 8 GPUs in 239 | the same time (with the batch size of 6). The smaller models obtain only 0.01 bpd higher 240 | negative log-likelihood. 241 |
242 | 243 | **If for any reason your training is stopped, use the exact same commend with the addition of `--cont_training` 244 | to continue training from the last saved checkpoint. If you observe NaN, continuing the training using this flag 245 | usually will not fix the NaN issue.** 246 | 247 | ## Known Issues 248 |
Cannot build CelebA 64 or training gives NaN right at the beginning on this dataset 249 | 250 | Several users have reported issues building CelebA 64 or have encountered NaN at the beginning of training on this dataset. 251 | If you face similar issues on this dataset, you can download this dataset manually and build LMDBs using instructions 252 | on this issue https://github.com/NVlabs/NVAE/issues/2 . 253 |
254 | 255 |
Getting NaN after a few epochs of training 256 | 257 | One of the main challenges in training very deep hierarchical VAEs is training instability that we discussed in the paper. 258 | We have verified that the settings in the commands above can be trained in a stable way. If you modify the settings 259 | above and you encounter NaN after a few epochs of training, you can use these tricks to stabilize your training: 260 | i) increase the spectral regularization coefficient, `--weight_decay_norm`. ii) Use exponential decay on 261 | `--weight_decay_norm` using `--weight_decay_norm_anneal` and `--weight_decay_norm_init`. iii) Decrease learning rate. 262 |
263 | 264 |
Training freezes with no NaN 265 | 266 | In some very rare cases, we observed that training freezes after 2-3 days of training. We believe the root cause 267 | of this is because of a racing condition that is happening in one of the low-level libraries. If for any reason the training 268 | is stopped, kill your current run, and use the exact same commend with the addition of `--cont_training` 269 | to continue training from the last saved checkpoint. 270 |
271 | 272 | ## Monitoring the training progress 273 | While running any of the commands above, you can monitor the training progress using Tensorboard: 274 | 275 |
Click here 276 | 277 | ```shell script 278 | tensorboard --logdir $CHECKPOINT_DIR/eval-$EXPR_ID/ 279 | ``` 280 | Above, `$CHECKPOINT_DIR` and `$EXPR_ID` are the same variables used for running the main training script. 281 | 282 |
283 | 284 | ## Post-training sampling, evaluation, and checkpoints 285 | 286 |
Evaluating Log-Likelihood 287 | 288 | You can use the following command to load a trained model and evaluate it on the test datasets: 289 | 290 | ```shell script 291 | cd $CODE_DIR 292 | python evaluate.py --checkpoint $CHECKPOINT_DIR/eval-$EXPR_ID/checkpoint.pt --data $DATA_DIR/mnist --eval_mode=evaluate --num_iw_samples=1000 293 | ``` 294 | Above, `--num_iw_samples` indicates the number of importance weighted samples used in evaluation. 295 | `$CHECKPOINT_DIR` and `$EXPR_ID` are the same variables used for running the main training script. 296 | Set `--data` to the same argument that was used when training NVAE (our example is for MNIST). 297 | 298 |
299 | 300 |
Sampling 301 | 302 | You can also use the following command to generate samples from a trained model: 303 | 304 | ```shell script 305 | cd $CODE_DIR 306 | python evaluate.py --checkpoint $CHECKPOINT_DIR/eval-$EXPR_ID/checkpoint.pt --eval_mode=sample --temp=0.6 --readjust_bn 307 | ``` 308 | where `--temp` sets the temperature used for sampling and `--readjust_bn` enables readjustment of the BN statistics 309 | as described in the paper. If you remove `--readjust_bn`, the sampling will proceed with BN layer in the eval mode 310 | (i.e., BN layers will use running mean and variances extracted during training). 311 | 312 |
313 | 314 |
Computing FID 315 | 316 | You can compute the FID score using 50K samples. To do so, you will need to create 317 | a mean and covariance statistics file on the training data using a command like: 318 | 319 | ```shell script 320 | cd $CODE_DIR 321 | python scripts/precompute_fid_statistics.py --data $DATA_DIR/cifar10 --dataset cifar10 --fid_dir /tmp/fid-stats/ 322 | ``` 323 | The command above computes the references statistics on the CIFAR-10 dataset and stores them in the `--fid_dir` durectory. 324 | Given the reference statistics file, we can run the following command to compute the FID score: 325 | 326 | ```shell script 327 | cd $CODE_DIR 328 | python evaluate.py --checkpoint $CHECKPOINT_DIR/eval-$EXPR_ID/checkpoint.pt --data $DATA_DIR/cifar10 --eval_mode=evaluate_fid --fid_dir /tmp/fid-stats/ --temp=0.6 --readjust_bn 329 | ``` 330 | where `--temp` sets the temperature used for sampling and `--readjust_bn` enables readjustment of the BN statistics 331 | as described in the paper. If you remove `--readjust_bn`, the sampling will proceed with BN layer in the eval mode 332 | (i.e., BN layers will use running mean and variances extracted during training). 333 | Above, `$CHECKPOINT_DIR` and `$EXPR_ID` are the same variables used for running the main training script. 334 | Set `--data` to the same argument that was used when training NVAE (our example is for MNIST). 335 | 336 |
337 | 338 |
Checkpoints 339 | 340 | We provide checkpoints on MNIST, CIFAR-10, CelebA 64, CelebA HQ 256, FFHQ in 341 | [this Google drive directory](https://drive.google.com/drive/folders/1KVpw12AzdVjvbfEYM_6_3sxTy93wWkbe?usp=sharing). 342 | For CIFAR10, we provide two checkpoints as we observed that a multiscale NVAE provides better qualitative 343 | results than a single scale model on this dataset. The multiscale model is only slightly worse in terms 344 | of log-likelihood (0.01 bpd). We also observe that one of our early models on CelebA HQ 256 with 0.01 bpd 345 | worse likelihood generates much better images in low temperature on this dataset. 346 | 347 | You can use the commands above to evaluate or sample from these checkpoints. 348 | 349 |
350 | 351 | ## How to construct smaller NVAE models 352 | In the commands above, we are constructing big NVAE models that require several days of training 353 | in most cases. If you'd like to construct smaller NVAEs, you can use these tricks: 354 | 355 | * Reduce the network width: `--num_channels_enc` and `--num_channels_dec` are controlling the number 356 | of initial channels in the bottom-up and top-down networks respectively. Recall that we halve the 357 | number of channels with every spatial downsampling layer in the bottom-up network, and we double the number of 358 | channels with every upsampling layer in the top-down network. By reducing 359 | `--num_channels_enc` and `--num_channels_dec`, you can reduce the overall width of the networks. 360 | 361 | * Reduce the number of residual cells in the hierarchy: `--num_cell_per_cond_enc` and 362 | `--num_cell_per_cond_dec` control the number of residual cells used between every latent variable 363 | group in the bottom-up and top-down networks respectively. In most of our experiments, we are using 364 | two cells per group for both networks. You can reduce the number of residual cells to one to make the model 365 | smaller. 366 | 367 | * Reduce the number of epochs: You can reduce the training time by reducing `--epochs`. 368 | 369 | * Reduce the number of groups: You can make NVAE smaller by using a smaller number of latent variable groups. 370 | We use two schemes for setting the number of groups: 371 | 1. An equal number of groups: This is set by `--num_groups_per_scale` which indicates the number of groups 372 | in each scale of latent variables. Reduce this number to have a small NVAE. 373 | 374 | 2. An adaptive number of groups: This is enabled by `--ada_groups`. In this case, the highest 375 | resolution of latent variables will have `--num_groups_per_scale` groups and 376 | the smaller scales will get half the number of groups successively (see groups_per_scale in utils.py). 377 | We don't let the number of groups go below `--min_groups_per_scale`. You can reduce 378 | the total number of groups by reducing `--num_groups_per_scale` and `--min_groups_per_scale` 379 | when `--ada_groups` is enabled. 380 | 381 | ## Understanding the implementation 382 | If you are modifying the code, you can use the following figure to map the code to the paper. 383 | 384 |

385 | 386 |

387 | 388 | 389 | ## Traversing the latent space 390 | We can generate images by traversing in the latent space of NVAE. This sequence is generated using our model 391 | trained on CelebA HQ, by interpolating between samples generated with temperature 0.6. 392 | Some artifacts are due to color quantization in GIFs. 393 | 394 |

395 | 396 |

397 | 398 | ## License 399 | Please check the LICENSE file. NVAE may be used non-commercially, meaning for research or 400 | evaluation purposes only. For business inquiries, please contact 401 | [researchinquiries@nvidia.com](mailto:researchinquiries@nvidia.com). 402 | 403 | You should take into consideration that VAEs are trained to mimic the training data distribution, and, any 404 | bias introduced in data collection will make VAEs generate samples with a similar bias. Additional bias could be 405 | introduced during model design, training, or when VAEs are sampled using small temperatures. Bias correction in 406 | generative learning is an active area of research, and we recommend interested readers to check this area before 407 | building applications using NVAE. 408 | 409 | ## Bibtex: 410 | Please cite our paper, if you happen to use this codebase: 411 | 412 | ``` 413 | @inproceedings{vahdat2020NVAE, 414 | title={{NVAE}: A Deep Hierarchical Variational Autoencoder}, 415 | author={Vahdat, Arash and Kautz, Jan}, 416 | booktitle={Neural Information Processing Systems (NeurIPS)}, 417 | year={2020} 418 | } 419 | ``` 420 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for NVAE. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | """Code for getting the data loaders.""" 9 | 10 | import numpy as np 11 | from PIL import Image 12 | import torch 13 | import torchvision.datasets as dset 14 | import torchvision.transforms as transforms 15 | from torch.utils.data import Dataset 16 | from scipy.io import loadmat 17 | import os 18 | import urllib 19 | from lmdb_datasets import LMDBDataset 20 | from thirdparty.lsun import LSUN 21 | 22 | 23 | class StackedMNIST(dset.MNIST): 24 | def __init__(self, root, train=True, transform=None, target_transform=None, 25 | download=False): 26 | super(StackedMNIST, self).__init__(root=root, train=train, transform=transform, 27 | target_transform=target_transform, download=download) 28 | 29 | index1 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))]) 30 | index2 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))]) 31 | index3 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))]) 32 | self.num_images = 2 * len(self.data) 33 | 34 | self.index = [] 35 | for i in range(self.num_images): 36 | self.index.append((index1[i], index2[i], index3[i])) 37 | 38 | def __len__(self): 39 | return self.num_images 40 | 41 | def __getitem__(self, index): 42 | img = np.zeros((28, 28, 3), dtype=np.uint8) 43 | target = 0 44 | for i in range(3): 45 | img_, target_ = self.data[self.index[index][i]], int(self.targets[self.index[index][i]]) 46 | img[:, :, i] = img_ 47 | target += target_ * 10 ** (2 - i) 48 | 49 | img = Image.fromarray(img, mode="RGB") 50 | 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | 54 | if self.target_transform is not None: 55 | target = self.target_transform(target) 56 | 57 | return img, target 58 | 59 | 60 | 61 | class Binarize(object): 62 | """ This class introduces a binarization transformation 63 | """ 64 | def __call__(self, pic): 65 | return torch.Tensor(pic.size()).bernoulli_(pic) 66 | 67 | def __repr__(self): 68 | return self.__class__.__name__ + '()' 69 | 70 | 71 | class CropCelebA64(object): 72 | """ This class applies cropping for CelebA64. This is a simplified implementation of: 73 | https://github.com/andersbll/autoencoding_beyond_pixels/blob/master/dataset/celeba.py 74 | """ 75 | def __call__(self, pic): 76 | new_pic = pic.crop((15, 40, 178 - 15, 218 - 30)) 77 | return new_pic 78 | 79 | def __repr__(self): 80 | return self.__class__.__name__ + '()' 81 | 82 | 83 | def get_loaders(args): 84 | """Get data loaders for required dataset.""" 85 | return get_loaders_eval(args.dataset, args) 86 | 87 | def download_omniglot(data_dir): 88 | filename = 'chardata.mat' 89 | if not os.path.exists(data_dir): 90 | os.mkdir(data_dir) 91 | url = 'https://raw.github.com/yburda/iwae/master/datasets/OMNIGLOT/chardata.mat' 92 | 93 | filepath = os.path.join(data_dir, filename) 94 | if not os.path.exists(filepath): 95 | filepath, _ = urllib.request.urlretrieve(url, filepath) 96 | print('Downloaded', filename) 97 | 98 | return 99 | 100 | 101 | def load_omniglot(data_dir): 102 | download_omniglot(data_dir) 103 | 104 | data_path = os.path.join(data_dir, 'chardata.mat') 105 | 106 | omni = loadmat(data_path) 107 | train_data = 255 * omni['data'].astype('float32').reshape((28, 28, -1)).transpose((2, 1, 0)) 108 | test_data = 255 * omni['testdata'].astype('float32').reshape((28, 28, -1)).transpose((2, 1, 0)) 109 | 110 | train_data = train_data.astype('uint8') 111 | test_data = test_data.astype('uint8') 112 | 113 | return train_data, test_data 114 | 115 | 116 | class OMNIGLOT(Dataset): 117 | def __init__(self, data, transform): 118 | self.data = data 119 | self.transform = transform 120 | 121 | def __getitem__(self, index): 122 | d = self.data[index] 123 | img = Image.fromarray(d) 124 | return self.transform(img), 0 # return zero as label. 125 | 126 | def __len__(self): 127 | return len(self.data) 128 | 129 | def get_loaders_eval(dataset, args): 130 | """Get train and valid loaders for cifar10/tiny imagenet.""" 131 | 132 | if dataset == 'cifar10': 133 | num_classes = 10 134 | train_transform, valid_transform = _data_transforms_cifar10(args) 135 | train_data = dset.CIFAR10( 136 | root=args.data, train=True, download=True, transform=train_transform) 137 | valid_data = dset.CIFAR10( 138 | root=args.data, train=False, download=True, transform=valid_transform) 139 | elif dataset == 'mnist': 140 | num_classes = 10 141 | train_transform, valid_transform = _data_transforms_mnist(args) 142 | train_data = dset.MNIST( 143 | root=args.data, train=True, download=True, transform=train_transform) 144 | valid_data = dset.MNIST( 145 | root=args.data, train=False, download=True, transform=valid_transform) 146 | elif dataset == 'stacked_mnist': 147 | num_classes = 1000 148 | train_transform, valid_transform = _data_transforms_stacked_mnist(args) 149 | train_data = StackedMNIST( 150 | root=args.data, train=True, download=True, transform=train_transform) 151 | valid_data = StackedMNIST( 152 | root=args.data, train=False, download=True, transform=valid_transform) 153 | elif dataset == 'omniglot': 154 | num_classes = 0 155 | download_omniglot(args.data) 156 | train_transform, valid_transform = _data_transforms_mnist(args) 157 | train_data, valid_data = load_omniglot(args.data) 158 | train_data = OMNIGLOT(train_data, train_transform) 159 | valid_data = OMNIGLOT(valid_data, valid_transform) 160 | elif dataset.startswith('celeba'): 161 | if dataset == 'celeba_64': 162 | resize = 64 163 | num_classes = 40 164 | train_transform, valid_transform = _data_transforms_celeba64(resize) 165 | train_data = LMDBDataset(root=args.data, name='celeba64', train=True, transform=train_transform, is_encoded=True) 166 | valid_data = LMDBDataset(root=args.data, name='celeba64', train=False, transform=valid_transform, is_encoded=True) 167 | elif dataset in {'celeba_256'}: 168 | num_classes = 1 169 | resize = int(dataset.split('_')[1]) 170 | train_transform, valid_transform = _data_transforms_generic(resize) 171 | train_data = LMDBDataset(root=args.data, name='celeba', train=True, transform=train_transform) 172 | valid_data = LMDBDataset(root=args.data, name='celeba', train=False, transform=valid_transform) 173 | else: 174 | raise NotImplementedError 175 | elif dataset.startswith('lsun'): 176 | if dataset.startswith('lsun_bedroom'): 177 | resize = int(dataset.split('_')[-1]) 178 | num_classes = 1 179 | train_transform, valid_transform = _data_transforms_lsun(resize) 180 | train_data = LSUN(root=args.data, classes=['bedroom_train'], transform=train_transform) 181 | valid_data = LSUN(root=args.data, classes=['bedroom_val'], transform=valid_transform) 182 | elif dataset.startswith('lsun_church'): 183 | resize = int(dataset.split('_')[-1]) 184 | num_classes = 1 185 | train_transform, valid_transform = _data_transforms_lsun(resize) 186 | train_data = LSUN(root=args.data, classes=['church_outdoor_train'], transform=train_transform) 187 | valid_data = LSUN(root=args.data, classes=['church_outdoor_val'], transform=valid_transform) 188 | elif dataset.startswith('lsun_tower'): 189 | resize = int(dataset.split('_')[-1]) 190 | num_classes = 1 191 | train_transform, valid_transform = _data_transforms_lsun(resize) 192 | train_data = LSUN(root=args.data, classes=['tower_train'], transform=train_transform) 193 | valid_data = LSUN(root=args.data, classes=['tower_val'], transform=valid_transform) 194 | else: 195 | raise NotImplementedError 196 | elif dataset.startswith('imagenet'): 197 | num_classes = 1 198 | resize = int(dataset.split('_')[1]) 199 | assert args.data.replace('/', '')[-3:] == dataset.replace('/', '')[-3:], 'the size should match' 200 | train_transform, valid_transform = _data_transforms_generic(resize) 201 | train_data = LMDBDataset(root=args.data, name='imagenet-oord', train=True, transform=train_transform) 202 | valid_data = LMDBDataset(root=args.data, name='imagenet-oord', train=False, transform=valid_transform) 203 | elif dataset.startswith('ffhq'): 204 | num_classes = 1 205 | resize = 256 206 | train_transform, valid_transform = _data_transforms_generic(resize) 207 | train_data = LMDBDataset(root=args.data, name='ffhq', train=True, transform=train_transform) 208 | valid_data = LMDBDataset(root=args.data, name='ffhq', train=False, transform=valid_transform) 209 | else: 210 | raise NotImplementedError 211 | 212 | train_sampler, valid_sampler = None, None 213 | if args.distributed: 214 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) 215 | valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data) 216 | 217 | train_queue = torch.utils.data.DataLoader( 218 | train_data, batch_size=args.batch_size, 219 | shuffle=(train_sampler is None), 220 | sampler=train_sampler, pin_memory=True, num_workers=8, drop_last=True) 221 | 222 | valid_queue = torch.utils.data.DataLoader( 223 | valid_data, batch_size=args.batch_size, 224 | shuffle=(valid_sampler is None), 225 | sampler=valid_sampler, pin_memory=True, num_workers=1, drop_last=False) 226 | 227 | return train_queue, valid_queue, num_classes 228 | 229 | 230 | def _data_transforms_cifar10(args): 231 | """Get data transforms for cifar10.""" 232 | 233 | train_transform = transforms.Compose([ 234 | transforms.RandomHorizontalFlip(), 235 | transforms.ToTensor() 236 | ]) 237 | 238 | valid_transform = transforms.Compose([ 239 | transforms.ToTensor() 240 | ]) 241 | 242 | return train_transform, valid_transform 243 | 244 | 245 | def _data_transforms_mnist(args): 246 | """Get data transforms for cifar10.""" 247 | train_transform = transforms.Compose([ 248 | transforms.Pad(padding=2), 249 | transforms.ToTensor(), 250 | Binarize(), 251 | ]) 252 | 253 | valid_transform = transforms.Compose([ 254 | transforms.Pad(padding=2), 255 | transforms.ToTensor(), 256 | Binarize(), 257 | ]) 258 | 259 | return train_transform, valid_transform 260 | 261 | 262 | def _data_transforms_stacked_mnist(args): 263 | """Get data transforms for cifar10.""" 264 | train_transform = transforms.Compose([ 265 | transforms.Pad(padding=2), 266 | transforms.ToTensor() 267 | ]) 268 | 269 | valid_transform = transforms.Compose([ 270 | transforms.Pad(padding=2), 271 | transforms.ToTensor() 272 | ]) 273 | 274 | return train_transform, valid_transform 275 | 276 | 277 | def _data_transforms_generic(size): 278 | train_transform = transforms.Compose([ 279 | transforms.Resize(size), 280 | transforms.RandomHorizontalFlip(), 281 | transforms.ToTensor(), 282 | ]) 283 | 284 | valid_transform = transforms.Compose([ 285 | transforms.Resize(size), 286 | transforms.ToTensor(), 287 | ]) 288 | 289 | return train_transform, valid_transform 290 | 291 | 292 | def _data_transforms_celeba64(size): 293 | train_transform = transforms.Compose([ 294 | CropCelebA64(), 295 | transforms.Resize(size), 296 | transforms.RandomHorizontalFlip(), 297 | transforms.ToTensor(), 298 | ]) 299 | 300 | valid_transform = transforms.Compose([ 301 | CropCelebA64(), 302 | transforms.Resize(size), 303 | transforms.ToTensor(), 304 | ]) 305 | 306 | return train_transform, valid_transform 307 | 308 | 309 | def _data_transforms_lsun(size): 310 | train_transform = transforms.Compose([ 311 | transforms.Resize(size), 312 | transforms.RandomCrop(size), 313 | transforms.RandomHorizontalFlip(), 314 | transforms.ToTensor(), 315 | ]) 316 | 317 | valid_transform = transforms.Compose([ 318 | transforms.Resize(size), 319 | transforms.CenterCrop(size), 320 | transforms.ToTensor(), 321 | ]) 322 | 323 | return train_transform, valid_transform 324 | -------------------------------------------------------------------------------- /distributions.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for NVAE. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | 13 | from utils import one_hot 14 | 15 | @torch.jit.script 16 | def soft_clamp5(x: torch.Tensor): 17 | return x.div(5.).tanh_().mul(5.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5] 18 | 19 | 20 | @torch.jit.script 21 | def sample_normal_jit(mu, sigma): 22 | eps = mu.mul(0).normal_() 23 | z = eps.mul_(sigma).add_(mu) 24 | return z, eps 25 | 26 | 27 | class Normal: 28 | def __init__(self, mu, log_sigma, temp=1.): 29 | self.mu = soft_clamp5(mu) 30 | log_sigma = soft_clamp5(log_sigma) 31 | self.sigma = torch.exp(log_sigma) + 1e-2 # we don't need this after soft clamp 32 | if temp != 1.: 33 | self.sigma *= temp 34 | 35 | def sample(self): 36 | return sample_normal_jit(self.mu, self.sigma) 37 | 38 | def sample_given_eps(self, eps): 39 | return eps * self.sigma + self.mu 40 | 41 | def log_p(self, samples): 42 | normalized_samples = (samples - self.mu) / self.sigma 43 | log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - torch.log(self.sigma) 44 | return log_p 45 | 46 | def kl(self, normal_dist): 47 | term1 = (self.mu - normal_dist.mu) / normal_dist.sigma 48 | term2 = self.sigma / normal_dist.sigma 49 | 50 | return 0.5 * (term1 * term1 + term2 * term2) - 0.5 - torch.log(term2) 51 | 52 | 53 | class NormalDecoder: 54 | def __init__(self, param, num_bits=8): 55 | B, C, H, W = param.size() 56 | self.num_c = C // 2 57 | mu = param[:, :self.num_c, :, :] # B, 3, H, W 58 | log_sigma = param[:, self.num_c:, :, :] # B, 3, H, W 59 | self.dist = Normal(mu, log_sigma) 60 | 61 | def log_prob(self, samples): 62 | assert torch.max(samples) <= 1.0 and torch.min(samples) >= 0.0 63 | # convert samples to be in [-1, 1] 64 | samples = 2 * samples - 1.0 65 | 66 | return self.dist.log_p(samples) 67 | 68 | def sample(self, t=1.): 69 | x, _ = self.dist.sample() 70 | x = torch.clamp(x, -1, 1.) 71 | x = x / 2. + 0.5 72 | return x 73 | 74 | 75 | class DiscLogistic: 76 | def __init__(self, param): 77 | B, C, H, W = param.size() 78 | self.num_c = C // 2 79 | self.means = param[:, :self.num_c, :, :] # B, 3, H, W 80 | self.log_scales = torch.clamp(param[:, self.num_c:, :, :], min=-8.0) # B, 3, H, W 81 | 82 | def log_prob(self, samples): 83 | assert torch.max(samples) <= 1.0 and torch.min(samples) >= 0.0 84 | # convert samples to be in [-1, 1] 85 | samples = 2 * samples - 1.0 86 | 87 | B, C, H, W = samples.size() 88 | assert C == self.num_c 89 | 90 | centered = samples - self.means # B, 3, H, W 91 | inv_stdv = torch.exp(- self.log_scales) 92 | plus_in = inv_stdv * (centered + 1. / 255.) 93 | cdf_plus = torch.sigmoid(plus_in) 94 | min_in = inv_stdv * (centered - 1. / 255.) 95 | cdf_min = torch.sigmoid(min_in) 96 | log_cdf_plus = plus_in - F.softplus(plus_in) 97 | log_one_minus_cdf_min = - F.softplus(min_in) 98 | cdf_delta = cdf_plus - cdf_min 99 | mid_in = inv_stdv * centered 100 | log_pdf_mid = mid_in - self.log_scales - 2. * F.softplus(mid_in) 101 | 102 | log_prob_mid_safe = torch.where(cdf_delta > 1e-5, 103 | torch.log(torch.clamp(cdf_delta, min=1e-10)), 104 | log_pdf_mid - np.log(127.5)) 105 | # woow the original implementation uses samples > 0.999, this ignores the largest possible pixel value (255) 106 | # which is mapped to 0.9922 107 | log_probs = torch.where(samples < -0.999, log_cdf_plus, torch.where(samples > 0.99, log_one_minus_cdf_min, 108 | log_prob_mid_safe)) # B, 3, H, W 109 | 110 | return log_probs 111 | 112 | def sample(self): 113 | u = torch.Tensor(self.means.size()).uniform_(1e-5, 1. - 1e-5).cuda() # B, 3, H, W 114 | x = self.means + torch.exp(self.log_scales) * (torch.log(u) - torch.log(1. - u)) # B, 3, H, W 115 | x = torch.clamp(x, -1, 1.) 116 | x = x / 2. + 0.5 117 | return x 118 | 119 | 120 | class DiscMixLogistic: 121 | def __init__(self, param, num_mix=10, num_bits=8): 122 | B, C, H, W = param.size() 123 | self.num_mix = num_mix 124 | self.logit_probs = param[:, :num_mix, :, :] # B, M, H, W 125 | l = param[:, num_mix:, :, :].view(B, 3, 3 * num_mix, H, W) # B, 3, 3 * M, H, W 126 | self.means = l[:, :, :num_mix, :, :] # B, 3, M, H, W 127 | self.log_scales = torch.clamp(l[:, :, num_mix:2 * num_mix, :, :], min=-7.0) # B, 3, M, H, W 128 | self.coeffs = torch.tanh(l[:, :, 2 * num_mix:3 * num_mix, :, :]) # B, 3, M, H, W 129 | self.max_val = 2. ** num_bits - 1 130 | 131 | def log_prob(self, samples): 132 | assert torch.max(samples) <= 1.0 and torch.min(samples) >= 0.0 133 | # convert samples to be in [-1, 1] 134 | samples = 2 * samples - 1.0 135 | 136 | B, C, H, W = samples.size() 137 | assert C == 3, 'only RGB images are considered.' 138 | 139 | samples = samples.unsqueeze(4) # B, 3, H , W 140 | samples = samples.expand(-1, -1, -1, -1, self.num_mix).permute(0, 1, 4, 2, 3) # B, 3, M, H, W 141 | mean1 = self.means[:, 0, :, :, :] # B, M, H, W 142 | mean2 = self.means[:, 1, :, :, :] + \ 143 | self.coeffs[:, 0, :, :, :] * samples[:, 0, :, :, :] # B, M, H, W 144 | mean3 = self.means[:, 2, :, :, :] + \ 145 | self.coeffs[:, 1, :, :, :] * samples[:, 0, :, :, :] + \ 146 | self.coeffs[:, 2, :, :, :] * samples[:, 1, :, :, :] # B, M, H, W 147 | 148 | mean1 = mean1.unsqueeze(1) # B, 1, M, H, W 149 | mean2 = mean2.unsqueeze(1) # B, 1, M, H, W 150 | mean3 = mean3.unsqueeze(1) # B, 1, M, H, W 151 | means = torch.cat([mean1, mean2, mean3], dim=1) # B, 3, M, H, W 152 | centered = samples - means # B, 3, M, H, W 153 | 154 | inv_stdv = torch.exp(- self.log_scales) 155 | plus_in = inv_stdv * (centered + 1. / self.max_val) 156 | cdf_plus = torch.sigmoid(plus_in) 157 | min_in = inv_stdv * (centered - 1. / self.max_val) 158 | cdf_min = torch.sigmoid(min_in) 159 | log_cdf_plus = plus_in - F.softplus(plus_in) 160 | log_one_minus_cdf_min = - F.softplus(min_in) 161 | cdf_delta = cdf_plus - cdf_min 162 | mid_in = inv_stdv * centered 163 | log_pdf_mid = mid_in - self.log_scales - 2. * F.softplus(mid_in) 164 | 165 | log_prob_mid_safe = torch.where(cdf_delta > 1e-5, 166 | torch.log(torch.clamp(cdf_delta, min=1e-10)), 167 | log_pdf_mid - np.log(self.max_val / 2)) 168 | # the original implementation uses samples > 0.999, this ignores the largest possible pixel value (255) 169 | # which is mapped to 0.9922 170 | log_probs = torch.where(samples < -0.999, log_cdf_plus, torch.where(samples > 0.99, log_one_minus_cdf_min, 171 | log_prob_mid_safe)) # B, 3, M, H, W 172 | 173 | log_probs = torch.sum(log_probs, 1) + F.log_softmax(self.logit_probs, dim=1) # B, M, H, W 174 | return torch.logsumexp(log_probs, dim=1) # B, H, W 175 | 176 | def sample(self, t=1.): 177 | gumbel = -torch.log(- torch.log(torch.Tensor(self.logit_probs.size()).uniform_(1e-5, 1. - 1e-5).cuda())) # B, M, H, W 178 | sel = one_hot(torch.argmax(self.logit_probs / t + gumbel, 1), self.num_mix, dim=1) # B, M, H, W 179 | sel = sel.unsqueeze(1) # B, 1, M, H, W 180 | 181 | # select logistic parameters 182 | means = torch.sum(self.means * sel, dim=2) # B, 3, H, W 183 | log_scales = torch.sum(self.log_scales * sel, dim=2) # B, 3, H, W 184 | coeffs = torch.sum(self.coeffs * sel, dim=2) # B, 3, H, W 185 | 186 | # cells from logistic & clip to interval 187 | # we don't actually round to the nearest 8bit value when sampling 188 | u = torch.Tensor(means.size()).uniform_(1e-5, 1. - 1e-5).cuda() # B, 3, H, W 189 | x = means + torch.exp(log_scales) / t * (torch.log(u) - torch.log(1. - u)) # B, 3, H, W 190 | 191 | x0 = torch.clamp(x[:, 0, :, :], -1, 1.) # B, H, W 192 | x1 = torch.clamp(x[:, 1, :, :] + coeffs[:, 0, :, :] * x0, -1, 1) # B, H, W 193 | x2 = torch.clamp(x[:, 2, :, :] + coeffs[:, 1, :, :] * x0 + coeffs[:, 2, :, :] * x1, -1, 1) # B, H, W 194 | 195 | x0 = x0.unsqueeze(1) 196 | x1 = x1.unsqueeze(1) 197 | x2 = x2.unsqueeze(1) 198 | 199 | x = torch.cat([x0, x1, x2], 1) 200 | x = x / 2. + 0.5 201 | return x 202 | 203 | def mean(self): 204 | sel = torch.softmax(self.logit_probs, dim=1) # B, M, H, W 205 | sel = sel.unsqueeze(1) # B, 1, M, H, W 206 | 207 | # select logistic parameters 208 | means = torch.sum(self.means * sel, dim=2) # B, 3, H, W 209 | coeffs = torch.sum(self.coeffs * sel, dim=2) # B, 3, H, W 210 | 211 | # we don't sample from logistic components, because of the linear dependencies, we use mean 212 | x = means # B, 3, H, W 213 | x0 = torch.clamp(x[:, 0, :, :], -1, 1.) # B, H, W 214 | x1 = torch.clamp(x[:, 1, :, :] + coeffs[:, 0, :, :] * x0, -1, 1) # B, H, W 215 | x2 = torch.clamp(x[:, 2, :, :] + coeffs[:, 1, :, :] * x0 + coeffs[:, 2, :, :] * x1, -1, 1) # B, H, W 216 | 217 | x0 = x0.unsqueeze(1) 218 | x1 = x1.unsqueeze(1) 219 | x2 = x2.unsqueeze(1) 220 | 221 | x = torch.cat([x0, x1, x2], 1) 222 | x = x / 2. + 0.5 223 | return x 224 | 225 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for NVAE. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import argparse 9 | import torch 10 | import numpy as np 11 | import os 12 | import matplotlib.pyplot as plt 13 | from time import time 14 | 15 | from torch.multiprocessing import Process 16 | from torch.cuda.amp import autocast 17 | 18 | from model import AutoEncoder 19 | import utils 20 | import datasets 21 | from train import test, init_processes, test_vae_fid 22 | 23 | 24 | def set_bn(model, bn_eval_mode, num_samples=1, t=1.0, iter=100): 25 | if bn_eval_mode: 26 | model.eval() 27 | else: 28 | model.train() 29 | with autocast(): 30 | for i in range(iter): 31 | if i % 10 == 0: 32 | print('setting BN statistics iter %d out of %d' % (i+1, iter)) 33 | model.sample(num_samples, t) 34 | model.eval() 35 | 36 | 37 | def main(eval_args): 38 | # ensures that weight initializations are all the same 39 | logging = utils.Logger(eval_args.local_rank, eval_args.save) 40 | 41 | # load a checkpoint 42 | logging.info('loading the model at:') 43 | logging.info(eval_args.checkpoint) 44 | checkpoint = torch.load(eval_args.checkpoint, map_location='cpu') 45 | args = checkpoint['args'] 46 | 47 | if not hasattr(args, 'ada_groups'): 48 | logging.info('old model, no ada groups was found.') 49 | args.ada_groups = False 50 | 51 | if not hasattr(args, 'min_groups_per_scale'): 52 | logging.info('old model, no min_groups_per_scale was found.') 53 | args.min_groups_per_scale = 1 54 | 55 | if not hasattr(args, 'num_mixture_dec'): 56 | logging.info('old model, no num_mixture_dec was found.') 57 | args.num_mixture_dec = 10 58 | 59 | if eval_args.batch_size > 0: 60 | args.batch_size = eval_args.batch_size 61 | 62 | logging.info('loaded the model at epoch %d', checkpoint['epoch']) 63 | arch_instance = utils.get_arch_cells(args.arch_instance) 64 | model = AutoEncoder(args, None, arch_instance) 65 | # Loading is not strict because of self.weight_normalized in Conv2D class in neural_operations. This variable 66 | # is only used for computing the spectral normalization and it is safe not to load it. Some of our earlier models 67 | # did not have this variable. 68 | model.load_state_dict(checkpoint['state_dict'], strict=False) 69 | model = model.cuda() 70 | 71 | logging.info('args = %s', args) 72 | logging.info('num conv layers: %d', len(model.all_conv_layers)) 73 | logging.info('param size = %fM ', utils.count_parameters_in_M(model)) 74 | 75 | if eval_args.eval_mode == 'evaluate': 76 | # load train valid queue 77 | args.data = eval_args.data 78 | train_queue, valid_queue, num_classes = datasets.get_loaders(args) 79 | 80 | if eval_args.eval_on_train: 81 | logging.info('Using the training data for eval.') 82 | valid_queue = train_queue 83 | 84 | # get number of bits 85 | num_output = utils.num_output(args.dataset) 86 | bpd_coeff = 1. / np.log(2.) / num_output 87 | 88 | valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=eval_args.num_iw_samples, args=args, logging=logging) 89 | logging.info('final valid nelbo %f', valid_nelbo) 90 | logging.info('final valid neg log p %f', valid_neg_log_p) 91 | logging.info('final valid nelbo in bpd %f', valid_nelbo * bpd_coeff) 92 | logging.info('final valid neg log p in bpd %f', valid_neg_log_p * bpd_coeff) 93 | elif eval_args.eval_mode == 'evaluate_fid': 94 | bn_eval_mode = not eval_args.readjust_bn 95 | set_bn(model, bn_eval_mode, num_samples=2, t=eval_args.temp, iter=500) 96 | args.fid_dir = eval_args.fid_dir 97 | args.num_process_per_node, args.num_proc_node = eval_args.world_size, 1 # evaluate only one 1 node 98 | fid = test_vae_fid(model, args, total_fid_samples=50000) 99 | logging.info('fid is %f' % fid) 100 | else: 101 | bn_eval_mode = not eval_args.readjust_bn 102 | total_samples = 50000 // eval_args.world_size # num images per gpu 103 | num_samples = 100 # sampling batch size 104 | num_iter = int(np.ceil(total_samples / num_samples)) # num iterations per gpu 105 | 106 | with torch.no_grad(): 107 | n = int(np.floor(np.sqrt(num_samples))) 108 | set_bn(model, bn_eval_mode, num_samples=16, t=eval_args.temp, iter=500) 109 | for ind in range(num_iter): # sampling is repeated. 110 | torch.cuda.synchronize() 111 | start = time() 112 | with autocast(): 113 | logits = model.sample(num_samples, eval_args.temp) 114 | output = model.decoder_output(logits) 115 | output_img = output.mean if isinstance(output, torch.distributions.bernoulli.Bernoulli) \ 116 | else output.sample() 117 | torch.cuda.synchronize() 118 | end = time() 119 | logging.info('sampling time per batch: %0.3f sec', (end - start)) 120 | 121 | visualize = False 122 | if visualize: 123 | output_tiled = utils.tile_image(output_img, n).cpu().numpy().transpose(1, 2, 0) 124 | output_tiled = np.asarray(output_tiled * 255, dtype=np.uint8) 125 | output_tiled = np.squeeze(output_tiled) 126 | 127 | plt.imshow(output_tiled) 128 | plt.show() 129 | else: 130 | file_path = os.path.join(eval_args.save, 'gpu_%d_samples_%d.npz' % (eval_args.local_rank, ind)) 131 | np.savez_compressed(file_path, samples=output_img.cpu().numpy()) 132 | logging.info('Saved at: {}'.format(file_path)) 133 | 134 | 135 | if __name__ == '__main__': 136 | parser = argparse.ArgumentParser('encoder decoder examiner') 137 | # experimental results 138 | parser.add_argument('--checkpoint', type=str, default='/tmp/expr/checkpoint.pt', 139 | help='location of the checkpoint') 140 | parser.add_argument('--save', type=str, default='/tmp/expr', 141 | help='location of the checkpoint') 142 | parser.add_argument('--eval_mode', type=str, default='sample', choices=['sample', 'evaluate', 'evaluate_fid'], 143 | help='evaluation mode. you can choose between sample or evaluate.') 144 | parser.add_argument('--eval_on_train', action='store_true', default=False, 145 | help='Settings this to true will evaluate the model on training data.') 146 | parser.add_argument('--data', type=str, default='/tmp/data', 147 | help='location of the data corpus') 148 | parser.add_argument('--readjust_bn', action='store_true', default=False, 149 | help='adding this flag will enable readjusting BN statistics.') 150 | parser.add_argument('--temp', type=float, default=0.7, 151 | help='The temperature used for sampling.') 152 | parser.add_argument('--num_iw_samples', type=int, default=1000, 153 | help='The number of IW samples used in test_ll mode.') 154 | parser.add_argument('--fid_dir', type=str, default='/tmp/fid-stats', 155 | help='path to directory where fid related files are stored') 156 | parser.add_argument('--batch_size', type=int, default=0, 157 | help='Batch size used during evaluation. If set to zero, training batch size is used.') 158 | # DDP. 159 | parser.add_argument('--local_rank', type=int, default=0, 160 | help='rank of process') 161 | parser.add_argument('--world_size', type=int, default=1, 162 | help='number of gpus') 163 | parser.add_argument('--seed', type=int, default=1, 164 | help='seed used for initialization') 165 | parser.add_argument('--master_address', type=str, default='127.0.0.1', 166 | help='address for master') 167 | 168 | args = parser.parse_args() 169 | utils.create_exp_dir(args.save) 170 | 171 | size = args.world_size 172 | 173 | if size > 1: 174 | args.distributed = True 175 | processes = [] 176 | for rank in range(size): 177 | args.local_rank = rank 178 | p = Process(target=init_processes, args=(rank, size, main, args)) 179 | p.start() 180 | processes.append(p) 181 | 182 | for p in processes: 183 | p.join() 184 | else: 185 | # for debugging 186 | print('starting in debug mode') 187 | args.distributed = True 188 | init_processes(0, size, main, args) -------------------------------------------------------------------------------- /fid/LICENSE_pytorch_fid: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /fid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/NVAE/9fc1a288fb831c87d93a4e2663bc30ccf9225b29/fid/__init__.py -------------------------------------------------------------------------------- /fid/fid_score.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the pytorch_fid library 5 | # which was released under the Apache License v2.0 License. 6 | # 7 | # Source: 8 | # https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/fid_score.py 9 | # 10 | # The license for the original version of this file can be 11 | # found in this directory (LICENSE_pytorch_fid). The modifications 12 | # to this file are subject to the same Apache License. 13 | # --------------------------------------------------------------- 14 | 15 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 16 | 17 | The FID metric calculates the distance between two distributions of images. 18 | Typically, we have summary statistics (mean & covariance matrix) of one 19 | of these distributions, while the 2nd distribution is given by a GAN. 20 | 21 | When run as a stand-alone program, it compares the distribution of 22 | images that are stored as PNG/JPEG at a specified location with a 23 | distribution given by summary statistics (in pickle format). 24 | 25 | The FID is calculated by assuming that X_1 and X_2 are the activations of 26 | the pool_3 layer of the inception net for generated samples and real world 27 | samples respectively. 28 | 29 | See --help to see further details. 30 | 31 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 32 | of Tensorflow 33 | 34 | Copyright 2018 Institute of Bioinformatics, JKU Linz 35 | 36 | Licensed under the Apache License, Version 2.0 (the "License"); 37 | you may not use this file except in compliance with the License. 38 | You may obtain a copy of the License at 39 | 40 | http://www.apache.org/licenses/LICENSE-2.0 41 | 42 | Unless required by applicable law or agreed to in writing, software 43 | distributed under the License is distributed on an "AS IS" BASIS, 44 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 45 | See the License for the specific language governing permissions and 46 | limitations under the License. 47 | """ 48 | import os 49 | import pathlib 50 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 51 | from multiprocessing import cpu_count 52 | 53 | import numpy as np 54 | import torch 55 | from torch.utils.data import DataLoader 56 | import torchvision.transforms as TF 57 | from scipy import linalg 58 | from torch.nn.functional import adaptive_avg_pool2d 59 | from PIL import Image 60 | 61 | from fid.inception import InceptionV3 62 | 63 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 64 | parser.add_argument('--batch-size', type=int, default=50, 65 | help='Batch size to use') 66 | parser.add_argument('--device', type=str, default=None, 67 | help='Device to use. Like cuda, cuda:0 or cpu') 68 | parser.add_argument('--dims', type=int, default=2048, 69 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 70 | help=('Dimensionality of Inception features to use. ' 71 | 'By default, uses pool3 features')) 72 | parser.add_argument('path', type=str, nargs=2, 73 | help=('Paths to the generated images or ' 74 | 'to .npz statistic files')) 75 | 76 | 77 | class ImagesPathDataset(torch.utils.data.Dataset): 78 | def __init__(self, files, transforms=None): 79 | self.files = files 80 | self.transforms = transforms 81 | 82 | def __len__(self): 83 | return len(self.files) 84 | 85 | def __getitem__(self, i): 86 | path = self.files[i] 87 | img = Image.open(path).convert('RGB') 88 | if self.transforms is not None: 89 | img = self.transforms(img) 90 | return img 91 | 92 | 93 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', max_samples=None): 94 | """Calculates the activations of the pool_3 layer for all images. 95 | 96 | Params: 97 | -- files : List of image files paths or pytorch data loader 98 | -- model : Instance of inception model 99 | -- batch_size : Batch size of images for the model to process at once. 100 | Make sure that the number of samples is a multiple of 101 | the batch size, otherwise some samples are ignored. This 102 | behavior is retained to match the original FID score 103 | implementation. 104 | -- dims : Dimensionality of features returned by Inception 105 | -- device : Device to run calculations 106 | -- max_samples : Setting this value will stop activation when max_samples is reached 107 | 108 | Returns: 109 | -- A numpy array of dimension (num images, dims) that contains the 110 | activations of the given tensor when feeding inception with the 111 | query tensor. 112 | """ 113 | model.eval() 114 | 115 | if isinstance(files, list): 116 | if batch_size > len(files): 117 | print(('Warning: batch size is bigger than the data size. ' 118 | 'Setting batch size to data size')) 119 | batch_size = len(files) 120 | 121 | ds = ImagesPathDataset(files, transforms=TF.ToTensor()) 122 | dl = DataLoader(ds, batch_size=batch_size, 123 | drop_last=False, num_workers=cpu_count()) 124 | else: 125 | dl = files 126 | 127 | pred_arr = [] 128 | total_processed = 0 129 | 130 | print('Starting to sample.') 131 | for batch in dl: 132 | # ignore labels 133 | if isinstance(batch, list): 134 | batch = batch[0] 135 | 136 | batch = batch.to(device) 137 | if batch.shape[1] == 1: # if image is gray scale 138 | batch = batch.repeat(1, 3, 1, 1) 139 | 140 | with torch.no_grad(): 141 | pred = model(batch)[0] 142 | 143 | # If model output is not scalar, apply global spatial average pooling. 144 | # This happens if you choose a dimensionality not equal 2048. 145 | if pred.size(2) != 1 or pred.size(3) != 1: 146 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 147 | 148 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 149 | pred_arr.append(pred) 150 | total_processed += pred.shape[0] 151 | if max_samples is not None and total_processed > max_samples: 152 | print('Max Samples Reached.') 153 | break 154 | 155 | pred_arr = np.concatenate(pred_arr, axis=0) 156 | if max_samples is not None: 157 | pred_arr = pred_arr[:max_samples] 158 | 159 | return pred_arr 160 | 161 | 162 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 163 | """Numpy implementation of the Frechet Distance. 164 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 165 | and X_2 ~ N(mu_2, C_2) is 166 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 167 | 168 | Stable version by Dougal J. Sutherland. 169 | 170 | Params: 171 | -- mu1 : Numpy array containing the activations of a layer of the 172 | inception net (like returned by the function 'get_predictions') 173 | for generated samples. 174 | -- mu2 : The sample mean over activations, precalculated on an 175 | representative data set. 176 | -- sigma1: The covariance matrix over activations for generated samples. 177 | -- sigma2: The covariance matrix over activations, precalculated on an 178 | representative data set. 179 | 180 | Returns: 181 | -- : The Frechet Distance. 182 | """ 183 | 184 | mu1 = np.atleast_1d(mu1) 185 | mu2 = np.atleast_1d(mu2) 186 | 187 | sigma1 = np.atleast_2d(sigma1) 188 | sigma2 = np.atleast_2d(sigma2) 189 | 190 | assert mu1.shape == mu2.shape, \ 191 | 'Training and test mean vectors have different lengths' 192 | assert sigma1.shape == sigma2.shape, \ 193 | 'Training and test covariances have different dimensions' 194 | 195 | diff = mu1 - mu2 196 | 197 | # Product might be almost singular 198 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 199 | if not np.isfinite(covmean).all(): 200 | msg = ('fid calculation produces singular product; ' 201 | 'adding %s to diagonal of cov estimates') % eps 202 | print(msg) 203 | offset = np.eye(sigma1.shape[0]) * eps 204 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 205 | 206 | # Numerical error might give slight imaginary component 207 | if np.iscomplexobj(covmean): 208 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 209 | m = np.max(np.abs(covmean.imag)) 210 | raise ValueError('Imaginary component {}'.format(m)) 211 | covmean = covmean.real 212 | 213 | tr_covmean = np.trace(covmean) 214 | 215 | return (diff.dot(diff) + np.trace(sigma1) + 216 | np.trace(sigma2) - 2 * tr_covmean) 217 | 218 | 219 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, device='cpu', max_samples=None): 220 | """Calculation of the statistics used by the FID. 221 | Params: 222 | -- files : List of image files paths or pytorch data loader 223 | -- model : Instance of inception model 224 | -- batch_size : The images numpy array is split into batches with 225 | batch size batch_size. A reasonable batch size 226 | depends on the hardware. 227 | -- dims : Dimensionality of features returned by Inception 228 | -- device : Device to run calculations 229 | -- max_samples : Setting this value will stop activation when max_samples is reached 230 | 231 | Returns: 232 | -- mu : The mean over samples of the activations of the pool_3 layer of 233 | the inception model. 234 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 235 | the inception model. 236 | """ 237 | act = get_activations(files, model, batch_size, dims, device, max_samples) 238 | mu = np.mean(act, axis=0) 239 | sigma = np.cov(act, rowvar=False) 240 | return mu, sigma 241 | 242 | 243 | def _compute_statistics_of_path(path, model, batch_size, dims, device): 244 | if path.endswith('.npz'): 245 | f = np.load(path) 246 | m, s = f['mu'][:], f['sigma'][:] 247 | f.close() 248 | else: 249 | path = pathlib.Path(path) 250 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 251 | m, s = calculate_activation_statistics(files, model, batch_size, 252 | dims, device) 253 | 254 | return m, s 255 | 256 | 257 | def compute_statistics_of_generator(data_loader, model, batch_size, dims, device, max_samples=None): 258 | m, s = calculate_activation_statistics(data_loader, model, batch_size, 259 | dims, device, max_samples) 260 | 261 | return m, s 262 | 263 | 264 | def save_statistics(path, m, s): 265 | assert path.endswith('.npz') 266 | np.savez(path, mu=m, sigma=s) 267 | 268 | 269 | def load_statistics(path): 270 | assert path.endswith('.npz') 271 | f = np.load(path) 272 | m, s = f['mu'][:], f['sigma'][:] 273 | f.close() 274 | return m, s 275 | 276 | 277 | def calculate_fid_given_paths(paths, batch_size, device, dims): 278 | """Calculates the FID of two paths""" 279 | for p in paths: 280 | if not os.path.exists(p): 281 | raise RuntimeError('Invalid path: %s' % p) 282 | 283 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 284 | 285 | model = InceptionV3([block_idx]).to(device) 286 | 287 | m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, 288 | dims, device) 289 | m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, 290 | dims, device) 291 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 292 | 293 | return fid_value 294 | 295 | 296 | def main(): 297 | args = parser.parse_args() 298 | 299 | if args.device is None: 300 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 301 | else: 302 | device = torch.device(args.device) 303 | 304 | fid_value = calculate_fid_given_paths(args.path, 305 | args.batch_size, 306 | device, 307 | args.dims) 308 | print('FID: ', fid_value) 309 | 310 | 311 | if __name__ == '__main__': 312 | main() -------------------------------------------------------------------------------- /fid/inception.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the pytorch_fid library 5 | # which was released under the Apache License v2.0 License. 6 | # 7 | # Source: 8 | # https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py 9 | # 10 | # The license for the original version of this file can be 11 | # found in this directory (LICENSE_pytorch_fid). The modifications 12 | # to this file are subject to the same Apache License. 13 | # --------------------------------------------------------------- 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torchvision 19 | 20 | try: 21 | from torchvision.models.utils import load_state_dict_from_url 22 | except ImportError: 23 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 24 | 25 | # Inception weights ported to Pytorch from 26 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 27 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 28 | 29 | 30 | class InceptionV3(nn.Module): 31 | """Pretrained InceptionV3 network returning feature maps""" 32 | 33 | # Index of default block of inception to return, 34 | # corresponds to output of final average pooling 35 | DEFAULT_BLOCK_INDEX = 3 36 | 37 | # Maps feature dimensionality to their output blocks indices 38 | BLOCK_INDEX_BY_DIM = { 39 | 64: 0, # First max pooling features 40 | 192: 1, # Second max pooling featurs 41 | 768: 2, # Pre-aux classifier features 42 | 2048: 3 # Final average pooling features 43 | } 44 | 45 | def __init__(self, 46 | output_blocks=[DEFAULT_BLOCK_INDEX], 47 | resize_input=True, 48 | normalize_input=True, 49 | requires_grad=False, 50 | use_fid_inception=True, 51 | model_dir=None): 52 | """Build pretrained InceptionV3 53 | 54 | Parameters 55 | ---------- 56 | output_blocks : list of int 57 | Indices of blocks to return features of. Possible values are: 58 | - 0: corresponds to output of first max pooling 59 | - 1: corresponds to output of second max pooling 60 | - 2: corresponds to output which is fed to aux classifier 61 | - 3: corresponds to output of final average pooling 62 | resize_input : bool 63 | If true, bilinearly resizes input to width and height 299 before 64 | feeding input to model. As the network without fully connected 65 | layers is fully convolutional, it should be able to handle inputs 66 | of arbitrary size, so resizing might not be strictly needed 67 | normalize_input : bool 68 | If true, scales the input from range (0, 1) to the range the 69 | pretrained Inception network expects, namely (-1, 1) 70 | requires_grad : bool 71 | If true, parameters of the model require gradients. Possibly useful 72 | for finetuning the network 73 | use_fid_inception : bool 74 | If true, uses the pretrained Inception model used in Tensorflow's 75 | FID implementation. If false, uses the pretrained Inception model 76 | available in torchvision. The FID Inception model has different 77 | weights and a slightly different structure from torchvision's 78 | Inception model. If you want to compute FID scores, you are 79 | strongly advised to set this parameter to true to get comparable 80 | results. 81 | model_dir: is used for storing pretrained checkpoints 82 | """ 83 | super(InceptionV3, self).__init__() 84 | 85 | self.resize_input = resize_input 86 | self.normalize_input = normalize_input 87 | self.output_blocks = sorted(output_blocks) 88 | self.last_needed_block = max(output_blocks) 89 | 90 | assert self.last_needed_block <= 3, \ 91 | 'Last possible output block index is 3' 92 | 93 | self.blocks = nn.ModuleList() 94 | 95 | if use_fid_inception: 96 | inception = fid_inception_v3(model_dir) 97 | else: 98 | inception = _inception_v3(pretrained=True) 99 | 100 | # Block 0: input to maxpool1 101 | block0 = [ 102 | inception.Conv2d_1a_3x3, 103 | inception.Conv2d_2a_3x3, 104 | inception.Conv2d_2b_3x3, 105 | nn.MaxPool2d(kernel_size=3, stride=2) 106 | ] 107 | self.blocks.append(nn.Sequential(*block0)) 108 | 109 | # Block 1: maxpool1 to maxpool2 110 | if self.last_needed_block >= 1: 111 | block1 = [ 112 | inception.Conv2d_3b_1x1, 113 | inception.Conv2d_4a_3x3, 114 | nn.MaxPool2d(kernel_size=3, stride=2) 115 | ] 116 | self.blocks.append(nn.Sequential(*block1)) 117 | 118 | # Block 2: maxpool2 to aux classifier 119 | if self.last_needed_block >= 2: 120 | block2 = [ 121 | inception.Mixed_5b, 122 | inception.Mixed_5c, 123 | inception.Mixed_5d, 124 | inception.Mixed_6a, 125 | inception.Mixed_6b, 126 | inception.Mixed_6c, 127 | inception.Mixed_6d, 128 | inception.Mixed_6e, 129 | ] 130 | self.blocks.append(nn.Sequential(*block2)) 131 | 132 | # Block 3: aux classifier to final avgpool 133 | if self.last_needed_block >= 3: 134 | block3 = [ 135 | inception.Mixed_7a, 136 | inception.Mixed_7b, 137 | inception.Mixed_7c, 138 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 139 | ] 140 | self.blocks.append(nn.Sequential(*block3)) 141 | 142 | for param in self.parameters(): 143 | param.requires_grad = requires_grad 144 | 145 | def forward(self, inp): 146 | """Get Inception feature maps 147 | 148 | Parameters 149 | ---------- 150 | inp : torch.autograd.Variable 151 | Input tensor of shape Bx3xHxW. Values are expected to be in 152 | range (0, 1) 153 | 154 | Returns 155 | ------- 156 | List of torch.autograd.Variable, corresponding to the selected output 157 | block, sorted ascending by index 158 | """ 159 | outp = [] 160 | x = inp 161 | 162 | if self.resize_input: 163 | x = F.interpolate(x, 164 | size=(299, 299), 165 | mode='bilinear', 166 | align_corners=False) 167 | 168 | if self.normalize_input: 169 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 170 | 171 | for idx, block in enumerate(self.blocks): 172 | x = block(x) 173 | if idx in self.output_blocks: 174 | outp.append(x) 175 | 176 | if idx == self.last_needed_block: 177 | break 178 | 179 | return outp 180 | 181 | 182 | def _inception_v3(*args, **kwargs): 183 | """Wraps `torchvision.models.inception_v3` 184 | 185 | Skips default weight inititialization if supported by torchvision version. 186 | See https://github.com/mseitzer/pytorch-fid/issues/28. 187 | """ 188 | try: 189 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 190 | except ValueError: 191 | # Just a caution against weird version strings 192 | version = (0,) 193 | 194 | if version >= (0, 6): 195 | kwargs['init_weights'] = False 196 | 197 | return torchvision.models.inception_v3(*args, **kwargs) 198 | 199 | 200 | def fid_inception_v3(model_dir=None): 201 | """Build pretrained Inception model for FID computation 202 | 203 | The Inception model for FID computation uses a different set of weights 204 | and has a slightly different structure than torchvision's Inception. 205 | 206 | This method first constructs torchvision's Inception and then patches the 207 | necessary parts that are different in the FID Inception model. 208 | """ 209 | inception = _inception_v3(num_classes=1008, 210 | aux_logits=False, 211 | pretrained=False) 212 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 213 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 214 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 215 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 216 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 217 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 218 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 219 | inception.Mixed_7b = FIDInceptionE_1(1280) 220 | inception.Mixed_7c = FIDInceptionE_2(2048) 221 | 222 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, model_dir=model_dir, progress=True) 223 | inception.load_state_dict(state_dict) 224 | return inception 225 | 226 | 227 | class FIDInceptionA(torchvision.models.inception.InceptionA): 228 | """InceptionA block patched for FID computation""" 229 | def __init__(self, in_channels, pool_features): 230 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 231 | 232 | def forward(self, x): 233 | branch1x1 = self.branch1x1(x) 234 | 235 | branch5x5 = self.branch5x5_1(x) 236 | branch5x5 = self.branch5x5_2(branch5x5) 237 | 238 | branch3x3dbl = self.branch3x3dbl_1(x) 239 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 240 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 241 | 242 | # Patch: Tensorflow's average pool does not use the padded zero's in 243 | # its average calculation 244 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 245 | count_include_pad=False) 246 | branch_pool = self.branch_pool(branch_pool) 247 | 248 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 249 | return torch.cat(outputs, 1) 250 | 251 | 252 | class FIDInceptionC(torchvision.models.inception.InceptionC): 253 | """InceptionC block patched for FID computation""" 254 | def __init__(self, in_channels, channels_7x7): 255 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 256 | 257 | def forward(self, x): 258 | branch1x1 = self.branch1x1(x) 259 | 260 | branch7x7 = self.branch7x7_1(x) 261 | branch7x7 = self.branch7x7_2(branch7x7) 262 | branch7x7 = self.branch7x7_3(branch7x7) 263 | 264 | branch7x7dbl = self.branch7x7dbl_1(x) 265 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 266 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 267 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 268 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 269 | 270 | # Patch: Tensorflow's average pool does not use the padded zero's in 271 | # its average calculation 272 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 273 | count_include_pad=False) 274 | branch_pool = self.branch_pool(branch_pool) 275 | 276 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 277 | return torch.cat(outputs, 1) 278 | 279 | 280 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 281 | """First InceptionE block patched for FID computation""" 282 | def __init__(self, in_channels): 283 | super(FIDInceptionE_1, self).__init__(in_channels) 284 | 285 | def forward(self, x): 286 | branch1x1 = self.branch1x1(x) 287 | 288 | branch3x3 = self.branch3x3_1(x) 289 | branch3x3 = [ 290 | self.branch3x3_2a(branch3x3), 291 | self.branch3x3_2b(branch3x3), 292 | ] 293 | branch3x3 = torch.cat(branch3x3, 1) 294 | 295 | branch3x3dbl = self.branch3x3dbl_1(x) 296 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 297 | branch3x3dbl = [ 298 | self.branch3x3dbl_3a(branch3x3dbl), 299 | self.branch3x3dbl_3b(branch3x3dbl), 300 | ] 301 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 302 | 303 | # Patch: Tensorflow's average pool does not use the padded zero's in 304 | # its average calculation 305 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 306 | count_include_pad=False) 307 | branch_pool = self.branch_pool(branch_pool) 308 | 309 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 310 | return torch.cat(outputs, 1) 311 | 312 | 313 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 314 | """Second InceptionE block patched for FID computation""" 315 | def __init__(self, in_channels): 316 | super(FIDInceptionE_2, self).__init__(in_channels) 317 | 318 | def forward(self, x): 319 | branch1x1 = self.branch1x1(x) 320 | 321 | branch3x3 = self.branch3x3_1(x) 322 | branch3x3 = [ 323 | self.branch3x3_2a(branch3x3), 324 | self.branch3x3_2b(branch3x3), 325 | ] 326 | branch3x3 = torch.cat(branch3x3, 1) 327 | 328 | branch3x3dbl = self.branch3x3dbl_1(x) 329 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 330 | branch3x3dbl = [ 331 | self.branch3x3dbl_3a(branch3x3dbl), 332 | self.branch3x3dbl_3b(branch3x3dbl), 333 | ] 334 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 335 | 336 | # Patch: The FID Inception model uses max pooling instead of average 337 | # pooling. This is likely an error in this specific Inception 338 | # implementation, as other Inception models use average pooling here 339 | # (which matches the description in the paper). 340 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 341 | branch_pool = self.branch_pool(branch_pool) 342 | 343 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 344 | return torch.cat(outputs, 1) -------------------------------------------------------------------------------- /img/celebahq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/NVAE/9fc1a288fb831c87d93a4e2663bc30ccf9225b29/img/celebahq.png -------------------------------------------------------------------------------- /img/model_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/NVAE/9fc1a288fb831c87d93a4e2663bc30ccf9225b29/img/model_diagram.png -------------------------------------------------------------------------------- /lmdb_datasets.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for NVAE. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import torch.utils.data as data 9 | import numpy as np 10 | import lmdb 11 | import os 12 | import io 13 | from PIL import Image 14 | 15 | 16 | def num_samples(dataset, train): 17 | if dataset == 'celeba': 18 | return 27000 if train else 3000 19 | elif dataset == 'celeba64': 20 | return 162770 if train else 19867 21 | elif dataset == 'imagenet-oord': 22 | return 1281147 if train else 50000 23 | elif dataset == 'ffhq': 24 | return 63000 if train else 7000 25 | else: 26 | raise NotImplementedError('dataset %s is unknown' % dataset) 27 | 28 | 29 | class LMDBDataset(data.Dataset): 30 | def __init__(self, root, name='', train=True, transform=None, is_encoded=False): 31 | self.train = train 32 | self.name = name 33 | self.transform = transform 34 | if self.train: 35 | lmdb_path = os.path.join(root, 'train.lmdb') 36 | else: 37 | lmdb_path = os.path.join(root, 'validation.lmdb') 38 | self.data_lmdb = lmdb.open(lmdb_path, readonly=True, max_readers=1, 39 | lock=False, readahead=False, meminit=False) 40 | self.is_encoded = is_encoded 41 | 42 | def __getitem__(self, index): 43 | target = [0] 44 | with self.data_lmdb.begin(write=False, buffers=True) as txn: 45 | data = txn.get(str(index).encode()) 46 | if self.is_encoded: 47 | img = Image.open(io.BytesIO(data)) 48 | img = img.convert('RGB') 49 | else: 50 | img = np.asarray(data, dtype=np.uint8) 51 | # assume data is RGB 52 | size = int(np.sqrt(len(img) / 3)) 53 | img = np.reshape(img, (size, size, 3)) 54 | img = Image.fromarray(img, mode='RGB') 55 | 56 | if self.transform is not None: 57 | img = self.transform(img) 58 | 59 | return img, target 60 | 61 | def __len__(self): 62 | return num_samples(self.name, self.train) 63 | -------------------------------------------------------------------------------- /neural_ar_operations.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for NVAE. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | import numpy as np 13 | from collections import OrderedDict 14 | 15 | from neural_operations import ConvBNSwish, normalize_weight_jit 16 | 17 | AROPS = OrderedDict([ 18 | ('conv_3x3', lambda C, masked, zero_diag: ELUConv(C, C, 3, 1, 1, masked=masked, zero_diag=zero_diag)) 19 | ]) 20 | 21 | 22 | class Identity(nn.Module): 23 | def __init__(self, masked, zero_diag): 24 | super(Identity, self).__init__() 25 | if zero_diag: 26 | raise ValueError('Skip connection with zero diag is just a zero operation.') 27 | 28 | def forward(self, x): 29 | return x 30 | 31 | 32 | def channel_mask(c_in, g_in, c_out, zero_diag): 33 | assert c_in % c_out == 0 or c_out % c_in == 0, "%d - %d" % (c_in, c_out) 34 | assert g_in == 1 or g_in == c_in 35 | 36 | if g_in == 1: 37 | mask = np.ones([c_out, c_in], dtype=np.float32) 38 | if c_out >= c_in: 39 | ratio = c_out // c_in 40 | for i in range(c_in): 41 | mask[i * ratio:(i + 1) * ratio, i + 1:] = 0 42 | if zero_diag: 43 | mask[i * ratio:(i + 1) * ratio, i:i + 1] = 0 44 | else: 45 | ratio = c_in // c_out 46 | for i in range(c_out): 47 | mask[i:i + 1, (i + 1) * ratio:] = 0 48 | if zero_diag: 49 | mask[i:i + 1, i * ratio:(i + 1) * ratio:] = 0 50 | elif g_in == c_in: 51 | mask = np.ones([c_out, c_in // g_in], dtype=np.float32) 52 | if zero_diag: 53 | mask = 0. * mask 54 | 55 | return mask 56 | 57 | 58 | def create_conv_mask(kernel_size, c_in, g_in, c_out, zero_diag, mirror): 59 | m = (kernel_size - 1) // 2 60 | mask = np.ones([c_out, c_in // g_in, kernel_size, kernel_size], dtype=np.float32) 61 | mask[:, :, m:, :] = 0 62 | mask[:, :, m, :m] = 1 63 | mask[:, :, m, m] = channel_mask(c_in, g_in, c_out, zero_diag) 64 | if mirror: 65 | mask = np.copy(mask[:, :, ::-1, ::-1]) 66 | return mask 67 | 68 | 69 | def norm(t, dim): 70 | return torch.sqrt(torch.sum(t * t, dim)) 71 | 72 | 73 | class ARConv2d(nn.Conv2d): 74 | """Allows for weights as input.""" 75 | 76 | def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, 77 | masked=False, zero_diag=False, mirror=False): 78 | """ 79 | Args: 80 | use_shared (bool): Use weights for this layer or not? 81 | """ 82 | super(ARConv2d, self).__init__(C_in, C_out, kernel_size, stride, padding, dilation, groups, bias) 83 | 84 | self.masked = masked 85 | if self.masked: 86 | assert kernel_size % 2 == 1, 'kernel size should be an odd value.' 87 | self.mask = torch.from_numpy(create_conv_mask(kernel_size, C_in, groups, C_out, zero_diag, mirror)).cuda() 88 | init_mask = self.mask.cpu() 89 | else: 90 | self.mask = 1.0 91 | init_mask = 1.0 92 | 93 | # init weight normalizaition parameters 94 | init = torch.log(norm(self.weight * init_mask, dim=[1, 2, 3]).view(-1, 1, 1, 1) + 1e-2) 95 | self.log_weight_norm = nn.Parameter(init, requires_grad=True) 96 | self.weight_normalized = None 97 | 98 | def normalize_weight(self): 99 | weight = self.weight 100 | if self.masked: 101 | assert self.mask.size() == weight.size() 102 | weight = weight * self.mask 103 | 104 | # weight normalization 105 | weight = normalize_weight_jit(self.log_weight_norm, weight) 106 | return weight 107 | 108 | def forward(self, x): 109 | """ 110 | Args: 111 | x (torch.Tensor): of size (B, C_in, H, W). 112 | params (ConvParam): containing `weight` and `bias` (optional) of conv operation. 113 | """ 114 | self.weight_normalized = self.normalize_weight() 115 | bias = self.bias 116 | return F.conv2d(x, self.weight_normalized, bias, self.stride, 117 | self.padding, self.dilation, self.groups) 118 | 119 | 120 | class ELUConv(nn.Module): 121 | """ReLU + Conv2d + BN.""" 122 | 123 | def __init__(self, C_in, C_out, kernel_size, padding=0, dilation=1, masked=True, zero_diag=True, 124 | weight_init_coeff=1.0, mirror=False): 125 | super(ELUConv, self).__init__() 126 | self.conv_0 = ARConv2d(C_in, C_out, kernel_size, stride=1, padding=padding, bias=True, dilation=dilation, 127 | masked=masked, zero_diag=zero_diag, mirror=mirror) 128 | # change the initialized log weight norm 129 | self.conv_0.log_weight_norm.data += np.log(weight_init_coeff) 130 | 131 | def forward(self, x): 132 | """ 133 | Args: 134 | x (torch.Tensor): of size (B, C_in, H, W) 135 | """ 136 | out = F.elu(x) 137 | out = self.conv_0(out) 138 | return out 139 | 140 | 141 | class ARInvertedResidual(nn.Module): 142 | def __init__(self, inz, inf, ex=6, dil=1, k=5, mirror=False): 143 | super(ARInvertedResidual, self).__init__() 144 | hidden_dim = int(round(inz * ex)) 145 | padding = dil * (k - 1) // 2 146 | layers = [] 147 | layers.extend([ARConv2d(inz, hidden_dim, kernel_size=3, padding=1, masked=True, mirror=mirror, zero_diag=True), 148 | nn.ELU(inplace=True)]) 149 | layers.extend([ARConv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=k, padding=padding, dilation=dil, 150 | masked=True, mirror=mirror, zero_diag=False), 151 | nn.ELU(inplace=True)]) 152 | self.convz = nn.Sequential(*layers) 153 | self.hidden_dim = hidden_dim 154 | 155 | def forward(self, z, ftr): 156 | z = self.convz(z) 157 | return z 158 | 159 | 160 | class MixLogCDFParam(nn.Module): 161 | def __init__(self, num_z, num_mix, num_ftr, mirror): 162 | super(MixLogCDFParam, self).__init__() 163 | 164 | num_out = num_z * (3 * num_mix + 3) 165 | self.conv = ELUConv(num_ftr, num_out, kernel_size=1, padding=0, masked=True, zero_diag=False, 166 | weight_init_coeff=0.1, mirror=mirror) 167 | self.num_z = num_z 168 | self.num_mix = num_mix 169 | 170 | def forward(self, ftr): 171 | out = self.conv(ftr) 172 | b, c, h, w = out.size() 173 | out = out.view(b, self.num_z, c // self.num_z, h, w) 174 | m = self.num_mix 175 | logit_pi, mu, log_s, log_a, b, _ = torch.split(out, [m, m, m, 1, 1, 1], dim=2) # the last one is dummy 176 | return logit_pi, mu, log_s, log_a, b 177 | 178 | 179 | def mix_log_cdf_flow(z1, logit_pi, mu, log_s, log_a, b): 180 | # z b, n, 1, h, w 181 | # logit_pi b, n, k, h, w 182 | # mu b, n, k, h, w 183 | # log_s b, n, k, h, w 184 | # log_a b, n, 1, h, w 185 | # b b, n, 1, h, w 186 | 187 | log_s = torch.clamp(log_s, min=-7) 188 | 189 | z = z1.unsqueeze(dim=2) 190 | log_pi = torch.log_softmax(logit_pi, dim=2) # normalize log_pi 191 | u = - (z - mu) * torch.exp(-log_s) 192 | softplus_u = F.softplus(u) 193 | log_mix_cdf = log_pi - softplus_u 194 | log_one_minus_mix_cdf = log_mix_cdf + u 195 | log_mix_cdf = torch.logsumexp(log_mix_cdf, dim=2) 196 | log_one_minus_mix_cdf = torch.logsumexp(log_one_minus_mix_cdf, dim=2) 197 | 198 | log_a = log_a.squeeze_(dim=2) 199 | b = b.squeeze_(dim=2) 200 | new_z = torch.exp(log_a) * (log_mix_cdf - log_one_minus_mix_cdf) + b 201 | 202 | # compute log determinant Jac 203 | log_mix_pdf = torch.logsumexp(log_pi + u - log_s - 2 * softplus_u, dim=2) 204 | log_det = log_a - log_mix_cdf - log_one_minus_mix_cdf + log_mix_pdf 205 | 206 | return new_z, log_det 207 | -------------------------------------------------------------------------------- /neural_operations.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for NVAE. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from thirdparty.swish import Swish as SwishFN 13 | from thirdparty.inplaced_sync_batchnorm import SyncBatchNormSwish 14 | 15 | from utils import average_tensor 16 | from collections import OrderedDict 17 | 18 | BN_EPS = 1e-5 19 | SYNC_BN = True 20 | 21 | OPS = OrderedDict([ 22 | ('res_elu', lambda Cin, Cout, stride: ELUConv(Cin, Cout, 3, stride, 1)), 23 | ('res_bnelu', lambda Cin, Cout, stride: BNELUConv(Cin, Cout, 3, stride, 1)), 24 | ('res_bnswish', lambda Cin, Cout, stride: BNSwishConv(Cin, Cout, 3, stride, 1)), 25 | ('res_bnswish5', lambda Cin, Cout, stride: BNSwishConv(Cin, Cout, 3, stride, 2, 2)), 26 | ('mconv_e6k5g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=6, dil=1, k=5, g=0)), 27 | ('mconv_e3k5g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=3, dil=1, k=5, g=0)), 28 | ('mconv_e3k5g8', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=3, dil=1, k=5, g=8)), 29 | ('mconv_e6k11g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=6, dil=1, k=11, g=0)), 30 | ]) 31 | 32 | 33 | def get_skip_connection(C, stride, affine, channel_mult): 34 | if stride == 1: 35 | return Identity() 36 | elif stride == 2: 37 | return FactorizedReduce(C, int(channel_mult * C)) 38 | elif stride == -1: 39 | return nn.Sequential(UpSample(), Conv2D(C, int(C / channel_mult), kernel_size=1)) 40 | 41 | 42 | def norm(t, dim): 43 | return torch.sqrt(torch.sum(t * t, dim)) 44 | 45 | 46 | def logit(t): 47 | return torch.log(t) - torch.log(1 - t) 48 | 49 | 50 | def act(t): 51 | # The following implementation has lower memory. 52 | return SwishFN.apply(t) 53 | 54 | 55 | class Swish(nn.Module): 56 | def __init__(self): 57 | super(Swish, self).__init__() 58 | 59 | def forward(self, x): 60 | return act(x) 61 | 62 | @torch.jit.script 63 | def normalize_weight_jit(log_weight_norm, weight): 64 | n = torch.exp(log_weight_norm) 65 | wn = torch.sqrt(torch.sum(weight * weight, dim=[1, 2, 3])) # norm(w) 66 | weight = n * weight / (wn.view(-1, 1, 1, 1) + 1e-5) 67 | return weight 68 | 69 | 70 | class Conv2D(nn.Conv2d): 71 | """Allows for weights as input.""" 72 | 73 | def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, data_init=False, 74 | weight_norm=True): 75 | """ 76 | Args: 77 | use_shared (bool): Use weights for this layer or not? 78 | """ 79 | super(Conv2D, self).__init__(C_in, C_out, kernel_size, stride, padding, dilation, groups, bias) 80 | 81 | self.log_weight_norm = None 82 | if weight_norm: 83 | init = norm(self.weight, dim=[1, 2, 3]).view(-1, 1, 1, 1) 84 | self.log_weight_norm = nn.Parameter(torch.log(init + 1e-2), requires_grad=True) 85 | 86 | self.data_init = data_init 87 | self.init_done = False 88 | self.weight_normalized = self.normalize_weight() 89 | 90 | def forward(self, x): 91 | """ 92 | Args: 93 | x (torch.Tensor): of size (B, C_in, H, W). 94 | params (ConvParam): containing `weight` and `bias` (optional) of conv operation. 95 | """ 96 | # do data based initialization 97 | if self.data_init and not self.init_done: 98 | with torch.no_grad(): 99 | weight = self.weight / (norm(self.weight, dim=[1, 2, 3]).view(-1, 1, 1, 1) + 1e-5) 100 | bias = None 101 | out = F.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups) 102 | mn = torch.mean(out, dim=[0, 2, 3]) 103 | st = 5 * torch.std(out, dim=[0, 2, 3]) 104 | 105 | # get mn and st from other GPUs 106 | average_tensor(mn, is_distributed=True) 107 | average_tensor(st, is_distributed=True) 108 | 109 | if self.bias is not None: 110 | self.bias.data = - mn / (st + 1e-5) 111 | self.log_weight_norm.data = -torch.log((st.view(-1, 1, 1, 1) + 1e-5)) 112 | self.init_done = True 113 | 114 | self.weight_normalized = self.normalize_weight() 115 | 116 | bias = self.bias 117 | return F.conv2d(x, self.weight_normalized, bias, self.stride, 118 | self.padding, self.dilation, self.groups) 119 | 120 | def normalize_weight(self): 121 | """ applies weight normalization """ 122 | if self.log_weight_norm is not None: 123 | weight = normalize_weight_jit(self.log_weight_norm, self.weight) 124 | else: 125 | weight = self.weight 126 | 127 | return weight 128 | 129 | 130 | class Identity(nn.Module): 131 | def __init__(self): 132 | super(Identity, self).__init__() 133 | 134 | def forward(self, x): 135 | return x 136 | 137 | 138 | class SyncBatchNorm(nn.Module): 139 | def __init__(self, *args, **kwargs): 140 | super(SyncBatchNorm, self).__init__() 141 | self.bn = nn.SyncBatchNorm(*args, **kwargs) 142 | 143 | def forward(self, x): 144 | # Sync BN only works with distributed data parallel with 1 GPU per process. I don't use DDP, so I need to let 145 | # Sync BN to know that I have 1 gpu per process. 146 | self.bn.ddp_gpu_size = 1 147 | return self.bn(x) 148 | 149 | 150 | # quick switch between multi-gpu, single-gpu batch norm 151 | def get_batchnorm(*args, **kwargs): 152 | if SYNC_BN: 153 | return SyncBatchNorm(*args, **kwargs) 154 | else: 155 | return nn.BatchNorm2d(*args, **kwargs) 156 | 157 | 158 | class ELUConv(nn.Module): 159 | def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1): 160 | super(ELUConv, self).__init__() 161 | self.upsample = stride == -1 162 | stride = abs(stride) 163 | self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation, 164 | data_init=True) 165 | 166 | def forward(self, x): 167 | out = F.elu(x) 168 | if self.upsample: 169 | out = F.interpolate(out, scale_factor=2, mode='nearest') 170 | out = self.conv_0(out) 171 | return out 172 | 173 | 174 | class BNELUConv(nn.Module): 175 | def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1): 176 | super(BNELUConv, self).__init__() 177 | self.upsample = stride == -1 178 | stride = abs(stride) 179 | self.bn = get_batchnorm(C_in, eps=BN_EPS, momentum=0.05) 180 | self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation) 181 | 182 | def forward(self, x): 183 | x = self.bn(x) 184 | out = F.elu(x) 185 | if self.upsample: 186 | out = F.interpolate(out, scale_factor=2, mode='nearest') 187 | out = self.conv_0(out) 188 | return out 189 | 190 | 191 | class BNSwishConv(nn.Module): 192 | """ReLU + Conv2d + BN.""" 193 | 194 | def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1): 195 | super(BNSwishConv, self).__init__() 196 | self.upsample = stride == -1 197 | stride = abs(stride) 198 | self.bn_act = SyncBatchNormSwish(C_in, eps=BN_EPS, momentum=0.05) 199 | self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation) 200 | 201 | def forward(self, x): 202 | """ 203 | Args: 204 | x (torch.Tensor): of size (B, C_in, H, W) 205 | """ 206 | out = self.bn_act(x) 207 | if self.upsample: 208 | out = F.interpolate(out, scale_factor=2, mode='nearest') 209 | out = self.conv_0(out) 210 | return out 211 | 212 | 213 | 214 | class FactorizedReduce(nn.Module): 215 | def __init__(self, C_in, C_out): 216 | super(FactorizedReduce, self).__init__() 217 | assert C_out % 2 == 0 218 | self.conv_1 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True) 219 | self.conv_2 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True) 220 | self.conv_3 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True) 221 | self.conv_4 = Conv2D(C_in, C_out - 3 * (C_out // 4), 1, stride=2, padding=0, bias=True) 222 | 223 | def forward(self, x): 224 | out = act(x) 225 | conv1 = self.conv_1(out) 226 | conv2 = self.conv_2(out[:, :, 1:, 1:]) 227 | conv3 = self.conv_3(out[:, :, :, 1:]) 228 | conv4 = self.conv_4(out[:, :, 1:, :]) 229 | out = torch.cat([conv1, conv2, conv3, conv4], dim=1) 230 | return out 231 | 232 | 233 | class UpSample(nn.Module): 234 | def __init__(self): 235 | super(UpSample, self).__init__() 236 | pass 237 | 238 | def forward(self, x): 239 | return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 240 | 241 | 242 | class EncCombinerCell(nn.Module): 243 | def __init__(self, Cin1, Cin2, Cout, cell_type): 244 | super(EncCombinerCell, self).__init__() 245 | self.cell_type = cell_type 246 | # Cin = Cin1 + Cin2 247 | self.conv = Conv2D(Cin2, Cout, kernel_size=1, stride=1, padding=0, bias=True) 248 | 249 | def forward(self, x1, x2): 250 | x2 = self.conv(x2) 251 | out = x1 + x2 252 | return out 253 | 254 | 255 | # original combiner 256 | class DecCombinerCell(nn.Module): 257 | def __init__(self, Cin1, Cin2, Cout, cell_type): 258 | super(DecCombinerCell, self).__init__() 259 | self.cell_type = cell_type 260 | self.conv = Conv2D(Cin1 + Cin2, Cout, kernel_size=1, stride=1, padding=0, bias=True) 261 | 262 | def forward(self, x1, x2): 263 | out = torch.cat([x1, x2], dim=1) 264 | out = self.conv(out) 265 | return out 266 | 267 | 268 | class ConvBNSwish(nn.Module): 269 | def __init__(self, Cin, Cout, k=3, stride=1, groups=1, dilation=1): 270 | padding = dilation * (k - 1) // 2 271 | super(ConvBNSwish, self).__init__() 272 | 273 | self.conv = nn.Sequential( 274 | Conv2D(Cin, Cout, k, stride, padding, groups=groups, bias=False, dilation=dilation, weight_norm=False), 275 | SyncBatchNormSwish(Cout, eps=BN_EPS, momentum=0.05) # drop in replacement for BN + Swish 276 | ) 277 | 278 | def forward(self, x): 279 | return self.conv(x) 280 | 281 | 282 | class SE(nn.Module): 283 | def __init__(self, Cin, Cout): 284 | super(SE, self).__init__() 285 | num_hidden = max(Cout // 16, 4) 286 | self.se = nn.Sequential(nn.Linear(Cin, num_hidden), nn.ReLU(inplace=True), 287 | nn.Linear(num_hidden, Cout), nn.Sigmoid()) 288 | 289 | def forward(self, x): 290 | se = torch.mean(x, dim=[2, 3]) 291 | se = se.view(se.size(0), -1) 292 | se = self.se(se) 293 | se = se.view(se.size(0), -1, 1, 1) 294 | return x * se 295 | 296 | 297 | class InvertedResidual(nn.Module): 298 | def __init__(self, Cin, Cout, stride, ex, dil, k, g): 299 | super(InvertedResidual, self).__init__() 300 | self.stride = stride 301 | assert stride in [1, 2, -1] 302 | 303 | hidden_dim = int(round(Cin * ex)) 304 | self.use_res_connect = self.stride == 1 and Cin == Cout 305 | self.upsample = self.stride == -1 306 | self.stride = abs(self.stride) 307 | groups = hidden_dim if g == 0 else g 308 | 309 | layers0 = [nn.UpsamplingNearest2d(scale_factor=2)] if self.upsample else [] 310 | layers = [get_batchnorm(Cin, eps=BN_EPS, momentum=0.05), 311 | ConvBNSwish(Cin, hidden_dim, k=1), 312 | ConvBNSwish(hidden_dim, hidden_dim, stride=self.stride, groups=groups, k=k, dilation=dil), 313 | Conv2D(hidden_dim, Cout, 1, 1, 0, bias=False, weight_norm=False), 314 | get_batchnorm(Cout, momentum=0.05)] 315 | 316 | layers0.extend(layers) 317 | self.conv = nn.Sequential(*layers0) 318 | 319 | def forward(self, x): 320 | return self.conv(x) 321 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchvision==0.7.0 3 | pillow 4 | matplotlib 5 | tensorboard 6 | tensorboardX 7 | lmdb 8 | tfrecord -------------------------------------------------------------------------------- /scripts/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:20.07-py3 2 | 3 | RUN pip install pillow 4 | RUN pip install cython 5 | RUN pip install matplotlib 6 | RUN pip install tensorboard 7 | RUN pip install tensorboardX 8 | RUN pip install tfrecord 9 | RUN apt update 10 | RUN apt install screen -y 11 | -------------------------------------------------------------------------------- /scripts/convert_tfrecord_to_lmdb.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for NVAE. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | import argparse 8 | import torch 9 | import lmdb 10 | import os 11 | 12 | from tfrecord.torch.dataset import TFRecordDataset 13 | 14 | 15 | def main(dataset, split, tfr_path, lmdb_path): 16 | assert split in {'train', 'validation'} 17 | 18 | # create target directory 19 | if not os.path.exists(lmdb_path): 20 | os.makedirs(lmdb_path, exist_ok=True) 21 | if dataset == 'celeba' and split in {'train', 'validation'}: 22 | num_shards = {'train': 120, 'validation': 40}[split] 23 | lmdb_path = os.path.join(lmdb_path, '%s.lmdb' % split) 24 | tfrecord_path_template = os.path.join(tfr_path, '%s/%s-r08-s-%04d-of-%04d.tfrecords') 25 | elif dataset == 'imagenet-oord_32': 26 | num_shards = {'train': 2000, 'validation': 80}[split] 27 | # imagenet_oord_lmdb_path += '_32' 28 | lmdb_path = os.path.join(lmdb_path, '%s.lmdb' % split) 29 | tfrecord_path_template = os.path.join(tfr_path, '%s/%s-r05-s-%04d-of-%04d.tfrecords') 30 | elif dataset == 'imagenet-oord_64': 31 | num_shards = {'train': 2000, 'validation': 80}[split] 32 | # imagenet_oord_lmdb_path += '_64' 33 | lmdb_path = os.path.join(lmdb_path, '%s.lmdb' % split) 34 | tfrecord_path_template = os.path.join(tfr_path, '%s/%s-r06-s-%04d-of-%04d.tfrecords') 35 | else: 36 | raise NotImplementedError 37 | 38 | # create lmdb 39 | env = lmdb.open(lmdb_path, map_size=1e12) 40 | count = 0 41 | with env.begin(write=True) as txn: 42 | for tf_ind in range(num_shards): 43 | # read tf_record 44 | tfrecord_path = tfrecord_path_template % (split, split, tf_ind, num_shards) 45 | index_path = None 46 | description = {'shape': 'int', 'data': 'byte', 'label': 'int'} 47 | dataset = TFRecordDataset(tfrecord_path, index_path, description) 48 | loader = torch.utils.data.DataLoader(dataset, batch_size=1) 49 | 50 | # put the data in lmdb 51 | for data in loader: 52 | im = data['data'][0].cpu().numpy() 53 | txn.put(str(count).encode(), im) 54 | count += 1 55 | if count % 100 == 0: 56 | print(count) 57 | 58 | print('added %d items to the LMDB dataset.' % count) 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser('LMDB creator using TFRecords from GLOW.') 63 | # experimental results 64 | parser.add_argument('--dataset', type=str, default='imagenet-oord_32', 65 | help='dataset name', choices=['imagenet-oord_32', 'imagenet-oord_32', 'celeba']) 66 | parser.add_argument('--tfr_path', type=str, default='/data1/datasets/imagenet-oord/mnt/host/imagenet-oord-tfr', 67 | help='location of TFRecords') 68 | parser.add_argument('--lmdb_path', type=str, default='/data1/datasets/imagenet-oord/imagenet-oord-lmdb_32', 69 | help='target location for storing lmdb files') 70 | parser.add_argument('--split', type=str, default='train', 71 | help='training or validation split', choices=['train', 'validation']) 72 | args = parser.parse_args() 73 | main(args.dataset, args.split, args.tfr_path, args.lmdb_path) 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /scripts/create_celeba64_lmdb.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for NVAE. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import argparse 9 | import lmdb 10 | import os 11 | import torchvision.datasets as dset 12 | 13 | 14 | def main(split, img_path, lmdb_path): 15 | assert split in {"train", "valid", "test"} 16 | # create target directory 17 | if not os.path.exists(lmdb_path): 18 | os.makedirs(lmdb_path, exist_ok=True) 19 | 20 | lmdb_split = {'train': 'train', 'valid': 'validation', 'test': 'test'}[split] 21 | lmdb_path = os.path.join(lmdb_path, '%s.lmdb' % lmdb_split) 22 | 23 | # if you don't have this will download the data 24 | data = dset.celeba.CelebA(root=img_path, split=split, target_type='attr', transform=None, download=True) 25 | print(len('total data')) 26 | 27 | # create lmdb 28 | env = lmdb.open(lmdb_path, map_size=1e12) 29 | with env.begin(write=True) as txn: 30 | for i in range(len(data)): 31 | file_path = os.path.join(data.root, data.base_folder, "img_align_celeba", data.filename[i]) 32 | attr = data.attr[i, :] 33 | with open(file_path, 'rb') as f: 34 | file_data = f.read() 35 | 36 | txn.put(str(i).encode(), file_data) 37 | print(i) 38 | 39 | 40 | if __name__ == '__main__': 41 | parser = argparse.ArgumentParser('CelebA 64 LMDB creator.') 42 | # experimental results 43 | parser.add_argument('--img_path', type=str, default='/data1/datasets/celeba_org/', 44 | help='location of images for CelebA dataset') 45 | parser.add_argument('--lmdb_path', type=str, default='/data1/datasets/celeba_org/celeba64_lmdb', 46 | help='target location for storing lmdb files') 47 | parser.add_argument('--split', type=str, default='train', 48 | help='training or validation split', choices=["train", "valid", "test"]) 49 | args = parser.parse_args() 50 | main(args.split, args.img_path, args.lmdb_path) 51 | 52 | -------------------------------------------------------------------------------- /scripts/create_ffhq_lmdb.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for NVAE. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import argparse 9 | import torch 10 | import numpy as np 11 | import lmdb 12 | import os 13 | 14 | from PIL import Image 15 | 16 | 17 | def main(split, ffhq_img_path, ffhq_lmdb_path): 18 | assert split in {'train', 'validation'} 19 | num_images = 70000 20 | num_train = 63000 21 | 22 | # create target directory 23 | if not os.path.exists(ffhq_lmdb_path): 24 | os.makedirs(ffhq_lmdb_path, exist_ok=True) 25 | 26 | ind_path = os.path.join(ffhq_lmdb_path, 'train_test_ind.pt') 27 | if os.path.exists(ind_path): 28 | ind_dat = torch.load(ind_path) 29 | train_ind = ind_dat['train'] 30 | test_ind = ind_dat['test'] 31 | else: 32 | rand = np.random.permutation(num_images) 33 | train_ind = rand[:num_train] 34 | test_ind = rand[num_train:] 35 | torch.save({'train': train_ind, 'test': test_ind}, ind_path) 36 | 37 | file_ind = train_ind if split == 'train' else test_ind 38 | lmdb_path = os.path.join(ffhq_lmdb_path, '%s.lmdb' % split) 39 | 40 | # create lmdb 41 | env = lmdb.open(lmdb_path, map_size=1e12) 42 | count = 0 43 | with env.begin(write=True) as txn: 44 | for i in file_ind: 45 | img_path = os.path.join(ffhq_img_path, '%05d.png' % i) 46 | im = Image.open(img_path) 47 | im = im.resize(size=(256, 256), resample=Image.BILINEAR) 48 | im = np.array(im.getdata(), dtype=np.uint8).reshape(im.size[1], im.size[0], 3) 49 | 50 | txn.put(str(count).encode(), im) 51 | count += 1 52 | if count % 100 == 0: 53 | print(count) 54 | 55 | print('added %d items to the LMDB dataset.' % count) 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser('FFHQ LMDB creator. Download images1024x1024.zip from here and unzip it \n' 60 | 'https://drive.google.com/drive/folders/1WocxvZ4GEZ1DI8dOz30aSj2zT6pkATYS') 61 | # experimental results 62 | parser.add_argument('--ffhq_img_path', type=str, default='/data1/datasets/ffhq/images1024x1024', 63 | help='location of images from FFHQ') 64 | parser.add_argument('--ffhq_lmdb_path', type=str, default='/data1/datasets/ffhq/ffhq-lmdb', 65 | help='target location for storing lmdb files') 66 | parser.add_argument('--split', type=str, default='train', 67 | help='training or validation split', choices=['train', 'validation']) 68 | args = parser.parse_args() 69 | 70 | main(args.split, args.ffhq_img_path, args.ffhq_lmdb_path) 71 | 72 | -------------------------------------------------------------------------------- /scripts/precompute_fid_statistics.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for NVAE. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | import os 8 | import argparse 9 | from fid.fid_score import compute_statistics_of_generator, save_statistics 10 | from datasets import get_loaders_eval 11 | from fid.inception import InceptionV3 12 | from itertools import chain 13 | 14 | 15 | def main(args): 16 | device = 'cuda' 17 | dims = 2048 18 | # for binary datasets including MNIST and OMNIGLOT, we don't apply binarization for FID computation 19 | train_queue, valid_queue, _ = get_loaders_eval(args.dataset, args) 20 | print('len train queue', len(train_queue), 'len val queue', len(valid_queue), 'batch size', args.batch_size) 21 | if args.dataset in {'celeba_256', 'omniglot'}: 22 | train_queue = chain(train_queue, valid_queue) 23 | 24 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 25 | model = InceptionV3([block_idx], model_dir=args.fid_dir).to(device) 26 | m, s = compute_statistics_of_generator(train_queue, model, args.batch_size, dims, device, args.max_samples) 27 | file_path = os.path.join(args.fid_dir, args.dataset + '.npz') 28 | print('saving fid stats at %s' % file_path) 29 | save_statistics(file_path, m, s) 30 | 31 | 32 | if __name__ == '__main__': 33 | # python precompute_fid_statistics.py --dataset cifar10 34 | parser = argparse.ArgumentParser('') 35 | parser.add_argument('--dataset', type=str, default='cifar10', 36 | choices=['cifar10', 'celeba_64', 'celeba_256', 'omniglot', 'mnist', 37 | 'imagenet_32', 'ffhq', 'lsun_bedroom_128', 'lsun_church_256'], 38 | help='which dataset to use') 39 | parser.add_argument('--data', type=str, default='/tmp/nvae-diff/data', 40 | help='location of the data corpus') 41 | parser.add_argument('--batch_size', type=int, default=64, 42 | help='batch size per GPU') 43 | parser.add_argument('--max_samples', type=int, default=50000, 44 | help='batch size per GPU') 45 | parser.add_argument('--fid_dir', type=str, default='/tmp/fid-stats', 46 | help='A dir to store fid related files') 47 | 48 | args = parser.parse_args() 49 | args.distributed = False 50 | 51 | main(args) -------------------------------------------------------------------------------- /thirdparty/LICENSE_PyTorch: -------------------------------------------------------------------------------- 1 | From PyTorch: 2 | 3 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 4 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 5 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 6 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 7 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 8 | Copyright (c) 2011-2013 NYU (Clement Farabet) 9 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 10 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 11 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 12 | 13 | From Caffe2: 14 | 15 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 16 | 17 | All contributions by Facebook: 18 | Copyright (c) 2016 Facebook Inc. 19 | 20 | All contributions by Google: 21 | Copyright (c) 2015 Google Inc. 22 | All rights reserved. 23 | 24 | All contributions by Yangqing Jia: 25 | Copyright (c) 2015 Yangqing Jia 26 | All rights reserved. 27 | 28 | All contributions from Caffe: 29 | Copyright(c) 2013, 2014, 2015, the respective contributors 30 | All rights reserved. 31 | 32 | All other contributions: 33 | Copyright(c) 2015, 2016 the respective contributors 34 | All rights reserved. 35 | 36 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 37 | copyright over their contributions to Caffe2. The project versioning records 38 | all such contribution and copyright details. If a contributor wants to further 39 | mark their specific copyright on a particular contribution, they should 40 | indicate their copyright solely in the commit message of the change when it is 41 | committed. 42 | 43 | All rights reserved. 44 | 45 | Redistribution and use in source and binary forms, with or without 46 | modification, are permitted provided that the following conditions are met: 47 | 48 | 1. Redistributions of source code must retain the above copyright 49 | notice, this list of conditions and the following disclaimer. 50 | 51 | 2. Redistributions in binary form must reproduce the above copyright 52 | notice, this list of conditions and the following disclaimer in the 53 | documentation and/or other materials provided with the distribution. 54 | 55 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 56 | and IDIAP Research Institute nor the names of its contributors may be 57 | used to endorse or promote products derived from this software without 58 | specific prior written permission. 59 | 60 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 61 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 62 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 63 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 64 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 65 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 66 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 67 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 68 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 69 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 70 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /thirdparty/LICENSE_apache: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /thirdparty/LICENSE_torchvision: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) Soumith Chintala 2016, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /thirdparty/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/NVAE/9fc1a288fb831c87d93a4e2663bc30ccf9225b29/thirdparty/__init__.py -------------------------------------------------------------------------------- /thirdparty/adamax.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the PyTorch library. 5 | # 6 | # Source: 7 | # https://github.com/pytorch/pytorch/blob/6e2bb1c05442010aff90b413e21fce99f0393727/torch/optim/adamax.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in this directory (LICENSE_PyTorch). The modifications 11 | # to this file are subject to the NVIDIA Source Code License for 12 | # NVAE located at the root directory. 13 | # --------------------------------------------------------------- 14 | 15 | import torch 16 | from torch.optim import Optimizer 17 | 18 | @torch.jit.script 19 | def fusion1(exp_avg :torch.Tensor, grad :torch.Tensor, beta1: float): 20 | return exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 21 | 22 | 23 | class Adamax(Optimizer): 24 | """Implements Adamax algorithm (a variant of Adam based on infinity norm). 25 | 26 | It has been proposed in `Adam: A Method for Stochastic Optimization`__. 27 | 28 | Arguments: 29 | params (iterable): iterable of parameters to optimize or dicts defining 30 | parameter groups 31 | lr (float, optional): learning rate (default: 2e-3) 32 | betas (Tuple[float, float], optional): coefficients used for computing 33 | running averages of gradient and its square 34 | eps (float, optional): term added to the denominator to improve 35 | numerical stability (default: 1e-8) 36 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 37 | 38 | __ https://arxiv.org/abs/1412.6980 39 | """ 40 | 41 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 42 | weight_decay=0): 43 | if not 0.0 <= lr: 44 | raise ValueError("Invalid learning rate: {}".format(lr)) 45 | if not 0.0 <= eps: 46 | raise ValueError("Invalid epsilon value: {}".format(eps)) 47 | if not 0.0 <= betas[0] < 1.0: 48 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 49 | if not 0.0 <= betas[1] < 1.0: 50 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 51 | if not 0.0 <= weight_decay: 52 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 53 | 54 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 55 | super(Adamax, self).__init__(params, defaults) 56 | 57 | def step(self, closure=None): 58 | """Performs a single optimization step. 59 | 60 | Arguments: 61 | closure (callable, optional): A closure that reevaluates the model 62 | and returns the loss. 63 | """ 64 | loss = None 65 | if closure is not None: 66 | loss = closure() 67 | 68 | params, grads, exp_avg, exp_inf = {},{},{},{} 69 | 70 | for group in self.param_groups: 71 | for i, p in enumerate(group['params']): 72 | if p.grad is None: 73 | continue 74 | grad = p.grad.data 75 | if grad.is_sparse: 76 | raise RuntimeError('Adamax does not support sparse gradients') 77 | state = self.state[p] 78 | 79 | # State initialization 80 | if len(state) == 0: 81 | state['step'] = 0 82 | # state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format) 83 | # state['exp_inf'] = torch.zeros_like(p.data, memory_format=torch.preserve_format) 84 | state['exp_avg'] = torch.zeros_like(p.data) 85 | state['exp_inf'] = torch.zeros_like(p.data) 86 | 87 | state['step'] += 1 88 | 89 | if p.shape not in params: 90 | params[p.shape] = {'idx': 0, 'data': []} 91 | grads[p.shape] = [] 92 | exp_avg[p.shape] = [] 93 | exp_inf[p.shape] = [] 94 | 95 | params[p.shape]['data'].append(p.data) 96 | grads[p.shape].append(grad) 97 | exp_avg[p.shape].append(state['exp_avg']) 98 | exp_inf[p.shape].append(state['exp_inf']) 99 | 100 | for i in params: 101 | params[i]['data'] = torch.stack(params[i]['data'], dim=0) 102 | grads[i] = torch.stack(grads[i], dim=0) 103 | exp_avg[i] = torch.stack(exp_avg[i], dim=0) 104 | exp_inf[i] = torch.stack(exp_inf[i], dim=0) 105 | 106 | for group in self.param_groups: 107 | beta1, beta2 = group['betas'] 108 | eps = group['eps'] 109 | bias_correction = 1 - beta1 ** self.state[group['params'][0]]['step'] 110 | clr = group['lr'] / bias_correction 111 | 112 | for i in params: 113 | if group['weight_decay'] != 0: 114 | grads[i] = grads[i].add_(params[i]['data'], alpha=group['weight_decay']) 115 | # Update biased first moment estimate. 116 | exp_avg[i].mul_(beta1).add_(grads[i], alpha=1 - beta1) 117 | # Update the exponentially weighted infinity norm. 118 | torch.max(exp_inf[i].mul_(beta2), grads[i].abs_().add_(eps), out=exp_inf[i]) 119 | params[i]['data'].addcdiv_(exp_avg[i], exp_inf[i], value=-clr) 120 | 121 | for group in self.param_groups: 122 | for p in group['params']: 123 | idx = params[p.shape]['idx'] 124 | p.data = params[p.shape]['data'][idx, :] 125 | self.state[p]['exp_avg'] = exp_avg[p.shape][idx, :] 126 | self.state[p]['exp_inf'] = exp_inf[p.shape][idx, :] 127 | params[p.shape]['idx'] += 1 128 | 129 | return loss 130 | -------------------------------------------------------------------------------- /thirdparty/functions.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the PyTorch library. 5 | # 6 | # Source: 7 | # https://github.com/pytorch/pytorch/blob/2a54533c64c409b626b6c209ed78258f67aec194/torch/nn/modules/_functions.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in this directory (LICENSE_PyTorch). The modifications 11 | # to this file are subject to the NVIDIA Source Code License for 12 | # NVAE located at the root directory. 13 | # --------------------------------------------------------------- 14 | 15 | import torch 16 | from torch.autograd.function import Function 17 | import torch.distributed as dist 18 | 19 | 20 | class SyncBatchNorm(Function): 21 | 22 | @staticmethod 23 | def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): 24 | input = input.contiguous() 25 | 26 | count = torch.empty(1, 27 | dtype=running_mean.dtype, 28 | device=input.device).fill_(input.numel() // input.size(1)) 29 | 30 | # calculate mean/invstd for input. 31 | mean, invstd = torch.batch_norm_stats(input, eps) 32 | 33 | num_channels = input.shape[1] 34 | # C, C, 1 -> (2C + 1) 35 | combined = torch.cat([mean, invstd, count], dim=0) 36 | # world_size * (2C + 1) 37 | combined_list = [ 38 | torch.empty_like(combined) for k in range(world_size) 39 | ] 40 | # Use allgather instead of allreduce since I don't trust in-place operations .. 41 | dist.all_gather(combined_list, combined, async_op=False) 42 | combined = torch.stack(combined_list, dim=0) 43 | # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 44 | mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) 45 | 46 | size = count_all.view(-1).long().sum() 47 | if size == 1: 48 | raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size)) 49 | 50 | # calculate global mean & invstd 51 | mean, invstd = torch.batch_norm_gather_stats_with_counts( 52 | input, 53 | mean_all, 54 | invstd_all, 55 | running_mean, 56 | running_var, 57 | momentum, 58 | eps, 59 | count_all.view(-1) 60 | ) 61 | 62 | self.save_for_backward(input, weight, mean, invstd, bias, count_all) 63 | self.process_group = process_group 64 | 65 | # apply element-wise normalization 66 | out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) 67 | 68 | # av: apply swish 69 | assert eps == 1e-5, "I assumed below that eps is 1e-5" 70 | out = out * torch.sigmoid(out) 71 | # av: end 72 | 73 | return out 74 | 75 | @staticmethod 76 | def backward(self, grad_output): 77 | grad_output = grad_output.contiguous() 78 | saved_input, weight, mean, invstd, bias, count_tensor = self.saved_tensors 79 | 80 | # av: re-compute batch normalized out 81 | eps = 1e-5 82 | out = torch.batch_norm_elemt(saved_input, weight, bias, mean, invstd, eps) 83 | sigmoid_out = torch.sigmoid(out) 84 | grad_output *= (sigmoid_out * (1 + out * (1 - sigmoid_out))) 85 | # av: end 86 | 87 | grad_input = grad_weight = grad_bias = None 88 | process_group = self.process_group 89 | 90 | # calculate local stats as well as grad_weight / grad_bias 91 | sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( 92 | grad_output, 93 | saved_input, 94 | mean, 95 | invstd, 96 | weight, 97 | self.needs_input_grad[0], 98 | self.needs_input_grad[1], 99 | self.needs_input_grad[2] 100 | ) 101 | 102 | if self.needs_input_grad[0]: 103 | # synchronizing stats used to calculate input gradient. 104 | # TODO: move div_ into batch_norm_backward_elemt kernel 105 | num_channels = sum_dy.shape[0] 106 | combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) 107 | torch.distributed.all_reduce( 108 | combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) 109 | sum_dy, sum_dy_xmu = torch.split(combined, num_channels) 110 | 111 | divisor = count_tensor.sum() 112 | mean_dy = sum_dy / divisor 113 | mean_dy_xmu = sum_dy_xmu / divisor 114 | # backward pass for gradient calculation 115 | grad_input = torch.batch_norm_backward_elemt( 116 | grad_output, 117 | saved_input, 118 | mean, 119 | invstd, 120 | weight, 121 | mean_dy, 122 | mean_dy_xmu 123 | ) 124 | 125 | # synchronizing of grad_weight / grad_bias is not needed as distributed 126 | # training would handle all reduce. 127 | if weight is None or not self.needs_input_grad[1]: 128 | grad_weight = None 129 | 130 | if weight is None or not self.needs_input_grad[2]: 131 | grad_bias = None 132 | 133 | return grad_input, grad_weight, grad_bias, None, None, None, None, None, None 134 | -------------------------------------------------------------------------------- /thirdparty/inplaced_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the PyTorch library. 5 | # 6 | # Source: 7 | # https://github.com/pytorch/pytorch/blob/881c1adfcd916b6cd5de91bc343eb86aff88cc80/torch/nn/modules/batchnorm.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in this directory (LICENSE_PyTorch). The modifications 11 | # to this file are subject to the NVIDIA Source Code License for 12 | # NVAE located at the root directory. 13 | # --------------------------------------------------------------- 14 | 15 | from __future__ import division 16 | 17 | import torch 18 | from torch.nn.modules.batchnorm import _BatchNorm 19 | import torch.nn.functional as F 20 | 21 | from .functions import SyncBatchNorm as sync_batch_norm 22 | from .swish import Swish as swish 23 | 24 | 25 | class SyncBatchNormSwish(_BatchNorm): 26 | r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs 27 | with additional channel dimension) as described in the paper 28 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 29 | 30 | .. math:: 31 | 32 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 33 | 34 | The mean and standard-deviation are calculated per-dimension over all 35 | mini-batches of the same process groups. :math:`\gamma` and :math:`\beta` 36 | are learnable parameter vectors of size `C` (where `C` is the input size). 37 | By default, the elements of :math:`\gamma` are sampled from 38 | :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. 39 | 40 | Also by default, during training this layer keeps running estimates of its 41 | computed mean and variance, which are then used for normalization during 42 | evaluation. The running estimates are kept with a default :attr:`momentum` 43 | of 0.1. 44 | 45 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 46 | keep running estimates, and batch statistics are instead used during 47 | evaluation time as well. 48 | 49 | .. note:: 50 | This :attr:`momentum` argument is different from one used in optimizer 51 | classes and the conventional notion of momentum. Mathematically, the 52 | update rule for running statistics here is 53 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, 54 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 55 | new observed value. 56 | 57 | Because the Batch Normalization is done over the `C` dimension, computing statistics 58 | on `(N, +)` slices, it's common terminology to call this Volumetric Batch Normalization 59 | or Spatio-temporal Batch Normalization. 60 | 61 | Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use 62 | torch.nn.SyncBatchNorm.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping 63 | Network with DDP. 64 | 65 | Args: 66 | num_features: :math:`C` from an expected input of size 67 | :math:`(N, C, +)` 68 | eps: a value added to the denominator for numerical stability. 69 | Default: 1e-5 70 | momentum: the value used for the running_mean and running_var 71 | computation. Can be set to ``None`` for cumulative moving average 72 | (i.e. simple average). Default: 0.1 73 | affine: a boolean value that when set to ``True``, this module has 74 | learnable affine parameters. Default: ``True`` 75 | track_running_stats: a boolean value that when set to ``True``, this 76 | module tracks the running mean and variance, and when set to ``False``, 77 | this module does not track such statistics and always uses batch 78 | statistics in both training and eval modes. Default: ``True`` 79 | process_group: synchronization of stats happen within each process group 80 | individually. Default behavior is synchronization across the whole 81 | world 82 | 83 | Shape: 84 | - Input: :math:`(N, C, +)` 85 | - Output: :math:`(N, C, +)` (same shape as input) 86 | 87 | Examples:: 88 | 89 | >>> # With Learnable Parameters 90 | >>> m = nn.SyncBatchNorm(100) 91 | >>> # creating process group (optional) 92 | >>> # process_ids is a list of int identifying rank ids. 93 | >>> process_group = torch.distributed.new_group(process_ids) 94 | >>> # Without Learnable Parameters 95 | >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) 96 | >>> input = torch.randn(20, 100, 35, 45, 10) 97 | >>> output = m(input) 98 | 99 | >>> # network is nn.BatchNorm layer 100 | >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group) 101 | >>> # only single gpu per process is currently supported 102 | >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( 103 | >>> sync_bn_network, 104 | >>> device_ids=[args.local_rank], 105 | >>> output_device=args.local_rank) 106 | 107 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 108 | https://arxiv.org/abs/1502.03167 109 | """ 110 | 111 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 112 | track_running_stats=True, process_group=None): 113 | super(SyncBatchNormSwish, self).__init__(num_features, eps, momentum, affine, track_running_stats) 114 | self.process_group = process_group 115 | # gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used 116 | # under supported condition (single GPU per process) 117 | self.ddp_gpu_size = None 118 | 119 | def _check_input_dim(self, input): 120 | if input.dim() < 2: 121 | raise ValueError('expected at least 2D input (got {}D input)' 122 | .format(input.dim())) 123 | 124 | def _specify_ddp_gpu_num(self, gpu_size): 125 | if gpu_size > 1: 126 | raise ValueError('SyncBatchNorm is only supported for DDP with single GPU per process') 127 | self.ddp_gpu_size = gpu_size 128 | 129 | def forward(self, input): 130 | # currently only GPU input is supported 131 | if not input.is_cuda: 132 | raise ValueError('SyncBatchNorm expected input tensor to be on GPU') 133 | 134 | self._check_input_dim(input) 135 | 136 | # exponential_average_factor is set to self.momentum 137 | # (when it is available) only so that it gets updated 138 | # in ONNX graph when this node is exported to ONNX. 139 | if self.momentum is None: 140 | exponential_average_factor = 0.0 141 | else: 142 | exponential_average_factor = self.momentum 143 | 144 | if self.training and self.track_running_stats: 145 | self.num_batches_tracked = self.num_batches_tracked + 1 146 | if self.momentum is None: # use cumulative moving average 147 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 148 | else: # use exponential moving average 149 | exponential_average_factor = self.momentum 150 | 151 | need_sync = self.training or not self.track_running_stats 152 | if need_sync: 153 | process_group = torch.distributed.group.WORLD 154 | if self.process_group: 155 | process_group = self.process_group 156 | world_size = torch.distributed.get_world_size(process_group) 157 | need_sync = world_size > 1 158 | 159 | # fallback to framework BN when synchronization is not necessary 160 | if not need_sync: 161 | out = F.batch_norm( 162 | input, self.running_mean, self.running_var, self.weight, self.bias, 163 | self.training or not self.track_running_stats, 164 | exponential_average_factor, self.eps) 165 | return swish.apply(out) 166 | else: 167 | # av: I only use it in this setting. 168 | if not self.ddp_gpu_size and False: 169 | raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel') 170 | 171 | return sync_batch_norm.apply( 172 | input, self.weight, self.bias, self.running_mean, self.running_var, 173 | self.eps, exponential_average_factor, process_group, world_size) -------------------------------------------------------------------------------- /thirdparty/lsun.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the torchvision library 5 | # which was released under the BSD 3-Clause License. 6 | # 7 | # Source: 8 | # https://github.com/pytorch/vision/blob/ea6b879e90459006e71a164dc76b7e2cc3bff9d9/torchvision/datasets/lsun.py 9 | # 10 | # The license for the original version of this file can be 11 | # found in this directory (LICENSE_torchvision). The modifications 12 | # to this file are subject to the same BSD 3-Clause License. 13 | # --------------------------------------------------------------- 14 | 15 | from torchvision.datasets.vision import VisionDataset 16 | from PIL import Image 17 | import os 18 | import os.path 19 | import io 20 | import string 21 | from collections.abc import Iterable 22 | import pickle 23 | from torchvision.datasets.utils import verify_str_arg, iterable_to_str 24 | 25 | 26 | class LSUNClass(VisionDataset): 27 | def __init__(self, root, transform=None, target_transform=None): 28 | import lmdb 29 | super(LSUNClass, self).__init__(root, transform=transform, 30 | target_transform=target_transform) 31 | 32 | self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, 33 | readahead=False, meminit=False) 34 | with self.env.begin(write=False) as txn: 35 | self.length = txn.stat()['entries'] 36 | # cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters) 37 | # av begin 38 | # We only modified the location of cache_file. 39 | cache_file = os.path.join(self.root, '_cache_') 40 | # av end 41 | if os.path.isfile(cache_file): 42 | self.keys = pickle.load(open(cache_file, "rb")) 43 | else: 44 | with self.env.begin(write=False) as txn: 45 | self.keys = [key for key, _ in txn.cursor()] 46 | pickle.dump(self.keys, open(cache_file, "wb")) 47 | 48 | def __getitem__(self, index): 49 | img, target = None, None 50 | env = self.env 51 | with env.begin(write=False) as txn: 52 | imgbuf = txn.get(self.keys[index]) 53 | 54 | buf = io.BytesIO() 55 | buf.write(imgbuf) 56 | buf.seek(0) 57 | img = Image.open(buf).convert('RGB') 58 | 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | 62 | if self.target_transform is not None: 63 | target = self.target_transform(target) 64 | 65 | return img, target 66 | 67 | def __len__(self): 68 | return self.length 69 | 70 | 71 | class LSUN(VisionDataset): 72 | """ 73 | `LSUN `_ dataset. 74 | 75 | Args: 76 | root (string): Root directory for the database files. 77 | classes (string or list): One of {'train', 'val', 'test'} or a list of 78 | categories to load. e,g. ['bedroom_train', 'church_outdoor_train']. 79 | transform (callable, optional): A function/transform that takes in an PIL image 80 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 81 | target_transform (callable, optional): A function/transform that takes in the 82 | target and transforms it. 83 | """ 84 | 85 | def __init__(self, root, classes='train', transform=None, target_transform=None): 86 | super(LSUN, self).__init__(root, transform=transform, 87 | target_transform=target_transform) 88 | self.classes = self._verify_classes(classes) 89 | 90 | # for each class, create an LSUNClassDataset 91 | self.dbs = [] 92 | for c in self.classes: 93 | self.dbs.append(LSUNClass( 94 | root=root + '/' + c + '_lmdb', 95 | transform=transform)) 96 | 97 | self.indices = [] 98 | count = 0 99 | for db in self.dbs: 100 | count += len(db) 101 | self.indices.append(count) 102 | 103 | self.length = count 104 | 105 | def _verify_classes(self, classes): 106 | categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', 107 | 'conference_room', 'dining_room', 'kitchen', 108 | 'living_room', 'restaurant', 'tower'] 109 | dset_opts = ['train', 'val', 'test'] 110 | 111 | try: 112 | verify_str_arg(classes, "classes", dset_opts) 113 | if classes == 'test': 114 | classes = [classes] 115 | else: 116 | classes = [c + '_' + classes for c in categories] 117 | except ValueError: 118 | if not isinstance(classes, Iterable): 119 | msg = ("Expected type str or Iterable for argument classes, " 120 | "but got type {}.") 121 | raise ValueError(msg.format(type(classes))) 122 | 123 | classes = list(classes) 124 | msg_fmtstr = ("Expected type str for elements in argument classes, " 125 | "but got type {}.") 126 | for c in classes: 127 | verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c))) 128 | c_short = c.split('_') 129 | category, dset_opt = '_'.join(c_short[:-1]), c_short[-1] 130 | 131 | msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." 132 | msg = msg_fmtstr.format(category, "LSUN class", 133 | iterable_to_str(categories)) 134 | verify_str_arg(category, valid_values=categories, custom_msg=msg) 135 | 136 | msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) 137 | verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg) 138 | 139 | return classes 140 | 141 | def __getitem__(self, index): 142 | """ 143 | Args: 144 | index (int): Index 145 | 146 | Returns: 147 | tuple: Tuple (image, target) where target is the index of the target category. 148 | """ 149 | target = 0 150 | sub = 0 151 | for ind in self.indices: 152 | if index < ind: 153 | break 154 | target += 1 155 | sub = ind 156 | 157 | db = self.dbs[target] 158 | index = index - sub 159 | 160 | if self.target_transform is not None: 161 | target = self.target_transform(target) 162 | 163 | img, _ = db[index] 164 | return img, target 165 | 166 | def __len__(self): 167 | return self.length 168 | 169 | def extra_repr(self): 170 | return "Classes: {classes}".format(**self.__dict__) 171 | -------------------------------------------------------------------------------- /thirdparty/swish.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the following repo 5 | # (released under the Apache License 2.0). 6 | # 7 | # Source: 8 | # https://github.com/ceshine/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py 9 | # 10 | # The license for the original version of this file can be 11 | # found in this directory (LICENSE_apache). 12 | # --------------------------------------------------------------- 13 | 14 | import torch 15 | 16 | 17 | class Swish(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, i): 20 | result = i * torch.sigmoid(i) 21 | ctx.save_for_backward(i) 22 | return result 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | i = ctx.saved_variables[0] 27 | sigmoid_i = torch.sigmoid(i) 28 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for NVAE. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import argparse 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | import os 13 | 14 | import torch.distributed as dist 15 | from torch.multiprocessing import Process 16 | from torch.cuda.amp import autocast, GradScaler 17 | 18 | from model import AutoEncoder 19 | from thirdparty.adamax import Adamax 20 | import utils 21 | import datasets 22 | 23 | from fid.fid_score import compute_statistics_of_generator, load_statistics, calculate_frechet_distance 24 | from fid.inception import InceptionV3 25 | 26 | 27 | def main(args): 28 | # ensures that weight initializations are all the same 29 | torch.manual_seed(args.seed) 30 | np.random.seed(args.seed) 31 | torch.cuda.manual_seed(args.seed) 32 | torch.cuda.manual_seed_all(args.seed) 33 | 34 | logging = utils.Logger(args.global_rank, args.save) 35 | writer = utils.Writer(args.global_rank, args.save) 36 | 37 | # Get data loaders. 38 | train_queue, valid_queue, num_classes = datasets.get_loaders(args) 39 | args.num_total_iter = len(train_queue) * args.epochs 40 | warmup_iters = len(train_queue) * args.warmup_epochs 41 | swa_start = len(train_queue) * (args.epochs - 1) 42 | 43 | arch_instance = utils.get_arch_cells(args.arch_instance) 44 | 45 | model = AutoEncoder(args, writer, arch_instance) 46 | model = model.cuda() 47 | 48 | logging.info('args = %s', args) 49 | logging.info('param size = %fM ', utils.count_parameters_in_M(model)) 50 | logging.info('groups per scale: %s, total_groups: %d', model.groups_per_scale, sum(model.groups_per_scale)) 51 | 52 | if args.fast_adamax: 53 | # Fast adamax has the same functionality as torch.optim.Adamax, except it is faster. 54 | cnn_optimizer = Adamax(model.parameters(), args.learning_rate, 55 | weight_decay=args.weight_decay, eps=1e-3) 56 | else: 57 | cnn_optimizer = torch.optim.Adamax(model.parameters(), args.learning_rate, 58 | weight_decay=args.weight_decay, eps=1e-3) 59 | 60 | cnn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 61 | cnn_optimizer, float(args.epochs - args.warmup_epochs - 1), eta_min=args.learning_rate_min) 62 | grad_scalar = GradScaler(2**10) 63 | 64 | num_output = utils.num_output(args.dataset) 65 | bpd_coeff = 1. / np.log(2.) / num_output 66 | 67 | # if load 68 | checkpoint_file = os.path.join(args.save, 'checkpoint.pt') 69 | if args.cont_training: 70 | logging.info('loading the model.') 71 | checkpoint = torch.load(checkpoint_file, map_location='cpu') 72 | init_epoch = checkpoint['epoch'] 73 | model.load_state_dict(checkpoint['state_dict']) 74 | model = model.cuda() 75 | cnn_optimizer.load_state_dict(checkpoint['optimizer']) 76 | grad_scalar.load_state_dict(checkpoint['grad_scalar']) 77 | cnn_scheduler.load_state_dict(checkpoint['scheduler']) 78 | global_step = checkpoint['global_step'] 79 | else: 80 | global_step, init_epoch = 0, 0 81 | 82 | for epoch in range(init_epoch, args.epochs): 83 | # update lrs. 84 | if args.distributed: 85 | train_queue.sampler.set_epoch(global_step + args.seed) 86 | valid_queue.sampler.set_epoch(0) 87 | 88 | if epoch > args.warmup_epochs: 89 | cnn_scheduler.step() 90 | 91 | # Logging. 92 | logging.info('epoch %d', epoch) 93 | 94 | # Training. 95 | train_nelbo, global_step = train(train_queue, model, cnn_optimizer, grad_scalar, global_step, warmup_iters, writer, logging) 96 | logging.info('train_nelbo %f', train_nelbo) 97 | writer.add_scalar('train/nelbo', train_nelbo, global_step) 98 | 99 | model.eval() 100 | # generate samples less frequently 101 | eval_freq = 1 if args.epochs <= 50 else 20 102 | if epoch % eval_freq == 0 or epoch == (args.epochs - 1): 103 | with torch.no_grad(): 104 | num_samples = 16 105 | n = int(np.floor(np.sqrt(num_samples))) 106 | for t in [0.7, 0.8, 0.9, 1.0]: 107 | logits = model.sample(num_samples, t) 108 | output = model.decoder_output(logits) 109 | output_img = output.mean if isinstance(output, torch.distributions.bernoulli.Bernoulli) else output.sample(t) 110 | output_tiled = utils.tile_image(output_img, n) 111 | writer.add_image('generated_%0.1f' % t, output_tiled, global_step) 112 | 113 | valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=10, args=args, logging=logging) 114 | logging.info('valid_nelbo %f', valid_nelbo) 115 | logging.info('valid neg log p %f', valid_neg_log_p) 116 | logging.info('valid bpd elbo %f', valid_nelbo * bpd_coeff) 117 | logging.info('valid bpd log p %f', valid_neg_log_p * bpd_coeff) 118 | writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch) 119 | writer.add_scalar('val/nelbo', valid_nelbo, epoch) 120 | writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch) 121 | writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch) 122 | 123 | save_freq = int(np.ceil(args.epochs / 100)) 124 | if epoch % save_freq == 0 or epoch == (args.epochs - 1): 125 | if args.global_rank == 0: 126 | logging.info('saving the model.') 127 | torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict(), 128 | 'optimizer': cnn_optimizer.state_dict(), 'global_step': global_step, 129 | 'args': args, 'arch_instance': arch_instance, 'scheduler': cnn_scheduler.state_dict(), 130 | 'grad_scalar': grad_scalar.state_dict()}, checkpoint_file) 131 | 132 | # Final validation 133 | valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=1000, args=args, logging=logging) 134 | logging.info('final valid nelbo %f', valid_nelbo) 135 | logging.info('final valid neg log p %f', valid_neg_log_p) 136 | writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch + 1) 137 | writer.add_scalar('val/nelbo', valid_nelbo, epoch + 1) 138 | writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch + 1) 139 | writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch + 1) 140 | writer.close() 141 | 142 | 143 | def train(train_queue, model, cnn_optimizer, grad_scalar, global_step, warmup_iters, writer, logging): 144 | alpha_i = utils.kl_balancer_coeff(num_scales=model.num_latent_scales, 145 | groups_per_scale=model.groups_per_scale, fun='square') 146 | nelbo = utils.AvgrageMeter() 147 | model.train() 148 | for step, x in enumerate(train_queue): 149 | x = x[0] if len(x) > 1 else x 150 | x = x.cuda() 151 | 152 | # change bit length 153 | x = utils.pre_process(x, args.num_x_bits) 154 | 155 | # warm-up lr 156 | if global_step < warmup_iters: 157 | lr = args.learning_rate * float(global_step) / warmup_iters 158 | for param_group in cnn_optimizer.param_groups: 159 | param_group['lr'] = lr 160 | 161 | # sync parameters, it may not be necessary 162 | if step % 100 == 0: 163 | utils.average_params(model.parameters(), args.distributed) 164 | 165 | cnn_optimizer.zero_grad() 166 | with autocast(): 167 | logits, log_q, log_p, kl_all, kl_diag = model(x) 168 | 169 | output = model.decoder_output(logits) 170 | kl_coeff = utils.kl_coeff(global_step, args.kl_anneal_portion * args.num_total_iter, 171 | args.kl_const_portion * args.num_total_iter, args.kl_const_coeff) 172 | 173 | recon_loss = utils.reconstruction_loss(output, x, crop=model.crop_output) 174 | balanced_kl, kl_coeffs, kl_vals = utils.kl_balancer(kl_all, kl_coeff, kl_balance=True, alpha_i=alpha_i) 175 | 176 | nelbo_batch = recon_loss + balanced_kl 177 | loss = torch.mean(nelbo_batch) 178 | norm_loss = model.spectral_norm_parallel() 179 | bn_loss = model.batchnorm_loss() 180 | # get spectral regularization coefficient (lambda) 181 | if args.weight_decay_norm_anneal: 182 | assert args.weight_decay_norm_init > 0 and args.weight_decay_norm > 0, 'init and final wdn should be positive.' 183 | wdn_coeff = (1. - kl_coeff) * np.log(args.weight_decay_norm_init) + kl_coeff * np.log(args.weight_decay_norm) 184 | wdn_coeff = np.exp(wdn_coeff) 185 | else: 186 | wdn_coeff = args.weight_decay_norm 187 | 188 | loss += norm_loss * wdn_coeff + bn_loss * wdn_coeff 189 | 190 | grad_scalar.scale(loss).backward() 191 | utils.average_gradients(model.parameters(), args.distributed) 192 | grad_scalar.step(cnn_optimizer) 193 | grad_scalar.update() 194 | nelbo.update(loss.data, 1) 195 | 196 | if (global_step + 1) % 100 == 0: 197 | if (global_step + 1) % 1000 == 0: # reduced frequency 198 | n = int(np.floor(np.sqrt(x.size(0)))) 199 | x_img = x[:n*n] 200 | output_img = output.mean if isinstance(output, torch.distributions.bernoulli.Bernoulli) else output.sample() 201 | output_img = output_img[:n*n] 202 | x_tiled = utils.tile_image(x_img, n) 203 | output_tiled = utils.tile_image(output_img, n) 204 | in_out_tiled = torch.cat((x_tiled, output_tiled), dim=2) 205 | writer.add_image('reconstruction', in_out_tiled, global_step) 206 | 207 | # norm 208 | writer.add_scalar('train/norm_loss', norm_loss, global_step) 209 | writer.add_scalar('train/bn_loss', bn_loss, global_step) 210 | writer.add_scalar('train/norm_coeff', wdn_coeff, global_step) 211 | 212 | utils.average_tensor(nelbo.avg, args.distributed) 213 | logging.info('train %d %f', global_step, nelbo.avg) 214 | writer.add_scalar('train/nelbo_avg', nelbo.avg, global_step) 215 | writer.add_scalar('train/lr', cnn_optimizer.state_dict()[ 216 | 'param_groups'][0]['lr'], global_step) 217 | writer.add_scalar('train/nelbo_iter', loss, global_step) 218 | writer.add_scalar('train/kl_iter', torch.mean(sum(kl_all)), global_step) 219 | writer.add_scalar('train/recon_iter', torch.mean(utils.reconstruction_loss(output, x, crop=model.crop_output)), global_step) 220 | writer.add_scalar('kl_coeff/coeff', kl_coeff, global_step) 221 | total_active = 0 222 | for i, kl_diag_i in enumerate(kl_diag): 223 | utils.average_tensor(kl_diag_i, args.distributed) 224 | num_active = torch.sum(kl_diag_i > 0.1).detach() 225 | total_active += num_active 226 | 227 | # kl_ceoff 228 | writer.add_scalar('kl/active_%d' % i, num_active, global_step) 229 | writer.add_scalar('kl_coeff/layer_%d' % i, kl_coeffs[i], global_step) 230 | writer.add_scalar('kl_vals/layer_%d' % i, kl_vals[i], global_step) 231 | writer.add_scalar('kl/total_active', total_active, global_step) 232 | 233 | global_step += 1 234 | 235 | utils.average_tensor(nelbo.avg, args.distributed) 236 | return nelbo.avg, global_step 237 | 238 | 239 | def test(valid_queue, model, num_samples, args, logging): 240 | if args.distributed: 241 | dist.barrier() 242 | nelbo_avg = utils.AvgrageMeter() 243 | neg_log_p_avg = utils.AvgrageMeter() 244 | model.eval() 245 | for step, x in enumerate(valid_queue): 246 | x = x[0] if len(x) > 1 else x 247 | x = x.cuda() 248 | 249 | # change bit length 250 | x = utils.pre_process(x, args.num_x_bits) 251 | 252 | with torch.no_grad(): 253 | nelbo, log_iw = [], [] 254 | for k in range(num_samples): 255 | logits, log_q, log_p, kl_all, _ = model(x) 256 | output = model.decoder_output(logits) 257 | recon_loss = utils.reconstruction_loss(output, x, crop=model.crop_output) 258 | balanced_kl, _, _ = utils.kl_balancer(kl_all, kl_balance=False) 259 | nelbo_batch = recon_loss + balanced_kl 260 | nelbo.append(nelbo_batch) 261 | log_iw.append(utils.log_iw(output, x, log_q, log_p, crop=model.crop_output)) 262 | 263 | nelbo = torch.mean(torch.stack(nelbo, dim=1)) 264 | log_p = torch.mean(torch.logsumexp(torch.stack(log_iw, dim=1), dim=1) - np.log(num_samples)) 265 | 266 | nelbo_avg.update(nelbo.data, x.size(0)) 267 | neg_log_p_avg.update(- log_p.data, x.size(0)) 268 | 269 | utils.average_tensor(nelbo_avg.avg, args.distributed) 270 | utils.average_tensor(neg_log_p_avg.avg, args.distributed) 271 | if args.distributed: 272 | # block to sync 273 | dist.barrier() 274 | logging.info('val, step: %d, NELBO: %f, neg Log p %f', step, nelbo_avg.avg, neg_log_p_avg.avg) 275 | return neg_log_p_avg.avg, nelbo_avg.avg 276 | 277 | 278 | def create_generator_vae(model, batch_size, num_total_samples): 279 | num_iters = int(np.ceil(num_total_samples / batch_size)) 280 | for i in range(num_iters): 281 | with torch.no_grad(): 282 | logits = model.sample(batch_size, 1.0) 283 | output = model.decoder_output(logits) 284 | output_img = output.mean if isinstance(output, torch.distributions.bernoulli.Bernoulli) else output.mean() 285 | yield output_img.float() 286 | 287 | 288 | def test_vae_fid(model, args, total_fid_samples): 289 | dims = 2048 290 | device = 'cuda' 291 | num_gpus = args.num_process_per_node * args.num_proc_node 292 | num_sample_per_gpu = int(np.ceil(total_fid_samples / num_gpus)) 293 | 294 | g = create_generator_vae(model, args.batch_size, num_sample_per_gpu) 295 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 296 | model = InceptionV3([block_idx], model_dir=args.fid_dir).to(device) 297 | m, s = compute_statistics_of_generator(g, model, args.batch_size, dims, device, max_samples=num_sample_per_gpu) 298 | 299 | # share m and s 300 | m = torch.from_numpy(m).cuda() 301 | s = torch.from_numpy(s).cuda() 302 | # take average across gpus 303 | utils.average_tensor(m, args.distributed) 304 | utils.average_tensor(s, args.distributed) 305 | 306 | # convert m, s 307 | m = m.cpu().numpy() 308 | s = s.cpu().numpy() 309 | 310 | # load precomputed m, s 311 | path = os.path.join(args.fid_dir, args.dataset + '.npz') 312 | m0, s0 = load_statistics(path) 313 | 314 | fid = calculate_frechet_distance(m0, s0, m, s) 315 | return fid 316 | 317 | 318 | def init_processes(rank, size, fn, args): 319 | """ Initialize the distributed environment. """ 320 | os.environ['MASTER_ADDR'] = args.master_address 321 | os.environ['MASTER_PORT'] = '6020' 322 | torch.cuda.set_device(args.local_rank) 323 | dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=size) 324 | fn(args) 325 | cleanup() 326 | 327 | 328 | def cleanup(): 329 | dist.destroy_process_group() 330 | 331 | 332 | if __name__ == '__main__': 333 | parser = argparse.ArgumentParser('encoder decoder examiner') 334 | # experimental results 335 | parser.add_argument('--root', type=str, default='/tmp/nasvae/expr', 336 | help='location of the results') 337 | parser.add_argument('--save', type=str, default='exp', 338 | help='id used for storing intermediate results') 339 | # data 340 | parser.add_argument('--dataset', type=str, default='mnist', 341 | choices=['cifar10', 'mnist', 'omniglot', 'celeba_64', 'celeba_256', 342 | 'imagenet_32', 'ffhq', 'lsun_bedroom_128', 'stacked_mnist', 343 | 'lsun_church_128', 'lsun_church_64'], 344 | help='which dataset to use') 345 | parser.add_argument('--data', type=str, default='/tmp/nasvae/data', 346 | help='location of the data corpus') 347 | # optimization 348 | parser.add_argument('--batch_size', type=int, default=200, 349 | help='batch size per GPU') 350 | parser.add_argument('--learning_rate', type=float, default=1e-2, 351 | help='init learning rate') 352 | parser.add_argument('--learning_rate_min', type=float, default=1e-4, 353 | help='min learning rate') 354 | parser.add_argument('--weight_decay', type=float, default=3e-4, 355 | help='weight decay') 356 | parser.add_argument('--weight_decay_norm', type=float, default=0., 357 | help='The lambda parameter for spectral regularization.') 358 | parser.add_argument('--weight_decay_norm_init', type=float, default=10., 359 | help='The initial lambda parameter') 360 | parser.add_argument('--weight_decay_norm_anneal', action='store_true', default=False, 361 | help='This flag enables annealing the lambda coefficient from ' 362 | '--weight_decay_norm_init to --weight_decay_norm.') 363 | parser.add_argument('--epochs', type=int, default=200, 364 | help='num of training epochs') 365 | parser.add_argument('--warmup_epochs', type=int, default=5, 366 | help='num of training epochs in which lr is warmed up') 367 | parser.add_argument('--fast_adamax', action='store_true', default=False, 368 | help='This flag enables using our optimized adamax.') 369 | parser.add_argument('--arch_instance', type=str, default='res_mbconv', 370 | help='path to the architecture instance') 371 | # KL annealing 372 | parser.add_argument('--kl_anneal_portion', type=float, default=0.3, 373 | help='The portions epochs that KL is annealed') 374 | parser.add_argument('--kl_const_portion', type=float, default=0.0001, 375 | help='The portions epochs that KL is constant at kl_const_coeff') 376 | parser.add_argument('--kl_const_coeff', type=float, default=0.0001, 377 | help='The constant value used for min KL coeff') 378 | # Flow params 379 | parser.add_argument('--num_nf', type=int, default=0, 380 | help='The number of normalizing flow cells per groups. Set this to zero to disable flows.') 381 | parser.add_argument('--num_x_bits', type=int, default=8, 382 | help='The number of bits used for representing data for colored images.') 383 | # latent variables 384 | parser.add_argument('--num_latent_scales', type=int, default=1, 385 | help='the number of latent scales') 386 | parser.add_argument('--num_groups_per_scale', type=int, default=10, 387 | help='number of groups of latent variables per scale') 388 | parser.add_argument('--num_latent_per_group', type=int, default=20, 389 | help='number of channels in latent variables per group') 390 | parser.add_argument('--ada_groups', action='store_true', default=False, 391 | help='Settings this to true will set different number of groups per scale.') 392 | parser.add_argument('--min_groups_per_scale', type=int, default=1, 393 | help='the minimum number of groups per scale.') 394 | # encoder parameters 395 | parser.add_argument('--num_channels_enc', type=int, default=32, 396 | help='number of channels in encoder') 397 | parser.add_argument('--num_preprocess_blocks', type=int, default=2, 398 | help='number of preprocessing blocks') 399 | parser.add_argument('--num_preprocess_cells', type=int, default=3, 400 | help='number of cells per block') 401 | parser.add_argument('--num_cell_per_cond_enc', type=int, default=1, 402 | help='number of cell for each conditional in encoder') 403 | # decoder parameters 404 | parser.add_argument('--num_channels_dec', type=int, default=32, 405 | help='number of channels in decoder') 406 | parser.add_argument('--num_postprocess_blocks', type=int, default=2, 407 | help='number of postprocessing blocks') 408 | parser.add_argument('--num_postprocess_cells', type=int, default=3, 409 | help='number of cells per block') 410 | parser.add_argument('--num_cell_per_cond_dec', type=int, default=1, 411 | help='number of cell for each conditional in decoder') 412 | parser.add_argument('--num_mixture_dec', type=int, default=10, 413 | help='number of mixture components in decoder. set to 1 for Normal decoder.') 414 | # NAS 415 | parser.add_argument('--use_se', action='store_true', default=False, 416 | help='This flag enables squeeze and excitation.') 417 | parser.add_argument('--res_dist', action='store_true', default=False, 418 | help='This flag enables squeeze and excitation.') 419 | parser.add_argument('--cont_training', action='store_true', default=False, 420 | help='This flag enables training from an existing checkpoint.') 421 | # DDP. 422 | parser.add_argument('--num_proc_node', type=int, default=1, 423 | help='The number of nodes in multi node env.') 424 | parser.add_argument('--node_rank', type=int, default=0, 425 | help='The index of node.') 426 | parser.add_argument('--local_rank', type=int, default=0, 427 | help='rank of process in the node') 428 | parser.add_argument('--global_rank', type=int, default=0, 429 | help='rank of process among all the processes') 430 | parser.add_argument('--num_process_per_node', type=int, default=1, 431 | help='number of gpus') 432 | parser.add_argument('--master_address', type=str, default='127.0.0.1', 433 | help='address for master') 434 | parser.add_argument('--seed', type=int, default=1, 435 | help='seed used for initialization') 436 | args = parser.parse_args() 437 | args.save = args.root + '/eval-' + args.save 438 | utils.create_exp_dir(args.save) 439 | 440 | size = args.num_process_per_node 441 | 442 | if size > 1: 443 | args.distributed = True 444 | processes = [] 445 | for rank in range(size): 446 | args.local_rank = rank 447 | global_rank = rank + args.node_rank * args.num_process_per_node 448 | global_size = args.num_proc_node * args.num_process_per_node 449 | args.global_rank = global_rank 450 | print('Node rank %d, local proc %d, global proc %d' % (args.node_rank, rank, global_rank)) 451 | p = Process(target=init_processes, args=(global_rank, global_size, main, args)) 452 | p.start() 453 | processes.append(p) 454 | 455 | for p in processes: 456 | p.join() 457 | else: 458 | # for debugging 459 | print('starting in debug mode') 460 | args.distributed = True 461 | init_processes(0, size, main, args) 462 | 463 | 464 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for NVAE. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import logging 9 | import os 10 | import shutil 11 | import time 12 | from datetime import timedelta 13 | import sys 14 | 15 | import torch 16 | import torch.nn as nn 17 | import numpy as np 18 | import torch.distributed as dist 19 | 20 | import torch.nn.functional as F 21 | from tensorboardX import SummaryWriter 22 | 23 | 24 | class AvgrageMeter(object): 25 | 26 | def __init__(self): 27 | self.reset() 28 | 29 | def reset(self): 30 | self.avg = 0 31 | self.sum = 0 32 | self.cnt = 0 33 | 34 | def update(self, val, n=1): 35 | self.sum += val * n 36 | self.cnt += n 37 | self.avg = self.sum / self.cnt 38 | 39 | 40 | class ExpMovingAvgrageMeter(object): 41 | 42 | def __init__(self, momentum=0.9): 43 | self.momentum = momentum 44 | self.reset() 45 | 46 | def reset(self): 47 | self.avg = 0 48 | 49 | def update(self, val): 50 | self.avg = (1. - self.momentum) * self.avg + self.momentum * val 51 | 52 | 53 | class DummyDDP(nn.Module): 54 | def __init__(self, model): 55 | super(DummyDDP, self).__init__() 56 | self.module = model 57 | 58 | def forward(self, *input, **kwargs): 59 | return self.module(*input, **kwargs) 60 | 61 | 62 | def count_parameters_in_M(model): 63 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 64 | 65 | 66 | def save_checkpoint(state, is_best, save): 67 | filename = os.path.join(save, 'checkpoint.pth.tar') 68 | torch.save(state, filename) 69 | if is_best: 70 | best_filename = os.path.join(save, 'model_best.pth.tar') 71 | shutil.copyfile(filename, best_filename) 72 | 73 | 74 | def save(model, model_path): 75 | torch.save(model.state_dict(), model_path) 76 | 77 | 78 | def load(model, model_path): 79 | model.load_state_dict(torch.load(model_path)) 80 | 81 | 82 | def create_exp_dir(path, scripts_to_save=None): 83 | if not os.path.exists(path): 84 | os.makedirs(path, exist_ok=True) 85 | print('Experiment dir : {}'.format(path)) 86 | 87 | if scripts_to_save is not None: 88 | if not os.path.exists(os.path.join(path, 'scripts')): 89 | os.mkdir(os.path.join(path, 'scripts')) 90 | for script in scripts_to_save: 91 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 92 | shutil.copyfile(script, dst_file) 93 | 94 | 95 | class Logger(object): 96 | def __init__(self, rank, save): 97 | # other libraries may set logging before arriving at this line. 98 | # by reloading logging, we can get rid of previous configs set by other libraries. 99 | from importlib import reload 100 | reload(logging) 101 | self.rank = rank 102 | if self.rank == 0: 103 | log_format = '%(asctime)s %(message)s' 104 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 105 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 106 | fh = logging.FileHandler(os.path.join(save, 'log.txt')) 107 | fh.setFormatter(logging.Formatter(log_format)) 108 | logging.getLogger().addHandler(fh) 109 | self.start_time = time.time() 110 | 111 | def info(self, string, *args): 112 | if self.rank == 0: 113 | elapsed_time = time.time() - self.start_time 114 | elapsed_time = time.strftime( 115 | '(Elapsed: %H:%M:%S) ', time.gmtime(elapsed_time)) 116 | if isinstance(string, str): 117 | string = elapsed_time + string 118 | else: 119 | logging.info(elapsed_time) 120 | logging.info(string, *args) 121 | 122 | 123 | class Writer(object): 124 | def __init__(self, rank, save): 125 | self.rank = rank 126 | if self.rank == 0: 127 | self.writer = SummaryWriter(log_dir=save, flush_secs=20) 128 | 129 | def add_scalar(self, *args, **kwargs): 130 | if self.rank == 0: 131 | self.writer.add_scalar(*args, **kwargs) 132 | 133 | def add_figure(self, *args, **kwargs): 134 | if self.rank == 0: 135 | self.writer.add_figure(*args, **kwargs) 136 | 137 | def add_image(self, *args, **kwargs): 138 | if self.rank == 0: 139 | self.writer.add_image(*args, **kwargs) 140 | 141 | def add_histogram(self, *args, **kwargs): 142 | if self.rank == 0: 143 | self.writer.add_histogram(*args, **kwargs) 144 | 145 | def add_histogram_if(self, write, *args, **kwargs): 146 | if write and False: # Used for debugging. 147 | self.add_histogram(*args, **kwargs) 148 | 149 | def close(self, *args, **kwargs): 150 | if self.rank == 0: 151 | self.writer.close() 152 | 153 | 154 | def reduce_tensor(tensor, world_size): 155 | rt = tensor.clone() 156 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 157 | rt /= world_size 158 | return rt 159 | 160 | 161 | def get_stride_for_cell_type(cell_type): 162 | if cell_type.startswith('normal') or cell_type.startswith('combiner'): 163 | stride = 1 164 | elif cell_type.startswith('down'): 165 | stride = 2 166 | elif cell_type.startswith('up'): 167 | stride = -1 168 | else: 169 | raise NotImplementedError(cell_type) 170 | 171 | return stride 172 | 173 | 174 | def get_cout(cin, stride): 175 | if stride == 1: 176 | cout = cin 177 | elif stride == -1: 178 | cout = cin // 2 179 | elif stride == 2: 180 | cout = 2 * cin 181 | 182 | return cout 183 | 184 | 185 | def kl_balancer_coeff(num_scales, groups_per_scale, fun): 186 | if fun == 'equal': 187 | coeff = torch.cat([torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda() 188 | elif fun == 'linear': 189 | coeff = torch.cat([(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda() 190 | elif fun == 'sqrt': 191 | coeff = torch.cat([np.sqrt(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda() 192 | elif fun == 'square': 193 | coeff = torch.cat([np.square(2 ** i) / groups_per_scale[num_scales - i - 1] * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda() 194 | else: 195 | raise NotImplementedError 196 | # convert min to 1. 197 | coeff /= torch.min(coeff) 198 | return coeff 199 | 200 | 201 | def kl_per_group(kl_all): 202 | kl_vals = torch.mean(kl_all, dim=0) 203 | kl_coeff_i = torch.abs(kl_all) 204 | kl_coeff_i = torch.mean(kl_coeff_i, dim=0, keepdim=True) + 0.01 205 | 206 | return kl_coeff_i, kl_vals 207 | 208 | 209 | def kl_balancer(kl_all, kl_coeff=1.0, kl_balance=False, alpha_i=None): 210 | if kl_balance and kl_coeff < 1.0: 211 | alpha_i = alpha_i.unsqueeze(0) 212 | 213 | kl_all = torch.stack(kl_all, dim=1) 214 | kl_coeff_i, kl_vals = kl_per_group(kl_all) 215 | total_kl = torch.sum(kl_coeff_i) 216 | 217 | kl_coeff_i = kl_coeff_i / alpha_i * total_kl 218 | kl_coeff_i = kl_coeff_i / torch.mean(kl_coeff_i, dim=1, keepdim=True) 219 | kl = torch.sum(kl_all * kl_coeff_i.detach(), dim=1) 220 | 221 | # for reporting 222 | kl_coeffs = kl_coeff_i.squeeze(0) 223 | else: 224 | kl_all = torch.stack(kl_all, dim=1) 225 | kl_vals = torch.mean(kl_all, dim=0) 226 | kl = torch.sum(kl_all, dim=1) 227 | kl_coeffs = torch.ones(size=(len(kl_vals),)) 228 | 229 | return kl_coeff * kl, kl_coeffs, kl_vals 230 | 231 | 232 | def kl_coeff(step, total_step, constant_step, min_kl_coeff): 233 | return max(min((step - constant_step) / total_step, 1.0), min_kl_coeff) 234 | 235 | 236 | def log_iw(decoder, x, log_q, log_p, crop=False): 237 | recon = reconstruction_loss(decoder, x, crop) 238 | return - recon - log_q + log_p 239 | 240 | 241 | def reconstruction_loss(decoder, x, crop=False): 242 | from distributions import Normal, DiscMixLogistic 243 | 244 | recon = decoder.log_prob(x) 245 | if crop: 246 | recon = recon[:, :, 2:30, 2:30] 247 | 248 | if isinstance(decoder, DiscMixLogistic): 249 | return - torch.sum(recon, dim=[1, 2]) # summation over RGB is done. 250 | else: 251 | return - torch.sum(recon, dim=[1, 2, 3]) 252 | 253 | 254 | def tile_image(batch_image, n): 255 | assert n * n == batch_image.size(0) 256 | channels, height, width = batch_image.size(1), batch_image.size(2), batch_image.size(3) 257 | batch_image = batch_image.view(n, n, channels, height, width) 258 | batch_image = batch_image.permute(2, 0, 3, 1, 4) # n, height, n, width, c 259 | batch_image = batch_image.contiguous().view(channels, n * height, n * width) 260 | return batch_image 261 | 262 | 263 | def average_gradients(params, is_distributed): 264 | """ Gradient averaging. """ 265 | if is_distributed: 266 | size = float(dist.get_world_size()) 267 | for param in params: 268 | if param.requires_grad: 269 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 270 | param.grad.data /= size 271 | 272 | 273 | def average_params(params, is_distributed): 274 | """ parameter averaging. """ 275 | if is_distributed: 276 | size = float(dist.get_world_size()) 277 | for param in params: 278 | dist.all_reduce(param.data, op=dist.ReduceOp.SUM) 279 | param.data /= size 280 | 281 | 282 | def average_tensor(t, is_distributed): 283 | if is_distributed: 284 | size = float(dist.get_world_size()) 285 | dist.all_reduce(t.data, op=dist.ReduceOp.SUM) 286 | t.data /= size 287 | 288 | 289 | def one_hot(indices, depth, dim): 290 | indices = indices.unsqueeze(dim) 291 | size = list(indices.size()) 292 | size[dim] = depth 293 | y_onehot = torch.zeros(size).cuda() 294 | y_onehot.zero_() 295 | y_onehot.scatter_(dim, indices, 1) 296 | 297 | return y_onehot 298 | 299 | 300 | def num_output(dataset): 301 | if dataset in {'mnist', 'omniglot'}: 302 | return 28 * 28 303 | elif dataset == 'cifar10': 304 | return 3 * 32 * 32 305 | elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'): 306 | size = int(dataset.split('_')[-1]) 307 | return 3 * size * size 308 | elif dataset == 'ffhq': 309 | return 3 * 256 * 256 310 | else: 311 | raise NotImplementedError 312 | 313 | 314 | def get_input_size(dataset): 315 | if dataset in {'mnist', 'omniglot'}: 316 | return 32 317 | elif dataset == 'cifar10': 318 | return 32 319 | elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'): 320 | size = int(dataset.split('_')[-1]) 321 | return size 322 | elif dataset == 'ffhq': 323 | return 256 324 | else: 325 | raise NotImplementedError 326 | 327 | 328 | def pre_process(x, num_bits): 329 | if num_bits != 8: 330 | x = torch.floor(x * 255 / 2 ** (8 - num_bits)) 331 | x /= (2 ** num_bits - 1) 332 | return x 333 | 334 | 335 | def get_arch_cells(arch_type): 336 | if arch_type == 'res_elu': 337 | arch_cells = dict() 338 | arch_cells['normal_enc'] = ['res_elu', 'res_elu'] 339 | arch_cells['down_enc'] = ['res_elu', 'res_elu'] 340 | arch_cells['normal_dec'] = ['res_elu', 'res_elu'] 341 | arch_cells['up_dec'] = ['res_elu', 'res_elu'] 342 | arch_cells['normal_pre'] = ['res_elu', 'res_elu'] 343 | arch_cells['down_pre'] = ['res_elu', 'res_elu'] 344 | arch_cells['normal_post'] = ['res_elu', 'res_elu'] 345 | arch_cells['up_post'] = ['res_elu', 'res_elu'] 346 | arch_cells['ar_nn'] = [''] 347 | elif arch_type == 'res_bnelu': 348 | arch_cells = dict() 349 | arch_cells['normal_enc'] = ['res_bnelu', 'res_bnelu'] 350 | arch_cells['down_enc'] = ['res_bnelu', 'res_bnelu'] 351 | arch_cells['normal_dec'] = ['res_bnelu', 'res_bnelu'] 352 | arch_cells['up_dec'] = ['res_bnelu', 'res_bnelu'] 353 | arch_cells['normal_pre'] = ['res_bnelu', 'res_bnelu'] 354 | arch_cells['down_pre'] = ['res_bnelu', 'res_bnelu'] 355 | arch_cells['normal_post'] = ['res_bnelu', 'res_bnelu'] 356 | arch_cells['up_post'] = ['res_bnelu', 'res_bnelu'] 357 | arch_cells['ar_nn'] = [''] 358 | elif arch_type == 'res_bnswish': 359 | arch_cells = dict() 360 | arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish'] 361 | arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish'] 362 | arch_cells['normal_dec'] = ['res_bnswish', 'res_bnswish'] 363 | arch_cells['up_dec'] = ['res_bnswish', 'res_bnswish'] 364 | arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish'] 365 | arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish'] 366 | arch_cells['normal_post'] = ['res_bnswish', 'res_bnswish'] 367 | arch_cells['up_post'] = ['res_bnswish', 'res_bnswish'] 368 | arch_cells['ar_nn'] = [''] 369 | elif arch_type == 'mbconv_sep': 370 | arch_cells = dict() 371 | arch_cells['normal_enc'] = ['mconv_e6k5g0'] 372 | arch_cells['down_enc'] = ['mconv_e6k5g0'] 373 | arch_cells['normal_dec'] = ['mconv_e6k5g0'] 374 | arch_cells['up_dec'] = ['mconv_e6k5g0'] 375 | arch_cells['normal_pre'] = ['mconv_e3k5g0'] 376 | arch_cells['down_pre'] = ['mconv_e3k5g0'] 377 | arch_cells['normal_post'] = ['mconv_e3k5g0'] 378 | arch_cells['up_post'] = ['mconv_e3k5g0'] 379 | arch_cells['ar_nn'] = [''] 380 | elif arch_type == 'mbconv_sep11': 381 | arch_cells = dict() 382 | arch_cells['normal_enc'] = ['mconv_e6k11g0'] 383 | arch_cells['down_enc'] = ['mconv_e6k11g0'] 384 | arch_cells['normal_dec'] = ['mconv_e6k11g0'] 385 | arch_cells['up_dec'] = ['mconv_e6k11g0'] 386 | arch_cells['normal_pre'] = ['mconv_e3k5g0'] 387 | arch_cells['down_pre'] = ['mconv_e3k5g0'] 388 | arch_cells['normal_post'] = ['mconv_e3k5g0'] 389 | arch_cells['up_post'] = ['mconv_e3k5g0'] 390 | arch_cells['ar_nn'] = [''] 391 | elif arch_type == 'res_mbconv': 392 | arch_cells = dict() 393 | arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish'] 394 | arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish'] 395 | arch_cells['normal_dec'] = ['mconv_e6k5g0'] 396 | arch_cells['up_dec'] = ['mconv_e6k5g0'] 397 | arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish'] 398 | arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish'] 399 | arch_cells['normal_post'] = ['mconv_e3k5g0'] 400 | arch_cells['up_post'] = ['mconv_e3k5g0'] 401 | arch_cells['ar_nn'] = [''] 402 | elif arch_type == 'res53_mbconv': 403 | arch_cells = dict() 404 | arch_cells['normal_enc'] = ['res_bnswish5', 'res_bnswish'] 405 | arch_cells['down_enc'] = ['res_bnswish5', 'res_bnswish'] 406 | arch_cells['normal_dec'] = ['mconv_e6k5g0'] 407 | arch_cells['up_dec'] = ['mconv_e6k5g0'] 408 | arch_cells['normal_pre'] = ['res_bnswish5', 'res_bnswish'] 409 | arch_cells['down_pre'] = ['res_bnswish5', 'res_bnswish'] 410 | arch_cells['normal_post'] = ['mconv_e3k5g0'] 411 | arch_cells['up_post'] = ['mconv_e3k5g0'] 412 | arch_cells['ar_nn'] = [''] 413 | elif arch_type == 'res35_mbconv': 414 | arch_cells = dict() 415 | arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish5'] 416 | arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish5'] 417 | arch_cells['normal_dec'] = ['mconv_e6k5g0'] 418 | arch_cells['up_dec'] = ['mconv_e6k5g0'] 419 | arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish5'] 420 | arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish5'] 421 | arch_cells['normal_post'] = ['mconv_e3k5g0'] 422 | arch_cells['up_post'] = ['mconv_e3k5g0'] 423 | arch_cells['ar_nn'] = [''] 424 | elif arch_type == 'res55_mbconv': 425 | arch_cells = dict() 426 | arch_cells['normal_enc'] = ['res_bnswish5', 'res_bnswish5'] 427 | arch_cells['down_enc'] = ['res_bnswish5', 'res_bnswish5'] 428 | arch_cells['normal_dec'] = ['mconv_e6k5g0'] 429 | arch_cells['up_dec'] = ['mconv_e6k5g0'] 430 | arch_cells['normal_pre'] = ['res_bnswish5', 'res_bnswish5'] 431 | arch_cells['down_pre'] = ['res_bnswish5', 'res_bnswish5'] 432 | arch_cells['normal_post'] = ['mconv_e3k5g0'] 433 | arch_cells['up_post'] = ['mconv_e3k5g0'] 434 | arch_cells['ar_nn'] = [''] 435 | elif arch_type == 'res_mbconv9': 436 | arch_cells = dict() 437 | arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish'] 438 | arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish'] 439 | arch_cells['normal_dec'] = ['mconv_e6k9g0'] 440 | arch_cells['up_dec'] = ['mconv_e6k9g0'] 441 | arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish'] 442 | arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish'] 443 | arch_cells['normal_post'] = ['mconv_e3k9g0'] 444 | arch_cells['up_post'] = ['mconv_e3k9g0'] 445 | arch_cells['ar_nn'] = [''] 446 | elif arch_type == 'mbconv_res': 447 | arch_cells = dict() 448 | arch_cells['normal_enc'] = ['mconv_e6k5g0'] 449 | arch_cells['down_enc'] = ['mconv_e6k5g0'] 450 | arch_cells['normal_dec'] = ['res_bnswish', 'res_bnswish'] 451 | arch_cells['up_dec'] = ['res_bnswish', 'res_bnswish'] 452 | arch_cells['normal_pre'] = ['mconv_e3k5g0'] 453 | arch_cells['down_pre'] = ['mconv_e3k5g0'] 454 | arch_cells['normal_post'] = ['res_bnswish', 'res_bnswish'] 455 | arch_cells['up_post'] = ['res_bnswish', 'res_bnswish'] 456 | arch_cells['ar_nn'] = [''] 457 | elif arch_type == 'mbconv_den': 458 | arch_cells = dict() 459 | arch_cells['normal_enc'] = ['mconv_e6k5g0'] 460 | arch_cells['down_enc'] = ['mconv_e6k5g0'] 461 | arch_cells['normal_dec'] = ['mconv_e6k5g0'] 462 | arch_cells['up_dec'] = ['mconv_e6k5g0'] 463 | arch_cells['normal_pre'] = ['mconv_e3k5g8'] 464 | arch_cells['down_pre'] = ['mconv_e3k5g8'] 465 | arch_cells['normal_post'] = ['mconv_e3k5g8'] 466 | arch_cells['up_post'] = ['mconv_e3k5g8'] 467 | arch_cells['ar_nn'] = [''] 468 | else: 469 | raise NotImplementedError 470 | 471 | return arch_cells 472 | 473 | 474 | def groups_per_scale(num_scales, num_groups_per_scale, is_adaptive, divider=2, minimum_groups=1): 475 | g = [] 476 | n = num_groups_per_scale 477 | for s in range(num_scales): 478 | assert n >= 1 479 | g.append(n) 480 | if is_adaptive: 481 | n = n // divider 482 | n = max(minimum_groups, n) 483 | return g 484 | --------------------------------------------------------------------------------