├── .gitignore ├── LICENSE ├── LICENSE-NVIDIA ├── README.md ├── README_StudioGAN.md ├── configs ├── Conditional_img_synthesis │ ├── acgan_sngan_cifar32_rel_no.json │ ├── acgan_sngan_cifar32_rel_no_weightReg_0.01_no.json │ ├── acgan_sngan_cifar32_rel_no_weightReg_0.1_no.json │ ├── acgan_sngan_lsun_rel_no_weightReg_0.01_no.json │ ├── acgan_snresgan_cifar32_rel_no.json │ ├── contra_biggan_tiny64_hinge_no.json │ ├── proj_biggan_tiny64_hinge_no.json │ ├── proj_sagan_tiny64_hinge_no.json │ ├── proj_sngan_cifar32_hinge_no.json │ ├── proj_sngan_cifar32_rel_no.json │ ├── proj_sngan_cifar32_rel_no_weightReg_0.01_512_no.json │ ├── proj_sngan_cifar32_rel_no_weightReg_0.01_no.json │ ├── proj_sngan_cifar32_rel_no_weightReg_0.1_512_no.json │ ├── proj_sngan_cifar32_rel_no_weightReg_0.1_no.json │ ├── proj_sngan_lsun_rel_no.json │ ├── proj_sngan_lsun_rel_no_weightReg_0.01_512_no.json │ ├── proj_sngan_lsun_rel_no_weightReg_0.1_512_no.json │ └── proj_sngan_tiny64_hinge_no.json └── Unconditional_img_synthesis │ ├── no_dcgan_cifar32_hinge_no.json │ ├── no_dcgan_cifar32_rel_no_weightReg_0.01_no.json │ ├── no_dcgan_cifar32_rel_no_weightReg_0.1_no.json │ ├── no_dcgan_cifar32_rel_no_weightReg_no.json │ ├── no_dcgan_cifar32_rel_weightReg_0.01_no.json │ ├── no_dcgan_cifar32_rel_weightReg_0.1_no.json │ ├── no_dcgan_cifar32_rel_weightReg_no.json │ ├── no_dcgan_lsun_rel_weightReg_0.01_no.json │ ├── no_dcgan_lsun_rel_weightReg_0.1_no.json │ ├── no_dcgan_lsun_rel_weightReg_no.json │ ├── no_resgan_lsun_rel_no_weightReg_0.01_no.json │ ├── no_resgan_lsun_rel_no_weightReg_0.1_no.json │ └── no_resgan_lsun_rel_no_weightReg_no.json ├── data_utils ├── imbalance_cifar.py └── load_dataset.py ├── docs ├── ContraGAN.md └── figures │ ├── Table1.png │ ├── Table2.png │ ├── Table3.png │ ├── conditional GAN.png │ ├── generated images.png │ └── metric learning loss.png ├── environment.yml ├── figures └── Table3.png ├── main.py ├── make_hdf5.py ├── metrics ├── FID.py ├── IS.py ├── inception_network.py └── prepare_inception_moments_eval_dataset.py ├── models ├── biggan.py ├── biggan_deep.py ├── dcgan.py ├── linear_classifier.py ├── model_ops.py └── resgan.py ├── resnet_cifar.py ├── sync_batchnorm ├── batchnorm.py ├── batchnorm_reimpl.py ├── comm.py ├── replicate.py └── unittest.py ├── train.py ├── trainer.py ├── utils ├── ada.py ├── biggan_utils.py ├── calculate_accuracy.py ├── diff_aug.py ├── icr.py ├── load_checkpoint.py ├── log.py ├── losses.py ├── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── plot.py ├── sample.py └── utils.py └── weight_regularizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *.pyc 3 | db.sqlite3 4 | *.DS_Store 5 | media/ 6 | res/ 7 | .vscode/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | PyTorch StudioGAN: 4 | Copyright (c) 2020 MinGuk Kang 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /LICENSE-NVIDIA: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | Nvidia Source Code License-NC 5 | 6 | ======================================================================= 7 | 8 | 1. Definitions 9 | 10 | "Licensor" means any person or entity that distributes its Work. 11 | 12 | "Software" means the original work of authorship made available under 13 | this License. 14 | 15 | "Work" means the Software and any additions to or derivative works of 16 | the Software that are made available under this License. 17 | 18 | "Nvidia Processors" means any central processing unit (CPU), graphics 19 | processing unit (GPU), field-programmable gate array (FPGA), 20 | application-specific integrated circuit (ASIC) or any combination 21 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 22 | 23 | The terms "reproduce," "reproduction," "derivative works," and 24 | "distribution" have the meaning as provided under U.S. copyright law; 25 | provided, however, that for the purposes of this License, derivative 26 | works shall not include works that remain separable from, or merely 27 | link (or bind by name) to the interfaces of, the Work. 28 | 29 | Works, including the Software, are "made available" under this License 30 | by including in or with the Work either (a) a copyright notice 31 | referencing the applicability of this License to the Work, or (b) a 32 | copy of this License. 33 | 34 | 2. License Grants 35 | 36 | 2.1 Copyright Grant. Subject to the terms and conditions of this 37 | License, each Licensor grants to you a perpetual, worldwide, 38 | non-exclusive, royalty-free, copyright license to reproduce, 39 | prepare derivative works of, publicly display, publicly perform, 40 | sublicense and distribute its Work and any resulting derivative 41 | works in any form. 42 | 43 | 3. Limitations 44 | 45 | 3.1 Redistribution. You may reproduce or distribute the Work only 46 | if (a) you do so under this License, (b) you include a complete 47 | copy of this License with your distribution, and (c) you retain 48 | without modification any copyright, patent, trademark, or 49 | attribution notices that are present in the Work. 50 | 51 | 3.2 Derivative Works. You may specify that additional or different 52 | terms apply to the use, reproduction, and distribution of your 53 | derivative works of the Work ("Your Terms") only if (a) Your Terms 54 | provide that the use limitation in Section 3.3 applies to your 55 | derivative works, and (b) you identify the specific derivative 56 | works that are subject to Your Terms. Notwithstanding Your Terms, 57 | this License (including the redistribution requirements in Section 58 | 3.1) will continue to apply to the Work itself. 59 | 60 | 3.3 Use Limitation. The Work and any derivative works thereof only 61 | may be used or intended for use non-commercially. The Work or 62 | derivative works thereof may be used or intended for use by Nvidia 63 | or its affiliates commercially or non-commercially. As used herein, 64 | "non-commercially" means for research or evaluation purposes only. 65 | 66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 67 | against any Licensor (including any claim, cross-claim or 68 | counterclaim in a lawsuit) to enforce any patents that you allege 69 | are infringed by any Work, then your rights under this License from 70 | such Licensor (including the grants in Sections 2.1 and 2.2) will 71 | terminate immediately. 72 | 73 | 3.5 Trademarks. This License does not grant any rights to use any 74 | Licensor's or its affiliates' names, logos, or trademarks, except 75 | as necessary to reproduce the notices described in this License. 76 | 77 | 3.6 Termination. If you violate any term of this License, then your 78 | rights under this License (including the grants in Sections 2.1 and 79 | 2.2) will terminate immediately. 80 | 81 | 4. Disclaimer of Warranty. 82 | 83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 87 | THIS LICENSE. 88 | 89 | 5. Limitation of Liability. 90 | 91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 99 | THE POSSIBILITY OF SUCH DAMAGES. 100 | 101 | ======================================================================= 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Class Balancing GAN with A Classifier In the Loop ([Paper](https://arxiv.org/abs/2106.09402)) 2 | 3 | 4 | This is code release for our **UAI 2021** paper Class Balancing GAN with a Classifier in the Loop. 5 | ![approach](https://user-images.githubusercontent.com/15148765/125190714-1f9a3300-e25c-11eb-9933-e13e91c79ea6.jpg) 6 | 7 | 8 | 9 | 10 | ## 1. Requirements 11 | 12 | - Anaconda 13 | - Python > 3.6 14 | - torch > 1.6.0 15 | - torchvision > 0.7.0 16 | - Pillow < 7 17 | - apex 0.1 (for fused optimiers) 18 | - tensorboard 19 | - h5py 20 | - tqdm 21 | 22 | You can install the recommended environment setting as follows: 23 | 24 | ``` 25 | conda env create -f environment.yml -n classbalancinggan 26 | ``` 27 | 28 | 29 | 30 | ## 2. Dataset (CIFAR10, LSUN) 31 | CIFAR-10 dataset will be downloaded automatically in ```./data``` folder in the project directory. For LSUN dataset download please follow the instructions [here](https://github.com/fyu/lsun) on how to download, then update the config file with the dataset path. 32 | 33 | 34 | ## 3. Pretrained Classifier 35 | 36 | One of the requirments of our framework is the availability of pretrained classifier on the data on the classes you want to train the GAN. For all the results we use the [LDAM-DRW](https://github.com/kaidic/LDAM-DRW) repo to obtain the pretrained models. We provide link for downloading the pretrained models of classifier. 37 | 38 | Dataset | 0.01 | 0.1 | 1.0 39 | --- | --- | --- | --- 40 | CIFAR | [link](https://drive.google.com/file/d/18OPwjIpFYYcNJfuLNcnyEY3e_V5UGScf/view?usp=sharing) | [link](https://drive.google.com/file/d/1o-5f0b2Fr7LwxK0lgThZ3yLcZ2VTpZiI/view?usp=sharing) | [link](https://drive.google.com/file/d/1o-5f0b2Fr7LwxK0lgThZ3yLcZ2VTpZiI/view?usp=sharing) 41 | LSUN | [link](https://drive.google.com/file/d/1vvNVQLFFmpv1qxX_28V-sVDdSHwFM58X/view?usp=sharing) | [link](https://drive.google.com/file/d/1OouiaShrUiwn48EtYasRQKmRxrq74rSE/view?usp=sharing) | [link](https://drive.google.com/file/d/1dSTuv2IFEeYshyr0MQCjnGVSkj1_lI1w/view?usp=sharing) 42 | 43 | Please download these files before you start to run experiments. Update the path of pretrained models in the ```pretrained_model_path``` field in the configurations in ```./configs``` folder. 44 | 45 | 46 | ## 4. How to run 47 | For each of the imbalance factors (i.e. 0.01, 0.1 and 1) there is seperate configuration file in the config folder. 48 | 49 | For CIFAR10 image generation training: 50 | 51 | ``` 52 | python3 main.py -c "./configs/Unconditional_img_synthesis/no_dcgan_cifar32_rel_weightReg_0.01_no.json" -mpc --eval 53 | ``` 54 | 55 | For LSUN image generation training: 56 | 57 | ``` 58 | python3 main.py -t -c "./configs/Unconditional_img_synthesis/no_dcgan_lsun_rel_weightReg_0.1_no.json" -mpc --eval 59 | ``` 60 | Most experiments were run on an Nvidia 12GB RTX 2080ti gpu. 61 | ## 5. References 62 | 63 | **PyTorch-StudioGAN** : https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 64 | 65 | **LDAM-DRW**: https://github.com/kaidic/LDAM-DRW 66 | 67 | We thank them for open sourcing their code which has been immensely helpful. 68 | 69 | ## 6. Citation 70 | Please email in case of any queries. In case you find our work useful please consider citing the following paper: 71 | 72 | ``` 73 | @inproceedings{rangwani2021class, 74 | title={Class Balancing GAN with a Classifier in the Loop}, 75 | author={Rangwani, Harsh and Mopuri, Konda Reddy and Babu, R Venkatesh}, 76 | booktitle={Uncertainty in Artificial Intelligence}, 77 | pages={1618--1627}, 78 | year={2021}, 79 | organization={PMLR} 80 | } 81 | ``` 82 | 83 | -------------------------------------------------------------------------------- /README_StudioGAN.md: -------------------------------------------------------------------------------- 1 | ## StudioGAN: A Library for Experiment and Evaluation of GANs (Early Version) 2 | 3 | StudioGAN is a Pytorch library providing the implementation of representative Generative Adversarial Networks (GANs) for conditional/unconditional image synthesis. This project aims to help machine learning researchers to compare the new idea with other GANs in the same Pytorch environment. 4 | 5 | 6 | ## 1. Implemented GANs 7 | 8 | * [Vanilla DCGAN (Radford et al.)](https://arxiv.org/abs/1511.06434) 9 | * [WGAN Weight Clipping (Arjovsky et al.)](https://arxiv.org/abs/1701.07875) 10 | * [WGAN Gradient Penalty (Gulrajani et al.)](https://arxiv.org/abs/1704.00028) 11 | * [ACGAN (Odena et al.)](https://arxiv.org/abs/1610.09585) 12 | * [Geometric GAN (Lim and Ye)](https://arxiv.org/abs/1705.02894) 13 | * [cGAN (Miyato and Koyama)](https://arxiv.org/abs/1705.02894) 14 | * [SNDCGAN,SNResGAN (Miyato et al.)](https://arxiv.org/abs/1802.05957) 15 | * [SAGAN (Zhang et al.)](https://arxiv.org/abs/1805.08318) 16 | * [BigGAN (Brock et al.)](https://arxiv.org/abs/1809.11096) 17 | * [BigGAN-Deep (Brock et al.)](https://arxiv.org/abs/1809.11096) 18 | * [CRGAN (Zhang et al.)](https://arxiv.org/abs/1910.12027) 19 | * [ICRGAN (Zhao et al.)](https://arxiv.org/abs/2002.04724) 20 | * [DiffAugment (Zhao et al.)](https://arxiv.org/abs/2006.10738) 21 | * [Adaptive Discriminator Augmentation (Karras et al.)](https://arxiv.org/abs/2006.06676) 22 | * [ContraGAN (Ours)](https://github.com/) 23 | 24 | ## 2. To be implemented 25 | * [LOGAN (Wu et al.)](https://arxiv.org/abs/1912.00953) 26 | 27 | ## 3. Requirements 28 | 29 | - Anaconda 30 | - Python > 3.6 31 | - torch > 1.6.0 32 | - torchvision > 0.7.0 33 | - Pillow < 7 34 | - apex 0.1 (for fused optimiers) 35 | - tensorboard 36 | - h5py 37 | - tqdm 38 | 39 | You can install the recommended environment setting as follows: 40 | 41 | ``` 42 | conda env create -f environment.yml -n StudioGAN 43 | ``` 44 | 45 | or using docker 46 | ``` 47 | docker pull minguk/studio_gan:latest 48 | ``` 49 | 50 | ## 4. Dataset(CIFAR10, Tiny ImageNet, ImageNet possible) 51 | The folder structure of the datasets is shown below: 52 | ``` 53 | ├── data 54 |    └── ILSVRC2012 55 |    ├── train 56 |           ├── n01443537 57 |    ├── image1.png 58 |    ├── image2.png 59 | └── ... 60 | ├── n01629819 61 | └── ... 62 |    ├── valid 63 | └── val_folder 64 | ├── val1.png 65 |    ├── val2.png 66 | └── ... 67 | ``` 68 | 69 | 70 | ## 5. How to run 71 | 72 | For CIFAR10 image generation: 73 | 74 | ``` 75 | CUDA_VISIBLE_DEVICES=0 python3 main.py -t -e -rm_API -c "./configs/Table1/contra_biggan_cifar32_hinge_no.json" 76 | ``` 77 | 78 | For Tiny ImageNet image generation: 79 | 80 | ``` 81 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py -t -e -rm_API -c "./configs/Table1/contra_biggan_tiny32_hinge_no.json" 82 | ``` 83 | 84 | For ImageNet image generation: 85 | 86 | ``` 87 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 main.py -t -e -rm_API -c "./configs/Imagenet_experiments/contra_biggan_imagenet128_hinge_no.json" 88 | ``` 89 | 90 | For ImageNet image generation (loading all images into main memory to reduce I/O bottleneck): 91 | ``` 92 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 main.py -t -e -rm_API -c "./configs/Imagenet_experiments/contra_biggan_imagenet128_hinge_no.json" -l 93 | ``` 94 | 95 | ## 6. About PyTorch FID 96 | 97 | FID is a widely used metric to evaluate the performance of a GAN model. Calculating FID requires a pre-trained inception-V3 network, and approaches use Tensorflow-based FID (https://github.com/bioinf-jku/TTUR), or PyTorch-based FID (https://github.com/mseitzer/pytorch-fid). StudioGAN utilizes the PyTorch-based FID to test GAN models in the same PyTorch environment seamlessly. We show that the PyTorch based FID implementation used in StudioGAN provides almost the same results with the TensorFlow implementation. The results are summarized in the table below. 98 |

99 | 100 | ## 6. References 101 | 102 | **Self-Attention module:** https://github.com/voletiv/self-attention-GAN-pytorch 103 | 104 | **DiffAugment:** https://github.com/mit-han-lab/data-efficient-gans 105 | 106 | **Adaptive Discriminator Augmentation:** https://github.com/rosinality/stylegan2-pytorch 107 | 108 | **Exponential Moving Average:** https://github.com/ajbrock/BigGAN-PyTorch 109 | 110 | **Tensorflow FID:** https://github.com/bioinf-jku/TTUR 111 | 112 | **Pytorch FID:** https://github.com/mseitzer/pytorch-fid 113 | 114 | **Synchronized BatchNorm:** https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 115 | 116 | **Implementation Details:** https://github.com/ajbrock/BigGAN-PyTorch 117 | 118 | ## Citation 119 | StudioGAN is established for the following research project. Please cite our work if you use StudioGAN. 120 | ```bib 121 | @article{kang2020ContraGAN, 122 | title = {{Contrastive Generative Adversarial Networks}}, 123 | author = {Minguk Kang and Jaesik Park}, 124 | journal = {arXiv preprint arXiv:2006.12681}, 125 | year = {2020} 126 | } 127 | ``` 128 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/acgan_sngan_cifar32_rel_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 1.0, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "ACGAN", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": "N/A", 31 | "d_conv_dim": "N/A", 32 | "G_depth":"N/A", 33 | "D_depth":"N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda":"N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | "weight_reg": true, 80 | "weight_lambda": 0, 81 | "pretrained_model_path": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar", 82 | 83 | "evaluation_checkpoint" :"./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | 85 | }, 86 | 87 | "initialization":{ 88 | "g_init": "ortho", 89 | "d_init": "ortho" 90 | }, 91 | 92 | "training_and_sampling_setting":{ 93 | "random_flip_preprocessing": false, 94 | "diff_aug":false, 95 | 96 | "ada": false, 97 | "fixed_augment_p": "N/A", 98 | "ada_target": "N/A", 99 | "ada_length": "N/A", 100 | 101 | "prior": "gaussian", 102 | "truncated_factor": 1, 103 | 104 | "latent_op": false, 105 | "latent_op_rate":"N/A", 106 | "latent_op_step":"N/A", 107 | "latent_op_step4eval":"N/A", 108 | "latent_op_alpha":"N/A", 109 | "latent_op_beta":"N/A", 110 | "latent_norm_reg_weight":"N/A", 111 | "latent_op_lambda": "N/A", 112 | 113 | "ema": true, 114 | "ema_decay": 0.9999, 115 | "ema_start": 20000 116 | } 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/acgan_sngan_cifar32_rel_no_weightReg_0.01_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.01, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "ACGAN", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": "N/A", 31 | "d_conv_dim": "N/A", 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_0.01_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint" :"./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/acgan_sngan_cifar32_rel_no_weightReg_0.1_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.1, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "ACGAN", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": "N/A", 31 | "d_conv_dim": "N/A", 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_0.1_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint" :"./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/acgan_sngan_lsun_rel_no_weightReg_0.01_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "lsun", 4 | "data_path": "../lsun_dataset", 5 | "img_size": 64, 6 | "num_classes": 5, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.01, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "resgan", 17 | "conditional_strategy": "ACGAN", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": 64, 31 | "d_conv_dim": 64, 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.0, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "../LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_0.01_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "../LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/acgan_snresgan_cifar32_rel_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "~/data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "resgan", 15 | "conditional_strategy": "ACGAN", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": false, 21 | "d_spectral_norm": true, 22 | "activation_fn": "Leaky_ReLU", 23 | "attention": false, 24 | "attention_after_nth_gen_block": "N/A", 25 | "attention_after_nth_dis_block": "N/A", 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": 64, 29 | "d_conv_dim": 64, 30 | "G_depth":"N/A", 31 | "D_depth":"N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 64, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0002, 39 | "g_lr": 0.0002, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.5, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 200000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "rel", 52 | 53 | "contrastive_lambda":"N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penelty_lambda": "N/A", 65 | 66 | "consistency_reg": false, 67 | "consistency_lambda":"N/A", 68 | 69 | "bcr": false, 70 | "real_lambda": "N/A", 71 | "fake_lambda": "N/A", 72 | 73 | "zcr": false, 74 | "gen_lambda": "N/A", 75 | "dis_lambda": "N/A", 76 | "sigma_noise": "N/A" 77 | }, 78 | 79 | "initialization":{ 80 | "g_init": "ortho", 81 | "d_init": "ortho" 82 | }, 83 | 84 | "training_and_sampling_setting":{ 85 | "random_flip_preprocessing": false, 86 | "diff_aug":false, 87 | 88 | "ada": false, 89 | "fixed_augment_p": "N/A", 90 | "ada_target": "N/A", 91 | "ada_length": "N/A", 92 | 93 | "prior": "gaussian", 94 | "truncated_factor": 1, 95 | 96 | "latent_op": false, 97 | "latent_op_rate":"N/A", 98 | "latent_op_step":"N/A", 99 | "latent_op_step4eval":"N/A", 100 | "latent_op_alpha":"N/A", 101 | "latent_op_beta":"N/A", 102 | "latent_norm_reg_weight":"N/A", 103 | "latent_op_lambda": "N/A", 104 | 105 | "ema": true, 106 | "ema_decay": 0.9999, 107 | "ema_start": 20000 108 | } 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/contra_biggan_tiny64_hinge_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "biggan", 15 | "conditional_strategy": "ContraGAN", 16 | "pos_collected_numerator":true, 17 | "hypersphere_dim": 768, 18 | "nonlinear_embed": false, 19 | "normalize_embed": true, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 3, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 100, 27 | "shared_dim": 128, 28 | "g_conv_dim": 80, 29 | "d_conv_dim": 80, 30 | "G_depth":"N/A", 31 | "D_depth":"N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 4, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda": 1.0, 54 | "margin": 1.0, 55 | "tempering_type": "constant", 56 | "tempering_step": "N/A", 57 | "start_temperature": 1.0, 58 | "end_temperature": 1.0, 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penelty_lambda": "N/A", 65 | 66 | "consistency_reg": false, 67 | "consistency_lambda":"N/A", 68 | 69 | "bcr": false, 70 | "real_lambda": "N/A", 71 | "fake_lambda": "N/A", 72 | 73 | "zcr": false, 74 | "gen_lambda": "N/A", 75 | "dis_lambda": "N/A", 76 | "sigma_noise": "N/A" 77 | }, 78 | 79 | "initialization":{ 80 | "g_init": "ortho", 81 | "d_init": "ortho" 82 | }, 83 | 84 | "training_and_sampling_setting":{ 85 | "random_flip_preprocessing": false, 86 | "diff_aug":false, 87 | 88 | "ada": false, 89 | "ada_target": "N/A", 90 | "ada_length": "N/A", 91 | 92 | "prior": "gaussian", 93 | "truncated_factor": 1, 94 | 95 | "latent_op": false, 96 | "latent_op_rate":"N/A", 97 | "latent_op_step":"N/A", 98 | "latent_op_step4eval":"N/A", 99 | "latent_op_alpha":"N/A", 100 | "latent_op_beta":"N/A", 101 | "latent_norm_reg_weight":"N/A", 102 | 103 | "ema": true, 104 | "ema_decay": 0.9999, 105 | "ema_start": 20000 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/proj_biggan_tiny64_hinge_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "biggan", 15 | "conditional_strategy": "cGAN", 16 | "pos_collected_numerator":false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 3, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 100, 27 | "shared_dim": 128, 28 | "g_conv_dim": 80, 29 | "d_conv_dim": 80, 30 | "G_depth":"N/A", 31 | "D_depth":"N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 4, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda":"N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penelty_lambda": "N/A", 65 | 66 | "consistency_reg": false, 67 | "consistency_lambda":"N/A", 68 | 69 | "bcr": false, 70 | "real_lambda": "N/A", 71 | "fake_lambda": "N/A", 72 | 73 | "zcr": false, 74 | "gen_lambda": "N/A", 75 | "dis_lambda": "N/A", 76 | "sigma_noise": "N/A" 77 | }, 78 | 79 | "initialization":{ 80 | "g_init": "ortho", 81 | "d_init": "ortho" 82 | }, 83 | 84 | "training_and_sampling_setting":{ 85 | "random_flip_preprocessing": false, 86 | "diff_aug":false, 87 | 88 | "ada": false, 89 | "ada_target": "N/A", 90 | "ada_length": "N/A", 91 | 92 | "prior": "gaussian", 93 | "truncated_factor": 1, 94 | 95 | "latent_op": false, 96 | "latent_op_rate":"N/A", 97 | "latent_op_step":"N/A", 98 | "latent_op_step4eval":"N/A", 99 | "latent_op_alpha":"N/A", 100 | "latent_op_beta":"N/A", 101 | "latent_norm_reg_weight":"N/A", 102 | 103 | "ema": true, 104 | "ema_decay": 0.9999, 105 | "ema_start": 20000 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/proj_sagan_tiny64_hinge_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "resgan", 15 | "conditional_strategy": "cGAN", 16 | "pos_collected_numerator":false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 3, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": 64, 29 | "d_conv_dim": 64, 30 | "G_depth":"N/A", 31 | "D_depth":"N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 4, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda":"N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penelty_lambda": "N/A", 65 | 66 | "consistency_reg": false, 67 | "consistency_lambda":"N/A", 68 | 69 | "bcr": false, 70 | "real_lambda": "N/A", 71 | "fake_lambda": "N/A", 72 | 73 | "zcr": false, 74 | "gen_lambda": "N/A", 75 | "dis_lambda": "N/A", 76 | "sigma_noise": "N/A" 77 | }, 78 | 79 | "initialization":{ 80 | "g_init": "ortho", 81 | "d_init": "ortho" 82 | }, 83 | 84 | "training_and_sampling_setting":{ 85 | "random_flip_preprocessing": false, 86 | "diff_aug":false, 87 | 88 | "ada": false, 89 | "ada_target": "N/A", 90 | "ada_length": "N/A", 91 | 92 | "prior": "gaussian", 93 | "truncated_factor": 1, 94 | 95 | "latent_op": false, 96 | "latent_op_rate":"N/A", 97 | "latent_op_step":"N/A", 98 | "latent_op_step4eval":"N/A", 99 | "latent_op_alpha":"N/A", 100 | "latent_op_beta":"N/A", 101 | "latent_norm_reg_weight":"N/A", 102 | 103 | "ema": true, 104 | "ema_decay": 0.9999, 105 | "ema_start": 20000 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/proj_sngan_cifar32_hinge_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "dcgan", 15 | "conditional_strategy": "cGAN", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": false, 21 | "d_spectral_norm": true, 22 | "activation_fn": "Leaky_ReLU", 23 | "attention": false, 24 | "attention_after_nth_gen_block": "N/A", 25 | "attention_after_nth_dis_block": "N/A", 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": "N/A", 29 | "d_conv_dim": "N/A", 30 | "G_depth":"N/A", 31 | "D_depth":"N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 64, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0002, 39 | "g_lr": 0.0002, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.5, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 200000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda":"N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penelty_lambda": "N/A", 65 | 66 | "consistency_reg": false, 67 | "consistency_lambda":"N/A", 68 | 69 | "bcr": false, 70 | "real_lambda": "N/A", 71 | "fake_lambda": "N/A", 72 | 73 | "zcr": false, 74 | "gen_lambda": "N/A", 75 | "dis_lambda": "N/A", 76 | "sigma_noise": "N/A" 77 | }, 78 | 79 | "initialization":{ 80 | "g_init": "ortho", 81 | "d_init": "ortho" 82 | }, 83 | 84 | "training_and_sampling_setting":{ 85 | "random_flip_preprocessing": false, 86 | "diff_aug":false, 87 | 88 | "ada": false, 89 | "fixed_augment_p": "N/A", 90 | "ada_target": "N/A", 91 | "ada_length": "N/A", 92 | 93 | "prior": "gaussian", 94 | "truncated_factor": 1, 95 | 96 | "latent_op": false, 97 | "latent_op_rate":"N/A", 98 | "latent_op_step":"N/A", 99 | "latent_op_step4eval":"N/A", 100 | "latent_op_alpha":"N/A", 101 | "latent_op_beta":"N/A", 102 | "latent_norm_reg_weight":"N/A", 103 | "latent_op_lambda": "N/A", 104 | 105 | "ema": true, 106 | "ema_decay": 0.9999, 107 | "ema_start": 20000 108 | } 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/proj_sngan_cifar32_rel_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 1.0, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "cGAN", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": "N/A", 31 | "d_conv_dim": "N/A", 32 | "G_depth":"N/A", 33 | "D_depth":"N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda":"N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint":"./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/proj_sngan_cifar32_rel_no_weightReg_0.01_512_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "~/data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.01, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "cGAN", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": "N/A", 31 | "d_conv_dim": "N/A", 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_0.01_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/proj_sngan_cifar32_rel_no_weightReg_0.01_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "~/data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.01, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "cGAN", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": "N/A", 31 | "d_conv_dim": "N/A", 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 64, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "/mnt/data/harsh/harsh-dir2/DeGAN/CIFAR_10/train_teacher/LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 83 | }, 84 | 85 | "initialization":{ 86 | "g_init": "ortho", 87 | "d_init": "ortho" 88 | }, 89 | 90 | "training_and_sampling_setting":{ 91 | "random_flip_preprocessing": false, 92 | "diff_aug":false, 93 | 94 | "ada": false, 95 | "fixed_augment_p": "N/A", 96 | "ada_target": "N/A", 97 | "ada_length": "N/A", 98 | 99 | "prior": "gaussian", 100 | "truncated_factor": 1, 101 | 102 | "latent_op": false, 103 | "latent_op_rate":"N/A", 104 | "latent_op_step":"N/A", 105 | "latent_op_step4eval":"N/A", 106 | "latent_op_alpha":"N/A", 107 | "latent_op_beta":"N/A", 108 | "latent_norm_reg_weight":"N/A", 109 | "latent_op_lambda": "N/A", 110 | 111 | "ema": true, 112 | "ema_decay": 0.9999, 113 | "ema_start": 20000 114 | } 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/proj_sngan_cifar32_rel_no_weightReg_0.1_512_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "~/data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.1, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "cGAN", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": 64, 31 | "d_conv_dim": 64, 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_0.1_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/proj_sngan_cifar32_rel_no_weightReg_0.1_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "~/data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.1, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "cGAN", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": "N/A", 31 | "d_conv_dim": "N/A", 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 64, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "/mnt/data/harsh/harsh-dir2/DeGAN/CIFAR_10/train_teacher/LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 83 | }, 84 | 85 | "initialization":{ 86 | "g_init": "ortho", 87 | "d_init": "ortho" 88 | }, 89 | 90 | "training_and_sampling_setting":{ 91 | "random_flip_preprocessing": false, 92 | "diff_aug":false, 93 | 94 | "ada": false, 95 | "fixed_augment_p": "N/A", 96 | "ada_target": "N/A", 97 | "ada_length": "N/A", 98 | 99 | "prior": "gaussian", 100 | "truncated_factor": 1, 101 | 102 | "latent_op": false, 103 | "latent_op_rate":"N/A", 104 | "latent_op_step":"N/A", 105 | "latent_op_step4eval":"N/A", 106 | "latent_op_alpha":"N/A", 107 | "latent_op_beta":"N/A", 108 | "latent_norm_reg_weight":"N/A", 109 | "latent_op_lambda": "N/A", 110 | 111 | "ema": true, 112 | "ema_decay": 0.9999, 113 | "ema_start": 20000 114 | } 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/proj_sngan_lsun_rel_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "lsun", 4 | "data_path": "../lsun_dataset", 5 | "img_size": 128, 6 | "num_classes": 5, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.01, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "resgan", 17 | "conditional_strategy": "cGAN", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": 64, 31 | "d_conv_dim": 64, 32 | "G_depth":"N/A", 33 | "D_depth":"N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.0, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 2, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "hinge", 54 | 55 | "contrastive_lambda":"N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":true, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/proj_sngan_lsun_rel_no_weightReg_0.01_512_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "lsun", 4 | "data_path": "../lsun_dataset", 5 | "img_size": 64, 6 | "num_classes": 5, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.01, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "resgan", 17 | "conditional_strategy": "cGAN", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": 64, 31 | "d_conv_dim": 64, 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.0, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/proj_sngan_lsun_rel_no_weightReg_0.1_512_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "lsun", 4 | "data_path": "../lsun_dataset", 5 | "img_size": 64, 6 | "num_classes": 5, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.1, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "resgan", 17 | "conditional_strategy": "cGAN", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": 64, 31 | "d_conv_dim": 64, 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.0, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Conditional_img_synthesis/proj_sngan_tiny64_hinge_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "resgan", 15 | "conditional_strategy": "cGAN", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": false, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": false, 24 | "attention_after_nth_gen_block": "N/A", 25 | "attention_after_nth_dis_block": "N/A", 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": 64, 29 | "d_conv_dim": 64, 30 | "G_depth":"N/A", 31 | "D_depth":"N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 4, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda":"N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penelty_lambda": "N/A", 65 | 66 | "consistency_reg": false, 67 | "consistency_lambda":"N/A", 68 | 69 | "bcr": false, 70 | "real_lambda": "N/A", 71 | "fake_lambda": "N/A", 72 | 73 | "zcr": false, 74 | "gen_lambda": "N/A", 75 | "dis_lambda": "N/A", 76 | "sigma_noise": "N/A" 77 | }, 78 | 79 | "initialization":{ 80 | "g_init": "ortho", 81 | "d_init": "ortho" 82 | }, 83 | 84 | "training_and_sampling_setting":{ 85 | "random_flip_preprocessing": false, 86 | "diff_aug":false, 87 | 88 | "ada": false, 89 | "ada_target": "N/A", 90 | "ada_length": "N/A", 91 | 92 | "prior": "gaussian", 93 | "truncated_factor": 1, 94 | 95 | "latent_op": false, 96 | "latent_op_rate":"N/A", 97 | "latent_op_step":"N/A", 98 | "latent_op_step4eval":"N/A", 99 | "latent_op_alpha":"N/A", 100 | "latent_op_beta":"N/A", 101 | "latent_norm_reg_weight":"N/A", 102 | 103 | "ema": true, 104 | "ema_decay": 0.9999, 105 | "ema_start": 20000 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_dcgan_cifar32_hinge_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "dcgan", 15 | "conditional_strategy": "no", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": false, 21 | "d_spectral_norm": true, 22 | "activation_fn": "Leaky_ReLU", 23 | "attention": false, 24 | "attention_after_nth_gen_block": "N/A", 25 | "attention_after_nth_dis_block": "N/A", 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": "N/A", 29 | "d_conv_dim": "N/A", 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 64, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0002, 39 | "g_lr": 0.0002, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.5, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 200000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "rel", 52 | 53 | "contrastive_lambda": "N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penelty_lambda": "N/A", 65 | 66 | "consistency_reg": false, 67 | "consistency_lambda":"N/A", 68 | 69 | "bcr": false, 70 | "real_lambda": "N/A", 71 | "fake_lambda": "N/A", 72 | 73 | "zcr": false, 74 | "gen_lambda": "N/A", 75 | "dis_lambda": "N/A", 76 | "sigma_noise": "N/A" 77 | }, 78 | 79 | "initialization":{ 80 | "g_init": "ortho", 81 | "d_init": "ortho" 82 | }, 83 | 84 | "training_and_sampling_setting":{ 85 | "random_flip_preprocessing": false, 86 | "diff_aug":false, 87 | 88 | "ada": false, 89 | "ada_target": "N/A", 90 | "ada_length": "N/A", 91 | 92 | "prior": "gaussian", 93 | "truncated_factor": 1, 94 | 95 | "latent_op": false, 96 | "latent_op_rate":"N/A", 97 | "latent_op_step":"N/A", 98 | "latent_op_step4eval":"N/A", 99 | "latent_op_alpha":"N/A", 100 | "latent_op_beta":"N/A", 101 | "latent_norm_reg_weight":"N/A", 102 | 103 | "ema": true, 104 | "ema_decay": 0.9999, 105 | "ema_start": 20000 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_dcgan_cifar32_rel_no_weightReg_0.01_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.01, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "no", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": "N/A", 31 | "d_conv_dim": "N/A", 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 32, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_0.01_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_dcgan_cifar32_rel_no_weightReg_0.1_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.1, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "no", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": "N/A", 31 | "d_conv_dim": "N/A", 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_0.1_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_dcgan_cifar32_rel_no_weightReg_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 1.0, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "resgan", 17 | "conditional_strategy": "no", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": false, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": 64, 31 | "d_conv_dim": 64, 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 128, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 500000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "vanilla", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_dcgan_cifar32_rel_weightReg_0.01_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.01, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "no", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": "N/A", 31 | "d_conv_dim": "N/A", 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 10, 82 | "pretrained_model_path": "./cifar10_resnet32_Focal_DRW_exp_0.01_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./cifar10_resnet32_Focal_DRW_exp_0.01_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_dcgan_cifar32_rel_weightReg_0.1_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.1, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "no", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": "N/A", 31 | "d_conv_dim": "N/A", 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 7.5, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_0.1_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_dcgan_cifar32_rel_weightReg_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 1.0, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "dcgan", 17 | "conditional_strategy": "no", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": "N/A", 31 | "d_conv_dim": "N/A", 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.5, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 5, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/cifar10_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_dcgan_lsun_rel_weightReg_0.01_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "lsun", 4 | "data_path": "../lsun_dataset", 5 | "img_size": 64, 6 | "num_classes": 5, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.01, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "resgan", 17 | "conditional_strategy": "no", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": 64, 31 | "d_conv_dim": 64, 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.0, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 10, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_0.01_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_dcgan_lsun_rel_weightReg_0.1_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "lsun", 4 | "data_path": "../lsun_dataset", 5 | "img_size": 64, 6 | "num_classes": 5, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.1, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "resgan", 17 | "conditional_strategy": "no", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": 64, 31 | "d_conv_dim": 64, 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.0, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 7.5, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_0.1_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_dcgan_lsun_rel_weightReg_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "lsun", 4 | "data_path": "../lsun_dataset", 5 | "img_size": 64, 6 | "num_classes": 5, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 1.0, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "resgan", 17 | "conditional_strategy": "no", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": 64, 31 | "d_conv_dim": 64, 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.0, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 2.5, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_resgan_lsun_rel_no_weightReg_0.01_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "lsun", 4 | "data_path": "../lsun_dataset", 5 | "img_size": 64, 6 | "num_classes": 5, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.01, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "resgan", 17 | "conditional_strategy": "no", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": 64, 31 | "d_conv_dim": 64, 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.0, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_0.01_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_resgan_lsun_rel_no_weightReg_0.1_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "lsun", 4 | "data_path": "../lsun_dataset", 5 | "img_size": 64, 6 | "num_classes": 5, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 0.1, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "resgan", 17 | "conditional_strategy": "no", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": 64, 31 | "d_conv_dim": 64, 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.0, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_0.1_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /configs/Unconditional_img_synthesis/no_resgan_lsun_rel_no_weightReg_no.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "lsun", 4 | "data_path": "../lsun_dataset", 5 | "img_size": 64, 6 | "num_classes": 5, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false, 10 | "imb_factor": 1.0, 11 | "imb_type": "exp" 12 | }, 13 | 14 | "train": { 15 | "model": { 16 | "architecture": "resgan", 17 | "conditional_strategy": "no", 18 | "pos_collected_numerator": false, 19 | "hypersphere_dim": "N/A", 20 | "nonlinear_embed": false, 21 | "normalize_embed": false, 22 | "g_spectral_norm": false, 23 | "d_spectral_norm": true, 24 | "activation_fn": "Leaky_ReLU", 25 | "attention": false, 26 | "attention_after_nth_gen_block": "N/A", 27 | "attention_after_nth_dis_block": "N/A", 28 | "z_dim": 128, 29 | "shared_dim": "N/A", 30 | "g_conv_dim": 64, 31 | "d_conv_dim": 64, 32 | "G_depth": "N/A", 33 | "D_depth": "N/A" 34 | }, 35 | 36 | "optimization": { 37 | "optimizer": "Adam", 38 | "batch_size": 256, 39 | "accumulation_steps": 1, 40 | "d_lr": 0.0002, 41 | "g_lr": 0.0002, 42 | "momentum": "N/A", 43 | "nesterov": "N/A", 44 | "alpha": "N/A", 45 | "beta1": 0.0, 46 | "beta2": 0.999, 47 | "g_steps_per_iter": 1, 48 | "d_steps_per_iter": 1, 49 | "total_step": 100000 50 | }, 51 | 52 | "loss_function": { 53 | "adv_loss": "rel", 54 | 55 | "contrastive_lambda": "N/A", 56 | "margin": "N/A", 57 | "tempering_type": "N/A", 58 | "tempering_step": "N/A", 59 | "start_temperature": "N/A", 60 | "end_temperature": "N/A", 61 | 62 | "weight_clipping_for_dis": false, 63 | "weight_clipping_bound": "N/A", 64 | 65 | "gradient_penalty_for_dis": false, 66 | "gradient_penelty_lambda": "N/A", 67 | 68 | "consistency_reg": false, 69 | "consistency_lambda":"N/A", 70 | 71 | "bcr": false, 72 | "real_lambda": "N/A", 73 | "fake_lambda": "N/A", 74 | 75 | "zcr": false, 76 | "gen_lambda": "N/A", 77 | "dis_lambda": "N/A", 78 | "sigma_noise": "N/A", 79 | 80 | "weight_reg" : true, 81 | "weight_lambda" : 0, 82 | "pretrained_model_path": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar", 83 | "evaluation_checkpoint": "./LDAM-DRW/checkpoint/lsun_resnet32_Focal_DRW_exp_1.0_0/ckpt.best.pth.tar" 84 | }, 85 | 86 | "initialization":{ 87 | "g_init": "ortho", 88 | "d_init": "ortho" 89 | }, 90 | 91 | "training_and_sampling_setting":{ 92 | "random_flip_preprocessing": false, 93 | "diff_aug":false, 94 | 95 | "ada": false, 96 | "fixed_augment_p": "N/A", 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | "latent_op_lambda": "N/A", 111 | 112 | "ema": true, 113 | "ema_decay": 0.9999, 114 | "ema_start": 20000 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /data_utils/imbalance_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | 6 | class IMBALANCECIFAR10(torchvision.datasets.CIFAR10): 7 | cls_num = 10 8 | 9 | def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True, 10 | transform=None, target_transform=None, 11 | download=False): 12 | super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download) 13 | np.random.seed(rand_number) 14 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) 15 | print(img_num_list) 16 | self.gen_imbalanced_data(img_num_list) 17 | 18 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor): 19 | img_max = len(self.data) / cls_num 20 | img_num_per_cls = [] 21 | if imb_type == 'exp': 22 | for cls_idx in range(cls_num): 23 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 24 | img_num_per_cls.append(int(num)) 25 | elif imb_type == 'step': 26 | for cls_idx in range(cls_num // 2): 27 | img_num_per_cls.append(int(img_max)) 28 | for cls_idx in range(cls_num // 2): 29 | img_num_per_cls.append(int(img_max * imb_factor)) 30 | else: 31 | img_num_per_cls.extend([int(img_max)] * cls_num) 32 | return img_num_per_cls 33 | 34 | def gen_imbalanced_data(self, img_num_per_cls): 35 | new_data = [] 36 | new_targets = [] 37 | targets_np = np.array(self.targets, dtype=np.int64) 38 | classes = np.unique(targets_np) 39 | # np.random.shuffle(classes) 40 | self.num_per_cls_dict = dict() 41 | for the_class, the_img_num in zip(classes, img_num_per_cls): 42 | self.num_per_cls_dict[the_class] = the_img_num 43 | idx = np.where(targets_np == the_class)[0] 44 | np.random.shuffle(idx) 45 | selec_idx = idx[:the_img_num] 46 | new_data.append(self.data[selec_idx, ...]) 47 | new_targets.extend([the_class, ] * the_img_num) 48 | new_data = np.vstack(new_data) 49 | self.data = new_data 50 | self.targets = new_targets 51 | 52 | def get_cls_num_list(self): 53 | cls_num_list = [] 54 | for i in range(self.cls_num): 55 | cls_num_list.append(self.num_per_cls_dict[i]) 56 | return cls_num_list 57 | 58 | class IMBALANCECIFAR100(IMBALANCECIFAR10): 59 | """`CIFAR100 `_ Dataset. 60 | This is a subclass of the `CIFAR10` Dataset. 61 | """ 62 | base_folder = 'cifar-100-python' 63 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 64 | filename = "cifar-100-python.tar.gz" 65 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 66 | train_list = [ 67 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 68 | ] 69 | 70 | test_list = [ 71 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 72 | ] 73 | meta = { 74 | 'filename': 'meta', 75 | 'key': 'fine_label_names', 76 | 'md5': '7973b15100ade9c7d40fb424638fde48', 77 | } 78 | cls_num = 100 79 | 80 | 81 | class IMBALANCELSUN(torchvision.datasets.LSUN): 82 | def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, classes="val", 83 | transform=None, target_transform=None, max_samples = None): 84 | super(IMBALANCELSUN, self).__init__(root, classes, transform, target_transform,) 85 | np.random.seed(rand_number) 86 | self.max_samples = max_samples 87 | img_num_list = self.get_img_num_per_cls(len(self.classes), imb_type, imb_factor) 88 | self.gen_imbalanced_data(img_num_list) 89 | 90 | 91 | 92 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor): 93 | img_max = self.max_samples 94 | img_num_per_cls = [] 95 | if imb_type == 'exp': 96 | for cls_idx in range(cls_num): 97 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 98 | img_num_per_cls.append(int(num)) 99 | elif imb_type == 'step': 100 | for cls_idx in range(cls_num // 2): 101 | img_num_per_cls.append(int(img_max)) 102 | for cls_idx in range(cls_num // 2): 103 | img_num_per_cls.append(int(img_max * imb_factor)) 104 | else: 105 | img_num_per_cls.extend([int(img_max)] * cls_num) 106 | return img_num_per_cls 107 | 108 | def gen_imbalanced_data(self, img_num_list): 109 | 110 | modified_self_indices = [] 111 | count = 0 112 | for c in img_num_list: 113 | count += c 114 | modified_self_indices.append(count) 115 | 116 | self.indices = modified_self_indices 117 | self.length = count 118 | 119 | 120 | 121 | 122 | if __name__ == '__main__': 123 | transform = transforms.Compose( 124 | [transforms.ToTensor(), 125 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 126 | trainset = IMBALANCECIFAR100(root='./data', train=True, 127 | download=True, transform=transform) 128 | trainloader = iter(trainset) 129 | data, label = next(trainloader) 130 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /docs/ContraGAN.md: -------------------------------------------------------------------------------- 1 | ## Contrastive Generative Adversarial Networks 2 | 3 | **Abstract** 4 | 5 | Conditional image synthesis is the task to generate high-fidelity diverse images using class label information. Although many studies have shown realistic results, there is room for improvement if the number of classes increases. In this paper, we propose a novel conditional contrastive loss to maximize a lower bound on mutual information between samples from the same class. Our framework, called Contrastive Generative Adversarial Networks (ContraGAN), learns to synthesize images using class information and data-to-data relations of training examples. The discriminator in ContraGAN discriminates the authenticity of given samples and maximizes the mutual information between embeddings of real images from the same class. Simultaneously, the generator attempts to synthesize images to fool the discriminator and to maximize the mutual information of fake images from the same class prior. The experimental results show that ContraGAN is robust to network architecture selection and outperforms state-of-the-art models by 3.7% and 11.2% on CIFAR10 and Tiny ImageNet datasets, respectively, without any data augmentation. For the fair comparison, we re-implement nine state-of-the-art approaches to test various methods under the same condition. 6 | 7 | 8 | ## 1. Illustrative Figures of Conditional Contrastive Loss 9 |

10 | 11 | 12 | Illustrative figures visualize the metric learning losses (a,b,c) and conditional GANs (d,e,f). The objective of all losses is to collect samples if they have the same label but keep them away otherwise. The color indicates the class label, and the shape represents the role. (Square) an embedding of an image. (Diamond) an augmented embedding. (Circle) a reference. Each loss is applied to the reference. (Star) the embedding of a class label. (Triangle) the one-hot encoding of a class label. The thicknesses of red and blue lines represents the strength of pull and push force, respectively. Compared to ACGAN and cGAN, our loss is inspired by XT-Xent to consider data-to-data relationships and to infer full information without data augmentation. 13 | 14 | 15 | ## 2. Schematics of the discriminator of three conditional GANs 16 | ![Figure2](https://github.com/MINGUKKANG/Pytorch-GAN-Benchmark/blob/master/figures/conditional%20GAN.png) 17 | 18 | Schematics of the discriminators of three conditional GANs. (a) ACGAN has an auxiliary classifier to help the generator to synthesize well-classifiable images. (b) cGAN improves the ACGAN by adding the inner product of an embedded image and the corresponding class embedding. (c) Our approach extends cGAN with conditional contrastive loss (2C loss) between embedded images and the actual label embedding. ContraGAN considers multiple positive and negative pairs in the same batch, as shown in the above figure (f). 19 | 20 | ## 3. Results 21 | 22 | ### Quantitative Results 23 | **Table1:** Experiments using CIFAR10 and Tiny ImageNet dataset. Using three backbone architectures (SNDCGAN, SNResGAN, and BigGAN), we test three approaches using different class information conditioning (ACGAN, cGAN, and ours). Mean +- variance of FID is reported. 24 |

25 | 26 | **Table2:** Comparison with state-of-the-art GAN models. We mark "*" to FID values reported in the original papers (BigGAN and CRGAN). The other FID values are obtained from our implementation. 27 |

28 | 29 | ### Qualitative Results 30 | **Figure 1:** Examples of generated images using ContraGAN. (left) CIFAR10, FID: 10.322, (right) Tiny ImageNet, FID: 27.018. 31 |

32 | 33 | ## 4. Pre-trained model on Tiny ImageNet 34 | ``` 35 | https://drive.google.com/file/d/1XsouS_HIlES9CAshrtgYA3b6H2qUXTnx/view?usp=sharing 36 | ``` 37 | 38 | ## 5. How to run 39 | 40 | For CIFAR10 image generation tasks: 41 | 42 | ``` 43 | CUDA_VISIBLE_DEVICES=0 python3 main.py --eval -t -c "./configs/Table1/contra_biggan32_cifar_hinge_no.json" 44 | ``` 45 | 46 | For Tiny ImageNet generation tasks: 47 | 48 | ``` 49 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py --eval -t -c "./configs/Table1/contra_biggan64_tiny_hinge_no.json" 50 | ``` 51 | 52 | To use pre-trained model on Tiny ImageNet: 53 | ``` 54 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py -c "./configs/Table1/contra_biggan64_tiny_hinge_no.json" --checkpoint_folder "./checkpoints/contra_tiny_imagenet_pretrained" --step 50000 --eval -t 55 | ``` 56 | -------------------------------------------------------------------------------- /docs/figures/Table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/class-balancing-gan/6c254c5fdf86155ad6c99e9e4668fc0de713da1f/docs/figures/Table1.png -------------------------------------------------------------------------------- /docs/figures/Table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/class-balancing-gan/6c254c5fdf86155ad6c99e9e4668fc0de713da1f/docs/figures/Table2.png -------------------------------------------------------------------------------- /docs/figures/Table3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/class-balancing-gan/6c254c5fdf86155ad6c99e9e4668fc0de713da1f/docs/figures/Table3.png -------------------------------------------------------------------------------- /docs/figures/conditional GAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/class-balancing-gan/6c254c5fdf86155ad6c99e9e4668fc0de713da1f/docs/figures/conditional GAN.png -------------------------------------------------------------------------------- /docs/figures/generated images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/class-balancing-gan/6c254c5fdf86155ad6c99e9e4668fc0de713da1f/docs/figures/generated images.png -------------------------------------------------------------------------------- /docs/figures/metric learning loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/class-balancing-gan/6c254c5fdf86155ad6c99e9e4668fc0de713da1f/docs/figures/metric learning loss.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: studiogan 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - ca-certificates=2021.5.30=ha878542_0 10 | - certifi=2021.5.30=py37h89c1867_0 11 | - cudatoolkit=11.0.221=h6bb024c_0 12 | - cycler=0.10.0=py37_0 13 | - dbus=1.13.16=hb2f20db_0 14 | - expat=2.2.9=he6710b0_2 15 | - fontconfig=2.13.0=h9420a91_0 16 | - freetype=2.9.1=h8a8886c_1 17 | - glib=2.63.1=h5a9c865_0 18 | - gst-plugins-base=1.14.0=hbbd80ab_1 19 | - gstreamer=1.14.0=hb453b48_1 20 | - h5py=2.10.0=py37hd6299e0_1 21 | - hdf5=1.10.6=hb1b8bf9_0 22 | - icu=58.2=he6710b0_3 23 | - intel-openmp=2020.0=166 24 | - jpeg=9b=h024ee3a_2 25 | - kiwisolver=1.2.0=py37hfd86e86_0 26 | - kornia=0.5.5=pyhd8ed1ab_0 27 | - ld_impl_linux-64=2.33.1=h53a641e_7 28 | - libedit=3.1.20181209=hc058e9b_0 29 | - libffi=3.2.1=hd88cf55_4 30 | - libgcc-ng=9.1.0=hdf63c60_0 31 | - libgfortran-ng=7.3.0=hdf63c60_0 32 | - libpng=1.6.37=hbc83047_0 33 | - libstdcxx-ng=9.1.0=hdf63c60_0 34 | - libtiff=4.1.0=h2733197_0 35 | - libuuid=1.0.3=h1bed415_2 36 | - libuv=1.40.0=h7b6447c_0 37 | - libxcb=1.14=h7b6447c_0 38 | - libxml2=2.9.9=hea5a465_1 39 | - matplotlib=3.1.3=py37_0 40 | - matplotlib-base=3.1.3=py37hef1b27d_0 41 | - mkl=2020.0=166 42 | - mkl-include=2020.0=166 43 | - mkl-service=2.3.0=py37he904b0f_0 44 | - mkl_fft=1.0.15=py37ha843d7b_0 45 | - mkl_random=1.1.0=py37hd6b4f25_0 46 | - ncurses=6.2=he6710b0_0 47 | - ninja=1.9.0=py37hfd86e86_0 48 | - numpy=1.18.1=py37h4f9e942_0 49 | - numpy-base=1.18.1=py37hde5b4d6_1 50 | - olefile=0.46=py37_0 51 | - openssl=1.1.1k=h27cfd23_0 52 | - pcre=8.44=he6710b0_0 53 | - pillow=7.0.0=py37hb39fc2d_0 54 | - pip=21.0.1=py37h06a4308_0 55 | - pyparsing=2.4.7=py_0 56 | - pyqt=5.9.2=py37h05f1152_2 57 | - python=3.7.6=h0371630_2 58 | - python-dateutil=2.8.1=py_0 59 | - python_abi=3.7=2_cp37m 60 | - pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0 61 | - qt=5.9.7=h5867ecd_1 62 | - readline=7.0=h7b6447c_5 63 | - setuptools=45.2.0=py37_0 64 | - sip=4.19.8=py37hf484d3e_0 65 | - six=1.14.0=py37_0 66 | - sqlite=3.31.1=h7b6447c_0 67 | - tk=8.6.8=hbc83047_0 68 | - torchaudio=0.7.2=py37 69 | - torchvision=0.8.2=py37_cu110 70 | - tornado=6.0.4=py37h7b6447c_1 71 | - tqdm=4.47.0=py_0 72 | - typing_extensions=3.7.4.3=py_0 73 | - wheel=0.34.2=py37_0 74 | - xz=5.2.4=h14c3975_4 75 | - zlib=1.2.11=h7b6447c_3 76 | - zstd=1.3.7=h0b5b093_0 77 | - pip: 78 | - absl-py==0.12.0 79 | - cachetools==4.2.1 80 | - chardet==4.0.0 81 | - google-auth==1.28.0 82 | - google-auth-oauthlib==0.4.3 83 | - grpcio==1.36.1 84 | - idna==2.10 85 | - importlib-metadata==3.7.3 86 | - joblib==1.0.1 87 | - markdown==3.3.4 88 | - nfnets-pytorch==0.1.2 89 | - oauthlib==3.1.0 90 | - pandas==1.2.3 91 | - protobuf==3.15.6 92 | - pyasn1==0.4.8 93 | - pyasn1-modules==0.2.8 94 | - pytz==2021.1 95 | - requests==2.25.1 96 | - requests-oauthlib==1.3.0 97 | - rsa==4.7.2 98 | - scikit-learn==0.24.1 99 | - scipy==1.6.1 100 | - seaborn==0.11.1 101 | - sklearn==0.0 102 | - tensorboard==2.1.1 103 | - threadpoolctl==2.1.0 104 | - torchlars==0.1.2 105 | - urllib3==1.26.4 106 | - werkzeug==1.0.1 107 | - zipp==3.4.1 108 | prefix: /home/harsh/anaconda3/envs/studiogan 109 | -------------------------------------------------------------------------------- /figures/Table3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/class-balancing-gan/6c254c5fdf86155ad6c99e9e4668fc0de713da1f/figures/Table3.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # main.py 6 | 7 | 8 | from argparse import ArgumentParser 9 | import json 10 | import os 11 | 12 | from make_hdf5 import make_hdf5 13 | from train import train_framework 14 | 15 | 16 | 17 | def main(): 18 | parser = ArgumentParser(add_help=False) 19 | parser.add_argument('-c', '--config_path', type=str, default='./configs/Table1/proj_biggan_cifar32_hinge_no.json') 20 | parser.add_argument('--checkpoint_folder', type=str, default=None) 21 | parser.add_argument('-current', '--load_current', action='store_true', help='choose whether you load current or est weights') 22 | parser.add_argument('--log_output_path', type=str, default=None) 23 | 24 | parser.add_argument('--seed', type=int, default=82624, help='seed for generating random number') 25 | parser.add_argument('--num_workers', type=int, default=16, help='') 26 | parser.add_argument('-sync_bn', '--synchronized_bn', action='store_true', help='select whether turn on synchronized batchnorm') 27 | parser.add_argument('-mpc', '--mixed_precision', action='store_true', help='select whether turn on mixed precision training') 28 | parser.add_argument('-rm_API', '--disable_debugging_API', action='store_true', help='whether disable pytorch autograd debugging mode') 29 | parser.add_argument('-fz_op', '--fused_optimization', action='store_true', help='using fused optimization for faster training') 30 | 31 | parser.add_argument('--reduce_train_dataset', type=float, default=1.0, help='control the number of train dataset') 32 | parser.add_argument('-l', '--load_all_data_in_memory', action='store_true') 33 | parser.add_argument('-t', '--train', action='store_true') 34 | parser.add_argument('-e', '--eval', action='store_true') 35 | parser.add_argument('-knn', '--k_nearest_neighbor', action='store_true', help='select whether conduct k-nearest neighbor analysis') 36 | parser.add_argument('-itp', '--interpolation', action='store_true', help='select whether conduct interpolation analysis') 37 | parser.add_argument('-le', '--linear_evaluation', action='store_true', help='select whether conduct linear classification on the feature space') 38 | parser.add_argument('--nrow', type=int, default=10, help='number of rows to plot image canvas') 39 | parser.add_argument('--ncol', type=int, default=8, help='number of cols to plot image canvas') 40 | parser.add_argument('--step_linear_eval', type=int, default=10000, help='number of steps for the optimization') 41 | 42 | parser.add_argument('--print_every', type=int, default=100, help='control log interval') 43 | parser.add_argument('--save_every', type=int, default=2000, help='control evaluation and save interval') 44 | parser.add_argument('--update_every', type=int, default=2000, help='Update the batch size every') 45 | parser.add_argument('--type4eval_dataset', type=str, default='test', help='[train/valid/test]') 46 | 47 | 48 | 49 | args = parser.parse_args() 50 | 51 | if args.config_path is not None: 52 | with open(args.config_path) as f: 53 | model_config = json.load(f) 54 | train_config = vars(args) 55 | else: 56 | raise NotImplementedError 57 | 58 | dataset = model_config['data_processing']['dataset_name'] 59 | if dataset == 'cifar10': 60 | assert args.type4eval_dataset == 'train' or args.type4eval_dataset == 'test', "cifar10 does not contain dataset for validation" 61 | elif dataset == 'imagenet' or dataset == 'tiny_imagenet': 62 | assert args.type4eval_dataset == 'train' or args.type4eval_dataset == 'valid',\ 63 | "we do not support the evaluation mode using test images in tiny_imagenet/imagenet dataset" 64 | 65 | hdf5_path_train = make_hdf5(mode='train',**model_config['data_processing'], **train_config) if args.load_all_data_in_memory else None 66 | 67 | train_framework(**train_config, 68 | **model_config['data_processing'], 69 | **model_config['train']['model'], 70 | **model_config['train']['optimization'], 71 | **model_config['train']['loss_function'], 72 | **model_config['train']['initialization'], 73 | **model_config['train']['training_and_sampling_setting'], 74 | train_config=train_config, model_config=model_config['train'], hdf5_path_train=hdf5_path_train) 75 | 76 | if __name__ == '__main__': 77 | main() 78 | -------------------------------------------------------------------------------- /make_hdf5.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch 3 | 4 | MIT License 5 | 6 | Copyright (c) 2019 Andy Brock 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | """ 23 | 24 | 25 | from data_utils.load_dataset import LoadDataset 26 | 27 | import os 28 | import sys 29 | from argparse import ArgumentParser 30 | from tqdm import tqdm, trange 31 | import h5py as h5 32 | import numpy as np 33 | import PIL 34 | 35 | import torch 36 | import torchvision.transforms as transforms 37 | from torch.utils.data import DataLoader 38 | 39 | 40 | 41 | 42 | def make_hdf5(dataset_name, data_path, img_size, batch_size4prcsing, num_workers, chunk_size, compression, mode, **_): 43 | if 'hdf5' in dataset_name: 44 | raise ValueError('Reading from an HDF5 file which you will probably be ' 45 | 'about to overwrite! Override this error only if you know ' 46 | 'what you''re doing!') 47 | 48 | file_name = '{dataset_name}_{size}_{mode}.hdf5'.format(dataset_name=dataset_name, size=img_size, mode=mode) 49 | file_path = os.path.join(data_path, file_name) 50 | train = True if mode == "train" else False 51 | 52 | if os.path.isfile(file_path): 53 | print("{file_name} exist!\nThe file are located in the {file_path}".format(file_name=file_name, file_path=file_path)) 54 | else: 55 | dataset = LoadDataset(dataset_name, data_path, train=train, download=True, resize_size=img_size, 56 | conditional_strategy="no", hdf5_path=None, consistency_reg=False, random_flip=False) 57 | 58 | loader = DataLoader(dataset, 59 | batch_size=batch_size4prcsing, 60 | shuffle=False, 61 | pin_memory=False, 62 | num_workers=num_workers, 63 | drop_last=False) 64 | 65 | print('Starting to load %s into an HDF5 file with chunk size %i and compression %s...' % (dataset_name, chunk_size, compression)) 66 | # Loop over loader 67 | for i,(x,y) in enumerate(tqdm(loader)): 68 | # Numpyify x, y 69 | x = (255 * ((x + 1) / 2.0)).byte().numpy() 70 | y = y.numpy() 71 | # If we're on the first batch, prepare the hdf5 72 | if i==0: 73 | with h5.File(file_path, 'w') as f: 74 | print('Producing dataset of len %d' % len(loader.dataset)) 75 | imgs_dset = f.create_dataset('imgs', x.shape, dtype='uint8', maxshape=(len(loader.dataset), 3, img_size, img_size), 76 | chunks=(chunk_size, 3, img_size, img_size), compression=compression) 77 | print('Image chunks chosen as ' + str(imgs_dset.chunks)) 78 | imgs_dset[...] = x 79 | 80 | labels_dset = f.create_dataset('labels', y.shape, dtype='int64', maxshape=(len(loader.dataset),), 81 | chunks=(chunk_size,), compression=compression) 82 | print('Label chunks chosen as ' + str(labels_dset.chunks)) 83 | labels_dset[...] = y 84 | # Else append to the hdf5 85 | else: 86 | with h5.File(file_path, 'a') as f: 87 | f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0) 88 | f['imgs'][-x.shape[0]:] = x 89 | f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0) 90 | f['labels'][-y.shape[0]:] = y 91 | return file_path 92 | -------------------------------------------------------------------------------- /metrics/IS.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # metrics/IS.py 6 | 7 | 8 | from utils.sample import sample_latents 9 | from utils.losses import latent_optimise 10 | 11 | import torch 12 | from torch.nn import DataParallel 13 | 14 | import math 15 | from tqdm import tqdm 16 | 17 | 18 | 19 | class evaluator(object): 20 | def __init__(self,inception_model, device): 21 | self.inception_model = inception_model 22 | self.device = device 23 | 24 | 25 | def generate_images(self, gen, dis, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, batch_size): 26 | if isinstance(gen, DataParallel): 27 | z_dim = gen.module.z_dim 28 | num_classes = gen.module.num_classes 29 | else: 30 | z_dim = gen.z_dim 31 | num_classes = gen.num_classes 32 | 33 | z, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, self.device) 34 | 35 | if latent_op: 36 | z = latent_optimise(z, fake_labels, gen, dis, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, self.device) 37 | 38 | with torch.no_grad(): 39 | batch_images = gen(z, fake_labels, evaluation=True) 40 | 41 | return batch_images 42 | 43 | 44 | def inception_softmax(self, batch_images): 45 | with torch.no_grad(): 46 | embeddings, logits = self.inception_model(batch_images) 47 | y = torch.nn.functional.softmax(logits, dim=1) 48 | return y 49 | 50 | 51 | def kl_scores(self, ys, splits): 52 | scores = [] 53 | n_images = ys.shape[0] 54 | with torch.no_grad(): 55 | for j in range(splits): 56 | part = ys[(j*n_images//splits): ((j+1)*n_images//splits), :] 57 | kl = part * (torch.log(part) - torch.log(torch.unsqueeze(torch.mean(part, 0), 0))) 58 | kl = torch.mean(torch.sum(kl, 1)) 59 | kl = torch.exp(kl) 60 | scores.append(kl.unsqueeze(0)) 61 | scores = torch.cat(scores, 0) 62 | m_scores = torch.mean(scores).detach().cpu().numpy() 63 | m_std = torch.std(scores).detach().cpu().numpy() 64 | return m_scores, m_std 65 | 66 | 67 | def eval_gen(self, gen, dis, n_eval, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, 68 | latent_op_beta, split, batch_size): 69 | ys = [] 70 | n_batches = int(math.ceil(float(n_eval) / float(batch_size))) 71 | for i in tqdm(range(n_batches)): 72 | batch_images = self.generate_images(gen, dis, truncated_factor, prior, latent_op, latent_op_step, 73 | latent_op_alpha, latent_op_beta, batch_size) 74 | y = self.inception_softmax(batch_images) 75 | ys.append(y) 76 | 77 | with torch.no_grad(): 78 | ys = torch.cat(ys, 0) 79 | m_scores, m_std = self.kl_scores(ys[:n_eval], splits=split) 80 | return m_scores, m_std 81 | 82 | 83 | def eval_dataset(self, dataloader, splits): 84 | batch_size = dataloader.batch_size 85 | n_images = len(dataloader.dataset) 86 | n_batches = len(dataloader) #int(math.ceil(float(n_images)/float(batch_size))) 87 | dataset_iter = iter(dataloader) 88 | ys = [] 89 | for i in tqdm(range(n_batches)): 90 | feed_list = next(dataset_iter) 91 | batch_images, batch_labels = feed_list[0], feed_list[1] 92 | batch_images = batch_images.to(self.device) 93 | y = self.inception_softmax(batch_images) 94 | ys.append(y) 95 | 96 | with torch.no_grad(): 97 | ys = torch.cat(ys, 0) 98 | m_scores, m_std = self.kl_scores(ys, splits=splits) 99 | return m_scores, m_std 100 | 101 | 102 | def calculate_incep_score(dataloader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, 103 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, splits, device): 104 | generator.eval() 105 | discriminator.eval() 106 | inception_model.eval() 107 | 108 | batch_size = dataloader.batch_size 109 | evaluator_instance = evaluator(inception_model, device=device) 110 | print("Calculating Inception Score....") 111 | kl_score, kl_std = evaluator_instance.eval_gen(generator, discriminator, n_generate, truncated_factor, prior, 112 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, splits, batch_size) 113 | generator.train() 114 | discriminator.train() 115 | return kl_score, kl_std 116 | -------------------------------------------------------------------------------- /metrics/prepare_inception_moments_eval_dataset.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # metrics/prepare_inception_moments_eval_dataset.py 6 | 7 | 8 | from metrics.FID import calculate_activation_statistics 9 | from metrics.IS import evaluator 10 | 11 | import numpy as np 12 | import os 13 | 14 | 15 | 16 | 17 | def prepare_inception_moments_eval_dataset(dataloader, eval_mode, generator, inception_model, splits, run_name, logger, device): 18 | dataset_name = dataloader.dataset.dataset_name 19 | inception_model.eval() 20 | 21 | save_path = os.path.abspath(os.path.join("./data", dataset_name + "_" + eval_mode +'_' + 'inception_moments.npz')) 22 | is_file = os.path.isfile(save_path) 23 | 24 | if is_file is True: 25 | mu = np.load(save_path)['mu'] 26 | sigma = np.load(save_path)['sigma'] 27 | else: 28 | logger.info('Calculate moments of {} dataset'.format(eval_mode)) 29 | mu, sigma = calculate_activation_statistics(data_loader=dataloader, 30 | generator=generator, 31 | discriminator=None, 32 | inception_model=inception_model, 33 | n_generate=None, 34 | truncated_factor=None, 35 | prior=None, 36 | is_generate=False, 37 | latent_op=False, 38 | latent_op_step=None, 39 | latent_op_alpha=None, 40 | latent_op_beta=None, 41 | device=device, 42 | tqdm_disable=False, 43 | run_name=run_name) 44 | 45 | logger.info('Saving calculated means and covariances to disk...') 46 | np.savez(save_path, **{'mu': mu, 'sigma': sigma}) 47 | 48 | logger.info('calculate inception score of {} dataset'.format(eval_mode)) 49 | evaluator_instance = evaluator(inception_model, device=device) 50 | is_score, is_std = evaluator_instance.eval_dataset(dataloader, splits=splits) 51 | logger.info('Inception score={is_score}-Inception_std={is_std}'.format(is_score=is_score, is_std=is_std)) 52 | return mu, sigma, is_score, is_std 53 | 54 | 55 | -------------------------------------------------------------------------------- /models/linear_classifier.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # models/linear_classifier.py 6 | 7 | 8 | from models.model_ops import * 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | 16 | class linear_classifier(nn.Module): 17 | def __init__(self, in_channels, num_classes): 18 | super(linear_classifier, self).__init__() 19 | self.linear = linear(in_features=in_channels, out_features=num_classes) 20 | 21 | def forward(self, x): 22 | out = self.linear(x) 23 | return out -------------------------------------------------------------------------------- /resnet_cifar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet for CIFAR10 as described in paper [1]. 3 | The implementation and structure of this file is hugely influenced by [2] 4 | which is implemented for ImageNet and doesn't have option A for identity. 5 | Moreover, most of the implementations on the web is copy-paste from 6 | torchvision's resnet and has wrong number of params. 7 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 8 | number of layers and parameters: 9 | name | layers | params 10 | ResNet20 | 20 | 0.27M 11 | ResNet32 | 32 | 0.46M 12 | ResNet44 | 44 | 0.66M 13 | ResNet56 | 56 | 0.85M 14 | ResNet110 | 110 | 1.7M 15 | ResNet1202| 1202 | 19.4m 16 | which this implementation indeed has. 17 | Reference: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 21 | If you use this implementation in you work, please don't forget to mention the 22 | author, Yerlan Idelbayev. 23 | ''' 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | import torch.nn.init as init 28 | from torch.nn import Parameter 29 | 30 | __all__ = ['ResNet_s', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 31 | 32 | def _weights_init(m): 33 | classname = m.__class__.__name__ 34 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 35 | init.kaiming_normal_(m.weight) 36 | 37 | class NormedLinear(nn.Module): 38 | 39 | def __init__(self, in_features, out_features): 40 | super(NormedLinear, self).__init__() 41 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 42 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 43 | 44 | def forward(self, x): 45 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 46 | return out 47 | 48 | class LambdaLayer(nn.Module): 49 | 50 | def __init__(self, lambd): 51 | super(LambdaLayer, self).__init__() 52 | self.lambd = lambd 53 | 54 | def forward(self, x): 55 | return self.lambd(x) 56 | 57 | 58 | class BasicBlock(nn.Module): 59 | expansion = 1 60 | 61 | def __init__(self, in_planes, planes, stride=1, option='A'): 62 | super(BasicBlock, self).__init__() 63 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | 68 | self.shortcut = nn.Sequential() 69 | if stride != 1 or in_planes != planes: 70 | if option == 'A': 71 | """ 72 | For CIFAR10 ResNet paper uses option A. 73 | """ 74 | self.shortcut = LambdaLayer(lambda x: 75 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 76 | elif option == 'B': 77 | self.shortcut = nn.Sequential( 78 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 79 | nn.BatchNorm2d(self.expansion * planes) 80 | ) 81 | 82 | def forward(self, x): 83 | out = F.relu(self.bn1(self.conv1(x))) 84 | out = self.bn2(self.conv2(out)) 85 | out += self.shortcut(x) 86 | out = F.relu(out) 87 | return out 88 | 89 | 90 | class ResNet_s(nn.Module): 91 | 92 | def __init__(self, block, num_blocks, num_classes=10, use_norm=False): 93 | super(ResNet_s, self).__init__() 94 | self.in_planes = 16 95 | 96 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 97 | self.bn1 = nn.BatchNorm2d(16) 98 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 99 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 100 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 101 | if use_norm: 102 | self.linear = NormedLinear(64, num_classes) 103 | else: 104 | self.linear = nn.Linear(64, num_classes) 105 | self.apply(_weights_init) 106 | 107 | def _make_layer(self, block, planes, num_blocks, stride): 108 | strides = [stride] + [1]*(num_blocks-1) 109 | layers = [] 110 | for stride in strides: 111 | layers.append(block(self.in_planes, planes, stride)) 112 | self.in_planes = planes * block.expansion 113 | 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x): 117 | out = F.relu(self.bn1(self.conv1(x))) 118 | out = self.layer1(out) 119 | out = self.layer2(out) 120 | out = self.layer3(out) 121 | out = F.avg_pool2d(out, out.size()[3]) 122 | out = out.view(out.size(0), -1) 123 | out = self.linear(out) 124 | return out 125 | 126 | 127 | def resnet20(): 128 | return ResNet_s(BasicBlock, [3, 3, 3]) 129 | 130 | 131 | def resnet32(num_classes=10, use_norm=False): 132 | return ResNet_s(BasicBlock, [5, 5, 5], num_classes=num_classes, use_norm=use_norm) 133 | 134 | 135 | def resnet44(): 136 | return ResNet_s(BasicBlock, [7, 7, 7]) 137 | 138 | 139 | def resnet56(): 140 | return ResNet_s(BasicBlock, [9, 9, 9]) 141 | 142 | 143 | def resnet110(): 144 | return ResNet_s(BasicBlock, [18, 18, 18]) 145 | 146 | 147 | def resnet1202(): 148 | return ResNet_s(BasicBlock, [200, 200, 200]) 149 | 150 | 151 | def test(net): 152 | import numpy as np 153 | total_params = 0 154 | 155 | for x in filter(lambda p: p.requires_grad, net.parameters()): 156 | total_params += np.prod(x.data.numpy().shape) 157 | print("Total number of params", total_params) 158 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 159 | 160 | 161 | if __name__ == "__main__": 162 | for net_name in __all__: 163 | if net_name.startswith('resnet'): 164 | print(net_name) 165 | test(globals()[net_name]()) 166 | print() -------------------------------------------------------------------------------- /sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding: utf-8 -*- 3 | File : batchnorm_reimpl.py 4 | Author : Jiayuan Mao 5 | Email : maojiayuan@gmail.com 6 | Date : 27/01/2018 7 | 8 | This file is part of Synchronized-BatchNorm-PyTorch. 9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 10 | Distributed under MIT License. 11 | 12 | MIT License 13 | 14 | Copyright (c) 2018 Jiayuan MAO 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | 36 | import torch 37 | import torch.nn as nn 38 | import torch.nn.init as init 39 | 40 | __all__ = ['BatchNorm2dReimpl'] 41 | 42 | 43 | class BatchNorm2dReimpl(nn.Module): 44 | """ 45 | A re-implementation of batch normalization, used for testing the numerical 46 | stability. 47 | 48 | Author: acgtyrant 49 | See also: 50 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 51 | """ 52 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 53 | super().__init__() 54 | 55 | self.num_features = num_features 56 | self.eps = eps 57 | self.momentum = momentum 58 | self.weight = nn.Parameter(torch.empty(num_features)) 59 | self.bias = nn.Parameter(torch.empty(num_features)) 60 | self.register_buffer('running_mean', torch.zeros(num_features)) 61 | self.register_buffer('running_var', torch.ones(num_features)) 62 | self.reset_parameters() 63 | 64 | def reset_running_stats(self): 65 | self.running_mean.zero_() 66 | self.running_var.fill_(1) 67 | 68 | def reset_parameters(self): 69 | self.reset_running_stats() 70 | init.uniform_(self.weight) 71 | init.zeros_(self.bias) 72 | 73 | def forward(self, input_): 74 | batchsize, channels, height, width = input_.size() 75 | numel = batchsize * height * width 76 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 77 | sum_ = input_.sum(1) 78 | sum_of_square = input_.pow(2).sum(1) 79 | mean = sum_ / numel 80 | sumvar = sum_of_square - sum_ * mean 81 | 82 | self.running_mean = ( 83 | (1 - self.momentum) * self.running_mean 84 | + self.momentum * mean.detach() 85 | ) 86 | unbias_var = sumvar / (numel - 1) 87 | self.running_var = ( 88 | (1 - self.momentum) * self.running_var 89 | + self.momentum * unbias_var.detach() 90 | ) 91 | 92 | bias_var = sumvar / numel 93 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 94 | output = ( 95 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 96 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 97 | 98 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 99 | 100 | -------------------------------------------------------------------------------- /sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding: utf-8 -*- 3 | File : comm.py 4 | Author : Jiayuan Mao 5 | Email : maojiayuan@gmail.com 6 | Date : 27/01/2018 7 | 8 | This file is part of Synchronized-BatchNorm-PyTorch. 9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 10 | Distributed under MIT License. 11 | 12 | MIT License 13 | 14 | Copyright (c) 2018 Jiayuan MAO 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | 36 | import queue 37 | import collections 38 | import threading 39 | 40 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 41 | 42 | 43 | class FutureResult(object): 44 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 45 | 46 | def __init__(self): 47 | self._result = None 48 | self._lock = threading.Lock() 49 | self._cond = threading.Condition(self._lock) 50 | 51 | def put(self, result): 52 | with self._lock: 53 | assert self._result is None, 'Previous result has\'t been fetched.' 54 | self._result = result 55 | self._cond.notify() 56 | 57 | def get(self): 58 | with self._lock: 59 | if self._result is None: 60 | self._cond.wait() 61 | 62 | res = self._result 63 | self._result = None 64 | return res 65 | 66 | 67 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 68 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 69 | 70 | 71 | class SlavePipe(_SlavePipeBase): 72 | """Pipe for master-slave communication.""" 73 | 74 | def run_slave(self, msg): 75 | self.queue.put((self.identifier, msg)) 76 | ret = self.result.get() 77 | self.queue.put(True) 78 | return ret 79 | 80 | 81 | class SyncMaster(object): 82 | """An abstract `SyncMaster` object. 83 | 84 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 85 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 86 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 87 | and passed to a registered callback. 88 | - After receiving the messages, the master device should gather the information and determine to message passed 89 | back to each slave devices. 90 | """ 91 | 92 | def __init__(self, master_callback): 93 | """ 94 | 95 | Args: 96 | master_callback: a callback to be invoked after having collected messages from slave devices. 97 | """ 98 | self._master_callback = master_callback 99 | self._queue = queue.Queue() 100 | self._registry = collections.OrderedDict() 101 | self._activated = False 102 | 103 | def __getstate__(self): 104 | return {'master_callback': self._master_callback} 105 | 106 | def __setstate__(self, state): 107 | self.__init__(state['master_callback']) 108 | 109 | def register_slave(self, identifier): 110 | """ 111 | Register an slave device. 112 | 113 | Args: 114 | identifier: an identifier, usually is the device id. 115 | 116 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 117 | 118 | """ 119 | if self._activated: 120 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 121 | self._activated = False 122 | self._registry.clear() 123 | future = FutureResult() 124 | self._registry[identifier] = _MasterRegistry(future) 125 | return SlavePipe(identifier, self._queue, future) 126 | 127 | def run_master(self, master_msg): 128 | """ 129 | Main entry for the master device in each forward pass. 130 | The messages were first collected from each devices (including the master device), and then 131 | an callback will be invoked to compute the message to be sent back to each devices 132 | (including the master device). 133 | 134 | Args: 135 | master_msg: the message that the master want to send to itself. This will be placed as the first 136 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 137 | 138 | Returns: the message to be sent back to the master device. 139 | 140 | """ 141 | self._activated = True 142 | 143 | intermediates = [(0, master_msg)] 144 | for i in range(self.nr_slaves): 145 | intermediates.append(self._queue.get()) 146 | 147 | results = self._master_callback(intermediates) 148 | assert results[0][0] == 0, 'The first result should belongs to the master.' 149 | 150 | for i, res in results: 151 | if i == 0: 152 | continue 153 | self._registry[i].result.put(res) 154 | 155 | for i in range(self.nr_slaves): 156 | assert self._queue.get() is True 157 | 158 | return results[0][1] 159 | 160 | @property 161 | def nr_slaves(self): 162 | return len(self._registry) 163 | -------------------------------------------------------------------------------- /sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding: utf-8 -*- 3 | File : replicate.py 4 | Author : Jiayuan Mao 5 | Email : maojiayuan@gmail.com 6 | Date : 27/01/2018 7 | 8 | This file is part of Synchronized-BatchNorm-PyTorch. 9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 10 | Distributed under MIT License. 11 | 12 | MIT License 13 | 14 | Copyright (c) 2018 Jiayuan MAO 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | 36 | import functools 37 | 38 | from torch.nn.parallel.data_parallel import DataParallel 39 | 40 | __all__ = [ 41 | 'CallbackContext', 42 | 'execute_replication_callbacks', 43 | 'DataParallelWithCallback', 44 | 'patch_replication_callback' 45 | ] 46 | 47 | 48 | class CallbackContext(object): 49 | pass 50 | 51 | 52 | def execute_replication_callbacks(modules): 53 | """ 54 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 55 | 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Note that, as all modules are isomorphism, we assign each sub-module with a context 59 | (shared among multiple copies of this module on different devices). 60 | Through this context, different copies can share some information. 61 | 62 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 63 | of any slave copies. 64 | """ 65 | master_copy = modules[0] 66 | nr_modules = len(list(master_copy.modules())) 67 | ctxs = [CallbackContext() for _ in range(nr_modules)] 68 | 69 | for i, module in enumerate(modules): 70 | for j, m in enumerate(module.modules()): 71 | if hasattr(m, '__data_parallel_replicate__'): 72 | m.__data_parallel_replicate__(ctxs[j], i) 73 | 74 | 75 | class DataParallelWithCallback(DataParallel): 76 | """ 77 | Data Parallel with a replication callback. 78 | 79 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 80 | original `replicate` function. 81 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 82 | 83 | Examples: 84 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 85 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 86 | # sync_bn.__data_parallel_replicate__ will be invoked. 87 | """ 88 | 89 | def replicate(self, module, device_ids): 90 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | 95 | def patch_replication_callback(data_parallel): 96 | """ 97 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 98 | Useful when you have customized `DataParallel` implementation. 99 | 100 | Examples: 101 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 102 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 103 | > patch_replication_callback(sync_bn) 104 | # this is equivalent to 105 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 106 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 107 | """ 108 | 109 | assert isinstance(data_parallel, DataParallel) 110 | 111 | old_replicate = data_parallel.replicate 112 | 113 | @functools.wraps(old_replicate) 114 | def new_replicate(module, device_ids): 115 | modules = old_replicate(module, device_ids) 116 | execute_replication_callbacks(modules) 117 | return modules 118 | 119 | data_parallel.replicate = new_replicate 120 | -------------------------------------------------------------------------------- /sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding: utf-8 -*- 3 | File : unittest.py 4 | Author : Jiayuan Mao 5 | Email : maojiayuan@gmail.com 6 | Date : 27/01/2018 7 | 8 | This file is part of Synchronized-BatchNorm-PyTorch. 9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 10 | Distributed under MIT License. 11 | 12 | MIT License 13 | 14 | Copyright (c) 2018 Jiayuan MAO 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | 36 | import unittest 37 | import torch 38 | 39 | 40 | class TorchTestCase(unittest.TestCase): 41 | def assertTensorClose(self, x, y): 42 | adiff = float((x - y).abs().max()) 43 | if (y == 0).all(): 44 | rdiff = 'NaN' 45 | else: 46 | rdiff = float((adiff / y).abs().max()) 47 | 48 | message = ( 49 | 'Tensor close check failed\n' 50 | 'adiff={}\n' 51 | 'rdiff={}\n' 52 | ).format(adiff, rdiff) 53 | self.assertTrue(torch.allclose(x, y), message) 54 | 55 | -------------------------------------------------------------------------------- /utils/biggan_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch 3 | 4 | MIT License 5 | 6 | Copyright (c) 2019 Andy Brock 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | """ 26 | 27 | 28 | import torch 29 | 30 | 31 | 32 | class ema_(object): 33 | def __init__(self, source, target, decay=0.9999, start_itr=0): 34 | self.source = source 35 | self.target = target 36 | self.decay = decay 37 | # Optional parameter indicating what iteration to start the decay at 38 | self.start_itr = start_itr 39 | # Initialize target's params to be source's 40 | self.source_dict = self.source.state_dict() 41 | self.target_dict = self.target.state_dict() 42 | print('Initializing EMA parameters to be source parameters...') 43 | with torch.no_grad(): 44 | for key in self.source_dict: 45 | self.target_dict[key].data.copy_(self.source_dict[key].data) 46 | # target_dict[key].data = source_dict[key].data # Doesn't work! 47 | 48 | def update(self, itr=None): 49 | # If an iteration counter is provided and itr is less than the start itr, 50 | # peg the ema weights to the underlying weights. 51 | if itr and itr < self.start_itr: 52 | decay = 0.0 53 | else: 54 | decay = self.decay 55 | with torch.no_grad(): 56 | for key in self.source_dict: 57 | self.target_dict[key].data.copy_(self.target_dict[key].data * decay + self.source_dict[key].data * (1 - decay)) 58 | 59 | 60 | def ortho(model, strength=1e-4, blacklist=[]): 61 | with torch.no_grad(): 62 | for param in model.parameters(): 63 | # Only apply this to parameters with at least 2 axes, and not in the blacklist 64 | if len(param.shape) < 2 or any([param is item for item in blacklist]): 65 | continue 66 | w = param.view(param.shape[0], -1) 67 | grad = (2 * torch.mm(torch.mm(w, w.t()) 68 | * (1. - torch.eye(w.shape[0], device=w.device)), w)) 69 | param.grad.data += strength * grad.view(param.shape) 70 | 71 | # Convenience utility to switch off requires_grad 72 | def toggle_grad(model, on_or_off): 73 | for param in model.parameters(): 74 | param.requires_grad = on_or_off 75 | 76 | 77 | # Interp function; expects x0 and x1 to be of shape (shape0, 1, rest_of_shape..) 78 | def interp(x0, x1, num_midpoints): 79 | lerp = torch.linspace(0, 1.0, num_midpoints + 2, device='cuda').to(x0.dtype) 80 | return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1))) -------------------------------------------------------------------------------- /utils/diff_aug.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | """ 26 | 27 | 28 | import torch 29 | import torch.nn.functional as F 30 | 31 | 32 | 33 | ### Differentiable Augmentation for Data-Efficient GAN Training (https://arxiv.org/abs/2006.10738) 34 | ### Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 35 | ### https://github.com/mit-han-lab/data-efficient-gans 36 | 37 | 38 | def DiffAugment(x, policy='', channels_first=True): 39 | if policy: 40 | if not channels_first: 41 | x = x.permute(0, 3, 1, 2) 42 | for p in policy.split(','): 43 | for f in AUGMENT_FNS[p]: 44 | x = f(x) 45 | if not channels_first: 46 | x = x.permute(0, 2, 3, 1) 47 | x = x.contiguous() 48 | return x 49 | 50 | 51 | def rand_brightness(x): 52 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 53 | return x 54 | 55 | 56 | def rand_saturation(x): 57 | x_mean = x.mean(dim=1, keepdim=True) 58 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 59 | return x 60 | 61 | 62 | def rand_contrast(x): 63 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 64 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 65 | return x 66 | 67 | 68 | def rand_translation(x, ratio=0.125): 69 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 70 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 71 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 72 | grid_batch, grid_x, grid_y = torch.meshgrid( 73 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 74 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 75 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 76 | ) 77 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 78 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 79 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 80 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 81 | return x 82 | 83 | 84 | def rand_cutout(x, ratio=0.5): 85 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 86 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 87 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 88 | grid_batch, grid_x, grid_y = torch.meshgrid( 89 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 90 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 91 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 92 | ) 93 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 94 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 95 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 96 | mask[grid_batch, grid_x, grid_y] = 0 97 | x = x * mask.unsqueeze(1) 98 | return x 99 | 100 | 101 | AUGMENT_FNS = { 102 | 'color': [rand_brightness, rand_saturation, rand_contrast], 103 | 'translation': [rand_translation], 104 | 'cutout': [rand_cutout], 105 | } 106 | -------------------------------------------------------------------------------- /utils/icr.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # utils/icr.py 6 | 7 | 8 | import random 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | 14 | 15 | def ICR_Aug(x, flip=True, translation=True): 16 | if flip: 17 | x = random_flip(x, 0.5) 18 | if translation: 19 | x = random_translation(x, 1/8) 20 | if flip or translation: 21 | x = x.contiguous() 22 | return x 23 | 24 | 25 | def random_flip(x, p): 26 | x_out = x.clone() 27 | n, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] 28 | flip_prob = torch.FloatTensor(n, 1).uniform_(0.0, 1.0) 29 | flip_mask = flip_prob < p 30 | flip_mask = flip_mask.type(torch.bool).view(n, 1, 1, 1).repeat(1, c, h, w).to(x.device) 31 | x_out[flip_mask] = torch.flip(x[flip_mask].view(-1, c, h, w), [3]).view(-1) 32 | return x_out 33 | 34 | 35 | def random_translation(x, ratio): 36 | max_t_x, max_t_y = int(x.shape[2]*ratio), int(x.shape[3]*ratio) 37 | t_x = torch.randint(-max_t_x, max_t_x + 1, size = [x.shape[0], 1, 1], device=x.device) 38 | t_y = torch.randint(-max_t_y, max_t_y + 1, size = [x.shape[0], 1, 1], device=x.device) 39 | 40 | grid_batch, grid_x, grid_y = torch.meshgrid( 41 | torch.arange(x.shape[0], dtype=torch.long, device=x.device), 42 | torch.arange(x.shape[2], dtype=torch.long, device=x.device), 43 | torch.arange(x.shape[3], dtype=torch.long, device=x.device), 44 | ) 45 | 46 | grid_x = (grid_x + t_x) + max_t_x 47 | grid_y = (grid_y + t_y) + max_t_y 48 | x_pad = F.pad(input=x, pad=[max_t_x, max_t_x, max_t_y, max_t_y], mode='reflect') 49 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 50 | return x 51 | -------------------------------------------------------------------------------- /utils/load_checkpoint.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # utils/load_checkpoint.py 6 | 7 | 8 | import torch 9 | 10 | import os 11 | 12 | 13 | 14 | def load_checkpoint(model, optimizer, filename, metric=False, ema=False): 15 | # Note: Input model & optimizer should be pre-defined. This routine only updates their states. 16 | start_step = 0 17 | if ema: 18 | checkpoint = torch.load(filename) 19 | model.load_state_dict(checkpoint['state_dict']) 20 | return model 21 | else: 22 | checkpoint = torch.load(filename) 23 | seed = checkpoint['seed'] 24 | run_name = checkpoint['run_name'] 25 | start_step = checkpoint['step'] 26 | model.load_state_dict(checkpoint['state_dict']) 27 | optimizer.load_state_dict(checkpoint['optimizer']) 28 | ada_p = checkpoint['ada_p'] 29 | for state in optimizer.state.values(): 30 | for k, v in state.items(): 31 | if isinstance(v, torch.Tensor): 32 | state[k] = v.cuda() 33 | 34 | if metric: 35 | best_step = checkpoint['best_step'] 36 | best_fid = checkpoint['best_fid'] 37 | best_fid_checkpoint_path = checkpoint['best_fid_checkpoint_path'] 38 | return model, optimizer, seed, run_name, start_step, ada_p, best_step, best_fid, best_fid_checkpoint_path 39 | return model, optimizer, seed, run_name, start_step, ada_p 40 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os.path import dirname, abspath, exists, join 3 | import os 4 | import logging 5 | from datetime import datetime 6 | 7 | 8 | 9 | def make_run_name(format, framework, phase): 10 | return format.format( 11 | framework=framework, 12 | phase=phase, 13 | timestamp=datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 14 | ) 15 | 16 | 17 | def make_logger(run_name, log_output): 18 | if log_output is not None: 19 | run_name = log_output.split('/')[-1].split('.')[0] 20 | logger = logging.getLogger(run_name) 21 | logger.propagate = False 22 | log_filepath = log_output if log_output is not None else join('logs', f'{run_name}.log') 23 | 24 | log_dir = dirname(abspath(log_filepath)) 25 | if not exists(log_dir): 26 | os.makedirs(log_dir) 27 | 28 | if not logger.handlers: # execute only if logger doesn't already exist 29 | file_handler = logging.FileHandler(log_filepath, 'a', 'utf-8') 30 | stream_handler = logging.StreamHandler(os.sys.stdout) 31 | 32 | formatter = logging.Formatter('[%(levelname)s] %(asctime)s > %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 33 | 34 | file_handler.setFormatter(formatter) 35 | stream_handler.setFormatter(formatter) 36 | 37 | logger.addHandler(file_handler) 38 | logger.addHandler(stream_handler) 39 | logger.setLevel(logging.INFO) 40 | return logger 41 | 42 | 43 | def make_checkpoint_dir(checkpoint_dir, run_name): 44 | checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else join('checkpoints', run_name) 45 | if not exists(abspath(checkpoint_dir)): 46 | os.makedirs(checkpoint_dir) 47 | return checkpoint_dir 48 | -------------------------------------------------------------------------------- /utils/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /utils/op/fused_act.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Kim Seonghyeon 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | 26 | import os 27 | 28 | import torch 29 | from torch import nn 30 | from torch.nn import functional as F 31 | from torch.autograd import Function 32 | from torch.utils.cpp_extension import load 33 | 34 | 35 | module_path = os.path.dirname(__file__) 36 | fused = load( 37 | "fused", 38 | sources=[ 39 | os.path.join(module_path, "fused_bias_act.cpp"), 40 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 41 | ], 42 | ) 43 | 44 | 45 | class FusedLeakyReLUFunctionBackward(Function): 46 | @staticmethod 47 | def forward(ctx, grad_output, out, negative_slope, scale): 48 | ctx.save_for_backward(out) 49 | ctx.negative_slope = negative_slope 50 | ctx.scale = scale 51 | 52 | empty = grad_output.new_empty(0) 53 | 54 | grad_input = fused.fused_bias_act( 55 | grad_output, empty, out, 3, 1, negative_slope, scale 56 | ) 57 | 58 | dim = [0] 59 | 60 | if grad_input.ndim > 2: 61 | dim += list(range(2, grad_input.ndim)) 62 | 63 | grad_bias = grad_input.sum(dim).detach() 64 | 65 | return grad_input, grad_bias 66 | 67 | @staticmethod 68 | def backward(ctx, gradgrad_input, gradgrad_bias): 69 | out, = ctx.saved_tensors 70 | gradgrad_out = fused.fused_bias_act( 71 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 72 | ) 73 | 74 | return gradgrad_out, None, None, None 75 | 76 | 77 | class FusedLeakyReLUFunction(Function): 78 | @staticmethod 79 | def forward(ctx, input, bias, negative_slope, scale): 80 | empty = input.new_empty(0) 81 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 82 | ctx.save_for_backward(out) 83 | ctx.negative_slope = negative_slope 84 | ctx.scale = scale 85 | 86 | return out 87 | 88 | @staticmethod 89 | def backward(ctx, grad_output): 90 | out, = ctx.saved_tensors 91 | 92 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 93 | grad_output, out, ctx.negative_slope, ctx.scale 94 | ) 95 | 96 | return grad_input, grad_bias, None, None 97 | 98 | 99 | class FusedLeakyReLU(nn.Module): 100 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 101 | super().__init__() 102 | 103 | self.bias = nn.Parameter(torch.zeros(channel)) 104 | self.negative_slope = negative_slope 105 | self.scale = scale 106 | 107 | def forward(self, input): 108 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 109 | 110 | 111 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 112 | if input.device.type == "cpu": 113 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 114 | return ( 115 | F.leaky_relu( 116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 117 | ) 118 | * scale 119 | ) 120 | 121 | else: 122 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 123 | -------------------------------------------------------------------------------- /utils/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | MIT License 3 | 4 | Copyright (c) 2019 Kim Seonghyeon 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | */ 24 | 25 | 26 | #include 27 | 28 | 29 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 30 | int act, int grad, float alpha, float scale); 31 | 32 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 33 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 34 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 35 | 36 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 37 | int act, int grad, float alpha, float scale) { 38 | CHECK_CUDA(input); 39 | CHECK_CUDA(bias); 40 | 41 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 42 | } 43 | 44 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 45 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 46 | } 47 | -------------------------------------------------------------------------------- /utils/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } 100 | -------------------------------------------------------------------------------- /utils/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | MIT License 3 | 4 | Copyright (c) 2019 Kim Seonghyeon 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | */ 24 | 25 | 26 | #include 27 | 28 | 29 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 30 | int up_x, int up_y, int down_x, int down_y, 31 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 32 | 33 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 34 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 35 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 36 | 37 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 38 | int up_x, int up_y, int down_x, int down_y, 39 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 40 | CHECK_CUDA(input); 41 | CHECK_CUDA(kernel); 42 | 43 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 44 | } 45 | 46 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 47 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 48 | } 49 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # utils/plot.py 6 | 7 | 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from os.path import dirname, abspath, exists 11 | import os 12 | 13 | from torchvision.utils import save_image 14 | 15 | 16 | 17 | def plot_img_canvas(images, save_path, logger, nrow): 18 | directory = dirname(save_path) 19 | 20 | if not exists(abspath(directory)): 21 | os.makedirs(directory) 22 | 23 | save_image(images, save_path, padding=0, nrow=nrow) 24 | logger.info("Saved image to {}".format(save_path)) 25 | 26 | 27 | def plot_confidence_histogram(confidence, labels, save_path, logger): 28 | directory = dirname(save_path) 29 | 30 | if not exists(abspath(directory)): 31 | os.makedirs(directory) 32 | 33 | f, ax = plt.subplots(1,1) 34 | real_confidence = confidence[labels==1.0] 35 | gen_confidence = confidence[labels!=1.0] 36 | plt.hist([real_confidence, gen_confidence], 20, density=True, alpha=0.5, color=['red','blue'], label= ['Real samples', 'Generated samples']) 37 | plt.legend(loc='upper right') 38 | ax.set_title('Confidence Histogram', fontsize=15) 39 | ax.set_xlabel('Confidence') 40 | ax.set_ylabel('Density') 41 | 42 | plt.savefig(save_path, dpi=1000) 43 | logger.info("Saved image to {}".format(save_path)) 44 | plt.close() 45 | 46 | 47 | def discrete_cmap(base_cmap, num_classes): 48 | base = plt.cm.get_cmap(base_cmap) 49 | color_list = base(np.linspace(0,1,num_classes)) 50 | cmap_name = base.name + str(num_classes) 51 | return base.from_list(cmap_name,color_list,num_classes) 52 | 53 | def plot_2d_scatter(x0,x1, num_classes, labels, file_name): 54 | plt.figure(figsize = (8,6)) 55 | plt.scatter(x0, x1, c = labels, marker ='.', edgecolor = 'none', cmap = discrete_cmap('jet', num_classes), alpha=0.5) 56 | plt.colorbar() 57 | plt.grid() 58 | # plt.xlim((0.0, 2.0)) 59 | # plt.ylim((0.0, 2.0)) 60 | if not exists(abspath('./experimetns')): 61 | os.makedirs('./experimetns') 62 | plt.savefig('./experimetns/' + file_name) 63 | plt.close() 64 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # utils/utils.py 6 | 7 | 8 | import numpy as np 9 | import random 10 | import os 11 | from datetime import datetime 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.nn import DataParallel 16 | 17 | #Loading model Weights 18 | def load_model_weights(model, checkpoint_path): 19 | checkpoint = torch.load(checkpoint_path, map_location='cuda:0')['state_dict'] 20 | for key in list(checkpoint.keys()): 21 | if 'module.' in key: 22 | checkpoint[key.replace('module.', '')] = checkpoint[key] 23 | del checkpoint[key] 24 | 25 | model.load_state_dict(checkpoint) 26 | 27 | return model 28 | 29 | 30 | # fix python, numpy, torch seed 31 | def fix_all_seed(seed): 32 | random.seed(seed) 33 | np.random.seed(seed) 34 | torch.manual_seed(seed) 35 | torch.cuda.manual_seed_all(seed) 36 | torch.cuda.manual_seed(seed) 37 | 38 | def count_parameters(module): 39 | return 'Number of parameters: {}'.format(sum([p.data.nelement() for p in module.parameters()])) 40 | 41 | def elapsed_time(start_time): 42 | now = datetime.now() 43 | elapsed = now - start_time 44 | return str(elapsed).split('.')[0] # remove milliseconds 45 | 46 | def reshape_weight_to_matrix(weight): 47 | weight_mat = weight 48 | dim =0 49 | if dim != 0: 50 | # permute dim to front 51 | weight_mat = weight_mat.permute(dim, *[d for d in range(weight_mat.dim()) if d != dim]) 52 | height = weight_mat.size(0) 53 | return weight_mat.reshape(height, -1) 54 | 55 | 56 | def find_string(list_, string): 57 | for i, s in enumerate(list_): 58 | if string == s: 59 | return i 60 | 61 | def find_and_remove(path): 62 | if os.path.isfile(path): 63 | os.remove(path) 64 | 65 | def calculate_all_sn(model): 66 | sigmas = {} 67 | with torch.no_grad(): 68 | for name, param in model.named_parameters(): 69 | if "weight" in name and "bn" not in name and "shared" not in name and "deconv" not in name: 70 | if "blocks" in name: 71 | splited_name = name.split('.') 72 | idx = find_string(splited_name, 'blocks') 73 | block_idx = int(splited_name[int(idx+1)]) 74 | module_idx = int(splited_name[int(idx+2)]) 75 | operation_name = splited_name[idx+3] 76 | if isinstance(model, DataParallel): 77 | operations = model.module.blocks[block_idx][module_idx] 78 | else: 79 | operations = model.blocks[block_idx][module_idx] 80 | operation = getattr(operations, operation_name) 81 | else: 82 | splited_name = name.split('.') 83 | idx = find_string(splited_name, 'module') if isinstance(model, DataParallel) else -1 84 | operation_name = splited_name[idx+1] 85 | if isinstance(model, DataParallel): 86 | operation = getattr(model.module, operation_name) 87 | else: 88 | operation = getattr(model, operation_name) 89 | 90 | weight_orig = reshape_weight_to_matrix(operation.weight_orig) 91 | weight_u = operation.weight_u 92 | weight_v = operation.weight_v 93 | sigmas[name] = torch.dot(weight_u, torch.mv(weight_orig, weight_v)) 94 | return sigmas 95 | 96 | 97 | def load_model_weights(model, checkpoint_path): 98 | checkpoint = torch.load(checkpoint_path, map_location='cuda:0')['state_dict'] 99 | for key in list(checkpoint.keys()): 100 | if 'module.' in key: 101 | checkpoint[key.replace('module.', '')] = checkpoint[key] 102 | del checkpoint[key] 103 | 104 | model.load_state_dict(checkpoint) 105 | 106 | return model 107 | -------------------------------------------------------------------------------- /weight_regularizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | def KL(P,Q): 6 | 7 | """ Epsilon is used here to avoid conditional code for 8 | checking that neither P nor Q is equal to 0. """ 9 | epsilon = 0.00001 10 | 11 | # You may want to instead make copies to avoid changing the np arrays. 12 | P = P+epsilon 13 | Q = Q+epsilon 14 | 15 | divergence = np.sum(P*np.log(P/Q)) 16 | return divergence 17 | 18 | 19 | 20 | class WeightRegularizer(): 21 | 22 | def __init__(self, num_classes, pretrained_model, init_param = 1, decay_weight = 0.5, beta = 0.99, number_sample = 10, mode="effective", init_size=512, class_req_perf = -1, actual_perf = 90, convex_comb = False): 23 | """Weight Regularizer is the regularizer based on output of a classifier. It 24 | causes the GAN to avoid mode collapse for the imbalanced classes. 25 | """ 26 | 27 | self.count_class_samples = [50] * num_classes 28 | self.decay_weight = decay_weight 29 | self.beta = beta 30 | self.num_classes = num_classes 31 | self.stats = [] 32 | self.number_to_sample = number_sample 33 | self.mode = mode 34 | self.pretrained_model = pretrained_model 35 | self.pred_class = [init_size] * self.num_classes 36 | self.convex_comb = convex_comb 37 | 38 | if self.mode == "effective": 39 | effective_num = 1.0 - np.power(self.beta, self.count_class_samples) 40 | weights = (1.0 - self.beta) / np.array(effective_num) 41 | else: 42 | weights = 1 / np.array(self.count_class_samples) 43 | 44 | self.weights = weights / np.sum(weights) * self.num_classes 45 | self.ce_loss = torch.nn.CrossEntropyLoss() 46 | 47 | # For decreasing classifier accuracy for ablations 48 | if class_req_perf >= 0 and actual_perf >= class_req_perf: 49 | self.random_ratio = (actual_perf - class_req_perf) /(actual_perf - 100/num_classes) 50 | else: 51 | self.random_ratio = None 52 | 53 | 54 | 55 | 56 | self.i = 0 57 | 58 | def update(self, logger = None): 59 | """Update the effective class statistics. 60 | """ 61 | 62 | 63 | 64 | stats, kl_div = self.get_stats() 65 | if logger != None: 66 | logger.info("Mean Number of Samples %f Variance of Number of Samples %f KL Divergence of the Samples %f ."%(np.mean(stats), np.std(stats), kl_div)) 67 | 68 | if self.convex_comb: 69 | factor = 1 - self.decay_weight 70 | else: 71 | factor = 1 72 | 73 | self.count_class_samples = [self.decay_weight * i for i in self.count_class_samples] 74 | for i in range(self.num_classes): 75 | self.count_class_samples[i] += (factor) * self.pred_class[i] 76 | 77 | self.reset_stats() 78 | # Clamp the values to one for values < 1 79 | self.count_class_samples = [max(1, i) for i in self.count_class_samples] 80 | #logger.info("Updated Cumulative Samples:" + str(self.count_class_samples)) 81 | 82 | if self.mode == "effective": 83 | effective_num = 1.0 - np.power(self.beta, self.count_class_samples) 84 | weights = (1.0 - self.beta) / np.array(effective_num) 85 | else: 86 | weights = 1 / np.array(self.count_class_samples) 87 | 88 | self.weights = weights / np.sum(weights) * self.num_classes 89 | 90 | 91 | def log_and_print(self, count_class_samples, writer = None): 92 | """ Print the log of statistics of generated image classes""" 93 | 94 | if len(self.stats) == self.number_to_sample: 95 | self.stats.pop(0) 96 | 97 | print("Max and Min of class samples",max(count_class_samples), min(count_class_samples)) 98 | self.stats.append(count_class_samples) 99 | 100 | stats = np.mean(np.array(self.stats), axis=0) 101 | stats = stats/np.sum(stats) 102 | print("Mean Number of Samples %f Standard Deviation of Number of Samples %f"%(np.mean(stats), np.std(stats))) 103 | 104 | if writer is not None: 105 | writer.add_scalar("12. Variance of Samples in Different Classes", np.std(stats)) 106 | 107 | def reset_stats(self): 108 | """[Reset the stats produced by classifier] 109 | """ 110 | self.pred_class = [0] * self.num_classes 111 | 112 | 113 | def get_stats(self): 114 | stats = np.array(self.pred_class) 115 | stats = stats/np.sum(stats) 116 | uniform = np.mean(stats)*np.ones(len(stats)) 117 | kl_divergence = KL(stats, uniform) 118 | 119 | return stats, kl_divergence 120 | 121 | 122 | 123 | 124 | 125 | 126 | def loss(self, input_images, labels=False, writer = None, target_labels = None): 127 | """This calculates the loss of the regularizer 128 | 129 | Args: 130 | softmax_output ([Tensor]): [Softmax output of the given batch] 131 | """ 132 | temperature = 5 * 1e-1 133 | with torch.cuda.amp.autocast(): 134 | if self.pretrained_model.__class__.__name__ == "ResNetV2": 135 | output = self.pretrained_model(F.interpolate(input_images, mode='bilinear', size=(128, 128))) 136 | softmax_output = torch.softmax(output, dim = 1) 137 | else: 138 | 139 | output = self.pretrained_model(input_images) 140 | softmax_output = torch.softmax(output, dim = 1) 141 | 142 | 143 | 144 | pred_class_max = torch.argmax(softmax_output, dim = 1).cpu() 145 | 146 | if self.random_ratio is not None: 147 | # logic for random sampling 148 | random_mask = torch.rand((input_images.shape[0], )).le(self.random_ratio) 149 | 150 | random_labels = torch.randint(0, self.num_classes, (input_images.shape[0], )) 151 | pred_class_max = (~random_mask) * pred_class_max + random_mask * random_labels 152 | 153 | pred_class = [(pred_class_max == i).sum().item() for i in range(self.num_classes)] 154 | 155 | ### For debugging 156 | if writer is not None: 157 | pred_class_prob = np.array(pred_class)/np.sum(pred_class) 158 | count_class_prob = np.array(self.count_class_samples)/np.sum(self.count_class_samples) 159 | writer.add_histogram('expected', count_class_prob, self.i) 160 | writer.add_histogram('predicted', pred_class_prob, self.i) 161 | self.i += 1 162 | pred_class_dict = [(k,v) for v, k in enumerate(pred_class)] 163 | pred_class_dict = sorted(pred_class_dict, reverse=True) 164 | 165 | self.pred_class = [ i + j for i,j in zip(self.pred_class, pred_class)] 166 | 167 | 168 | 169 | sm_batch_mean = torch.from_numpy(self.weights).float().to(device = softmax_output.device)* torch.mean(softmax_output,dim=0) 170 | div_loss = torch.sum(sm_batch_mean*torch.log(torch.mean(softmax_output,dim=0))) 171 | 172 | 173 | if labels: 174 | return div_loss, softmax_output 175 | 176 | return div_loss 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | --------------------------------------------------------------------------------