├── .gitignore ├── BigGAN_PyTorch ├── BigGAN.py ├── BigGANdeep.py ├── LICENSE ├── README.md ├── TFHub │ ├── README.md │ ├── biggan_v1.py │ └── converter.py ├── animal_hash.py ├── config_files │ ├── COCO_Stuff │ │ ├── BigGAN │ │ │ ├── unconditional_biggan_res128.json │ │ │ └── unconditional_biggan_res256.json │ │ └── IC-GAN │ │ │ ├── icgan_res128_ddp.json │ │ │ └── icgan_res256_ddp.json │ ├── ImageNet-LT │ │ ├── BigGAN │ │ │ ├── biggan_res128.json │ │ │ ├── biggan_res256.json │ │ │ └── biggan_res64.json │ │ └── cc_IC-GAN │ │ │ ├── cc_icgan_res128.json │ │ │ ├── cc_icgan_res256.json │ │ │ └── cc_icgan_res64.json │ └── ImageNet │ │ ├── BigGAN │ │ ├── biggan_res128.json │ │ ├── biggan_res256_half_cap.json │ │ └── biggan_res64.json │ │ ├── IC-GAN │ │ ├── icgan_res128.json │ │ ├── icgan_res256.json │ │ ├── icgan_res256_halfcap.json │ │ └── icgan_res64.json │ │ └── cc_IC-GAN │ │ ├── cc_icgan_res128.json │ │ ├── cc_icgan_res256.json │ │ ├── cc_icgan_res256_halfcap.json │ │ └── cc_icgan_res64.json ├── diffaugment_utils.py ├── imagenet_lt │ ├── ImageNet_LT_train.txt │ └── ImageNet_LT_val.txt ├── imgs │ ├── D Singular Values.png │ ├── DeepSamples.png │ ├── DogBall.png │ ├── G Singular Values.png │ ├── IS_FID.png │ ├── Losses.png │ ├── header_image.jpg │ └── interp_sample.jpg ├── layers.py ├── logs │ ├── BigGAN_ch96_bs256x8.jsonl │ ├── compare_IS.m │ ├── metalog.txt │ ├── process_inception_log.m │ └── process_training.m ├── losses.py ├── make_hdf5.py ├── run.py ├── scripts │ ├── launch_BigGAN_bs256x8.sh │ ├── launch_BigGAN_bs512x4.sh │ ├── launch_BigGAN_ch64_bs256x8.sh │ ├── launch_BigGAN_deep.sh │ ├── launch_SAGAN_bs128x2_ema.sh │ ├── launch_SNGAN.sh │ ├── launch_cifar_ema.sh │ ├── sample_BigGAN_bs256x8.sh │ ├── sample_cifar_ema.sh │ └── utils │ │ ├── duplicate.sh │ │ └── prepare_data.sh ├── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── batchnorm_reimpl.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── train_fns.py ├── trainer.py └── utils.py ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── cog.yaml ├── data_utils ├── __init__.py ├── calculate_inception_moments.py ├── cocostuff_dataset.py ├── compute_pdrc.py ├── datasets_common.py ├── inception_tf13.py ├── inception_utils.py ├── make_hdf5.py ├── make_hdf5_nns.py ├── prepare_data.sh ├── resnet.py ├── store_coco_jpeg_images.py ├── store_kmeans_indexes.py └── utils.py ├── download-weights.sh ├── environment.yml ├── figures ├── github_image.png ├── icgan_clip.png └── icgan_transfer_all_github.png ├── inference ├── .ipynb_checkpoints │ └── icgan_colab-checkpoint.ipynb ├── generate_images.py ├── icgan_colab.ipynb ├── sample.py ├── test.py └── utils.py ├── predict.py └── stylegan2_ada_pytorch ├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── .gitignore ├── Dockerfile ├── LICENSE.txt ├── README.md ├── calc_metrics.py ├── config_files └── COCO_Stuff │ ├── IC-GAN │ ├── icgan_stylegan_res128.json │ └── icgan_stylegan_res256.json │ └── StyleGAN2 │ ├── unconditional_stylegan_res128.json │ └── unconditional_stylegan_res256.json ├── dataset_tool.py ├── dnnlib ├── __init__.py └── util.py ├── docker_run.sh ├── docs ├── dataset-tool-help.txt ├── license.html ├── stylegan2-ada-teaser-1024x252.png ├── stylegan2-ada-training-curves.png └── train-help.txt ├── generate.py ├── legacy.py ├── metrics ├── __init__.py ├── frechet_inception_distance.py ├── inception_score.py ├── kernel_inception_distance.py ├── metric_main.py ├── metric_utils.py ├── perceptual_path_length.py └── precision_recall.py ├── parser.py ├── projector.py ├── run.py ├── style_mixing.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py ├── train.py └── training ├── __init__.py ├── augment.py ├── dataset.py ├── loss.py ├── networks.py └── training_loop.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | weight_norms/ 3 | *pyc 4 | inception_net/ 5 | *.npy 6 | *.npz 7 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Andy Brock 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/TFHub/README.md: -------------------------------------------------------------------------------- 1 | # BigGAN-PyTorch TFHub converter 2 | This dir contains scripts for taking the [pre-trained generator weights from TFHub](https://tfhub.dev/s?q=biggan) and porting them to BigGAN-Pytorch. 3 | 4 | In addition to the base libraries for BigGAN-PyTorch, to run this code you will need: 5 | 6 | TensorFlow 7 | TFHub 8 | parse 9 | 10 | Note that this code is only presently set up to run the ported models without truncation--you'll need to accumulate standing stats at each truncation level yourself if you wish to employ it. 11 | 12 | To port the 128x128 model from tfhub, produce a pretrained weights .pth file, and generate samples using all your GPUs, run 13 | 14 | `python converter.py -r 128 --generate_samples --parallel` -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/COCO_Stuff/BigGAN/unconditional_biggan_res128.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "unconditional_biggan_class_cond_res128_COCO", 3 | "which_dataset": "coco", 4 | "run_setup": "local_debug", 5 | "deterministic_run": true, 6 | "num_workers": 10, 7 | 8 | "ddp_train": true, 9 | "n_nodes": 1, 10 | "n_gpus_per_node": 4, 11 | "hflips": true, 12 | "DA": true, 13 | "DiffAugment": "translation", 14 | 15 | "test_every": 1, 16 | "save_every": 1, 17 | "num_epochs": 3000, 18 | "es_patience": 50, 19 | "shuffle": true, 20 | 21 | "G_eval_mode": true, 22 | "ema": true, 23 | "use_ema": true, 24 | "num_G_accumulations": 1, 25 | "num_D_accumulations": 1, 26 | "num_D_steps": 2, 27 | 28 | "constant_conditioning": true, 29 | "class_cond": true, 30 | "hier": true, 31 | "resolution": 128, 32 | "G_attn": "64", 33 | "D_attn": "64", 34 | "shared_dim": 128, 35 | "G_shared": true, 36 | "batch_size": 64, 37 | "D_lr": 4e-4, 38 | "G_lr": 1e-4, 39 | "G_ch": 48, 40 | "D_ch": 48, 41 | 42 | "load_weights": "" 43 | 44 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/COCO_Stuff/BigGAN/unconditional_biggan_res256.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "unconditional_biggan_class_cond_res256_COCO", 3 | "which_dataset": "coco", 4 | "run_setup": "local_debug", 5 | "deterministic_run": true, 6 | "num_workers": 10, 7 | 8 | "ddp_train": true, 9 | "n_nodes": 2, 10 | "n_gpus_per_node": 8, 11 | "hflips": true, 12 | "DA": true, 13 | "DiffAugment": "translation", 14 | 15 | "test_every": 1, 16 | "save_every": 1, 17 | "num_epochs": 3000, 18 | "es_patience": 50, 19 | "shuffle": true, 20 | 21 | "G_eval_mode": true, 22 | "ema": true, 23 | "use_ema": true, 24 | "num_G_accumulations": 1, 25 | "num_D_accumulations": 1, 26 | "num_D_steps": 2, 27 | 28 | "constant_conditioning": true, 29 | "class_cond": true, 30 | "hier": true, 31 | "resolution": 256, 32 | "G_attn": "64", 33 | "D_attn": "64", 34 | "shared_dim": 128, 35 | "G_shared": true, 36 | "batch_size": 16, 37 | "D_lr": 1e-4, 38 | "G_lr": 1e-4, 39 | "G_ch": 48, 40 | "D_ch": 48, 41 | 42 | "load_weights": "" 43 | 44 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/COCO_Stuff/IC-GAN/icgan_res128_ddp.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "icgan_res128_COCO", 3 | "which_dataset": "coco", 4 | "run_setup": "local_debug", 5 | "deterministic_run": true, 6 | "num_workers": 10, 7 | 8 | "ddp_train": true, 9 | "n_nodes": 1, 10 | "n_gpus_per_node": 4, 11 | "hflips": true, 12 | "DA": true, 13 | "DiffAugment": "translation", 14 | "feature_augmentation": true, 15 | 16 | "test_every": 5, 17 | "save_every": 1, 18 | "num_epochs": 3000, 19 | "es_patience": 50, 20 | "shuffle": true, 21 | 22 | "G_eval_mode": true, 23 | "ema": true, 24 | "use_ema": true, 25 | "num_G_accumulations": 1, 26 | "num_D_accumulations": 1, 27 | "num_D_steps": 1, 28 | 29 | "class_cond": false, 30 | "instance_cond": true, 31 | "hier": true, 32 | "resolution": 128, 33 | "G_attn": "64", 34 | "D_attn": "64", 35 | "shared_dim": 128, 36 | "shared_dim_feat": 512, 37 | "G_shared": true, 38 | "G_shared_feat": true, 39 | 40 | "k_nn": 5, 41 | "feature_extractor": "selfsupervised", 42 | 43 | "batch_size": 64, 44 | "D_lr": 4e-4, 45 | "G_lr": 1e-4, 46 | "G_ch": 48, 47 | "D_ch": 48, 48 | 49 | "load_weights": "" 50 | 51 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/COCO_Stuff/IC-GAN/icgan_res256_ddp.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "icgan_res256_COCO", 3 | "which_dataset": "coco", 4 | "run_setup": "local_debug", 5 | "deterministic_run": true, 6 | "num_workers": 10, 7 | 8 | "ddp_train": true, 9 | "n_nodes": 2, 10 | "n_gpus_per_node": 8, 11 | "hflips": true, 12 | "DA": true, 13 | "DiffAugment": "translation", 14 | "feature_augmentation": true, 15 | 16 | "test_every": 5, 17 | "save_every": 1, 18 | "num_epochs": 3000, 19 | "es_patience": 50, 20 | "shuffle": true, 21 | 22 | "G_eval_mode": true, 23 | "ema": true, 24 | "use_ema": true, 25 | "num_G_accumulations": 1, 26 | "num_D_accumulations": 1, 27 | "num_D_steps": 1, 28 | 29 | "class_cond": false, 30 | "instance_cond": true, 31 | "hier": true, 32 | "resolution": 256, 33 | "G_attn": "64", 34 | "D_attn": "64", 35 | "shared_dim": 128, 36 | "shared_dim_feat": 512, 37 | "G_shared": true, 38 | "G_shared_feat": true, 39 | 40 | "k_nn": 5, 41 | "feature_extractor": "selfsupervised", 42 | 43 | "batch_size": 16, 44 | "D_lr": 1e-4, 45 | "G_lr": 1e-4, 46 | "G_ch": 48, 47 | "D_ch": 48, 48 | 49 | "load_weights": "" 50 | 51 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet-LT/BigGAN/biggan_res128.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "biggan_imagenet_lt_class_cond_res128", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 1, 9 | "n_gpus_per_node": 2, 10 | "hflips": true, 11 | "DA": true, 12 | "DiffAugment": "translation", 13 | 14 | "test_every": 10, 15 | "save_every": 1, 16 | "num_epochs": 3000, 17 | "es_patience": 50, 18 | "shuffle": true, 19 | 20 | "G_eval_mode": true, 21 | "ema": true, 22 | "use_ema": true, 23 | "num_G_accumulations": 1, 24 | "num_D_accumulations": 1, 25 | "num_D_steps": 2, 26 | 27 | "class_cond": true, 28 | "hier": true, 29 | "resolution": 128, 30 | "G_attn": "64", 31 | "D_attn": "64", 32 | "shared_dim": 128, 33 | "G_shared": true, 34 | "batch_size": 64, 35 | "D_lr": 1e-4, 36 | "G_lr": 1e-4, 37 | "G_ch": 64, 38 | "D_ch": 64, 39 | 40 | "longtail": true, 41 | "longtail_gen": true, 42 | "use_balanced_sampler": false, 43 | "custom_distrib_gen": false, 44 | "longtail_temperature": 1, 45 | 46 | "load_weights": "" 47 | 48 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet-LT/BigGAN/biggan_res256.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "biggan_imagenet_lt_class_cond_res256", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 1, 9 | "n_gpus_per_node": 8, 10 | "hflips": true, 11 | "DA": true, 12 | "DiffAugment": "translation", 13 | 14 | "test_every": 10, 15 | "save_every": 1, 16 | "num_epochs": 3000, 17 | "es_patience": 50, 18 | "shuffle": true, 19 | 20 | "G_eval_mode": true, 21 | "ema": true, 22 | "use_ema": true, 23 | "num_G_accumulations": 1, 24 | "num_D_accumulations": 1, 25 | "num_D_steps": 2, 26 | 27 | "class_cond": true, 28 | "hier": true, 29 | "resolution": 256, 30 | "G_attn": "64", 31 | "D_attn": "64", 32 | "shared_dim": 128, 33 | "G_shared": true, 34 | "batch_size": 16, 35 | "D_lr": 1e-4, 36 | "G_lr": 1e-4, 37 | "G_ch": 64, 38 | "D_ch": 64, 39 | 40 | "longtail": true, 41 | "longtail_gen": true, 42 | "use_balanced_sampler": false, 43 | "custom_distrib_gen": false, 44 | "longtail_temperature": 1, 45 | 46 | "load_weights": "" 47 | 48 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet-LT/BigGAN/biggan_res64.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "biggan_imagenet_lt_class_cond_res64", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 1, 9 | "n_gpus_per_node": 1, 10 | "hflips": true, 11 | "DA": true, 12 | "DiffAugment": "translation", 13 | 14 | "test_every": 1, 15 | "save_every": 1, 16 | "num_epochs": 3000, 17 | "es_patience": 50, 18 | "shuffle": true, 19 | 20 | "G_eval_mode": true, 21 | "ema": true, 22 | "use_ema": true, 23 | "num_G_accumulations": 1, 24 | "num_D_accumulations": 1, 25 | "num_D_steps": 1, 26 | 27 | "class_cond": true, 28 | "hier": true, 29 | "resolution": 64, 30 | "G_attn": "32", 31 | "D_attn": "32", 32 | "shared_dim": 128, 33 | "G_shared": true, 34 | "batch_size": 128, 35 | "D_lr": 1e-3, 36 | "G_lr": 1e-5, 37 | "G_ch": 64, 38 | "D_ch": 64, 39 | 40 | "longtail": true, 41 | "longtail_gen": true, 42 | "use_balanced_sampler": false, 43 | "custom_distrib_gen": false, 44 | "longtail_temperature": 1, 45 | 46 | "load_weights": "" 47 | 48 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet-LT/cc_IC-GAN/cc_icgan_res128.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "cc_icgan_biggan_imagenet_res128", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 1, 9 | "n_gpus_per_node": 2, 10 | "hflips": true, 11 | "DA": true, 12 | "DiffAugment": "translation", 13 | 14 | "test_every": 10, 15 | "save_every": 1, 16 | "num_epochs": 3000, 17 | "es_patience": 50, 18 | "shuffle": true, 19 | 20 | "G_eval_mode": true, 21 | "ema": true, 22 | "use_ema": true, 23 | "num_G_accumulations": 1, 24 | "num_D_accumulations": 1, 25 | "num_D_steps": 2, 26 | 27 | "class_cond": true, 28 | "instance_cond": true, 29 | "which_knn_balance": "instance_balance", 30 | "hier": true, 31 | "resolution": 128, 32 | "G_attn": "64", 33 | "D_attn": "64", 34 | "shared_dim": 128, 35 | "shared_dim_feat": 512, 36 | "G_shared": true, 37 | "G_shared_feat": true, 38 | 39 | "k_nn": 5, 40 | "feature_extractor": "classification", 41 | 42 | "batch_size": 64, 43 | "D_lr": 1e-4, 44 | "G_lr": 1e-4, 45 | "G_ch": 64, 46 | "D_ch": 64, 47 | 48 | "longtail": true, 49 | "longtail_gen": true, 50 | "use_balanced_sampler": false, 51 | "custom_distrib_gen": false, 52 | "longtail_temperature": 1, 53 | 54 | "load_weights": "" 55 | 56 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet-LT/cc_IC-GAN/cc_icgan_res256.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "cc_icgan_biggan_imagenet_res256", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 1, 9 | "n_gpus_per_node": 8, 10 | "hflips": true, 11 | "DA": true, 12 | "DiffAugment": "translation", 13 | 14 | "test_every": 10, 15 | "save_every": 1, 16 | "num_epochs": 3000, 17 | "es_patience": 50, 18 | "shuffle": true, 19 | 20 | "G_eval_mode": true, 21 | "ema": true, 22 | "use_ema": true, 23 | "num_G_accumulations": 1, 24 | "num_D_accumulations": 1, 25 | "num_D_steps": 2, 26 | 27 | "class_cond": true, 28 | "instance_cond": true, 29 | "which_knn_balance": "instance_balance", 30 | "hier": true, 31 | "resolution": 256, 32 | "G_attn": "64", 33 | "D_attn": "64", 34 | "shared_dim": 128, 35 | "shared_dim_feat": 512, 36 | "G_shared": true, 37 | "G_shared_feat": true, 38 | 39 | "k_nn": 5, 40 | "feature_extractor": "classification", 41 | 42 | "batch_size": 16, 43 | "D_lr": 1e-4, 44 | "G_lr": 1e-4, 45 | "G_ch": 64, 46 | "D_ch": 64, 47 | 48 | "longtail": true, 49 | "longtail_gen": true, 50 | "use_balanced_sampler": false, 51 | "custom_distrib_gen": false, 52 | "longtail_temperature": 1, 53 | 54 | "load_weights": "" 55 | 56 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet-LT/cc_IC-GAN/cc_icgan_res64.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "cc_icgan_biggan_imagenet_res64", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 1, 9 | "n_gpus_per_node": 1, 10 | "hflips": true, 11 | "DA": true, 12 | "DiffAugment": "translation", 13 | 14 | "test_every": 1, 15 | "save_every": 1, 16 | "num_epochs": 3000, 17 | "es_patience": 50, 18 | "shuffle": true, 19 | 20 | "G_eval_mode": true, 21 | "ema": true, 22 | "use_ema": true, 23 | "num_G_accumulations": 1, 24 | "num_D_accumulations": 1, 25 | "num_D_steps": 1, 26 | 27 | "class_cond": true, 28 | "instance_cond": true, 29 | "which_knn_balance": "instance_balance", 30 | "hier": true, 31 | "resolution": 64, 32 | "G_attn": "32", 33 | "D_attn": "32", 34 | "shared_dim": 128, 35 | "shared_dim_feat": 512, 36 | "G_shared": true, 37 | "G_shared_feat": true, 38 | 39 | "k_nn": 5, 40 | "feature_extractor": "classification", 41 | 42 | "batch_size": 128, 43 | "D_lr": 1e-3, 44 | "G_lr": 1e-5, 45 | "G_ch": 64, 46 | "D_ch": 64, 47 | 48 | "longtail": true, 49 | "longtail_gen": true, 50 | "use_balanced_sampler": false, 51 | "custom_distrib_gen": false, 52 | "longtail_temperature": 1, 53 | 54 | "load_weights": "" 55 | 56 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet/BigGAN/biggan_res128.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "biggan_imagenet_res128", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 4, 9 | "n_gpus_per_node": 8, 10 | "hflips": true, 11 | 12 | "test_every": 5, 13 | "save_every": 2, 14 | "num_epochs": 3000, 15 | "es_patience": 50, 16 | "shuffle": true, 17 | 18 | "G_eval_mode": true, 19 | "ema": true, 20 | "use_ema": true, 21 | "num_G_accumulations": 1, 22 | "num_D_accumulations": 1, 23 | "num_D_steps": 1, 24 | 25 | "class_cond": true, 26 | "hier": true, 27 | "resolution": 128, 28 | "G_attn": "64", 29 | "D_attn": "64", 30 | "shared_dim": 128, 31 | "G_shared": true, 32 | "batch_size": 64, 33 | "D_lr": 4e-4, 34 | "G_lr": 1e-4, 35 | "G_ch": 96, 36 | "D_ch": 96, 37 | 38 | "load_weights": "" 39 | 40 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet/BigGAN/biggan_res256_half_cap.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "biggan_class_cond_res256_half_cap_noflips", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 4, 9 | "n_gpus_per_node": 8, 10 | "hflips": false, 11 | 12 | "test_every": 5, 13 | "save_every": 1, 14 | "num_epochs": 3000, 15 | "es_patience": 50, 16 | "shuffle": true, 17 | 18 | "G_eval_mode": true, 19 | "ema": true, 20 | "use_ema": true, 21 | "num_G_accumulations": 4, 22 | "num_D_accumulations": 4, 23 | "num_D_steps": 1, 24 | 25 | "class_cond": true, 26 | "hier": true, 27 | "resolution": 256, 28 | "G_attn": "64", 29 | "D_attn": "64", 30 | "shared_dim": 128, 31 | "G_shared": true, 32 | "batch_size": 16, 33 | "D_lr": 4e-4, 34 | "G_lr": 1e-4, 35 | "G_ch": 64, 36 | "D_ch": 64, 37 | 38 | "load_weights": "" 39 | 40 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet/BigGAN/biggan_res64.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "biggan_imagenet_res64", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 1, 9 | "n_gpus_per_node": 1, 10 | "hflips": true, 11 | 12 | "test_every": 1, 13 | "save_every": 1, 14 | "num_epochs": 3000, 15 | "es_patience": 50, 16 | "shuffle": true, 17 | 18 | "G_eval_mode": true, 19 | "ema": true, 20 | "use_ema": true, 21 | "num_G_accumulations": 1, 22 | "num_D_accumulations": 1, 23 | "num_D_steps": 1, 24 | 25 | "class_cond": true, 26 | "hier": true, 27 | "resolution": 64, 28 | "G_attn": "32", 29 | "D_attn": "32", 30 | "shared_dim": 128, 31 | "G_shared": true, 32 | "batch_size": 256, 33 | "D_lr": 1e-4, 34 | "G_lr": 1e-4, 35 | "G_ch": 64, 36 | "D_ch": 64, 37 | 38 | "load_weights": "" 39 | 40 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet/IC-GAN/icgan_res128.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "icgan_biggan_imagenet_res128", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 4, 9 | "n_gpus_per_node": 8, 10 | "hflips": true, 11 | "feature_augmentation": true, 12 | 13 | "test_every": 5, 14 | "save_every": 1, 15 | "num_epochs": 3000, 16 | "es_patience": 50, 17 | "shuffle": true, 18 | 19 | "G_eval_mode": true, 20 | "ema": true, 21 | "use_ema": true, 22 | "num_G_accumulations": 1, 23 | "num_D_accumulations": 1, 24 | "num_D_steps": 1, 25 | 26 | "class_cond": false, 27 | "instance_cond": true, 28 | "hier": true, 29 | "resolution": 128, 30 | "G_attn": "64", 31 | "D_attn": "64", 32 | "shared_dim": 128, 33 | "shared_dim_feat": 512, 34 | "G_shared": true, 35 | "G_shared_feat": true, 36 | 37 | "k_nn": 50, 38 | "feature_extractor": "selfsupervised", 39 | 40 | "batch_size": 64, 41 | "D_lr": 1e-4, 42 | "G_lr": 4e-5, 43 | "G_ch": 96, 44 | "D_ch": 96, 45 | 46 | "load_weights": "" 47 | 48 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet/IC-GAN/icgan_res256.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "icgan_biggan_imagenet_res256", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 4, 9 | "n_gpus_per_node": 8, 10 | "hflips": true, 11 | "feature_augmentation": false, 12 | 13 | "test_every": 5, 14 | "save_every": 1, 15 | "num_epochs": 3000, 16 | "es_patience": 50, 17 | "shuffle": true, 18 | 19 | "G_eval_mode": true, 20 | "ema": true, 21 | "use_ema": true, 22 | "num_G_accumulations": 4, 23 | "num_D_accumulations": 4, 24 | "num_D_steps": 1, 25 | 26 | "class_cond": false, 27 | "instance_cond": true, 28 | "hier": true, 29 | "resolution": 256, 30 | "G_attn": "64", 31 | "D_attn": "64", 32 | "shared_dim": 128, 33 | "shared_dim_feat": 512, 34 | "G_shared": true, 35 | "G_shared_feat": true, 36 | 37 | "k_nn": 50, 38 | "feature_extractor": "selfsupervised", 39 | 40 | "batch_size": 16, 41 | "D_lr": 1e-4, 42 | "G_lr": 4e-5, 43 | "G_ch": 96, 44 | "D_ch": 96, 45 | 46 | "load_weights": "" 47 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet/IC-GAN/icgan_res256_halfcap.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "icgan_biggan_imagenet_res256_halfcap", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 4, 9 | "n_gpus_per_node": 8, 10 | "hflips": true, 11 | "feature_augmentation": true, 12 | 13 | "test_every": 5, 14 | "save_every": 1, 15 | "num_epochs": 3000, 16 | "es_patience": 50, 17 | "shuffle": true, 18 | 19 | "G_eval_mode": true, 20 | "ema": true, 21 | "use_ema": true, 22 | "num_G_accumulations": 4, 23 | "num_D_accumulations": 4, 24 | "num_D_steps": 2, 25 | 26 | "class_cond": false, 27 | "instance_cond": true, 28 | "hier": true, 29 | "resolution": 256, 30 | "G_attn": "64", 31 | "D_attn": "64", 32 | "shared_dim": 128, 33 | "shared_dim_feat": 512, 34 | "G_shared": true, 35 | "G_shared_feat": true, 36 | 37 | "k_nn": 50, 38 | "feature_extractor": "selfsupervised", 39 | 40 | "batch_size": 16, 41 | "D_lr": 1e-4, 42 | "G_lr": 1e-4, 43 | "G_ch": 64, 44 | "D_ch": 64, 45 | 46 | "load_weights": "" 47 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet/IC-GAN/icgan_res64.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "icgan_biggan_imagenet_res64", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 1, 9 | "n_gpus_per_node": 1, 10 | "hflips": true, 11 | "feature_augmentation": true, 12 | 13 | "test_every": 1, 14 | "save_every": 1, 15 | "num_epochs": 3000, 16 | "es_patience": 50, 17 | "shuffle": true, 18 | 19 | "G_eval_mode": true, 20 | "ema": true, 21 | "use_ema": true, 22 | "num_G_accumulations": 1, 23 | "num_D_accumulations": 1, 24 | "num_D_steps": 1, 25 | 26 | "class_cond": false, 27 | "instance_cond": true, 28 | "hier": true, 29 | "resolution": 64, 30 | "G_attn": "32", 31 | "D_attn": "32", 32 | "shared_dim": 128, 33 | "shared_dim_feat": 512, 34 | "G_shared": true, 35 | "G_shared_feat": true, 36 | 37 | "k_nn": 50, 38 | "feature_extractor": "selfsupervised", 39 | 40 | "batch_size": 256, 41 | "D_lr": 1e-4, 42 | "G_lr": 1e-4, 43 | "G_ch": 64, 44 | "D_ch": 64, 45 | 46 | "load_weights": "" 47 | 48 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet/cc_IC-GAN/cc_icgan_res128.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "cc_icgan_biggan_imagenet_res128", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 4, 9 | "n_gpus_per_node": 8, 10 | "hflips": true, 11 | "feature_augmentation": true, 12 | 13 | "test_every": 5, 14 | "save_every": 1, 15 | "num_epochs": 3000, 16 | "es_patience": 50, 17 | "shuffle": true, 18 | 19 | "G_eval_mode": true, 20 | "ema": true, 21 | "use_ema": true, 22 | "num_G_accumulations": 1, 23 | "num_D_accumulations": 1, 24 | "num_D_steps": 1, 25 | 26 | "class_cond": true, 27 | "instance_cond": true, 28 | "hier": true, 29 | "resolution": 128, 30 | "G_attn": "64", 31 | "D_attn": "64", 32 | "shared_dim": 128, 33 | "shared_dim_feat": 512, 34 | "G_shared": true, 35 | "G_shared_feat": true, 36 | 37 | "k_nn": 50, 38 | "feature_extractor": "classification", 39 | 40 | "batch_size": 64, 41 | "D_lr": 1e-4, 42 | "G_lr": 4e-5, 43 | "G_ch": 96, 44 | "D_ch": 96, 45 | 46 | "load_weights": "" 47 | 48 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet/cc_IC-GAN/cc_icgan_res256.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "cc_icgan_biggan_imagenet_res256", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 4, 9 | "n_gpus_per_node": 8, 10 | "hflips": true, 11 | "feature_augmentation": false, 12 | 13 | "test_every": 5, 14 | "save_every": 1, 15 | "num_epochs": 3000, 16 | "es_patience": 50, 17 | "shuffle": true, 18 | 19 | "G_eval_mode": true, 20 | "ema": true, 21 | "use_ema": true, 22 | "num_G_accumulations": 4, 23 | "num_D_accumulations": 4, 24 | "num_D_steps": 1, 25 | 26 | "class_cond": true, 27 | "instance_cond": true, 28 | "hier": true, 29 | "resolution": 256, 30 | "G_attn": "64", 31 | "D_attn": "64", 32 | "shared_dim": 128, 33 | "shared_dim_feat": 512, 34 | "G_shared": true, 35 | "G_shared_feat": true, 36 | 37 | "k_nn": 50, 38 | "feature_extractor": "classification", 39 | 40 | "batch_size": 16, 41 | "D_lr": 1e-4, 42 | "G_lr": 4e-5, 43 | "G_ch": 96, 44 | "D_ch": 96, 45 | 46 | "load_weights": "" 47 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet/cc_IC-GAN/cc_icgan_res256_halfcap.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "cc_icgan_biggan_imagenet_res256_halfcap", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 4, 9 | "n_gpus_per_node": 8, 10 | "hflips": true, 11 | "feature_augmentation": true, 12 | 13 | "test_every": 5, 14 | "save_every": 1, 15 | "num_epochs": 3000, 16 | "es_patience": 50, 17 | "shuffle": true, 18 | 19 | "G_eval_mode": true, 20 | "ema": true, 21 | "use_ema": true, 22 | "num_G_accumulations": 4, 23 | "num_D_accumulations": 4, 24 | "num_D_steps": 2, 25 | 26 | "class_cond": true, 27 | "instance_cond": true, 28 | "hier": true, 29 | "resolution": 256, 30 | "G_attn": "64", 31 | "D_attn": "64", 32 | "shared_dim": 128, 33 | "shared_dim_feat": 512, 34 | "G_shared": true, 35 | "G_shared_feat": true, 36 | 37 | "k_nn": 50, 38 | "feature_extractor": "classification", 39 | 40 | "batch_size": 16, 41 | "D_lr": 1e-4, 42 | "G_lr": 1e-4, 43 | "G_ch": 64, 44 | "D_ch": 64, 45 | 46 | "load_weights": "" 47 | 48 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/config_files/ImageNet/cc_IC-GAN/cc_icgan_res64.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "cc_icgan_biggan_imagenet_res64", 3 | "run_setup": "local_debug", 4 | "deterministic_run": true, 5 | "num_workers": 10, 6 | 7 | "ddp_train": true, 8 | "n_nodes": 1, 9 | "n_gpus_per_node": 1, 10 | "hflips": true, 11 | "feature_augmentation": true, 12 | 13 | "test_every": 1, 14 | "save_every": 1, 15 | "num_epochs": 3000, 16 | "es_patience": 50, 17 | "shuffle": true, 18 | 19 | "G_eval_mode": true, 20 | "ema": true, 21 | "use_ema": true, 22 | "num_G_accumulations": 1, 23 | "num_D_accumulations": 1, 24 | "num_D_steps": 1, 25 | 26 | "class_cond": true, 27 | "instance_cond": true, 28 | "hier": true, 29 | "resolution": 64, 30 | "G_attn": "32", 31 | "D_attn": "32", 32 | "shared_dim": 128, 33 | "shared_dim_feat": 512, 34 | "G_shared": true, 35 | "G_shared_feat": true, 36 | 37 | "k_nn": 50, 38 | "feature_extractor": "classification", 39 | 40 | "batch_size": 256, 41 | "D_lr": 1e-4, 42 | "G_lr": 1e-4, 43 | "G_ch": 64, 44 | "D_ch": 64, 45 | 46 | "load_weights": "" 47 | 48 | } -------------------------------------------------------------------------------- /BigGAN_PyTorch/diffaugment_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are met: 9 | 10 | # * Redistributions of source code must retain the above copyright notice, this 11 | # list of conditions and the following disclaimer. 12 | # 13 | # * Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | # 17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 21 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 23 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 24 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 25 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | import torch 29 | import torch.nn.functional as F 30 | 31 | 32 | def DiffAugment(x, policy="", channels_first=True): 33 | if policy: 34 | if not channels_first: 35 | x = x.permute(0, 3, 1, 2) 36 | for p in policy.split(","): 37 | for f in AUGMENT_FNS[p]: 38 | x = f(x) 39 | if not channels_first: 40 | x = x.permute(0, 2, 3, 1) 41 | x = x.contiguous() 42 | return x 43 | 44 | 45 | def rand_brightness(x): 46 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 47 | return x 48 | 49 | 50 | def rand_saturation(x): 51 | x_mean = x.mean(dim=1, keepdim=True) 52 | x = (x - x_mean) * ( 53 | torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2 54 | ) + x_mean 55 | return x 56 | 57 | 58 | def rand_contrast(x): 59 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 60 | x = (x - x_mean) * ( 61 | torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5 62 | ) + x_mean 63 | return x 64 | 65 | 66 | def rand_translation(x, ratio=0.125): 67 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 68 | translation_x = torch.randint( 69 | -shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device 70 | ) 71 | translation_y = torch.randint( 72 | -shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device 73 | ) 74 | grid_batch, grid_x, grid_y = torch.meshgrid( 75 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 76 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 77 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 78 | ) 79 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 80 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 81 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 82 | x = ( 83 | x_pad.permute(0, 2, 3, 1) 84 | .contiguous()[grid_batch, grid_x, grid_y] 85 | .permute(0, 3, 1, 2) 86 | ) 87 | return x 88 | 89 | 90 | def rand_cutout(x, ratio=0.5): 91 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 92 | offset_x = torch.randint( 93 | 0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device 94 | ) 95 | offset_y = torch.randint( 96 | 0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device 97 | ) 98 | grid_batch, grid_x, grid_y = torch.meshgrid( 99 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 100 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 101 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 102 | ) 103 | grid_x = torch.clamp( 104 | grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1 105 | ) 106 | grid_y = torch.clamp( 107 | grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1 108 | ) 109 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 110 | mask[grid_batch, grid_x, grid_y] = 0 111 | x = x * mask.unsqueeze(1) 112 | return x 113 | 114 | 115 | AUGMENT_FNS = { 116 | "color": [rand_brightness, rand_saturation, rand_contrast], 117 | "translation": [rand_translation], 118 | "cutout": [rand_cutout], 119 | } 120 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/imgs/D Singular Values.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/BigGAN_PyTorch/imgs/D Singular Values.png -------------------------------------------------------------------------------- /BigGAN_PyTorch/imgs/DeepSamples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/BigGAN_PyTorch/imgs/DeepSamples.png -------------------------------------------------------------------------------- /BigGAN_PyTorch/imgs/DogBall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/BigGAN_PyTorch/imgs/DogBall.png -------------------------------------------------------------------------------- /BigGAN_PyTorch/imgs/G Singular Values.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/BigGAN_PyTorch/imgs/G Singular Values.png -------------------------------------------------------------------------------- /BigGAN_PyTorch/imgs/IS_FID.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/BigGAN_PyTorch/imgs/IS_FID.png -------------------------------------------------------------------------------- /BigGAN_PyTorch/imgs/Losses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/BigGAN_PyTorch/imgs/Losses.png -------------------------------------------------------------------------------- /BigGAN_PyTorch/imgs/header_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/BigGAN_PyTorch/imgs/header_image.jpg -------------------------------------------------------------------------------- /BigGAN_PyTorch/imgs/interp_sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/BigGAN_PyTorch/imgs/interp_sample.jpg -------------------------------------------------------------------------------- /BigGAN_PyTorch/logs/compare_IS.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) Facebook, Inc. and its affiliates. 2 | % All rights reserved. 3 | % 4 | % All contributions by Andy Brock: 5 | % Copyright (c) 2019 Andy Brock 6 | % 7 | % MIT License 8 | 9 | clc 10 | clear all 11 | close all 12 | fclose all; 13 | 14 | 15 | 16 | %% Get All logs and sort them 17 | s = {}; 18 | d = dir(); 19 | j = 1; 20 | for i = 1:length(d) 21 | if any(strfind(d(i).name,'.jsonl')) 22 | s = [s; d(i).name]; 23 | end 24 | end 25 | 26 | 27 | j = 1; 28 | for i = 1:length(s) 29 | fname = s{i,1}; 30 | % Check if the Inception metrics log exists, and if so, plot it 31 | [itr, IS, FID, t] = process_inception_log(fname(1:end - 10), 'log.jsonl'); 32 | s{i,2} = itr; 33 | s{i,3} = IS; 34 | s{i,4} = FID; 35 | s{i,5} = max(IS); 36 | s{i,6} = min(FID); 37 | s{i,7} = t; 38 | end 39 | % Sort by Inception Score? 40 | [IS_sorted, IS_index] = sort(cell2mat(s(:,5))); 41 | % Cutoff inception scores below a certain value? 42 | threshold = 22; 43 | IS_index = IS_index(IS_sorted > threshold); 44 | 45 | % Sort by FID? 46 | [FID_sorted, FID_index] = sort(cell2mat(s(:,6))); 47 | % Cutoff also based on IS? 48 | % threshold = 0; 49 | FID_index = FID_index(IS_sorted > threshold); 50 | 51 | 52 | 53 | %% Plot things? 54 | cc = hsv(length(IS_index)); 55 | legend1 = {}; 56 | legend2 = {}; 57 | make_axis=true;%false % Turn this on to see the axis out to 1e6 iterations 58 | for i=1:length(IS_index) 59 | legend1 = [legend1; s{IS_index(i), 1}]; 60 | figure(1) 61 | plot(s{IS_index(i),2}, s{IS_index(i),3}, 'color', cc(i,:),'linewidth',2) 62 | hold on; 63 | xlabel('itr'); ylabel('IS'); 64 | grid on; 65 | if make_axis 66 | axis([0,1e6,0,80]); % 50% grid on; 67 | end 68 | legend(legend1,'Interpreter','none') 69 | %pause(1) % Turn this on to animate stuff 70 | legend2 = [legend2; s{IS_index(i), 1}]; 71 | figure(2) 72 | plot(s{IS_index(i),2}, s{IS_index(i),4}, 'color', cc(i,:),'linewidth',2) 73 | hold on; 74 | xlabel('itr'); ylabel('FID'); 75 | j = j + 1; 76 | grid on; 77 | if make_axis 78 | axis([0,1e6,0,50]);% grid on; 79 | end 80 | legend(legend2, 'Interpreter','none') 81 | 82 | end 83 | 84 | %% Quick script to plot IS versus timesteps 85 | if 0 86 | figure(3); 87 | this_index=4; 88 | subplot(2,1,1); 89 | %plot(s{this_index, 2}(2:end), s{this_index, 7}(2:end) - s{this_index, 7}(1:end-1), 'r*'); 90 | % xlabel('Iteration');ylabel('\Delta T') 91 | plot(s{this_index, 2}, s{this_index, 7}, 'r*'); 92 | xlabel('Iteration');ylabel('T') 93 | subplot(2,1,2); 94 | plot(s{this_index, 2}, s{this_index, 3}, 'r', 'linewidth',2); 95 | xlabel('Iteration'), ylabel('Inception score') 96 | title(s{this_index,1}) 97 | end 98 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/logs/metalog.txt: -------------------------------------------------------------------------------- 1 | datetime: 2019-03-18 13:27:59.181225 2 | config: {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 8, 'pin_memory': True, 'shuffle': True, 'load_in_mem': True, 'use_multiepoch_sampler': True, 'model': 'model', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'G_depth': 1, 'D_depth': 1, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'z_var': 1.0, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': True, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 0, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 400, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'G_eval_mode': True, 'save_every': 500, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'dataset_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': 'Jade_BigGAN_B1_bs256x8_fp32', 'config_from_name': False, 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': True, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace), 'D_activation': ReLU(inplace)} 3 | state: {'itr': 137500, 'epoch': 2, 'save_num': 0, 'save_best_num': 1, 'best_IS': 91.509384, 'best_FID': tensor(9.7711, 'config': {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 8, 'pin_memory': True, 'shuffle': True, 'load_in_mem': True, 'use_multiepoch_sampler': True, 'model': 'model', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': False, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 0, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 100, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'BN_sync': False, 'G_eval_mode': True, 'save_every': 500, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'dataset_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': 'Jade_BigGAN_B1_bs256x8_fp32', 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': False, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace), 'D_activation': ReLU(inplace)}} 4 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/logs/process_inception_log.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) Facebook, Inc. and its affiliates. 2 | % All rights reserved. 3 | % 4 | % All contributions by Andy Brock: 5 | % Copyright (c) 2019 Andy Brock 6 | % 7 | % MIT License 8 | % 9 | function [itr, IS, FID, t] = process_inception_log(fname, which_log) 10 | f = sprintf('%s_%s',fname, which_log);%'G_loss.log'); 11 | fid = fopen(f,'r'); 12 | itr = []; 13 | IS = []; 14 | FID = []; 15 | t = []; 16 | i = 1; 17 | while ~feof(fid); 18 | s = fgets(fid); 19 | parsed = sscanf(s,'{"itr": %d, "IS_mean": %f, "IS_std": %f, "FID": %f, "_stamp": %f}'); 20 | itr(i) = parsed(1); 21 | IS(i) = parsed(2); 22 | FID(i) = parsed(4); 23 | t(i) = parsed(5); 24 | i = i + 1; 25 | end 26 | fclose(fid); 27 | end 28 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/logs/process_training.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) Facebook, Inc. and its affiliates. 2 | % All rights reserved. 3 | % 4 | % All contributions by Andy Brock: 5 | % Copyright (c) 2019 Andy Brock 6 | % 7 | % MIT License 8 | % 9 | clc 10 | clear all 11 | close all 12 | fclose all; 13 | 14 | 15 | 16 | %% Get all training logs for a given run 17 | target_dir = '.'; 18 | s = {}; 19 | nm = {}; 20 | d = dir(target_dir); 21 | j = 1; 22 | for i = 1:length(d) 23 | if any(strfind(d(i).name,'.log')) 24 | s = [s; sprintf('%s\\%s', target_dir, d(i).name)]; 25 | nm = [nm; d(i).name]; 26 | end 27 | end 28 | %% Loop over training logs and acquire data 29 | D_count = 0; 30 | G_count = 0; 31 | for i = 1:length(s) 32 | fname = s{i,1}; 33 | fid = fopen(s{i,1},'r'); 34 | % Prepare bookkeeping for sv0 35 | if any(strfind(s{i,1},'sv')) 36 | if any(strfind(s{i,1},'G_')) 37 | G_count = G_count +1; 38 | else 39 | D_count = D_count + 1; 40 | end 41 | end 42 | itr = []; 43 | val = []; 44 | j = 1; 45 | while ~feof(fid); 46 | line = fgets(fid); 47 | parsed = sscanf(line, '%d: %e'); 48 | itr(j) = parsed(1); 49 | val(j) = parsed(2); 50 | j = j + 1; 51 | end 52 | s{i,2} = itr; 53 | s{i,3} = val; 54 | fclose(fid); 55 | end 56 | 57 | %% Plot SVs and losses 58 | close all; 59 | Gcc = hsv(G_count); 60 | Dcc = hsv(D_count); 61 | gi = 1; 62 | di = 1; 63 | li = 1; 64 | legendG = {}; 65 | legendD = {}; 66 | legendL = {}; 67 | thresh=2; % wavelet denoising threshold 68 | losses = {}; 69 | for i=1:length(s) 70 | if any(strfind(s{i,1},'D_loss_real.log')) || any(strfind(s{i,1},'D_loss_fake.log')) || any(strfind(s{i,1},'G_loss.log')) 71 | % Select colors 72 | if any(strfind(s{i,1},'D_loss_real.log')) 73 | color1 = [0.7,0.7,1.0]; 74 | color2 = [0, 0, 1]; 75 | dlr = {s{i,2}, s{i,3}, wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1, color2}; 76 | losses = [losses; dlr]; 77 | elseif any(strfind(s{i,1},'D_loss_fake.log')) 78 | color1 = [0.7,1.0,0.7]; 79 | color2 = [0, 1, 0]; 80 | dlf = {s{i,2},s{i,3} wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1, color2}; 81 | losses = [losses; dlf]; 82 | else % g loss 83 | color1 = [1.0, 0.7,0.7]; 84 | color2 = [1, 0, 0]; 85 | gl = {s{i,2},s{i,3}, wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1 color2}; 86 | losses = [losses; gl]; 87 | end 88 | figure(1); hold on; 89 | % Plot the unsmoothed losses; we'll plot the smoothed losses later 90 | plot(s{i,2},s{i,3},'color', color1, 'HandleVisibility','off'); 91 | legendL = [legendL; nm{i}]; 92 | continue 93 | end 94 | if any(strfind(s{i,1},'G_')) 95 | legendG = [legendG; nm{i}]; 96 | figure(2); hold on; 97 | plot(s{i,2},s{i,3},'color',Gcc(gi,:),'linewidth',2); 98 | gi = gi+1; 99 | elseif any(strfind(s{i,1},'D_')) 100 | legendD = [legendD; nm{i}]; 101 | figure(3); hold on; 102 | plot(s{i,2},s{i,3},'color',Dcc(di,:),'linewidth',2); 103 | di = di+1; 104 | else 105 | s{i,1} % Debug print to show the name of the log that was not processed. 106 | end 107 | end 108 | figure(1); 109 | % Plot the smoothed losses last 110 | for i = 1:3 111 | % plot(losses{i,1}, losses{i,2},'color', losses{i,4}, 'HandleVisibility','off'); 112 | plot(losses{i,1},losses{i,3},'color',losses{i,5}); 113 | end 114 | legend(legendL, 'Interpreter', 'none'); title('Losses'); xlabel('Generator itr'); ylabel('loss'); axis([0, max(s{end,2}), -1, 4]); 115 | 116 | figure(2); legend(legendG,'Interpreter','none'); title('Singular Values in G'); xlabel('Generator itr'); ylabel('SV0'); 117 | figure(3); legend(legendD, 'Interpreter', 'none'); title('Singular Values in D'); xlabel('Generator itr'); ylabel('SV0'); 118 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # All contributions by Andy Brock: 5 | # Copyright (c) 2019 Andy Brock 6 | # 7 | # MIT License 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | # DCGAN loss 12 | def loss_dcgan_dis(dis_fake, dis_real): 13 | L1 = torch.mean(F.softplus(-dis_real)) 14 | L2 = torch.mean(F.softplus(dis_fake)) 15 | return L1, L2 16 | 17 | 18 | def loss_dcgan_gen(dis_fake): 19 | loss = torch.mean(F.softplus(-dis_fake)) 20 | return loss 21 | 22 | 23 | # Hinge Loss 24 | def loss_hinge_dis(dis_fake, dis_real): 25 | loss_real = torch.mean(F.relu(1.0 - dis_real)) 26 | loss_fake = torch.mean(F.relu(1.0 + dis_fake)) 27 | return loss_real, loss_fake 28 | 29 | 30 | # def loss_hinge_dis(dis_fake, dis_real): # This version returns a single loss 31 | # loss = torch.mean(F.relu(1. - dis_real)) 32 | # loss += torch.mean(F.relu(1. + dis_fake)) 33 | # return loss 34 | 35 | 36 | def loss_hinge_gen(dis_fake): 37 | loss = -torch.mean(dis_fake) 38 | return loss 39 | 40 | 41 | # Default to hinge loss 42 | generator_loss = loss_hinge_gen 43 | discriminator_loss = loss_hinge_dis 44 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/make_hdf5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # All contributions by Andy Brock: 5 | # Copyright (c) 2019 Andy Brock 6 | # 7 | # MIT License 8 | """ Convert dataset to HDF5 9 | This script preprocesses a dataset and saves it (images and labels) to 10 | an HDF5 file for improved I/O. """ 11 | import os 12 | import sys 13 | from argparse import ArgumentParser 14 | from tqdm import tqdm, trange 15 | import h5py as h5 16 | 17 | import numpy as np 18 | import torch 19 | import torchvision.datasets as dset 20 | import torchvision.transforms as transforms 21 | from torchvision.utils import save_image 22 | import torchvision.transforms as transforms 23 | from torch.utils.data import DataLoader 24 | 25 | import utils 26 | 27 | 28 | def prepare_parser(): 29 | usage = "Parser for ImageNet HDF5 scripts." 30 | parser = ArgumentParser(description=usage) 31 | parser.add_argument( 32 | "--resolution", 33 | type=int, 34 | default=128, 35 | help="Which Dataset resolution to train on, out of 64, 128, 256, 512 (default: %(default)s)", 36 | ) 37 | parser.add_argument( 38 | "--split", 39 | type=str, 40 | default="train", 41 | help="Which Dataset to convert: train, val (default: %(default)s)", 42 | ) 43 | parser.add_argument( 44 | "--data_root", 45 | type=str, 46 | default="data", 47 | help="Default location where data is stored (default: %(default)s)", 48 | ) 49 | parser.add_argument( 50 | "--out_path", 51 | type=str, 52 | default="data", 53 | help="Default location where data in hdf5 format will be stored (default: %(default)s)", 54 | ) 55 | parser.add_argument( 56 | "--longtail", 57 | action="store_true", 58 | default=False, 59 | help="Use long-tail version of the dataset", 60 | ) 61 | parser.add_argument( 62 | "--batch_size", 63 | type=int, 64 | default=256, 65 | help="Default overall batchsize (default: %(default)s)", 66 | ) 67 | parser.add_argument( 68 | "--num_workers", 69 | type=int, 70 | default=16, 71 | help="Number of dataloader workers (default: %(default)s)", 72 | ) 73 | parser.add_argument( 74 | "--chunk_size", 75 | type=int, 76 | default=500, 77 | help="Default overall batchsize (default: %(default)s)", 78 | ) 79 | parser.add_argument( 80 | "--compression", 81 | action="store_true", 82 | default=False, 83 | help="Use LZF compression? (default: %(default)s)", 84 | ) 85 | return parser 86 | 87 | 88 | def run(config): 89 | # Get image size 90 | 91 | # Update compression entry 92 | config["compression"] = ( 93 | "lzf" if config["compression"] else None 94 | ) # No compression; can also use 'lzf' 95 | 96 | # Get dataset 97 | kwargs = { 98 | "num_workers": config["num_workers"], 99 | "pin_memory": False, 100 | "drop_last": False, 101 | } 102 | dataset = utils.get_dataset_images( 103 | config["resolution"], 104 | data_path=os.path.join(config["data_root"], config["split"]), 105 | longtail=config["longtail"], 106 | ) 107 | train_loader = utils.get_dataloader( 108 | dataset, config["batch_size"], shuffle=False, **kwargs 109 | ) 110 | 111 | # HDF5 supports chunking and compression. You may want to experiment 112 | # with different chunk sizes to see how it runs on your machines. 113 | # Chunk Size/compression Read speed @ 256x256 Read speed @ 128x128 Filesize @ 128x128 Time to write @128x128 114 | # 1 / None 20/s 115 | # 500 / None ramps up to 77/s 102/s 61GB 23min 116 | # 500 / LZF 8/s 56GB 23min 117 | # 1000 / None 78/s 118 | # 5000 / None 81/s 119 | # auto:(125,1,16,32) / None 11/s 61GB 120 | 121 | print( 122 | "Starting to load dataset into an HDF5 file with chunk size %i and compression %s..." 123 | % (config["chunk_size"], config["compression"]) 124 | ) 125 | # Loop over train loader 126 | for i, (x, y) in enumerate(tqdm(train_loader)): 127 | # Stick X into the range [0, 255] since it's coming from the train loader 128 | x = (255 * ((x + 1) / 2.0)).byte().numpy() 129 | # Numpyify y 130 | y = y.numpy() 131 | # If we're on the first batch, prepare the hdf5 132 | if i == 0: 133 | with h5.File( 134 | config["out_path"] 135 | + "/ILSVRC%i%s_xy.hdf5" 136 | % (config["resolution"], "" if not config["longtail"] else "longtail"), 137 | "w", 138 | ) as f: 139 | print("Producing dataset of len %d" % len(train_loader.dataset)) 140 | imgs_dset = f.create_dataset( 141 | "imgs", 142 | x.shape, 143 | dtype="uint8", 144 | maxshape=( 145 | len(train_loader.dataset), 146 | 3, 147 | config["resolution"], 148 | config["resolution"], 149 | ), 150 | chunks=( 151 | config["chunk_size"], 152 | 3, 153 | config["resolution"], 154 | config["resolution"], 155 | ), 156 | compression=config["compression"], 157 | ) 158 | print("Image chunks chosen as " + str(imgs_dset.chunks)) 159 | imgs_dset[...] = x 160 | labels_dset = f.create_dataset( 161 | "labels", 162 | y.shape, 163 | dtype="int64", 164 | maxshape=(len(train_loader.dataset),), 165 | chunks=(config["chunk_size"],), 166 | compression=config["compression"], 167 | ) 168 | print("Label chunks chosen as " + str(labels_dset.chunks)) 169 | labels_dset[...] = y 170 | # Else append to the hdf5 171 | else: 172 | with h5.File( 173 | config["out_path"] 174 | + "/ILSVRC%i%s_xy.hdf5" 175 | % (config["resolution"], "" if not config["longtail"] else "longtail"), 176 | "a", 177 | ) as f: 178 | f["imgs"].resize(f["imgs"].shape[0] + x.shape[0], axis=0) 179 | f["imgs"][-x.shape[0] :] = x 180 | f["labels"].resize(f["labels"].shape[0] + y.shape[0], axis=0) 181 | f["labels"][-y.shape[0] :] = y 182 | 183 | 184 | def main(): 185 | # parse command line and run 186 | parser = prepare_parser() 187 | config = vars(parser.parse_args()) 188 | print(config) 189 | run(config) 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import utils 8 | from trainer import run 9 | from submitit.helpers import Checkpointable 10 | 11 | LOCAL = False 12 | try: 13 | import submitit 14 | except: 15 | print( 16 | "No submitit package found! Defaulting to executing the script in the local machine" 17 | ) 18 | LOCAL = True 19 | import json 20 | 21 | 22 | class Trainer(Checkpointable): 23 | def __call__(self, config): 24 | if config["run_setup"] == "local_debug" or LOCAL: 25 | run(config, "local_debug") 26 | else: 27 | run(config, "slurm", master_node=submitit.JobEnvironment().hostnames[0]) 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = utils.prepare_parser() 32 | config = vars(parser.parse_args()) 33 | 34 | if config["json_config"] != "": 35 | data = json.load(open(config["json_config"])) 36 | for key in data.keys(): 37 | config[key] = data[key] 38 | else: 39 | print("Not using JSON configuration file!") 40 | config["G_batch_size"] = config["batch_size"] 41 | config["batch_size"] = ( 42 | config["batch_size"] * config["num_D_accumulations"] * config["num_D_steps"] 43 | ) 44 | 45 | trainer = Trainer() 46 | if config["run_setup"] == "local_debug" or LOCAL: 47 | trainer(config) 48 | else: 49 | print( 50 | "Using ", 51 | config["n_nodes"], 52 | " nodes and ", 53 | config["n_gpus_per_node"], 54 | " GPUs per node.", 55 | ) 56 | executor = submitit.SlurmExecutor( 57 | folder=config["slurm_logdir"], max_num_timeout=60 58 | ) 59 | executor.update_parameters( 60 | gpus_per_node=config["n_gpus_per_node"], 61 | partition=config["partition"], 62 | constraint="volta32gb", 63 | nodes=config["n_nodes"], 64 | ntasks_per_node=config["n_gpus_per_node"], 65 | cpus_per_task=8, 66 | mem=256000, 67 | time=3200, 68 | job_name=config["experiment_name"], 69 | exclusive=True if config["n_gpus_per_node"] == 8 else False, 70 | ) 71 | 72 | executor.submit(trainer, config) 73 | import time 74 | 75 | time.sleep(1) 76 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/scripts/launch_BigGAN_bs256x8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # All contributions by Andy Brock: 7 | # Copyright (c) 2019 Andy Brock 8 | # 9 | # MIT License 10 | # 11 | python train.py \ 12 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 --load_in_mem \ 13 | --num_G_accumulations 8 --num_D_accumulations 8 \ 14 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 15 | --G_attn 64 --D_attn 64 \ 16 | --G_nl inplace_relu --D_nl inplace_relu \ 17 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 18 | --G_ortho 0.0 \ 19 | --G_shared \ 20 | --G_init ortho --D_init ortho \ 21 | --hier --dim_z 120 --shared_dim 128 \ 22 | --G_eval_mode \ 23 | --G_ch 96 --D_ch 96 \ 24 | --ema --use_ema --ema_start 20000 \ 25 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 26 | --use_multiepoch_sampler \ 27 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/scripts/launch_BigGAN_bs512x4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # All contributions by Andy Brock: 6 | # Copyright (c) 2019 Andy Brock 7 | # 8 | # MIT License 9 | # 10 | python train.py \ 11 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 512 --load_in_mem \ 12 | --num_G_accumulations 4 --num_D_accumulations 4 \ 13 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 14 | --G_attn 64 --D_attn 64 \ 15 | --G_nl inplace_relu --D_nl inplace_relu \ 16 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 17 | --G_ortho 0.0 \ 18 | --G_shared \ 19 | --G_init ortho --D_init ortho \ 20 | --hier --dim_z 120 --shared_dim 128 \ 21 | --G_eval_mode \ 22 | --G_ch 96 --D_ch 96 \ 23 | --ema --use_ema --ema_start 20000 \ 24 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 25 | --use_multiepoch_sampler \ 26 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/scripts/launch_BigGAN_ch64_bs256x8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # All contributions by Andy Brock: 7 | # Copyright (c) 2019 Andy Brock 8 | # 9 | # MIT License 10 | # 11 | python train.py \ 12 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 --load_in_mem \ 13 | --num_G_accumulations 8 --num_D_accumulations 8 \ 14 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 15 | --G_attn 64 --D_attn 64 \ 16 | --G_nl inplace_relu --D_nl inplace_relu \ 17 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 18 | --G_ortho 0.0 \ 19 | --G_shared \ 20 | --G_init ortho --D_init ortho \ 21 | --hier --dim_z 120 --shared_dim 128 \ 22 | --G_eval_mode \ 23 | --G_ch 64 --G_ch 64 \ 24 | --ema --use_ema --ema_start 20000 \ 25 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 26 | --use_multiepoch_sampler 27 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/scripts/launch_BigGAN_deep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # All contributions by Andy Brock: 7 | # Copyright (c) 2019 Andy Brock 8 | # 9 | # MIT License 10 | # 11 | python train.py \ 12 | --model BigGANdeep \ 13 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 \ 14 | --num_G_accumulations 8 --num_D_accumulations 8 \ 15 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 16 | --G_attn 64 --D_attn 64 \ 17 | --G_ch 128 --D_ch 128 \ 18 | --G_depth 2 --D_depth 2 \ 19 | --G_nl inplace_relu --D_nl inplace_relu \ 20 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 21 | --G_ortho 0.0 \ 22 | --G_shared \ 23 | --G_init ortho --D_init ortho \ 24 | --hier --dim_z 128 --shared_dim 128 \ 25 | --ema --use_ema --ema_start 20000 --G_eval_mode \ 26 | --test_every 2000 --save_every 500 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 27 | --use_multiepoch_sampler \ 28 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/scripts/launch_SAGAN_bs128x2_ema.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # All contributions by Andy Brock: 7 | # Copyright (c) 2019 Andy Brock 8 | # 9 | # MIT License 10 | # 11 | python train.py \ 12 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 128 \ 13 | --num_G_accumulations 2 --num_D_accumulations 2 \ 14 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 15 | --G_attn 64 --D_attn 64 \ 16 | --G_nl relu --D_nl relu \ 17 | --SN_eps 1e-8 --BN_eps 1e-5 --adam_eps 1e-8 \ 18 | --G_ortho 0.0 \ 19 | --G_init xavier --D_init xavier \ 20 | --ema --use_ema --ema_start 2000 --G_eval_mode \ 21 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 22 | --name_suffix SAGAN_ema \ 23 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/scripts/launch_SNGAN.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # All contributions by Andy Brock: 7 | # Copyright (c) 2019 Andy Brock 8 | # 9 | # MIT License 10 | # 11 | python train.py \ 12 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 64 \ 13 | --num_G_accumulations 1 --num_D_accumulations 1 \ 14 | --num_D_steps 5 --G_lr 2e-4 --D_lr 2e-4 --D_B2 0.900 --G_B2 0.900 \ 15 | --G_attn 0 --D_attn 0 \ 16 | --G_nl relu --D_nl relu \ 17 | --SN_eps 1e-8 --BN_eps 1e-5 --adam_eps 1e-8 \ 18 | --G_ortho 0.0 \ 19 | --D_thin \ 20 | --G_init xavier --D_init xavier \ 21 | --G_eval_mode \ 22 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 23 | --name_suffix SNGAN \ 24 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/scripts/launch_cifar_ema.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # All contributions by Andy Brock: 7 | # Copyright (c) 2019 Andy Brock 8 | # 9 | # MIT License 10 | # 11 | CUDA_VISIBLE_DEVICES=0,1 python train.py \ 12 | --shuffle --batch_size 50 --parallel \ 13 | --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 \ 14 | --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 \ 15 | --dataset C10 \ 16 | --G_ortho 0.0 \ 17 | --G_attn 0 --D_attn 0 \ 18 | --G_init N02 --D_init N02 \ 19 | --ema --use_ema --ema_start 1000 \ 20 | --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --seed 0 21 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/scripts/sample_BigGAN_bs256x8.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # All contributions by Andy Brock: 5 | # Copyright (c) 2019 Andy Brock 6 | # 7 | # MIT License 8 | # 9 | # use z_var to change the variance of z for all the sampling 10 | # use --mybn --accumulate_stats --num_standing_accumulations 32 to 11 | # use running stats 12 | python sample.py \ 13 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 \ 14 | --num_G_accumulations 8 --num_D_accumulations 8 \ 15 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 16 | --G_attn 64 --D_attn 64 \ 17 | --G_ch 96 --D_ch 96 \ 18 | --G_nl inplace_relu --D_nl inplace_relu \ 19 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 20 | --G_ortho 0.0 \ 21 | --G_shared \ 22 | --G_init ortho --D_init ortho --skip_init \ 23 | --hier --dim_z 120 --shared_dim 128 \ 24 | --ema --ema_start 20000 \ 25 | --use_multiepoch_sampler \ 26 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 27 | --skip_init --G_batch_size 512 --use_ema --G_eval_mode --sample_trunc_curves 0.05_0.05_1.0 \ 28 | --sample_inception_metrics --sample_npz --sample_random --sample_sheets --sample_interps 29 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/scripts/sample_cifar_ema.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # All contributions by Andy Brock: 7 | # Copyright (c) 2019 Andy Brock 8 | # 9 | # MIT License 10 | # 11 | CUDA_VISIBLE_DEVICES=0,1 python sample.py \ 12 | --shuffle --batch_size 50 --G_batch_size 256 --parallel \ 13 | --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 \ 14 | --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 \ 15 | --dataset C10 \ 16 | --G_ortho 0.0 \ 17 | --G_attn 0 --D_attn 0 \ 18 | --G_init N02 --D_init N02 \ 19 | --ema --use_ema --ema_start 1000 \ 20 | --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --seed 0 21 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/scripts/utils/duplicate.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # All contributions by Andy Brock: 5 | # Copyright (c) 2019 Andy Brock 6 | # 7 | # MIT License 8 | # 9 | #duplicate.sh 10 | source=BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs256_Glr1.0e-04_Dlr4.0e-04_Gnlinplace_relu_Dnlinplace_relu_Ginitxavier_Dinitxavier_Gshared_alex0 11 | target=BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs256_Glr1.0e-04_Dlr4.0e-04_Gnlinplace_relu_Dnlinplace_relu_Ginitxavier_Dinitxavier_Gshared_alex0A 12 | logs_root=logs 13 | weights_root=weights 14 | echo "copying ${source} to ${target}" 15 | cp -r ${logs_root}/${source} ${logs_root}/${target} 16 | cp ${logs_root}/${source}_log.jsonl ${logs_root}/${target}_log.jsonl 17 | cp ${weights_root}/${source}_G.pth ${weights_root}/${target}_G.pth 18 | cp ${weights_root}/${source}_G_ema.pth ${weights_root}/${target}_G_ema.pth 19 | cp ${weights_root}/${source}_D.pth ${weights_root}/${target}_D.pth 20 | cp ${weights_root}/${source}_G_optim.pth ${weights_root}/${target}_G_optim.pth 21 | cp ${weights_root}/${source}_D_optim.pth ${weights_root}/${target}_D_optim.pth 22 | cp ${weights_root}/${source}_state_dict.pth ${weights_root}/${target}_state_dict.pth 23 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/scripts/utils/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # All contributions by Andy Brock: 7 | # Copyright (c) 2019 Andy Brock 8 | # 9 | # MIT License 10 | # 11 | # ImageNet 12 | python make_hdf5.py --resolution 64 --split 'train' --data_root '../../anyshot_longtail/data/Imagenet_all/' --out_path 'mock_data' 13 | python calculate_inception_moments.py --resolution 64 --split 'train' --data_root 'mock_data' --load_in_mem --out_path 'mock_data' 14 | python make_hdf5.py --resolution 64 --split 'val' --data_root '../../anyshot_longtail/data/Imagenet_all/' --out_path 'mock_data' 15 | python calculate_inception_moments.py --resolution 64 --split 'val' --data_root 'mock_data' --load_in_mem --out_path 'mock_data' 16 | 17 | # ImageNet-LT 18 | python make_hdf5.py --resolution 64 --split 'train' --data_root '../../anyshot_longtail/data/Imagenet_all/' --out_path 'mock_data' --longtail 19 | python calculate_inception_moments.py --resolution 64 --split 'train' --data_root 'mock_data' --longtail --load_in_mem 20 | python make_hdf5.py --resolution 64 --split 'val' --data_root '../../anyshot_longtail/data/Imagenet_all/' --out_path 'mock_data' 21 | python calculate_inception_moments.py --resolution 64 --split 'val' --data_root 'mock_data' --load_in_mem --out_path 'mock_data' --stratified_moments 22 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import ( 12 | SynchronizedBatchNorm1d, 13 | SynchronizedBatchNorm2d, 14 | SynchronizedBatchNorm3d, 15 | ) 16 | from .replicate import DataParallelWithCallback, patch_replication_callback 17 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # All contributions by Andy Brock: 5 | # Copyright (c) 2019 Andy Brock 6 | # 7 | #! /usr/bin/env python3 8 | # -*- coding: utf-8 -*- 9 | # File : batchnorm_reimpl.py 10 | # Author : acgtyrant 11 | # Date : 11/01/2018 12 | # 13 | # This file is part of Synchronized-BatchNorm-PyTorch. 14 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 15 | # Distributed under MIT License. 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.init as init 20 | 21 | __all__ = ["BatchNormReimpl"] 22 | 23 | 24 | class BatchNorm2dReimpl(nn.Module): 25 | """ 26 | A re-implementation of batch normalization, used for testing the numerical 27 | stability. 28 | 29 | Author: acgtyrant 30 | See also: 31 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 32 | """ 33 | 34 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 35 | super().__init__() 36 | 37 | self.num_features = num_features 38 | self.eps = eps 39 | self.momentum = momentum 40 | self.weight = nn.Parameter(torch.empty(num_features)) 41 | self.bias = nn.Parameter(torch.empty(num_features)) 42 | self.register_buffer("running_mean", torch.zeros(num_features)) 43 | self.register_buffer("running_var", torch.ones(num_features)) 44 | self.reset_parameters() 45 | 46 | def reset_running_stats(self): 47 | self.running_mean.zero_() 48 | self.running_var.fill_(1) 49 | 50 | def reset_parameters(self): 51 | self.reset_running_stats() 52 | init.uniform_(self.weight) 53 | init.zeros_(self.bias) 54 | 55 | def forward(self, input_): 56 | batchsize, channels, height, width = input_.size() 57 | numel = batchsize * height * width 58 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 59 | sum_ = input_.sum(1) 60 | sum_of_square = input_.pow(2).sum(1) 61 | mean = sum_ / numel 62 | sumvar = sum_of_square - sum_ * mean 63 | 64 | self.running_mean = ( 65 | 1 - self.momentum 66 | ) * self.running_mean + self.momentum * mean.detach() 67 | unbias_var = sumvar / (numel - 1) 68 | self.running_var = ( 69 | 1 - self.momentum 70 | ) * self.running_var + self.momentum * unbias_var.detach() 71 | 72 | bias_var = sumvar / numel 73 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 74 | output = (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze( 75 | 1 76 | ) * self.weight.unsqueeze(1) + self.bias.unsqueeze(1) 77 | 78 | return ( 79 | output.view(channels, batchsize, height, width) 80 | .permute(1, 0, 2, 3) 81 | .contiguous() 82 | ) 83 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # All contributions by Andy Brock: 5 | # Copyright (c) 2019 Andy Brock 6 | # 7 | # -*- coding: utf-8 -*- 8 | # File : comm.py 9 | # Author : Jiayuan Mao 10 | # Email : maojiayuan@gmail.com 11 | # Date : 27/01/2018 12 | # 13 | # This file is part of Synchronized-BatchNorm-PyTorch. 14 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 15 | # Distributed under MIT License. 16 | 17 | import queue 18 | import collections 19 | import threading 20 | 21 | __all__ = ["FutureResult", "SlavePipe", "SyncMaster"] 22 | 23 | 24 | class FutureResult(object): 25 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 26 | 27 | def __init__(self): 28 | self._result = None 29 | self._lock = threading.Lock() 30 | self._cond = threading.Condition(self._lock) 31 | 32 | def put(self, result): 33 | with self._lock: 34 | assert self._result is None, "Previous result has't been fetched." 35 | self._result = result 36 | self._cond.notify() 37 | 38 | def get(self): 39 | with self._lock: 40 | if self._result is None: 41 | self._cond.wait() 42 | 43 | res = self._result 44 | self._result = None 45 | return res 46 | 47 | 48 | _MasterRegistry = collections.namedtuple("MasterRegistry", ["result"]) 49 | _SlavePipeBase = collections.namedtuple( 50 | "_SlavePipeBase", ["identifier", "queue", "result"] 51 | ) 52 | 53 | 54 | class SlavePipe(_SlavePipeBase): 55 | """Pipe for master-slave communication.""" 56 | 57 | def run_slave(self, msg): 58 | self.queue.put((self.identifier, msg)) 59 | ret = self.result.get() 60 | self.queue.put(True) 61 | return ret 62 | 63 | 64 | class SyncMaster(object): 65 | """An abstract `SyncMaster` object. 66 | 67 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 68 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 69 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 70 | and passed to a registered callback. 71 | - After receiving the messages, the master device should gather the information and determine to message passed 72 | back to each slave devices. 73 | """ 74 | 75 | def __init__(self, master_callback): 76 | """ 77 | 78 | Args: 79 | master_callback: a callback to be invoked after having collected messages from slave devices. 80 | """ 81 | self._master_callback = master_callback 82 | self._queue = queue.Queue() 83 | self._registry = collections.OrderedDict() 84 | self._activated = False 85 | 86 | def __getstate__(self): 87 | return {"master_callback": self._master_callback} 88 | 89 | def __setstate__(self, state): 90 | self.__init__(state["master_callback"]) 91 | 92 | def register_slave(self, identifier): 93 | """ 94 | Register an slave device. 95 | 96 | Args: 97 | identifier: an identifier, usually is the device id. 98 | 99 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 100 | 101 | """ 102 | if self._activated: 103 | assert self._queue.empty(), "Queue is not clean before next initialization." 104 | self._activated = False 105 | self._registry.clear() 106 | future = FutureResult() 107 | self._registry[identifier] = _MasterRegistry(future) 108 | return SlavePipe(identifier, self._queue, future) 109 | 110 | def run_master(self, master_msg): 111 | """ 112 | Main entry for the master device in each forward pass. 113 | The messages were first collected from each devices (including the master device), and then 114 | an callback will be invoked to compute the message to be sent back to each devices 115 | (including the master device). 116 | 117 | Args: 118 | master_msg: the message that the master want to send to itself. This will be placed as the first 119 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 120 | 121 | Returns: the message to be sent back to the master device. 122 | 123 | """ 124 | self._activated = True 125 | 126 | intermediates = [(0, master_msg)] 127 | for i in range(self.nr_slaves): 128 | intermediates.append(self._queue.get()) 129 | 130 | results = self._master_callback(intermediates) 131 | assert results[0][0] == 0, "The first result should belongs to the master." 132 | 133 | for i, res in results: 134 | if i == 0: 135 | continue 136 | self._registry[i].result.put(res) 137 | 138 | for i in range(self.nr_slaves): 139 | assert self._queue.get() is True 140 | 141 | return results[0][1] 142 | 143 | @property 144 | def nr_slaves(self): 145 | return len(self._registry) 146 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # All contributions by Andy Brock: 5 | # Copyright (c) 2019 Andy Brock 6 | # 7 | # -*- coding: utf-8 -*- 8 | # File : replicate.py 9 | # Author : Jiayuan Mao 10 | # Email : maojiayuan@gmail.com 11 | # Date : 27/01/2018 12 | # 13 | # This file is part of Synchronized-BatchNorm-PyTorch. 14 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 15 | # Distributed under MIT License. 16 | 17 | import functools 18 | 19 | from torch.nn.parallel.data_parallel import DataParallel 20 | 21 | __all__ = [ 22 | "CallbackContext", 23 | "execute_replication_callbacks", 24 | "DataParallelWithCallback", 25 | "patch_replication_callback", 26 | ] 27 | 28 | 29 | class CallbackContext(object): 30 | pass 31 | 32 | 33 | def execute_replication_callbacks(modules): 34 | """ 35 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 36 | 37 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 38 | 39 | Note that, as all modules are isomorphism, we assign each sub-module with a context 40 | (shared among multiple copies of this module on different devices). 41 | Through this context, different copies can share some information. 42 | 43 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 44 | of any slave copies. 45 | """ 46 | master_copy = modules[0] 47 | nr_modules = len(list(master_copy.modules())) 48 | ctxs = [CallbackContext() for _ in range(nr_modules)] 49 | 50 | for i, module in enumerate(modules): 51 | for j, m in enumerate(module.modules()): 52 | if hasattr(m, "__data_parallel_replicate__"): 53 | m.__data_parallel_replicate__(ctxs[j], i) 54 | 55 | 56 | class DataParallelWithCallback(DataParallel): 57 | """ 58 | Data Parallel with a replication callback. 59 | 60 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 61 | original `replicate` function. 62 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 63 | 64 | Examples: 65 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 66 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 67 | # sync_bn.__data_parallel_replicate__ will be invoked. 68 | """ 69 | 70 | def replicate(self, module, device_ids): 71 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 72 | execute_replication_callbacks(modules) 73 | return modules 74 | 75 | 76 | def patch_replication_callback(data_parallel): 77 | """ 78 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 79 | Useful when you have customized `DataParallel` implementation. 80 | 81 | Examples: 82 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 83 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 84 | > patch_replication_callback(sync_bn) 85 | # this is equivalent to 86 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 87 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 88 | """ 89 | 90 | assert isinstance(data_parallel, DataParallel) 91 | 92 | old_replicate = data_parallel.replicate 93 | 94 | @functools.wraps(old_replicate) 95 | def new_replicate(module, device_ids): 96 | modules = old_replicate(module, device_ids) 97 | execute_replication_callbacks(modules) 98 | return modules 99 | 100 | data_parallel.replicate = new_replicate 101 | -------------------------------------------------------------------------------- /BigGAN_PyTorch/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # All contributions by Andy Brock: 5 | # Copyright (c) 2019 Andy Brock 6 | # 7 | # -*- coding: utf-8 -*- 8 | # File : unittest.py 9 | # Author : Jiayuan Mao 10 | # Email : maojiayuan@gmail.com 11 | # Date : 27/01/2018 12 | # 13 | # This file is part of Synchronized-BatchNorm-PyTorch. 14 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 15 | # Distributed under MIT License. 16 | 17 | import unittest 18 | import torch 19 | 20 | 21 | class TorchTestCase(unittest.TestCase): 22 | def assertTensorClose(self, x, y): 23 | adiff = float((x - y).abs().max()) 24 | if (y == 0).all(): 25 | rdiff = "NaN" 26 | else: 27 | rdiff = float((adiff / y).abs().max()) 28 | 29 | message = ("Tensor close check failed\n" "adiff={}\n" "rdiff={}\n").format( 30 | adiff, rdiff 31 | ) 32 | self.assertTrue(torch.allclose(x, y), message) 33 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to IC-GAN 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | 30 | ## License 31 | By contributing to IC-GAN, you agree that your contributions will be licensed 32 | under the LICENSE file in the root directory of this source tree. 33 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | python_version: "3.8" 4 | system_packages: 5 | - "libgl1-mesa-glx" 6 | - "libglib2.0-0" 7 | python_packages: 8 | - "torch==1.8.0" 9 | - "torchvision==0.9.0" 10 | - "numpy==1.21.1" 11 | - "ipython==7.21.0" 12 | - "pytorch-pretrained-biggan==0.1.1" 13 | - "ftfy==6.0.3" 14 | - "cma==3.1.0" 15 | - "scikit-learn==0.24.2" 16 | - "imageio==2.9.0" 17 | - "Pillow==8.3.1" 18 | - "regex==2021.8.28" 19 | - "nltk==3.6.3" 20 | - "scipy==1.7.1" 21 | - "h5py==3.4.0" 22 | - "matplotlib==3.4.3" 23 | - "faiss-gpu==1.7.1.post2" 24 | 25 | predict: "predict.py:Predictor" 26 | -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/data_utils/__init__.py -------------------------------------------------------------------------------- /data_utils/compute_pdrc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # prdc 5 | # Copyright (c) 2020-present NAVER Corp. 6 | # MIT license 7 | 8 | import numpy as np 9 | import sklearn.metrics 10 | 11 | __all__ = ["compute_prdc"] 12 | 13 | 14 | def compute_pairwise_distance(data_x, data_y=None): 15 | """ 16 | Parameters 17 | ---------- 18 | data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) 19 | data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) 20 | Returns 21 | ------- 22 | numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. 23 | """ 24 | if data_y is None: 25 | data_y = data_x 26 | dists = sklearn.metrics.pairwise_distances( 27 | data_x, data_y, metric="euclidean", n_jobs=8 28 | ) 29 | return dists 30 | 31 | 32 | def get_kth_value(unsorted, k, axis=-1): 33 | """ 34 | Parameters 35 | ---------- 36 | unsorted: numpy.ndarray of any dimensionality. 37 | k: int 38 | axis: int 39 | Returns 40 | ------- 41 | kth values along the designated axis. 42 | """ 43 | indices = np.argpartition(unsorted, k, axis=axis)[..., :k] 44 | k_smallests = np.take_along_axis(unsorted, indices, axis=axis) 45 | kth_values = k_smallests.max(axis=axis) 46 | return kth_values 47 | 48 | 49 | def compute_nearest_neighbour_distances(input_features, nearest_k): 50 | """ 51 | Parameters 52 | ---------- 53 | input_features: numpy.ndarray([N, feature_dim], dtype=np.float32) 54 | nearest_k: int 55 | Returns 56 | ------- 57 | Distances to kth nearest neighbours. 58 | """ 59 | distances = compute_pairwise_distance(input_features) 60 | radii = get_kth_value(distances, k=nearest_k + 1, axis=-1) 61 | return radii 62 | 63 | 64 | def compute_prdc(real_features, fake_features, nearest_k): 65 | """ 66 | Computes precision, recall, density, and coverage given two manifolds. 67 | 68 | Parameters 69 | ---------- 70 | real_features: numpy.ndarray([N, feature_dim], dtype=np.float32) 71 | fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32) 72 | nearest_k: int. 73 | Returns 74 | ------- 75 | dict of precision, recall, density, and coverage. 76 | """ 77 | 78 | print( 79 | "Num real: {} Num fake: {}".format( 80 | real_features.shape[0], fake_features.shape[0] 81 | ) 82 | ) 83 | 84 | real_nearest_neighbour_distances = compute_nearest_neighbour_distances( 85 | real_features, nearest_k 86 | ) 87 | fake_nearest_neighbour_distances = compute_nearest_neighbour_distances( 88 | fake_features, nearest_k 89 | ) 90 | distance_real_fake = compute_pairwise_distance(real_features, fake_features) 91 | 92 | precision = ( 93 | (distance_real_fake < np.expand_dims(real_nearest_neighbour_distances, axis=1)) 94 | .any(axis=0) 95 | .mean() 96 | ) 97 | 98 | recall = ( 99 | (distance_real_fake < np.expand_dims(fake_nearest_neighbour_distances, axis=0)) 100 | .any(axis=1) 101 | .mean() 102 | ) 103 | 104 | density = (1.0 / float(nearest_k)) * ( 105 | distance_real_fake < np.expand_dims(real_nearest_neighbour_distances, axis=1) 106 | ).sum(axis=0).mean() 107 | 108 | coverage = ( 109 | distance_real_fake.min(axis=1) < real_nearest_neighbour_distances 110 | ).mean() 111 | 112 | return dict(precision=precision, recall=recall, density=density, coverage=coverage) 113 | -------------------------------------------------------------------------------- /data_utils/make_hdf5_nns.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # All contributions by Andy Brock: 8 | # Copyright (c) 2019 Andy Brock 9 | # 10 | # MIT License 11 | """ Obtain nearest neighbors and store them in a HDF5 file. """ 12 | import os 13 | import sys 14 | from argparse import ArgumentParser 15 | from tqdm import tqdm, trange 16 | import h5py as h5 17 | 18 | import numpy as np 19 | import torch 20 | import utils 21 | 22 | 23 | def prepare_parser(): 24 | usage = "Parser for ImageNet HDF5 scripts." 25 | parser = ArgumentParser(description=usage) 26 | parser.add_argument( 27 | "--resolution", 28 | type=int, 29 | default=128, 30 | help="Which Dataset resolution to train on, out of 64, 128, 256 (default: %(default)s)", 31 | ) 32 | parser.add_argument( 33 | "--split", 34 | type=str, 35 | default="train", 36 | help="Which Dataset to convert: train, val (default: %(default)s)", 37 | ) 38 | parser.add_argument( 39 | "--data_root", 40 | type=str, 41 | default="data", 42 | help="Default location where data is stored (default: %(default)s)", 43 | ) 44 | parser.add_argument( 45 | "--out_path", 46 | type=str, 47 | default="data", 48 | help="Default location where data in hdf5 format will be stored (default: %(default)s)", 49 | ) 50 | parser.add_argument( 51 | "--num_workers", 52 | type=int, 53 | default=16, 54 | help="Number of dataloader workers (default: %(default)s)", 55 | ) 56 | parser.add_argument( 57 | "--chunk_size", 58 | type=int, 59 | default=500, 60 | help="Default overall batchsize (default: %(default)s)", 61 | ) 62 | parser.add_argument( 63 | "--compression", 64 | action="store_true", 65 | default=False, 66 | help="Use LZF compression? (default: %(default)s)", 67 | ) 68 | 69 | parser.add_argument( 70 | "--feature_extractor", 71 | type=str, 72 | default="classification", 73 | choices=["classification", "selfsupervised"], 74 | help="Choice of feature extractor", 75 | ) 76 | parser.add_argument( 77 | "--backbone_feature_extractor", 78 | type=str, 79 | default="resnet50", 80 | choices=["resnet50"], 81 | help="Choice of feature extractor backbone", 82 | ) 83 | parser.add_argument( 84 | "--k_nn", 85 | type=int, 86 | default=100, 87 | help="Number of nearest neighbors (default: %(default)s)", 88 | ) 89 | 90 | parser.add_argument( 91 | "--which_dataset", type=str, default="imagenet", help="Dataset choice." 92 | ) 93 | 94 | return parser 95 | 96 | 97 | def run(config): 98 | # Update compression entry 99 | config["compression"] = ( 100 | "lzf" if config["compression"] else None 101 | ) # No compression; can also use 'lzf' 102 | 103 | test_part = False 104 | if config["split"] == "test": 105 | config["split"] = "val" 106 | test_part = True 107 | if config["which_dataset"] in ["imagenet", "imagenet_lt"]: 108 | dataset_name_prefix = "ILSVRC" 109 | elif config["which_dataset"] == "coco": 110 | dataset_name_prefix = "COCO" 111 | else: 112 | dataset_name_prefix = config["which_dataset"] 113 | 114 | train_dataset = utils.get_dataset_hdf5( 115 | **{ 116 | "resolution": config["resolution"], 117 | "data_path": config["data_root"], 118 | "load_in_mem_feats": True, 119 | "compute_nns": True, 120 | "longtail": config["which_dataset"] == "imagenet_lt", 121 | "split": config["split"], 122 | "instance_cond": True, 123 | "feature_extractor": config["feature_extractor"], 124 | "backbone_feature_extractor": config["backbone_feature_extractor"], 125 | "k_nn": config["k_nn"], 126 | "ddp": False, 127 | "which_dataset": config["which_dataset"], 128 | "test_part": test_part, 129 | } 130 | ) 131 | 132 | all_nns = np.array(train_dataset.sample_nns)[:, : config["k_nn"]] 133 | all_nns_radii = train_dataset.kth_values[:, config["k_nn"]] 134 | print("NNs shape ", all_nns.shape, all_nns_radii.shape) 135 | labels_ = torch.Tensor(train_dataset.labels) 136 | acc = np.array( 137 | [(labels_[all_nns[:, i_nn]] == labels_).sum() for i_nn in range(config["k_nn"])] 138 | ).sum() / (len(labels_) * config["k_nn"]) 139 | print("For k ", config["k_nn"], " accuracy:", acc) 140 | 141 | h5file_name_nns = config["out_path"] + "/%s%i%s%s%s_feats_%s_%s_nn_k%i.hdf5" % ( 142 | dataset_name_prefix, 143 | config["resolution"], 144 | "" if config["which_dataset"] != "imagenet_lt" else "longtail", 145 | "_val" if config["split"] == "val" else "", 146 | "_test" if test_part else "", 147 | config["feature_extractor"], 148 | config["backbone_feature_extractor"], 149 | config["k_nn"], 150 | ) 151 | print("Filename is ", h5file_name_nns) 152 | 153 | with h5.File(h5file_name_nns, "w") as f: 154 | nns_dset = f.create_dataset( 155 | "sample_nns", 156 | all_nns.shape, 157 | dtype="int64", 158 | maxshape=all_nns.shape, 159 | chunks=(config["chunk_size"], all_nns.shape[1]), 160 | compression=config["compression"], 161 | ) 162 | nns_dset[...] = all_nns 163 | 164 | nns_radii_dset = f.create_dataset( 165 | "sample_nns_radius", 166 | all_nns_radii.shape, 167 | dtype="float", 168 | maxshape=all_nns_radii.shape, 169 | chunks=(config["chunk_size"],), 170 | compression=config["compression"], 171 | ) 172 | nns_radii_dset[...] = all_nns_radii 173 | 174 | 175 | def main(): 176 | # parse command line and run 177 | parser = prepare_parser() 178 | config = vars(parser.parse_args()) 179 | print(config) 180 | run(config) 181 | 182 | 183 | if __name__ == "__main__": 184 | main() 185 | -------------------------------------------------------------------------------- /data_utils/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | 5 | resolution=$2 # 64,128,256 6 | dataset=$1 #'imagenet', 'imagenet_lt', 'coco', [a transfer dataset, such as 'cityscapes'] 7 | out_path='' 8 | path_imnet='' 9 | path_swav='swav_800ep_pretrain.pth.tar' 10 | path_classifier_lt='resnet50_uniform_e90.pth' 11 | 12 | 13 | ################## 14 | #### ImageNet #### 15 | ################## 16 | if [ $dataset = 'imagenet' ]; then 17 | python data_utils/make_hdf5.py --resolution $resolution --split 'train' --data_root $path_imnet --out_path $out_path --feature_extractor 'classification' --feature_augmentation 18 | python data_utils/make_hdf5.py --resolution $resolution --split 'train' --data_root $path_imnet --out_path $out_path --save_features_only --feature_extractor 'selfsupervised' --feature_augmentation --pretrained_model_path $path_swav 19 | python data_utils/make_hdf5.py --resolution $resolution --split 'val' --data_root $path_imnet --out_path $out_path --save_images_only 20 | ## Calculate inception moments 21 | for split in 'train' 'val'; do 22 | python data_utils/calculate_inception_moments.py --resolution $resolution --split $split --data_root $out_path --load_in_mem --out_path $out_path 23 | done 24 | # Compute NNs 25 | python data_utils/make_hdf5_nns.py --resolution $resolution --split 'train' --feature_extractor 'classification' --data_root $out_path --out_path $out_path --k_nn 50 26 | python data_utils/make_hdf5_nns.py --resolution $resolution --split 'train' --feature_extractor 'selfsupervised' --data_root $out_path --out_path $out_path --k_nn 50 27 | 28 | elif [ $dataset = 'imagenet_lt' ]; then 29 | python data_utils/make_hdf5.py --resolution $resolution --which_dataset 'imagenet_lt' --split 'train' --data_root $path_imnet --out_path $out_path --feature_extractor 'classification' --feature_augmentation --pretrained_model_path $path_classifier_lt 30 | python data_utils/make_hdf5.py --resolution $resolution --which_dataset 'imagenet_lt' --split 'val' --data_root $path_imnet --out_path $out_path --save_images_only 31 | # Calculate inception moments 32 | python data_utils/calculate_inception_moments.py --resolution $resolution --which_dataset 'imagenet_lt' --split 'train' --data_root $out_path --out_path $out_path 33 | python data_utils/calculate_inception_moments.py --resolution $resolution --split 'val' --data_root $out_path --out_path $out_path --stratified_moments 34 | # Compute NNs 35 | python data_utils/make_hdf5_nns.py --resolution $resolution --which_dataset 'imagenet_lt' --split 'train' --feature_extractor 'classification' --data_root $out_path --out_path $out_path --k_nn 5 36 | 37 | elif [ $dataset = 'coco' ]; then 38 | path_split=("train" "val") 39 | split=("train" "test") 40 | for i in "${!path_split[@]}"; do 41 | coco_data_path='COCO/022719/'${path_split[i]}'2017' 42 | coco_instances_path='datasets/coco/annotations/instances_'${path_split[i]}'2017.json' 43 | coco_stuff_path='datasets/coco/annotations/stuff_'${path_split[i]}'2017.json' 44 | python data_utils/make_hdf5.py --resolution $resolution --which_dataset 'coco' --split ${split[i]} --data_root $coco_data_path --instance_json $coco_instances_path --stuff_json $coco_stuff_path --out_path $out_path --feature_extractor 'selfsupervised' --feature_augmentation --pretrained_model_path $path_swav 45 | python data_utils/make_hdf5.py --resolution $resolution --which_dataset 'coco' --split ${split[i]} --data_root $coco_data_path --instance_json $coco_instances_path --stuff_json $coco_stuff_path --out_path $out_path --feature_extractor 'classification' --feature_augmentation 46 | 47 | # Calculate inception moments 48 | python data_utils/calculate_inception_moments.py --resolution $resolution --which_dataset 'coco' --split ${split[i]} --data_root $out_path --load_in_mem --out_path $out_path 49 | # Compute NNs 50 | python data_utils/make_hdf5_nns.py --resolution $resolution --which_dataset 'coco' --split ${split[i]} --feature_extractor 'selfsupervised' --data_root $out_path --out_path $out_path --k_nn 5 51 | python data_utils/make_hdf5_nns.py --resolution $resolution --which_dataset 'coco' --split ${split[i]} --feature_extractor 'classification' --data_root $out_path --out_path $out_path --k_nn 5 52 | 53 | done 54 | # Transfer datasets 55 | else 56 | python data_utils/make_hdf5.py --resolution $resolution --which_dataset $dataset --split 'train' --data_root $3 --feature_extractor 'classification' --out_path $out_path 57 | # Compute NNs 58 | python data_utils/make_hdf5.py --resolution $resolution --which_dataset $dataset --split 'train' --data_root $3 --feature_extractor 'selfsupervised' --pretrained_model_path $path_swav --save_features_only --out_path $out_path 59 | # Compute NNs 60 | # Compute NNs 61 | python data_utils/make_hdf5_nns.py --resolution $resolution --which_dataset $dataset --split 'train' --feature_extractor 'classification' --data_root $out_path --out_path $out_path --k_nn 5 62 | python data_utils/make_hdf5_nns.py --resolution $resolution --which_dataset $dataset --split 'train' --feature_extractor 'selfsupervised' --data_root $out_path --out_path $out_path --k_nn 5 63 | 64 | fi 65 | -------------------------------------------------------------------------------- /data_utils/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # From PyTorch: 8 | # 9 | # Copyright (c) 2016- Facebook, Inc (Adam Paszke) 10 | # Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 11 | # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 12 | # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 13 | # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 14 | # Copyright (c) 2011-2013 NYU (Clement Farabet) 15 | # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 16 | # Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 17 | # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 18 | # 19 | # From Caffe2: 20 | # 21 | # Copyright (c) 2016-present, Facebook Inc. All rights reserved. 22 | # 23 | # All contributions by Facebook: 24 | # Copyright (c) 2016 Facebook Inc. 25 | # 26 | # All contributions by Google: 27 | # Copyright (c) 2015 Google Inc. 28 | # All rights reserved. 29 | # 30 | # All contributions by Yangqing Jia: 31 | # Copyright (c) 2015 Yangqing Jia 32 | # All rights reserved. 33 | # 34 | # All contributions by Kakao Brain: 35 | # Copyright 2019-2020 Kakao Brain 36 | # 37 | # All contributions from Caffe: 38 | # Copyright(c) 2013, 2014, 2015, the respective contributors 39 | # All rights reserved. 40 | # 41 | # All other contributions: 42 | # Copyright(c) 2015, 2016 the respective contributors 43 | # All rights reserved. 44 | import torch 45 | from torchvision.models.utils import load_state_dict_from_url 46 | from typing import Type, Any, Callable, Union, List, Optional 47 | from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet 48 | 49 | 50 | __all__ = [ 51 | "ResNet", 52 | "resnet18", 53 | "resnet34", 54 | "resnet50", 55 | "resnet101", 56 | "resnet152", 57 | "resnext50_32x4d", 58 | "resnext101_32x8d", 59 | "wide_resnet50_2", 60 | "wide_resnet101_2", 61 | ] 62 | 63 | 64 | model_urls = { 65 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 66 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 67 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 68 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 69 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 70 | "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", 71 | "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", 72 | "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", 73 | "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", 74 | } 75 | 76 | 77 | class ResNet_mine(ResNet): 78 | def __init__(self, block, layers, classifier_run=True, **kwargs): 79 | super().__init__(block, layers, **kwargs) 80 | self.classifier_run = classifier_run 81 | 82 | def _forward_impl(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor): 83 | # See note [TorchScript super()] 84 | x = self.conv1(x) 85 | x = self.bn1(x) 86 | x = self.relu(x) 87 | x = self.maxpool(x) 88 | 89 | x = self.layer1(x) 90 | x = self.layer2(x) 91 | x = self.layer3(x) 92 | x_ = self.layer4(x) 93 | 94 | x = self.avgpool(x_) 95 | x = torch.flatten(x, 1) 96 | if self.classifier_run: 97 | x = self.fc(x) 98 | 99 | return x, x_ 100 | 101 | def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor): 102 | return self._forward_impl(x) 103 | 104 | 105 | def pnorm(weights, p): 106 | normB = torch.norm(weights, 2, 1) 107 | ws = weights.clone() 108 | for i in range(weights.size(0)): 109 | ws[i] = ws[i] / torch.pow(normB[i], p) 110 | return ws 111 | 112 | 113 | def _resnet( 114 | arch: str, 115 | block: Type[Union[BasicBlock, Bottleneck]], 116 | layers: List[int], 117 | pretrained: bool, 118 | progress: bool, 119 | **kwargs: Any 120 | ) -> ResNet: 121 | model = ResNet_mine(block, layers, **kwargs) 122 | if pretrained: 123 | print("Inside resnet function, using ImageNet pretrained from model url!") 124 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 125 | model.load_state_dict(state_dict) 126 | return model 127 | 128 | 129 | def resnext50_32x4d( 130 | pretrained: bool = False, progress: bool = True, **kwargs: Any 131 | ) -> ResNet: 132 | r"""ResNeXt-50 32x4d model from 133 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 134 | Args: 135 | pretrained (bool): If True, returns a model pre-trained on ImageNet 136 | progress (bool): If True, displays a progress bar of the download to stderr 137 | """ 138 | kwargs["groups"] = 32 139 | kwargs["width_per_group"] = 4 140 | return _resnet( 141 | "resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs 142 | ) 143 | 144 | 145 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 146 | r"""ResNet-50 model from 147 | `"Deep Residual Learning for Image Recognition" `_. 148 | 149 | Args: 150 | pretrained (bool): If True, returns a model pre-trained on ImageNet 151 | progress (bool): If True, displays a progress bar of the download to stderr 152 | """ 153 | return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 154 | -------------------------------------------------------------------------------- /data_utils/store_coco_jpeg_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Store JPEG images from a hdf5 file, in order to compute FID scores (COCO-Stuff).""" 8 | from argparse import ArgumentParser 9 | import numpy as np 10 | import os 11 | import h5py as h5 12 | from imageio import imwrite as imsave 13 | 14 | import sys 15 | 16 | sys.path.insert(1, os.path.join(sys.path[0], "..")) 17 | from data_utils.utils import filter_by_hd 18 | 19 | 20 | def main(args): 21 | dataset_name_prefix = "COCO" 22 | test_part = True if args["split"] == "val" else False 23 | 24 | # HDF5 file name 25 | hdf5_filename = "%s%i%s%s" % ( 26 | dataset_name_prefix, 27 | args["resolution"], 28 | "_val" if args["split"] == "val" else "", 29 | "_test" if test_part else "", 30 | ) 31 | data_path_xy = os.path.join(args["data_root"], hdf5_filename + "_xy.hdf5") 32 | # Load data 33 | print("Loading images %s..." % (data_path_xy)) 34 | with h5.File(data_path_xy, "r") as f: 35 | imgs = f["imgs"][:] 36 | 37 | # Filter images 38 | if args["filter_hd"] > -1: 39 | filtered_idxs = filter_by_hd(args["filter_hd"]) 40 | else: 41 | filtered_idxs = range(len(imgs)) 42 | 43 | # Save images 44 | counter_i = 0 45 | for i, im in enumerate(imgs): 46 | if i in filtered_idxs: 47 | imsave( 48 | "%s/%06d.%s" % (args["out_path"], counter_i, "jpg"), 49 | im.transpose(1, 2, 0), 50 | ) 51 | counter_i += 1 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = ArgumentParser( 56 | description="Storing ground-truth COCO-Stuff images to compute FID metric." 57 | ) 58 | parser.add_argument( 59 | "--resolution", 60 | type=int, 61 | default=64, 62 | help="Data resolution (default: %(default)s)", 63 | ) 64 | parser.add_argument( 65 | "--split", 66 | type=str, 67 | default="train", 68 | choices=["train", "val"], 69 | help="Data split (default: %(default)s)", 70 | ) 71 | parser.add_argument( 72 | "--filter_hd", 73 | type=int, 74 | default=-1, 75 | help="Hamming distance to filter val test in COCO_Stuff (by default no filtering) (default: %(default)s)", 76 | ) 77 | parser.add_argument( 78 | "--data_root", 79 | type=str, 80 | default="data", 81 | help="Default location where the hdf5 file is stored (default: %(default)s)", 82 | ) 83 | parser.add_argument( 84 | "--out_path", 85 | type=str, 86 | default="data", 87 | help="Default location where the resulting images will be stored (default: %(default)s)", 88 | ) 89 | 90 | args = vars(parser.parse_args()) 91 | main(args) 92 | -------------------------------------------------------------------------------- /data_utils/store_kmeans_indexes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Store dataset indexes of datapoints selected by k-means algorithm.""" 8 | from argparse import ArgumentParser 9 | import numpy as np 10 | import os 11 | import h5py as h5 12 | import faiss 13 | 14 | 15 | def main(args): 16 | if args["which_dataset"] == "imagenet": 17 | dataset_name_prefix = "ILSVRC" 18 | im_prefix = "IN" 19 | elif args["which_dataset"] == "coco": 20 | dataset_name_prefix = "COCO" 21 | im_prefix = "COCO" 22 | else: 23 | dataset_name_prefix = args["which_dataset"] 24 | im_prefix = args["which_dataset"] 25 | # HDF5 filename 26 | filename = os.path.join( 27 | args["data_root"], 28 | "%s%s_feats_%s_%s.hdf5" 29 | % ( 30 | dataset_name_prefix, 31 | args["resolution"], 32 | args["feature_extractor"], 33 | args["backbone_feature_extractor"], 34 | ), 35 | ) 36 | # Load features 37 | print("Loading features %s..." % (filename)) 38 | with h5.File(filename, "r") as f: 39 | features = f["feats"][:] 40 | features = np.array(features) 41 | # Normalize features 42 | features /= np.linalg.norm(features, axis=1, keepdims=True) 43 | 44 | feat_dim = 2048 45 | # k-means 46 | print("Training k-means with %i centers..." % (args["kmeans_subsampled"])) 47 | kmeans = faiss.Kmeans( 48 | feat_dim, 49 | args["kmeans_subsampled"], 50 | niter=100, 51 | verbose=True, 52 | gpu=args["gpu"], 53 | min_points_per_centroid=200, 54 | spherical=False, 55 | ) 56 | kmeans.train(features.astype(np.float32)) 57 | 58 | # Find closest instances to each k-means cluster 59 | print("Finding closest instances to centers...") 60 | index = faiss.IndexFlatL2(feat_dim) 61 | index.add(features.astype(np.float32)) 62 | D, closest_sample = index.search(kmeans.centroids, 1) 63 | 64 | net_str = ( 65 | "rn50" 66 | if args["backbone_feature_extractor"] 67 | else args["backbone_feature_extractor"] 68 | ) 69 | stored_filename = "%s_res%i_%s_%s_kmeans_k%i" % ( 70 | im_prefix, 71 | args["resolution"], 72 | net_str, 73 | args["feature_extractor"], 74 | args["kmeans_subsampled"], 75 | ) 76 | np.save( 77 | os.path.join(args["data_root"], stored_filename), 78 | {"center_examples": closest_sample}, 79 | ) 80 | print( 81 | "Instance indexes resulting from a subsampling based on k-means have been saved in file %s!" 82 | % (stored_filename) 83 | ) 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = ArgumentParser( 88 | description="Storing cluster indexes for k-means-based data subsampling" 89 | ) 90 | parser.add_argument( 91 | "--resolution", 92 | type=int, 93 | default=64, 94 | help="Data resolution (default: %(default)s)", 95 | ) 96 | parser.add_argument( 97 | "--which_dataset", type=str, default="imagenet", help="Dataset choice." 98 | ) 99 | parser.add_argument( 100 | "--data_root", 101 | type=str, 102 | default="data", 103 | help="Default location where data is stored (default: %(default)s)", 104 | ) 105 | parser.add_argument( 106 | "--feature_extractor", 107 | type=str, 108 | default="classification", 109 | choices=["classification", "selfsupervised"], 110 | help="Choice of feature extractor", 111 | ) 112 | parser.add_argument( 113 | "--backbone_feature_extractor", 114 | type=str, 115 | default="resnet50", 116 | choices=["resnet50"], 117 | help="Choice of feature extractor backbone", 118 | ) 119 | parser.add_argument( 120 | "--kmeans_subsampled", 121 | type=int, 122 | default=-1, 123 | help="Number of k-means centers if using subsampled training instances" 124 | " (default: %(default)s)", 125 | ) 126 | parser.add_argument( 127 | "--gpu", 128 | action="store_true", 129 | default=False, 130 | help="Use faiss with GPUs (default: %(default)s)", 131 | ) 132 | args = vars(parser.parse_args()) 133 | main(args) 134 | -------------------------------------------------------------------------------- /download-weights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | wget https://dl.fbaipublicfiles.com/ic_gan/cc_icgan_biggan_imagenet_res256.tar.gz 4 | tar -xvf cc_icgan_biggan_imagenet_res256.tar.gz 5 | wget https://dl.fbaipublicfiles.com/ic_gan/icgan_biggan_imagenet_res256.tar.gz 6 | tar -xvf icgan_biggan_imagenet_res256.tar.gz 7 | wget https://dl.fbaipublicfiles.com/ic_gan/stored_instances.tar.gz 8 | tar -xvf stored_instances.tar.gz 9 | curl -L -o swav_pretrained.pth.tar -C - 'https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar' 10 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ic_gan_ddp_1.8.0 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - bzip2=1.0.8=h7b6447c_0 10 | - ca-certificates=2021.5.25=h06a4308_1 11 | - certifi=2021.5.30=py38h06a4308_0 12 | - click=7.1.2=pyhd3eb1b0_0 13 | - cloudpickle=1.6.0=py_0 14 | - cudatoolkit=10.2.89=hfd86e86_1 15 | - cycler=0.10.0=py38_0 16 | - cytoolz=0.11.0=py38h7b6447c_0 17 | - dask-core=2021.4.0=pyhd3eb1b0_0 18 | - decorator=4.4.2=pyhd3eb1b0_0 19 | - faiss-gpu=1.7.0=py3.8_h080d439_0_cuda10.2 20 | - ffmpeg=4.3=hf484d3e_0 21 | - freetype=2.10.4=h5ab3b9f_0 22 | - fsspec=0.9.0=pyhd3eb1b0_0 23 | - gmp=6.2.1=h2531618_2 24 | - gnutls=3.6.15=he1e5248_0 25 | - h5py=2.10.0=py38hd6299e0_1 26 | - hdf5=1.10.6=hb1b8bf9_0 27 | - imageio=2.9.0=pyhd3eb1b0_0 28 | - intel-openmp=2020.2=254 29 | - joblib=1.0.1=pyhd3eb1b0_0 30 | - jpeg=9b=h024ee3a_2 31 | - kiwisolver=1.3.1=py38h2531618_0 32 | - lame=3.100=h7b6447c_0 33 | - lcms2=2.12=h3be6417_0 34 | - ld_impl_linux-64=2.33.1=h53a641e_7 35 | - libfaiss=1.7.0=h4fe19ad_0_cuda10.2 36 | - libffi=3.3=he6710b0_2 37 | - libgcc-ng=9.1.0=hdf63c60_0 38 | - libgfortran-ng=7.3.0=hdf63c60_0 39 | - libiconv=1.15=h63c8f33_5 40 | - libidn2=2.3.0=h27cfd23_0 41 | - libpng=1.6.37=hbc83047_0 42 | - libstdcxx-ng=9.1.0=hdf63c60_0 43 | - libtasn1=4.16.0=h27cfd23_0 44 | - libtiff=4.1.0=h2733197_1 45 | - libunistring=0.9.10=h27cfd23_0 46 | - libuv=1.40.0=h7b6447c_0 47 | - lmdb=0.9.28=h2531618_0 48 | - locket=0.2.1=py38h06a4308_1 49 | - lz4-c=1.9.3=h2531618_0 50 | - matplotlib-base=3.3.4=py38h62a2d02_0 51 | - mkl=2020.2=256 52 | - mkl-service=2.3.0=py38he904b0f_0 53 | - mkl_fft=1.3.0=py38h54f3939_0 54 | - mkl_random=1.1.1=py38h0573a6f_0 55 | - ncurses=6.2=he6710b0_1 56 | - nettle=3.7.2=hbbd107a_1 57 | - networkx=2.5.1=pyhd3eb1b0_0 58 | - ninja=1.10.2=hff7bd54_1 59 | - numpy=1.19.2=py38h54aff64_0 60 | - numpy-base=1.19.2=py38hfa32c7d_0 61 | - olefile=0.46=py_0 62 | - openh264=2.1.0=hd408876_0 63 | - openssl=1.1.1k=h27cfd23_0 64 | - partd=1.2.0=pyhd3eb1b0_0 65 | - pillow=8.2.0=py38he98fc37_0 66 | - pip=21.0.1=py38h06a4308_0 67 | - psutil=5.8.0=py38h27cfd23_1 68 | - pyparsing=2.4.7=pyhd3eb1b0_0 69 | - python=3.8.8=hdb3f193_5 70 | - python-dateutil=2.8.1=pyhd3eb1b0_0 71 | - python-lmdb=1.1.1=py38h2531618_1 72 | - python_abi=3.8=1_cp38 73 | - pytorch=1.8.0=py3.8_cuda10.2_cudnn7.6.5_0 74 | - pywavelets=1.1.1=py38h7b6447c_2 75 | - pyyaml=5.4.1=py38h27cfd23_1 76 | - readline=8.1=h27cfd23_0 77 | - scikit-image=0.18.1=py38ha9443f7_0 78 | - scikit-learn=0.24.1=py38ha9443f7_0 79 | - scipy=1.6.2=py38h91f5cce_0 80 | - setuptools=52.0.0=py38h06a4308_0 81 | - six=1.15.0=py38h06a4308_0 82 | - sqlite=3.35.4=hdfb4753_0 83 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 84 | - tifffile=2020.10.1=py38hdd07704_2 85 | - tk=8.6.10=hbc83047_0 86 | - toolz=0.11.1=pyhd3eb1b0_0 87 | - torchaudio=0.8.0=py38 88 | - torchvision=0.9.0=py38_cu102 89 | - tornado=6.1=py38h27cfd23_0 90 | - tqdm=4.59.0=pyhd3eb1b0_1 91 | - typing_extensions=3.7.4.3=pyha847dfd_0 92 | - wheel=0.36.2=pyhd3eb1b0_0 93 | - xz=5.2.5=h7b6447c_0 94 | - yaml=0.2.5=h7b6447c_0 95 | - zlib=1.2.11=h7b6447c_3 96 | - zstd=1.4.9=haebb681_0 97 | - pip: 98 | - argon2-cffi==20.1.0 99 | - async-generator==1.10 100 | - attrs==21.2.0 101 | - backcall==0.2.0 102 | - bleach==3.3.0 103 | - cffi==1.14.5 104 | - chardet==4.0.0 105 | - defusedxml==0.7.1 106 | - entrypoints==0.3 107 | - filelock==3.0.12 108 | - gdown==3.13.0 109 | - idna==2.10 110 | - ipykernel==5.5.4 111 | - ipython==7.23.1 112 | - ipython-genutils==0.2.0 113 | - ipywidgets==7.6.3 114 | - jedi==0.18.0 115 | - jinja2==2.11.3 116 | - jsonschema==3.2.0 117 | - jupyter==1.0.0 118 | - jupyter-client==6.1.12 119 | - jupyter-console==6.4.0 120 | - jupyter-core==4.7.1 121 | - jupyterlab-pygments==0.1.2 122 | - jupyterlab-widgets==1.0.0 123 | - markupsafe==1.1.1 124 | - matplotlib-inline==0.1.2 125 | - mistune==0.8.4 126 | - nbclient==0.5.3 127 | - nbconvert==6.0.7 128 | - nbformat==5.1.3 129 | - nest-asyncio==1.5.1 130 | - notebook==6.3.0 131 | - packaging==20.9 132 | - pandocfilters==1.4.3 133 | - parso==0.8.2 134 | - pexpect==4.8.0 135 | - pickleshare==0.7.5 136 | - prometheus-client==0.10.1 137 | - prompt-toolkit==3.0.18 138 | - ptyprocess==0.7.0 139 | - pycparser==2.20 140 | - pygments==2.9.0 141 | - pyrsistent==0.17.3 142 | - pysocks==1.7.1 143 | - pyzmq==22.0.3 144 | - qtconsole==5.1.0 145 | - qtpy==1.9.0 146 | - requests==2.25.1 147 | - send2trash==1.5.0 148 | - submitit==1.3.3 149 | - terminado==0.9.4 150 | - testpath==0.4.4 151 | - traitlets==5.0.5 152 | - urllib3==1.26.4 153 | - wcwidth==0.2.5 154 | - webencodings==0.5.1 155 | - widgetsnbextension==3.5.1 156 | -------------------------------------------------------------------------------- /figures/github_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/figures/github_image.png -------------------------------------------------------------------------------- /figures/icgan_clip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/figures/icgan_clip.png -------------------------------------------------------------------------------- /figures/icgan_transfer_all_github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/figures/icgan_transfer_all_github.png -------------------------------------------------------------------------------- /inference/sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # All contributions by Andy Brock: 8 | # Copyright (c) 2019 Andy Brock 9 | # 10 | # MIT License 11 | """ BigGAN: The Authorized Unofficial PyTorch release 12 | Code by A. Brock and A. Andonian 13 | This code is an unofficial reimplementation of 14 | "Large-Scale GAN Training for High Fidelity Natural Image Synthesis," 15 | by A. Brock, J. Donahue, and K. Simonyan (arXiv 1809.11096). 16 | 17 | Let's go. 18 | """ 19 | 20 | import os 21 | import numpy as np 22 | from tqdm import tqdm, trange 23 | import json 24 | 25 | from imageio import imwrite as imsave 26 | 27 | # Import my stuff 28 | import sys 29 | 30 | sys.path.insert(1, os.path.join(sys.path[0], "..")) 31 | import inference.utils as inference_utils 32 | import BigGAN_PyTorch.utils as biggan_utils 33 | 34 | 35 | class Tester: 36 | def __init__(self, config): 37 | self.config = vars(config) if not isinstance(config, dict) else config 38 | 39 | def __call__(self) -> float: 40 | # Seed RNG 41 | biggan_utils.seed_rng(self.config["seed"]) 42 | 43 | import torch 44 | 45 | # Setup cudnn.benchmark for free speed 46 | torch.backends.cudnn.benchmark = True 47 | 48 | self.config = biggan_utils.update_config_roots( 49 | self.config, change_weight_folder=False 50 | ) 51 | # Prepare root folders if necessary 52 | biggan_utils.prepare_root(self.config) 53 | 54 | # Load model 55 | self.G, self.config = inference_utils.load_model_inference(self.config) 56 | biggan_utils.count_parameters(self.G) 57 | 58 | # Get sampling function and reference statistics for FID 59 | print("Eval reference set is ", self.config["eval_reference_set"]) 60 | sample, im_reference_filename = inference_utils.get_sampling_funct( 61 | self.config, 62 | self.G, 63 | instance_set=self.config["eval_instance_set"], 64 | reference_set=self.config["eval_reference_set"], 65 | which_dataset=self.config["which_dataset"], 66 | ) 67 | 68 | if config["which_dataset"] == "coco": 69 | image_format = "jpg" 70 | else: 71 | image_format = "png" 72 | if ( 73 | self.config["eval_instance_set"] == "val" 74 | and config["which_dataset"] == "coco" 75 | ): 76 | # using evaluation set 77 | test_part = True 78 | else: 79 | test_part = False 80 | path_samples = os.path.join( 81 | self.config["samples_root"], 82 | self.config["experiment_name"], 83 | "%s_images_seed%i%s%s%s" 84 | % ( 85 | config["which_dataset"], 86 | config["seed"], 87 | "_test" if test_part else "", 88 | "_hd" + str(self.config["filter_hd"]) 89 | if self.config["filter_hd"] > -1 90 | else "", 91 | "" 92 | if self.config["kmeans_subsampled"] == -1 93 | else "_" + str(self.config["kmeans_subsampled"]) + "centers", 94 | ), 95 | ) 96 | 97 | print("Path samples will be ", path_samples) 98 | if not os.path.exists(path_samples): 99 | os.makedirs(path_samples) 100 | 101 | if not os.path.exists( 102 | os.path.join(self.config["samples_root"], self.config["experiment_name"]) 103 | ): 104 | os.mkdir( 105 | os.path.join( 106 | self.config["samples_root"], self.config["experiment_name"] 107 | ) 108 | ) 109 | print( 110 | "Sampling %d images and saving them with %s format..." 111 | % (self.config["sample_num_npz"], image_format) 112 | ) 113 | counter_i = 0 114 | for i in trange( 115 | int( 116 | np.ceil( 117 | self.config["sample_num_npz"] / float(self.config["batch_size"]) 118 | ) 119 | ) 120 | ): 121 | with torch.no_grad(): 122 | images, labels, _ = sample() 123 | 124 | fake_imgs = images.cpu().detach().numpy().transpose(0, 2, 3, 1) 125 | if self.config["model_backbone"] == "biggan": 126 | fake_imgs = fake_imgs * 0.5 + 0.5 127 | elif self.config["model_backbone"] == "stylegan2": 128 | fake_imgs = np.clip((fake_imgs * 127.5 + 128), 0, 255).astype( 129 | np.uint8 130 | ) 131 | for fake_img in fake_imgs: 132 | imsave( 133 | "%s/%06d.%s" % (path_samples, counter_i, image_format), fake_img 134 | ) 135 | counter_i += 1 136 | if counter_i >= self.config["sample_num_npz"]: 137 | break 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = biggan_utils.prepare_parser() 142 | parser = biggan_utils.add_sample_parser(parser) 143 | parser = inference_utils.add_backbone_parser(parser) 144 | config = vars(parser.parse_args()) 145 | config["n_classes"] = 1000 146 | if config["json_config"] != "": 147 | data = json.load(open(config["json_config"])) 148 | for key in data.keys(): 149 | if "exp_name" in key: 150 | config["experiment_name"] = data[key] 151 | else: 152 | config[key] = data[key] 153 | else: 154 | print("No json file to load configuration from") 155 | 156 | tester = Tester(config) 157 | 158 | tester() 159 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. In '...' directory, run command '...' 16 | 2. See error (copy&paste full log, including exceptions and **stacktraces**). 17 | 18 | Please copy&paste text instead of screenshots for better searchability. 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. Linux Ubuntu 20.04, Windows 10] 28 | - PyTorch version (e.g., pytorch 1.7.1) 29 | - CUDA toolkit version (e.g., CUDA 11.0) 30 | - NVIDIA driver version 31 | - GPU [e.g., Titan V, RTX 3090] 32 | - Docker: did you use Docker? If yes, specify docker image URL (e.g., nvcr.io/nvidia/pytorch:20.12-py3) 33 | 34 | **Additional context** 35 | Add any other context about the problem here. 36 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .cache/ 3 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | FROM nvcr.io/nvidia/pytorch:20.12-py3 10 | 11 | ENV PYTHONDONTWRITEBYTECODE 1 12 | ENV PYTHONUNBUFFERED 1 13 | 14 | RUN pip install imageio-ffmpeg==0.4.3 pyspng==0.1.0 15 | 16 | WORKDIR /workspace 17 | 18 | # Unset TORCH_CUDA_ARCH_LIST and exec. This makes pytorch run-time 19 | # extension builds significantly faster as we only compile for the 20 | # currently active GPU configuration. 21 | RUN (printf '#!/bin/bash\nunset TORCH_CUDA_ARCH_LIST\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 22 | ENTRYPOINT ["/entry.sh"] 23 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA) 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. 96 | 97 | ======================================================================= 98 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/config_files/COCO_Stuff/IC-GAN/icgan_stylegan_res128.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "icgan_stylegan2_coco_res128", 3 | "data": "COCO128_xy.hdf5", 4 | "root_feats": "COCO128_feats_selfsupervised_resnet50.hdf5", 5 | "root_nns": "COCO128_feats_selfsupervised_resnet50_nn_k5.hdf5", 6 | "gpus": 2, 7 | "slurm": false, 8 | "aug": "noaug", 9 | "lrate": 0.0025, 10 | "gamma": 0.05, 11 | "kimg": 100000, 12 | "es_patience": 3738850, 13 | "instance_cond": true, 14 | "mirror": true, 15 | "resolution": 128, 16 | "feature_augmentation": true, 17 | "feature_extractor": "selfsupervised", 18 | "k_nn": 5, 19 | "run_setup": "local_debug" 20 | } -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/config_files/COCO_Stuff/IC-GAN/icgan_stylegan_res256.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "icgan_stylegan2_coco_res256", 3 | "data": "COCO256_xy.hdf5", 4 | "root_feats": "COCO256_feats_selfsupervised_resnet50.hdf5", 5 | "root_nns": "COCO256_feats_selfsupervised_resnet50_nn_k5.hdf5", 6 | "gpus": 2, 7 | "slurm": true, 8 | "aug": "noaug", 9 | "lrate": 0.003, 10 | "gamma": 0.5, 11 | "kimg": 100000, 12 | "es_patience": 3738850, 13 | "instance_cond": true, 14 | "mirror": true, 15 | "resolution": 256, 16 | "feature_augmentation": true, 17 | "feature_extractor": "selfsupervised", 18 | "k_nn": 5, 19 | "run_setup": "local_debug" 20 | } -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/config_files/COCO_Stuff/StyleGAN2/unconditional_stylegan_res128.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "unconditional_stylegan2_coco_res128", 3 | "data": "COCO128_xy.hdf5", 4 | "gpus": 2, 5 | "slurm": true, 6 | "aug": "noaug", 7 | "lrate": 0.0025, 8 | "gamma": 0.05, 9 | "kimg": 100000, 10 | "es_patience": 3738850, 11 | "mirror": true, 12 | "resolution": 128, 13 | "class_cond": false, 14 | "instance_cond": false, 15 | "run_setup": "local_debug" 16 | } -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/config_files/COCO_Stuff/StyleGAN2/unconditional_stylegan_res256.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "unconditional_stylegan2_coco_res256", 3 | "data": "COCO256_xy.hdf5", 4 | "gpus": 2, 5 | "slurm": true, 6 | "aug": "noaug", 7 | "lrate": 0.002, 8 | "gamma": 0.2, 9 | "kimg": 100000, 10 | "es_patience": 3738850, 11 | "mirror": true, 12 | "resolution": 256, 13 | "class_cond": false, 14 | "instance_cond": false, 15 | "run_setup": "local_debug" 16 | } -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/docker_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 7 | # 8 | # NVIDIA CORPORATION and its licensors retain all intellectual property 9 | # and proprietary rights in and to this software, related documentation 10 | # and any modifications thereto. Any use, reproduction, disclosure or 11 | # distribution of this software and related documentation without an express 12 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 13 | 14 | set -e 15 | 16 | # Wrapper script for setting up `docker run` to properly 17 | # cache downloaded files, custom extension builds and 18 | # mount the source directory into the container and make it 19 | # run as non-root user. 20 | # 21 | # Use it like: 22 | # 23 | # ./docker_run.sh python generate.py --help 24 | # 25 | # To override the default `stylegan2ada:latest` image, run: 26 | # 27 | # IMAGE=my_image:v1.0 ./docker_run.sh python generate.py --help 28 | # 29 | 30 | rest=$@ 31 | 32 | IMAGE="${IMAGE:-sg2ada:latest}" 33 | 34 | CONTAINER_ID=$(docker inspect --format="{{.Id}}" ${IMAGE} 2> /dev/null) 35 | if [[ "${CONTAINER_ID}" ]]; then 36 | docker run --shm-size=2g --gpus all -it --rm -v `pwd`:/scratch --user $(id -u):$(id -g) \ 37 | --workdir=/scratch -e HOME=/scratch $IMAGE $@ 38 | else 39 | echo "Unknown container image: ${IMAGE}" 40 | exit 1 41 | fi 42 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/docs/dataset-tool-help.txt: -------------------------------------------------------------------------------- 1 | Usage: dataset_tool.py [OPTIONS] 2 | 3 | Convert an image dataset into a dataset archive usable with StyleGAN2 ADA 4 | PyTorch. 5 | 6 | The input dataset format is guessed from the --source argument: 7 | 8 | --source *_lmdb/ - Load LSUN dataset 9 | --source cifar-10-python.tar.gz - Load CIFAR-10 dataset 10 | --source path/ - Recursively load all images from path/ 11 | --source dataset.zip - Recursively load all images from dataset.zip 12 | 13 | The output dataset format can be either an image folder or a zip archive. 14 | Specifying the output format and path: 15 | 16 | --dest /path/to/dir - Save output files under /path/to/dir 17 | --dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive 18 | 19 | Images within the dataset archive will be stored as uncompressed PNG. 20 | 21 | Image scale/crop and resolution requirements: 22 | 23 | Output images must be square-shaped and they must all have the same power- 24 | of-two dimensions. 25 | 26 | To scale arbitrary input image size to a specific width and height, use 27 | the --width and --height options. Output resolution will be either the 28 | original input resolution (if --width/--height was not specified) or the 29 | one specified with --width/height. 30 | 31 | Use the --transform=center-crop or --transform=center-crop-wide options to 32 | apply a center crop transform on the input image. These options should be 33 | used with the --width and --height options. For example: 34 | 35 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \ 36 | --transform=center-crop-wide --width 512 --height=384 37 | 38 | Options: 39 | --source PATH Directory or archive name for input dataset 40 | [required] 41 | --dest PATH Output directory or archive name for output 42 | dataset [required] 43 | --max-images INTEGER Output only up to `max-images` images 44 | --resize-filter [box|lanczos] Filter to use when resizing images for 45 | output resolution [default: lanczos] 46 | --transform [center-crop|center-crop-wide] 47 | Input crop/resize mode 48 | --width INTEGER Output width 49 | --height INTEGER Output height 50 | --help Show this message and exit. 51 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/docs/license.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Nvidia Source Code License-NC 7 | 8 | 56 | 57 | 58 | 59 |

NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)

60 | 61 |
62 | 63 |

1. Definitions

64 | 65 |

“Licensor” means any person or entity that distributes its Work.

66 | 67 |

“Software” means the original work of authorship made available under 68 | this License.

69 | 70 |

“Work” means the Software and any additions to or derivative works of 71 | the Software that are made available under this License.

72 | 73 |

The terms “reproduce,” “reproduction,” “derivative works,” and 74 | “distribution” have the meaning as provided under U.S. copyright law; 75 | provided, however, that for the purposes of this License, derivative 76 | works shall not include works that remain separable from, or merely 77 | link (or bind by name) to the interfaces of, the Work.

78 | 79 |

Works, including the Software, are “made available” under this License 80 | by including in or with the Work either (a) a copyright notice 81 | referencing the applicability of this License to the Work, or (b) a 82 | copy of this License.

83 | 84 |

2. License Grants

85 | 86 |

2.1 Copyright Grant. Subject to the terms and conditions of this 87 | License, each Licensor grants to you a perpetual, worldwide, 88 | non-exclusive, royalty-free, copyright license to reproduce, 89 | prepare derivative works of, publicly display, publicly perform, 90 | sublicense and distribute its Work and any resulting derivative 91 | works in any form.

92 | 93 |

3. Limitations

94 | 95 |

