├── .gitignore ├── LICENSE ├── LICENSE-NVIDIA ├── README.md ├── environment.yml ├── images ├── ECGAN.png └── imagenet.png ├── logs ├── CIFAR10 │ └── ecgan_v2_none_1_0p01-train-2021_05_26_16_35_45.log ├── ILSVRC2012 │ └── imagenet_ecgan_v2_contra_1_0p05-train-2021_10_03_00_11_58.log └── TINY_ILSVRC2012 │ └── ecgan_v2_none_1_0p05-train-2021_05_26_16_47_55.log └── src ├── configs ├── CIFAR10 │ └── ecgan_v2_none_1_0p01.json ├── ILSVRC2012 │ └── imagenet_ecgan_v2_contra_1_0p05.json └── TINY_ILSVRC2012 │ └── ecgan_v2_none_1_0p05.json ├── data_utils └── load_dataset.py ├── inception_tf13.py ├── loader.py ├── main.py ├── metrics ├── Accuracy.py ├── CAS.py ├── DCA.py ├── FID.py ├── F_beta.py ├── IS.py ├── IntraFID.py ├── __init__.py ├── inception_network.py └── prepare_inception_moments.py ├── models ├── big_resnet.py ├── big_resnet_deep.py ├── dcgan.py └── resnet.py ├── sync_batchnorm ├── batchnorm.py ├── batchnorm_reimpl.py ├── comm.py ├── replicate.py └── unittest.py ├── utils ├── ada.py ├── ada_op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── biggan_utils.py ├── cr_diff_aug.py ├── diff_aug.py ├── load_checkpoint.py ├── log.py ├── losses.py ├── make_hdf5.py ├── misc.py ├── model_ops.py └── sample.py └── worker.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *.pyc 3 | *.DS_Store 4 | media/ 5 | res/ 6 | 7 | data/* 8 | checkpoints/* 9 | figures/* 10 | *.xlsx 11 | src/data/* 12 | src/checkpoints/* 13 | src/figures/* 14 | src/logs/* 15 | *test* 16 | src/unpublished/* 17 | 18 | # Swap 19 | [._]*.s[a-v][a-z] 20 | !*.svg # comment out if you don't need vector files 21 | [._]*.sw[a-p] 22 | [._]s[a-rt-v][a-z] 23 | [._]ss[a-gi-z] 24 | [._]sw[a-p] 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | PyTorch StudioGAN: 4 | Copyright (c) 2020 MinGuk Kang 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /LICENSE-NVIDIA: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | Nvidia Source Code License-NC 5 | 6 | ======================================================================= 7 | 8 | 1. Definitions 9 | 10 | "Licensor" means any person or entity that distributes its Work. 11 | 12 | "Software" means the original work of authorship made available under 13 | this License. 14 | 15 | "Work" means the Software and any additions to or derivative works of 16 | the Software that are made available under this License. 17 | 18 | "Nvidia Processors" means any central processing unit (CPU), graphics 19 | processing unit (GPU), field-programmable gate array (FPGA), 20 | application-specific integrated circuit (ASIC) or any combination 21 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 22 | 23 | The terms "reproduce," "reproduction," "derivative works," and 24 | "distribution" have the meaning as provided under U.S. copyright law; 25 | provided, however, that for the purposes of this License, derivative 26 | works shall not include works that remain separable from, or merely 27 | link (or bind by name) to the interfaces of, the Work. 28 | 29 | Works, including the Software, are "made available" under this License 30 | by including in or with the Work either (a) a copyright notice 31 | referencing the applicability of this License to the Work, or (b) a 32 | copy of this License. 33 | 34 | 2. License Grants 35 | 36 | 2.1 Copyright Grant. Subject to the terms and conditions of this 37 | License, each Licensor grants to you a perpetual, worldwide, 38 | non-exclusive, royalty-free, copyright license to reproduce, 39 | prepare derivative works of, publicly display, publicly perform, 40 | sublicense and distribute its Work and any resulting derivative 41 | works in any form. 42 | 43 | 3. Limitations 44 | 45 | 3.1 Redistribution. You may reproduce or distribute the Work only 46 | if (a) you do so under this License, (b) you include a complete 47 | copy of this License with your distribution, and (c) you retain 48 | without modification any copyright, patent, trademark, or 49 | attribution notices that are present in the Work. 50 | 51 | 3.2 Derivative Works. You may specify that additional or different 52 | terms apply to the use, reproduction, and distribution of your 53 | derivative works of the Work ("Your Terms") only if (a) Your Terms 54 | provide that the use limitation in Section 3.3 applies to your 55 | derivative works, and (b) you identify the specific derivative 56 | works that are subject to Your Terms. Notwithstanding Your Terms, 57 | this License (including the redistribution requirements in Section 58 | 3.1) will continue to apply to the Work itself. 59 | 60 | 3.3 Use Limitation. The Work and any derivative works thereof only 61 | may be used or intended for use non-commercially. The Work or 62 | derivative works thereof may be used or intended for use by Nvidia 63 | or its affiliates commercially or non-commercially. As used herein, 64 | "non-commercially" means for research or evaluation purposes only. 65 | 66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 67 | against any Licensor (including any claim, cross-claim or 68 | counterclaim in a lawsuit) to enforce any patents that you allege 69 | are infringed by any Work, then your rights under this License from 70 | such Licensor (including the grants in Sections 2.1 and 2.2) will 71 | terminate immediately. 72 | 73 | 3.5 Trademarks. This License does not grant any rights to use any 74 | Licensor's or its affiliates' names, logos, or trademarks, except 75 | as necessary to reproduce the notices described in this License. 76 | 77 | 3.6 Termination. If you violate any term of this License, then your 78 | rights under this License (including the grants in Sections 2.1 and 79 | 2.2) will terminate immediately. 80 | 81 | 4. Disclaimer of Warranty. 82 | 83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 87 | THIS LICENSE. 88 | 89 | 5. Limitation of Liability. 90 | 91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 99 | THE POSSIBILITY OF SUCH DAMAGES. 100 | 101 | ======================================================================= 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Energy-based Conditional Generative Adversarial Network (ECGAN) 2 | 3 | This is the code for the NeurIPS 2021 paper "[A Unified View of cGANs with and without Classifiers](https://arxiv.org/abs/2111.01035)". The repository is modified from [StudioGAN](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN). If you find our work useful, please consider citing the following paper: 4 | ```bib 5 | @inproceedings{chen2021ECGAN, 6 | title = {A Unified View of cGANs with and without Classifiers}, 7 | author = {Si-An Chen and Chun-Liang Li and Hsuan-Tien Lin}, 8 | booktitle = {Advances in Neural Information Processing Systems}, 9 | year = {2021} 10 | } 11 | ``` 12 | Please feel free to contact [Si-An Chen](https://scholar.google.com/citations?hl=en&user=XtkmEncAAAAJ) if you have any questions about the code/paper. 13 | 14 | ## Introduction 15 | We propose a new Conditional Generative Adversarial Network (cGAN) framework called Energy-based Conditional Generative Adversarial Network (ECGAN) which provides a unified view of cGANs and achieves state-of-the-art results. We use the decomposition of the joint probability distribution to connect the goals of cGANs and classification as a unified framework. The framework, along with a classic energy model to parameterize distributions, justifies the use of classifiers for cGANs in a principled manner. It explains several popular cGAN variants, such as ACGAN, ProjGAN, and ContraGAN, as special cases with different levels of approximations. An illustration of the framework is shown below. 16 |

17 | 18 |

19 | 20 | 21 | ## Requirements 22 | 23 | - Anaconda 24 | - Python >= 3.6 25 | - 6.0.0 <= Pillow <= 7.0.0 26 | - scipy == 1.1.0 (Recommended for fast loading of [Inception Network](https://github.com/openai/improved-gan/blob/master/inception_score/model.py)) 27 | - sklearn 28 | - seaborn 29 | - h5py 30 | - tqdm 31 | - torch >= 1.6.0 (Recommended for mixed precision training and knn analysis) 32 | - torchvision >= 0.7.0 33 | - tensorboard 34 | - 5.4.0 <= gcc <= 7.4.0 (Recommended for proper use of [adaptive discriminator augmentation module](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN/tree/master/src/utils/ada_op)) 35 | 36 | 37 | You can install the recommended environment as follows: 38 | 39 | ``` 40 | conda env create -f environment.yml -n studiogan 41 | ``` 42 | 43 | With docker, you can use: 44 | ``` 45 | docker pull mgkang/studiogan:0.1 46 | ``` 47 | 48 | ## Quick Start 49 | 50 | * Train (``-t``) and evaluate (``-e``) the model defined in ``CONFIG_PATH`` using GPU ``0`` 51 | ``` 52 | CUDA_VISIBLE_DEVICES=0 python3 src/main.py -t -e -c CONFIG_PATH 53 | ``` 54 | 55 | * Train (``-t``) and evaluate (``-e``) the model defined in ``CONFIG_PATH`` using GPUs ``(0, 1, 2, 3)`` and ``DataParallel`` 56 | ``` 57 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 src/main.py -t -e -c CONFIG_PATH 58 | ``` 59 | 60 | Try ``python3 src/main.py`` to see available options. 61 | 62 | ## Dataset 63 | 64 | * CIFAR10: StudioGAN will automatically download the dataset once you execute ``main.py``. 65 | 66 | * Tiny Imagenet, Imagenet, or a custom dataset: 67 | 1. download [Tiny Imagenet](https://tiny-imagenet.herokuapp.com) and [Imagenet](http://www.image-net.org). Prepare your own dataset. 68 | 2. make the folder structure of the dataset as follows: 69 | 70 | ``` 71 | ┌── docs 72 | ├── src 73 | └── data 74 | └── ILSVRC2012 or TINY_ILSVRC2012 or CUSTOM 75 |    ├── train 76 |     │   ├── cls0 77 | │ │ ├── train0.png 78 | │ │ ├── train1.png 79 | │ │ └── ... 80 | │ ├── cls1 81 | │ └── ... 82 | └── valid 83 | ├── cls0 84 | │ ├── valid0.png 85 | │ ├── valid1.png 86 | │ └── ... 87 | ├── cls1 88 | └── ... 89 | ``` 90 | 91 | ## Examples and Results 92 | The ``src/configs`` directory contains config files used in our experiments. 93 | 94 | ### CIFAR10 (3x32x32) 95 | To train and evaluate ECGAN-UC on CIFAR10: 96 | ``` 97 | python3 src/main.py -t -e -c src/configs/CIFAR10/ecgan_v2_none_0_0p01.json 98 | ``` 99 | 100 | | Method | Reference | IS(⭡) | FID(⭣) | F_1/8(⭡) | F_8(⭡) | Cfg | Log | Weights | 101 | |---|---|---|---|---|---|---|---|---| 102 | | BigGAN-Mod | StudioGAN | 9.746 | 8.034 | 0.995 | 0.994 | - | - | - | 103 | | ContraGAN | StudioGAN | 9.729 | 8.065 | 0.993 | 0.992 | - | - | - | 104 | | Ours | - | **10.078** | **7.936** | 0.990 | 0.988 | [Cfg](./src/configs/CIFAR10/ecgan_v2_none_1_0p01.json) | [Log](./logs/CIFAR10/ecgan_v2_none_1_0p01-train-2021_05_26_16_35_45.log) | [Link](https://drive.google.com/drive/folders/1Kig2Loo2Ds5N3Pqc85R6c46Hbx5n9heM?usp=sharing) | 105 | 106 | ### Tiny ImageNet (3x64x64) 107 | To train and evaluate ECGAN-UC on Tiny ImageNet: 108 | ``` 109 | python3 src/main.py -t -e -c src/configs/TINY_ILSVRC2012/ecgan_v2_none_0_0p01.json --eval_type valid 110 | ``` 111 | 112 | | Method | Reference | IS(⭡) | FID(⭣) | F_1/8(⭡) | F_8(⭡) | Cfg | Log | Weights | 113 | |---|---|---|---|---|---|---|---|---| 114 | | BigGAN-Mod | StudioGAN | 11.998 | 31.92 | 0.956 | 0.879 | - | - | - | 115 | | ContraGAN | StudioGAN | 13.494 | 27.027 | 0.975 | 0.902 | - | - | - | 116 | | Ours | - | **18.445** | **18.319** | **0.977** | **0.973** | [Cfg](./src/configs/TINY_ILSVRC2012/ecgan_v2_none_1_0p05.json) | [Log](./logs/TINY_ILSVRC2012/ecgan_v2_none_1_0p05-train-2021_05_26_16_47_55.log) | [Link](https://drive.google.com/drive/folders/1oVAIljTEIA3b0BHRVjcnukMf3POQQ3rw?usp=sharing) | 117 | 118 | ### ImageNet (3x128x128) 119 | To train and evaluate ECGAN-UCE on ImageNet (~12 days on 8 NVIDIA V100 GPUs): 120 | ``` 121 | python3 src/main.py -t -e -l -sync_bn -c src/configs/ILSVRC2012/imagenet_ecgan_v2_contra_1_0p05.json --eval_type valid 122 | ``` 123 | 124 | | Method | Reference | IS(⭡) | FID(⭣) | F_1/8(⭡) | F_8(⭡) | Cfg | Log | Weights | 125 | |---|---|---|---|---|---|---|---|---| 126 | | BigGAN | StudioGAN | 28.633 | 24.684 | 0.941 | 0.921 | - | - | - | 127 | | ContraGAN | StudioGAN | 25.249 | 25.161 | 0.947 | 0.855 | - | - | - | 128 | | Ours | - | **80.685** | **8.491** | **0.984** | **0.985** | [Cfg](./src/configs/ILSVRC2012/imagenet_ecgan_v2_contra_1_0p05.json) | [Log](./logs/ILSVRC2012/imagenet_ecgan_v2_contra_1_0p05-train-2021_10_03_00_11_58.log) | [Link](https://drive.google.com/drive/folders/1EkcotNsnA-KBvOCFkvpJpUVoSDRxk-EV?usp=sharing) | 129 | 130 | 131 | ## Generated Images 132 | Here are some selected images generated by ECGAN. 133 |

134 | 135 |

136 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _ipyw_jlab_nb_ext_conf=0.1.0=py37_0 7 | - _libgcc_mutex=0.1=main 8 | - absl-py=0.9.0=py37_0 9 | - alabaster=0.7.12=py37_0 10 | - anaconda=2020.02=py37_0 11 | - anaconda-client=1.7.2=py37_0 12 | - anaconda-navigator=1.9.12=py37_0 13 | - anaconda-project=0.8.4=py_0 14 | - argh=0.26.2=py37_0 15 | - asn1crypto=1.3.0=py37_0 16 | - astroid=2.3.3=py37_0 17 | - astropy=4.0=py37h7b6447c_0 18 | - atomicwrites=1.3.0=py37_1 19 | - attrs=19.3.0=py_0 20 | - autopep8=1.4.4=py_0 21 | - babel=2.8.0=py_0 22 | - backcall=0.1.0=py37_0 23 | - backports=1.0=py_2 24 | - backports.functools_lru_cache=1.6.1=py_0 25 | - backports.shutil_get_terminal_size=1.0.0=py37_2 26 | - backports.tempfile=1.0=py_1 27 | - backports.weakref=1.0.post1=py_1 28 | - beautifulsoup4=4.8.2=py37_0 29 | - bitarray=1.2.1=py37h7b6447c_0 30 | - bkcharts=0.2=py37_0 31 | - blas=1.0=mkl 32 | - bleach=3.1.0=py37_0 33 | - blosc=1.16.3=hd408876_0 34 | - bokeh=1.4.0=py37_0 35 | - boto=2.49.0=py37_0 36 | - bottleneck=1.3.2=py37heb32a55_0 37 | - bzip2=1.0.8=h7b6447c_0 38 | - c-ares=1.15.0=h7b6447c_1001 39 | - ca-certificates=2020.1.1=0 40 | - cairo=1.14.12=h8948797_3 41 | - certifi=2019.11.28=py37_0 42 | - cffi=1.14.0=py37h2e261b9_0 43 | - chardet=3.0.4=py37_1003 44 | - click=7.0=py37_0 45 | - cloudpickle=1.3.0=py_0 46 | - clyent=1.2.2=py37_1 47 | - colorama=0.4.3=py_0 48 | - conda=4.8.4=py37_0 49 | - conda-build=3.18.11=py37_0 50 | - conda-env=2.6.0=1 51 | - conda-package-handling=1.6.0=py37h7b6447c_0 52 | - conda-verify=3.4.2=py_1 53 | - contextlib2=0.6.0.post1=py_0 54 | - cryptography=2.8=py37h1ba5d50_0 55 | - cudatoolkit=10.1.243=h6bb024c_0 56 | - curl=7.68.0=hbc83047_0 57 | - cycler=0.10.0=py37_0 58 | - cython=0.29.15=py37he6710b0_0 59 | - cytoolz=0.10.1=py37h7b6447c_0 60 | - dask=2.11.0=py_0 61 | - dask-core=2.11.0=py_0 62 | - dbus=1.13.12=h746ee38_0 63 | - decorator=4.4.1=py_0 64 | - defusedxml=0.6.0=py_0 65 | - diff-match-patch=20181111=py_0 66 | - distributed=2.11.0=py37_0 67 | - docutils=0.16=py37_0 68 | - entrypoints=0.3=py37_0 69 | - et_xmlfile=1.0.1=py37_0 70 | - expat=2.2.6=he6710b0_0 71 | - fastcache=1.1.0=py37h7b6447c_0 72 | - filelock=3.0.12=py_0 73 | - flake8=3.7.9=py37_0 74 | - flask=1.1.1=py_0 75 | - fontconfig=2.13.0=h9420a91_0 76 | - freetype=2.9.1=h8a8886c_1 77 | - fribidi=1.0.5=h7b6447c_0 78 | - fsspec=0.6.2=py_0 79 | - future=0.18.2=py37_0 80 | - get_terminal_size=1.0.0=haa9412d_0 81 | - gevent=1.4.0=py37h7b6447c_0 82 | - glib=2.63.1=h5a9c865_0 83 | - glob2=0.7=py_0 84 | - gmp=6.1.2=h6c8ec71_1 85 | - gmpy2=2.0.8=py37h10f8cd9_2 86 | - graphite2=1.3.13=h23475e2_0 87 | - greenlet=0.4.15=py37h7b6447c_0 88 | - grpcio=1.27.2=py37hf8bcb03_0 89 | - gst-plugins-base=1.14.0=hbbd80ab_1 90 | - gstreamer=1.14.0=hb453b48_1 91 | - h5py=2.10.0=py37h7918eee_0 92 | - harfbuzz=1.8.8=hffaf4a1_0 93 | - hdf5=1.10.4=hb1b8bf9_0 94 | - heapdict=1.0.1=py_0 95 | - html5lib=1.0.1=py37_0 96 | - hypothesis=5.5.4=py_0 97 | - icu=58.2=h9c2bf20_1 98 | - idna=2.8=py37_0 99 | - imageio=2.6.1=py37_0 100 | - imagesize=1.2.0=py_0 101 | - importlib-metadata=1.7.0=py37_0 102 | - importlib_metadata=1.5.0=py37_0 103 | - intel-openmp=2020.0=166 104 | - intervaltree=3.0.2=py_0 105 | - ipykernel=5.1.4=py37h39e3cac_0 106 | - ipython=7.12.0=py37h5ca1d4c_0 107 | - ipython_genutils=0.2.0=py37_0 108 | - ipywidgets=7.5.1=py_0 109 | - isort=4.3.21=py37_0 110 | - itsdangerous=1.1.0=py37_0 111 | - jbig=2.1=hdba287a_0 112 | - jdcal=1.4.1=py_0 113 | - jedi=0.14.1=py37_0 114 | - jeepney=0.4.2=py_0 115 | - jinja2=2.11.1=py_0 116 | - joblib=0.14.1=py_0 117 | - jpeg=9b=h024ee3a_2 118 | - json5=0.9.1=py_0 119 | - jsonschema=3.2.0=py37_0 120 | - jupyter=1.0.0=py37_7 121 | - jupyter_client=5.3.4=py37_0 122 | - jupyter_console=6.1.0=py_0 123 | - jupyter_core=4.6.1=py37_0 124 | - jupyterlab=1.2.6=pyhf63ae98_0 125 | - jupyterlab_server=1.0.6=py_0 126 | - keyring=21.1.0=py37_0 127 | - kiwisolver=1.1.0=py37he6710b0_0 128 | - krb5=1.17.1=h173b8e3_0 129 | - lazy-object-proxy=1.4.3=py37h7b6447c_0 130 | - ld_impl_linux-64=2.33.1=h53a641e_7 131 | - libarchive=3.3.3=h5d8350f_5 132 | - libcurl=7.68.0=h20c2e04_0 133 | - libedit=3.1.20181209=hc058e9b_0 134 | - libffi=3.2.1=hd88cf55_4 135 | - libgcc-ng=9.1.0=hdf63c60_0 136 | - libgfortran-ng=7.3.0=hdf63c60_0 137 | - liblief=0.9.0=h7725739_2 138 | - libpng=1.6.37=hbc83047_0 139 | - libprotobuf=3.13.0=hd408876_0 140 | - libsodium=1.0.16=h1bed415_0 141 | - libspatialindex=1.9.3=he6710b0_0 142 | - libssh2=1.8.2=h1ba5d50_0 143 | - libstdcxx-ng=9.1.0=hdf63c60_0 144 | - libtiff=4.1.0=h2733197_0 145 | - libtool=2.4.6=h7b6447c_5 146 | - libuuid=1.0.3=h1bed415_2 147 | - libxcb=1.13=h1bed415_1 148 | - libxml2=2.9.9=hea5a465_1 149 | - libxslt=1.1.33=h7d1a2b0_0 150 | - llvmlite=0.31.0=py37hd408876_0 151 | - locket=0.2.0=py37_1 152 | - lxml=4.5.0=py37hefd8a0e_0 153 | - lz4-c=1.8.1.2=h14c3975_0 154 | - lzo=2.10=h49e0be7_2 155 | - markdown=3.2.2=py37_0 156 | - markupsafe=1.1.1=py37h7b6447c_0 157 | - matplotlib=3.1.3=py37_0 158 | - matplotlib-base=3.1.3=py37hef1b27d_0 159 | - mccabe=0.6.1=py37_1 160 | - mistune=0.8.4=py37h7b6447c_0 161 | - mkl=2020.0=166 162 | - mkl-service=2.3.0=py37he904b0f_0 163 | - mkl_fft=1.0.15=py37ha843d7b_0 164 | - mkl_random=1.1.0=py37hd6b4f25_0 165 | - mock=4.0.1=py_0 166 | - more-itertools=8.2.0=py_0 167 | - mpc=1.1.0=h10f8cd9_1 168 | - mpfr=4.0.1=hdf1c602_3 169 | - mpmath=1.1.0=py37_0 170 | - msgpack-python=0.6.1=py37hfd86e86_1 171 | - multipledispatch=0.6.0=py37_0 172 | - navigator-updater=0.2.1=py37_0 173 | - nbconvert=5.6.1=py37_0 174 | - nbformat=5.0.4=py_0 175 | - ncurses=6.2=he6710b0_0 176 | - networkx=2.4=py_0 177 | - ninja=1.10.1=py37hfd86e86_0 178 | - nltk=3.4.5=py37_0 179 | - nose=1.3.7=py37_2 180 | - notebook=6.0.3=py37_0 181 | - numba=0.48.0=py37h0573a6f_0 182 | - numexpr=2.7.1=py37h423224d_0 183 | - numpy=1.18.1=py37h4f9e942_0 184 | - numpy-base=1.18.1=py37hde5b4d6_1 185 | - numpydoc=0.9.2=py_0 186 | - olefile=0.46=py37_0 187 | - openpyxl=3.0.3=py_0 188 | - openssl=1.1.1d=h7b6447c_4 189 | - packaging=20.1=py_0 190 | - pandas=1.0.1=py37h0573a6f_0 191 | - pandoc=2.2.3.2=0 192 | - pandocfilters=1.4.2=py37_1 193 | - pango=1.42.4=h049681c_0 194 | - parso=0.5.2=py_0 195 | - partd=1.1.0=py_0 196 | - patchelf=0.10=he6710b0_0 197 | - path=13.1.0=py37_0 198 | - path.py=12.4.0=0 199 | - pathlib2=2.3.5=py37_0 200 | - pathtools=0.1.2=py_1 201 | - patsy=0.5.1=py37_0 202 | - pcre=8.43=he6710b0_0 203 | - pep8=1.7.1=py37_0 204 | - pexpect=4.8.0=py37_0 205 | - pickleshare=0.7.5=py37_0 206 | - pip=20.0.2=py37_1 207 | - pixman=0.38.0=h7b6447c_0 208 | - pkginfo=1.5.0.1=py37_0 209 | - pluggy=0.13.1=py37_0 210 | - ply=3.11=py37_0 211 | - prometheus_client=0.7.1=py_0 212 | - prompt_toolkit=3.0.3=py_0 213 | - protobuf=3.13.0=py37hf484d3e_0 214 | - psutil=5.6.7=py37h7b6447c_0 215 | - ptyprocess=0.6.0=py37_0 216 | - py=1.8.1=py_0 217 | - py-lief=0.9.0=py37h7725739_2 218 | - pycodestyle=2.5.0=py37_0 219 | - pycosat=0.6.3=py37h7b6447c_0 220 | - pycparser=2.19=py37_0 221 | - pycrypto=2.6.1=py37h14c3975_9 222 | - pycurl=7.43.0.5=py37h1ba5d50_0 223 | - pydocstyle=4.0.1=py_0 224 | - pyflakes=2.1.1=py37_0 225 | - pygments=2.5.2=py_0 226 | - pylint=2.4.4=py37_0 227 | - pyodbc=4.0.30=py37he6710b0_0 228 | - pyopenssl=19.1.0=py37_0 229 | - pyparsing=2.4.6=py_0 230 | - pyqt=5.9.2=py37h05f1152_2 231 | - pyrsistent=0.15.7=py37h7b6447c_0 232 | - pysocks=1.7.1=py37_0 233 | - pytables=3.6.1=py37h71ec239_0 234 | - pytest=5.3.5=py37_0 235 | - pytest-arraydiff=0.3=py37h39e3cac_0 236 | - pytest-astropy=0.8.0=py_0 237 | - pytest-astropy-header=0.1.2=py_0 238 | - pytest-doctestplus=0.5.0=py_0 239 | - pytest-openfiles=0.4.0=py_0 240 | - pytest-remotedata=0.3.2=py37_0 241 | - python=3.7.6=h0371630_2 242 | - python-dateutil=2.8.1=py_0 243 | - python-jsonrpc-server=0.3.4=py_0 244 | - python-language-server=0.31.7=py37_0 245 | - python-libarchive-c=2.8=py37_13 246 | - pytorch=1.6.0=py3.7_cuda10.1.243_cudnn7.6.3_0 247 | - pytz=2019.3=py_0 248 | - pywavelets=1.1.1=py37h7b6447c_0 249 | - pyxdg=0.26=py_0 250 | - pyyaml=5.3=py37h7b6447c_0 251 | - pyzmq=18.1.1=py37he6710b0_0 252 | - qdarkstyle=2.8=py_0 253 | - qt=5.9.7=h5867ecd_1 254 | - qtawesome=0.6.1=py_0 255 | - qtconsole=4.6.0=py_1 256 | - qtpy=1.9.0=py_0 257 | - readline=7.0=h7b6447c_5 258 | - requests=2.22.0=py37_1 259 | - ripgrep=11.0.2=he32d670_0 260 | - rope=0.16.0=py_0 261 | - rtree=0.9.3=py37_0 262 | - ruamel_yaml=0.15.87=py37h7b6447c_0 263 | - scikit-image=0.16.2=py37h0573a6f_0 264 | - scikit-learn=0.22.1=py37hd81dba3_0 265 | - seaborn=0.10.0=py_0 266 | - secretstorage=3.1.2=py37_0 267 | - send2trash=1.5.0=py37_0 268 | - setuptools=45.2.0=py37_0 269 | - simplegeneric=0.8.1=py37_2 270 | - singledispatch=3.4.0.3=py37_0 271 | - sip=4.19.8=py37hf484d3e_0 272 | - six=1.14.0=py37_0 273 | - snappy=1.1.7=hbae5bb6_3 274 | - snowballstemmer=2.0.0=py_0 275 | - sortedcollections=1.1.2=py37_0 276 | - sortedcontainers=2.1.0=py37_0 277 | - soupsieve=1.9.5=py37_0 278 | - sphinx=2.4.0=py_0 279 | - sphinxcontrib=1.0=py37_1 280 | - sphinxcontrib-applehelp=1.0.1=py_0 281 | - sphinxcontrib-devhelp=1.0.1=py_0 282 | - sphinxcontrib-htmlhelp=1.0.2=py_0 283 | - sphinxcontrib-jsmath=1.0.1=py_0 284 | - sphinxcontrib-qthelp=1.0.2=py_0 285 | - sphinxcontrib-serializinghtml=1.1.3=py_0 286 | - sphinxcontrib-websupport=1.2.0=py_0 287 | - spyder=4.0.1=py37_0 288 | - spyder-kernels=1.8.1=py37_0 289 | - sqlalchemy=1.3.13=py37h7b6447c_0 290 | - sqlite=3.31.1=h7b6447c_0 291 | - statsmodels=0.11.0=py37h7b6447c_0 292 | - sympy=1.5.1=py37_0 293 | - tbb=2020.0=hfd86e86_0 294 | - tblib=1.6.0=py_0 295 | - tensorboard=1.15.0=pyhb230dea_0 296 | - terminado=0.8.3=py37_0 297 | - testpath=0.4.4=py_0 298 | - tk=8.6.8=hbc83047_0 299 | - toolz=0.10.0=py_0 300 | - torchvision=0.7.0=py37_cu101 301 | - tornado=6.0.3=py37h7b6447c_3 302 | - tqdm=4.42.1=py_0 303 | - traitlets=4.3.3=py37_0 304 | - ujson=1.35=py37h14c3975_0 305 | - unicodecsv=0.14.1=py37_0 306 | - unixodbc=2.3.7=h14c3975_0 307 | - urllib3=1.25.8=py37_0 308 | - watchdog=0.10.2=py37_0 309 | - wcwidth=0.1.8=py_0 310 | - webencodings=0.5.1=py37_1 311 | - werkzeug=1.0.0=py_0 312 | - wheel=0.34.2=py37_0 313 | - widgetsnbextension=3.5.1=py37_0 314 | - wrapt=1.11.2=py37h7b6447c_0 315 | - wurlitzer=2.0.0=py37_0 316 | - xlrd=1.2.0=py37_0 317 | - xlsxwriter=1.2.7=py_0 318 | - xlwt=1.3.0=py37_0 319 | - xmltodict=0.12.0=py_0 320 | - xz=5.2.4=h14c3975_4 321 | - yaml=0.1.7=had09818_2 322 | - yapf=0.28.0=py_0 323 | - zeromq=4.3.1=he6710b0_3 324 | - zict=1.0.0=py_0 325 | - zipp=2.2.0=py_0 326 | - zlib=1.2.11=h7b6447c_3 327 | - zstd=1.3.7=h0b5b093_0 328 | - pip: 329 | - pillow==6.2.2 330 | - scipy==1.1.0 331 | - sklearn==0.0 332 | prefix: /root/anaconda3 333 | -------------------------------------------------------------------------------- /images/ECGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sian-chen/PyTorch-ECGAN/974e86692611dd3ce4136cf7b2b786f5a011be6b/images/ECGAN.png -------------------------------------------------------------------------------- /images/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sian-chen/PyTorch-ECGAN/974e86692611dd3ce4136cf7b2b786f5a011be6b/images/imagenet.png -------------------------------------------------------------------------------- /src/configs/CIFAR10/ecgan_v2_none_1_0p01.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "cifar10", 4 | "data_path": "./data/cifar10", 5 | "img_size": 32, 6 | "num_classes": 10, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "big_resnet", 15 | "conditional_strategy": "ECGAN", 16 | "pos_collected_numerator": true, 17 | "hypersphere_dim": 512, 18 | "nonlinear_embed": false, 19 | "normalize_embed": true, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 2, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 80, 27 | "shared_dim": 128, 28 | "g_conv_dim": 96, 29 | "d_conv_dim": 96, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 64, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0002, 39 | "g_lr": 0.0002, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.5, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 5, 47 | "total_step": 150000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "cls_disc_lambda": 0.01, 54 | "cls_gen_lambda": "N/A", 55 | "cond_lambda": "N/A", 56 | "uncond_lambda": 1, 57 | "contrastive_type": "ContraGAN", 58 | "contrastive_lambda": 0.0, 59 | "margin": 0.0, 60 | "tempering_type": "discrete", 61 | "tempering_step": 1, 62 | "start_temperature": 1.0, 63 | "end_temperature": 1.0, 64 | 65 | "weight_clipping_for_dis": false, 66 | "weight_clipping_bound": "N/A", 67 | 68 | "gradient_penalty_for_dis": false, 69 | "gradient_penalty_lambda": "N/A", 70 | 71 | "deep_regret_analysis_for_dis": false, 72 | "regret_penalty_lambda": "N/A", 73 | 74 | "cr": false, 75 | "cr_lambda": "N/A", 76 | 77 | "bcr": false, 78 | "real_lambda": "N/A", 79 | "fake_lambda": "N/A", 80 | 81 | "zcr": false, 82 | "gen_lambda": "N/A", 83 | "dis_lambda": "N/A", 84 | "sigma_noise": "N/A" 85 | }, 86 | 87 | "initialization":{ 88 | "g_init": "ortho", 89 | "d_init": "ortho" 90 | }, 91 | 92 | "training_and_sampling_setting":{ 93 | "random_flip_preprocessing": true, 94 | "diff_aug": false, 95 | 96 | "ada": false, 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | 111 | "ema": true, 112 | "ema_decay": 0.9999, 113 | "ema_start": 1000 114 | } 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/configs/ILSVRC2012/imagenet_ecgan_v2_contra_1_0p05.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "imagenet", 4 | "data_path": "./data/ILSVRC2012", 5 | "img_size": 128, 6 | "num_classes": 1000, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "big_resnet", 15 | "conditional_strategy": "ECGAN", 16 | "pos_collected_numerator": true, 17 | "hypersphere_dim": 768, 18 | "nonlinear_embed": false, 19 | "normalize_embed": true, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 4, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 120, 27 | "shared_dim": 128, 28 | "g_conv_dim": 96, 29 | "d_conv_dim": 96, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0002, 39 | "g_lr": 0.00005, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 2, 47 | "total_step": 400000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "cls_disc_lambda": 0.05, 54 | "cls_gen_lambda": "N/A", 55 | "cond_lambda": "N/A", 56 | "uncond_lambda": 1, 57 | "contrastive_type": "ContraGAN", 58 | "contrastive_lambda": 1.0, 59 | "margin": 0.0, 60 | "tempering_type": "discrete", 61 | "tempering_step": 1, 62 | "start_temperature": 1.0, 63 | "end_temperature": 1.0, 64 | 65 | "weight_clipping_for_dis": false, 66 | "weight_clipping_bound": "N/A", 67 | 68 | "gradient_penalty_for_dis": false, 69 | "gradient_penalty_lambda": "N/A", 70 | 71 | "deep_regret_analysis_for_dis": false, 72 | "regret_penalty_lambda": "N/A", 73 | 74 | "cr": false, 75 | "cr_lambda":"N/A", 76 | 77 | "bcr": false, 78 | "real_lambda": "N/A", 79 | "fake_lambda": "N/A", 80 | 81 | "zcr": false, 82 | "gen_lambda": "N/A", 83 | "dis_lambda": "N/A", 84 | "sigma_noise": "N/A" 85 | }, 86 | 87 | "initialization":{ 88 | "g_init": "ortho", 89 | "d_init": "ortho" 90 | }, 91 | 92 | "training_and_sampling_setting":{ 93 | "random_flip_preprocessing": true, 94 | "diff_aug":false, 95 | 96 | "ada": false, 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate":"N/A", 105 | "latent_op_step":"N/A", 106 | "latent_op_step4eval":"N/A", 107 | "latent_op_alpha":"N/A", 108 | "latent_op_beta":"N/A", 109 | "latent_norm_reg_weight":"N/A", 110 | 111 | "ema": true, 112 | "ema_decay": 0.9999, 113 | "ema_start": 20000 114 | } 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/configs/TINY_ILSVRC2012/ecgan_v2_none_1_0p05.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "big_resnet", 15 | "conditional_strategy": "ECGAN", 16 | "pos_collected_numerator": true, 17 | "hypersphere_dim": 768, 18 | "nonlinear_embed": false, 19 | "normalize_embed": true, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 3, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 100, 27 | "shared_dim": 128, 28 | "g_conv_dim": 80, 29 | "d_conv_dim": 80, 30 | "G_depth":"N/A", 31 | "D_depth":"N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": 0.01, 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "cls_disc_lambda": 0.05, 54 | "cls_gen_lambda": "N/A", 55 | "cond_lambda": "N/A", 56 | "uncond_lambda": 1, 57 | "contrastive_type": "ContraGAN", 58 | "contrastive_lambda": 0.0, 59 | "margin": 0.0, 60 | "tempering_type": "discrete", 61 | "tempering_step": 1, 62 | "start_temperature": 1.0, 63 | "end_temperature": 1.0, 64 | 65 | "weight_clipping_for_dis": false, 66 | "weight_clipping_bound": "N/A", 67 | 68 | "gradient_penalty_for_dis": false, 69 | "gradient_penalty_lambda": "N/A", 70 | 71 | "deep_regret_analysis_for_dis": false, 72 | "regret_penalty_lambda": "N/A", 73 | 74 | "cr": false, 75 | "cr_lambda": "N/A", 76 | 77 | "bcr": false, 78 | "real_lambda": "N/A", 79 | "fake_lambda": "N/A", 80 | 81 | "zcr": false, 82 | "gen_lambda": "N/A", 83 | "dis_lambda": "N/A", 84 | "sigma_noise": "N/A" 85 | }, 86 | 87 | "initialization":{ 88 | "g_init": "ortho", 89 | "d_init": "ortho" 90 | }, 91 | 92 | "training_and_sampling_setting":{ 93 | "random_flip_preprocessing": true, 94 | "diff_aug": false, 95 | 96 | "ada": false, 97 | "ada_target": "N/A", 98 | "ada_length": "N/A", 99 | 100 | "prior": "gaussian", 101 | "truncated_factor": 1, 102 | 103 | "latent_op": false, 104 | "latent_op_rate": "N/A", 105 | "latent_op_step": "N/A", 106 | "latent_op_step4eval": "N/A", 107 | "latent_op_alpha": "N/A", 108 | "latent_op_beta": "N/A", 109 | "latent_norm_reg_weight": "N/A", 110 | 111 | "ema": true, 112 | "ema_decay": 0.9999, 113 | "ema_start": 20000 114 | } 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/data_utils/load_dataset.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/data_utils/load_dataset.py 6 | 7 | 8 | import os 9 | import h5py as h5 10 | import numpy as np 11 | import random 12 | from scipy import io 13 | from PIL import ImageOps, Image 14 | 15 | import torch 16 | import torchvision.transforms as transforms 17 | from torch.utils.data import Dataset 18 | from torchvision.datasets import CIFAR10, STL10 19 | from torchvision.datasets import ImageFolder 20 | 21 | 22 | 23 | class RandomCropLongEdge(object): 24 | """ 25 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch 26 | MIT License 27 | Copyright (c) 2019 Andy Brock 28 | """ 29 | def __call__(self, img): 30 | size = (min(img.size), min(img.size)) 31 | # Only step forward along this edge if it's the long edge 32 | i = (0 if size[0] == img.size[0] 33 | else np.random.randint(low=0,high=img.size[0] - size[0])) 34 | j = (0 if size[1] == img.size[1] 35 | else np.random.randint(low=0,high=img.size[1] - size[1])) 36 | return transforms.functional.crop(img, i, j, size[0], size[1]) 37 | 38 | def __repr__(self): 39 | return self.__class__.__name__ 40 | 41 | 42 | class CenterCropLongEdge(object): 43 | """ 44 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch 45 | MIT License 46 | Copyright (c) 2019 Andy Brock 47 | """ 48 | def __call__(self, img): 49 | return transforms.functional.center_crop(img, min(img.size)) 50 | 51 | def __repr__(self): 52 | return self.__class__.__name__ 53 | 54 | 55 | class LoadDataset(Dataset): 56 | def __init__(self, dataset_name, data_path, train, download, resize_size, hdf5_path=None, random_flip=False): 57 | super(LoadDataset, self).__init__() 58 | self.dataset_name = dataset_name 59 | self.data_path = data_path 60 | self.train = train 61 | self.download = download 62 | self.resize_size = resize_size 63 | self.hdf5_path = hdf5_path 64 | self.random_flip = random_flip 65 | self.norm_mean = [0.5,0.5,0.5] 66 | self.norm_std = [0.5,0.5,0.5] 67 | 68 | if self.hdf5_path is None: 69 | if self.dataset_name in ['cifar10', 'tiny_imagenet']: 70 | self.transforms = [] 71 | elif self.dataset_name in ['imagenet', 'custom']: 72 | if train: 73 | self.transforms = [RandomCropLongEdge(), transforms.Resize(self.resize_size)] 74 | else: 75 | self.transforms = [CenterCropLongEdge(), transforms.Resize(self.resize_size)] 76 | else: 77 | self.transforms = [transforms.ToPILImage()] 78 | 79 | if random_flip: 80 | self.transforms += [transforms.RandomHorizontalFlip()] 81 | 82 | self.transforms += [transforms.ToTensor(), transforms.Normalize(self.norm_mean, self.norm_std)] 83 | self.transforms = transforms.Compose(self.transforms) 84 | 85 | self.load_dataset() 86 | 87 | 88 | def load_dataset(self): 89 | if self.dataset_name == 'cifar10': 90 | if self.hdf5_path is not None: 91 | print('Loading %s into memory...' % self.hdf5_path) 92 | with h5.File(self.hdf5_path, 'r') as f: 93 | self.data = f['imgs'][:] 94 | self.labels = f['labels'][:] 95 | else: 96 | self.data = CIFAR10(root=os.path.join('data', self.dataset_name), 97 | train=self.train, 98 | download=self.download) 99 | 100 | elif self.dataset_name == 'imagenet': 101 | if self.hdf5_path is not None: 102 | print('Loading %s into memory...' % self.hdf5_path) 103 | with h5.File(self.hdf5_path, 'r') as f: 104 | self.data = f['imgs'][:] 105 | self.labels = f['labels'][:] 106 | else: 107 | mode = 'train' if self.train == True else 'valid' 108 | root = os.path.join('data','ILSVRC2012', mode) 109 | self.data = ImageFolder(root=root) 110 | 111 | elif self.dataset_name == "tiny_imagenet": 112 | if self.hdf5_path is not None: 113 | print('Loading %s into memory...' % self.hdf5_path) 114 | with h5.File(self.hdf5_path, 'r') as f: 115 | self.data = f['imgs'][:] 116 | self.labels = f['labels'][:] 117 | else: 118 | mode = 'train' if self.train == True else 'valid' 119 | root = os.path.join('data','TINY_ILSVRC2012', mode) 120 | self.data = ImageFolder(root=root) 121 | 122 | elif self.dataset_name == "custom": 123 | if self.hdf5_path is not None: 124 | print('Loading %s into memory...' % self.hdf5_path) 125 | with h5.File(self.hdf5_path, 'r') as f: 126 | self.data = f['imgs'][:] 127 | self.labels = f['labels'][:] 128 | else: 129 | mode = 'train' if self.train == True else 'valid' 130 | root = os.path.join(self.data_path, mode) 131 | self.data = ImageFolder(root=root) 132 | else: 133 | raise NotImplementedError 134 | 135 | 136 | def __len__(self): 137 | if self.hdf5_path is not None: 138 | num_dataset = self.data.shape[0] 139 | else: 140 | num_dataset = len(self.data) 141 | return num_dataset 142 | 143 | 144 | def __getitem__(self, index): 145 | if self.hdf5_path is None: 146 | img, label = self.data[index] 147 | img, label = self.transforms(img), int(label) 148 | else: 149 | img, label = np.transpose(self.data[index], (1,2,0)), int(self.labels[index]) 150 | img = self.transforms(img) 151 | return img, label 152 | -------------------------------------------------------------------------------- /src/inception_tf13.py: -------------------------------------------------------------------------------- 1 | ''' Tensorflow inception score code 2 | Derived from https://github.com/openai/improved-gan 3 | Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 4 | THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in PARALLEL BATCH MODE 5 | To use this code, run sample.py on your model with --sample_npz, and then 6 | pass the experiment name in the --experiment_name. 7 | This code also saves pool3 stats to an npz file for FID calculation 8 | ''' 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import os.path 14 | import sys 15 | import tarfile 16 | import math 17 | from tqdm import tqdm, trange 18 | from argparse import ArgumentParser 19 | 20 | import numpy as np 21 | from six.moves import urllib 22 | import tensorflow as tf 23 | 24 | MODEL_DIR = './inception_model' 25 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 26 | softmax = None 27 | 28 | def prepare_parser(): 29 | usage = 'Parser for TF1.3- Inception Score scripts.' 30 | parser = ArgumentParser(description=usage) 31 | parser.add_argument( 32 | '--run_name', type=str, default='', 33 | help='Which experiment''s samples.npz file to pull and evaluate') 34 | parser.add_argument( 35 | '--type', type=str, default='', 36 | help='[real, fake]') 37 | parser.add_argument( 38 | '--batch_size', type=int, default=500, 39 | help='Default overall batchsize (default: %(default)s)') 40 | return parser 41 | 42 | 43 | def run(config): 44 | # Inception with TF1.3 or earlier. 45 | # Call this function with list of images. Each of elements should be a 46 | # numpy array with values ranging from 0 to 255. 47 | def get_inception_score(images, splits=10): 48 | assert(type(images) == list) 49 | assert(type(images[0]) == np.ndarray) 50 | assert(len(images[0].shape) == 3) 51 | assert(np.max(images[0]) > 10) 52 | assert(np.min(images[0]) >= 0.0) 53 | inps = [] 54 | for img in images: 55 | img = img.astype(np.float32) 56 | inps.append(np.expand_dims(img, 0)) 57 | bs = config['batch_size'] 58 | with tf.Session() as sess: 59 | preds, pools = [], [] 60 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 61 | for i in trange(n_batches): 62 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 63 | inp = np.concatenate(inp, 0) 64 | pred, pool = sess.run([softmax, pool3], {'ExpandDims:0': inp}) 65 | preds.append(pred) 66 | pools.append(pool) 67 | preds = np.concatenate(preds, 0) 68 | scores = [] 69 | for i in range(splits): 70 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 71 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 72 | kl = np.mean(np.sum(kl, 1)) 73 | scores.append(np.exp(kl)) 74 | return np.mean(scores), np.std(scores), np.squeeze(np.concatenate(pools, 0)) 75 | # Init inception 76 | def _init_inception(): 77 | global softmax, pool3 78 | if not os.path.exists(MODEL_DIR): 79 | os.makedirs(MODEL_DIR) 80 | filename = DATA_URL.split('/')[-1] 81 | filepath = os.path.join(MODEL_DIR, filename) 82 | if not os.path.exists(filepath): 83 | def _progress(count, block_size, total_size): 84 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 85 | filename, float(count * block_size) / float(total_size) * 100.0)) 86 | sys.stdout.flush() 87 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 88 | print() 89 | statinfo = os.stat(filepath) 90 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 91 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 92 | with tf.gfile.FastGFile(os.path.join( 93 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 94 | graph_def = tf.GraphDef() 95 | graph_def.ParseFromString(f.read()) 96 | _ = tf.import_graph_def(graph_def, name='') 97 | # Works with an arbitrary minibatch size. 98 | with tf.Session() as sess: 99 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 100 | ops = pool3.graph.get_operations() 101 | for op_idx, op in enumerate(ops): 102 | for o in op.outputs: 103 | shape = o.get_shape() 104 | shape = [s.value for s in shape] 105 | new_shape = [] 106 | for j, s in enumerate(shape): 107 | if s == 1 and j == 0: 108 | new_shape.append(None) 109 | else: 110 | new_shape.append(s) 111 | o._shape = tf.TensorShape(new_shape) 112 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 113 | logits = tf.matmul(tf.squeeze(pool3), w) 114 | softmax = tf.nn.softmax(logits) 115 | 116 | # if softmax is None: # No need to functionalize like this. 117 | _init_inception() 118 | 119 | fname = '%s/%s/%s/%s/samples.npz' % ("samples", config['run_name'], config['type'], "npz") 120 | print('loading %s ...'%fname) 121 | ims = np.load(fname)['x'] 122 | import time 123 | t0 = time.time() 124 | inc_mean, inc_std, pool_activations = get_inception_score(list(ims.swapaxes(1,2).swapaxes(2,3)), splits=1) 125 | t1 = time.time() 126 | print('Inception took %3f seconds, score of %3f +/- %3f.'%(t1-t0, inc_mean, inc_std)) 127 | def main(): 128 | # parse command line and run 129 | parser = prepare_parser() 130 | config = vars(parser.parse_args()) 131 | print(config) 132 | run(config) 133 | 134 | if __name__ == '__main__': 135 | main() 136 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/main.py 6 | 7 | 8 | import json 9 | import os 10 | import sys 11 | import warnings 12 | from argparse import ArgumentParser 13 | 14 | from utils.misc import * 15 | from utils.make_hdf5 import make_hdf5 16 | from utils.log import make_run_name 17 | from loader import prepare_train_eval 18 | 19 | import torch 20 | from torch.backends import cudnn 21 | import torch.multiprocessing as mp 22 | 23 | 24 | 25 | RUN_NAME_FORMAT = ( 26 | "{framework}-" 27 | "{phase}-" 28 | "{timestamp}" 29 | ) 30 | 31 | 32 | def main(): 33 | parser = ArgumentParser(add_help=False) 34 | parser.add_argument('-c', '--config_path', type=str, default='./src/configs/CIFAR10/ContraGAN.json') 35 | parser.add_argument('--checkpoint_folder', type=str, default=None) 36 | parser.add_argument('-current', '--load_current', action='store_true', help='whether you load the current or best checkpoint') 37 | parser.add_argument('--log_output_path', type=str, default=None) 38 | 39 | parser.add_argument('--seed', type=int, default=-1, help='seed for generating random numbers') 40 | parser.add_argument('-DDP', '--distributed_data_parallel', action='store_true') 41 | parser.add_argument('--num_workers', type=int, default=8, help='') 42 | parser.add_argument('-sync_bn', '--synchronized_bn', action='store_true', help='whether turn on synchronized batchnorm') 43 | parser.add_argument('-mpc', '--mixed_precision', action='store_true', help='whether turn on mixed precision training') 44 | parser.add_argument('-rm_API', '--disable_debugging_API', action='store_true', help='whether disable pytorch autograd debugging mode') 45 | 46 | parser.add_argument('--reduce_train_dataset', type=float, default=1.0, help='control the number of train dataset') 47 | parser.add_argument('-stat_otf', '--bn_stat_OnTheFly', action='store_true', help='when evaluating, use the statistics of a batch') 48 | parser.add_argument('-std_stat', '--standing_statistics', action='store_true') 49 | parser.add_argument('--standing_step', type=int, default=-1, help='# of steps for accumulation batchnorm') 50 | parser.add_argument('--freeze_layers', type=int, default=-1, help='# of layers for freezing discriminator') 51 | 52 | parser.add_argument('-l', '--load_all_data_in_memory', action='store_true') 53 | parser.add_argument('-t', '--train', action='store_true') 54 | parser.add_argument('-e', '--eval', action='store_true') 55 | parser.add_argument('-s', '--save_images', action='store_true') 56 | parser.add_argument('-iv', '--image_visualization', action='store_true', help='select whether conduct image visualization') 57 | parser.add_argument('-knn', '--k_nearest_neighbor', action='store_true', help='select whether conduct k-nearest neighbor analysis') 58 | parser.add_argument('-itp', '--interpolation', action='store_true', help='whether conduct interpolation analysis') 59 | parser.add_argument('-fa', '--frequency_analysis', action='store_true', help='whether conduct frequency analysis') 60 | parser.add_argument('-tsne', '--tsne_analysis', action='store_true', help='whether conduct tsne analysis') 61 | parser.add_argument('--nrow', type=int, default=10, help='number of rows to plot image canvas') 62 | parser.add_argument('--ncol', type=int, default=8, help='number of cols to plot image canvas') 63 | 64 | parser.add_argument('--print_every', type=int, default=100, help='control log interval') 65 | parser.add_argument('--save_every', type=int, default=2000, help='control evaluation and save interval') 66 | parser.add_argument('--eval_type', type=str, default='test', help='[train/valid/test]') 67 | args = parser.parse_args() 68 | 69 | if not args.train and \ 70 | not args.eval and \ 71 | not args.save_images and \ 72 | not args.image_visualization and \ 73 | not args.k_nearest_neighbor and \ 74 | not args.interpolation and \ 75 | not args.frequency_analysis and \ 76 | not args.tsne_analysis: 77 | parser.print_help(sys.stderr) 78 | sys.exit(1) 79 | 80 | if args.config_path is not None: 81 | with open(args.config_path) as f: 82 | model_config = json.load(f) 83 | train_config = vars(args) 84 | else: 85 | raise NotImplementedError 86 | 87 | if model_config['data_processing']['dataset_name'] == 'cifar10': 88 | assert train_config['eval_type'] in ['train', 'test'], "Cifar10 does not contain dataset for validation." 89 | elif model_config['data_processing']['dataset_name'] in ['imagenet', 'tiny_imagenet', 'custom']: 90 | assert train_config['eval_type'] == 'train' or train_config['eval_type'] == 'valid', \ 91 | "StudioGAN dose not support the evalutation protocol that uses the test dataset on imagenet, tiny imagenet, and custom datasets" 92 | 93 | if train_config['distributed_data_parallel']: 94 | msg = "StudioGAN does not support image visualization, k_nearest_neighbor, interpolation, and frequency_analysis with DDP. " +\ 95 | "Please change DDP with a single GPU training or DataParallel instead." 96 | assert train_config['image_visualization'] + train_config['k_nearest_neighbor'] + \ 97 | train_config['interpolation'] + train_config['frequency_analysis'] + train_config['tsne_analysis'] == 0, msg 98 | 99 | hdf5_path_train = make_hdf5(model_config['data_processing'], train_config, mode="train") \ 100 | if train_config['load_all_data_in_memory'] else None 101 | 102 | if train_config['seed'] == -1: 103 | cudnn.benchmark, cudnn.deterministic = True, False 104 | else: 105 | fix_all_seed(train_config['seed']) 106 | cudnn.benchmark, cudnn.deterministic = False, True 107 | 108 | world_size, rank = torch.cuda.device_count(), torch.cuda.current_device() 109 | if world_size == 1: warnings.warn('You have chosen a specific GPU. This will completely disable data parallelism.') 110 | 111 | if train_config['disable_debugging_API']: torch.autograd.set_detect_anomaly(False) 112 | check_flag_0(model_config['train']['optimization']['batch_size'], world_size, train_config['freeze_layers'], train_config['checkpoint_folder'], 113 | model_config['train']['model']['architecture'], model_config['data_processing']['img_size']) 114 | 115 | run_name = make_run_name(RUN_NAME_FORMAT, framework=train_config['config_path'].split('/')[-1][:-5], phase='train') 116 | 117 | if train_config['distributed_data_parallel'] and world_size > 1: 118 | print("Train the models through DistributedDataParallel (DDP) mode.") 119 | mp.spawn(prepare_train_eval, nprocs=world_size, args=(world_size, run_name, train_config, model_config, hdf5_path_train)) 120 | else: 121 | prepare_train_eval(rank, world_size, run_name, train_config, model_config, hdf5_path_train=hdf5_path_train) 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /src/metrics/Accuracy.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/metrics/Accuracy.py 6 | 7 | 8 | import numpy as np 9 | import math 10 | from scipy import linalg 11 | from tqdm import tqdm 12 | 13 | from utils.sample import sample_latents 14 | from utils.losses import latent_optimise 15 | 16 | import torch 17 | from torch.nn import DataParallel 18 | from torch.nn.parallel import DistributedDataParallel 19 | 20 | 21 | 22 | def calculate_accuracy(dataloader, generator, discriminator, D_loss, num_evaluate, truncated_factor, prior, latent_op, 23 | latent_op_step, latent_op_alpha, latent_op_beta, device, cr, logger, eval_generated_sample=False): 24 | data_iter = iter(dataloader) 25 | batch_size = dataloader.batch_size 26 | disable_tqdm = device != 0 27 | 28 | if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel): 29 | z_dim = generator.module.z_dim 30 | num_classes = generator.module.num_classes 31 | conditional_strategy = discriminator.module.conditional_strategy 32 | else: 33 | z_dim = generator.z_dim 34 | num_classes = generator.num_classes 35 | conditional_strategy = discriminator.conditional_strategy 36 | 37 | total_batch = num_evaluate//batch_size 38 | 39 | if D_loss.__name__ in ["loss_dcgan_dis", "loss_lsgan_dis"]: 40 | cutoff = 0.5 41 | elif D_loss.__name__ == "loss_hinge_dis": 42 | cutoff = 0.0 43 | elif D_loss.__name__ == "loss_wgan_dis": 44 | raise NotImplementedError 45 | 46 | if device == 0: logger.info("Calculate Accuracies....") 47 | 48 | if eval_generated_sample: 49 | for batch_id in tqdm(range(total_batch), disable=disable_tqdm): 50 | zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device) 51 | if latent_op: 52 | zs = latent_optimise(zs, fake_labels, generator, discriminator, conditional_strategy, latent_op_step, 53 | 1.0, latent_op_alpha, latent_op_beta, False, device) 54 | 55 | real_images, real_labels = next(data_iter) 56 | real_images, real_labels = real_images.to(device), real_labels.to(device) 57 | 58 | fake_images = generator(zs, fake_labels, evaluation=True) 59 | 60 | with torch.no_grad(): 61 | if conditional_strategy in ["ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"]: 62 | _, _, dis_out_fake = discriminator(fake_images, fake_labels) 63 | _, _, dis_out_real = discriminator(real_images, real_labels) 64 | elif conditional_strategy == "ACGAN": 65 | _, dis_out_fake = discriminator(fake_images, fake_labels) 66 | _, dis_out_real = discriminator(real_images, real_labels) 67 | elif conditional_strategy == "ProjGAN" or conditional_strategy == "no": 68 | dis_out_fake = discriminator(fake_images, fake_labels) 69 | dis_out_real = discriminator(real_images, real_labels) 70 | elif conditional_strategy == 'ECGAN': 71 | _, dis_out_fake, _, _, _ = discriminator(fake_images, fake_labels) 72 | _, dis_out_real, _, _, _ = discriminator(real_images, real_labels) 73 | else: 74 | raise NotImplementedError 75 | 76 | dis_out_fake = dis_out_fake.detach().cpu().numpy() 77 | dis_out_real = dis_out_real.detach().cpu().numpy() 78 | 79 | if batch_id == 0: 80 | confid = np.concatenate((dis_out_fake, dis_out_real), axis=0) 81 | confid_label = np.concatenate(([0.0]*len(dis_out_fake), [1.0]*len(dis_out_real)), axis=0) 82 | else: 83 | confid = np.concatenate((confid, dis_out_fake, dis_out_real), axis=0) 84 | confid_label = np.concatenate((confid_label, [0.0]*len(dis_out_fake), [1.0]*len(dis_out_real)), axis=0) 85 | 86 | real_confid = confid[confid_label==1.0] 87 | fake_confid = confid[confid_label==0.0] 88 | 89 | true_positive = real_confid[np.where(real_confid>cutoff)] 90 | true_negative = fake_confid[np.where(fake_confidcutoff)] 124 | only_real_acc = len(true_positive)/len(real_confid) 125 | 126 | return only_real_acc 127 | -------------------------------------------------------------------------------- /src/metrics/CAS.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | from utils.sample import sample_latents 5 | from utils.losses import latent_optimise 6 | 7 | import torch 8 | from torch.nn import DataParallel 9 | from torch.nn.parallel import DistributedDataParallel 10 | from pytorchcv.model_provider import get_model as ptcv_get_model 11 | 12 | 13 | def calculate_classifier_accuracy_score(dataloader, generator, discriminator, num_evaluate, truncated_factor, prior, latent_op, 14 | latent_op_step, latent_op_alpha, latent_op_beta, device, logger, eval_generated_sample=False): 15 | data_iter = iter(dataloader) 16 | batch_size = dataloader.batch_size 17 | disable_tqdm = device != 0 18 | net = ptcv_get_model('wrn40_8_cifar10', pretrained=True).to(device) 19 | 20 | if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel): 21 | z_dim = generator.module.z_dim 22 | num_classes = generator.module.num_classes 23 | conditional_strategy = discriminator.module.conditional_strategy 24 | else: 25 | z_dim = generator.z_dim 26 | num_classes = generator.num_classes 27 | conditional_strategy = discriminator.conditional_strategy 28 | 29 | total_batch = num_evaluate//batch_size 30 | 31 | if device == 0: logger.info("Calculate Classifier Accuracy Score....") 32 | 33 | all_pred_fake, all_pred_real, all_fake_labels, all_real_labels = [], [], [], [] 34 | for batch_id in tqdm(range(total_batch), disable=disable_tqdm): 35 | zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device) 36 | if latent_op: 37 | zs = latent_optimise(zs, fake_labels, generator, discriminator, conditional_strategy, latent_op_step, 38 | 1.0, latent_op_alpha, latent_op_beta, False, device) 39 | 40 | real_images, real_labels = next(data_iter) 41 | real_images, real_labels = real_images.to(device), real_labels.to(device) 42 | 43 | fake_images = generator(zs, fake_labels, evaluation=True) 44 | 45 | with torch.no_grad(): 46 | pred_fake = net(fake_images).detach().cpu().numpy() 47 | pred_real = net(real_images).detach().cpu().numpy() 48 | 49 | all_pred_fake.append(pred_fake) 50 | all_pred_real.append(pred_real) 51 | all_fake_labels.append(fake_labels.cpu().numpy()) 52 | all_real_labels.append(real_labels.cpu().numpy()) 53 | 54 | all_pred_fake = np.concatenate(all_pred_fake, axis=0).argmax(axis=1) 55 | all_pred_real = np.concatenate(all_pred_real, axis=0).argmax(axis=1) 56 | all_fake_labels = np.concatenate(all_fake_labels) 57 | all_real_labels = np.concatenate(all_real_labels) 58 | 59 | fake_cas = (all_pred_fake == all_fake_labels).mean() 60 | real_cas = (all_pred_real == all_real_labels).mean() 61 | 62 | return real_cas, fake_cas 63 | -------------------------------------------------------------------------------- /src/metrics/DCA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import softmax 3 | from scipy.stats import entropy 4 | from tqdm import tqdm 5 | 6 | from utils.sample import sample_latents 7 | from utils.losses import latent_optimise 8 | 9 | import torch 10 | from torch.nn import DataParallel 11 | from torch.nn.parallel import DistributedDataParallel 12 | 13 | 14 | def calculate_discriminator_classification_accuracy(dataloader, generator, discriminator, num_evaluate, truncated_factor, prior, latent_op, 15 | latent_op_step, latent_op_alpha, latent_op_beta, device, logger, eval_generated_sample=False): 16 | data_iter = iter(dataloader) 17 | batch_size = dataloader.batch_size 18 | disable_tqdm = device != 0 19 | 20 | if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel): 21 | conditional_strategy = discriminator.module.conditional_strategy 22 | else: 23 | conditional_strategy = discriminator.conditional_strategy 24 | 25 | total_batch = num_evaluate//batch_size 26 | 27 | if device == 0: logger.info("Calculate Discriminator Classification Accuracy....") 28 | 29 | all_pred_real, all_real_labels = [], [] 30 | # all_cond_output, all_uncond_output = [], [] 31 | for batch_id in tqdm(range(total_batch), disable=disable_tqdm): 32 | real_images, real_labels = next(data_iter) 33 | real_images, real_labels = real_images.to(device), real_labels.to(device) 34 | 35 | with torch.no_grad(): 36 | if conditional_strategy == "ACGAN": 37 | cls_out_real, _ = discriminator(real_images, real_labels) 38 | elif conditional_strategy == 'ECGAN': 39 | cls_out_real, cond_output, uncond_output, _, _ = discriminator(real_images, real_labels) 40 | pred_real = cls_out_real.detach().cpu().numpy() 41 | 42 | all_pred_real.append(pred_real) 43 | all_real_labels.append(real_labels.cpu().numpy()) 44 | # all_cond_output.append(cond_output.detach().cpu().numpy()) 45 | # all_uncond_output.append(uncond_output.detach().cpu().numpy()) 46 | 47 | # mean_cond_output = np.abs(np.concatenate(all_cond_output)).mean() 48 | # mean_uncond_output = np.abs(np.concatenate(all_uncond_output)).mean() 49 | # print(f'Conditional Output Mean: {mean_cond_output}') 50 | # print(f'Unconditional Output Mean: {mean_uncond_output}') 51 | all_pred_logits = np.concatenate(all_pred_real) 52 | all_pred_entropy = entropy(softmax(all_pred_logits, axis=0), axis=0).mean() 53 | all_pred_real = all_pred_logits.argmax(axis=1) 54 | all_real_labels = np.concatenate(all_real_labels) 55 | 56 | dca = (all_pred_real == all_real_labels).mean() 57 | 58 | return dca, all_pred_entropy 59 | -------------------------------------------------------------------------------- /src/metrics/FID.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead of Tensorflow 4 | Copyright 2018 Institute of Bioinformatics, JKU Linz 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | 8 | You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import numpy as np 19 | import math 20 | import os 21 | import shutil 22 | from os.path import dirname, abspath, exists, join 23 | from scipy import linalg 24 | from tqdm import tqdm 25 | 26 | from utils.sample import sample_latents 27 | from utils.losses import latent_optimise 28 | 29 | import torch 30 | from torch.nn import DataParallel 31 | from torch.nn.parallel import DistributedDataParallel 32 | from torchvision.utils import save_image 33 | 34 | 35 | 36 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 37 | """Numpy implementation of the Frechet Distance. 38 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 39 | and X_2 ~ N(mu_2, C_2) is 40 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 41 | Stable version by Dougal J. Sutherland. 42 | Params: 43 | -- mu1 : Numpy array containing the activations of a layer of the 44 | inception net (like returned by the function 'get_predictions') 45 | for generated samples. 46 | -- mu2 : The sample mean over activations, precalculated on an 47 | representative data set. 48 | -- sigma1: The covariance matrix over activations for generated samples. 49 | -- sigma2: The covariance matrix over activations, precalculated on an 50 | representative data set. 51 | Returns: 52 | -- : The Frechet Distance. 53 | """ 54 | 55 | mu1 = np.atleast_1d(mu1) 56 | mu2 = np.atleast_1d(mu2) 57 | 58 | sigma1 = np.atleast_2d(sigma1) 59 | sigma2 = np.atleast_2d(sigma2) 60 | 61 | assert mu1.shape == mu2.shape, \ 62 | 'Training and test mean vectors have different lengths' 63 | assert sigma1.shape == sigma2.shape, \ 64 | 'Training and test covariances have different dimensions' 65 | 66 | diff = mu1 - mu2 67 | 68 | # Product might be almost singular 69 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 70 | if not np.isfinite(covmean).all(): 71 | offset = np.eye(sigma1.shape[0]) * eps 72 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 73 | 74 | # Numerical error might give slight imaginary component 75 | if np.iscomplexobj(covmean): 76 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 77 | m = np.max(np.abs(covmean.imag)) 78 | raise ValueError('Imaginary component {}'.format(m)) 79 | covmean = covmean.real 80 | 81 | tr_covmean = np.trace(covmean) 82 | return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) 83 | 84 | def generate_images(batch_size, gen, dis, truncated_factor, prior, latent_op, latent_op_step, 85 | latent_op_alpha, latent_op_beta, device): 86 | if isinstance(gen, DataParallel) or isinstance(gen, DistributedDataParallel): 87 | z_dim = gen.module.z_dim 88 | num_classes = gen.module.num_classes 89 | conditional_strategy = dis.module.conditional_strategy 90 | else: 91 | z_dim = gen.z_dim 92 | num_classes = gen.num_classes 93 | conditional_strategy = dis.conditional_strategy 94 | 95 | zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device) 96 | 97 | if latent_op: 98 | zs = latent_optimise(zs, fake_labels, gen, dis, conditional_strategy, latent_op_step, 1.0, latent_op_alpha, 99 | latent_op_beta, False, device) 100 | 101 | with torch.no_grad(): 102 | batch_images = gen(zs, fake_labels, evaluation=True) 103 | 104 | return batch_images, fake_labels 105 | 106 | 107 | def get_activations(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, is_generate, 108 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable=False, run_name=None): 109 | """Calculates the activations of the pool_3 layer for all images. 110 | Params: 111 | -- data_loader : data_loader of training images 112 | -- generator : instance of GANs' generator 113 | -- inception_model : Instance of inception model 114 | 115 | Returns: 116 | -- A numpy array of dimension (num images, dims) that contains the 117 | activations of the given tensor when feeding inception with the 118 | query tensor. 119 | """ 120 | if is_generate is True: 121 | batch_size = data_loader.batch_size 122 | total_instance = n_generate 123 | n_batches = math.ceil(float(total_instance) / float(batch_size)) 124 | else: 125 | batch_size = data_loader.batch_size 126 | total_instance = len(data_loader.dataset) 127 | n_batches = math.ceil(float(total_instance) / float(batch_size)) 128 | data_iter = iter(data_loader) 129 | 130 | num_classes = generator.module.num_classes if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel) else generator.num_classes 131 | pred_arr = np.empty((total_instance, 2048)) 132 | 133 | for i in tqdm(range(0, n_batches), disable=tqdm_disable): 134 | start = i*batch_size 135 | end = start + batch_size 136 | if is_generate is True: 137 | images, labels = generate_images(batch_size, generator, discriminator, truncated_factor, prior, latent_op, 138 | latent_op_step, latent_op_alpha, latent_op_beta, device) 139 | images = images.to(device) 140 | 141 | with torch.no_grad(): 142 | embeddings, logits = inception_model(images) 143 | 144 | if total_instance >= batch_size: 145 | pred_arr[start:end] = embeddings.cpu().data.numpy().reshape(batch_size, -1) 146 | else: 147 | pred_arr[start:] = embeddings[:total_instance].cpu().data.numpy().reshape(total_instance, -1) 148 | 149 | total_instance -= images.shape[0] 150 | else: 151 | try: 152 | feed_list = next(data_iter) 153 | images = feed_list[0] 154 | images = images.to(device) 155 | with torch.no_grad(): 156 | embeddings, logits = inception_model(images) 157 | 158 | if total_instance >= batch_size: 159 | pred_arr[start:end] = embeddings.cpu().data.numpy().reshape(batch_size, -1) 160 | else: 161 | pred_arr[start:] = embeddings[:total_instance].cpu().data.numpy().reshape(total_instance, -1) 162 | total_instance -= images.shape[0] 163 | 164 | except StopIteration: 165 | break 166 | return pred_arr 167 | 168 | 169 | def calculate_activation_statistics(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, 170 | is_generate, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable, run_name=None): 171 | act = get_activations(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, 172 | is_generate, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable, run_name) 173 | mu = np.mean(act, axis=0) 174 | sigma = np.cov(act, rowvar=False) 175 | return mu, sigma 176 | 177 | 178 | def calculate_fid_score(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, 179 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, logger, pre_cal_mean=None, pre_cal_std=None, run_name=None): 180 | disable_tqdm = device != 0 181 | inception_model.eval() 182 | 183 | if device == 0: logger.info("Calculating FID Score....") 184 | if pre_cal_mean is not None and pre_cal_std is not None: 185 | m1, s1 = pre_cal_mean, pre_cal_std 186 | else: 187 | m1, s1 = calculate_activation_statistics(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, 188 | prior, False, False, 0, latent_op_alpha, latent_op_beta, device, tqdm_disable=disable_tqdm) 189 | 190 | m2, s2 = calculate_activation_statistics(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, 191 | True, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable=disable_tqdm, run_name=run_name) 192 | 193 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 194 | 195 | return fid_value, m1, s1 196 | -------------------------------------------------------------------------------- /src/metrics/F_beta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Taken from: 3 | # https://github.com/google/compare_gan/blob/master/compare_gan/src/prd_score.py 4 | # 5 | # Changes: 6 | # - default dpi changed from 150 to 300 7 | # - added handling of cases where P = Q, where precision/recall may be 8 | # just above 1, leading to errors for the f_beta computation 9 | # 10 | # Copyright 2018 Google LLC & Hwalsuk Lee. 11 | # 12 | # Licensed under the Apache License, Version 2.0 (the "License"); 13 | # you may not use this file except in compliance with the License. 14 | # You may obtain a copy of the License at 15 | # 16 | # http://www.apache.org/licenses/LICENSE-2.0 17 | # 18 | # Unless required by applicable law or agreed to in writing, software 19 | # distributed under the License is distributed on an "AS IS" BASIS, 20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | # See the License for the specific language governing permissions and 22 | # limitations under the License. 23 | 24 | 25 | import math 26 | import numpy as np 27 | from sklearn.cluster import MiniBatchKMeans 28 | from tqdm import tqdm 29 | 30 | from utils.sample import sample_latents 31 | from utils.losses import latent_optimise 32 | 33 | import torch 34 | from torch.nn import DataParallel 35 | from torch.nn.parallel import DistributedDataParallel 36 | 37 | 38 | 39 | class precision_recall(object): 40 | def __init__(self,inception_model, device): 41 | self.inception_model = inception_model 42 | self.device = device 43 | self.disable_tqdm = device != 0 44 | 45 | 46 | def generate_images(self, gen, dis, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, batch_size): 47 | if isinstance(gen, DataParallel) or isinstance(gen, DistributedDataParallel): 48 | z_dim = gen.module.z_dim 49 | num_classes = gen.module.num_classes 50 | conditional_strategy = dis.module.conditional_strategy 51 | else: 52 | z_dim = gen.z_dim 53 | num_classes = gen.num_classes 54 | conditional_strategy = dis.conditional_strategy 55 | 56 | zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, self.device) 57 | 58 | if latent_op: 59 | zs = latent_optimise(zs, fake_labels, gen, dis, conditional_strategy, latent_op_step, 1.0, latent_op_alpha, 60 | latent_op_beta, False, self.device) 61 | 62 | with torch.no_grad(): 63 | batch_images = gen(zs, fake_labels, evaluation=True) 64 | 65 | return batch_images 66 | 67 | 68 | def inception_softmax(self, batch_images): 69 | with torch.no_grad(): 70 | embeddings, logits = self.inception_model(batch_images) 71 | return embeddings 72 | 73 | 74 | def cluster_into_bins(self, real_embeds, fake_embeds, num_clusters): 75 | representations = np.vstack([real_embeds, fake_embeds]) 76 | kmeans = MiniBatchKMeans(n_clusters=num_clusters, n_init=10) 77 | labels = kmeans.fit(representations).labels_ 78 | 79 | real_labels = labels[:len(real_embeds)] 80 | fake_labels = labels[len(real_embeds):] 81 | 82 | real_density = np.histogram(real_labels, bins=num_clusters, range=[0, num_clusters], density=True)[0] 83 | fake_density = np.histogram(fake_labels, bins=num_clusters, range=[0, num_clusters], density=True)[0] 84 | 85 | return real_density, fake_density 86 | 87 | 88 | def compute_PRD(self, real_density, fake_density, num_angles=1001, epsilon=1e-10): 89 | angles = np.linspace(epsilon, np.pi/2 - epsilon, num=num_angles) 90 | slopes = np.tan(angles) 91 | 92 | slopes_2d = np.expand_dims(slopes, 1) 93 | 94 | real_density_2d = np.expand_dims(real_density, 0) 95 | fake_density_2d = np.expand_dims(fake_density, 0) 96 | 97 | precision = np.minimum(real_density_2d*slopes_2d, fake_density_2d).sum(axis=1) 98 | recall = precision / slopes 99 | 100 | max_val = max(np.max(precision), np.max(recall)) 101 | if max_val > 1.001: 102 | raise ValueError('Detected value > 1.001, this should not happen.') 103 | precision = np.clip(precision, 0, 1) 104 | recall = np.clip(recall, 0, 1) 105 | 106 | return precision, recall 107 | 108 | def compute_precision_recall(self, dataloader, gen, dis, num_generate, num_runs, num_clusters, truncated_factor, prior, 109 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, batch_size, device, num_angles=1001): 110 | dataset_iter = iter(dataloader) 111 | n_batches = int(math.ceil(float(num_generate) / float(batch_size))) 112 | for i in tqdm(range(n_batches), disable = self.disable_tqdm): 113 | real_images, real_labels = next(dataset_iter) 114 | real_images, real_labels = real_images.to(self.device), real_labels.to(self.device) 115 | fake_images = self.generate_images(gen, dis, truncated_factor, prior, latent_op, latent_op_step, 116 | latent_op_alpha, latent_op_beta, batch_size) 117 | 118 | real_embed = self.inception_softmax(real_images).detach().cpu().numpy() 119 | fake_embed = self.inception_softmax(fake_images).detach().cpu().numpy() 120 | if i == 0: 121 | real_embeds = np.array(real_embed, dtype=np.float64) 122 | fake_embeds = np.array(fake_embed, dtype=np.float64) 123 | else: 124 | real_embeds = np.concatenate([real_embeds, np.array(real_embed, dtype=np.float64)], axis=0) 125 | fake_embeds = np.concatenate([fake_embeds, np.array(fake_embed, dtype=np.float64)], axis=0) 126 | 127 | real_embeds = real_embeds[:num_generate] 128 | fake_embeds = fake_embeds[:num_generate] 129 | 130 | precisions = [] 131 | recalls = [] 132 | for _ in range(num_runs): 133 | real_density, fake_density = self.cluster_into_bins(real_embeds, fake_embeds, num_clusters) 134 | precision, recall = self.compute_PRD(real_density, fake_density, num_angles=num_angles) 135 | precisions.append(precision) 136 | recalls.append(recall) 137 | 138 | mean_precision = np.mean(precisions, axis=0) 139 | mean_recall = np.mean(recalls, axis=0) 140 | 141 | return mean_precision, mean_recall 142 | 143 | 144 | def compute_f_beta(self, precision, recall, beta=1, epsilon=1e-10): 145 | return (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall + epsilon) 146 | 147 | 148 | def calculate_f_beta_score(dataloader, gen, dis, inception_model, num_generate, num_runs, num_clusters, beta, truncated_factor, 149 | prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, logger): 150 | inception_model.eval() 151 | 152 | batch_size = dataloader.batch_size 153 | PR = precision_recall(inception_model, device=device) 154 | if device == 0: logger.info("Calculate F_beta Score....") 155 | precision, recall = PR.compute_precision_recall(dataloader, gen, dis, num_generate, num_runs, num_clusters, truncated_factor, 156 | prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, batch_size, device) 157 | 158 | if not ((precision >= 0).all() and (precision <= 1).all()): 159 | raise ValueError('All values in precision must be in [0, 1].') 160 | if not ((recall >= 0).all() and (recall <= 1).all()): 161 | raise ValueError('All values in recall must be in [0, 1].') 162 | if beta <= 0: 163 | raise ValueError('Given parameter beta %s must be positive.' % str(beta)) 164 | 165 | f_beta = np.max(PR.compute_f_beta(precision, recall, beta=beta)) 166 | f_beta_inv = np.max(PR.compute_f_beta(precision, recall, beta=1/beta)) 167 | return precision, recall, f_beta, f_beta_inv 168 | -------------------------------------------------------------------------------- /src/metrics/IS.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/metrics/IS.py 6 | 7 | import math 8 | from tqdm import tqdm 9 | 10 | from utils.sample import sample_latents 11 | from utils.losses import latent_optimise 12 | 13 | import torch 14 | from torch.nn import DataParallel 15 | from torch.nn.parallel import DistributedDataParallel 16 | 17 | 18 | 19 | class evaluator(object): 20 | def __init__(self,inception_model, device): 21 | self.inception_model = inception_model 22 | self.device = device 23 | self.disable_tqdm = device != 0 24 | 25 | def generate_images(self, gen, dis, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, batch_size): 26 | if isinstance(gen, DataParallel) or isinstance(gen, DistributedDataParallel): 27 | z_dim = gen.module.z_dim 28 | num_classes = gen.module.num_classes 29 | conditional_strategy = dis.module.conditional_strategy 30 | else: 31 | z_dim = gen.z_dim 32 | num_classes = gen.num_classes 33 | conditional_strategy = dis.conditional_strategy 34 | 35 | zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, self.device) 36 | 37 | if latent_op: 38 | zs = latent_optimise(zs, fake_labels, gen, dis, conditional_strategy, latent_op_step, 1.0, latent_op_alpha, 39 | latent_op_beta, False, self.device) 40 | 41 | with torch.no_grad(): 42 | batch_images = gen(zs, fake_labels, evaluation=True) 43 | 44 | return batch_images 45 | 46 | 47 | def inception_softmax(self, batch_images): 48 | with torch.no_grad(): 49 | embeddings, logits = self.inception_model(batch_images) 50 | y = torch.nn.functional.softmax(logits, dim=1) 51 | return y 52 | 53 | 54 | def kl_scores(self, ys, splits): 55 | scores = [] 56 | n_images = ys.shape[0] 57 | with torch.no_grad(): 58 | for j in range(splits): 59 | part = ys[(j*n_images//splits): ((j+1)*n_images//splits), :] 60 | kl = part * (torch.log(part) - torch.log(torch.unsqueeze(torch.mean(part, 0), 0))) 61 | kl = torch.mean(torch.sum(kl, 1)) 62 | kl = torch.exp(kl) 63 | scores.append(kl.unsqueeze(0)) 64 | scores = torch.cat(scores, 0) 65 | m_scores = torch.mean(scores).detach().cpu().numpy() 66 | m_std = torch.std(scores).detach().cpu().numpy() 67 | return m_scores, m_std 68 | 69 | 70 | def eval_gen(self, gen, dis, n_eval, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, 71 | latent_op_beta, split, batch_size): 72 | ys = [] 73 | n_batches = int(math.ceil(float(n_eval) / float(batch_size))) 74 | for i in tqdm(range(n_batches), disable=self.disable_tqdm): 75 | batch_images = self.generate_images(gen, dis, truncated_factor, prior, latent_op, latent_op_step, 76 | latent_op_alpha, latent_op_beta, batch_size) 77 | y = self.inception_softmax(batch_images) 78 | ys.append(y) 79 | 80 | with torch.no_grad(): 81 | ys = torch.cat(ys, 0) 82 | m_scores, m_std = self.kl_scores(ys[:n_eval], splits=split) 83 | return m_scores, m_std 84 | 85 | 86 | def eval_dataset(self, dataloader, splits): 87 | batch_size = dataloader.batch_size 88 | n_images = len(dataloader.dataset) 89 | n_batches = int(math.ceil(float(n_images)/float(batch_size))) 90 | dataset_iter = iter(dataloader) 91 | ys = [] 92 | for i in tqdm(range(n_batches), disable=self.disable_tqdm): 93 | feed_list = next(dataset_iter) 94 | batch_images, batch_labels = feed_list[0], feed_list[1] 95 | batch_images = batch_images.to(self.device) 96 | y = self.inception_softmax(batch_images) 97 | ys.append(y) 98 | 99 | with torch.no_grad(): 100 | ys = torch.cat(ys, 0) 101 | m_scores, m_std = self.kl_scores(ys, splits=splits) 102 | return m_scores, m_std 103 | 104 | 105 | def calculate_incep_score(dataloader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, 106 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, splits, device, logger): 107 | inception_model.eval() 108 | 109 | batch_size = dataloader.batch_size 110 | evaluator_instance = evaluator(inception_model, device=device) 111 | if device == 0: logger.info("Calculating Inception Score....") 112 | kl_score, kl_std = evaluator_instance.eval_gen(generator, discriminator, n_generate, truncated_factor, prior, latent_op, 113 | latent_op_step, latent_op_alpha, latent_op_beta, splits, batch_size) 114 | return kl_score, kl_std 115 | -------------------------------------------------------------------------------- /src/metrics/IntraFID.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch 7 | from torch.nn import DataParallel 8 | from torch.nn.parallel import DistributedDataParallel 9 | 10 | from metrics.FID import generate_images 11 | from metrics.FID import calculate_frechet_distance 12 | 13 | 14 | def get_activations_with_label(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, is_generate, 15 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable=False, run_name=None): 16 | """Calculates the activations of the pool_3 layer for all images. 17 | Params: 18 | -- data_loader : data_loader of training images 19 | -- generator : instance of GANs' generator 20 | -- inception_model : Instance of inception model 21 | 22 | Returns: 23 | -- A numpy array of dimension (num images, dims) that contains the 24 | activations of the given tensor when feeding inception with the 25 | query tensor. 26 | """ 27 | if is_generate is True: 28 | batch_size = data_loader.batch_size 29 | total_instance = n_generate 30 | n_batches = math.ceil(float(total_instance) / float(batch_size)) 31 | else: 32 | batch_size = data_loader.batch_size 33 | total_instance = len(data_loader.dataset) 34 | n_batches = math.ceil(float(total_instance) / float(batch_size)) 35 | data_iter = iter(data_loader) 36 | 37 | num_classes = generator.module.num_classes if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel) else generator.num_classes 38 | pred_arr = np.empty((total_instance, 2048)) 39 | label_arr = [] 40 | 41 | for i in tqdm(range(0, n_batches), disable=tqdm_disable): 42 | start = i*batch_size 43 | end = start + batch_size 44 | if is_generate is True: 45 | images, labels = generate_images(batch_size, generator, discriminator, truncated_factor, prior, latent_op, 46 | latent_op_step, latent_op_alpha, latent_op_beta, device) 47 | images = images.to(device) 48 | 49 | with torch.no_grad(): 50 | embeddings, logits = inception_model(images) 51 | 52 | if total_instance >= batch_size: 53 | pred_arr[start:end] = embeddings.cpu().data.numpy().reshape(batch_size, -1) 54 | else: 55 | pred_arr[start:] = embeddings[:total_instance].cpu().data.numpy().reshape(total_instance, -1) 56 | 57 | total_instance -= images.shape[0] 58 | else: 59 | try: 60 | feed_list = next(data_iter) 61 | images = feed_list[0] 62 | labels = feed_list[1] 63 | images = images.to(device) 64 | with torch.no_grad(): 65 | embeddings, logits = inception_model(images) 66 | 67 | if total_instance >= batch_size: 68 | pred_arr[start:end] = embeddings.cpu().data.numpy().reshape(batch_size, -1) 69 | else: 70 | pred_arr[start:] = embeddings[:total_instance].cpu().data.numpy().reshape(total_instance, -1) 71 | total_instance -= images.shape[0] 72 | 73 | except StopIteration: 74 | break 75 | label_arr.append(labels.cpu().data.numpy()) 76 | label_arr = np.concatenate(label_arr)[:len(pred_arr)] 77 | return pred_arr, label_arr 78 | 79 | 80 | def calculate_intra_activation_statistics(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, 81 | is_generate, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable, run_name=None): 82 | act, labels = get_activations_with_label(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, 83 | is_generate, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable, run_name) 84 | num_classes = len(np.unique(labels)) 85 | mu, sigma = [], [] 86 | for i in tqdm(range(num_classes)): 87 | mu.append(np.mean(act[labels == i], axis=0)) 88 | sigma.append(np.cov(act[labels == i], rowvar=False)) 89 | return mu, sigma 90 | 91 | 92 | def calculate_intra_fid_score(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, 93 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, logger, pre_cal_mean=None, pre_cal_std=None, run_name=None): 94 | disable_tqdm = device != 0 95 | inception_model.eval() 96 | 97 | if device == 0: logger.info("Calculating Intra-FID Score....") 98 | if pre_cal_mean is not None and pre_cal_std is not None: 99 | m1, s1 = pre_cal_mean, pre_cal_std 100 | else: 101 | m1, s1 = calculate_intra_activation_statistics(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, 102 | prior, False, False, 0, latent_op_alpha, latent_op_beta, device, tqdm_disable=disable_tqdm) 103 | 104 | m2, s2 = calculate_intra_activation_statistics(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, 105 | True, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable=disable_tqdm, run_name=run_name) 106 | 107 | intra_fid = [] 108 | for i in tqdm(range(len(m1))): 109 | intra_fid.append(calculate_frechet_distance(m1[i], s1[i], m2[i], s2[i])) 110 | intra_fid = np.mean(intra_fid) 111 | 112 | return intra_fid, m1, s1 113 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sian-chen/PyTorch-ECGAN/974e86692611dd3ce4136cf7b2b786f5a011be6b/src/metrics/__init__.py -------------------------------------------------------------------------------- /src/metrics/inception_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | 12 | 13 | # Inception weights ported to Pytorch from 14 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 15 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 16 | 17 | 18 | class InceptionV3(nn.Module): 19 | """Pretrained InceptionV3 network returning feature maps""" 20 | 21 | # Index of default block of inception to return, 22 | # corresponds to output of final average pooling 23 | def __init__(self, 24 | resize_input=True, 25 | normalize_input=False, 26 | requires_grad=False): 27 | """Build pretrained InceptionV3 28 | Parameters 29 | ---------- 30 | resize_input : bool 31 | If true, bilinearly resizes input to width and height 299 before 32 | feeding input to model. As the network without fully connected 33 | layers is fully convolutional, it should be able to handle inputs 34 | of arbitrary size, so resizing might not be strictly needed 35 | normalize_input : bool 36 | If true, scales the input from range (0, 1) to the range the 37 | pretrained Inception network expects, namely (-1, 1) 38 | requires_grad : bool 39 | If true, parameters of the model require gradients. Possibly useful 40 | for finetuning the network 41 | """ 42 | super(InceptionV3, self).__init__() 43 | 44 | self.resize_input = resize_input 45 | self.normalize_input = normalize_input 46 | self.blocks = nn.ModuleList() 47 | 48 | state_dict, inception = fid_inception_v3() 49 | 50 | # Block 0: input to maxpool1 51 | block0 = [ 52 | inception.Conv2d_1a_3x3, 53 | inception.Conv2d_2a_3x3, 54 | inception.Conv2d_2b_3x3, 55 | nn.MaxPool2d(kernel_size=3, stride=2) 56 | ] 57 | self.blocks.append(nn.Sequential(*block0)) 58 | 59 | # Block 1: maxpool1 to maxpool2 60 | block1 = [ 61 | inception.Conv2d_3b_1x1, 62 | inception.Conv2d_4a_3x3, 63 | nn.MaxPool2d(kernel_size=3, stride=2) 64 | ] 65 | self.blocks.append(nn.Sequential(*block1)) 66 | 67 | # Block 2: maxpool2 to aux classifier 68 | block2 = [ 69 | inception.Mixed_5b, 70 | inception.Mixed_5c, 71 | inception.Mixed_5d, 72 | inception.Mixed_6a, 73 | inception.Mixed_6b, 74 | inception.Mixed_6c, 75 | inception.Mixed_6d, 76 | inception.Mixed_6e, 77 | ] 78 | self.blocks.append(nn.Sequential(*block2)) 79 | 80 | # Block 3: aux classifier to final avgpool 81 | block3 = [ 82 | inception.Mixed_7a, 83 | inception.Mixed_7b, 84 | inception.Mixed_7c, 85 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 86 | ] 87 | self.blocks.append(nn.Sequential(*block3)) 88 | 89 | with torch.no_grad(): 90 | self.fc = nn.Linear(2048, 1008, bias=True) 91 | self.fc.weight.copy_(state_dict['fc.weight']) 92 | self.fc.bias.copy_(state_dict['fc.bias']) 93 | 94 | for param in self.parameters(): 95 | param.requires_grad = requires_grad 96 | 97 | 98 | def forward(self, inp): 99 | """Get Inception feature maps 100 | Parameters 101 | ---------- 102 | inp : torch.autograd.Variable 103 | Input tensor of shape Bx3xHxW. Values are expected to be in 104 | range (0, 1) 105 | Returns 106 | ------- 107 | List of torch.autograd.Variable, corresponding to the selected output 108 | block, sorted ascending by index 109 | """ 110 | x = inp 111 | 112 | if self.resize_input: 113 | x = F.interpolate(x, 114 | size=(299, 299), 115 | mode='bilinear', 116 | align_corners=False) 117 | 118 | if self.normalize_input: 119 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 120 | 121 | for idx, block in enumerate(self.blocks): 122 | x = block(x) 123 | 124 | x = F.dropout(x, training=False) 125 | x = torch.flatten(x, 1) 126 | logit = self.fc(x) 127 | return x, logit 128 | 129 | 130 | def fid_inception_v3(): 131 | """Build pretrained Inception model for FID computation 132 | The Inception model for FID computation uses a different set of weights 133 | and has a slightly different structure than torchvision's Inception. 134 | This method first constructs torchvision's Inception and then patches the 135 | necessary parts that are different in the FID Inception model. 136 | """ 137 | inception = models.inception_v3(num_classes=1008, 138 | aux_logits=False, 139 | pretrained=False) 140 | 141 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 142 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 143 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 144 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 145 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 146 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 147 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 148 | inception.Mixed_7b = FIDInceptionE_1(1280) 149 | inception.Mixed_7c = FIDInceptionE_2(2048) 150 | # inception.fc = nn.Linear(2048, 1008, bias=False) 151 | 152 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 153 | inception.load_state_dict(state_dict) 154 | return state_dict, inception 155 | 156 | 157 | class FIDInceptionA(models.inception.InceptionA): 158 | """InceptionA block patched for FID computation""" 159 | def __init__(self, in_channels, pool_features): 160 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 161 | 162 | def forward(self, x): 163 | branch1x1 = self.branch1x1(x) 164 | 165 | branch5x5 = self.branch5x5_1(x) 166 | branch5x5 = self.branch5x5_2(branch5x5) 167 | 168 | branch3x3dbl = self.branch3x3dbl_1(x) 169 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 170 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 171 | 172 | # Patch: Tensorflow's average pool does not use the padded zero's in 173 | # its average calculation 174 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 175 | count_include_pad=False) 176 | branch_pool = self.branch_pool(branch_pool) 177 | 178 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 179 | return torch.cat(outputs, 1) 180 | 181 | 182 | class FIDInceptionC(models.inception.InceptionC): 183 | """InceptionC block patched for FID computation""" 184 | def __init__(self, in_channels, channels_7x7): 185 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 186 | 187 | def forward(self, x): 188 | branch1x1 = self.branch1x1(x) 189 | 190 | branch7x7 = self.branch7x7_1(x) 191 | branch7x7 = self.branch7x7_2(branch7x7) 192 | branch7x7 = self.branch7x7_3(branch7x7) 193 | 194 | branch7x7dbl = self.branch7x7dbl_1(x) 195 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 196 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 197 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 198 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 199 | 200 | # Patch: Tensorflow's average pool does not use the padded zero's in 201 | # its average calculation 202 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 203 | count_include_pad=False) 204 | branch_pool = self.branch_pool(branch_pool) 205 | 206 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 207 | return torch.cat(outputs, 1) 208 | 209 | 210 | class FIDInceptionE_1(models.inception.InceptionE): 211 | """First InceptionE block patched for FID computation""" 212 | def __init__(self, in_channels): 213 | super(FIDInceptionE_1, self).__init__(in_channels) 214 | 215 | def forward(self, x): 216 | branch1x1 = self.branch1x1(x) 217 | 218 | branch3x3 = self.branch3x3_1(x) 219 | branch3x3 = [ 220 | self.branch3x3_2a(branch3x3), 221 | self.branch3x3_2b(branch3x3), 222 | ] 223 | branch3x3 = torch.cat(branch3x3, 1) 224 | 225 | branch3x3dbl = self.branch3x3dbl_1(x) 226 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 227 | branch3x3dbl = [ 228 | self.branch3x3dbl_3a(branch3x3dbl), 229 | self.branch3x3dbl_3b(branch3x3dbl), 230 | ] 231 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 232 | 233 | # Patch: Tensorflow's average pool does not use the padded zero's in 234 | # its average calculation 235 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 236 | count_include_pad=False) 237 | branch_pool = self.branch_pool(branch_pool) 238 | 239 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 240 | return torch.cat(outputs, 1) 241 | 242 | 243 | class FIDInceptionE_2(models.inception.InceptionE): 244 | """Second InceptionE block patched for FID computation""" 245 | def __init__(self, in_channels): 246 | super(FIDInceptionE_2, self).__init__(in_channels) 247 | 248 | def forward(self, x): 249 | branch1x1 = self.branch1x1(x) 250 | 251 | branch3x3 = self.branch3x3_1(x) 252 | branch3x3 = [ 253 | self.branch3x3_2a(branch3x3), 254 | self.branch3x3_2b(branch3x3), 255 | ] 256 | branch3x3 = torch.cat(branch3x3, 1) 257 | 258 | branch3x3dbl = self.branch3x3dbl_1(x) 259 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 260 | branch3x3dbl = [ 261 | self.branch3x3dbl_3a(branch3x3dbl), 262 | self.branch3x3dbl_3b(branch3x3dbl), 263 | ] 264 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 265 | 266 | # Patch: The FID Inception model uses max pooling instead of average 267 | # pooling. This is likely an error in this specific Inception 268 | # implementation, as other Inception models use average pooling here 269 | # (which matches the description in the paper). 270 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 271 | branch_pool = self.branch_pool(branch_pool) 272 | 273 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 274 | return torch.cat(outputs, 1) 275 | -------------------------------------------------------------------------------- /src/metrics/prepare_inception_moments.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/metrics/prepare_inception_moments.py 6 | 7 | import numpy as np 8 | import os 9 | 10 | from metrics.FID import calculate_activation_statistics 11 | from metrics.IS import evaluator 12 | 13 | 14 | 15 | def prepare_inception_moments(dataloader, eval_mode, generator, inception_model, splits, run_name, logger, device): 16 | dataset_name = dataloader.dataset.dataset_name 17 | inception_model.eval() 18 | 19 | save_path = os.path.abspath(os.path.join("./data", dataset_name + "_" + eval_mode +'_' + 'inception_moments.npz')) 20 | is_file = os.path.isfile(save_path) 21 | 22 | if is_file: 23 | mu = np.load(save_path)['mu'] 24 | sigma = np.load(save_path)['sigma'] 25 | else: 26 | if device == 0: logger.info('Calculate moments of {} dataset'.format(eval_mode)) 27 | mu, sigma = calculate_activation_statistics(data_loader=dataloader, 28 | generator=generator, 29 | discriminator=None, 30 | inception_model=inception_model, 31 | n_generate=None, 32 | truncated_factor=None, 33 | prior=None, 34 | is_generate=False, 35 | latent_op=False, 36 | latent_op_step=None, 37 | latent_op_alpha=None, 38 | latent_op_beta=None, 39 | device=device, 40 | tqdm_disable=False, 41 | run_name=run_name) 42 | 43 | if device == 0: logger.info('Save calculated means and covariances to disk...') 44 | np.savez(save_path, **{'mu': mu, 'sigma': sigma}) 45 | 46 | if is_file: 47 | pass 48 | else: 49 | if device == 0: logger.info('calculate inception score of {} dataset.'.format(eval_mode)) 50 | evaluator_instance = evaluator(inception_model, device=device) 51 | is_score, is_std = evaluator_instance.eval_dataset(dataloader, splits=splits) 52 | if device == 0: logger.info('Inception score={is_score}-Inception_std={is_std}'.format(is_score=is_score, is_std=is_std)) 53 | return mu, sigma 54 | -------------------------------------------------------------------------------- /src/models/dcgan.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # models/dcgan.py 6 | 7 | 8 | from utils.model_ops import * 9 | from utils.misc import * 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | 17 | class GenBlock(nn.Module): 18 | def __init__(self, in_channels, out_channels, g_spectral_norm, activation_fn, conditional_bn, num_classes): 19 | super(GenBlock, self).__init__() 20 | self.conditional_bn = conditional_bn 21 | 22 | if g_spectral_norm: 23 | self.deconv0 = sndeconv2d(in_channels=in_channels, out_channels=out_channels, 24 | kernel_size=4, stride=2, padding=1) 25 | else: 26 | self.deconv0 = deconv2d(in_channels=in_channels, out_channels=out_channels, 27 | kernel_size=4, stride=2, padding=1) 28 | 29 | if self.conditional_bn: 30 | self.bn0 = ConditionalBatchNorm2d(num_features=out_channels, num_classes=num_classes, 31 | spectral_norm=g_spectral_norm) 32 | else: 33 | self.bn0 = batchnorm_2d(in_features=out_channels) 34 | 35 | if activation_fn == "ReLU": 36 | self.activation = nn.ReLU(inplace=True) 37 | elif activation_fn == "Leaky_ReLU": 38 | self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 39 | elif activation_fn == "ELU": 40 | self.activation = nn.ELU(alpha=1.0, inplace=True) 41 | elif activation_fn == "GELU": 42 | self.activation = nn.GELU() 43 | else: 44 | raise NotImplementedError 45 | 46 | def forward(self, x, label): 47 | x = self.deconv0(x) 48 | if self.conditional_bn: 49 | x = self.bn0(x, label) 50 | else: 51 | x = self.bn0(x) 52 | out = self.activation(x) 53 | return out 54 | 55 | 56 | class Generator(nn.Module): 57 | """Generator.""" 58 | def __init__(self, z_dim, shared_dim, img_size, g_conv_dim, g_spectral_norm, attention, attention_after_nth_gen_block, activation_fn, 59 | conditional_strategy, num_classes, initialize, G_depth, mixed_precision): 60 | super(Generator, self).__init__() 61 | self.in_dims = [512, 256, 128] 62 | self.out_dims = [256, 128, 64] 63 | 64 | self.z_dim = z_dim 65 | self.num_classes = num_classes 66 | self.mixed_precision = mixed_precision 67 | conditional_bn = True if conditional_strategy in ["ACGAN", "ProjGAN", "ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN", "ECGAN"] else False 68 | 69 | if g_spectral_norm: 70 | self.linear0 = snlinear(in_features=self.z_dim, out_features=self.in_dims[0]*4*4) 71 | else: 72 | self.linear0 = linear(in_features=self.z_dim, out_features=self.in_dims[0]*4*4) 73 | 74 | self.blocks = [] 75 | for index in range(len(self.in_dims)): 76 | self.blocks += [[GenBlock(in_channels=self.in_dims[index], 77 | out_channels=self.out_dims[index], 78 | g_spectral_norm=g_spectral_norm, 79 | activation_fn=activation_fn, 80 | conditional_bn=conditional_bn, 81 | num_classes=self.num_classes)]] 82 | 83 | if index+1 == attention_after_nth_gen_block and attention is True: 84 | self.blocks += [[Self_Attn(self.out_dims[index], g_spectral_norm)]] 85 | 86 | self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) 87 | 88 | if g_spectral_norm: 89 | self.conv4 = snconv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1) 90 | else: 91 | self.conv4 = conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1) 92 | 93 | self.tanh = nn.Tanh() 94 | 95 | # Weight init 96 | if initialize is not False: 97 | init_weights(self.modules, initialize) 98 | 99 | def forward(self, z, label, evaluation=False): 100 | with torch.cuda.amp.autocast() if self.mixed_precision is True and evaluation is False else dummy_context_mgr() as mp: 101 | act = self.linear0(z) 102 | act = act.view(-1, self.in_dims[0], 4, 4) 103 | for index, blocklist in enumerate(self.blocks): 104 | for block in blocklist: 105 | if isinstance(block, Self_Attn): 106 | act = block(act) 107 | else: 108 | act = block(act, label) 109 | act = self.conv4(act) 110 | out = self.tanh(act) 111 | return out 112 | 113 | 114 | class DiscBlock(nn.Module): 115 | def __init__(self, in_channels, out_channels, d_spectral_norm, activation_fn): 116 | super(DiscBlock, self).__init__() 117 | self.d_spectral_norm = d_spectral_norm 118 | 119 | if d_spectral_norm: 120 | self.conv0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1) 121 | self.conv1 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1) 122 | else: 123 | self.conv0 = conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1) 124 | self.conv1 = conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1) 125 | 126 | self.bn0 = batchnorm_2d(in_features=out_channels) 127 | self.bn1 = batchnorm_2d(in_features=out_channels) 128 | 129 | if activation_fn == "ReLU": 130 | self.activation = nn.ReLU(inplace=True) 131 | elif activation_fn == "Leaky_ReLU": 132 | self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 133 | elif activation_fn == "ELU": 134 | self.activation = nn.ELU(alpha=1.0, inplace=True) 135 | elif activation_fn == "GELU": 136 | self.activation = nn.GELU() 137 | else: 138 | raise NotImplementedError 139 | 140 | def forward(self, x): 141 | x = self.conv0(x) 142 | if self.d_spectral_norm is False: 143 | x = self.bn0(x) 144 | x = self.activation(x) 145 | x = self.conv1(x) 146 | if self.d_spectral_norm is False: 147 | x = self.bn1(x) 148 | out = self.activation(x) 149 | return out 150 | 151 | 152 | class Discriminator(nn.Module): 153 | """Discriminator.""" 154 | def __init__(self, img_size, d_conv_dim, d_spectral_norm, attention, attention_after_nth_dis_block, activation_fn, conditional_strategy, 155 | hypersphere_dim, num_classes, nonlinear_embed, normalize_embed, initialize, D_depth, mixed_precision): 156 | super(Discriminator, self).__init__() 157 | self.in_dims = [3] + [64, 128] 158 | self.out_dims = [64, 128, 256] 159 | 160 | self.d_spectral_norm = d_spectral_norm 161 | self.conditional_strategy = conditional_strategy 162 | self.num_classes = num_classes 163 | self.nonlinear_embed = nonlinear_embed 164 | self.normalize_embed = normalize_embed 165 | self.mixed_precision = mixed_precision 166 | 167 | self.blocks = [] 168 | for index in range(len(self.in_dims)): 169 | self.blocks += [[DiscBlock(in_channels=self.in_dims[index], 170 | out_channels=self.out_dims[index], 171 | d_spectral_norm=d_spectral_norm, 172 | activation_fn=activation_fn)]] 173 | 174 | if index+1 == attention_after_nth_dis_block and attention is True: 175 | self.blocks += [[Self_Attn(self.out_dims[index], d_spectral_norm)]] 176 | 177 | self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) 178 | 179 | if self.d_spectral_norm: 180 | self.conv = snconv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1) 181 | else: 182 | self.conv = conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1) 183 | self.bn = batchnorm_2d(in_features=512) 184 | 185 | if activation_fn == "ReLU": 186 | self.activation = nn.ReLU(inplace=True) 187 | elif activation_fn == "Leaky_ReLU": 188 | self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 189 | elif activation_fn == "ELU": 190 | self.activation = nn.ELU(alpha=1.0, inplace=True) 191 | elif activation_fn == "GELU": 192 | self.activation = nn.GELU() 193 | else: 194 | raise NotImplementedError 195 | 196 | if d_spectral_norm: 197 | if self.conditional_strategy == 'ECGAN': 198 | self.linear1 = snlinear(in_features=512, out_features=num_classes) 199 | self.linear2 = snlinear(in_features=512, out_features=hypersphere_dim) 200 | if self.nonlinear_embed: 201 | self.linear3 = snlinear(in_features=hypersphere_dim, out_features=hypersphere_dim) 202 | self.embedding = sn_embedding(num_classes, hypersphere_dim) 203 | else: 204 | self.linear1 = snlinear(in_features=512, out_features=1) 205 | if self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']: 206 | self.linear2 = snlinear(in_features=512, out_features=hypersphere_dim) 207 | if self.nonlinear_embed: 208 | self.linear3 = snlinear(in_features=hypersphere_dim, out_features=hypersphere_dim) 209 | self.embedding = sn_embedding(num_classes, hypersphere_dim) 210 | elif self.conditional_strategy == 'ProjGAN': 211 | self.embedding = sn_embedding(num_classes, 512) 212 | elif self.conditional_strategy == 'ACGAN': 213 | self.linear4 = snlinear(in_features=512, out_features=num_classes) 214 | else: 215 | pass 216 | else: 217 | if self.conditional_strategy == 'ECGAN': 218 | self.linear1 = linear(in_features=512, out_features=num_classes) 219 | if self.contrastive_lambda: 220 | self.linear2 = linear(in_features=512, out_features=hypersphere_dim) 221 | if self.nonlinear_embed: 222 | self.linear3 = linear(in_features=hypersphere_dim, out_features=hypersphere_dim) 223 | self.embedding = embedding(num_classes, hypersphere_dim) 224 | else: 225 | self.linear1 = linear(in_features=512, out_features=1) 226 | if self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']: 227 | self.linear2 = linear(in_features=512, out_features=hypersphere_dim) 228 | if self.nonlinear_embed: 229 | self.linear3 = linear(in_features=hypersphere_dim, out_features=hypersphere_dim) 230 | self.embedding = embedding(num_classes, hypersphere_dim) 231 | elif self.conditional_strategy == 'ProjGAN': 232 | self.embedding = embedding(num_classes, 512) 233 | elif self.conditional_strategy == 'ACGAN': 234 | self.linear4 = linear(in_features=512, out_features=num_classes) 235 | else: 236 | pass 237 | 238 | # Weight init 239 | if initialize is not False: 240 | init_weights(self.modules, initialize) 241 | 242 | 243 | def forward(self, x, label, evaluation=False): 244 | with torch.cuda.amp.autocast() if self.mixed_precision is True and evaluation is False else dummy_context_mgr() as mp: 245 | h = x 246 | for index, blocklist in enumerate(self.blocks): 247 | for block in blocklist: 248 | h = block(h) 249 | h = self.conv(h) 250 | if self.d_spectral_norm is False: 251 | h = self.bn(h) 252 | h = self.activation(h) 253 | h = torch.sum(h, dim=[2, 3]) 254 | 255 | if self.conditional_strategy == 'no': 256 | authen_output = torch.squeeze(self.linear1(h)) 257 | return authen_output 258 | 259 | elif self.conditional_strategy == 'ECGAN': 260 | cls_output = self.linear1(h) 261 | cond_output = (cls_output * F.one_hot(label.squeeze(), num_classes=self.num_classes)).sum(dim=1) 262 | uncond_output = torch.logsumexp(cls_output, dim=1) 263 | cls_proxy = self.embedding(label) 264 | cls_embed = self.linear2(h) 265 | if self.nonlinear_embed: 266 | cls_embed = self.linear3(self.activation(cls_embed)) 267 | if self.normalize_embed: 268 | cls_proxy = F.normalize(cls_proxy, dim=1) 269 | cls_embed = F.normalize(cls_embed, dim=1) 270 | return cls_output, cond_output, uncond_output, cls_proxy, cls_embed 271 | 272 | elif self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']: 273 | authen_output = torch.squeeze(self.linear1(h)) 274 | cls_proxy = self.embedding(label) 275 | cls_embed = self.linear2(h) 276 | if self.nonlinear_embed: 277 | cls_embed = self.linear3(self.activation(cls_embed)) 278 | if self.normalize_embed: 279 | cls_proxy = F.normalize(cls_proxy, dim=1) 280 | cls_embed = F.normalize(cls_embed, dim=1) 281 | return cls_proxy, cls_embed, authen_output 282 | 283 | elif self.conditional_strategy == 'ProjGAN': 284 | authen_output = torch.squeeze(self.linear1(h)) 285 | proj = torch.sum(torch.mul(self.embedding(label), h), 1) 286 | return authen_output + proj 287 | 288 | elif self.conditional_strategy == 'ACGAN': 289 | authen_output = torch.squeeze(self.linear1(h)) 290 | cls_output = self.linear4(h) 291 | return cls_output, authen_output 292 | 293 | else: 294 | raise NotImplementedError 295 | -------------------------------------------------------------------------------- /src/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding: utf-8 -*- 3 | File : batchnorm_reimpl.py 4 | Author : Jiayuan Mao 5 | Email : maojiayuan@gmail.com 6 | Date : 27/01/2018 7 | 8 | This file is part of Synchronized-BatchNorm-PyTorch. 9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 10 | Distributed under MIT License. 11 | 12 | MIT License 13 | 14 | Copyright (c) 2018 Jiayuan MAO 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | 36 | import torch 37 | import torch.nn as nn 38 | import torch.nn.init as init 39 | 40 | __all__ = ['BatchNorm2dReimpl'] 41 | 42 | 43 | class BatchNorm2dReimpl(nn.Module): 44 | """ 45 | A re-implementation of batch normalization, used for testing the numerical 46 | stability. 47 | 48 | Author: acgtyrant 49 | See also: 50 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 51 | """ 52 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 53 | super().__init__() 54 | 55 | self.num_features = num_features 56 | self.eps = eps 57 | self.momentum = momentum 58 | self.weight = nn.Parameter(torch.empty(num_features)) 59 | self.bias = nn.Parameter(torch.empty(num_features)) 60 | self.register_buffer('running_mean', torch.zeros(num_features)) 61 | self.register_buffer('running_var', torch.ones(num_features)) 62 | self.reset_parameters() 63 | 64 | def reset_running_stats(self): 65 | self.running_mean.zero_() 66 | self.running_var.fill_(1) 67 | 68 | def reset_parameters(self): 69 | self.reset_running_stats() 70 | init.uniform_(self.weight) 71 | init.zeros_(self.bias) 72 | 73 | def forward(self, input_): 74 | batchsize, channels, height, width = input_.size() 75 | numel = batchsize * height * width 76 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 77 | sum_ = input_.sum(1) 78 | sum_of_square = input_.pow(2).sum(1) 79 | mean = sum_ / numel 80 | sumvar = sum_of_square - sum_ * mean 81 | 82 | self.running_mean = ( 83 | (1 - self.momentum) * self.running_mean 84 | + self.momentum * mean.detach() 85 | ) 86 | unbias_var = sumvar / (numel - 1) 87 | self.running_var = ( 88 | (1 - self.momentum) * self.running_var 89 | + self.momentum * unbias_var.detach() 90 | ) 91 | 92 | bias_var = sumvar / numel 93 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 94 | output = ( 95 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 96 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 97 | 98 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 99 | 100 | -------------------------------------------------------------------------------- /src/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding: utf-8 -*- 3 | File : comm.py 4 | Author : Jiayuan Mao 5 | Email : maojiayuan@gmail.com 6 | Date : 27/01/2018 7 | 8 | This file is part of Synchronized-BatchNorm-PyTorch. 9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 10 | Distributed under MIT License. 11 | 12 | MIT License 13 | 14 | Copyright (c) 2018 Jiayuan MAO 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | 36 | import queue 37 | import collections 38 | import threading 39 | 40 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 41 | 42 | 43 | class FutureResult(object): 44 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 45 | 46 | def __init__(self): 47 | self._result = None 48 | self._lock = threading.Lock() 49 | self._cond = threading.Condition(self._lock) 50 | 51 | def put(self, result): 52 | with self._lock: 53 | assert self._result is None, 'Previous result has\'t been fetched.' 54 | self._result = result 55 | self._cond.notify() 56 | 57 | def get(self): 58 | with self._lock: 59 | if self._result is None: 60 | self._cond.wait() 61 | 62 | res = self._result 63 | self._result = None 64 | return res 65 | 66 | 67 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 68 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 69 | 70 | 71 | class SlavePipe(_SlavePipeBase): 72 | """Pipe for master-slave communication.""" 73 | 74 | def run_slave(self, msg): 75 | self.queue.put((self.identifier, msg)) 76 | ret = self.result.get() 77 | self.queue.put(True) 78 | return ret 79 | 80 | 81 | class SyncMaster(object): 82 | """An abstract `SyncMaster` object. 83 | 84 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 85 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 86 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 87 | and passed to a registered callback. 88 | - After receiving the messages, the master device should gather the information and determine to message passed 89 | back to each slave devices. 90 | """ 91 | 92 | def __init__(self, master_callback): 93 | """ 94 | 95 | Args: 96 | master_callback: a callback to be invoked after having collected messages from slave devices. 97 | """ 98 | self._master_callback = master_callback 99 | self._queue = queue.Queue() 100 | self._registry = collections.OrderedDict() 101 | self._activated = False 102 | 103 | def __getstate__(self): 104 | return {'master_callback': self._master_callback} 105 | 106 | def __setstate__(self, state): 107 | self.__init__(state['master_callback']) 108 | 109 | def register_slave(self, identifier): 110 | """ 111 | Register an slave device. 112 | 113 | Args: 114 | identifier: an identifier, usually is the device id. 115 | 116 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 117 | 118 | """ 119 | if self._activated: 120 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 121 | self._activated = False 122 | self._registry.clear() 123 | future = FutureResult() 124 | self._registry[identifier] = _MasterRegistry(future) 125 | return SlavePipe(identifier, self._queue, future) 126 | 127 | def run_master(self, master_msg): 128 | """ 129 | Main entry for the master device in each forward pass. 130 | The messages were first collected from each devices (including the master device), and then 131 | an callback will be invoked to compute the message to be sent back to each devices 132 | (including the master device). 133 | 134 | Args: 135 | master_msg: the message that the master want to send to itself. This will be placed as the first 136 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 137 | 138 | Returns: the message to be sent back to the master device. 139 | 140 | """ 141 | self._activated = True 142 | 143 | intermediates = [(0, master_msg)] 144 | for i in range(self.nr_slaves): 145 | intermediates.append(self._queue.get()) 146 | 147 | results = self._master_callback(intermediates) 148 | assert results[0][0] == 0, 'The first result should belongs to the master.' 149 | 150 | for i, res in results: 151 | if i == 0: 152 | continue 153 | self._registry[i].result.put(res) 154 | 155 | for i in range(self.nr_slaves): 156 | assert self._queue.get() is True 157 | 158 | return results[0][1] 159 | 160 | @property 161 | def nr_slaves(self): 162 | return len(self._registry) 163 | -------------------------------------------------------------------------------- /src/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding: utf-8 -*- 3 | File : replicate.py 4 | Author : Jiayuan Mao 5 | Email : maojiayuan@gmail.com 6 | Date : 27/01/2018 7 | 8 | This file is part of Synchronized-BatchNorm-PyTorch. 9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 10 | Distributed under MIT License. 11 | 12 | MIT License 13 | 14 | Copyright (c) 2018 Jiayuan MAO 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | 36 | import functools 37 | 38 | from torch.nn.parallel.data_parallel import DataParallel 39 | 40 | __all__ = [ 41 | 'CallbackContext', 42 | 'execute_replication_callbacks', 43 | 'DataParallelWithCallback', 44 | 'patch_replication_callback' 45 | ] 46 | 47 | 48 | class CallbackContext(object): 49 | pass 50 | 51 | 52 | def execute_replication_callbacks(modules): 53 | """ 54 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 55 | 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Note that, as all modules are isomorphism, we assign each sub-module with a context 59 | (shared among multiple copies of this module on different devices). 60 | Through this context, different copies can share some information. 61 | 62 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 63 | of any slave copies. 64 | """ 65 | master_copy = modules[0] 66 | nr_modules = len(list(master_copy.modules())) 67 | ctxs = [CallbackContext() for _ in range(nr_modules)] 68 | 69 | for i, module in enumerate(modules): 70 | for j, m in enumerate(module.modules()): 71 | if hasattr(m, '__data_parallel_replicate__'): 72 | m.__data_parallel_replicate__(ctxs[j], i) 73 | 74 | 75 | class DataParallelWithCallback(DataParallel): 76 | """ 77 | Data Parallel with a replication callback. 78 | 79 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 80 | original `replicate` function. 81 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 82 | 83 | Examples: 84 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 85 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 86 | # sync_bn.__data_parallel_replicate__ will be invoked. 87 | """ 88 | 89 | def replicate(self, module, device_ids): 90 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | 95 | def patch_replication_callback(data_parallel): 96 | """ 97 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 98 | Useful when you have customized `DataParallel` implementation. 99 | 100 | Examples: 101 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 102 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 103 | > patch_replication_callback(sync_bn) 104 | # this is equivalent to 105 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 106 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 107 | """ 108 | 109 | assert isinstance(data_parallel, DataParallel) 110 | 111 | old_replicate = data_parallel.replicate 112 | 113 | @functools.wraps(old_replicate) 114 | def new_replicate(module, device_ids): 115 | modules = old_replicate(module, device_ids) 116 | execute_replication_callbacks(modules) 117 | return modules 118 | 119 | data_parallel.replicate = new_replicate 120 | -------------------------------------------------------------------------------- /src/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding: utf-8 -*- 3 | File : unittest.py 4 | Author : Jiayuan Mao 5 | Email : maojiayuan@gmail.com 6 | Date : 27/01/2018 7 | 8 | This file is part of Synchronized-BatchNorm-PyTorch. 9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 10 | Distributed under MIT License. 11 | 12 | MIT License 13 | 14 | Copyright (c) 2018 Jiayuan MAO 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | 36 | import unittest 37 | import torch 38 | 39 | 40 | class TorchTestCase(unittest.TestCase): 41 | def assertTensorClose(self, x, y): 42 | adiff = float((x - y).abs().max()) 43 | if (y == 0).all(): 44 | rdiff = 'NaN' 45 | else: 46 | rdiff = float((adiff / y).abs().max()) 47 | 48 | message = ( 49 | 'Tensor close check failed\n' 50 | 'adiff={}\n' 51 | 'rdiff={}\n' 52 | ).format(adiff, rdiff) 53 | self.assertTrue(torch.allclose(x, y), message) 54 | 55 | -------------------------------------------------------------------------------- /src/utils/ada.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Kim Seonghyeon 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | 26 | import math 27 | 28 | from utils.ada_op import upfirdn2d 29 | 30 | import torch 31 | from torch.nn import functional as F 32 | 33 | 34 | 35 | SYM6 = ( 36 | 0.015404109327027373, 37 | 0.0034907120842174702, 38 | -0.11799011114819057, 39 | -0.048311742585633, 40 | 0.4910559419267466, 41 | 0.787641141030194, 42 | 0.3379294217276218, 43 | -0.07263752278646252, 44 | -0.021060292512300564, 45 | 0.04472490177066578, 46 | 0.0017677118642428036, 47 | -0.007800708325034148, 48 | ) 49 | 50 | 51 | def translate_mat(t_x, t_y): 52 | batch = t_x.shape[0] 53 | 54 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) 55 | translate = torch.stack((t_x, t_y), 1) 56 | mat[:, :2, 2] = translate 57 | 58 | return mat 59 | 60 | 61 | def rotate_mat(theta): 62 | batch = theta.shape[0] 63 | 64 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) 65 | sin_t = torch.sin(theta) 66 | cos_t = torch.cos(theta) 67 | rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2) 68 | mat[:, :2, :2] = rot 69 | 70 | return mat 71 | 72 | 73 | def scale_mat(s_x, s_y): 74 | batch = s_x.shape[0] 75 | 76 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) 77 | mat[:, 0, 0] = s_x 78 | mat[:, 1, 1] = s_y 79 | 80 | return mat 81 | 82 | 83 | def translate3d_mat(t_x, t_y, t_z): 84 | batch = t_x.shape[0] 85 | 86 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 87 | translate = torch.stack((t_x, t_y, t_z), 1) 88 | mat[:, :3, 3] = translate 89 | 90 | return mat 91 | 92 | 93 | def rotate3d_mat(axis, theta): 94 | batch = theta.shape[0] 95 | 96 | u_x, u_y, u_z = axis 97 | 98 | eye = torch.eye(3).unsqueeze(0) 99 | cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0) 100 | outer = torch.tensor(axis) 101 | outer = (outer.unsqueeze(1) * outer).unsqueeze(0) 102 | 103 | sin_t = torch.sin(theta).view(-1, 1, 1) 104 | cos_t = torch.cos(theta).view(-1, 1, 1) 105 | 106 | rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer 107 | 108 | eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 109 | eye_4[:, :3, :3] = rot 110 | 111 | return eye_4 112 | 113 | 114 | def scale3d_mat(s_x, s_y, s_z): 115 | batch = s_x.shape[0] 116 | 117 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 118 | mat[:, 0, 0] = s_x 119 | mat[:, 1, 1] = s_y 120 | mat[:, 2, 2] = s_z 121 | 122 | return mat 123 | 124 | 125 | def luma_flip_mat(axis, i): 126 | batch = i.shape[0] 127 | 128 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 129 | axis = torch.tensor(axis + (0,)) 130 | flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1) 131 | 132 | return eye - flip 133 | 134 | 135 | def saturation_mat(axis, i): 136 | batch = i.shape[0] 137 | 138 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 139 | axis = torch.tensor(axis + (0,)) 140 | axis = torch.ger(axis, axis) 141 | saturate = axis + (eye - axis) * i.view(-1, 1, 1) 142 | 143 | return saturate 144 | 145 | 146 | def lognormal_sample(size, mean=0, std=1): 147 | return torch.empty(size).log_normal_(mean=mean, std=std) 148 | 149 | 150 | def category_sample(size, categories): 151 | category = torch.tensor(categories) 152 | sample = torch.randint(high=len(categories), size=(size,)) 153 | 154 | return category[sample] 155 | 156 | 157 | def uniform_sample(size, low, high): 158 | return torch.empty(size).uniform_(low, high) 159 | 160 | 161 | def normal_sample(size, mean=0, std=1): 162 | return torch.empty(size).normal_(mean, std) 163 | 164 | 165 | def bernoulli_sample(size, p): 166 | return torch.empty(size).bernoulli_(p) 167 | 168 | 169 | def random_mat_apply(p, transform, prev, eye): 170 | size = transform.shape[0] 171 | select = bernoulli_sample(size, p).view(size, 1, 1) 172 | select_transform = select * transform + (1 - select) * eye 173 | 174 | return select_transform @ prev 175 | 176 | 177 | def sample_affine(p, size, height, width): 178 | G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1) 179 | eye = G 180 | 181 | # flip 182 | param = category_sample(size, (0, 1)) 183 | Gc = scale_mat(1 - 2.0 * param, torch.ones(size)) 184 | G = random_mat_apply(p, Gc, G, eye) 185 | # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n') 186 | 187 | # 90 rotate 188 | param = category_sample(size, (0, 3)) 189 | Gc = rotate_mat(-math.pi / 2 * param) 190 | G = random_mat_apply(p, Gc, G, eye) 191 | # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n') 192 | 193 | # integer translate 194 | param = uniform_sample(size, -0.125, 0.125) 195 | param_height = torch.round(param * height) / height 196 | param_width = torch.round(param * width) / width 197 | Gc = translate_mat(param_width, param_height) 198 | G = random_mat_apply(p, Gc, G, eye) 199 | # print('integer translate', G, translate_mat(param_width, param_height), sep='\n') 200 | 201 | # isotropic scale 202 | param = lognormal_sample(size, std=0.2 * math.log(2)) 203 | Gc = scale_mat(param, param) 204 | G = random_mat_apply(p, Gc, G, eye) 205 | # print('isotropic scale', G, scale_mat(param, param), sep='\n') 206 | 207 | p_rot = 1 - math.sqrt(1 - p) 208 | 209 | # pre-rotate 210 | param = uniform_sample(size, -math.pi, math.pi) 211 | Gc = rotate_mat(-param) 212 | G = random_mat_apply(p_rot, Gc, G, eye) 213 | # print('pre-rotate', G, rotate_mat(-param), sep='\n') 214 | 215 | # anisotropic scale 216 | param = lognormal_sample(size, std=0.2 * math.log(2)) 217 | Gc = scale_mat(param, 1 / param) 218 | G = random_mat_apply(p, Gc, G, eye) 219 | # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n') 220 | 221 | # post-rotate 222 | param = uniform_sample(size, -math.pi, math.pi) 223 | Gc = rotate_mat(-param) 224 | G = random_mat_apply(p_rot, Gc, G, eye) 225 | # print('post-rotate', G, rotate_mat(-param), sep='\n') 226 | 227 | # fractional translate 228 | param = normal_sample(size, std=0.125) 229 | Gc = translate_mat(param, param) 230 | G = random_mat_apply(p, Gc, G, eye) 231 | # print('fractional translate', G, translate_mat(param, param), sep='\n') 232 | 233 | return G 234 | 235 | 236 | def sample_color(p, size): 237 | C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1) 238 | eye = C 239 | axis_val = 1 / math.sqrt(3) 240 | axis = (axis_val, axis_val, axis_val) 241 | 242 | # brightness 243 | param = normal_sample(size, std=0.2) 244 | Cc = translate3d_mat(param, param, param) 245 | C = random_mat_apply(p, Cc, C, eye) 246 | 247 | # contrast 248 | param = lognormal_sample(size, std=0.5 * math.log(2)) 249 | Cc = scale3d_mat(param, param, param) 250 | C = random_mat_apply(p, Cc, C, eye) 251 | 252 | # luma flip 253 | param = category_sample(size, (0, 1)) 254 | Cc = luma_flip_mat(axis, param) 255 | C = random_mat_apply(p, Cc, C, eye) 256 | 257 | # hue rotation 258 | param = uniform_sample(size, -math.pi, math.pi) 259 | Cc = rotate3d_mat(axis, param) 260 | C = random_mat_apply(p, Cc, C, eye) 261 | 262 | # saturation 263 | param = lognormal_sample(size, std=1 * math.log(2)) 264 | Cc = saturation_mat(axis, param) 265 | C = random_mat_apply(p, Cc, C, eye) 266 | 267 | return C 268 | 269 | 270 | def make_grid(shape, x0, x1, y0, y1, device): 271 | n, c, h, w = shape 272 | grid = torch.empty(n, h, w, 3, device=device) 273 | grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device) 274 | grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1) 275 | grid[:, :, :, 2] = 1 276 | 277 | return grid 278 | 279 | 280 | def affine_grid(grid, mat): 281 | n, h, w, _ = grid.shape 282 | return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2) 283 | 284 | 285 | def get_padding(G, height, width): 286 | extreme = ( 287 | G[:, :2, :] 288 | @ torch.tensor([(-1.0, -1, 1), (-1, 1, 1), (1, -1, 1), (1, 1, 1)]).t() 289 | ) 290 | 291 | size = torch.tensor((width, height)) 292 | 293 | pad_low = ( 294 | ((extreme.min(-1).values + 1) * size) 295 | .clamp(max=0) 296 | .abs() 297 | .ceil() 298 | .max(0) 299 | .values.to(torch.int64) 300 | .tolist() 301 | ) 302 | pad_high = ( 303 | (extreme.max(-1).values * size - size) 304 | .clamp(min=0) 305 | .ceil() 306 | .max(0) 307 | .values.to(torch.int64) 308 | .tolist() 309 | ) 310 | 311 | return pad_low[0], pad_high[0], pad_low[1], pad_high[1] 312 | 313 | 314 | def try_sample_affine_and_pad(img, p, pad_k, G=None): 315 | batch, _, height, width = img.shape 316 | 317 | G_try = G 318 | 319 | while True: 320 | if G is None: 321 | G_try = sample_affine(p, batch, height, width) 322 | 323 | pad_x1, pad_x2, pad_y1, pad_y2 = get_padding( 324 | torch.inverse(G_try), height, width 325 | ) 326 | 327 | try: 328 | img_pad = F.pad( 329 | img, 330 | (pad_x1 + pad_k, pad_x2 + pad_k, pad_y1 + pad_k, pad_y2 + pad_k), 331 | mode="reflect", 332 | ) 333 | 334 | except RuntimeError: 335 | continue 336 | 337 | break 338 | 339 | return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2) 340 | 341 | 342 | def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6): 343 | kernel = antialiasing_kernel 344 | len_k = len(kernel) 345 | pad_k = (len_k + 1) // 2 346 | 347 | kernel = torch.as_tensor(kernel) 348 | kernel = torch.ger(kernel, kernel).to(img) 349 | kernel_flip = torch.flip(kernel, (0, 1)) 350 | 351 | img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad( 352 | img, p, pad_k, G 353 | ) 354 | 355 | p_ux1 = pad_x1 356 | p_ux2 = pad_x2 + 1 357 | p_uy1 = pad_y1 358 | p_uy2 = pad_y2 + 1 359 | w_p = img_pad.shape[3] - len_k + 1 360 | h_p = img_pad.shape[2] - len_k + 1 361 | h_o = img.shape[2] 362 | w_o = img.shape[3] 363 | 364 | img_2x = upfirdn2d(img_pad, kernel_flip, up=2) 365 | 366 | grid = make_grid( 367 | img_2x.shape, 368 | -2 * p_ux1 / w_o - 1, 369 | 2 * (w_p - p_ux1) / w_o - 1, 370 | -2 * p_uy1 / h_o - 1, 371 | 2 * (h_p - p_uy1) / h_o - 1, 372 | device=img_2x.device, 373 | ).to(img_2x) 374 | grid = affine_grid(grid, torch.inverse(G)[:, :2, :].to(img_2x)) 375 | grid = grid * torch.tensor( 376 | [w_o / w_p, h_o / h_p], device=grid.device 377 | ) + torch.tensor( 378 | [(w_o + 2 * p_ux1) / w_p - 1, (h_o + 2 * p_uy1) / h_p - 1], device=grid.device 379 | ) 380 | 381 | img_affine = F.grid_sample( 382 | img_2x, grid, mode="bilinear", align_corners=False, padding_mode="zeros" 383 | ) 384 | 385 | img_down = upfirdn2d(img_affine, kernel, down=2) 386 | 387 | end_y = -pad_y2 - 1 388 | if end_y == 0: 389 | end_y = img_down.shape[2] 390 | 391 | end_x = -pad_x2 - 1 392 | if end_x == 0: 393 | end_x = img_down.shape[3] 394 | 395 | img = img_down[:, :, pad_y1:end_y, pad_x1:end_x] 396 | 397 | return img, G 398 | 399 | 400 | def apply_color(img, mat): 401 | batch = img.shape[0] 402 | img = img.permute(0, 2, 3, 1) 403 | mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3) 404 | mat_add = mat[:, :3, 3].view(batch, 1, 1, 3) 405 | img = img @ mat_mul + mat_add 406 | img = img.permute(0, 3, 1, 2) 407 | 408 | return img 409 | 410 | 411 | def random_apply_color(img, p, C=None): 412 | if C is None: 413 | C = sample_color(p, img.shape[0]) 414 | 415 | img = apply_color(img, C.to(img)) 416 | 417 | return img, C 418 | 419 | 420 | def augment(img, p, transform_matrix=(None, None)): 421 | img, G = random_apply_affine(img, p, transform_matrix[0]) 422 | img, C = random_apply_color(img, p, transform_matrix[1]) 423 | 424 | return img, (G, C) 425 | -------------------------------------------------------------------------------- /src/utils/ada_op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /src/utils/ada_op/fused_act.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Kim Seonghyeon 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | 26 | import os 27 | 28 | import torch 29 | from torch import nn 30 | from torch.nn import functional as F 31 | from torch.autograd import Function 32 | from torch.utils.cpp_extension import load 33 | 34 | 35 | module_path = os.path.dirname(__file__) 36 | fused = load( 37 | "fused", 38 | sources=[ 39 | os.path.join(module_path, "fused_bias_act.cpp"), 40 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 41 | ], 42 | ) 43 | 44 | 45 | class FusedLeakyReLUFunctionBackward(Function): 46 | @staticmethod 47 | def forward(ctx, grad_output, out, negative_slope, scale): 48 | ctx.save_for_backward(out) 49 | ctx.negative_slope = negative_slope 50 | ctx.scale = scale 51 | 52 | empty = grad_output.new_empty(0) 53 | 54 | grad_input = fused.fused_bias_act( 55 | grad_output, empty, out, 3, 1, negative_slope, scale 56 | ) 57 | 58 | dim = [0] 59 | 60 | if grad_input.ndim > 2: 61 | dim += list(range(2, grad_input.ndim)) 62 | 63 | grad_bias = grad_input.sum(dim).detach() 64 | 65 | return grad_input, grad_bias 66 | 67 | @staticmethod 68 | def backward(ctx, gradgrad_input, gradgrad_bias): 69 | out, = ctx.saved_tensors 70 | gradgrad_out = fused.fused_bias_act( 71 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 72 | ) 73 | 74 | return gradgrad_out, None, None, None 75 | 76 | 77 | class FusedLeakyReLUFunction(Function): 78 | @staticmethod 79 | def forward(ctx, input, bias, negative_slope, scale): 80 | empty = input.new_empty(0) 81 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 82 | ctx.save_for_backward(out) 83 | ctx.negative_slope = negative_slope 84 | ctx.scale = scale 85 | 86 | return out 87 | 88 | @staticmethod 89 | def backward(ctx, grad_output): 90 | out, = ctx.saved_tensors 91 | 92 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 93 | grad_output, out, ctx.negative_slope, ctx.scale 94 | ) 95 | 96 | return grad_input, grad_bias, None, None 97 | 98 | 99 | class FusedLeakyReLU(nn.Module): 100 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 101 | super().__init__() 102 | 103 | self.bias = nn.Parameter(torch.zeros(channel)) 104 | self.negative_slope = negative_slope 105 | self.scale = scale 106 | 107 | def forward(self, input): 108 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 109 | 110 | 111 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 112 | if input.device.type == "cpu": 113 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 114 | return ( 115 | F.leaky_relu( 116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 117 | ) 118 | * scale 119 | ) 120 | 121 | else: 122 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 123 | -------------------------------------------------------------------------------- /src/utils/ada_op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | MIT License 3 | 4 | Copyright (c) 2019 Kim Seonghyeon 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | */ 24 | 25 | 26 | #include 27 | 28 | 29 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 30 | int act, int grad, float alpha, float scale); 31 | 32 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 33 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 34 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 35 | 36 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 37 | int act, int grad, float alpha, float scale) { 38 | CHECK_CUDA(input); 39 | CHECK_CUDA(bias); 40 | 41 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 42 | } 43 | 44 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 45 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 46 | } 47 | -------------------------------------------------------------------------------- /src/utils/ada_op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } 100 | -------------------------------------------------------------------------------- /src/utils/ada_op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | MIT License 3 | 4 | Copyright (c) 2019 Kim Seonghyeon 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | */ 24 | 25 | 26 | #include 27 | 28 | 29 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 30 | int up_x, int up_y, int down_x, int down_y, 31 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 32 | 33 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 34 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 35 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 36 | 37 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 38 | int up_x, int up_y, int down_x, int down_y, 39 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 40 | CHECK_CUDA(input); 41 | CHECK_CUDA(kernel); 42 | 43 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 44 | } 45 | 46 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 47 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 48 | } 49 | -------------------------------------------------------------------------------- /src/utils/ada_op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Kim Seonghyeon 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | 26 | import os 27 | 28 | import torch 29 | from torch.nn import functional as F 30 | from torch.autograd import Function 31 | from torch.utils.cpp_extension import load 32 | 33 | 34 | module_path = os.path.dirname(__file__) 35 | upfirdn2d_op = load( 36 | "upfirdn2d", 37 | sources=[ 38 | os.path.join(module_path, "upfirdn2d.cpp"), 39 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 40 | ], 41 | ) 42 | 43 | 44 | class UpFirDn2dBackward(Function): 45 | @staticmethod 46 | def forward( 47 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 48 | ): 49 | 50 | up_x, up_y = up 51 | down_x, down_y = down 52 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 53 | 54 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 55 | 56 | grad_input = upfirdn2d_op.upfirdn2d( 57 | grad_output, 58 | grad_kernel, 59 | down_x, 60 | down_y, 61 | up_x, 62 | up_y, 63 | g_pad_x0, 64 | g_pad_x1, 65 | g_pad_y0, 66 | g_pad_y1, 67 | ) 68 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 69 | 70 | ctx.save_for_backward(kernel) 71 | 72 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 73 | 74 | ctx.up_x = up_x 75 | ctx.up_y = up_y 76 | ctx.down_x = down_x 77 | ctx.down_y = down_y 78 | ctx.pad_x0 = pad_x0 79 | ctx.pad_x1 = pad_x1 80 | ctx.pad_y0 = pad_y0 81 | ctx.pad_y1 = pad_y1 82 | ctx.in_size = in_size 83 | ctx.out_size = out_size 84 | 85 | return grad_input 86 | 87 | @staticmethod 88 | def backward(ctx, gradgrad_input): 89 | kernel, = ctx.saved_tensors 90 | 91 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 92 | 93 | gradgrad_out = upfirdn2d_op.upfirdn2d( 94 | gradgrad_input, 95 | kernel, 96 | ctx.up_x, 97 | ctx.up_y, 98 | ctx.down_x, 99 | ctx.down_y, 100 | ctx.pad_x0, 101 | ctx.pad_x1, 102 | ctx.pad_y0, 103 | ctx.pad_y1, 104 | ) 105 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 106 | gradgrad_out = gradgrad_out.view( 107 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 108 | ) 109 | 110 | return gradgrad_out, None, None, None, None, None, None, None, None 111 | 112 | 113 | class UpFirDn2d(Function): 114 | @staticmethod 115 | def forward(ctx, input, kernel, up, down, pad): 116 | up_x, up_y = up 117 | down_x, down_y = down 118 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 119 | 120 | kernel_h, kernel_w = kernel.shape 121 | batch, channel, in_h, in_w = input.shape 122 | ctx.in_size = input.shape 123 | 124 | input = input.reshape(-1, in_h, in_w, 1) 125 | 126 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 127 | 128 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 129 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 130 | ctx.out_size = (out_h, out_w) 131 | 132 | ctx.up = (up_x, up_y) 133 | ctx.down = (down_x, down_y) 134 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 135 | 136 | g_pad_x0 = kernel_w - pad_x0 - 1 137 | g_pad_y0 = kernel_h - pad_y0 - 1 138 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 139 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 140 | 141 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 142 | 143 | out = upfirdn2d_op.upfirdn2d( 144 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 145 | ) 146 | # out = out.view(major, out_h, out_w, minor) 147 | out = out.view(-1, channel, out_h, out_w) 148 | 149 | return out 150 | 151 | @staticmethod 152 | def backward(ctx, grad_output): 153 | kernel, grad_kernel = ctx.saved_tensors 154 | 155 | grad_input = UpFirDn2dBackward.apply( 156 | grad_output, 157 | kernel, 158 | grad_kernel, 159 | ctx.up, 160 | ctx.down, 161 | ctx.pad, 162 | ctx.g_pad, 163 | ctx.in_size, 164 | ctx.out_size, 165 | ) 166 | 167 | return grad_input, None, None, None, None 168 | 169 | 170 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 171 | if input.device.type == "cpu": 172 | out = upfirdn2d_native( 173 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 174 | ) 175 | 176 | else: 177 | out = UpFirDn2d.apply( 178 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 179 | ) 180 | 181 | return out 182 | 183 | 184 | def upfirdn2d_native( 185 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 186 | ): 187 | _, channel, in_h, in_w = input.shape 188 | input = input.reshape(-1, in_h, in_w, 1) 189 | 190 | _, in_h, in_w, minor = input.shape 191 | kernel_h, kernel_w = kernel.shape 192 | 193 | out = input.view(-1, in_h, 1, in_w, 1, minor) 194 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 195 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 196 | 197 | out = F.pad( 198 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 199 | ) 200 | out = out[ 201 | :, 202 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 203 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 204 | :, 205 | ] 206 | 207 | out = out.permute(0, 3, 1, 2) 208 | out = out.reshape( 209 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 210 | ) 211 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 212 | out = F.conv2d(out, w) 213 | out = out.reshape( 214 | -1, 215 | minor, 216 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 217 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 218 | ) 219 | out = out.permute(0, 2, 3, 1) 220 | out = out[:, ::down_y, ::down_x, :] 221 | 222 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 223 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 224 | 225 | return out.view(-1, channel, out_h, out_w) 226 | -------------------------------------------------------------------------------- /src/utils/ada_op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } 370 | -------------------------------------------------------------------------------- /src/utils/biggan_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch 3 | 4 | MIT License 5 | 6 | Copyright (c) 2019 Andy Brock 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | """ 26 | 27 | 28 | import random 29 | 30 | from utils.sample import sample_latents 31 | 32 | import torch 33 | 34 | 35 | 36 | class ema(object): 37 | def __init__(self, source, target, decay=0.9999, start_itr=0): 38 | self.source = source 39 | self.target = target 40 | self.decay = decay 41 | # Optional parameter indicating what iteration to start the decay at 42 | self.start_itr = start_itr 43 | # Initialize target's params to be source's 44 | self.source_dict = self.source.state_dict() 45 | self.target_dict = self.target.state_dict() 46 | print('Initializing EMA parameters to be source parameters...') 47 | with torch.no_grad(): 48 | for key in self.source_dict: 49 | self.target_dict[key].data.copy_(self.source_dict[key].data) 50 | 51 | def update(self, itr=None): 52 | # If an iteration counter is provided and itr is less than the start itr, 53 | # peg the ema weights to the underlying weights. 54 | if itr >= 0 and itr < self.start_itr: 55 | decay = 0.0 56 | else: 57 | decay = self.decay 58 | with torch.no_grad(): 59 | for key in self.source_dict: 60 | self.target_dict[key].data.copy_(self.target_dict[key].data * decay + self.source_dict[key].data * (1 - decay)) 61 | 62 | 63 | class ema_DP_SyncBN(object): 64 | def __init__(self, source, target, decay=0.9999, start_itr=0): 65 | self.source = source 66 | self.target = target 67 | self.decay = decay 68 | self.start_itr = start_itr 69 | # Initialize target's params to be source's 70 | print('Initializing EMA parameters to be source parameters...') 71 | with torch.no_grad(): 72 | for key in self.source.state_dict(): 73 | self.target.state_dict()[key].data.copy_(self.source.state_dict()[key].data) 74 | 75 | 76 | def update(self, itr=None): 77 | # If an iteration counter is provided and itr is less than the start itr, 78 | # peg the ema weights to the underlying weights. 79 | if itr >= 0 and itr < self.start_itr: 80 | decay = 0.0 81 | else: 82 | decay = self.decay 83 | with torch.no_grad(): 84 | for key in self.source.state_dict(): 85 | data = self.target.state_dict()[key].data*decay + self.source.state_dict()[key].data*(1.-decay) 86 | self.target.state_dict()[key].data.copy_(data) 87 | 88 | 89 | def ortho(model, strength=1e-4, blacklist=[]): 90 | with torch.no_grad(): 91 | for param in model.parameters(): 92 | # Only apply this to parameters with at least 2 axes, and not in the blacklist 93 | if len(param.shape) < 2 or any([param is item for item in blacklist]): 94 | continue 95 | w = param.view(param.shape[0], -1) 96 | grad = (2 * torch.mm(torch.mm(w, w.t()) 97 | * (1. - torch.eye(w.shape[0], device=w.device)), w)) 98 | param.grad.data += strength * grad.view(param.shape) 99 | 100 | 101 | # Interp function; expects x0 and x1 to be of shape (shape0, 1, rest_of_shape..) 102 | def interp(x0, x1, num_midpoints): 103 | lerp = torch.linspace(0, 1.0, num_midpoints + 2, device='cuda').to(x0.dtype) 104 | return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1))) 105 | -------------------------------------------------------------------------------- /src/utils/cr_diff_aug.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/utils/cr_diff_aug.py 6 | 7 | 8 | import random 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | 14 | 15 | def CR_DiffAug(x, flip=True, translation=True): 16 | if flip: 17 | x = random_flip(x, 0.5) 18 | if translation: 19 | x = random_translation(x, 1/8) 20 | if flip or translation: 21 | x = x.contiguous() 22 | return x 23 | 24 | 25 | def random_flip(x, p): 26 | x_out = x.clone() 27 | n, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] 28 | flip_prob = torch.FloatTensor(n, 1).uniform_(0.0, 1.0) 29 | flip_mask = flip_prob < p 30 | flip_mask = flip_mask.type(torch.bool).view(n, 1, 1, 1).repeat(1, c, h, w).to(x.device) 31 | x_out[flip_mask] = torch.flip(x[flip_mask].view(-1, c, h, w), [3]).view(-1) 32 | return x_out 33 | 34 | 35 | def random_translation(x, ratio): 36 | max_t_x, max_t_y = int(x.shape[2]*ratio), int(x.shape[3]*ratio) 37 | t_x = torch.randint(-max_t_x, max_t_x + 1, size = [x.shape[0], 1, 1], device=x.device) 38 | t_y = torch.randint(-max_t_y, max_t_y + 1, size = [x.shape[0], 1, 1], device=x.device) 39 | 40 | grid_batch, grid_x, grid_y = torch.meshgrid( 41 | torch.arange(x.shape[0], dtype=torch.long, device=x.device), 42 | torch.arange(x.shape[2], dtype=torch.long, device=x.device), 43 | torch.arange(x.shape[3], dtype=torch.long, device=x.device), 44 | ) 45 | 46 | grid_x = (grid_x + t_x) + max_t_x 47 | grid_y = (grid_y + t_y) + max_t_y 48 | x_pad = F.pad(input=x, pad=[max_t_x, max_t_x, max_t_y, max_t_y], mode='reflect') 49 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 50 | return x 51 | -------------------------------------------------------------------------------- /src/utils/diff_aug.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | """ 26 | 27 | 28 | import torch 29 | import torch.nn.functional as F 30 | 31 | 32 | 33 | ### Differentiable Augmentation for Data-Efficient GAN Training (https://arxiv.org/abs/2006.10738) 34 | ### Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 35 | ### https://github.com/mit-han-lab/data-efficient-gans 36 | 37 | 38 | def DiffAugment(x, policy='', channels_first=True): 39 | if policy: 40 | if not channels_first: 41 | x = x.permute(0, 3, 1, 2) 42 | for p in policy.split(','): 43 | for f in AUGMENT_FNS[p]: 44 | x = f(x) 45 | if not channels_first: 46 | x = x.permute(0, 2, 3, 1) 47 | x = x.contiguous() 48 | return x 49 | 50 | 51 | def rand_brightness(x): 52 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 53 | return x 54 | 55 | 56 | def rand_saturation(x): 57 | x_mean = x.mean(dim=1, keepdim=True) 58 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 59 | return x 60 | 61 | 62 | def rand_contrast(x): 63 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 64 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 65 | return x 66 | 67 | 68 | def rand_translation(x, ratio=0.125): 69 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 70 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 71 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 72 | grid_batch, grid_x, grid_y = torch.meshgrid( 73 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 74 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 75 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 76 | ) 77 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 78 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 79 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 80 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 81 | return x 82 | 83 | 84 | def rand_cutout(x, ratio=0.5): 85 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 86 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 87 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 88 | grid_batch, grid_x, grid_y = torch.meshgrid( 89 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 90 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 91 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 92 | ) 93 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 94 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 95 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 96 | mask[grid_batch, grid_x, grid_y] = 0 97 | x = x * mask.unsqueeze(1) 98 | return x 99 | 100 | 101 | AUGMENT_FNS = { 102 | 'color': [rand_brightness, rand_saturation, rand_contrast], 103 | 'translation': [rand_translation], 104 | 'cutout': [rand_cutout], 105 | } 106 | -------------------------------------------------------------------------------- /src/utils/load_checkpoint.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/utils/load_checkpoint.py 6 | 7 | 8 | import os 9 | 10 | import torch 11 | 12 | 13 | 14 | def load_checkpoint(model, optimizer, filename, metric=False, ema=False): 15 | start_step = 0 16 | if ema: 17 | checkpoint = torch.load(filename) 18 | model.load_state_dict(checkpoint['state_dict']) 19 | return model 20 | else: 21 | checkpoint = torch.load(filename) 22 | seed = checkpoint['seed'] 23 | run_name = checkpoint['run_name'] 24 | start_step = checkpoint['step'] 25 | model.load_state_dict(checkpoint['state_dict']) 26 | optimizer.load_state_dict(checkpoint['optimizer']) 27 | ada_p = checkpoint['ada_p'] 28 | for state in optimizer.state.values(): 29 | for k, v in state.items(): 30 | if isinstance(v, torch.Tensor): 31 | state[k] = v.cuda() 32 | 33 | if metric: 34 | best_step = checkpoint['best_step'] 35 | best_fid = checkpoint['best_fid'] 36 | best_fid_checkpoint_path = checkpoint['best_fid_checkpoint_path'] 37 | return model, optimizer, seed, run_name, start_step, ada_p, best_step, best_fid, best_fid_checkpoint_path 38 | return model, optimizer, seed, run_name, start_step, ada_p 39 | -------------------------------------------------------------------------------- /src/utils/log.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/utils/log.py 6 | 7 | 8 | import json 9 | import os 10 | import logging 11 | from os.path import dirname, abspath, exists, join 12 | from datetime import datetime 13 | 14 | 15 | 16 | def make_run_name(format, framework, phase): 17 | return format.format( 18 | framework=framework, 19 | phase=phase, 20 | timestamp=datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 21 | ) 22 | 23 | 24 | def make_logger(run_name, log_output): 25 | if log_output is not None: 26 | run_name = log_output.split('/')[-1].split('.')[0] 27 | logger = logging.getLogger(run_name) 28 | logger.propagate = False 29 | log_filepath = log_output if log_output is not None else join('logs', f'{run_name}.log') 30 | 31 | log_dir = dirname(abspath(log_filepath)) 32 | if not exists(log_dir): 33 | os.makedirs(log_dir) 34 | 35 | if not logger.handlers: # execute only if logger doesn't already exist 36 | file_handler = logging.FileHandler(log_filepath, 'a', 'utf-8') 37 | stream_handler = logging.StreamHandler(os.sys.stdout) 38 | 39 | formatter = logging.Formatter('[%(levelname)s] %(asctime)s > %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 40 | 41 | file_handler.setFormatter(formatter) 42 | stream_handler.setFormatter(formatter) 43 | 44 | logger.addHandler(file_handler) 45 | logger.addHandler(stream_handler) 46 | logger.setLevel(logging.INFO) 47 | return logger 48 | 49 | 50 | def make_checkpoint_dir(checkpoint_dir, run_name): 51 | checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else join('checkpoints', run_name) 52 | if not exists(abspath(checkpoint_dir)): 53 | os.makedirs(checkpoint_dir) 54 | return checkpoint_dir 55 | -------------------------------------------------------------------------------- /src/utils/make_hdf5.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch 3 | 4 | MIT License 5 | 6 | Copyright (c) 2019 Andy Brock 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | """ 23 | 24 | 25 | import os 26 | import sys 27 | import h5py as h5 28 | import numpy as np 29 | import PIL 30 | from argparse import ArgumentParser 31 | from tqdm import tqdm, trange 32 | 33 | from data_utils.load_dataset import LoadDataset 34 | 35 | import torch 36 | import torchvision.transforms as transforms 37 | from torch.utils.data import DataLoader 38 | 39 | 40 | 41 | def make_hdf5(model_config, train_config, mode): 42 | if 'hdf5' in model_config['dataset_name']: 43 | raise ValueError('Reading from an HDF5 file which you will probably be ' 44 | 'about to overwrite! Override this error only if you know ' 45 | 'what you''re doing!') 46 | 47 | file_name = '{dataset_name}_{size}_{mode}.hdf5'.format(dataset_name=model_config['dataset_name'], size=model_config['img_size'], mode=mode) 48 | file_path = os.path.join(model_config['data_path'], file_name) 49 | train = True if mode == "train" else False 50 | 51 | if os.path.isfile(file_path): 52 | print("{file_name} exist!\nThe file are located in the {file_path}".format(file_name=file_name, file_path=file_path)) 53 | else: 54 | dataset = LoadDataset(model_config['dataset_name'], model_config['data_path'], train=train, download=True, resize_size=model_config['img_size'], 55 | hdf5_path=None, random_flip=False) 56 | 57 | loader = DataLoader(dataset, 58 | batch_size=model_config['batch_size4prcsing'], 59 | shuffle=False, 60 | pin_memory=False, 61 | num_workers=train_config['num_workers'], 62 | drop_last=False) 63 | 64 | print('Starting to load %s into an HDF5 file with chunk size %i and compression %s...' % (model_config['dataset_name'], 65 | model_config['chunk_size'], 66 | model_config['compression'])) 67 | # Loop over loader 68 | for i,(x,y) in enumerate(tqdm(loader)): 69 | # Numpyify x, y 70 | x = (255 * ((x + 1) / 2.0)).byte().numpy() 71 | y = y.numpy() 72 | # If we're on the first batch, prepare the hdf5 73 | if i==0: 74 | with h5.File(file_path, 'w') as f: 75 | print('Producing dataset of len %d' % len(loader.dataset)) 76 | imgs_dset = f.create_dataset('imgs', x.shape, dtype='uint8', maxshape=(len(loader.dataset), 3, 77 | model_config['img_size'], model_config['img_size']), 78 | chunks=(model_config['chunk_size'], 3, model_config['img_size'], model_config['img_size']), compression=model_config['compression']) 79 | print('Image chunks chosen as ' + str(imgs_dset.chunks)) 80 | imgs_dset[...] = x 81 | 82 | labels_dset = f.create_dataset('labels', y.shape, dtype='int64', maxshape=(len(loader.dataset),), 83 | chunks=(model_config['chunk_size'],), compression=model_config['compression']) 84 | print('Label chunks chosen as ' + str(labels_dset.chunks)) 85 | labels_dset[...] = y 86 | # Else append to the hdf5 87 | else: 88 | with h5.File(file_path, 'a') as f: 89 | f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0) 90 | f['imgs'][-x.shape[0]:] = x 91 | f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0) 92 | f['labels'][-y.shape[0]:] = y 93 | return file_path 94 | -------------------------------------------------------------------------------- /src/utils/model_ops.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/utils/model_ops.py 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn.utils import spectral_norm 11 | from torch.nn import init 12 | 13 | 14 | 15 | def init_weights(modules, initialize): 16 | for module in modules(): 17 | if (isinstance(module, nn.Conv2d) 18 | or isinstance(module, nn.ConvTranspose2d) 19 | or isinstance(module, nn.Linear)): 20 | if initialize == 'ortho': 21 | init.orthogonal_(module.weight) 22 | if module.bias is not None: 23 | module.bias.data.fill_(0.) 24 | elif initialize == 'N02': 25 | init.normal_(module.weight, 0, 0.02) 26 | if module.bias is not None: 27 | module.bias.data.fill_(0.) 28 | elif initialize in ['glorot', 'xavier']: 29 | init.xavier_uniform_(module.weight) 30 | if module.bias is not None: 31 | module.bias.data.fill_(0.) 32 | else: 33 | print('Init style not recognized...') 34 | elif isinstance(module, nn.Embedding): 35 | if initialize == 'ortho': 36 | init.orthogonal_(module.weight) 37 | elif initialize == 'N02': 38 | init.normal_(module.weight, 0, 0.02) 39 | elif initialize in ['glorot', 'xavier']: 40 | init.xavier_uniform_(module.weight) 41 | else: 42 | print('Init style not recognized...') 43 | else: 44 | pass 45 | 46 | 47 | def conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): 48 | return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 49 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 50 | 51 | def deconv2d(in_channels, out_channels, kernel_size, stride=2, padding=0, dilation=1, groups=1, bias=True): 52 | return nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 53 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 54 | 55 | def linear(in_features, out_features, bias=True): 56 | return nn.Linear(in_features=in_features, out_features=out_features, bias=bias) 57 | 58 | def embedding(num_embeddings, embedding_dim): 59 | return nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) 60 | 61 | def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): 62 | return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 63 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias), eps=1e-6) 64 | 65 | def sndeconv2d(in_channels, out_channels, kernel_size, stride=2, padding=0, dilation=1, groups=1, bias=True): 66 | return spectral_norm(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 67 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias), eps=1e-6) 68 | 69 | def snlinear(in_features, out_features, bias=True): 70 | return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features, bias=bias), eps=1e-6) 71 | 72 | def sn_embedding(num_embeddings, embedding_dim): 73 | return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim), eps=1e-6) 74 | 75 | def batchnorm_2d(in_features, eps=1e-4, momentum=0.1, affine=True): 76 | return nn.BatchNorm2d(in_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=True) 77 | 78 | 79 | class ConditionalBatchNorm2d(nn.Module): 80 | # https://github.com/voletiv/self-attention-GAN-pytorch 81 | def __init__(self, num_features, num_classes, spectral_norm): 82 | super().__init__() 83 | self.num_features = num_features 84 | self.bn = batchnorm_2d(num_features, eps=1e-4, momentum=0.1, affine=False) 85 | 86 | if spectral_norm: 87 | self.embed0 = sn_embedding(num_classes, num_features) 88 | self.embed1 = sn_embedding(num_classes, num_features) 89 | else: 90 | self.embed0 = embedding(num_classes, num_features) 91 | self.embed1 = embedding(num_classes, num_features) 92 | 93 | def forward(self, x, y): 94 | gain = (1 + self.embed0(y)).view(-1, self.num_features, 1, 1) 95 | bias = self.embed1(y).view(-1, self.num_features, 1, 1) 96 | out = self.bn(x) 97 | return out * gain + bias 98 | 99 | 100 | class ConditionalBatchNorm2d_for_skip_and_shared(nn.Module): 101 | # https://github.com/voletiv/self-attention-GAN-pytorch 102 | def __init__(self, num_features, z_dims_after_concat, spectral_norm): 103 | super().__init__() 104 | self.num_features = num_features 105 | self.bn = batchnorm_2d(num_features, eps=1e-4, momentum=0.1, affine=False) 106 | 107 | if spectral_norm: 108 | self.gain = snlinear(z_dims_after_concat, num_features, bias=False) 109 | self.bias = snlinear(z_dims_after_concat, num_features, bias=False) 110 | else: 111 | self.gain = linear(z_dims_after_concat, num_features, bias=False) 112 | self.bias = linear(z_dims_after_concat, num_features, bias=False) 113 | 114 | def forward(self, x, y): 115 | gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) 116 | bias = self.bias(y).view(y.size(0), -1, 1, 1) 117 | out = self.bn(x) 118 | return out * gain + bias 119 | 120 | 121 | class Self_Attn(nn.Module): 122 | # https://github.com/voletiv/self-attention-GAN-pytorch 123 | def __init__(self, in_channels, spectral_norm): 124 | super(Self_Attn, self).__init__() 125 | self.in_channels = in_channels 126 | 127 | if spectral_norm: 128 | self.conv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0, bias=False) 129 | self.conv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0, bias=False) 130 | self.conv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0, bias=False) 131 | self.conv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0, bias=False) 132 | else: 133 | self.conv1x1_theta = conv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0, bias=False) 134 | self.conv1x1_phi = conv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0, bias=False) 135 | self.conv1x1_g = conv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0, bias=False) 136 | self.conv1x1_attn = conv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0, bias=False) 137 | 138 | self.maxpool = nn.MaxPool2d(2, stride=2, padding=0) 139 | self.softmax = nn.Softmax(dim=-1) 140 | self.sigma = nn.Parameter(torch.zeros(1)) 141 | 142 | def forward(self, x): 143 | """ 144 | inputs : 145 | x : input feature maps(B X C X H X W) 146 | returns : 147 | out : self attention value + input feature 148 | attention: B X N X N (N is Width*Height) 149 | """ 150 | _, ch, h, w = x.size() 151 | # Theta path 152 | theta = self.conv1x1_theta(x) 153 | theta = theta.view(-1, ch//8, h*w) 154 | # Phi path 155 | phi = self.conv1x1_phi(x) 156 | phi = self.maxpool(phi) 157 | phi = phi.view(-1, ch//8, h*w//4) 158 | # Attn map 159 | attn = torch.bmm(theta.permute(0, 2, 1), phi) 160 | attn = self.softmax(attn) 161 | # g path 162 | g = self.conv1x1_g(x) 163 | g = self.maxpool(g) 164 | g = g.view(-1, ch//2, h*w//4) 165 | # Attn_g 166 | attn_g = torch.bmm(g, attn.permute(0, 2, 1)) 167 | attn_g = attn_g.view(-1, ch//2, h, w) 168 | attn_g = self.conv1x1_attn(attn_g) 169 | return x + self.sigma*attn_g 170 | 171 | -------------------------------------------------------------------------------- /src/utils/sample.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/utils/sample.py 6 | 7 | 8 | import numpy as np 9 | import random 10 | from numpy import linalg 11 | from math import sin,cos,sqrt 12 | 13 | from utils.losses import latent_optimise 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | from torch.nn import DataParallel 18 | 19 | 20 | 21 | def sample_latents(dist, batch_size, dim, truncated_factor=1, num_classes=None, perturb=None, device=torch.device("cpu"), sampler="default"): 22 | if num_classes: 23 | if sampler == "default": 24 | y_fake = torch.randint(low=0, high=num_classes, size=(batch_size,), dtype=torch.long, device=device) 25 | elif sampler == "class_order_some": 26 | assert batch_size % 8 == 0, "The size of the batches should be a multiple of 8." 27 | num_classes_plot = batch_size//8 28 | indices = np.random.permutation(num_classes)[:num_classes_plot] 29 | elif sampler == "class_order_all": 30 | batch_size = num_classes*8 31 | indices = [c for c in range(num_classes)] 32 | elif isinstance(sampler, int): 33 | y_fake = torch.tensor([sampler]*batch_size, dtype=torch.long).to(device) 34 | else: 35 | raise NotImplementedError 36 | 37 | if sampler in ["class_order_some", "class_order_all"]: 38 | y_fake = [] 39 | for idx in indices: 40 | y_fake += [idx]*8 41 | y_fake = torch.tensor(y_fake, dtype=torch.long).to(device) 42 | else: 43 | y_fake = None 44 | 45 | if isinstance(perturb, float) and perturb > 0.0: 46 | if dist == "gaussian": 47 | latents = torch.randn(batch_size, dim, device=device)/truncated_factor 48 | eps = perturb*torch.randn(batch_size, dim, device=device) 49 | latents_eps = latents + eps 50 | elif dist == "uniform": 51 | latents = torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device) 52 | eps = perturb*torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device) 53 | latents_eps = latents + eps 54 | elif dist == "hyper_sphere": 55 | latents, latents_eps = random_ball(batch_size, dim, perturb=perturb) 56 | latents, latents_eps = torch.FloatTensor(latents).to(device), torch.FloatTensor(latents_eps).to(device) 57 | return latents, y_fake, latents_eps 58 | else: 59 | if dist == "gaussian": 60 | latents = torch.randn(batch_size, dim, device=device)/truncated_factor 61 | elif dist == "uniform": 62 | latents = torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device) 63 | elif dist == "hyper_sphere": 64 | latents = random_ball(batch_size, dim, perturb=perturb).to(device) 65 | return latents, y_fake 66 | 67 | 68 | def random_ball(batch_size, z_dim, perturb=False): 69 | if perturb: 70 | normal = np.random.normal(size=(z_dim, batch_size)) 71 | random_directions = normal/linalg.norm(normal, axis=0) 72 | random_radii = random.random(batch_size) ** (1/z_dim) 73 | zs = 1.0 * (random_directions * random_radii).T 74 | 75 | normal_perturb = normal + 0.05*np.random.normal(size=(z_dim, batch_size)) 76 | perturb_random_directions = normal_perturb/linalg.norm(normal_perturb, axis=0) 77 | perturb_random_radii = random.random(batch_size) ** (1/z_dim) 78 | zs_perturb = 1.0 * (perturb_random_directions * perturb_random_radii).T 79 | return zs, zs_perturb 80 | else: 81 | normal = np.random.normal(size=(z_dim, batch_size)) 82 | random_directions = normal/linalg.norm(normal, axis=0) 83 | random_radii = random.random(batch_size) ** (1/z_dim) 84 | zs = 1.0 * (random_directions * random_radii).T 85 | return zs 86 | 87 | 88 | # Convenience function to sample an index, not actually a 1-hot 89 | def sample_1hot(batch_size, num_classes, device='cuda'): 90 | return torch.randint(low=0, high=num_classes, size=(batch_size,), 91 | device=device, dtype=torch.int64, requires_grad=False) 92 | 93 | 94 | def make_mask(labels, n_cls, device): 95 | labels = labels.detach().cpu().numpy() 96 | n_samples = labels.shape[0] 97 | mask_multi = np.zeros([n_cls, n_samples]) 98 | for c in range(n_cls): 99 | c_indices = np.where(labels==c) 100 | mask_multi[c, c_indices] =+1 101 | 102 | mask_multi = torch.tensor(mask_multi).type(torch.long) 103 | return mask_multi.to(device) 104 | 105 | 106 | def target_class_sampler(dataset, target_class): 107 | try: 108 | targets = dataset.data.targets 109 | except: 110 | targets = dataset.labels 111 | weights = [True if target == target_class else False for target in targets] 112 | num_samples = sum(weights) 113 | weights = torch.DoubleTensor(weights) 114 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights), replacement=False) 115 | return num_samples, sampler 116 | --------------------------------------------------------------------------------