3.1 Redistribution. You may reproduce or distribute the Work only 96 | if (a) you do so under this License, (b) you include a complete 97 | copy of this License with your distribution, and (c) you retain 98 | without modification any copyright, patent, trademark, or 99 | attribution notices that are present in the Work.

100 | 101 |

3.2 Derivative Works. You may specify that additional or different 102 | terms apply to the use, reproduction, and distribution of your 103 | derivative works of the Work (“Your Terms”) only if (a) Your Terms 104 | provide that the use limitation in Section 3.3 applies to your 105 | derivative works, and (b) you identify the specific derivative 106 | works that are subject to Your Terms. Notwithstanding Your Terms, 107 | this License (including the redistribution requirements in Section 108 | 3.1) will continue to apply to the Work itself.

109 | 110 |

3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for 111 | use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work 112 | and any derivative works commercially. As used herein, “non-commercially” means for research or 113 | evaluation purposes only. 114 | 115 |

3.4 Patent Claims. If you bring or threaten to bring a patent claim 116 | against any Licensor (including any claim, cross-claim or 117 | counterclaim in a lawsuit) to enforce any patents that you allege 118 | are infringed by any Work, then your rights under this License from 119 | such Licensor (including the grant in Section 2.1) will terminate immediately. 120 | 121 |

3.5 Trademarks. This License does not grant any rights to use any 122 | Licensor’s or its affiliates’ names, logos, or trademarks, except 123 | as necessary to reproduce the notices described in this License.

124 | 125 |

3.6 Termination. If you violate any term of this License, then your 126 | rights under this License (including the grant in Section 2.1) 127 | will terminate immediately.

128 | 129 |

4. Disclaimer of Warranty.

130 | 131 |

THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY 132 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 133 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 134 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 135 | THIS LICENSE.

136 | 137 |

5. Limitation of Liability.

138 | 139 |

EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 140 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 141 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 142 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 143 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 144 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 145 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 146 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 147 | THE POSSIBILITY OF SUCH DAMAGES.

148 | 149 |
150 |
151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/docs/stylegan2-ada-teaser-1024x252.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/stylegan2_ada_pytorch/docs/stylegan2-ada-teaser-1024x252.png -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/docs/stylegan2-ada-training-curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ic_gan/8eff2f7390e385e801993211a941f3592c50984d/stylegan2_ada_pytorch/docs/stylegan2-ada-training-curves.png -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/docs/train-help.txt: -------------------------------------------------------------------------------- 1 | Usage: train.py [OPTIONS] 2 | 3 | Train a GAN using the techniques described in the paper "Training 4 | Generative Adversarial Networks with Limited Data". 5 | 6 | Examples: 7 | 8 | # Train with custom images using 1 GPU. 9 | python train.py --outdir=~/training-runs --data=~/my-image-folder 10 | 11 | # Train class-conditional CIFAR-10 using 2 GPUs. 12 | python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \ 13 | --gpus=2 --cfg=cifar --cond=1 14 | 15 | # Transfer learn MetFaces from FFHQ using 4 GPUs. 16 | python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \ 17 | --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10 18 | 19 | # Reproduce original StyleGAN2 config F. 20 | python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \ 21 | --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug 22 | 23 | Base configs (--cfg): 24 | auto Automatically select reasonable defaults based on resolution 25 | and GPU count. Good starting point for new datasets. 26 | stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024. 27 | paper256 Reproduce results for FFHQ and LSUN Cat at 256x256. 28 | paper512 Reproduce results for BreCaHAD and AFHQ at 512x512. 29 | paper1024 Reproduce results for MetFaces at 1024x1024. 30 | cifar Reproduce results for CIFAR-10 at 32x32. 31 | 32 | Transfer learning source networks (--resume): 33 | ffhq256 FFHQ trained at 256x256 resolution. 34 | ffhq512 FFHQ trained at 512x512 resolution. 35 | ffhq1024 FFHQ trained at 1024x1024 resolution. 36 | celebahq256 CelebA-HQ trained at 256x256 resolution. 37 | lsundog256 LSUN Dog trained at 256x256 resolution. 38 | Custom network pickle. 39 | 40 | Options: 41 | --outdir DIR Where to save the results [required] 42 | --gpus INT Number of GPUs to use [default: 1] 43 | --snap INT Snapshot interval [default: 50 ticks] 44 | --metrics LIST Comma-separated list or "none" [default: 45 | fid50k_full] 46 | --seed INT Random seed [default: 0] 47 | -n, --dry-run Print training options and exit 48 | --data PATH Training data (directory or zip) [required] 49 | --cond BOOL Train conditional model based on dataset 50 | labels [default: false] 51 | --subset INT Train with only N images [default: all] 52 | --mirror BOOL Enable dataset x-flips [default: false] 53 | --cfg [auto|stylegan2|paper256|paper512|paper1024|cifar] 54 | Base config [default: auto] 55 | --gamma FLOAT Override R1 gamma 56 | --kimg INT Override training duration 57 | --batch INT Override batch size 58 | --aug [noaug|ada|fixed] Augmentation mode [default: ada] 59 | --p FLOAT Augmentation probability for --aug=fixed 60 | --target FLOAT ADA target value for --aug=ada 61 | --augpipe [blit|geom|color|filter|noise|cutout|bg|bgc|bgcf|bgcfn|bgcfnc] 62 | Augmentation pipeline [default: bgc] 63 | --resume PKL Resume training [default: noresume] 64 | --freezed INT Freeze-D [default: 0 layers] 65 | --fp32 BOOL Disable mixed-precision training 66 | --nhwc BOOL Use NHWC memory format with FP16 67 | --nobench BOOL Disable cuDNN benchmarking 68 | --allow-tf32 BOOL Allow PyTorch to use TF32 internally 69 | --workers INT Override number of DataLoader workers 70 | --help Show this message and exit. 71 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # NVIDIA CORPORATION and its licensors retain all intellectual property 7 | # and proprietary rights in and to this software, related documentation 8 | # and any modifications thereto. Any use, reproduction, disclosure or 9 | # distribution of this software and related documentation without an express 10 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | """Generate images using pretrained network pickle.""" 13 | 14 | import os 15 | import re 16 | from typing import List, Optional 17 | 18 | import click 19 | import dnnlib 20 | import numpy as np 21 | import PIL.Image 22 | import torch 23 | 24 | import legacy 25 | 26 | # ---------------------------------------------------------------------------- 27 | 28 | 29 | def num_range(s: str) -> List[int]: 30 | """Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.""" 31 | 32 | range_re = re.compile(r"^(\d+)-(\d+)$") 33 | m = range_re.match(s) 34 | if m: 35 | return list(range(int(m.group(1)), int(m.group(2)) + 1)) 36 | vals = s.split(",") 37 | return [int(x) for x in vals] 38 | 39 | 40 | # ---------------------------------------------------------------------------- 41 | 42 | 43 | @click.command() 44 | @click.pass_context 45 | @click.option("--network", "network_pkl", help="Network pickle filename", required=True) 46 | @click.option("--seeds", type=num_range, help="List of random seeds") 47 | @click.option( 48 | "--trunc", 49 | "truncation_psi", 50 | type=float, 51 | help="Truncation psi", 52 | default=1, 53 | show_default=True, 54 | ) 55 | @click.option( 56 | "--class", 57 | "class_idx", 58 | type=int, 59 | help="Class label (unconditional if not specified)", 60 | ) 61 | @click.option( 62 | "--noise-mode", 63 | help="Noise mode", 64 | type=click.Choice(["const", "random", "none"]), 65 | default="const", 66 | show_default=True, 67 | ) 68 | @click.option("--projected-w", help="Projection result file", type=str, metavar="FILE") 69 | @click.option( 70 | "--outdir", 71 | help="Where to save the output images", 72 | type=str, 73 | required=True, 74 | metavar="DIR", 75 | ) 76 | def generate_images( 77 | ctx: click.Context, 78 | network_pkl: str, 79 | seeds: Optional[List[int]], 80 | truncation_psi: float, 81 | noise_mode: str, 82 | outdir: str, 83 | class_idx: Optional[int], 84 | projected_w: Optional[str], 85 | ): 86 | """Generate images using pretrained network pickle. 87 | 88 | Examples: 89 | 90 | \b 91 | # Generate curated MetFaces images without truncation (Fig.10 left) 92 | python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\ 93 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 94 | 95 | \b 96 | # Generate uncurated MetFaces images with truncation (Fig.12 upper left) 97 | python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\ 98 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 99 | 100 | \b 101 | # Generate class conditional CIFAR-10 images (Fig.17 left, Car) 102 | python generate.py --outdir=out --seeds=0-35 --class=1 \\ 103 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl 104 | 105 | \b 106 | # Render an image from projected W 107 | python generate.py --outdir=out --projected_w=projected_w.npz \\ 108 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 109 | """ 110 | 111 | print('Loading networks from "%s"...' % network_pkl) 112 | device = torch.device("cuda") 113 | with dnnlib.util.open_url(network_pkl) as f: 114 | G = legacy.load_network_pkl(f)["G_ema"].to(device) # type: ignore 115 | 116 | os.makedirs(outdir, exist_ok=True) 117 | 118 | # Synthesize the result of a W projection. 119 | if projected_w is not None: 120 | if seeds is not None: 121 | print("warn: --seeds is ignored when using --projected-w") 122 | print(f'Generating images from projected W "{projected_w}"') 123 | ws = np.load(projected_w)["w"] 124 | ws = torch.tensor(ws, device=device) # pylint: disable=not-callable 125 | assert ws.shape[1:] == (G.num_ws, G.w_dim) 126 | for idx, w in enumerate(ws): 127 | img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode) 128 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 129 | img = PIL.Image.fromarray(img[0].cpu().numpy(), "RGB").save( 130 | f"{outdir}/proj{idx:02d}.png" 131 | ) 132 | return 133 | 134 | if seeds is None: 135 | ctx.fail("--seeds option is required when not using --projected-w") 136 | 137 | # Labels. 138 | label = torch.zeros([1, G.c_dim], device=device) 139 | if G.c_dim != 0: 140 | if class_idx is None: 141 | ctx.fail( 142 | "Must specify class label with --class when using a conditional network" 143 | ) 144 | label[:, class_idx] = 1 145 | else: 146 | if class_idx is not None: 147 | print("warn: --class=lbl ignored when running on an unconditional network") 148 | 149 | # Generate images. 150 | for seed_idx, seed in enumerate(seeds): 151 | print("Generating image for seed %d (%d/%d) ..." % (seed, seed_idx, len(seeds))) 152 | z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) 153 | img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) 154 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 155 | PIL.Image.fromarray(img[0].cpu().numpy(), "RGB").save( 156 | f"{outdir}/seed{seed:04d}.png" 157 | ) 158 | 159 | 160 | # ---------------------------------------------------------------------------- 161 | 162 | if __name__ == "__main__": 163 | generate_images() # pylint: disable=no-value-for-parameter 164 | 165 | # ---------------------------------------------------------------------------- 166 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # NVIDIA CORPORATION and its licensors retain all intellectual property 7 | # and proprietary rights in and to this software, related documentation 8 | # and any modifications thereto. Any use, reproduction, disclosure or 9 | # distribution of this software and related documentation without an express 10 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | """Frechet Inception Distance (FID) from the paper 13 | "GANs trained by a two time-scale update rule converge to a local Nash 14 | equilibrium". Matches the original implementation by Heusel et al. at 15 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 16 | 17 | import numpy as np 18 | import scipy.linalg 19 | from . import metric_utils 20 | 21 | # ---------------------------------------------------------------------------- 22 | 23 | 24 | def compute_fid(opts, max_real, num_gen): 25 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 26 | detector_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt" 27 | detector_kwargs = dict( 28 | return_features=True 29 | ) # Return raw features before the softmax layer. 30 | 31 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 32 | opts=opts, 33 | detector_url=detector_url, 34 | detector_kwargs=detector_kwargs, 35 | rel_lo=0, 36 | rel_hi=0, 37 | capture_mean_cov=True, 38 | max_items=max_real, 39 | ).get_mean_cov() 40 | 41 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 42 | opts=opts, 43 | detector_url=detector_url, 44 | detector_kwargs=detector_kwargs, 45 | rel_lo=0, 46 | rel_hi=1, 47 | capture_mean_cov=True, 48 | max_items=num_gen, 49 | ).get_mean_cov() 50 | 51 | if opts.rank != 0: 52 | return float("nan") 53 | 54 | m = np.square(mu_gen - mu_real).sum() 55 | s, _ = scipy.linalg.sqrtm( 56 | np.dot(sigma_gen, sigma_real), disp=False 57 | ) # pylint: disable=no-member 58 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 59 | return float(fid) 60 | 61 | 62 | # ---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # NVIDIA CORPORATION and its licensors retain all intellectual property 7 | # and proprietary rights in and to this software, related documentation 8 | # and any modifications thereto. Any use, reproduction, disclosure or 9 | # distribution of this software and related documentation without an express 10 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | """Inception Score (IS) from the paper "Improved techniques for training 13 | GANs". Matches the original implementation by Salimans et al. at 14 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 15 | 16 | import numpy as np 17 | from . import metric_utils 18 | 19 | # ---------------------------------------------------------------------------- 20 | 21 | 22 | def compute_is(opts, num_gen, num_splits): 23 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 24 | detector_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt" 25 | detector_kwargs = dict( 26 | no_output_bias=True 27 | ) # Match the original implementation by not applying bias in the softmax layer. 28 | 29 | gen_probs = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, 31 | detector_url=detector_url, 32 | detector_kwargs=detector_kwargs, 33 | capture_all=True, 34 | max_items=num_gen, 35 | ).get_all() 36 | 37 | if opts.rank != 0: 38 | return float("nan"), float("nan") 39 | 40 | scores = [] 41 | for i in range(num_splits): 42 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 43 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 44 | kl = np.mean(np.sum(kl, axis=1)) 45 | scores.append(np.exp(kl)) 46 | return float(np.mean(scores)), float(np.std(scores)) 47 | 48 | 49 | # ---------------------------------------------------------------------------- 50 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # NVIDIA CORPORATION and its licensors retain all intellectual property 7 | # and proprietary rights in and to this software, related documentation 8 | # and any modifications thereto. Any use, reproduction, disclosure or 9 | # distribution of this software and related documentation without an express 10 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 13 | GANs". Matches the original implementation by Binkowski et al. at 14 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 15 | 16 | import numpy as np 17 | from . import metric_utils 18 | 19 | # ---------------------------------------------------------------------------- 20 | 21 | 22 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 23 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 24 | detector_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt" 25 | detector_kwargs = dict( 26 | return_features=True 27 | ) # Return raw features before the softmax layer. 28 | 29 | real_features = metric_utils.compute_feature_stats_for_dataset( 30 | opts=opts, 31 | detector_url=detector_url, 32 | detector_kwargs=detector_kwargs, 33 | rel_lo=0, 34 | rel_hi=0, 35 | capture_all=True, 36 | max_items=max_real, 37 | ).get_all() 38 | 39 | gen_features = metric_utils.compute_feature_stats_for_generator( 40 | opts=opts, 41 | detector_url=detector_url, 42 | detector_kwargs=detector_kwargs, 43 | rel_lo=0, 44 | rel_hi=1, 45 | capture_all=True, 46 | max_items=num_gen, 47 | ).get_all() 48 | 49 | if opts.rank != 0: 50 | return float("nan") 51 | 52 | n = real_features.shape[1] 53 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 54 | t = 0 55 | for _subset_idx in range(num_subsets): 56 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 57 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 58 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 59 | b = (x @ y.T / n + 1) ** 3 60 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 61 | kid = t / num_subsets / m 62 | return float(kid) 63 | 64 | 65 | # ---------------------------------------------------------------------------- 66 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # NVIDIA CORPORATION and its licensors retain all intellectual property 7 | # and proprietary rights in and to this software, related documentation 8 | # and any modifications thereto. Any use, reproduction, disclosure or 9 | # distribution of this software and related documentation without an express 10 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 13 | Architecture for Generative Adversarial Networks". Matches the original 14 | implementation by Karras et al. at 15 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 16 | 17 | import copy 18 | import numpy as np 19 | import torch 20 | import dnnlib 21 | from . import metric_utils 22 | 23 | # ---------------------------------------------------------------------------- 24 | 25 | # Spherical interpolation of a batch of vectors. 26 | def slerp(a, b, t): 27 | a = a / a.norm(dim=-1, keepdim=True) 28 | b = b / b.norm(dim=-1, keepdim=True) 29 | d = (a * b).sum(dim=-1, keepdim=True) 30 | p = t * torch.acos(d) 31 | c = b - d * a 32 | c = c / c.norm(dim=-1, keepdim=True) 33 | d = a * torch.cos(p) + c * torch.sin(p) 34 | d = d / d.norm(dim=-1, keepdim=True) 35 | return d 36 | 37 | 38 | # ---------------------------------------------------------------------------- 39 | 40 | 41 | class PPLSampler(torch.nn.Module): 42 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 43 | assert space in ["z", "w"] 44 | assert sampling in ["full", "end"] 45 | super().__init__() 46 | self.G = copy.deepcopy(G) 47 | self.G_kwargs = G_kwargs 48 | self.epsilon = epsilon 49 | self.space = space 50 | self.sampling = sampling 51 | self.crop = crop 52 | self.vgg16 = copy.deepcopy(vgg16) 53 | 54 | def forward(self, c): 55 | # Generate random latents and interpolation t-values. 56 | t = torch.rand([c.shape[0]], device=c.device) * ( 57 | 1 if self.sampling == "full" else 0 58 | ) 59 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 60 | 61 | # Interpolate in W or Z. 62 | if self.space == "w": 63 | w0, w1 = self.G.mapping(z=torch.cat([z0, z1]), c=torch.cat([c, c])).chunk(2) 64 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 65 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 66 | else: # space == 'z' 67 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 68 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 69 | wt0, wt1 = self.G.mapping( 70 | z=torch.cat([zt0, zt1]), c=torch.cat([c, c]) 71 | ).chunk(2) 72 | 73 | # Randomize noise buffers. 74 | for name, buf in self.G.named_buffers(): 75 | if name.endswith(".noise_const"): 76 | buf.copy_(torch.randn_like(buf)) 77 | 78 | # Generate images. 79 | img = self.G.synthesis( 80 | ws=torch.cat([wt0, wt1]), 81 | noise_mode="const", 82 | force_fp32=True, 83 | **self.G_kwargs 84 | ) 85 | 86 | # Center crop. 87 | if self.crop: 88 | assert img.shape[2] == img.shape[3] 89 | c = img.shape[2] // 8 90 | img = img[:, :, c * 3 : c * 7, c * 2 : c * 6] 91 | 92 | # Downsample to 256x256. 93 | factor = self.G.img_resolution // 256 94 | if factor > 1: 95 | img = img.reshape( 96 | [ 97 | -1, 98 | img.shape[1], 99 | img.shape[2] // factor, 100 | factor, 101 | img.shape[3] // factor, 102 | factor, 103 | ] 104 | ).mean([3, 5]) 105 | 106 | # Scale dynamic range from [-1,1] to [0,255]. 107 | img = (img + 1) * (255 / 2) 108 | if self.G.img_channels == 1: 109 | img = img.repeat([1, 3, 1, 1]) 110 | 111 | # Evaluate differential LPIPS. 112 | lpips_t0, lpips_t1 = self.vgg16( 113 | img, resize_images=False, return_lpips=True 114 | ).chunk(2) 115 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 116 | return dist 117 | 118 | 119 | # ---------------------------------------------------------------------------- 120 | 121 | 122 | def compute_ppl( 123 | opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False 124 | ): 125 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 126 | vgg16_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt" 127 | vgg16 = metric_utils.get_feature_detector( 128 | vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose 129 | ) 130 | 131 | # Setup sampler. 132 | sampler = PPLSampler( 133 | G=opts.G, 134 | G_kwargs=opts.G_kwargs, 135 | epsilon=epsilon, 136 | space=space, 137 | sampling=sampling, 138 | crop=crop, 139 | vgg16=vgg16, 140 | ) 141 | sampler.eval().requires_grad_(False).to(opts.device) 142 | if jit: 143 | c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) 144 | sampler = torch.jit.trace(sampler, [c], check_trace=False) 145 | 146 | # Sampling loop. 147 | dist = [] 148 | progress = opts.progress.sub(tag="ppl sampling", num_items=num_samples) 149 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 150 | progress.update(batch_start) 151 | c = [ 152 | dataset.get_label(np.random.randint(len(dataset))) 153 | for _i in range(batch_size) 154 | ] 155 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) 156 | x = sampler(c) 157 | for src in range(opts.num_gpus): 158 | y = x.clone() 159 | if opts.num_gpus > 1: 160 | torch.distributed.broadcast(y, src=src) 161 | dist.append(y) 162 | progress.update(num_samples) 163 | 164 | # Compute PPL. 165 | if opts.rank != 0: 166 | return float("nan") 167 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 168 | lo = np.percentile(dist, 1, interpolation="lower") 169 | hi = np.percentile(dist, 99, interpolation="higher") 170 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 171 | return float(ppl) 172 | 173 | 174 | # ---------------------------------------------------------------------------- 175 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # NVIDIA CORPORATION and its licensors retain all intellectual property 7 | # and proprietary rights in and to this software, related documentation 8 | # and any modifications thereto. Any use, reproduction, disclosure or 9 | # distribution of this software and related documentation without an express 10 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 13 | Metric for Assessing Generative Models". Matches the original implementation 14 | by Kynkaanniemi et al. at 15 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 16 | 17 | import torch 18 | from . import metric_utils 19 | 20 | # ---------------------------------------------------------------------------- 21 | 22 | 23 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 24 | assert 0 <= rank < num_gpus 25 | num_cols = col_features.shape[0] 26 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 27 | col_batches = torch.nn.functional.pad( 28 | col_features, [0, 0, 0, -num_cols % num_batches] 29 | ).chunk(num_batches) 30 | dist_batches = [] 31 | for col_batch in col_batches[rank::num_gpus]: 32 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 33 | for src in range(num_gpus): 34 | dist_broadcast = dist_batch.clone() 35 | if num_gpus > 1: 36 | torch.distributed.broadcast(dist_broadcast, src=src) 37 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 38 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 39 | 40 | 41 | # ---------------------------------------------------------------------------- 42 | 43 | 44 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 45 | detector_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt" 46 | detector_kwargs = dict(return_features=True) 47 | 48 | real_features = ( 49 | metric_utils.compute_feature_stats_for_dataset( 50 | opts=opts, 51 | detector_url=detector_url, 52 | detector_kwargs=detector_kwargs, 53 | rel_lo=0, 54 | rel_hi=0, 55 | capture_all=True, 56 | max_items=max_real, 57 | ) 58 | .get_all_torch() 59 | .to(torch.float16) 60 | .to(opts.device) 61 | ) 62 | 63 | gen_features = ( 64 | metric_utils.compute_feature_stats_for_generator( 65 | opts=opts, 66 | detector_url=detector_url, 67 | detector_kwargs=detector_kwargs, 68 | rel_lo=0, 69 | rel_hi=1, 70 | capture_all=True, 71 | max_items=num_gen, 72 | ) 73 | .get_all_torch() 74 | .to(torch.float16) 75 | .to(opts.device) 76 | ) 77 | 78 | results = dict() 79 | for name, manifold, probes in [ 80 | ("precision", real_features, gen_features), 81 | ("recall", gen_features, real_features), 82 | ]: 83 | kth = [] 84 | for manifold_batch in manifold.split(row_batch_size): 85 | dist = compute_distances( 86 | row_features=manifold_batch, 87 | col_features=manifold, 88 | num_gpus=opts.num_gpus, 89 | rank=opts.rank, 90 | col_batch_size=col_batch_size, 91 | ) 92 | kth.append( 93 | dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) 94 | if opts.rank == 0 95 | else None 96 | ) 97 | kth = torch.cat(kth) if opts.rank == 0 else None 98 | pred = [] 99 | for probes_batch in probes.split(row_batch_size): 100 | dist = compute_distances( 101 | row_features=probes_batch, 102 | col_features=manifold, 103 | num_gpus=opts.num_gpus, 104 | rank=opts.rank, 105 | col_batch_size=col_batch_size, 106 | ) 107 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 108 | results[name] = float( 109 | torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else "nan" 110 | ) 111 | return results["precision"], results["recall"] 112 | 113 | 114 | # ---------------------------------------------------------------------------- 115 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from train import main 9 | from submitit.helpers import Checkpointable 10 | 11 | LOCAL = False 12 | try: 13 | import submitit 14 | except: 15 | print( 16 | "No submitit package found! Defaulting to executing the script in the local machine" 17 | ) 18 | LOCAL = True 19 | import parser 20 | import json 21 | 22 | 23 | class Trainer(Checkpointable): 24 | def __call__(self, args, slurm=False): 25 | if slurm and not LOCAL: 26 | main( 27 | args, 28 | args.outdir, 29 | master_node=submitit.JobEnvironment().hostnames[0], 30 | port=args.port, 31 | ) 32 | else: 33 | main(args, args.outdir, master_node="", dry_run=args.dry_run) 34 | 35 | 36 | if __name__ == "__main__": 37 | parser_ = parser.get_parser() 38 | args = parser_.parse_args() 39 | 40 | if args.json_config != "": 41 | data = json.load(open(args.json_config)) 42 | for key in data.keys(): 43 | setattr(args, key, data[key]) 44 | else: 45 | print("Not using JSON configuration file!") 46 | if args.data_root is not None: 47 | print("Appending data_root to paths") 48 | args.data = os.path.join(args.data_root, args.data) 49 | args.root_feats = os.path.join(args.data_root, args.root_feats) 50 | args.root_nns = os.path.join(args.data_root, args.root_nns) 51 | args.outdir = args.base_root 52 | 53 | trainer = Trainer() 54 | if not args.slurm or LOCAL: 55 | trainer(args) 56 | else: 57 | 58 | executor = submitit.SlurmExecutor(folder=args.slurm_logdir, max_num_timeout=60) 59 | print(args.gpus) 60 | executor.update_parameters( 61 | gpus_per_node=args.gpus, 62 | partition=args.partition, 63 | constraint="volta32gb", 64 | nodes=args.nodes, 65 | ntasks_per_node=args.gpus, 66 | cpus_per_task=10, 67 | mem=256000, 68 | time=args.slurm_time, 69 | job_name=args.exp_name, 70 | exclusive=True if args.gpus == 8 else False, 71 | ) 72 | 73 | job = executor.submit(trainer, args, slurm=True) 74 | print(job.job_id) 75 | 76 | import time 77 | 78 | time.sleep(1) 79 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/style_mixing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # NVIDIA CORPORATION and its licensors retain all intellectual property 7 | # and proprietary rights in and to this software, related documentation 8 | # and any modifications thereto. Any use, reproduction, disclosure or 9 | # distribution of this software and related documentation without an express 10 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | """Generate style mixing image matrix using pretrained network pickle.""" 13 | 14 | import os 15 | import re 16 | from typing import List 17 | 18 | import click 19 | import dnnlib 20 | import numpy as np 21 | import PIL.Image 22 | import torch 23 | 24 | import legacy 25 | 26 | # ---------------------------------------------------------------------------- 27 | 28 | 29 | def num_range(s: str) -> List[int]: 30 | """Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.""" 31 | 32 | range_re = re.compile(r"^(\d+)-(\d+)$") 33 | m = range_re.match(s) 34 | if m: 35 | return list(range(int(m.group(1)), int(m.group(2)) + 1)) 36 | vals = s.split(",") 37 | return [int(x) for x in vals] 38 | 39 | 40 | # ---------------------------------------------------------------------------- 41 | 42 | 43 | @click.command() 44 | @click.option("--network", "network_pkl", help="Network pickle filename", required=True) 45 | @click.option( 46 | "--rows", 47 | "row_seeds", 48 | type=num_range, 49 | help="Random seeds to use for image rows", 50 | required=True, 51 | ) 52 | @click.option( 53 | "--cols", 54 | "col_seeds", 55 | type=num_range, 56 | help="Random seeds to use for image columns", 57 | required=True, 58 | ) 59 | @click.option( 60 | "--styles", 61 | "col_styles", 62 | type=num_range, 63 | help="Style layer range", 64 | default="0-6", 65 | show_default=True, 66 | ) 67 | @click.option( 68 | "--trunc", 69 | "truncation_psi", 70 | type=float, 71 | help="Truncation psi", 72 | default=1, 73 | show_default=True, 74 | ) 75 | @click.option( 76 | "--noise-mode", 77 | help="Noise mode", 78 | type=click.Choice(["const", "random", "none"]), 79 | default="const", 80 | show_default=True, 81 | ) 82 | @click.option("--outdir", type=str, required=True) 83 | def generate_style_mix( 84 | network_pkl: str, 85 | row_seeds: List[int], 86 | col_seeds: List[int], 87 | col_styles: List[int], 88 | truncation_psi: float, 89 | noise_mode: str, 90 | outdir: str, 91 | ): 92 | """Generate images using pretrained network pickle. 93 | 94 | Examples: 95 | 96 | \b 97 | python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \\ 98 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 99 | """ 100 | print('Loading networks from "%s"...' % network_pkl) 101 | device = torch.device("cuda") 102 | with dnnlib.util.open_url(network_pkl) as f: 103 | G = legacy.load_network_pkl(f)["G_ema"].to(device) # type: ignore 104 | 105 | os.makedirs(outdir, exist_ok=True) 106 | 107 | print("Generating W vectors...") 108 | all_seeds = list(set(row_seeds + col_seeds)) 109 | all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds]) 110 | all_w = G.mapping(torch.from_numpy(all_z).to(device), None) 111 | w_avg = G.mapping.w_avg 112 | all_w = w_avg + (all_w - w_avg) * truncation_psi 113 | w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} 114 | 115 | print("Generating images...") 116 | all_images = G.synthesis(all_w, noise_mode=noise_mode) 117 | all_images = ( 118 | (all_images.permute(0, 2, 3, 1) * 127.5 + 128) 119 | .clamp(0, 255) 120 | .to(torch.uint8) 121 | .cpu() 122 | .numpy() 123 | ) 124 | image_dict = { 125 | (seed, seed): image for seed, image in zip(all_seeds, list(all_images)) 126 | } 127 | 128 | print("Generating style-mixed images...") 129 | for row_seed in row_seeds: 130 | for col_seed in col_seeds: 131 | w = w_dict[row_seed].clone() 132 | w[col_styles] = w_dict[col_seed][col_styles] 133 | image = G.synthesis(w[np.newaxis], noise_mode=noise_mode) 134 | image = ( 135 | (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 136 | ) 137 | image_dict[(row_seed, col_seed)] = image[0].cpu().numpy() 138 | 139 | print("Saving images...") 140 | os.makedirs(outdir, exist_ok=True) 141 | for (row_seed, col_seed), image in image_dict.items(): 142 | PIL.Image.fromarray(image, "RGB").save(f"{outdir}/{row_seed}-{col_seed}.png") 143 | 144 | print("Saving image grid...") 145 | W = G.img_resolution 146 | H = G.img_resolution 147 | canvas = PIL.Image.new( 148 | "RGB", (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), "black" 149 | ) 150 | for row_idx, row_seed in enumerate([0] + row_seeds): 151 | for col_idx, col_seed in enumerate([0] + col_seeds): 152 | if row_idx == 0 and col_idx == 0: 153 | continue 154 | key = (row_seed, col_seed) 155 | if row_idx == 0: 156 | key = (col_seed, col_seed) 157 | if col_idx == 0: 158 | key = (row_seed, row_seed) 159 | canvas.paste( 160 | PIL.Image.fromarray(image_dict[key], "RGB"), (W * col_idx, H * row_idx) 161 | ) 162 | canvas.save(f"{outdir}/grid.png") 163 | 164 | 165 | # ---------------------------------------------------------------------------- 166 | 167 | if __name__ == "__main__": 168 | generate_style_mix() # pylint: disable=no-value-for-parameter 169 | 170 | # ---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # NVIDIA CORPORATION and its licensors retain all intellectual property 7 | # and proprietary rights in and to this software, related documentation 8 | # and any modifications thereto. Any use, reproduction, disclosure or 9 | # distribution of this software and related documentation without an express 10 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | import os 13 | import glob 14 | import torch 15 | import torch.utils.cpp_extension 16 | import importlib 17 | import hashlib 18 | import shutil 19 | from pathlib import Path 20 | 21 | from torch.utils.file_baton import FileBaton 22 | 23 | # ---------------------------------------------------------------------------- 24 | # Global options. 25 | 26 | verbosity = "brief" # Verbosity level: 'none', 'brief', 'full' 27 | 28 | # ---------------------------------------------------------------------------- 29 | # Internal helper funcs. 30 | 31 | 32 | def _find_compiler_bindir(): 33 | patterns = [ 34 | "C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64", 35 | "C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64", 36 | "C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64", 37 | "C:/Program Files (x86)/Microsoft Visual Studio */vc/bin", 38 | ] 39 | for pattern in patterns: 40 | matches = sorted(glob.glob(pattern)) 41 | if len(matches): 42 | return matches[-1] 43 | return None 44 | 45 | 46 | # ---------------------------------------------------------------------------- 47 | # Main entry point for compiling and loading C++/CUDA plugins. 48 | 49 | _cached_plugins = dict() 50 | 51 | 52 | def get_plugin(module_name, sources, **build_kwargs): 53 | assert verbosity in ["none", "brief", "full"] 54 | 55 | # Already cached? 56 | if module_name in _cached_plugins: 57 | return _cached_plugins[module_name] 58 | 59 | # Print status. 60 | if verbosity == "full": 61 | print(f'Setting up PyTorch plugin "{module_name}"...') 62 | elif verbosity == "brief": 63 | print(f'Setting up PyTorch plugin "{module_name}"... ', end="", flush=True) 64 | 65 | try: # pylint: disable=too-many-nested-blocks 66 | # Make sure we can find the necessary compiler binaries. 67 | if os.name == "nt" and os.system("where cl.exe >nul 2>nul") != 0: 68 | compiler_bindir = _find_compiler_bindir() 69 | if compiler_bindir is None: 70 | raise RuntimeError( 71 | f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".' 72 | ) 73 | os.environ["PATH"] += ";" + compiler_bindir 74 | 75 | # Compile and load. 76 | verbose_build = verbosity == "full" 77 | 78 | # Incremental build md5sum trickery. Copies all the input source files 79 | # into a cached build directory under a combined md5 digest of the input 80 | # source files. Copying is done only if the combined digest has changed. 81 | # This keeps input file timestamps and filenames the same as in previous 82 | # extension builds, allowing for fast incremental rebuilds. 83 | # 84 | # This optimization is done only in case all the source files reside in 85 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 86 | # environment variable is set (we take this as a signal that the user 87 | # actually cares about this.) 88 | source_dirs_set = set(os.path.dirname(source) for source in sources) 89 | if len(source_dirs_set) == 1 and ("TORCH_EXTENSIONS_DIR" in os.environ): 90 | all_source_files = sorted( 91 | list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()) 92 | ) 93 | 94 | # Compute a combined hash digest for all source files in the same 95 | # custom op directory (usually .cu, .cpp, .py and .h files). 96 | hash_md5 = hashlib.md5() 97 | for src in all_source_files: 98 | with open(src, "rb") as f: 99 | hash_md5.update(f.read()) 100 | build_dir = torch.utils.cpp_extension._get_build_directory( 101 | module_name, verbose=verbose_build 102 | ) # pylint: disable=protected-access 103 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 104 | 105 | if not os.path.isdir(digest_build_dir): 106 | os.makedirs(digest_build_dir, exist_ok=True) 107 | baton = FileBaton(os.path.join(digest_build_dir, "lock")) 108 | if baton.try_acquire(): 109 | try: 110 | for src in all_source_files: 111 | shutil.copyfile( 112 | src, 113 | os.path.join(digest_build_dir, os.path.basename(src)), 114 | ) 115 | finally: 116 | baton.release() 117 | else: 118 | # Someone else is copying source files under the digest dir, 119 | # wait until done and continue. 120 | baton.wait() 121 | digest_sources = [ 122 | os.path.join(digest_build_dir, os.path.basename(x)) for x in sources 123 | ] 124 | torch.utils.cpp_extension.load( 125 | name=module_name, 126 | build_directory=build_dir, 127 | verbose=verbose_build, 128 | sources=digest_sources, 129 | **build_kwargs, 130 | ) 131 | else: 132 | torch.utils.cpp_extension.load( 133 | name=module_name, verbose=verbose_build, sources=sources, **build_kwargs 134 | ) 135 | module = importlib.import_module(module_name) 136 | 137 | except: 138 | if verbosity == "brief": 139 | print("Failed!") 140 | raise 141 | 142 | # Print status and add to cache. 143 | if verbosity == "full": 144 | print(f'Done setting up PyTorch plugin "{module_name}".') 145 | elif verbosity == "brief": 146 | print("Done.") 147 | _cached_plugins[module_name] = module 148 | return module 149 | 150 | 151 | # ---------------------------------------------------------------------------- 152 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // All rights reserved. 3 | // 4 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | // 6 | // NVIDIA CORPORATION and its licensors retain all intellectual property 7 | // and proprietary rights in and to this software, related documentation 8 | // and any modifications thereto. Any use, reproduction, disclosure or 9 | // distribution of this software and related documentation without an express 10 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | #include 13 | #include 14 | #include 15 | #include "bias_act.h" 16 | 17 | //------------------------------------------------------------------------ 18 | 19 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 20 | { 21 | if (x.dim() != y.dim()) 22 | return false; 23 | for (int64_t i = 0; i < x.dim(); i++) 24 | { 25 | if (x.size(i) != y.size(i)) 26 | return false; 27 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 28 | return false; 29 | } 30 | return true; 31 | } 32 | 33 | //------------------------------------------------------------------------ 34 | 35 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 36 | { 37 | // Validate arguments. 38 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 39 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 40 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 41 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 42 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 43 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 44 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 45 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 46 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 47 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 48 | 49 | // Validate layout. 50 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 51 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 52 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 53 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 54 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 55 | 56 | // Create output tensor. 57 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 58 | torch::Tensor y = torch::empty_like(x); 59 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 60 | 61 | // Initialize CUDA kernel parameters. 62 | bias_act_kernel_params p; 63 | p.x = x.data_ptr(); 64 | p.b = (b.numel()) ? b.data_ptr() : NULL; 65 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 66 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 67 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 68 | p.y = y.data_ptr(); 69 | p.grad = grad; 70 | p.act = act; 71 | p.alpha = alpha; 72 | p.gain = gain; 73 | p.clamp = clamp; 74 | p.sizeX = (int)x.numel(); 75 | p.sizeB = (int)b.numel(); 76 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 77 | 78 | // Choose CUDA kernel. 79 | void* kernel; 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 81 | { 82 | kernel = choose_bias_act_kernel(p); 83 | }); 84 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 85 | 86 | // Launch CUDA kernel. 87 | p.loopX = 4; 88 | int blockSize = 4 * 32; 89 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 90 | void* args[] = {&p}; 91 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 92 | return y; 93 | } 94 | 95 | //------------------------------------------------------------------------ 96 | 97 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 98 | { 99 | m.def("bias_act", &bias_act); 100 | } 101 | 102 | //------------------------------------------------------------------------ 103 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // All rights reserved. 3 | // 4 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | // 6 | // NVIDIA CORPORATION and its licensors retain all intellectual property 7 | // and proprietary rights in and to this software, related documentation 8 | // and any modifications thereto. Any use, reproduction, disclosure or 9 | // distribution of this software and related documentation without an express 10 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | #include 13 | #include "bias_act.h" 14 | 15 | //------------------------------------------------------------------------ 16 | // Helpers. 17 | 18 | template struct InternalType; 19 | template <> struct InternalType { typedef double scalar_t; }; 20 | template <> struct InternalType { typedef float scalar_t; }; 21 | template <> struct InternalType { typedef float scalar_t; }; 22 | 23 | //------------------------------------------------------------------------ 24 | // CUDA kernel. 25 | 26 | template 27 | __global__ void bias_act_kernel(bias_act_kernel_params p) 28 | { 29 | typedef typename InternalType::scalar_t scalar_t; 30 | int G = p.grad; 31 | scalar_t alpha = (scalar_t)p.alpha; 32 | scalar_t gain = (scalar_t)p.gain; 33 | scalar_t clamp = (scalar_t)p.clamp; 34 | scalar_t one = (scalar_t)1; 35 | scalar_t two = (scalar_t)2; 36 | scalar_t expRange = (scalar_t)80; 37 | scalar_t halfExpRange = (scalar_t)40; 38 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 39 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 40 | 41 | // Loop over elements. 42 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 43 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 44 | { 45 | // Load. 46 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 47 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 48 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 49 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 50 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 51 | scalar_t yy = (gain != 0) ? yref / gain : 0; 52 | scalar_t y = 0; 53 | 54 | // Apply bias. 55 | ((G == 0) ? x : xref) += b; 56 | 57 | // linear 58 | if (A == 1) 59 | { 60 | if (G == 0) y = x; 61 | if (G == 1) y = x; 62 | } 63 | 64 | // relu 65 | if (A == 2) 66 | { 67 | if (G == 0) y = (x > 0) ? x : 0; 68 | if (G == 1) y = (yy > 0) ? x : 0; 69 | } 70 | 71 | // lrelu 72 | if (A == 3) 73 | { 74 | if (G == 0) y = (x > 0) ? x : x * alpha; 75 | if (G == 1) y = (yy > 0) ? x : x * alpha; 76 | } 77 | 78 | // tanh 79 | if (A == 4) 80 | { 81 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 82 | if (G == 1) y = x * (one - yy * yy); 83 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 84 | } 85 | 86 | // sigmoid 87 | if (A == 5) 88 | { 89 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 90 | if (G == 1) y = x * yy * (one - yy); 91 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 92 | } 93 | 94 | // elu 95 | if (A == 6) 96 | { 97 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 98 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 99 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 100 | } 101 | 102 | // selu 103 | if (A == 7) 104 | { 105 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 106 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 107 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 108 | } 109 | 110 | // softplus 111 | if (A == 8) 112 | { 113 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 114 | if (G == 1) y = x * (one - exp(-yy)); 115 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 116 | } 117 | 118 | // swish 119 | if (A == 9) 120 | { 121 | if (G == 0) 122 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 123 | else 124 | { 125 | scalar_t c = exp(xref); 126 | scalar_t d = c + one; 127 | if (G == 1) 128 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 129 | else 130 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 131 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 132 | } 133 | } 134 | 135 | // Apply gain. 136 | y *= gain * dy; 137 | 138 | // Clamp. 139 | if (clamp >= 0) 140 | { 141 | if (G == 0) 142 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 143 | else 144 | y = (yref > -clamp & yref < clamp) ? y : 0; 145 | } 146 | 147 | // Store. 148 | ((T*)p.y)[xi] = (T)y; 149 | } 150 | } 151 | 152 | //------------------------------------------------------------------------ 153 | // CUDA kernel selection. 154 | 155 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 156 | { 157 | if (p.act == 1) return (void*)bias_act_kernel; 158 | if (p.act == 2) return (void*)bias_act_kernel; 159 | if (p.act == 3) return (void*)bias_act_kernel; 160 | if (p.act == 4) return (void*)bias_act_kernel; 161 | if (p.act == 5) return (void*)bias_act_kernel; 162 | if (p.act == 6) return (void*)bias_act_kernel; 163 | if (p.act == 7) return (void*)bias_act_kernel; 164 | if (p.act == 8) return (void*)bias_act_kernel; 165 | if (p.act == 9) return (void*)bias_act_kernel; 166 | return NULL; 167 | } 168 | 169 | //------------------------------------------------------------------------ 170 | // Template specializations. 171 | 172 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 173 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 174 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 175 | 176 | //------------------------------------------------------------------------ 177 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // All rights reserved. 3 | // 4 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | // 6 | // NVIDIA CORPORATION and its licensors retain all intellectual property 7 | // and proprietary rights in and to this software, related documentation 8 | // and any modifications thereto. Any use, reproduction, disclosure or 9 | // distribution of this software and related documentation without an express 10 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | //------------------------------------------------------------------------ 13 | // CUDA kernel parameters. 14 | 15 | struct bias_act_kernel_params 16 | { 17 | const void* x; // [sizeX] 18 | const void* b; // [sizeB] or NULL 19 | const void* xref; // [sizeX] or NULL 20 | const void* yref; // [sizeX] or NULL 21 | const void* dy; // [sizeX] or NULL 22 | void* y; // [sizeX] 23 | 24 | int grad; 25 | int act; 26 | float alpha; 27 | float gain; 28 | float clamp; 29 | 30 | int sizeX; 31 | int sizeB; 32 | int stepB; 33 | int loopX; 34 | }; 35 | 36 | //------------------------------------------------------------------------ 37 | // CUDA kernel selection. 38 | 39 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 40 | 41 | //------------------------------------------------------------------------ 42 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # NVIDIA CORPORATION and its licensors retain all intellectual property 7 | # and proprietary rights in and to this software, related documentation 8 | # and any modifications thereto. Any use, reproduction, disclosure or 9 | # distribution of this software and related documentation without an express 10 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 13 | 14 | import torch 15 | 16 | # ---------------------------------------------------------------------------- 17 | 18 | 19 | def fma(a, b, c): # => a * b + c 20 | return _FusedMultiplyAdd.apply(a, b, c) 21 | 22 | 23 | # ---------------------------------------------------------------------------- 24 | 25 | 26 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 27 | @staticmethod 28 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 29 | out = torch.addcmul(c, a, b) 30 | ctx.save_for_backward(a, b) 31 | ctx.c_shape = c.shape 32 | return out 33 | 34 | @staticmethod 35 | def backward(ctx, dout): # pylint: disable=arguments-differ 36 | a, b = ctx.saved_tensors 37 | c_shape = ctx.c_shape 38 | da = None 39 | db = None 40 | dc = None 41 | 42 | if ctx.needs_input_grad[0]: 43 | da = _unbroadcast(dout * b, a.shape) 44 | 45 | if ctx.needs_input_grad[1]: 46 | db = _unbroadcast(dout * a, b.shape) 47 | 48 | if ctx.needs_input_grad[2]: 49 | dc = _unbroadcast(dout, c_shape) 50 | 51 | return da, db, dc 52 | 53 | 54 | # ---------------------------------------------------------------------------- 55 | 56 | 57 | def _unbroadcast(x, shape): 58 | extra_dims = x.ndim - len(shape) 59 | assert extra_dims >= 0 60 | dim = [ 61 | i 62 | for i in range(x.ndim) 63 | if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1) 64 | ] 65 | if len(dim): 66 | x = x.sum(dim=dim, keepdim=True) 67 | if extra_dims: 68 | x = x.reshape(-1, *x.shape[extra_dims + 1 :]) 69 | assert x.shape == shape 70 | return x 71 | 72 | 73 | # ---------------------------------------------------------------------------- 74 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # NVIDIA CORPORATION and its licensors retain all intellectual property 7 | # and proprietary rights in and to this software, related documentation 8 | # and any modifications thereto. Any use, reproduction, disclosure or 9 | # distribution of this software and related documentation without an express 10 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | """Custom replacement for `torch.nn.functional.grid_sample` that 13 | supports arbitrarily high order gradients between the input and output. 14 | Only works on 2D images and assumes 15 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 16 | 17 | import warnings 18 | import torch 19 | 20 | # pylint: disable=redefined-builtin 21 | # pylint: disable=arguments-differ 22 | # pylint: disable=protected-access 23 | 24 | # ---------------------------------------------------------------------------- 25 | 26 | enabled = False # Enable the custom op by setting this to true. 27 | 28 | # ---------------------------------------------------------------------------- 29 | 30 | 31 | def grid_sample(input, grid): 32 | if _should_use_custom_op(): 33 | return _GridSample2dForward.apply(input, grid) 34 | return torch.nn.functional.grid_sample( 35 | input=input, 36 | grid=grid, 37 | mode="bilinear", 38 | padding_mode="zeros", 39 | align_corners=False, 40 | ) 41 | 42 | 43 | # ---------------------------------------------------------------------------- 44 | 45 | 46 | def _should_use_custom_op(): 47 | if not enabled: 48 | return False 49 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8.", "1.9"]): 50 | return True 51 | warnings.warn( 52 | f"grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample()." 53 | ) 54 | return False 55 | 56 | 57 | # ---------------------------------------------------------------------------- 58 | 59 | 60 | class _GridSample2dForward(torch.autograd.Function): 61 | @staticmethod 62 | def forward(ctx, input, grid): 63 | assert input.ndim == 4 64 | assert grid.ndim == 4 65 | output = torch.nn.functional.grid_sample( 66 | input=input, 67 | grid=grid, 68 | mode="bilinear", 69 | padding_mode="zeros", 70 | align_corners=False, 71 | ) 72 | ctx.save_for_backward(input, grid) 73 | return output 74 | 75 | @staticmethod 76 | def backward(ctx, grad_output): 77 | input, grid = ctx.saved_tensors 78 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 79 | return grad_input, grad_grid 80 | 81 | 82 | # ---------------------------------------------------------------------------- 83 | 84 | 85 | class _GridSample2dBackward(torch.autograd.Function): 86 | @staticmethod 87 | def forward(ctx, grad_output, input, grid): 88 | op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward") 89 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 90 | ctx.save_for_backward(grid) 91 | return grad_input, grad_grid 92 | 93 | @staticmethod 94 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 95 | _ = grad2_grad_grid # unused 96 | grid, = ctx.saved_tensors 97 | grad2_grad_output = None 98 | grad2_input = None 99 | grad2_grid = None 100 | 101 | if ctx.needs_input_grad[0]: 102 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 103 | 104 | assert not ctx.needs_input_grad[2] 105 | return grad2_grad_output, grad2_input, grad2_grid 106 | 107 | 108 | # ---------------------------------------------------------------------------- 109 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // All rights reserved. 3 | // 4 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | // 6 | // NVIDIA CORPORATION and its licensors retain all intellectual property 7 | // and proprietary rights in and to this software, related documentation 8 | // and any modifications thereto. Any use, reproduction, disclosure or 9 | // distribution of this software and related documentation without an express 10 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | #include 13 | #include 14 | #include 15 | #include "upfirdn2d.h" 16 | 17 | //------------------------------------------------------------------------ 18 | 19 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 20 | { 21 | // Validate arguments. 22 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 23 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 24 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 25 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 26 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 27 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 28 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | 41 | // Initialize CUDA kernel parameters. 42 | upfirdn2d_kernel_params p; 43 | p.x = x.data_ptr(); 44 | p.f = f.data_ptr(); 45 | p.y = y.data_ptr(); 46 | p.up = make_int2(upx, upy); 47 | p.down = make_int2(downx, downy); 48 | p.pad0 = make_int2(padx0, pady0); 49 | p.flip = (flip) ? 1 : 0; 50 | p.gain = gain; 51 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 52 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 53 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 54 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 55 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 56 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 57 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 58 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 59 | 60 | // Choose CUDA kernel. 61 | upfirdn2d_kernel_spec spec; 62 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 63 | { 64 | spec = choose_upfirdn2d_kernel(p); 65 | }); 66 | 67 | // Set looping options. 68 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 69 | p.loopMinor = spec.loopMinor; 70 | p.loopX = spec.loopX; 71 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 72 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 73 | 74 | // Compute grid size. 75 | dim3 blockSize, gridSize; 76 | if (spec.tileOutW < 0) // large 77 | { 78 | blockSize = dim3(4, 32, 1); 79 | gridSize = dim3( 80 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 81 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 82 | p.launchMajor); 83 | } 84 | else // small 85 | { 86 | blockSize = dim3(256, 1, 1); 87 | gridSize = dim3( 88 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 89 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 90 | p.launchMajor); 91 | } 92 | 93 | // Launch CUDA kernel. 94 | void* args[] = {&p}; 95 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 96 | return y; 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | 101 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 102 | { 103 | m.def("upfirdn2d", &upfirdn2d); 104 | } 105 | 106 | //------------------------------------------------------------------------ 107 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // All rights reserved. 3 | // 4 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 5 | // 6 | // NVIDIA CORPORATION and its licensors retain all intellectual property 7 | // and proprietary rights in and to this software, related documentation 8 | // and any modifications thereto. Any use, reproduction, disclosure or 9 | // distribution of this software and related documentation without an express 10 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 11 | 12 | #include 13 | 14 | //------------------------------------------------------------------------ 15 | // CUDA kernel parameters. 16 | 17 | struct upfirdn2d_kernel_params 18 | { 19 | const void* x; 20 | const float* f; 21 | void* y; 22 | 23 | int2 up; 24 | int2 down; 25 | int2 pad0; 26 | int flip; 27 | float gain; 28 | 29 | int4 inSize; // [width, height, channel, batch] 30 | int4 inStride; 31 | int2 filterSize; // [width, height] 32 | int2 filterStride; 33 | int4 outSize; // [width, height, channel, batch] 34 | int4 outStride; 35 | int sizeMinor; 36 | int sizeMajor; 37 | 38 | int loopMinor; 39 | int loopMajor; 40 | int loopX; 41 | int launchMinor; 42 | int launchMajor; 43 | }; 44 | 45 | //------------------------------------------------------------------------ 46 | // CUDA kernel specialization. 47 | 48 | struct upfirdn2d_kernel_spec 49 | { 50 | void* kernel; 51 | int tileOutW; 52 | int tileOutH; 53 | int loopMinor; 54 | int loopX; 55 | }; 56 | 57 | //------------------------------------------------------------------------ 58 | // CUDA kernel selection. 59 | 60 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 61 | 62 | //------------------------------------------------------------------------ 63 | -------------------------------------------------------------------------------- /stylegan2_ada_pytorch/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | --------------------------------------------------------------------------------