├── .gitignore
├── LICENSE
├── LICENSE-NVIDIA
├── README.md
├── environment.yml
├── images
├── ECGAN.png
└── imagenet.png
├── logs
├── CIFAR10
│ └── ecgan_v2_none_1_0p01-train-2021_05_26_16_35_45.log
├── ILSVRC2012
│ └── imagenet_ecgan_v2_contra_1_0p05-train-2021_10_03_00_11_58.log
└── TINY_ILSVRC2012
│ └── ecgan_v2_none_1_0p05-train-2021_05_26_16_47_55.log
└── src
├── configs
├── CIFAR10
│ └── ecgan_v2_none_1_0p01.json
├── ILSVRC2012
│ └── imagenet_ecgan_v2_contra_1_0p05.json
└── TINY_ILSVRC2012
│ └── ecgan_v2_none_1_0p05.json
├── data_utils
└── load_dataset.py
├── inception_tf13.py
├── loader.py
├── main.py
├── metrics
├── Accuracy.py
├── CAS.py
├── DCA.py
├── FID.py
├── F_beta.py
├── IS.py
├── IntraFID.py
├── __init__.py
├── inception_network.py
└── prepare_inception_moments.py
├── models
├── big_resnet.py
├── big_resnet_deep.py
├── dcgan.py
└── resnet.py
├── sync_batchnorm
├── batchnorm.py
├── batchnorm_reimpl.py
├── comm.py
├── replicate.py
└── unittest.py
├── utils
├── ada.py
├── ada_op
│ ├── __init__.py
│ ├── fused_act.py
│ ├── fused_bias_act.cpp
│ ├── fused_bias_act_kernel.cu
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.py
│ └── upfirdn2d_kernel.cu
├── biggan_utils.py
├── cr_diff_aug.py
├── diff_aug.py
├── load_checkpoint.py
├── log.py
├── losses.py
├── make_hdf5.py
├── misc.py
├── model_ops.py
└── sample.py
└── worker.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__*
2 | *.pyc
3 | *.DS_Store
4 | media/
5 | res/
6 |
7 | data/*
8 | checkpoints/*
9 | figures/*
10 | *.xlsx
11 | src/data/*
12 | src/checkpoints/*
13 | src/figures/*
14 | src/logs/*
15 | *test*
16 | src/unpublished/*
17 |
18 | # Swap
19 | [._]*.s[a-v][a-z]
20 | !*.svg # comment out if you don't need vector files
21 | [._]*.sw[a-p]
22 | [._]s[a-rt-v][a-z]
23 | [._]ss[a-gi-z]
24 | [._]sw[a-p]
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | PyTorch StudioGAN:
4 | Copyright (c) 2020 MinGuk Kang
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in
14 | all copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 | THE SOFTWARE.
23 |
--------------------------------------------------------------------------------
/LICENSE-NVIDIA:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 |
3 |
4 | Nvidia Source Code License-NC
5 |
6 | =======================================================================
7 |
8 | 1. Definitions
9 |
10 | "Licensor" means any person or entity that distributes its Work.
11 |
12 | "Software" means the original work of authorship made available under
13 | this License.
14 |
15 | "Work" means the Software and any additions to or derivative works of
16 | the Software that are made available under this License.
17 |
18 | "Nvidia Processors" means any central processing unit (CPU), graphics
19 | processing unit (GPU), field-programmable gate array (FPGA),
20 | application-specific integrated circuit (ASIC) or any combination
21 | thereof designed, made, sold, or provided by Nvidia or its affiliates.
22 |
23 | The terms "reproduce," "reproduction," "derivative works," and
24 | "distribution" have the meaning as provided under U.S. copyright law;
25 | provided, however, that for the purposes of this License, derivative
26 | works shall not include works that remain separable from, or merely
27 | link (or bind by name) to the interfaces of, the Work.
28 |
29 | Works, including the Software, are "made available" under this License
30 | by including in or with the Work either (a) a copyright notice
31 | referencing the applicability of this License to the Work, or (b) a
32 | copy of this License.
33 |
34 | 2. License Grants
35 |
36 | 2.1 Copyright Grant. Subject to the terms and conditions of this
37 | License, each Licensor grants to you a perpetual, worldwide,
38 | non-exclusive, royalty-free, copyright license to reproduce,
39 | prepare derivative works of, publicly display, publicly perform,
40 | sublicense and distribute its Work and any resulting derivative
41 | works in any form.
42 |
43 | 3. Limitations
44 |
45 | 3.1 Redistribution. You may reproduce or distribute the Work only
46 | if (a) you do so under this License, (b) you include a complete
47 | copy of this License with your distribution, and (c) you retain
48 | without modification any copyright, patent, trademark, or
49 | attribution notices that are present in the Work.
50 |
51 | 3.2 Derivative Works. You may specify that additional or different
52 | terms apply to the use, reproduction, and distribution of your
53 | derivative works of the Work ("Your Terms") only if (a) Your Terms
54 | provide that the use limitation in Section 3.3 applies to your
55 | derivative works, and (b) you identify the specific derivative
56 | works that are subject to Your Terms. Notwithstanding Your Terms,
57 | this License (including the redistribution requirements in Section
58 | 3.1) will continue to apply to the Work itself.
59 |
60 | 3.3 Use Limitation. The Work and any derivative works thereof only
61 | may be used or intended for use non-commercially. The Work or
62 | derivative works thereof may be used or intended for use by Nvidia
63 | or its affiliates commercially or non-commercially. As used herein,
64 | "non-commercially" means for research or evaluation purposes only.
65 |
66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim
67 | against any Licensor (including any claim, cross-claim or
68 | counterclaim in a lawsuit) to enforce any patents that you allege
69 | are infringed by any Work, then your rights under this License from
70 | such Licensor (including the grants in Sections 2.1 and 2.2) will
71 | terminate immediately.
72 |
73 | 3.5 Trademarks. This License does not grant any rights to use any
74 | Licensor's or its affiliates' names, logos, or trademarks, except
75 | as necessary to reproduce the notices described in this License.
76 |
77 | 3.6 Termination. If you violate any term of this License, then your
78 | rights under this License (including the grants in Sections 2.1 and
79 | 2.2) will terminate immediately.
80 |
81 | 4. Disclaimer of Warranty.
82 |
83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
87 | THIS LICENSE.
88 |
89 | 5. Limitation of Liability.
90 |
91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
99 | THE POSSIBILITY OF SUCH DAMAGES.
100 |
101 | =======================================================================
102 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Energy-based Conditional Generative Adversarial Network (ECGAN)
2 |
3 | This is the code for the NeurIPS 2021 paper "[A Unified View of cGANs with and without Classifiers](https://arxiv.org/abs/2111.01035)". The repository is modified from [StudioGAN](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN). If you find our work useful, please consider citing the following paper:
4 | ```bib
5 | @inproceedings{chen2021ECGAN,
6 | title = {A Unified View of cGANs with and without Classifiers},
7 | author = {Si-An Chen and Chun-Liang Li and Hsuan-Tien Lin},
8 | booktitle = {Advances in Neural Information Processing Systems},
9 | year = {2021}
10 | }
11 | ```
12 | Please feel free to contact [Si-An Chen](https://scholar.google.com/citations?hl=en&user=XtkmEncAAAAJ) if you have any questions about the code/paper.
13 |
14 | ## Introduction
15 | We propose a new Conditional Generative Adversarial Network (cGAN) framework called Energy-based Conditional Generative Adversarial Network (ECGAN) which provides a unified view of cGANs and achieves state-of-the-art results. We use the decomposition of the joint probability distribution to connect the goals of cGANs and classification as a unified framework. The framework, along with a classic energy model to parameterize distributions, justifies the use of classifiers for cGANs in a principled manner. It explains several popular cGAN variants, such as ACGAN, ProjGAN, and ContraGAN, as special cases with different levels of approximations. An illustration of the framework is shown below.
16 |
17 |
18 |
19 |
20 |
21 | ## Requirements
22 |
23 | - Anaconda
24 | - Python >= 3.6
25 | - 6.0.0 <= Pillow <= 7.0.0
26 | - scipy == 1.1.0 (Recommended for fast loading of [Inception Network](https://github.com/openai/improved-gan/blob/master/inception_score/model.py))
27 | - sklearn
28 | - seaborn
29 | - h5py
30 | - tqdm
31 | - torch >= 1.6.0 (Recommended for mixed precision training and knn analysis)
32 | - torchvision >= 0.7.0
33 | - tensorboard
34 | - 5.4.0 <= gcc <= 7.4.0 (Recommended for proper use of [adaptive discriminator augmentation module](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN/tree/master/src/utils/ada_op))
35 |
36 |
37 | You can install the recommended environment as follows:
38 |
39 | ```
40 | conda env create -f environment.yml -n studiogan
41 | ```
42 |
43 | With docker, you can use:
44 | ```
45 | docker pull mgkang/studiogan:0.1
46 | ```
47 |
48 | ## Quick Start
49 |
50 | * Train (``-t``) and evaluate (``-e``) the model defined in ``CONFIG_PATH`` using GPU ``0``
51 | ```
52 | CUDA_VISIBLE_DEVICES=0 python3 src/main.py -t -e -c CONFIG_PATH
53 | ```
54 |
55 | * Train (``-t``) and evaluate (``-e``) the model defined in ``CONFIG_PATH`` using GPUs ``(0, 1, 2, 3)`` and ``DataParallel``
56 | ```
57 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 src/main.py -t -e -c CONFIG_PATH
58 | ```
59 |
60 | Try ``python3 src/main.py`` to see available options.
61 |
62 | ## Dataset
63 |
64 | * CIFAR10: StudioGAN will automatically download the dataset once you execute ``main.py``.
65 |
66 | * Tiny Imagenet, Imagenet, or a custom dataset:
67 | 1. download [Tiny Imagenet](https://tiny-imagenet.herokuapp.com) and [Imagenet](http://www.image-net.org). Prepare your own dataset.
68 | 2. make the folder structure of the dataset as follows:
69 |
70 | ```
71 | ┌── docs
72 | ├── src
73 | └── data
74 | └── ILSVRC2012 or TINY_ILSVRC2012 or CUSTOM
75 | ├── train
76 | │ ├── cls0
77 | │ │ ├── train0.png
78 | │ │ ├── train1.png
79 | │ │ └── ...
80 | │ ├── cls1
81 | │ └── ...
82 | └── valid
83 | ├── cls0
84 | │ ├── valid0.png
85 | │ ├── valid1.png
86 | │ └── ...
87 | ├── cls1
88 | └── ...
89 | ```
90 |
91 | ## Examples and Results
92 | The ``src/configs`` directory contains config files used in our experiments.
93 |
94 | ### CIFAR10 (3x32x32)
95 | To train and evaluate ECGAN-UC on CIFAR10:
96 | ```
97 | python3 src/main.py -t -e -c src/configs/CIFAR10/ecgan_v2_none_0_0p01.json
98 | ```
99 |
100 | | Method | Reference | IS(⭡) | FID(⭣) | F_1/8(⭡) | F_8(⭡) | Cfg | Log | Weights |
101 | |---|---|---|---|---|---|---|---|---|
102 | | BigGAN-Mod | StudioGAN | 9.746 | 8.034 | 0.995 | 0.994 | - | - | - |
103 | | ContraGAN | StudioGAN | 9.729 | 8.065 | 0.993 | 0.992 | - | - | - |
104 | | Ours | - | **10.078** | **7.936** | 0.990 | 0.988 | [Cfg](./src/configs/CIFAR10/ecgan_v2_none_1_0p01.json) | [Log](./logs/CIFAR10/ecgan_v2_none_1_0p01-train-2021_05_26_16_35_45.log) | [Link](https://drive.google.com/drive/folders/1Kig2Loo2Ds5N3Pqc85R6c46Hbx5n9heM?usp=sharing) |
105 |
106 | ### Tiny ImageNet (3x64x64)
107 | To train and evaluate ECGAN-UC on Tiny ImageNet:
108 | ```
109 | python3 src/main.py -t -e -c src/configs/TINY_ILSVRC2012/ecgan_v2_none_0_0p01.json --eval_type valid
110 | ```
111 |
112 | | Method | Reference | IS(⭡) | FID(⭣) | F_1/8(⭡) | F_8(⭡) | Cfg | Log | Weights |
113 | |---|---|---|---|---|---|---|---|---|
114 | | BigGAN-Mod | StudioGAN | 11.998 | 31.92 | 0.956 | 0.879 | - | - | - |
115 | | ContraGAN | StudioGAN | 13.494 | 27.027 | 0.975 | 0.902 | - | - | - |
116 | | Ours | - | **18.445** | **18.319** | **0.977** | **0.973** | [Cfg](./src/configs/TINY_ILSVRC2012/ecgan_v2_none_1_0p05.json) | [Log](./logs/TINY_ILSVRC2012/ecgan_v2_none_1_0p05-train-2021_05_26_16_47_55.log) | [Link](https://drive.google.com/drive/folders/1oVAIljTEIA3b0BHRVjcnukMf3POQQ3rw?usp=sharing) |
117 |
118 | ### ImageNet (3x128x128)
119 | To train and evaluate ECGAN-UCE on ImageNet (~12 days on 8 NVIDIA V100 GPUs):
120 | ```
121 | python3 src/main.py -t -e -l -sync_bn -c src/configs/ILSVRC2012/imagenet_ecgan_v2_contra_1_0p05.json --eval_type valid
122 | ```
123 |
124 | | Method | Reference | IS(⭡) | FID(⭣) | F_1/8(⭡) | F_8(⭡) | Cfg | Log | Weights |
125 | |---|---|---|---|---|---|---|---|---|
126 | | BigGAN | StudioGAN | 28.633 | 24.684 | 0.941 | 0.921 | - | - | - |
127 | | ContraGAN | StudioGAN | 25.249 | 25.161 | 0.947 | 0.855 | - | - | - |
128 | | Ours | - | **80.685** | **8.491** | **0.984** | **0.985** | [Cfg](./src/configs/ILSVRC2012/imagenet_ecgan_v2_contra_1_0p05.json) | [Log](./logs/ILSVRC2012/imagenet_ecgan_v2_contra_1_0p05-train-2021_10_03_00_11_58.log) | [Link](https://drive.google.com/drive/folders/1EkcotNsnA-KBvOCFkvpJpUVoSDRxk-EV?usp=sharing) |
129 |
130 |
131 | ## Generated Images
132 | Here are some selected images generated by ECGAN.
133 |
134 |
135 |
136 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: base
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _ipyw_jlab_nb_ext_conf=0.1.0=py37_0
7 | - _libgcc_mutex=0.1=main
8 | - absl-py=0.9.0=py37_0
9 | - alabaster=0.7.12=py37_0
10 | - anaconda=2020.02=py37_0
11 | - anaconda-client=1.7.2=py37_0
12 | - anaconda-navigator=1.9.12=py37_0
13 | - anaconda-project=0.8.4=py_0
14 | - argh=0.26.2=py37_0
15 | - asn1crypto=1.3.0=py37_0
16 | - astroid=2.3.3=py37_0
17 | - astropy=4.0=py37h7b6447c_0
18 | - atomicwrites=1.3.0=py37_1
19 | - attrs=19.3.0=py_0
20 | - autopep8=1.4.4=py_0
21 | - babel=2.8.0=py_0
22 | - backcall=0.1.0=py37_0
23 | - backports=1.0=py_2
24 | - backports.functools_lru_cache=1.6.1=py_0
25 | - backports.shutil_get_terminal_size=1.0.0=py37_2
26 | - backports.tempfile=1.0=py_1
27 | - backports.weakref=1.0.post1=py_1
28 | - beautifulsoup4=4.8.2=py37_0
29 | - bitarray=1.2.1=py37h7b6447c_0
30 | - bkcharts=0.2=py37_0
31 | - blas=1.0=mkl
32 | - bleach=3.1.0=py37_0
33 | - blosc=1.16.3=hd408876_0
34 | - bokeh=1.4.0=py37_0
35 | - boto=2.49.0=py37_0
36 | - bottleneck=1.3.2=py37heb32a55_0
37 | - bzip2=1.0.8=h7b6447c_0
38 | - c-ares=1.15.0=h7b6447c_1001
39 | - ca-certificates=2020.1.1=0
40 | - cairo=1.14.12=h8948797_3
41 | - certifi=2019.11.28=py37_0
42 | - cffi=1.14.0=py37h2e261b9_0
43 | - chardet=3.0.4=py37_1003
44 | - click=7.0=py37_0
45 | - cloudpickle=1.3.0=py_0
46 | - clyent=1.2.2=py37_1
47 | - colorama=0.4.3=py_0
48 | - conda=4.8.4=py37_0
49 | - conda-build=3.18.11=py37_0
50 | - conda-env=2.6.0=1
51 | - conda-package-handling=1.6.0=py37h7b6447c_0
52 | - conda-verify=3.4.2=py_1
53 | - contextlib2=0.6.0.post1=py_0
54 | - cryptography=2.8=py37h1ba5d50_0
55 | - cudatoolkit=10.1.243=h6bb024c_0
56 | - curl=7.68.0=hbc83047_0
57 | - cycler=0.10.0=py37_0
58 | - cython=0.29.15=py37he6710b0_0
59 | - cytoolz=0.10.1=py37h7b6447c_0
60 | - dask=2.11.0=py_0
61 | - dask-core=2.11.0=py_0
62 | - dbus=1.13.12=h746ee38_0
63 | - decorator=4.4.1=py_0
64 | - defusedxml=0.6.0=py_0
65 | - diff-match-patch=20181111=py_0
66 | - distributed=2.11.0=py37_0
67 | - docutils=0.16=py37_0
68 | - entrypoints=0.3=py37_0
69 | - et_xmlfile=1.0.1=py37_0
70 | - expat=2.2.6=he6710b0_0
71 | - fastcache=1.1.0=py37h7b6447c_0
72 | - filelock=3.0.12=py_0
73 | - flake8=3.7.9=py37_0
74 | - flask=1.1.1=py_0
75 | - fontconfig=2.13.0=h9420a91_0
76 | - freetype=2.9.1=h8a8886c_1
77 | - fribidi=1.0.5=h7b6447c_0
78 | - fsspec=0.6.2=py_0
79 | - future=0.18.2=py37_0
80 | - get_terminal_size=1.0.0=haa9412d_0
81 | - gevent=1.4.0=py37h7b6447c_0
82 | - glib=2.63.1=h5a9c865_0
83 | - glob2=0.7=py_0
84 | - gmp=6.1.2=h6c8ec71_1
85 | - gmpy2=2.0.8=py37h10f8cd9_2
86 | - graphite2=1.3.13=h23475e2_0
87 | - greenlet=0.4.15=py37h7b6447c_0
88 | - grpcio=1.27.2=py37hf8bcb03_0
89 | - gst-plugins-base=1.14.0=hbbd80ab_1
90 | - gstreamer=1.14.0=hb453b48_1
91 | - h5py=2.10.0=py37h7918eee_0
92 | - harfbuzz=1.8.8=hffaf4a1_0
93 | - hdf5=1.10.4=hb1b8bf9_0
94 | - heapdict=1.0.1=py_0
95 | - html5lib=1.0.1=py37_0
96 | - hypothesis=5.5.4=py_0
97 | - icu=58.2=h9c2bf20_1
98 | - idna=2.8=py37_0
99 | - imageio=2.6.1=py37_0
100 | - imagesize=1.2.0=py_0
101 | - importlib-metadata=1.7.0=py37_0
102 | - importlib_metadata=1.5.0=py37_0
103 | - intel-openmp=2020.0=166
104 | - intervaltree=3.0.2=py_0
105 | - ipykernel=5.1.4=py37h39e3cac_0
106 | - ipython=7.12.0=py37h5ca1d4c_0
107 | - ipython_genutils=0.2.0=py37_0
108 | - ipywidgets=7.5.1=py_0
109 | - isort=4.3.21=py37_0
110 | - itsdangerous=1.1.0=py37_0
111 | - jbig=2.1=hdba287a_0
112 | - jdcal=1.4.1=py_0
113 | - jedi=0.14.1=py37_0
114 | - jeepney=0.4.2=py_0
115 | - jinja2=2.11.1=py_0
116 | - joblib=0.14.1=py_0
117 | - jpeg=9b=h024ee3a_2
118 | - json5=0.9.1=py_0
119 | - jsonschema=3.2.0=py37_0
120 | - jupyter=1.0.0=py37_7
121 | - jupyter_client=5.3.4=py37_0
122 | - jupyter_console=6.1.0=py_0
123 | - jupyter_core=4.6.1=py37_0
124 | - jupyterlab=1.2.6=pyhf63ae98_0
125 | - jupyterlab_server=1.0.6=py_0
126 | - keyring=21.1.0=py37_0
127 | - kiwisolver=1.1.0=py37he6710b0_0
128 | - krb5=1.17.1=h173b8e3_0
129 | - lazy-object-proxy=1.4.3=py37h7b6447c_0
130 | - ld_impl_linux-64=2.33.1=h53a641e_7
131 | - libarchive=3.3.3=h5d8350f_5
132 | - libcurl=7.68.0=h20c2e04_0
133 | - libedit=3.1.20181209=hc058e9b_0
134 | - libffi=3.2.1=hd88cf55_4
135 | - libgcc-ng=9.1.0=hdf63c60_0
136 | - libgfortran-ng=7.3.0=hdf63c60_0
137 | - liblief=0.9.0=h7725739_2
138 | - libpng=1.6.37=hbc83047_0
139 | - libprotobuf=3.13.0=hd408876_0
140 | - libsodium=1.0.16=h1bed415_0
141 | - libspatialindex=1.9.3=he6710b0_0
142 | - libssh2=1.8.2=h1ba5d50_0
143 | - libstdcxx-ng=9.1.0=hdf63c60_0
144 | - libtiff=4.1.0=h2733197_0
145 | - libtool=2.4.6=h7b6447c_5
146 | - libuuid=1.0.3=h1bed415_2
147 | - libxcb=1.13=h1bed415_1
148 | - libxml2=2.9.9=hea5a465_1
149 | - libxslt=1.1.33=h7d1a2b0_0
150 | - llvmlite=0.31.0=py37hd408876_0
151 | - locket=0.2.0=py37_1
152 | - lxml=4.5.0=py37hefd8a0e_0
153 | - lz4-c=1.8.1.2=h14c3975_0
154 | - lzo=2.10=h49e0be7_2
155 | - markdown=3.2.2=py37_0
156 | - markupsafe=1.1.1=py37h7b6447c_0
157 | - matplotlib=3.1.3=py37_0
158 | - matplotlib-base=3.1.3=py37hef1b27d_0
159 | - mccabe=0.6.1=py37_1
160 | - mistune=0.8.4=py37h7b6447c_0
161 | - mkl=2020.0=166
162 | - mkl-service=2.3.0=py37he904b0f_0
163 | - mkl_fft=1.0.15=py37ha843d7b_0
164 | - mkl_random=1.1.0=py37hd6b4f25_0
165 | - mock=4.0.1=py_0
166 | - more-itertools=8.2.0=py_0
167 | - mpc=1.1.0=h10f8cd9_1
168 | - mpfr=4.0.1=hdf1c602_3
169 | - mpmath=1.1.0=py37_0
170 | - msgpack-python=0.6.1=py37hfd86e86_1
171 | - multipledispatch=0.6.0=py37_0
172 | - navigator-updater=0.2.1=py37_0
173 | - nbconvert=5.6.1=py37_0
174 | - nbformat=5.0.4=py_0
175 | - ncurses=6.2=he6710b0_0
176 | - networkx=2.4=py_0
177 | - ninja=1.10.1=py37hfd86e86_0
178 | - nltk=3.4.5=py37_0
179 | - nose=1.3.7=py37_2
180 | - notebook=6.0.3=py37_0
181 | - numba=0.48.0=py37h0573a6f_0
182 | - numexpr=2.7.1=py37h423224d_0
183 | - numpy=1.18.1=py37h4f9e942_0
184 | - numpy-base=1.18.1=py37hde5b4d6_1
185 | - numpydoc=0.9.2=py_0
186 | - olefile=0.46=py37_0
187 | - openpyxl=3.0.3=py_0
188 | - openssl=1.1.1d=h7b6447c_4
189 | - packaging=20.1=py_0
190 | - pandas=1.0.1=py37h0573a6f_0
191 | - pandoc=2.2.3.2=0
192 | - pandocfilters=1.4.2=py37_1
193 | - pango=1.42.4=h049681c_0
194 | - parso=0.5.2=py_0
195 | - partd=1.1.0=py_0
196 | - patchelf=0.10=he6710b0_0
197 | - path=13.1.0=py37_0
198 | - path.py=12.4.0=0
199 | - pathlib2=2.3.5=py37_0
200 | - pathtools=0.1.2=py_1
201 | - patsy=0.5.1=py37_0
202 | - pcre=8.43=he6710b0_0
203 | - pep8=1.7.1=py37_0
204 | - pexpect=4.8.0=py37_0
205 | - pickleshare=0.7.5=py37_0
206 | - pip=20.0.2=py37_1
207 | - pixman=0.38.0=h7b6447c_0
208 | - pkginfo=1.5.0.1=py37_0
209 | - pluggy=0.13.1=py37_0
210 | - ply=3.11=py37_0
211 | - prometheus_client=0.7.1=py_0
212 | - prompt_toolkit=3.0.3=py_0
213 | - protobuf=3.13.0=py37hf484d3e_0
214 | - psutil=5.6.7=py37h7b6447c_0
215 | - ptyprocess=0.6.0=py37_0
216 | - py=1.8.1=py_0
217 | - py-lief=0.9.0=py37h7725739_2
218 | - pycodestyle=2.5.0=py37_0
219 | - pycosat=0.6.3=py37h7b6447c_0
220 | - pycparser=2.19=py37_0
221 | - pycrypto=2.6.1=py37h14c3975_9
222 | - pycurl=7.43.0.5=py37h1ba5d50_0
223 | - pydocstyle=4.0.1=py_0
224 | - pyflakes=2.1.1=py37_0
225 | - pygments=2.5.2=py_0
226 | - pylint=2.4.4=py37_0
227 | - pyodbc=4.0.30=py37he6710b0_0
228 | - pyopenssl=19.1.0=py37_0
229 | - pyparsing=2.4.6=py_0
230 | - pyqt=5.9.2=py37h05f1152_2
231 | - pyrsistent=0.15.7=py37h7b6447c_0
232 | - pysocks=1.7.1=py37_0
233 | - pytables=3.6.1=py37h71ec239_0
234 | - pytest=5.3.5=py37_0
235 | - pytest-arraydiff=0.3=py37h39e3cac_0
236 | - pytest-astropy=0.8.0=py_0
237 | - pytest-astropy-header=0.1.2=py_0
238 | - pytest-doctestplus=0.5.0=py_0
239 | - pytest-openfiles=0.4.0=py_0
240 | - pytest-remotedata=0.3.2=py37_0
241 | - python=3.7.6=h0371630_2
242 | - python-dateutil=2.8.1=py_0
243 | - python-jsonrpc-server=0.3.4=py_0
244 | - python-language-server=0.31.7=py37_0
245 | - python-libarchive-c=2.8=py37_13
246 | - pytorch=1.6.0=py3.7_cuda10.1.243_cudnn7.6.3_0
247 | - pytz=2019.3=py_0
248 | - pywavelets=1.1.1=py37h7b6447c_0
249 | - pyxdg=0.26=py_0
250 | - pyyaml=5.3=py37h7b6447c_0
251 | - pyzmq=18.1.1=py37he6710b0_0
252 | - qdarkstyle=2.8=py_0
253 | - qt=5.9.7=h5867ecd_1
254 | - qtawesome=0.6.1=py_0
255 | - qtconsole=4.6.0=py_1
256 | - qtpy=1.9.0=py_0
257 | - readline=7.0=h7b6447c_5
258 | - requests=2.22.0=py37_1
259 | - ripgrep=11.0.2=he32d670_0
260 | - rope=0.16.0=py_0
261 | - rtree=0.9.3=py37_0
262 | - ruamel_yaml=0.15.87=py37h7b6447c_0
263 | - scikit-image=0.16.2=py37h0573a6f_0
264 | - scikit-learn=0.22.1=py37hd81dba3_0
265 | - seaborn=0.10.0=py_0
266 | - secretstorage=3.1.2=py37_0
267 | - send2trash=1.5.0=py37_0
268 | - setuptools=45.2.0=py37_0
269 | - simplegeneric=0.8.1=py37_2
270 | - singledispatch=3.4.0.3=py37_0
271 | - sip=4.19.8=py37hf484d3e_0
272 | - six=1.14.0=py37_0
273 | - snappy=1.1.7=hbae5bb6_3
274 | - snowballstemmer=2.0.0=py_0
275 | - sortedcollections=1.1.2=py37_0
276 | - sortedcontainers=2.1.0=py37_0
277 | - soupsieve=1.9.5=py37_0
278 | - sphinx=2.4.0=py_0
279 | - sphinxcontrib=1.0=py37_1
280 | - sphinxcontrib-applehelp=1.0.1=py_0
281 | - sphinxcontrib-devhelp=1.0.1=py_0
282 | - sphinxcontrib-htmlhelp=1.0.2=py_0
283 | - sphinxcontrib-jsmath=1.0.1=py_0
284 | - sphinxcontrib-qthelp=1.0.2=py_0
285 | - sphinxcontrib-serializinghtml=1.1.3=py_0
286 | - sphinxcontrib-websupport=1.2.0=py_0
287 | - spyder=4.0.1=py37_0
288 | - spyder-kernels=1.8.1=py37_0
289 | - sqlalchemy=1.3.13=py37h7b6447c_0
290 | - sqlite=3.31.1=h7b6447c_0
291 | - statsmodels=0.11.0=py37h7b6447c_0
292 | - sympy=1.5.1=py37_0
293 | - tbb=2020.0=hfd86e86_0
294 | - tblib=1.6.0=py_0
295 | - tensorboard=1.15.0=pyhb230dea_0
296 | - terminado=0.8.3=py37_0
297 | - testpath=0.4.4=py_0
298 | - tk=8.6.8=hbc83047_0
299 | - toolz=0.10.0=py_0
300 | - torchvision=0.7.0=py37_cu101
301 | - tornado=6.0.3=py37h7b6447c_3
302 | - tqdm=4.42.1=py_0
303 | - traitlets=4.3.3=py37_0
304 | - ujson=1.35=py37h14c3975_0
305 | - unicodecsv=0.14.1=py37_0
306 | - unixodbc=2.3.7=h14c3975_0
307 | - urllib3=1.25.8=py37_0
308 | - watchdog=0.10.2=py37_0
309 | - wcwidth=0.1.8=py_0
310 | - webencodings=0.5.1=py37_1
311 | - werkzeug=1.0.0=py_0
312 | - wheel=0.34.2=py37_0
313 | - widgetsnbextension=3.5.1=py37_0
314 | - wrapt=1.11.2=py37h7b6447c_0
315 | - wurlitzer=2.0.0=py37_0
316 | - xlrd=1.2.0=py37_0
317 | - xlsxwriter=1.2.7=py_0
318 | - xlwt=1.3.0=py37_0
319 | - xmltodict=0.12.0=py_0
320 | - xz=5.2.4=h14c3975_4
321 | - yaml=0.1.7=had09818_2
322 | - yapf=0.28.0=py_0
323 | - zeromq=4.3.1=he6710b0_3
324 | - zict=1.0.0=py_0
325 | - zipp=2.2.0=py_0
326 | - zlib=1.2.11=h7b6447c_3
327 | - zstd=1.3.7=h0b5b093_0
328 | - pip:
329 | - pillow==6.2.2
330 | - scipy==1.1.0
331 | - sklearn==0.0
332 | prefix: /root/anaconda3
333 |
--------------------------------------------------------------------------------
/images/ECGAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sian-chen/PyTorch-ECGAN/974e86692611dd3ce4136cf7b2b786f5a011be6b/images/ECGAN.png
--------------------------------------------------------------------------------
/images/imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sian-chen/PyTorch-ECGAN/974e86692611dd3ce4136cf7b2b786f5a011be6b/images/imagenet.png
--------------------------------------------------------------------------------
/src/configs/CIFAR10/ecgan_v2_none_1_0p01.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_processing":{
3 | "dataset_name": "cifar10",
4 | "data_path": "./data/cifar10",
5 | "img_size": 32,
6 | "num_classes": 10,
7 | "batch_size4prcsing": 256,
8 | "chunk_size": 500,
9 | "compression": false
10 | },
11 |
12 | "train": {
13 | "model": {
14 | "architecture": "big_resnet",
15 | "conditional_strategy": "ECGAN",
16 | "pos_collected_numerator": true,
17 | "hypersphere_dim": 512,
18 | "nonlinear_embed": false,
19 | "normalize_embed": true,
20 | "g_spectral_norm": true,
21 | "d_spectral_norm": true,
22 | "activation_fn": "ReLU",
23 | "attention": true,
24 | "attention_after_nth_gen_block": 2,
25 | "attention_after_nth_dis_block": 1,
26 | "z_dim": 80,
27 | "shared_dim": 128,
28 | "g_conv_dim": 96,
29 | "d_conv_dim": 96,
30 | "G_depth": "N/A",
31 | "D_depth": "N/A"
32 | },
33 |
34 | "optimization": {
35 | "optimizer": "Adam",
36 | "batch_size": 64,
37 | "accumulation_steps": 1,
38 | "d_lr": 0.0002,
39 | "g_lr": 0.0002,
40 | "momentum": "N/A",
41 | "nesterov": "N/A",
42 | "alpha": "N/A",
43 | "beta1": 0.5,
44 | "beta2": 0.999,
45 | "g_steps_per_iter": 1,
46 | "d_steps_per_iter": 5,
47 | "total_step": 150000
48 | },
49 |
50 | "loss_function": {
51 | "adv_loss": "hinge",
52 |
53 | "cls_disc_lambda": 0.01,
54 | "cls_gen_lambda": "N/A",
55 | "cond_lambda": "N/A",
56 | "uncond_lambda": 1,
57 | "contrastive_type": "ContraGAN",
58 | "contrastive_lambda": 0.0,
59 | "margin": 0.0,
60 | "tempering_type": "discrete",
61 | "tempering_step": 1,
62 | "start_temperature": 1.0,
63 | "end_temperature": 1.0,
64 |
65 | "weight_clipping_for_dis": false,
66 | "weight_clipping_bound": "N/A",
67 |
68 | "gradient_penalty_for_dis": false,
69 | "gradient_penalty_lambda": "N/A",
70 |
71 | "deep_regret_analysis_for_dis": false,
72 | "regret_penalty_lambda": "N/A",
73 |
74 | "cr": false,
75 | "cr_lambda": "N/A",
76 |
77 | "bcr": false,
78 | "real_lambda": "N/A",
79 | "fake_lambda": "N/A",
80 |
81 | "zcr": false,
82 | "gen_lambda": "N/A",
83 | "dis_lambda": "N/A",
84 | "sigma_noise": "N/A"
85 | },
86 |
87 | "initialization":{
88 | "g_init": "ortho",
89 | "d_init": "ortho"
90 | },
91 |
92 | "training_and_sampling_setting":{
93 | "random_flip_preprocessing": true,
94 | "diff_aug": false,
95 |
96 | "ada": false,
97 | "ada_target": "N/A",
98 | "ada_length": "N/A",
99 |
100 | "prior": "gaussian",
101 | "truncated_factor": 1,
102 |
103 | "latent_op": false,
104 | "latent_op_rate":"N/A",
105 | "latent_op_step":"N/A",
106 | "latent_op_step4eval":"N/A",
107 | "latent_op_alpha":"N/A",
108 | "latent_op_beta":"N/A",
109 | "latent_norm_reg_weight":"N/A",
110 |
111 | "ema": true,
112 | "ema_decay": 0.9999,
113 | "ema_start": 1000
114 | }
115 | }
116 | }
117 |
--------------------------------------------------------------------------------
/src/configs/ILSVRC2012/imagenet_ecgan_v2_contra_1_0p05.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_processing":{
3 | "dataset_name": "imagenet",
4 | "data_path": "./data/ILSVRC2012",
5 | "img_size": 128,
6 | "num_classes": 1000,
7 | "batch_size4prcsing": 256,
8 | "chunk_size": 500,
9 | "compression": false
10 | },
11 |
12 | "train": {
13 | "model": {
14 | "architecture": "big_resnet",
15 | "conditional_strategy": "ECGAN",
16 | "pos_collected_numerator": true,
17 | "hypersphere_dim": 768,
18 | "nonlinear_embed": false,
19 | "normalize_embed": true,
20 | "g_spectral_norm": true,
21 | "d_spectral_norm": true,
22 | "activation_fn": "ReLU",
23 | "attention": true,
24 | "attention_after_nth_gen_block": 4,
25 | "attention_after_nth_dis_block": 1,
26 | "z_dim": 120,
27 | "shared_dim": 128,
28 | "g_conv_dim": 96,
29 | "d_conv_dim": 96,
30 | "G_depth": "N/A",
31 | "D_depth": "N/A"
32 | },
33 |
34 | "optimization": {
35 | "optimizer": "Adam",
36 | "batch_size": 256,
37 | "accumulation_steps": 1,
38 | "d_lr": 0.0002,
39 | "g_lr": 0.00005,
40 | "momentum": "N/A",
41 | "nesterov": "N/A",
42 | "alpha": "N/A",
43 | "beta1": 0.0,
44 | "beta2": 0.999,
45 | "g_steps_per_iter": 1,
46 | "d_steps_per_iter": 2,
47 | "total_step": 400000
48 | },
49 |
50 | "loss_function": {
51 | "adv_loss": "hinge",
52 |
53 | "cls_disc_lambda": 0.05,
54 | "cls_gen_lambda": "N/A",
55 | "cond_lambda": "N/A",
56 | "uncond_lambda": 1,
57 | "contrastive_type": "ContraGAN",
58 | "contrastive_lambda": 1.0,
59 | "margin": 0.0,
60 | "tempering_type": "discrete",
61 | "tempering_step": 1,
62 | "start_temperature": 1.0,
63 | "end_temperature": 1.0,
64 |
65 | "weight_clipping_for_dis": false,
66 | "weight_clipping_bound": "N/A",
67 |
68 | "gradient_penalty_for_dis": false,
69 | "gradient_penalty_lambda": "N/A",
70 |
71 | "deep_regret_analysis_for_dis": false,
72 | "regret_penalty_lambda": "N/A",
73 |
74 | "cr": false,
75 | "cr_lambda":"N/A",
76 |
77 | "bcr": false,
78 | "real_lambda": "N/A",
79 | "fake_lambda": "N/A",
80 |
81 | "zcr": false,
82 | "gen_lambda": "N/A",
83 | "dis_lambda": "N/A",
84 | "sigma_noise": "N/A"
85 | },
86 |
87 | "initialization":{
88 | "g_init": "ortho",
89 | "d_init": "ortho"
90 | },
91 |
92 | "training_and_sampling_setting":{
93 | "random_flip_preprocessing": true,
94 | "diff_aug":false,
95 |
96 | "ada": false,
97 | "ada_target": "N/A",
98 | "ada_length": "N/A",
99 |
100 | "prior": "gaussian",
101 | "truncated_factor": 1,
102 |
103 | "latent_op": false,
104 | "latent_op_rate":"N/A",
105 | "latent_op_step":"N/A",
106 | "latent_op_step4eval":"N/A",
107 | "latent_op_alpha":"N/A",
108 | "latent_op_beta":"N/A",
109 | "latent_norm_reg_weight":"N/A",
110 |
111 | "ema": true,
112 | "ema_decay": 0.9999,
113 | "ema_start": 20000
114 | }
115 | }
116 | }
117 |
--------------------------------------------------------------------------------
/src/configs/TINY_ILSVRC2012/ecgan_v2_none_1_0p05.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_processing":{
3 | "dataset_name": "tiny_imagenet",
4 | "data_path": "./data/TINY_ILSVRC2012",
5 | "img_size": 64,
6 | "num_classes": 200,
7 | "batch_size4prcsing": 256,
8 | "chunk_size": 500,
9 | "compression": false
10 | },
11 |
12 | "train": {
13 | "model": {
14 | "architecture": "big_resnet",
15 | "conditional_strategy": "ECGAN",
16 | "pos_collected_numerator": true,
17 | "hypersphere_dim": 768,
18 | "nonlinear_embed": false,
19 | "normalize_embed": true,
20 | "g_spectral_norm": true,
21 | "d_spectral_norm": true,
22 | "activation_fn": "ReLU",
23 | "attention": true,
24 | "attention_after_nth_gen_block": 3,
25 | "attention_after_nth_dis_block": 1,
26 | "z_dim": 100,
27 | "shared_dim": 128,
28 | "g_conv_dim": 80,
29 | "d_conv_dim": 80,
30 | "G_depth":"N/A",
31 | "D_depth":"N/A"
32 | },
33 |
34 | "optimization": {
35 | "optimizer": "Adam",
36 | "batch_size": 256,
37 | "accumulation_steps": 1,
38 | "d_lr": 0.0004,
39 | "g_lr": 0.0001,
40 | "momentum": "N/A",
41 | "nesterov": "N/A",
42 | "alpha": 0.01,
43 | "beta1": 0.0,
44 | "beta2": 0.999,
45 | "g_steps_per_iter": 1,
46 | "d_steps_per_iter": 1,
47 | "total_step": 100000
48 | },
49 |
50 | "loss_function": {
51 | "adv_loss": "hinge",
52 |
53 | "cls_disc_lambda": 0.05,
54 | "cls_gen_lambda": "N/A",
55 | "cond_lambda": "N/A",
56 | "uncond_lambda": 1,
57 | "contrastive_type": "ContraGAN",
58 | "contrastive_lambda": 0.0,
59 | "margin": 0.0,
60 | "tempering_type": "discrete",
61 | "tempering_step": 1,
62 | "start_temperature": 1.0,
63 | "end_temperature": 1.0,
64 |
65 | "weight_clipping_for_dis": false,
66 | "weight_clipping_bound": "N/A",
67 |
68 | "gradient_penalty_for_dis": false,
69 | "gradient_penalty_lambda": "N/A",
70 |
71 | "deep_regret_analysis_for_dis": false,
72 | "regret_penalty_lambda": "N/A",
73 |
74 | "cr": false,
75 | "cr_lambda": "N/A",
76 |
77 | "bcr": false,
78 | "real_lambda": "N/A",
79 | "fake_lambda": "N/A",
80 |
81 | "zcr": false,
82 | "gen_lambda": "N/A",
83 | "dis_lambda": "N/A",
84 | "sigma_noise": "N/A"
85 | },
86 |
87 | "initialization":{
88 | "g_init": "ortho",
89 | "d_init": "ortho"
90 | },
91 |
92 | "training_and_sampling_setting":{
93 | "random_flip_preprocessing": true,
94 | "diff_aug": false,
95 |
96 | "ada": false,
97 | "ada_target": "N/A",
98 | "ada_length": "N/A",
99 |
100 | "prior": "gaussian",
101 | "truncated_factor": 1,
102 |
103 | "latent_op": false,
104 | "latent_op_rate": "N/A",
105 | "latent_op_step": "N/A",
106 | "latent_op_step4eval": "N/A",
107 | "latent_op_alpha": "N/A",
108 | "latent_op_beta": "N/A",
109 | "latent_norm_reg_weight": "N/A",
110 |
111 | "ema": true,
112 | "ema_decay": 0.9999,
113 | "ema_start": 20000
114 | }
115 | }
116 | }
117 |
--------------------------------------------------------------------------------
/src/data_utils/load_dataset.py:
--------------------------------------------------------------------------------
1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2 | # The MIT License (MIT)
3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4 |
5 | # src/data_utils/load_dataset.py
6 |
7 |
8 | import os
9 | import h5py as h5
10 | import numpy as np
11 | import random
12 | from scipy import io
13 | from PIL import ImageOps, Image
14 |
15 | import torch
16 | import torchvision.transforms as transforms
17 | from torch.utils.data import Dataset
18 | from torchvision.datasets import CIFAR10, STL10
19 | from torchvision.datasets import ImageFolder
20 |
21 |
22 |
23 | class RandomCropLongEdge(object):
24 | """
25 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch
26 | MIT License
27 | Copyright (c) 2019 Andy Brock
28 | """
29 | def __call__(self, img):
30 | size = (min(img.size), min(img.size))
31 | # Only step forward along this edge if it's the long edge
32 | i = (0 if size[0] == img.size[0]
33 | else np.random.randint(low=0,high=img.size[0] - size[0]))
34 | j = (0 if size[1] == img.size[1]
35 | else np.random.randint(low=0,high=img.size[1] - size[1]))
36 | return transforms.functional.crop(img, i, j, size[0], size[1])
37 |
38 | def __repr__(self):
39 | return self.__class__.__name__
40 |
41 |
42 | class CenterCropLongEdge(object):
43 | """
44 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch
45 | MIT License
46 | Copyright (c) 2019 Andy Brock
47 | """
48 | def __call__(self, img):
49 | return transforms.functional.center_crop(img, min(img.size))
50 |
51 | def __repr__(self):
52 | return self.__class__.__name__
53 |
54 |
55 | class LoadDataset(Dataset):
56 | def __init__(self, dataset_name, data_path, train, download, resize_size, hdf5_path=None, random_flip=False):
57 | super(LoadDataset, self).__init__()
58 | self.dataset_name = dataset_name
59 | self.data_path = data_path
60 | self.train = train
61 | self.download = download
62 | self.resize_size = resize_size
63 | self.hdf5_path = hdf5_path
64 | self.random_flip = random_flip
65 | self.norm_mean = [0.5,0.5,0.5]
66 | self.norm_std = [0.5,0.5,0.5]
67 |
68 | if self.hdf5_path is None:
69 | if self.dataset_name in ['cifar10', 'tiny_imagenet']:
70 | self.transforms = []
71 | elif self.dataset_name in ['imagenet', 'custom']:
72 | if train:
73 | self.transforms = [RandomCropLongEdge(), transforms.Resize(self.resize_size)]
74 | else:
75 | self.transforms = [CenterCropLongEdge(), transforms.Resize(self.resize_size)]
76 | else:
77 | self.transforms = [transforms.ToPILImage()]
78 |
79 | if random_flip:
80 | self.transforms += [transforms.RandomHorizontalFlip()]
81 |
82 | self.transforms += [transforms.ToTensor(), transforms.Normalize(self.norm_mean, self.norm_std)]
83 | self.transforms = transforms.Compose(self.transforms)
84 |
85 | self.load_dataset()
86 |
87 |
88 | def load_dataset(self):
89 | if self.dataset_name == 'cifar10':
90 | if self.hdf5_path is not None:
91 | print('Loading %s into memory...' % self.hdf5_path)
92 | with h5.File(self.hdf5_path, 'r') as f:
93 | self.data = f['imgs'][:]
94 | self.labels = f['labels'][:]
95 | else:
96 | self.data = CIFAR10(root=os.path.join('data', self.dataset_name),
97 | train=self.train,
98 | download=self.download)
99 |
100 | elif self.dataset_name == 'imagenet':
101 | if self.hdf5_path is not None:
102 | print('Loading %s into memory...' % self.hdf5_path)
103 | with h5.File(self.hdf5_path, 'r') as f:
104 | self.data = f['imgs'][:]
105 | self.labels = f['labels'][:]
106 | else:
107 | mode = 'train' if self.train == True else 'valid'
108 | root = os.path.join('data','ILSVRC2012', mode)
109 | self.data = ImageFolder(root=root)
110 |
111 | elif self.dataset_name == "tiny_imagenet":
112 | if self.hdf5_path is not None:
113 | print('Loading %s into memory...' % self.hdf5_path)
114 | with h5.File(self.hdf5_path, 'r') as f:
115 | self.data = f['imgs'][:]
116 | self.labels = f['labels'][:]
117 | else:
118 | mode = 'train' if self.train == True else 'valid'
119 | root = os.path.join('data','TINY_ILSVRC2012', mode)
120 | self.data = ImageFolder(root=root)
121 |
122 | elif self.dataset_name == "custom":
123 | if self.hdf5_path is not None:
124 | print('Loading %s into memory...' % self.hdf5_path)
125 | with h5.File(self.hdf5_path, 'r') as f:
126 | self.data = f['imgs'][:]
127 | self.labels = f['labels'][:]
128 | else:
129 | mode = 'train' if self.train == True else 'valid'
130 | root = os.path.join(self.data_path, mode)
131 | self.data = ImageFolder(root=root)
132 | else:
133 | raise NotImplementedError
134 |
135 |
136 | def __len__(self):
137 | if self.hdf5_path is not None:
138 | num_dataset = self.data.shape[0]
139 | else:
140 | num_dataset = len(self.data)
141 | return num_dataset
142 |
143 |
144 | def __getitem__(self, index):
145 | if self.hdf5_path is None:
146 | img, label = self.data[index]
147 | img, label = self.transforms(img), int(label)
148 | else:
149 | img, label = np.transpose(self.data[index], (1,2,0)), int(self.labels[index])
150 | img = self.transforms(img)
151 | return img, label
152 |
--------------------------------------------------------------------------------
/src/inception_tf13.py:
--------------------------------------------------------------------------------
1 | ''' Tensorflow inception score code
2 | Derived from https://github.com/openai/improved-gan
3 | Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
4 | THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in PARALLEL BATCH MODE
5 | To use this code, run sample.py on your model with --sample_npz, and then
6 | pass the experiment name in the --experiment_name.
7 | This code also saves pool3 stats to an npz file for FID calculation
8 | '''
9 | from __future__ import absolute_import
10 | from __future__ import division
11 | from __future__ import print_function
12 |
13 | import os.path
14 | import sys
15 | import tarfile
16 | import math
17 | from tqdm import tqdm, trange
18 | from argparse import ArgumentParser
19 |
20 | import numpy as np
21 | from six.moves import urllib
22 | import tensorflow as tf
23 |
24 | MODEL_DIR = './inception_model'
25 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
26 | softmax = None
27 |
28 | def prepare_parser():
29 | usage = 'Parser for TF1.3- Inception Score scripts.'
30 | parser = ArgumentParser(description=usage)
31 | parser.add_argument(
32 | '--run_name', type=str, default='',
33 | help='Which experiment''s samples.npz file to pull and evaluate')
34 | parser.add_argument(
35 | '--type', type=str, default='',
36 | help='[real, fake]')
37 | parser.add_argument(
38 | '--batch_size', type=int, default=500,
39 | help='Default overall batchsize (default: %(default)s)')
40 | return parser
41 |
42 |
43 | def run(config):
44 | # Inception with TF1.3 or earlier.
45 | # Call this function with list of images. Each of elements should be a
46 | # numpy array with values ranging from 0 to 255.
47 | def get_inception_score(images, splits=10):
48 | assert(type(images) == list)
49 | assert(type(images[0]) == np.ndarray)
50 | assert(len(images[0].shape) == 3)
51 | assert(np.max(images[0]) > 10)
52 | assert(np.min(images[0]) >= 0.0)
53 | inps = []
54 | for img in images:
55 | img = img.astype(np.float32)
56 | inps.append(np.expand_dims(img, 0))
57 | bs = config['batch_size']
58 | with tf.Session() as sess:
59 | preds, pools = [], []
60 | n_batches = int(math.ceil(float(len(inps)) / float(bs)))
61 | for i in trange(n_batches):
62 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
63 | inp = np.concatenate(inp, 0)
64 | pred, pool = sess.run([softmax, pool3], {'ExpandDims:0': inp})
65 | preds.append(pred)
66 | pools.append(pool)
67 | preds = np.concatenate(preds, 0)
68 | scores = []
69 | for i in range(splits):
70 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
71 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
72 | kl = np.mean(np.sum(kl, 1))
73 | scores.append(np.exp(kl))
74 | return np.mean(scores), np.std(scores), np.squeeze(np.concatenate(pools, 0))
75 | # Init inception
76 | def _init_inception():
77 | global softmax, pool3
78 | if not os.path.exists(MODEL_DIR):
79 | os.makedirs(MODEL_DIR)
80 | filename = DATA_URL.split('/')[-1]
81 | filepath = os.path.join(MODEL_DIR, filename)
82 | if not os.path.exists(filepath):
83 | def _progress(count, block_size, total_size):
84 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (
85 | filename, float(count * block_size) / float(total_size) * 100.0))
86 | sys.stdout.flush()
87 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
88 | print()
89 | statinfo = os.stat(filepath)
90 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
91 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
92 | with tf.gfile.FastGFile(os.path.join(
93 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
94 | graph_def = tf.GraphDef()
95 | graph_def.ParseFromString(f.read())
96 | _ = tf.import_graph_def(graph_def, name='')
97 | # Works with an arbitrary minibatch size.
98 | with tf.Session() as sess:
99 | pool3 = sess.graph.get_tensor_by_name('pool_3:0')
100 | ops = pool3.graph.get_operations()
101 | for op_idx, op in enumerate(ops):
102 | for o in op.outputs:
103 | shape = o.get_shape()
104 | shape = [s.value for s in shape]
105 | new_shape = []
106 | for j, s in enumerate(shape):
107 | if s == 1 and j == 0:
108 | new_shape.append(None)
109 | else:
110 | new_shape.append(s)
111 | o._shape = tf.TensorShape(new_shape)
112 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
113 | logits = tf.matmul(tf.squeeze(pool3), w)
114 | softmax = tf.nn.softmax(logits)
115 |
116 | # if softmax is None: # No need to functionalize like this.
117 | _init_inception()
118 |
119 | fname = '%s/%s/%s/%s/samples.npz' % ("samples", config['run_name'], config['type'], "npz")
120 | print('loading %s ...'%fname)
121 | ims = np.load(fname)['x']
122 | import time
123 | t0 = time.time()
124 | inc_mean, inc_std, pool_activations = get_inception_score(list(ims.swapaxes(1,2).swapaxes(2,3)), splits=1)
125 | t1 = time.time()
126 | print('Inception took %3f seconds, score of %3f +/- %3f.'%(t1-t0, inc_mean, inc_std))
127 | def main():
128 | # parse command line and run
129 | parser = prepare_parser()
130 | config = vars(parser.parse_args())
131 | print(config)
132 | run(config)
133 |
134 | if __name__ == '__main__':
135 | main()
136 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2 | # The MIT License (MIT)
3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4 |
5 | # src/main.py
6 |
7 |
8 | import json
9 | import os
10 | import sys
11 | import warnings
12 | from argparse import ArgumentParser
13 |
14 | from utils.misc import *
15 | from utils.make_hdf5 import make_hdf5
16 | from utils.log import make_run_name
17 | from loader import prepare_train_eval
18 |
19 | import torch
20 | from torch.backends import cudnn
21 | import torch.multiprocessing as mp
22 |
23 |
24 |
25 | RUN_NAME_FORMAT = (
26 | "{framework}-"
27 | "{phase}-"
28 | "{timestamp}"
29 | )
30 |
31 |
32 | def main():
33 | parser = ArgumentParser(add_help=False)
34 | parser.add_argument('-c', '--config_path', type=str, default='./src/configs/CIFAR10/ContraGAN.json')
35 | parser.add_argument('--checkpoint_folder', type=str, default=None)
36 | parser.add_argument('-current', '--load_current', action='store_true', help='whether you load the current or best checkpoint')
37 | parser.add_argument('--log_output_path', type=str, default=None)
38 |
39 | parser.add_argument('--seed', type=int, default=-1, help='seed for generating random numbers')
40 | parser.add_argument('-DDP', '--distributed_data_parallel', action='store_true')
41 | parser.add_argument('--num_workers', type=int, default=8, help='')
42 | parser.add_argument('-sync_bn', '--synchronized_bn', action='store_true', help='whether turn on synchronized batchnorm')
43 | parser.add_argument('-mpc', '--mixed_precision', action='store_true', help='whether turn on mixed precision training')
44 | parser.add_argument('-rm_API', '--disable_debugging_API', action='store_true', help='whether disable pytorch autograd debugging mode')
45 |
46 | parser.add_argument('--reduce_train_dataset', type=float, default=1.0, help='control the number of train dataset')
47 | parser.add_argument('-stat_otf', '--bn_stat_OnTheFly', action='store_true', help='when evaluating, use the statistics of a batch')
48 | parser.add_argument('-std_stat', '--standing_statistics', action='store_true')
49 | parser.add_argument('--standing_step', type=int, default=-1, help='# of steps for accumulation batchnorm')
50 | parser.add_argument('--freeze_layers', type=int, default=-1, help='# of layers for freezing discriminator')
51 |
52 | parser.add_argument('-l', '--load_all_data_in_memory', action='store_true')
53 | parser.add_argument('-t', '--train', action='store_true')
54 | parser.add_argument('-e', '--eval', action='store_true')
55 | parser.add_argument('-s', '--save_images', action='store_true')
56 | parser.add_argument('-iv', '--image_visualization', action='store_true', help='select whether conduct image visualization')
57 | parser.add_argument('-knn', '--k_nearest_neighbor', action='store_true', help='select whether conduct k-nearest neighbor analysis')
58 | parser.add_argument('-itp', '--interpolation', action='store_true', help='whether conduct interpolation analysis')
59 | parser.add_argument('-fa', '--frequency_analysis', action='store_true', help='whether conduct frequency analysis')
60 | parser.add_argument('-tsne', '--tsne_analysis', action='store_true', help='whether conduct tsne analysis')
61 | parser.add_argument('--nrow', type=int, default=10, help='number of rows to plot image canvas')
62 | parser.add_argument('--ncol', type=int, default=8, help='number of cols to plot image canvas')
63 |
64 | parser.add_argument('--print_every', type=int, default=100, help='control log interval')
65 | parser.add_argument('--save_every', type=int, default=2000, help='control evaluation and save interval')
66 | parser.add_argument('--eval_type', type=str, default='test', help='[train/valid/test]')
67 | args = parser.parse_args()
68 |
69 | if not args.train and \
70 | not args.eval and \
71 | not args.save_images and \
72 | not args.image_visualization and \
73 | not args.k_nearest_neighbor and \
74 | not args.interpolation and \
75 | not args.frequency_analysis and \
76 | not args.tsne_analysis:
77 | parser.print_help(sys.stderr)
78 | sys.exit(1)
79 |
80 | if args.config_path is not None:
81 | with open(args.config_path) as f:
82 | model_config = json.load(f)
83 | train_config = vars(args)
84 | else:
85 | raise NotImplementedError
86 |
87 | if model_config['data_processing']['dataset_name'] == 'cifar10':
88 | assert train_config['eval_type'] in ['train', 'test'], "Cifar10 does not contain dataset for validation."
89 | elif model_config['data_processing']['dataset_name'] in ['imagenet', 'tiny_imagenet', 'custom']:
90 | assert train_config['eval_type'] == 'train' or train_config['eval_type'] == 'valid', \
91 | "StudioGAN dose not support the evalutation protocol that uses the test dataset on imagenet, tiny imagenet, and custom datasets"
92 |
93 | if train_config['distributed_data_parallel']:
94 | msg = "StudioGAN does not support image visualization, k_nearest_neighbor, interpolation, and frequency_analysis with DDP. " +\
95 | "Please change DDP with a single GPU training or DataParallel instead."
96 | assert train_config['image_visualization'] + train_config['k_nearest_neighbor'] + \
97 | train_config['interpolation'] + train_config['frequency_analysis'] + train_config['tsne_analysis'] == 0, msg
98 |
99 | hdf5_path_train = make_hdf5(model_config['data_processing'], train_config, mode="train") \
100 | if train_config['load_all_data_in_memory'] else None
101 |
102 | if train_config['seed'] == -1:
103 | cudnn.benchmark, cudnn.deterministic = True, False
104 | else:
105 | fix_all_seed(train_config['seed'])
106 | cudnn.benchmark, cudnn.deterministic = False, True
107 |
108 | world_size, rank = torch.cuda.device_count(), torch.cuda.current_device()
109 | if world_size == 1: warnings.warn('You have chosen a specific GPU. This will completely disable data parallelism.')
110 |
111 | if train_config['disable_debugging_API']: torch.autograd.set_detect_anomaly(False)
112 | check_flag_0(model_config['train']['optimization']['batch_size'], world_size, train_config['freeze_layers'], train_config['checkpoint_folder'],
113 | model_config['train']['model']['architecture'], model_config['data_processing']['img_size'])
114 |
115 | run_name = make_run_name(RUN_NAME_FORMAT, framework=train_config['config_path'].split('/')[-1][:-5], phase='train')
116 |
117 | if train_config['distributed_data_parallel'] and world_size > 1:
118 | print("Train the models through DistributedDataParallel (DDP) mode.")
119 | mp.spawn(prepare_train_eval, nprocs=world_size, args=(world_size, run_name, train_config, model_config, hdf5_path_train))
120 | else:
121 | prepare_train_eval(rank, world_size, run_name, train_config, model_config, hdf5_path_train=hdf5_path_train)
122 |
123 | if __name__ == '__main__':
124 | main()
125 |
--------------------------------------------------------------------------------
/src/metrics/Accuracy.py:
--------------------------------------------------------------------------------
1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2 | # The MIT License (MIT)
3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4 |
5 | # src/metrics/Accuracy.py
6 |
7 |
8 | import numpy as np
9 | import math
10 | from scipy import linalg
11 | from tqdm import tqdm
12 |
13 | from utils.sample import sample_latents
14 | from utils.losses import latent_optimise
15 |
16 | import torch
17 | from torch.nn import DataParallel
18 | from torch.nn.parallel import DistributedDataParallel
19 |
20 |
21 |
22 | def calculate_accuracy(dataloader, generator, discriminator, D_loss, num_evaluate, truncated_factor, prior, latent_op,
23 | latent_op_step, latent_op_alpha, latent_op_beta, device, cr, logger, eval_generated_sample=False):
24 | data_iter = iter(dataloader)
25 | batch_size = dataloader.batch_size
26 | disable_tqdm = device != 0
27 |
28 | if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel):
29 | z_dim = generator.module.z_dim
30 | num_classes = generator.module.num_classes
31 | conditional_strategy = discriminator.module.conditional_strategy
32 | else:
33 | z_dim = generator.z_dim
34 | num_classes = generator.num_classes
35 | conditional_strategy = discriminator.conditional_strategy
36 |
37 | total_batch = num_evaluate//batch_size
38 |
39 | if D_loss.__name__ in ["loss_dcgan_dis", "loss_lsgan_dis"]:
40 | cutoff = 0.5
41 | elif D_loss.__name__ == "loss_hinge_dis":
42 | cutoff = 0.0
43 | elif D_loss.__name__ == "loss_wgan_dis":
44 | raise NotImplementedError
45 |
46 | if device == 0: logger.info("Calculate Accuracies....")
47 |
48 | if eval_generated_sample:
49 | for batch_id in tqdm(range(total_batch), disable=disable_tqdm):
50 | zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device)
51 | if latent_op:
52 | zs = latent_optimise(zs, fake_labels, generator, discriminator, conditional_strategy, latent_op_step,
53 | 1.0, latent_op_alpha, latent_op_beta, False, device)
54 |
55 | real_images, real_labels = next(data_iter)
56 | real_images, real_labels = real_images.to(device), real_labels.to(device)
57 |
58 | fake_images = generator(zs, fake_labels, evaluation=True)
59 |
60 | with torch.no_grad():
61 | if conditional_strategy in ["ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"]:
62 | _, _, dis_out_fake = discriminator(fake_images, fake_labels)
63 | _, _, dis_out_real = discriminator(real_images, real_labels)
64 | elif conditional_strategy == "ACGAN":
65 | _, dis_out_fake = discriminator(fake_images, fake_labels)
66 | _, dis_out_real = discriminator(real_images, real_labels)
67 | elif conditional_strategy == "ProjGAN" or conditional_strategy == "no":
68 | dis_out_fake = discriminator(fake_images, fake_labels)
69 | dis_out_real = discriminator(real_images, real_labels)
70 | elif conditional_strategy == 'ECGAN':
71 | _, dis_out_fake, _, _, _ = discriminator(fake_images, fake_labels)
72 | _, dis_out_real, _, _, _ = discriminator(real_images, real_labels)
73 | else:
74 | raise NotImplementedError
75 |
76 | dis_out_fake = dis_out_fake.detach().cpu().numpy()
77 | dis_out_real = dis_out_real.detach().cpu().numpy()
78 |
79 | if batch_id == 0:
80 | confid = np.concatenate((dis_out_fake, dis_out_real), axis=0)
81 | confid_label = np.concatenate(([0.0]*len(dis_out_fake), [1.0]*len(dis_out_real)), axis=0)
82 | else:
83 | confid = np.concatenate((confid, dis_out_fake, dis_out_real), axis=0)
84 | confid_label = np.concatenate((confid_label, [0.0]*len(dis_out_fake), [1.0]*len(dis_out_real)), axis=0)
85 |
86 | real_confid = confid[confid_label==1.0]
87 | fake_confid = confid[confid_label==0.0]
88 |
89 | true_positive = real_confid[np.where(real_confid>cutoff)]
90 | true_negative = fake_confid[np.where(fake_confidcutoff)]
124 | only_real_acc = len(true_positive)/len(real_confid)
125 |
126 | return only_real_acc
127 |
--------------------------------------------------------------------------------
/src/metrics/CAS.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 |
4 | from utils.sample import sample_latents
5 | from utils.losses import latent_optimise
6 |
7 | import torch
8 | from torch.nn import DataParallel
9 | from torch.nn.parallel import DistributedDataParallel
10 | from pytorchcv.model_provider import get_model as ptcv_get_model
11 |
12 |
13 | def calculate_classifier_accuracy_score(dataloader, generator, discriminator, num_evaluate, truncated_factor, prior, latent_op,
14 | latent_op_step, latent_op_alpha, latent_op_beta, device, logger, eval_generated_sample=False):
15 | data_iter = iter(dataloader)
16 | batch_size = dataloader.batch_size
17 | disable_tqdm = device != 0
18 | net = ptcv_get_model('wrn40_8_cifar10', pretrained=True).to(device)
19 |
20 | if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel):
21 | z_dim = generator.module.z_dim
22 | num_classes = generator.module.num_classes
23 | conditional_strategy = discriminator.module.conditional_strategy
24 | else:
25 | z_dim = generator.z_dim
26 | num_classes = generator.num_classes
27 | conditional_strategy = discriminator.conditional_strategy
28 |
29 | total_batch = num_evaluate//batch_size
30 |
31 | if device == 0: logger.info("Calculate Classifier Accuracy Score....")
32 |
33 | all_pred_fake, all_pred_real, all_fake_labels, all_real_labels = [], [], [], []
34 | for batch_id in tqdm(range(total_batch), disable=disable_tqdm):
35 | zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device)
36 | if latent_op:
37 | zs = latent_optimise(zs, fake_labels, generator, discriminator, conditional_strategy, latent_op_step,
38 | 1.0, latent_op_alpha, latent_op_beta, False, device)
39 |
40 | real_images, real_labels = next(data_iter)
41 | real_images, real_labels = real_images.to(device), real_labels.to(device)
42 |
43 | fake_images = generator(zs, fake_labels, evaluation=True)
44 |
45 | with torch.no_grad():
46 | pred_fake = net(fake_images).detach().cpu().numpy()
47 | pred_real = net(real_images).detach().cpu().numpy()
48 |
49 | all_pred_fake.append(pred_fake)
50 | all_pred_real.append(pred_real)
51 | all_fake_labels.append(fake_labels.cpu().numpy())
52 | all_real_labels.append(real_labels.cpu().numpy())
53 |
54 | all_pred_fake = np.concatenate(all_pred_fake, axis=0).argmax(axis=1)
55 | all_pred_real = np.concatenate(all_pred_real, axis=0).argmax(axis=1)
56 | all_fake_labels = np.concatenate(all_fake_labels)
57 | all_real_labels = np.concatenate(all_real_labels)
58 |
59 | fake_cas = (all_pred_fake == all_fake_labels).mean()
60 | real_cas = (all_pred_real == all_real_labels).mean()
61 |
62 | return real_cas, fake_cas
63 |
--------------------------------------------------------------------------------
/src/metrics/DCA.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.special import softmax
3 | from scipy.stats import entropy
4 | from tqdm import tqdm
5 |
6 | from utils.sample import sample_latents
7 | from utils.losses import latent_optimise
8 |
9 | import torch
10 | from torch.nn import DataParallel
11 | from torch.nn.parallel import DistributedDataParallel
12 |
13 |
14 | def calculate_discriminator_classification_accuracy(dataloader, generator, discriminator, num_evaluate, truncated_factor, prior, latent_op,
15 | latent_op_step, latent_op_alpha, latent_op_beta, device, logger, eval_generated_sample=False):
16 | data_iter = iter(dataloader)
17 | batch_size = dataloader.batch_size
18 | disable_tqdm = device != 0
19 |
20 | if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel):
21 | conditional_strategy = discriminator.module.conditional_strategy
22 | else:
23 | conditional_strategy = discriminator.conditional_strategy
24 |
25 | total_batch = num_evaluate//batch_size
26 |
27 | if device == 0: logger.info("Calculate Discriminator Classification Accuracy....")
28 |
29 | all_pred_real, all_real_labels = [], []
30 | # all_cond_output, all_uncond_output = [], []
31 | for batch_id in tqdm(range(total_batch), disable=disable_tqdm):
32 | real_images, real_labels = next(data_iter)
33 | real_images, real_labels = real_images.to(device), real_labels.to(device)
34 |
35 | with torch.no_grad():
36 | if conditional_strategy == "ACGAN":
37 | cls_out_real, _ = discriminator(real_images, real_labels)
38 | elif conditional_strategy == 'ECGAN':
39 | cls_out_real, cond_output, uncond_output, _, _ = discriminator(real_images, real_labels)
40 | pred_real = cls_out_real.detach().cpu().numpy()
41 |
42 | all_pred_real.append(pred_real)
43 | all_real_labels.append(real_labels.cpu().numpy())
44 | # all_cond_output.append(cond_output.detach().cpu().numpy())
45 | # all_uncond_output.append(uncond_output.detach().cpu().numpy())
46 |
47 | # mean_cond_output = np.abs(np.concatenate(all_cond_output)).mean()
48 | # mean_uncond_output = np.abs(np.concatenate(all_uncond_output)).mean()
49 | # print(f'Conditional Output Mean: {mean_cond_output}')
50 | # print(f'Unconditional Output Mean: {mean_uncond_output}')
51 | all_pred_logits = np.concatenate(all_pred_real)
52 | all_pred_entropy = entropy(softmax(all_pred_logits, axis=0), axis=0).mean()
53 | all_pred_real = all_pred_logits.argmax(axis=1)
54 | all_real_labels = np.concatenate(all_real_labels)
55 |
56 | dca = (all_pred_real == all_real_labels).mean()
57 |
58 | return dca, all_pred_entropy
59 |
--------------------------------------------------------------------------------
/src/metrics/FID.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead of Tensorflow
4 | Copyright 2018 Institute of Bioinformatics, JKU Linz
5 | Licensed under the Apache License, Version 2.0 (the "License");
6 | you may not use this file except in compliance with the License.
7 |
8 | You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | import numpy as np
19 | import math
20 | import os
21 | import shutil
22 | from os.path import dirname, abspath, exists, join
23 | from scipy import linalg
24 | from tqdm import tqdm
25 |
26 | from utils.sample import sample_latents
27 | from utils.losses import latent_optimise
28 |
29 | import torch
30 | from torch.nn import DataParallel
31 | from torch.nn.parallel import DistributedDataParallel
32 | from torchvision.utils import save_image
33 |
34 |
35 |
36 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
37 | """Numpy implementation of the Frechet Distance.
38 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
39 | and X_2 ~ N(mu_2, C_2) is
40 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
41 | Stable version by Dougal J. Sutherland.
42 | Params:
43 | -- mu1 : Numpy array containing the activations of a layer of the
44 | inception net (like returned by the function 'get_predictions')
45 | for generated samples.
46 | -- mu2 : The sample mean over activations, precalculated on an
47 | representative data set.
48 | -- sigma1: The covariance matrix over activations for generated samples.
49 | -- sigma2: The covariance matrix over activations, precalculated on an
50 | representative data set.
51 | Returns:
52 | -- : The Frechet Distance.
53 | """
54 |
55 | mu1 = np.atleast_1d(mu1)
56 | mu2 = np.atleast_1d(mu2)
57 |
58 | sigma1 = np.atleast_2d(sigma1)
59 | sigma2 = np.atleast_2d(sigma2)
60 |
61 | assert mu1.shape == mu2.shape, \
62 | 'Training and test mean vectors have different lengths'
63 | assert sigma1.shape == sigma2.shape, \
64 | 'Training and test covariances have different dimensions'
65 |
66 | diff = mu1 - mu2
67 |
68 | # Product might be almost singular
69 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
70 | if not np.isfinite(covmean).all():
71 | offset = np.eye(sigma1.shape[0]) * eps
72 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
73 |
74 | # Numerical error might give slight imaginary component
75 | if np.iscomplexobj(covmean):
76 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
77 | m = np.max(np.abs(covmean.imag))
78 | raise ValueError('Imaginary component {}'.format(m))
79 | covmean = covmean.real
80 |
81 | tr_covmean = np.trace(covmean)
82 | return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
83 |
84 | def generate_images(batch_size, gen, dis, truncated_factor, prior, latent_op, latent_op_step,
85 | latent_op_alpha, latent_op_beta, device):
86 | if isinstance(gen, DataParallel) or isinstance(gen, DistributedDataParallel):
87 | z_dim = gen.module.z_dim
88 | num_classes = gen.module.num_classes
89 | conditional_strategy = dis.module.conditional_strategy
90 | else:
91 | z_dim = gen.z_dim
92 | num_classes = gen.num_classes
93 | conditional_strategy = dis.conditional_strategy
94 |
95 | zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device)
96 |
97 | if latent_op:
98 | zs = latent_optimise(zs, fake_labels, gen, dis, conditional_strategy, latent_op_step, 1.0, latent_op_alpha,
99 | latent_op_beta, False, device)
100 |
101 | with torch.no_grad():
102 | batch_images = gen(zs, fake_labels, evaluation=True)
103 |
104 | return batch_images, fake_labels
105 |
106 |
107 | def get_activations(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, is_generate,
108 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable=False, run_name=None):
109 | """Calculates the activations of the pool_3 layer for all images.
110 | Params:
111 | -- data_loader : data_loader of training images
112 | -- generator : instance of GANs' generator
113 | -- inception_model : Instance of inception model
114 |
115 | Returns:
116 | -- A numpy array of dimension (num images, dims) that contains the
117 | activations of the given tensor when feeding inception with the
118 | query tensor.
119 | """
120 | if is_generate is True:
121 | batch_size = data_loader.batch_size
122 | total_instance = n_generate
123 | n_batches = math.ceil(float(total_instance) / float(batch_size))
124 | else:
125 | batch_size = data_loader.batch_size
126 | total_instance = len(data_loader.dataset)
127 | n_batches = math.ceil(float(total_instance) / float(batch_size))
128 | data_iter = iter(data_loader)
129 |
130 | num_classes = generator.module.num_classes if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel) else generator.num_classes
131 | pred_arr = np.empty((total_instance, 2048))
132 |
133 | for i in tqdm(range(0, n_batches), disable=tqdm_disable):
134 | start = i*batch_size
135 | end = start + batch_size
136 | if is_generate is True:
137 | images, labels = generate_images(batch_size, generator, discriminator, truncated_factor, prior, latent_op,
138 | latent_op_step, latent_op_alpha, latent_op_beta, device)
139 | images = images.to(device)
140 |
141 | with torch.no_grad():
142 | embeddings, logits = inception_model(images)
143 |
144 | if total_instance >= batch_size:
145 | pred_arr[start:end] = embeddings.cpu().data.numpy().reshape(batch_size, -1)
146 | else:
147 | pred_arr[start:] = embeddings[:total_instance].cpu().data.numpy().reshape(total_instance, -1)
148 |
149 | total_instance -= images.shape[0]
150 | else:
151 | try:
152 | feed_list = next(data_iter)
153 | images = feed_list[0]
154 | images = images.to(device)
155 | with torch.no_grad():
156 | embeddings, logits = inception_model(images)
157 |
158 | if total_instance >= batch_size:
159 | pred_arr[start:end] = embeddings.cpu().data.numpy().reshape(batch_size, -1)
160 | else:
161 | pred_arr[start:] = embeddings[:total_instance].cpu().data.numpy().reshape(total_instance, -1)
162 | total_instance -= images.shape[0]
163 |
164 | except StopIteration:
165 | break
166 | return pred_arr
167 |
168 |
169 | def calculate_activation_statistics(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior,
170 | is_generate, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable, run_name=None):
171 | act = get_activations(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior,
172 | is_generate, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable, run_name)
173 | mu = np.mean(act, axis=0)
174 | sigma = np.cov(act, rowvar=False)
175 | return mu, sigma
176 |
177 |
178 | def calculate_fid_score(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior,
179 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, logger, pre_cal_mean=None, pre_cal_std=None, run_name=None):
180 | disable_tqdm = device != 0
181 | inception_model.eval()
182 |
183 | if device == 0: logger.info("Calculating FID Score....")
184 | if pre_cal_mean is not None and pre_cal_std is not None:
185 | m1, s1 = pre_cal_mean, pre_cal_std
186 | else:
187 | m1, s1 = calculate_activation_statistics(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor,
188 | prior, False, False, 0, latent_op_alpha, latent_op_beta, device, tqdm_disable=disable_tqdm)
189 |
190 | m2, s2 = calculate_activation_statistics(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior,
191 | True, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable=disable_tqdm, run_name=run_name)
192 |
193 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
194 |
195 | return fid_value, m1, s1
196 |
--------------------------------------------------------------------------------
/src/metrics/F_beta.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Taken from:
3 | # https://github.com/google/compare_gan/blob/master/compare_gan/src/prd_score.py
4 | #
5 | # Changes:
6 | # - default dpi changed from 150 to 300
7 | # - added handling of cases where P = Q, where precision/recall may be
8 | # just above 1, leading to errors for the f_beta computation
9 | #
10 | # Copyright 2018 Google LLC & Hwalsuk Lee.
11 | #
12 | # Licensed under the Apache License, Version 2.0 (the "License");
13 | # you may not use this file except in compliance with the License.
14 | # You may obtain a copy of the License at
15 | #
16 | # http://www.apache.org/licenses/LICENSE-2.0
17 | #
18 | # Unless required by applicable law or agreed to in writing, software
19 | # distributed under the License is distributed on an "AS IS" BASIS,
20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21 | # See the License for the specific language governing permissions and
22 | # limitations under the License.
23 |
24 |
25 | import math
26 | import numpy as np
27 | from sklearn.cluster import MiniBatchKMeans
28 | from tqdm import tqdm
29 |
30 | from utils.sample import sample_latents
31 | from utils.losses import latent_optimise
32 |
33 | import torch
34 | from torch.nn import DataParallel
35 | from torch.nn.parallel import DistributedDataParallel
36 |
37 |
38 |
39 | class precision_recall(object):
40 | def __init__(self,inception_model, device):
41 | self.inception_model = inception_model
42 | self.device = device
43 | self.disable_tqdm = device != 0
44 |
45 |
46 | def generate_images(self, gen, dis, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, batch_size):
47 | if isinstance(gen, DataParallel) or isinstance(gen, DistributedDataParallel):
48 | z_dim = gen.module.z_dim
49 | num_classes = gen.module.num_classes
50 | conditional_strategy = dis.module.conditional_strategy
51 | else:
52 | z_dim = gen.z_dim
53 | num_classes = gen.num_classes
54 | conditional_strategy = dis.conditional_strategy
55 |
56 | zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, self.device)
57 |
58 | if latent_op:
59 | zs = latent_optimise(zs, fake_labels, gen, dis, conditional_strategy, latent_op_step, 1.0, latent_op_alpha,
60 | latent_op_beta, False, self.device)
61 |
62 | with torch.no_grad():
63 | batch_images = gen(zs, fake_labels, evaluation=True)
64 |
65 | return batch_images
66 |
67 |
68 | def inception_softmax(self, batch_images):
69 | with torch.no_grad():
70 | embeddings, logits = self.inception_model(batch_images)
71 | return embeddings
72 |
73 |
74 | def cluster_into_bins(self, real_embeds, fake_embeds, num_clusters):
75 | representations = np.vstack([real_embeds, fake_embeds])
76 | kmeans = MiniBatchKMeans(n_clusters=num_clusters, n_init=10)
77 | labels = kmeans.fit(representations).labels_
78 |
79 | real_labels = labels[:len(real_embeds)]
80 | fake_labels = labels[len(real_embeds):]
81 |
82 | real_density = np.histogram(real_labels, bins=num_clusters, range=[0, num_clusters], density=True)[0]
83 | fake_density = np.histogram(fake_labels, bins=num_clusters, range=[0, num_clusters], density=True)[0]
84 |
85 | return real_density, fake_density
86 |
87 |
88 | def compute_PRD(self, real_density, fake_density, num_angles=1001, epsilon=1e-10):
89 | angles = np.linspace(epsilon, np.pi/2 - epsilon, num=num_angles)
90 | slopes = np.tan(angles)
91 |
92 | slopes_2d = np.expand_dims(slopes, 1)
93 |
94 | real_density_2d = np.expand_dims(real_density, 0)
95 | fake_density_2d = np.expand_dims(fake_density, 0)
96 |
97 | precision = np.minimum(real_density_2d*slopes_2d, fake_density_2d).sum(axis=1)
98 | recall = precision / slopes
99 |
100 | max_val = max(np.max(precision), np.max(recall))
101 | if max_val > 1.001:
102 | raise ValueError('Detected value > 1.001, this should not happen.')
103 | precision = np.clip(precision, 0, 1)
104 | recall = np.clip(recall, 0, 1)
105 |
106 | return precision, recall
107 |
108 | def compute_precision_recall(self, dataloader, gen, dis, num_generate, num_runs, num_clusters, truncated_factor, prior,
109 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, batch_size, device, num_angles=1001):
110 | dataset_iter = iter(dataloader)
111 | n_batches = int(math.ceil(float(num_generate) / float(batch_size)))
112 | for i in tqdm(range(n_batches), disable = self.disable_tqdm):
113 | real_images, real_labels = next(dataset_iter)
114 | real_images, real_labels = real_images.to(self.device), real_labels.to(self.device)
115 | fake_images = self.generate_images(gen, dis, truncated_factor, prior, latent_op, latent_op_step,
116 | latent_op_alpha, latent_op_beta, batch_size)
117 |
118 | real_embed = self.inception_softmax(real_images).detach().cpu().numpy()
119 | fake_embed = self.inception_softmax(fake_images).detach().cpu().numpy()
120 | if i == 0:
121 | real_embeds = np.array(real_embed, dtype=np.float64)
122 | fake_embeds = np.array(fake_embed, dtype=np.float64)
123 | else:
124 | real_embeds = np.concatenate([real_embeds, np.array(real_embed, dtype=np.float64)], axis=0)
125 | fake_embeds = np.concatenate([fake_embeds, np.array(fake_embed, dtype=np.float64)], axis=0)
126 |
127 | real_embeds = real_embeds[:num_generate]
128 | fake_embeds = fake_embeds[:num_generate]
129 |
130 | precisions = []
131 | recalls = []
132 | for _ in range(num_runs):
133 | real_density, fake_density = self.cluster_into_bins(real_embeds, fake_embeds, num_clusters)
134 | precision, recall = self.compute_PRD(real_density, fake_density, num_angles=num_angles)
135 | precisions.append(precision)
136 | recalls.append(recall)
137 |
138 | mean_precision = np.mean(precisions, axis=0)
139 | mean_recall = np.mean(recalls, axis=0)
140 |
141 | return mean_precision, mean_recall
142 |
143 |
144 | def compute_f_beta(self, precision, recall, beta=1, epsilon=1e-10):
145 | return (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall + epsilon)
146 |
147 |
148 | def calculate_f_beta_score(dataloader, gen, dis, inception_model, num_generate, num_runs, num_clusters, beta, truncated_factor,
149 | prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, logger):
150 | inception_model.eval()
151 |
152 | batch_size = dataloader.batch_size
153 | PR = precision_recall(inception_model, device=device)
154 | if device == 0: logger.info("Calculate F_beta Score....")
155 | precision, recall = PR.compute_precision_recall(dataloader, gen, dis, num_generate, num_runs, num_clusters, truncated_factor,
156 | prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, batch_size, device)
157 |
158 | if not ((precision >= 0).all() and (precision <= 1).all()):
159 | raise ValueError('All values in precision must be in [0, 1].')
160 | if not ((recall >= 0).all() and (recall <= 1).all()):
161 | raise ValueError('All values in recall must be in [0, 1].')
162 | if beta <= 0:
163 | raise ValueError('Given parameter beta %s must be positive.' % str(beta))
164 |
165 | f_beta = np.max(PR.compute_f_beta(precision, recall, beta=beta))
166 | f_beta_inv = np.max(PR.compute_f_beta(precision, recall, beta=1/beta))
167 | return precision, recall, f_beta, f_beta_inv
168 |
--------------------------------------------------------------------------------
/src/metrics/IS.py:
--------------------------------------------------------------------------------
1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2 | # The MIT License (MIT)
3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4 |
5 | # src/metrics/IS.py
6 |
7 | import math
8 | from tqdm import tqdm
9 |
10 | from utils.sample import sample_latents
11 | from utils.losses import latent_optimise
12 |
13 | import torch
14 | from torch.nn import DataParallel
15 | from torch.nn.parallel import DistributedDataParallel
16 |
17 |
18 |
19 | class evaluator(object):
20 | def __init__(self,inception_model, device):
21 | self.inception_model = inception_model
22 | self.device = device
23 | self.disable_tqdm = device != 0
24 |
25 | def generate_images(self, gen, dis, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, batch_size):
26 | if isinstance(gen, DataParallel) or isinstance(gen, DistributedDataParallel):
27 | z_dim = gen.module.z_dim
28 | num_classes = gen.module.num_classes
29 | conditional_strategy = dis.module.conditional_strategy
30 | else:
31 | z_dim = gen.z_dim
32 | num_classes = gen.num_classes
33 | conditional_strategy = dis.conditional_strategy
34 |
35 | zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, self.device)
36 |
37 | if latent_op:
38 | zs = latent_optimise(zs, fake_labels, gen, dis, conditional_strategy, latent_op_step, 1.0, latent_op_alpha,
39 | latent_op_beta, False, self.device)
40 |
41 | with torch.no_grad():
42 | batch_images = gen(zs, fake_labels, evaluation=True)
43 |
44 | return batch_images
45 |
46 |
47 | def inception_softmax(self, batch_images):
48 | with torch.no_grad():
49 | embeddings, logits = self.inception_model(batch_images)
50 | y = torch.nn.functional.softmax(logits, dim=1)
51 | return y
52 |
53 |
54 | def kl_scores(self, ys, splits):
55 | scores = []
56 | n_images = ys.shape[0]
57 | with torch.no_grad():
58 | for j in range(splits):
59 | part = ys[(j*n_images//splits): ((j+1)*n_images//splits), :]
60 | kl = part * (torch.log(part) - torch.log(torch.unsqueeze(torch.mean(part, 0), 0)))
61 | kl = torch.mean(torch.sum(kl, 1))
62 | kl = torch.exp(kl)
63 | scores.append(kl.unsqueeze(0))
64 | scores = torch.cat(scores, 0)
65 | m_scores = torch.mean(scores).detach().cpu().numpy()
66 | m_std = torch.std(scores).detach().cpu().numpy()
67 | return m_scores, m_std
68 |
69 |
70 | def eval_gen(self, gen, dis, n_eval, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha,
71 | latent_op_beta, split, batch_size):
72 | ys = []
73 | n_batches = int(math.ceil(float(n_eval) / float(batch_size)))
74 | for i in tqdm(range(n_batches), disable=self.disable_tqdm):
75 | batch_images = self.generate_images(gen, dis, truncated_factor, prior, latent_op, latent_op_step,
76 | latent_op_alpha, latent_op_beta, batch_size)
77 | y = self.inception_softmax(batch_images)
78 | ys.append(y)
79 |
80 | with torch.no_grad():
81 | ys = torch.cat(ys, 0)
82 | m_scores, m_std = self.kl_scores(ys[:n_eval], splits=split)
83 | return m_scores, m_std
84 |
85 |
86 | def eval_dataset(self, dataloader, splits):
87 | batch_size = dataloader.batch_size
88 | n_images = len(dataloader.dataset)
89 | n_batches = int(math.ceil(float(n_images)/float(batch_size)))
90 | dataset_iter = iter(dataloader)
91 | ys = []
92 | for i in tqdm(range(n_batches), disable=self.disable_tqdm):
93 | feed_list = next(dataset_iter)
94 | batch_images, batch_labels = feed_list[0], feed_list[1]
95 | batch_images = batch_images.to(self.device)
96 | y = self.inception_softmax(batch_images)
97 | ys.append(y)
98 |
99 | with torch.no_grad():
100 | ys = torch.cat(ys, 0)
101 | m_scores, m_std = self.kl_scores(ys, splits=splits)
102 | return m_scores, m_std
103 |
104 |
105 | def calculate_incep_score(dataloader, generator, discriminator, inception_model, n_generate, truncated_factor, prior,
106 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, splits, device, logger):
107 | inception_model.eval()
108 |
109 | batch_size = dataloader.batch_size
110 | evaluator_instance = evaluator(inception_model, device=device)
111 | if device == 0: logger.info("Calculating Inception Score....")
112 | kl_score, kl_std = evaluator_instance.eval_gen(generator, discriminator, n_generate, truncated_factor, prior, latent_op,
113 | latent_op_step, latent_op_alpha, latent_op_beta, splits, batch_size)
114 | return kl_score, kl_std
115 |
--------------------------------------------------------------------------------
/src/metrics/IntraFID.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 |
4 | import numpy as np
5 | from tqdm import tqdm
6 | import torch
7 | from torch.nn import DataParallel
8 | from torch.nn.parallel import DistributedDataParallel
9 |
10 | from metrics.FID import generate_images
11 | from metrics.FID import calculate_frechet_distance
12 |
13 |
14 | def get_activations_with_label(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior, is_generate,
15 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable=False, run_name=None):
16 | """Calculates the activations of the pool_3 layer for all images.
17 | Params:
18 | -- data_loader : data_loader of training images
19 | -- generator : instance of GANs' generator
20 | -- inception_model : Instance of inception model
21 |
22 | Returns:
23 | -- A numpy array of dimension (num images, dims) that contains the
24 | activations of the given tensor when feeding inception with the
25 | query tensor.
26 | """
27 | if is_generate is True:
28 | batch_size = data_loader.batch_size
29 | total_instance = n_generate
30 | n_batches = math.ceil(float(total_instance) / float(batch_size))
31 | else:
32 | batch_size = data_loader.batch_size
33 | total_instance = len(data_loader.dataset)
34 | n_batches = math.ceil(float(total_instance) / float(batch_size))
35 | data_iter = iter(data_loader)
36 |
37 | num_classes = generator.module.num_classes if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel) else generator.num_classes
38 | pred_arr = np.empty((total_instance, 2048))
39 | label_arr = []
40 |
41 | for i in tqdm(range(0, n_batches), disable=tqdm_disable):
42 | start = i*batch_size
43 | end = start + batch_size
44 | if is_generate is True:
45 | images, labels = generate_images(batch_size, generator, discriminator, truncated_factor, prior, latent_op,
46 | latent_op_step, latent_op_alpha, latent_op_beta, device)
47 | images = images.to(device)
48 |
49 | with torch.no_grad():
50 | embeddings, logits = inception_model(images)
51 |
52 | if total_instance >= batch_size:
53 | pred_arr[start:end] = embeddings.cpu().data.numpy().reshape(batch_size, -1)
54 | else:
55 | pred_arr[start:] = embeddings[:total_instance].cpu().data.numpy().reshape(total_instance, -1)
56 |
57 | total_instance -= images.shape[0]
58 | else:
59 | try:
60 | feed_list = next(data_iter)
61 | images = feed_list[0]
62 | labels = feed_list[1]
63 | images = images.to(device)
64 | with torch.no_grad():
65 | embeddings, logits = inception_model(images)
66 |
67 | if total_instance >= batch_size:
68 | pred_arr[start:end] = embeddings.cpu().data.numpy().reshape(batch_size, -1)
69 | else:
70 | pred_arr[start:] = embeddings[:total_instance].cpu().data.numpy().reshape(total_instance, -1)
71 | total_instance -= images.shape[0]
72 |
73 | except StopIteration:
74 | break
75 | label_arr.append(labels.cpu().data.numpy())
76 | label_arr = np.concatenate(label_arr)[:len(pred_arr)]
77 | return pred_arr, label_arr
78 |
79 |
80 | def calculate_intra_activation_statistics(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior,
81 | is_generate, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable, run_name=None):
82 | act, labels = get_activations_with_label(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior,
83 | is_generate, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable, run_name)
84 | num_classes = len(np.unique(labels))
85 | mu, sigma = [], []
86 | for i in tqdm(range(num_classes)):
87 | mu.append(np.mean(act[labels == i], axis=0))
88 | sigma.append(np.cov(act[labels == i], rowvar=False))
89 | return mu, sigma
90 |
91 |
92 | def calculate_intra_fid_score(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior,
93 | latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, logger, pre_cal_mean=None, pre_cal_std=None, run_name=None):
94 | disable_tqdm = device != 0
95 | inception_model.eval()
96 |
97 | if device == 0: logger.info("Calculating Intra-FID Score....")
98 | if pre_cal_mean is not None and pre_cal_std is not None:
99 | m1, s1 = pre_cal_mean, pre_cal_std
100 | else:
101 | m1, s1 = calculate_intra_activation_statistics(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor,
102 | prior, False, False, 0, latent_op_alpha, latent_op_beta, device, tqdm_disable=disable_tqdm)
103 |
104 | m2, s2 = calculate_intra_activation_statistics(data_loader, generator, discriminator, inception_model, n_generate, truncated_factor, prior,
105 | True, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, tqdm_disable=disable_tqdm, run_name=run_name)
106 |
107 | intra_fid = []
108 | for i in tqdm(range(len(m1))):
109 | intra_fid.append(calculate_frechet_distance(m1[i], s1[i], m2[i], s2[i]))
110 | intra_fid = np.mean(intra_fid)
111 |
112 | return intra_fid, m1, s1
113 |
--------------------------------------------------------------------------------
/src/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sian-chen/PyTorch-ECGAN/974e86692611dd3ce4136cf7b2b786f5a011be6b/src/metrics/__init__.py
--------------------------------------------------------------------------------
/src/metrics/inception_network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 |
6 | try:
7 | from torchvision.models.utils import load_state_dict_from_url
8 | except ImportError:
9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
10 |
11 |
12 |
13 | # Inception weights ported to Pytorch from
14 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
15 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
16 |
17 |
18 | class InceptionV3(nn.Module):
19 | """Pretrained InceptionV3 network returning feature maps"""
20 |
21 | # Index of default block of inception to return,
22 | # corresponds to output of final average pooling
23 | def __init__(self,
24 | resize_input=True,
25 | normalize_input=False,
26 | requires_grad=False):
27 | """Build pretrained InceptionV3
28 | Parameters
29 | ----------
30 | resize_input : bool
31 | If true, bilinearly resizes input to width and height 299 before
32 | feeding input to model. As the network without fully connected
33 | layers is fully convolutional, it should be able to handle inputs
34 | of arbitrary size, so resizing might not be strictly needed
35 | normalize_input : bool
36 | If true, scales the input from range (0, 1) to the range the
37 | pretrained Inception network expects, namely (-1, 1)
38 | requires_grad : bool
39 | If true, parameters of the model require gradients. Possibly useful
40 | for finetuning the network
41 | """
42 | super(InceptionV3, self).__init__()
43 |
44 | self.resize_input = resize_input
45 | self.normalize_input = normalize_input
46 | self.blocks = nn.ModuleList()
47 |
48 | state_dict, inception = fid_inception_v3()
49 |
50 | # Block 0: input to maxpool1
51 | block0 = [
52 | inception.Conv2d_1a_3x3,
53 | inception.Conv2d_2a_3x3,
54 | inception.Conv2d_2b_3x3,
55 | nn.MaxPool2d(kernel_size=3, stride=2)
56 | ]
57 | self.blocks.append(nn.Sequential(*block0))
58 |
59 | # Block 1: maxpool1 to maxpool2
60 | block1 = [
61 | inception.Conv2d_3b_1x1,
62 | inception.Conv2d_4a_3x3,
63 | nn.MaxPool2d(kernel_size=3, stride=2)
64 | ]
65 | self.blocks.append(nn.Sequential(*block1))
66 |
67 | # Block 2: maxpool2 to aux classifier
68 | block2 = [
69 | inception.Mixed_5b,
70 | inception.Mixed_5c,
71 | inception.Mixed_5d,
72 | inception.Mixed_6a,
73 | inception.Mixed_6b,
74 | inception.Mixed_6c,
75 | inception.Mixed_6d,
76 | inception.Mixed_6e,
77 | ]
78 | self.blocks.append(nn.Sequential(*block2))
79 |
80 | # Block 3: aux classifier to final avgpool
81 | block3 = [
82 | inception.Mixed_7a,
83 | inception.Mixed_7b,
84 | inception.Mixed_7c,
85 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
86 | ]
87 | self.blocks.append(nn.Sequential(*block3))
88 |
89 | with torch.no_grad():
90 | self.fc = nn.Linear(2048, 1008, bias=True)
91 | self.fc.weight.copy_(state_dict['fc.weight'])
92 | self.fc.bias.copy_(state_dict['fc.bias'])
93 |
94 | for param in self.parameters():
95 | param.requires_grad = requires_grad
96 |
97 |
98 | def forward(self, inp):
99 | """Get Inception feature maps
100 | Parameters
101 | ----------
102 | inp : torch.autograd.Variable
103 | Input tensor of shape Bx3xHxW. Values are expected to be in
104 | range (0, 1)
105 | Returns
106 | -------
107 | List of torch.autograd.Variable, corresponding to the selected output
108 | block, sorted ascending by index
109 | """
110 | x = inp
111 |
112 | if self.resize_input:
113 | x = F.interpolate(x,
114 | size=(299, 299),
115 | mode='bilinear',
116 | align_corners=False)
117 |
118 | if self.normalize_input:
119 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
120 |
121 | for idx, block in enumerate(self.blocks):
122 | x = block(x)
123 |
124 | x = F.dropout(x, training=False)
125 | x = torch.flatten(x, 1)
126 | logit = self.fc(x)
127 | return x, logit
128 |
129 |
130 | def fid_inception_v3():
131 | """Build pretrained Inception model for FID computation
132 | The Inception model for FID computation uses a different set of weights
133 | and has a slightly different structure than torchvision's Inception.
134 | This method first constructs torchvision's Inception and then patches the
135 | necessary parts that are different in the FID Inception model.
136 | """
137 | inception = models.inception_v3(num_classes=1008,
138 | aux_logits=False,
139 | pretrained=False)
140 |
141 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
142 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
143 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
144 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
145 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
146 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
147 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
148 | inception.Mixed_7b = FIDInceptionE_1(1280)
149 | inception.Mixed_7c = FIDInceptionE_2(2048)
150 | # inception.fc = nn.Linear(2048, 1008, bias=False)
151 |
152 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
153 | inception.load_state_dict(state_dict)
154 | return state_dict, inception
155 |
156 |
157 | class FIDInceptionA(models.inception.InceptionA):
158 | """InceptionA block patched for FID computation"""
159 | def __init__(self, in_channels, pool_features):
160 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
161 |
162 | def forward(self, x):
163 | branch1x1 = self.branch1x1(x)
164 |
165 | branch5x5 = self.branch5x5_1(x)
166 | branch5x5 = self.branch5x5_2(branch5x5)
167 |
168 | branch3x3dbl = self.branch3x3dbl_1(x)
169 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
170 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
171 |
172 | # Patch: Tensorflow's average pool does not use the padded zero's in
173 | # its average calculation
174 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
175 | count_include_pad=False)
176 | branch_pool = self.branch_pool(branch_pool)
177 |
178 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
179 | return torch.cat(outputs, 1)
180 |
181 |
182 | class FIDInceptionC(models.inception.InceptionC):
183 | """InceptionC block patched for FID computation"""
184 | def __init__(self, in_channels, channels_7x7):
185 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
186 |
187 | def forward(self, x):
188 | branch1x1 = self.branch1x1(x)
189 |
190 | branch7x7 = self.branch7x7_1(x)
191 | branch7x7 = self.branch7x7_2(branch7x7)
192 | branch7x7 = self.branch7x7_3(branch7x7)
193 |
194 | branch7x7dbl = self.branch7x7dbl_1(x)
195 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
196 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
197 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
198 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
199 |
200 | # Patch: Tensorflow's average pool does not use the padded zero's in
201 | # its average calculation
202 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
203 | count_include_pad=False)
204 | branch_pool = self.branch_pool(branch_pool)
205 |
206 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
207 | return torch.cat(outputs, 1)
208 |
209 |
210 | class FIDInceptionE_1(models.inception.InceptionE):
211 | """First InceptionE block patched for FID computation"""
212 | def __init__(self, in_channels):
213 | super(FIDInceptionE_1, self).__init__(in_channels)
214 |
215 | def forward(self, x):
216 | branch1x1 = self.branch1x1(x)
217 |
218 | branch3x3 = self.branch3x3_1(x)
219 | branch3x3 = [
220 | self.branch3x3_2a(branch3x3),
221 | self.branch3x3_2b(branch3x3),
222 | ]
223 | branch3x3 = torch.cat(branch3x3, 1)
224 |
225 | branch3x3dbl = self.branch3x3dbl_1(x)
226 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
227 | branch3x3dbl = [
228 | self.branch3x3dbl_3a(branch3x3dbl),
229 | self.branch3x3dbl_3b(branch3x3dbl),
230 | ]
231 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
232 |
233 | # Patch: Tensorflow's average pool does not use the padded zero's in
234 | # its average calculation
235 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
236 | count_include_pad=False)
237 | branch_pool = self.branch_pool(branch_pool)
238 |
239 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
240 | return torch.cat(outputs, 1)
241 |
242 |
243 | class FIDInceptionE_2(models.inception.InceptionE):
244 | """Second InceptionE block patched for FID computation"""
245 | def __init__(self, in_channels):
246 | super(FIDInceptionE_2, self).__init__(in_channels)
247 |
248 | def forward(self, x):
249 | branch1x1 = self.branch1x1(x)
250 |
251 | branch3x3 = self.branch3x3_1(x)
252 | branch3x3 = [
253 | self.branch3x3_2a(branch3x3),
254 | self.branch3x3_2b(branch3x3),
255 | ]
256 | branch3x3 = torch.cat(branch3x3, 1)
257 |
258 | branch3x3dbl = self.branch3x3dbl_1(x)
259 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
260 | branch3x3dbl = [
261 | self.branch3x3dbl_3a(branch3x3dbl),
262 | self.branch3x3dbl_3b(branch3x3dbl),
263 | ]
264 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
265 |
266 | # Patch: The FID Inception model uses max pooling instead of average
267 | # pooling. This is likely an error in this specific Inception
268 | # implementation, as other Inception models use average pooling here
269 | # (which matches the description in the paper).
270 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
271 | branch_pool = self.branch_pool(branch_pool)
272 |
273 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
274 | return torch.cat(outputs, 1)
275 |
--------------------------------------------------------------------------------
/src/metrics/prepare_inception_moments.py:
--------------------------------------------------------------------------------
1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2 | # The MIT License (MIT)
3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4 |
5 | # src/metrics/prepare_inception_moments.py
6 |
7 | import numpy as np
8 | import os
9 |
10 | from metrics.FID import calculate_activation_statistics
11 | from metrics.IS import evaluator
12 |
13 |
14 |
15 | def prepare_inception_moments(dataloader, eval_mode, generator, inception_model, splits, run_name, logger, device):
16 | dataset_name = dataloader.dataset.dataset_name
17 | inception_model.eval()
18 |
19 | save_path = os.path.abspath(os.path.join("./data", dataset_name + "_" + eval_mode +'_' + 'inception_moments.npz'))
20 | is_file = os.path.isfile(save_path)
21 |
22 | if is_file:
23 | mu = np.load(save_path)['mu']
24 | sigma = np.load(save_path)['sigma']
25 | else:
26 | if device == 0: logger.info('Calculate moments of {} dataset'.format(eval_mode))
27 | mu, sigma = calculate_activation_statistics(data_loader=dataloader,
28 | generator=generator,
29 | discriminator=None,
30 | inception_model=inception_model,
31 | n_generate=None,
32 | truncated_factor=None,
33 | prior=None,
34 | is_generate=False,
35 | latent_op=False,
36 | latent_op_step=None,
37 | latent_op_alpha=None,
38 | latent_op_beta=None,
39 | device=device,
40 | tqdm_disable=False,
41 | run_name=run_name)
42 |
43 | if device == 0: logger.info('Save calculated means and covariances to disk...')
44 | np.savez(save_path, **{'mu': mu, 'sigma': sigma})
45 |
46 | if is_file:
47 | pass
48 | else:
49 | if device == 0: logger.info('calculate inception score of {} dataset.'.format(eval_mode))
50 | evaluator_instance = evaluator(inception_model, device=device)
51 | is_score, is_std = evaluator_instance.eval_dataset(dataloader, splits=splits)
52 | if device == 0: logger.info('Inception score={is_score}-Inception_std={is_std}'.format(is_score=is_score, is_std=is_std))
53 | return mu, sigma
54 |
--------------------------------------------------------------------------------
/src/models/dcgan.py:
--------------------------------------------------------------------------------
1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2 | # The MIT License (MIT)
3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4 |
5 | # models/dcgan.py
6 |
7 |
8 | from utils.model_ops import *
9 | from utils.misc import *
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 |
17 | class GenBlock(nn.Module):
18 | def __init__(self, in_channels, out_channels, g_spectral_norm, activation_fn, conditional_bn, num_classes):
19 | super(GenBlock, self).__init__()
20 | self.conditional_bn = conditional_bn
21 |
22 | if g_spectral_norm:
23 | self.deconv0 = sndeconv2d(in_channels=in_channels, out_channels=out_channels,
24 | kernel_size=4, stride=2, padding=1)
25 | else:
26 | self.deconv0 = deconv2d(in_channels=in_channels, out_channels=out_channels,
27 | kernel_size=4, stride=2, padding=1)
28 |
29 | if self.conditional_bn:
30 | self.bn0 = ConditionalBatchNorm2d(num_features=out_channels, num_classes=num_classes,
31 | spectral_norm=g_spectral_norm)
32 | else:
33 | self.bn0 = batchnorm_2d(in_features=out_channels)
34 |
35 | if activation_fn == "ReLU":
36 | self.activation = nn.ReLU(inplace=True)
37 | elif activation_fn == "Leaky_ReLU":
38 | self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
39 | elif activation_fn == "ELU":
40 | self.activation = nn.ELU(alpha=1.0, inplace=True)
41 | elif activation_fn == "GELU":
42 | self.activation = nn.GELU()
43 | else:
44 | raise NotImplementedError
45 |
46 | def forward(self, x, label):
47 | x = self.deconv0(x)
48 | if self.conditional_bn:
49 | x = self.bn0(x, label)
50 | else:
51 | x = self.bn0(x)
52 | out = self.activation(x)
53 | return out
54 |
55 |
56 | class Generator(nn.Module):
57 | """Generator."""
58 | def __init__(self, z_dim, shared_dim, img_size, g_conv_dim, g_spectral_norm, attention, attention_after_nth_gen_block, activation_fn,
59 | conditional_strategy, num_classes, initialize, G_depth, mixed_precision):
60 | super(Generator, self).__init__()
61 | self.in_dims = [512, 256, 128]
62 | self.out_dims = [256, 128, 64]
63 |
64 | self.z_dim = z_dim
65 | self.num_classes = num_classes
66 | self.mixed_precision = mixed_precision
67 | conditional_bn = True if conditional_strategy in ["ACGAN", "ProjGAN", "ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN", "ECGAN"] else False
68 |
69 | if g_spectral_norm:
70 | self.linear0 = snlinear(in_features=self.z_dim, out_features=self.in_dims[0]*4*4)
71 | else:
72 | self.linear0 = linear(in_features=self.z_dim, out_features=self.in_dims[0]*4*4)
73 |
74 | self.blocks = []
75 | for index in range(len(self.in_dims)):
76 | self.blocks += [[GenBlock(in_channels=self.in_dims[index],
77 | out_channels=self.out_dims[index],
78 | g_spectral_norm=g_spectral_norm,
79 | activation_fn=activation_fn,
80 | conditional_bn=conditional_bn,
81 | num_classes=self.num_classes)]]
82 |
83 | if index+1 == attention_after_nth_gen_block and attention is True:
84 | self.blocks += [[Self_Attn(self.out_dims[index], g_spectral_norm)]]
85 |
86 | self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
87 |
88 | if g_spectral_norm:
89 | self.conv4 = snconv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1)
90 | else:
91 | self.conv4 = conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1)
92 |
93 | self.tanh = nn.Tanh()
94 |
95 | # Weight init
96 | if initialize is not False:
97 | init_weights(self.modules, initialize)
98 |
99 | def forward(self, z, label, evaluation=False):
100 | with torch.cuda.amp.autocast() if self.mixed_precision is True and evaluation is False else dummy_context_mgr() as mp:
101 | act = self.linear0(z)
102 | act = act.view(-1, self.in_dims[0], 4, 4)
103 | for index, blocklist in enumerate(self.blocks):
104 | for block in blocklist:
105 | if isinstance(block, Self_Attn):
106 | act = block(act)
107 | else:
108 | act = block(act, label)
109 | act = self.conv4(act)
110 | out = self.tanh(act)
111 | return out
112 |
113 |
114 | class DiscBlock(nn.Module):
115 | def __init__(self, in_channels, out_channels, d_spectral_norm, activation_fn):
116 | super(DiscBlock, self).__init__()
117 | self.d_spectral_norm = d_spectral_norm
118 |
119 | if d_spectral_norm:
120 | self.conv0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
121 | self.conv1 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1)
122 | else:
123 | self.conv0 = conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
124 | self.conv1 = conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1)
125 |
126 | self.bn0 = batchnorm_2d(in_features=out_channels)
127 | self.bn1 = batchnorm_2d(in_features=out_channels)
128 |
129 | if activation_fn == "ReLU":
130 | self.activation = nn.ReLU(inplace=True)
131 | elif activation_fn == "Leaky_ReLU":
132 | self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
133 | elif activation_fn == "ELU":
134 | self.activation = nn.ELU(alpha=1.0, inplace=True)
135 | elif activation_fn == "GELU":
136 | self.activation = nn.GELU()
137 | else:
138 | raise NotImplementedError
139 |
140 | def forward(self, x):
141 | x = self.conv0(x)
142 | if self.d_spectral_norm is False:
143 | x = self.bn0(x)
144 | x = self.activation(x)
145 | x = self.conv1(x)
146 | if self.d_spectral_norm is False:
147 | x = self.bn1(x)
148 | out = self.activation(x)
149 | return out
150 |
151 |
152 | class Discriminator(nn.Module):
153 | """Discriminator."""
154 | def __init__(self, img_size, d_conv_dim, d_spectral_norm, attention, attention_after_nth_dis_block, activation_fn, conditional_strategy,
155 | hypersphere_dim, num_classes, nonlinear_embed, normalize_embed, initialize, D_depth, mixed_precision):
156 | super(Discriminator, self).__init__()
157 | self.in_dims = [3] + [64, 128]
158 | self.out_dims = [64, 128, 256]
159 |
160 | self.d_spectral_norm = d_spectral_norm
161 | self.conditional_strategy = conditional_strategy
162 | self.num_classes = num_classes
163 | self.nonlinear_embed = nonlinear_embed
164 | self.normalize_embed = normalize_embed
165 | self.mixed_precision = mixed_precision
166 |
167 | self.blocks = []
168 | for index in range(len(self.in_dims)):
169 | self.blocks += [[DiscBlock(in_channels=self.in_dims[index],
170 | out_channels=self.out_dims[index],
171 | d_spectral_norm=d_spectral_norm,
172 | activation_fn=activation_fn)]]
173 |
174 | if index+1 == attention_after_nth_dis_block and attention is True:
175 | self.blocks += [[Self_Attn(self.out_dims[index], d_spectral_norm)]]
176 |
177 | self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
178 |
179 | if self.d_spectral_norm:
180 | self.conv = snconv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
181 | else:
182 | self.conv = conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
183 | self.bn = batchnorm_2d(in_features=512)
184 |
185 | if activation_fn == "ReLU":
186 | self.activation = nn.ReLU(inplace=True)
187 | elif activation_fn == "Leaky_ReLU":
188 | self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
189 | elif activation_fn == "ELU":
190 | self.activation = nn.ELU(alpha=1.0, inplace=True)
191 | elif activation_fn == "GELU":
192 | self.activation = nn.GELU()
193 | else:
194 | raise NotImplementedError
195 |
196 | if d_spectral_norm:
197 | if self.conditional_strategy == 'ECGAN':
198 | self.linear1 = snlinear(in_features=512, out_features=num_classes)
199 | self.linear2 = snlinear(in_features=512, out_features=hypersphere_dim)
200 | if self.nonlinear_embed:
201 | self.linear3 = snlinear(in_features=hypersphere_dim, out_features=hypersphere_dim)
202 | self.embedding = sn_embedding(num_classes, hypersphere_dim)
203 | else:
204 | self.linear1 = snlinear(in_features=512, out_features=1)
205 | if self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']:
206 | self.linear2 = snlinear(in_features=512, out_features=hypersphere_dim)
207 | if self.nonlinear_embed:
208 | self.linear3 = snlinear(in_features=hypersphere_dim, out_features=hypersphere_dim)
209 | self.embedding = sn_embedding(num_classes, hypersphere_dim)
210 | elif self.conditional_strategy == 'ProjGAN':
211 | self.embedding = sn_embedding(num_classes, 512)
212 | elif self.conditional_strategy == 'ACGAN':
213 | self.linear4 = snlinear(in_features=512, out_features=num_classes)
214 | else:
215 | pass
216 | else:
217 | if self.conditional_strategy == 'ECGAN':
218 | self.linear1 = linear(in_features=512, out_features=num_classes)
219 | if self.contrastive_lambda:
220 | self.linear2 = linear(in_features=512, out_features=hypersphere_dim)
221 | if self.nonlinear_embed:
222 | self.linear3 = linear(in_features=hypersphere_dim, out_features=hypersphere_dim)
223 | self.embedding = embedding(num_classes, hypersphere_dim)
224 | else:
225 | self.linear1 = linear(in_features=512, out_features=1)
226 | if self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']:
227 | self.linear2 = linear(in_features=512, out_features=hypersphere_dim)
228 | if self.nonlinear_embed:
229 | self.linear3 = linear(in_features=hypersphere_dim, out_features=hypersphere_dim)
230 | self.embedding = embedding(num_classes, hypersphere_dim)
231 | elif self.conditional_strategy == 'ProjGAN':
232 | self.embedding = embedding(num_classes, 512)
233 | elif self.conditional_strategy == 'ACGAN':
234 | self.linear4 = linear(in_features=512, out_features=num_classes)
235 | else:
236 | pass
237 |
238 | # Weight init
239 | if initialize is not False:
240 | init_weights(self.modules, initialize)
241 |
242 |
243 | def forward(self, x, label, evaluation=False):
244 | with torch.cuda.amp.autocast() if self.mixed_precision is True and evaluation is False else dummy_context_mgr() as mp:
245 | h = x
246 | for index, blocklist in enumerate(self.blocks):
247 | for block in blocklist:
248 | h = block(h)
249 | h = self.conv(h)
250 | if self.d_spectral_norm is False:
251 | h = self.bn(h)
252 | h = self.activation(h)
253 | h = torch.sum(h, dim=[2, 3])
254 |
255 | if self.conditional_strategy == 'no':
256 | authen_output = torch.squeeze(self.linear1(h))
257 | return authen_output
258 |
259 | elif self.conditional_strategy == 'ECGAN':
260 | cls_output = self.linear1(h)
261 | cond_output = (cls_output * F.one_hot(label.squeeze(), num_classes=self.num_classes)).sum(dim=1)
262 | uncond_output = torch.logsumexp(cls_output, dim=1)
263 | cls_proxy = self.embedding(label)
264 | cls_embed = self.linear2(h)
265 | if self.nonlinear_embed:
266 | cls_embed = self.linear3(self.activation(cls_embed))
267 | if self.normalize_embed:
268 | cls_proxy = F.normalize(cls_proxy, dim=1)
269 | cls_embed = F.normalize(cls_embed, dim=1)
270 | return cls_output, cond_output, uncond_output, cls_proxy, cls_embed
271 |
272 | elif self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']:
273 | authen_output = torch.squeeze(self.linear1(h))
274 | cls_proxy = self.embedding(label)
275 | cls_embed = self.linear2(h)
276 | if self.nonlinear_embed:
277 | cls_embed = self.linear3(self.activation(cls_embed))
278 | if self.normalize_embed:
279 | cls_proxy = F.normalize(cls_proxy, dim=1)
280 | cls_embed = F.normalize(cls_embed, dim=1)
281 | return cls_proxy, cls_embed, authen_output
282 |
283 | elif self.conditional_strategy == 'ProjGAN':
284 | authen_output = torch.squeeze(self.linear1(h))
285 | proj = torch.sum(torch.mul(self.embedding(label), h), 1)
286 | return authen_output + proj
287 |
288 | elif self.conditional_strategy == 'ACGAN':
289 | authen_output = torch.squeeze(self.linear1(h))
290 | cls_output = self.linear4(h)
291 | return cls_output, authen_output
292 |
293 | else:
294 | raise NotImplementedError
295 |
--------------------------------------------------------------------------------
/src/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | """
2 | -*- coding: utf-8 -*-
3 | File : batchnorm_reimpl.py
4 | Author : Jiayuan Mao
5 | Email : maojiayuan@gmail.com
6 | Date : 27/01/2018
7 |
8 | This file is part of Synchronized-BatchNorm-PyTorch.
9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
10 | Distributed under MIT License.
11 |
12 | MIT License
13 |
14 | Copyright (c) 2018 Jiayuan MAO
15 |
16 | Permission is hereby granted, free of charge, to any person obtaining a copy
17 | of this software and associated documentation files (the "Software"), to deal
18 | in the Software without restriction, including without limitation the rights
19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20 | copies of the Software, and to permit persons to whom the Software is
21 | furnished to do so, subject to the following conditions:
22 |
23 | The above copyright notice and this permission notice shall be included in all
24 | copies or substantial portions of the Software.
25 |
26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32 | SOFTWARE.
33 | """
34 |
35 |
36 | import torch
37 | import torch.nn as nn
38 | import torch.nn.init as init
39 |
40 | __all__ = ['BatchNorm2dReimpl']
41 |
42 |
43 | class BatchNorm2dReimpl(nn.Module):
44 | """
45 | A re-implementation of batch normalization, used for testing the numerical
46 | stability.
47 |
48 | Author: acgtyrant
49 | See also:
50 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
51 | """
52 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
53 | super().__init__()
54 |
55 | self.num_features = num_features
56 | self.eps = eps
57 | self.momentum = momentum
58 | self.weight = nn.Parameter(torch.empty(num_features))
59 | self.bias = nn.Parameter(torch.empty(num_features))
60 | self.register_buffer('running_mean', torch.zeros(num_features))
61 | self.register_buffer('running_var', torch.ones(num_features))
62 | self.reset_parameters()
63 |
64 | def reset_running_stats(self):
65 | self.running_mean.zero_()
66 | self.running_var.fill_(1)
67 |
68 | def reset_parameters(self):
69 | self.reset_running_stats()
70 | init.uniform_(self.weight)
71 | init.zeros_(self.bias)
72 |
73 | def forward(self, input_):
74 | batchsize, channels, height, width = input_.size()
75 | numel = batchsize * height * width
76 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
77 | sum_ = input_.sum(1)
78 | sum_of_square = input_.pow(2).sum(1)
79 | mean = sum_ / numel
80 | sumvar = sum_of_square - sum_ * mean
81 |
82 | self.running_mean = (
83 | (1 - self.momentum) * self.running_mean
84 | + self.momentum * mean.detach()
85 | )
86 | unbias_var = sumvar / (numel - 1)
87 | self.running_var = (
88 | (1 - self.momentum) * self.running_var
89 | + self.momentum * unbias_var.detach()
90 | )
91 |
92 | bias_var = sumvar / numel
93 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
94 | output = (
95 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
96 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
97 |
98 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
99 |
100 |
--------------------------------------------------------------------------------
/src/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | """
2 | -*- coding: utf-8 -*-
3 | File : comm.py
4 | Author : Jiayuan Mao
5 | Email : maojiayuan@gmail.com
6 | Date : 27/01/2018
7 |
8 | This file is part of Synchronized-BatchNorm-PyTorch.
9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
10 | Distributed under MIT License.
11 |
12 | MIT License
13 |
14 | Copyright (c) 2018 Jiayuan MAO
15 |
16 | Permission is hereby granted, free of charge, to any person obtaining a copy
17 | of this software and associated documentation files (the "Software"), to deal
18 | in the Software without restriction, including without limitation the rights
19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20 | copies of the Software, and to permit persons to whom the Software is
21 | furnished to do so, subject to the following conditions:
22 |
23 | The above copyright notice and this permission notice shall be included in all
24 | copies or substantial portions of the Software.
25 |
26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32 | SOFTWARE.
33 | """
34 |
35 |
36 | import queue
37 | import collections
38 | import threading
39 |
40 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
41 |
42 |
43 | class FutureResult(object):
44 | """A thread-safe future implementation. Used only as one-to-one pipe."""
45 |
46 | def __init__(self):
47 | self._result = None
48 | self._lock = threading.Lock()
49 | self._cond = threading.Condition(self._lock)
50 |
51 | def put(self, result):
52 | with self._lock:
53 | assert self._result is None, 'Previous result has\'t been fetched.'
54 | self._result = result
55 | self._cond.notify()
56 |
57 | def get(self):
58 | with self._lock:
59 | if self._result is None:
60 | self._cond.wait()
61 |
62 | res = self._result
63 | self._result = None
64 | return res
65 |
66 |
67 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
68 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
69 |
70 |
71 | class SlavePipe(_SlavePipeBase):
72 | """Pipe for master-slave communication."""
73 |
74 | def run_slave(self, msg):
75 | self.queue.put((self.identifier, msg))
76 | ret = self.result.get()
77 | self.queue.put(True)
78 | return ret
79 |
80 |
81 | class SyncMaster(object):
82 | """An abstract `SyncMaster` object.
83 |
84 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
85 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
86 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
87 | and passed to a registered callback.
88 | - After receiving the messages, the master device should gather the information and determine to message passed
89 | back to each slave devices.
90 | """
91 |
92 | def __init__(self, master_callback):
93 | """
94 |
95 | Args:
96 | master_callback: a callback to be invoked after having collected messages from slave devices.
97 | """
98 | self._master_callback = master_callback
99 | self._queue = queue.Queue()
100 | self._registry = collections.OrderedDict()
101 | self._activated = False
102 |
103 | def __getstate__(self):
104 | return {'master_callback': self._master_callback}
105 |
106 | def __setstate__(self, state):
107 | self.__init__(state['master_callback'])
108 |
109 | def register_slave(self, identifier):
110 | """
111 | Register an slave device.
112 |
113 | Args:
114 | identifier: an identifier, usually is the device id.
115 |
116 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
117 |
118 | """
119 | if self._activated:
120 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
121 | self._activated = False
122 | self._registry.clear()
123 | future = FutureResult()
124 | self._registry[identifier] = _MasterRegistry(future)
125 | return SlavePipe(identifier, self._queue, future)
126 |
127 | def run_master(self, master_msg):
128 | """
129 | Main entry for the master device in each forward pass.
130 | The messages were first collected from each devices (including the master device), and then
131 | an callback will be invoked to compute the message to be sent back to each devices
132 | (including the master device).
133 |
134 | Args:
135 | master_msg: the message that the master want to send to itself. This will be placed as the first
136 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
137 |
138 | Returns: the message to be sent back to the master device.
139 |
140 | """
141 | self._activated = True
142 |
143 | intermediates = [(0, master_msg)]
144 | for i in range(self.nr_slaves):
145 | intermediates.append(self._queue.get())
146 |
147 | results = self._master_callback(intermediates)
148 | assert results[0][0] == 0, 'The first result should belongs to the master.'
149 |
150 | for i, res in results:
151 | if i == 0:
152 | continue
153 | self._registry[i].result.put(res)
154 |
155 | for i in range(self.nr_slaves):
156 | assert self._queue.get() is True
157 |
158 | return results[0][1]
159 |
160 | @property
161 | def nr_slaves(self):
162 | return len(self._registry)
163 |
--------------------------------------------------------------------------------
/src/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | """
2 | -*- coding: utf-8 -*-
3 | File : replicate.py
4 | Author : Jiayuan Mao
5 | Email : maojiayuan@gmail.com
6 | Date : 27/01/2018
7 |
8 | This file is part of Synchronized-BatchNorm-PyTorch.
9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
10 | Distributed under MIT License.
11 |
12 | MIT License
13 |
14 | Copyright (c) 2018 Jiayuan MAO
15 |
16 | Permission is hereby granted, free of charge, to any person obtaining a copy
17 | of this software and associated documentation files (the "Software"), to deal
18 | in the Software without restriction, including without limitation the rights
19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20 | copies of the Software, and to permit persons to whom the Software is
21 | furnished to do so, subject to the following conditions:
22 |
23 | The above copyright notice and this permission notice shall be included in all
24 | copies or substantial portions of the Software.
25 |
26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32 | SOFTWARE.
33 | """
34 |
35 |
36 | import functools
37 |
38 | from torch.nn.parallel.data_parallel import DataParallel
39 |
40 | __all__ = [
41 | 'CallbackContext',
42 | 'execute_replication_callbacks',
43 | 'DataParallelWithCallback',
44 | 'patch_replication_callback'
45 | ]
46 |
47 |
48 | class CallbackContext(object):
49 | pass
50 |
51 |
52 | def execute_replication_callbacks(modules):
53 | """
54 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
55 |
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Note that, as all modules are isomorphism, we assign each sub-module with a context
59 | (shared among multiple copies of this module on different devices).
60 | Through this context, different copies can share some information.
61 |
62 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
63 | of any slave copies.
64 | """
65 | master_copy = modules[0]
66 | nr_modules = len(list(master_copy.modules()))
67 | ctxs = [CallbackContext() for _ in range(nr_modules)]
68 |
69 | for i, module in enumerate(modules):
70 | for j, m in enumerate(module.modules()):
71 | if hasattr(m, '__data_parallel_replicate__'):
72 | m.__data_parallel_replicate__(ctxs[j], i)
73 |
74 |
75 | class DataParallelWithCallback(DataParallel):
76 | """
77 | Data Parallel with a replication callback.
78 |
79 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
80 | original `replicate` function.
81 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
82 |
83 | Examples:
84 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
85 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
86 | # sync_bn.__data_parallel_replicate__ will be invoked.
87 | """
88 |
89 | def replicate(self, module, device_ids):
90 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 |
95 | def patch_replication_callback(data_parallel):
96 | """
97 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
98 | Useful when you have customized `DataParallel` implementation.
99 |
100 | Examples:
101 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
102 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
103 | > patch_replication_callback(sync_bn)
104 | # this is equivalent to
105 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
106 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
107 | """
108 |
109 | assert isinstance(data_parallel, DataParallel)
110 |
111 | old_replicate = data_parallel.replicate
112 |
113 | @functools.wraps(old_replicate)
114 | def new_replicate(module, device_ids):
115 | modules = old_replicate(module, device_ids)
116 | execute_replication_callbacks(modules)
117 | return modules
118 |
119 | data_parallel.replicate = new_replicate
120 |
--------------------------------------------------------------------------------
/src/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | """
2 | -*- coding: utf-8 -*-
3 | File : unittest.py
4 | Author : Jiayuan Mao
5 | Email : maojiayuan@gmail.com
6 | Date : 27/01/2018
7 |
8 | This file is part of Synchronized-BatchNorm-PyTorch.
9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
10 | Distributed under MIT License.
11 |
12 | MIT License
13 |
14 | Copyright (c) 2018 Jiayuan MAO
15 |
16 | Permission is hereby granted, free of charge, to any person obtaining a copy
17 | of this software and associated documentation files (the "Software"), to deal
18 | in the Software without restriction, including without limitation the rights
19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20 | copies of the Software, and to permit persons to whom the Software is
21 | furnished to do so, subject to the following conditions:
22 |
23 | The above copyright notice and this permission notice shall be included in all
24 | copies or substantial portions of the Software.
25 |
26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32 | SOFTWARE.
33 | """
34 |
35 |
36 | import unittest
37 | import torch
38 |
39 |
40 | class TorchTestCase(unittest.TestCase):
41 | def assertTensorClose(self, x, y):
42 | adiff = float((x - y).abs().max())
43 | if (y == 0).all():
44 | rdiff = 'NaN'
45 | else:
46 | rdiff = float((adiff / y).abs().max())
47 |
48 | message = (
49 | 'Tensor close check failed\n'
50 | 'adiff={}\n'
51 | 'rdiff={}\n'
52 | ).format(adiff, rdiff)
53 | self.assertTrue(torch.allclose(x, y), message)
54 |
55 |
--------------------------------------------------------------------------------
/src/utils/ada.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2019 Kim Seonghyeon
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 |
26 | import math
27 |
28 | from utils.ada_op import upfirdn2d
29 |
30 | import torch
31 | from torch.nn import functional as F
32 |
33 |
34 |
35 | SYM6 = (
36 | 0.015404109327027373,
37 | 0.0034907120842174702,
38 | -0.11799011114819057,
39 | -0.048311742585633,
40 | 0.4910559419267466,
41 | 0.787641141030194,
42 | 0.3379294217276218,
43 | -0.07263752278646252,
44 | -0.021060292512300564,
45 | 0.04472490177066578,
46 | 0.0017677118642428036,
47 | -0.007800708325034148,
48 | )
49 |
50 |
51 | def translate_mat(t_x, t_y):
52 | batch = t_x.shape[0]
53 |
54 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
55 | translate = torch.stack((t_x, t_y), 1)
56 | mat[:, :2, 2] = translate
57 |
58 | return mat
59 |
60 |
61 | def rotate_mat(theta):
62 | batch = theta.shape[0]
63 |
64 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
65 | sin_t = torch.sin(theta)
66 | cos_t = torch.cos(theta)
67 | rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
68 | mat[:, :2, :2] = rot
69 |
70 | return mat
71 |
72 |
73 | def scale_mat(s_x, s_y):
74 | batch = s_x.shape[0]
75 |
76 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
77 | mat[:, 0, 0] = s_x
78 | mat[:, 1, 1] = s_y
79 |
80 | return mat
81 |
82 |
83 | def translate3d_mat(t_x, t_y, t_z):
84 | batch = t_x.shape[0]
85 |
86 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
87 | translate = torch.stack((t_x, t_y, t_z), 1)
88 | mat[:, :3, 3] = translate
89 |
90 | return mat
91 |
92 |
93 | def rotate3d_mat(axis, theta):
94 | batch = theta.shape[0]
95 |
96 | u_x, u_y, u_z = axis
97 |
98 | eye = torch.eye(3).unsqueeze(0)
99 | cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
100 | outer = torch.tensor(axis)
101 | outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
102 |
103 | sin_t = torch.sin(theta).view(-1, 1, 1)
104 | cos_t = torch.cos(theta).view(-1, 1, 1)
105 |
106 | rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
107 |
108 | eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
109 | eye_4[:, :3, :3] = rot
110 |
111 | return eye_4
112 |
113 |
114 | def scale3d_mat(s_x, s_y, s_z):
115 | batch = s_x.shape[0]
116 |
117 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
118 | mat[:, 0, 0] = s_x
119 | mat[:, 1, 1] = s_y
120 | mat[:, 2, 2] = s_z
121 |
122 | return mat
123 |
124 |
125 | def luma_flip_mat(axis, i):
126 | batch = i.shape[0]
127 |
128 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
129 | axis = torch.tensor(axis + (0,))
130 | flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
131 |
132 | return eye - flip
133 |
134 |
135 | def saturation_mat(axis, i):
136 | batch = i.shape[0]
137 |
138 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
139 | axis = torch.tensor(axis + (0,))
140 | axis = torch.ger(axis, axis)
141 | saturate = axis + (eye - axis) * i.view(-1, 1, 1)
142 |
143 | return saturate
144 |
145 |
146 | def lognormal_sample(size, mean=0, std=1):
147 | return torch.empty(size).log_normal_(mean=mean, std=std)
148 |
149 |
150 | def category_sample(size, categories):
151 | category = torch.tensor(categories)
152 | sample = torch.randint(high=len(categories), size=(size,))
153 |
154 | return category[sample]
155 |
156 |
157 | def uniform_sample(size, low, high):
158 | return torch.empty(size).uniform_(low, high)
159 |
160 |
161 | def normal_sample(size, mean=0, std=1):
162 | return torch.empty(size).normal_(mean, std)
163 |
164 |
165 | def bernoulli_sample(size, p):
166 | return torch.empty(size).bernoulli_(p)
167 |
168 |
169 | def random_mat_apply(p, transform, prev, eye):
170 | size = transform.shape[0]
171 | select = bernoulli_sample(size, p).view(size, 1, 1)
172 | select_transform = select * transform + (1 - select) * eye
173 |
174 | return select_transform @ prev
175 |
176 |
177 | def sample_affine(p, size, height, width):
178 | G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1)
179 | eye = G
180 |
181 | # flip
182 | param = category_sample(size, (0, 1))
183 | Gc = scale_mat(1 - 2.0 * param, torch.ones(size))
184 | G = random_mat_apply(p, Gc, G, eye)
185 | # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
186 |
187 | # 90 rotate
188 | param = category_sample(size, (0, 3))
189 | Gc = rotate_mat(-math.pi / 2 * param)
190 | G = random_mat_apply(p, Gc, G, eye)
191 | # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
192 |
193 | # integer translate
194 | param = uniform_sample(size, -0.125, 0.125)
195 | param_height = torch.round(param * height) / height
196 | param_width = torch.round(param * width) / width
197 | Gc = translate_mat(param_width, param_height)
198 | G = random_mat_apply(p, Gc, G, eye)
199 | # print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
200 |
201 | # isotropic scale
202 | param = lognormal_sample(size, std=0.2 * math.log(2))
203 | Gc = scale_mat(param, param)
204 | G = random_mat_apply(p, Gc, G, eye)
205 | # print('isotropic scale', G, scale_mat(param, param), sep='\n')
206 |
207 | p_rot = 1 - math.sqrt(1 - p)
208 |
209 | # pre-rotate
210 | param = uniform_sample(size, -math.pi, math.pi)
211 | Gc = rotate_mat(-param)
212 | G = random_mat_apply(p_rot, Gc, G, eye)
213 | # print('pre-rotate', G, rotate_mat(-param), sep='\n')
214 |
215 | # anisotropic scale
216 | param = lognormal_sample(size, std=0.2 * math.log(2))
217 | Gc = scale_mat(param, 1 / param)
218 | G = random_mat_apply(p, Gc, G, eye)
219 | # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
220 |
221 | # post-rotate
222 | param = uniform_sample(size, -math.pi, math.pi)
223 | Gc = rotate_mat(-param)
224 | G = random_mat_apply(p_rot, Gc, G, eye)
225 | # print('post-rotate', G, rotate_mat(-param), sep='\n')
226 |
227 | # fractional translate
228 | param = normal_sample(size, std=0.125)
229 | Gc = translate_mat(param, param)
230 | G = random_mat_apply(p, Gc, G, eye)
231 | # print('fractional translate', G, translate_mat(param, param), sep='\n')
232 |
233 | return G
234 |
235 |
236 | def sample_color(p, size):
237 | C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
238 | eye = C
239 | axis_val = 1 / math.sqrt(3)
240 | axis = (axis_val, axis_val, axis_val)
241 |
242 | # brightness
243 | param = normal_sample(size, std=0.2)
244 | Cc = translate3d_mat(param, param, param)
245 | C = random_mat_apply(p, Cc, C, eye)
246 |
247 | # contrast
248 | param = lognormal_sample(size, std=0.5 * math.log(2))
249 | Cc = scale3d_mat(param, param, param)
250 | C = random_mat_apply(p, Cc, C, eye)
251 |
252 | # luma flip
253 | param = category_sample(size, (0, 1))
254 | Cc = luma_flip_mat(axis, param)
255 | C = random_mat_apply(p, Cc, C, eye)
256 |
257 | # hue rotation
258 | param = uniform_sample(size, -math.pi, math.pi)
259 | Cc = rotate3d_mat(axis, param)
260 | C = random_mat_apply(p, Cc, C, eye)
261 |
262 | # saturation
263 | param = lognormal_sample(size, std=1 * math.log(2))
264 | Cc = saturation_mat(axis, param)
265 | C = random_mat_apply(p, Cc, C, eye)
266 |
267 | return C
268 |
269 |
270 | def make_grid(shape, x0, x1, y0, y1, device):
271 | n, c, h, w = shape
272 | grid = torch.empty(n, h, w, 3, device=device)
273 | grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
274 | grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
275 | grid[:, :, :, 2] = 1
276 |
277 | return grid
278 |
279 |
280 | def affine_grid(grid, mat):
281 | n, h, w, _ = grid.shape
282 | return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
283 |
284 |
285 | def get_padding(G, height, width):
286 | extreme = (
287 | G[:, :2, :]
288 | @ torch.tensor([(-1.0, -1, 1), (-1, 1, 1), (1, -1, 1), (1, 1, 1)]).t()
289 | )
290 |
291 | size = torch.tensor((width, height))
292 |
293 | pad_low = (
294 | ((extreme.min(-1).values + 1) * size)
295 | .clamp(max=0)
296 | .abs()
297 | .ceil()
298 | .max(0)
299 | .values.to(torch.int64)
300 | .tolist()
301 | )
302 | pad_high = (
303 | (extreme.max(-1).values * size - size)
304 | .clamp(min=0)
305 | .ceil()
306 | .max(0)
307 | .values.to(torch.int64)
308 | .tolist()
309 | )
310 |
311 | return pad_low[0], pad_high[0], pad_low[1], pad_high[1]
312 |
313 |
314 | def try_sample_affine_and_pad(img, p, pad_k, G=None):
315 | batch, _, height, width = img.shape
316 |
317 | G_try = G
318 |
319 | while True:
320 | if G is None:
321 | G_try = sample_affine(p, batch, height, width)
322 |
323 | pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(
324 | torch.inverse(G_try), height, width
325 | )
326 |
327 | try:
328 | img_pad = F.pad(
329 | img,
330 | (pad_x1 + pad_k, pad_x2 + pad_k, pad_y1 + pad_k, pad_y2 + pad_k),
331 | mode="reflect",
332 | )
333 |
334 | except RuntimeError:
335 | continue
336 |
337 | break
338 |
339 | return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
340 |
341 |
342 | def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
343 | kernel = antialiasing_kernel
344 | len_k = len(kernel)
345 | pad_k = (len_k + 1) // 2
346 |
347 | kernel = torch.as_tensor(kernel)
348 | kernel = torch.ger(kernel, kernel).to(img)
349 | kernel_flip = torch.flip(kernel, (0, 1))
350 |
351 | img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
352 | img, p, pad_k, G
353 | )
354 |
355 | p_ux1 = pad_x1
356 | p_ux2 = pad_x2 + 1
357 | p_uy1 = pad_y1
358 | p_uy2 = pad_y2 + 1
359 | w_p = img_pad.shape[3] - len_k + 1
360 | h_p = img_pad.shape[2] - len_k + 1
361 | h_o = img.shape[2]
362 | w_o = img.shape[3]
363 |
364 | img_2x = upfirdn2d(img_pad, kernel_flip, up=2)
365 |
366 | grid = make_grid(
367 | img_2x.shape,
368 | -2 * p_ux1 / w_o - 1,
369 | 2 * (w_p - p_ux1) / w_o - 1,
370 | -2 * p_uy1 / h_o - 1,
371 | 2 * (h_p - p_uy1) / h_o - 1,
372 | device=img_2x.device,
373 | ).to(img_2x)
374 | grid = affine_grid(grid, torch.inverse(G)[:, :2, :].to(img_2x))
375 | grid = grid * torch.tensor(
376 | [w_o / w_p, h_o / h_p], device=grid.device
377 | ) + torch.tensor(
378 | [(w_o + 2 * p_ux1) / w_p - 1, (h_o + 2 * p_uy1) / h_p - 1], device=grid.device
379 | )
380 |
381 | img_affine = F.grid_sample(
382 | img_2x, grid, mode="bilinear", align_corners=False, padding_mode="zeros"
383 | )
384 |
385 | img_down = upfirdn2d(img_affine, kernel, down=2)
386 |
387 | end_y = -pad_y2 - 1
388 | if end_y == 0:
389 | end_y = img_down.shape[2]
390 |
391 | end_x = -pad_x2 - 1
392 | if end_x == 0:
393 | end_x = img_down.shape[3]
394 |
395 | img = img_down[:, :, pad_y1:end_y, pad_x1:end_x]
396 |
397 | return img, G
398 |
399 |
400 | def apply_color(img, mat):
401 | batch = img.shape[0]
402 | img = img.permute(0, 2, 3, 1)
403 | mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
404 | mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
405 | img = img @ mat_mul + mat_add
406 | img = img.permute(0, 3, 1, 2)
407 |
408 | return img
409 |
410 |
411 | def random_apply_color(img, p, C=None):
412 | if C is None:
413 | C = sample_color(p, img.shape[0])
414 |
415 | img = apply_color(img, C.to(img))
416 |
417 | return img, C
418 |
419 |
420 | def augment(img, p, transform_matrix=(None, None)):
421 | img, G = random_apply_affine(img, p, transform_matrix[0])
422 | img, C = random_apply_color(img, p, transform_matrix[1])
423 |
424 | return img, (G, C)
425 |
--------------------------------------------------------------------------------
/src/utils/ada_op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/src/utils/ada_op/fused_act.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2019 Kim Seonghyeon
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 |
26 | import os
27 |
28 | import torch
29 | from torch import nn
30 | from torch.nn import functional as F
31 | from torch.autograd import Function
32 | from torch.utils.cpp_extension import load
33 |
34 |
35 | module_path = os.path.dirname(__file__)
36 | fused = load(
37 | "fused",
38 | sources=[
39 | os.path.join(module_path, "fused_bias_act.cpp"),
40 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
41 | ],
42 | )
43 |
44 |
45 | class FusedLeakyReLUFunctionBackward(Function):
46 | @staticmethod
47 | def forward(ctx, grad_output, out, negative_slope, scale):
48 | ctx.save_for_backward(out)
49 | ctx.negative_slope = negative_slope
50 | ctx.scale = scale
51 |
52 | empty = grad_output.new_empty(0)
53 |
54 | grad_input = fused.fused_bias_act(
55 | grad_output, empty, out, 3, 1, negative_slope, scale
56 | )
57 |
58 | dim = [0]
59 |
60 | if grad_input.ndim > 2:
61 | dim += list(range(2, grad_input.ndim))
62 |
63 | grad_bias = grad_input.sum(dim).detach()
64 |
65 | return grad_input, grad_bias
66 |
67 | @staticmethod
68 | def backward(ctx, gradgrad_input, gradgrad_bias):
69 | out, = ctx.saved_tensors
70 | gradgrad_out = fused.fused_bias_act(
71 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
72 | )
73 |
74 | return gradgrad_out, None, None, None
75 |
76 |
77 | class FusedLeakyReLUFunction(Function):
78 | @staticmethod
79 | def forward(ctx, input, bias, negative_slope, scale):
80 | empty = input.new_empty(0)
81 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
82 | ctx.save_for_backward(out)
83 | ctx.negative_slope = negative_slope
84 | ctx.scale = scale
85 |
86 | return out
87 |
88 | @staticmethod
89 | def backward(ctx, grad_output):
90 | out, = ctx.saved_tensors
91 |
92 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
93 | grad_output, out, ctx.negative_slope, ctx.scale
94 | )
95 |
96 | return grad_input, grad_bias, None, None
97 |
98 |
99 | class FusedLeakyReLU(nn.Module):
100 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
101 | super().__init__()
102 |
103 | self.bias = nn.Parameter(torch.zeros(channel))
104 | self.negative_slope = negative_slope
105 | self.scale = scale
106 |
107 | def forward(self, input):
108 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
109 |
110 |
111 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
112 | if input.device.type == "cpu":
113 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
114 | return (
115 | F.leaky_relu(
116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
117 | )
118 | * scale
119 | )
120 |
121 | else:
122 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
123 |
--------------------------------------------------------------------------------
/src/utils/ada_op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | MIT License
3 |
4 | Copyright (c) 2019 Kim Seonghyeon
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | */
24 |
25 |
26 | #include
27 |
28 |
29 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
30 | int act, int grad, float alpha, float scale);
31 |
32 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
33 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
34 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
35 |
36 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
37 | int act, int grad, float alpha, float scale) {
38 | CHECK_CUDA(input);
39 | CHECK_CUDA(bias);
40 |
41 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
42 | }
43 |
44 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
45 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
46 | }
47 |
--------------------------------------------------------------------------------
/src/utils/ada_op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
100 |
--------------------------------------------------------------------------------
/src/utils/ada_op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | MIT License
3 |
4 | Copyright (c) 2019 Kim Seonghyeon
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | */
24 |
25 |
26 | #include
27 |
28 |
29 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
30 | int up_x, int up_y, int down_x, int down_y,
31 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
32 |
33 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
34 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
35 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
36 |
37 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
38 | int up_x, int up_y, int down_x, int down_y,
39 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
40 | CHECK_CUDA(input);
41 | CHECK_CUDA(kernel);
42 |
43 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
44 | }
45 |
46 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
47 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
48 | }
49 |
--------------------------------------------------------------------------------
/src/utils/ada_op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2019 Kim Seonghyeon
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 |
26 | import os
27 |
28 | import torch
29 | from torch.nn import functional as F
30 | from torch.autograd import Function
31 | from torch.utils.cpp_extension import load
32 |
33 |
34 | module_path = os.path.dirname(__file__)
35 | upfirdn2d_op = load(
36 | "upfirdn2d",
37 | sources=[
38 | os.path.join(module_path, "upfirdn2d.cpp"),
39 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
40 | ],
41 | )
42 |
43 |
44 | class UpFirDn2dBackward(Function):
45 | @staticmethod
46 | def forward(
47 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
48 | ):
49 |
50 | up_x, up_y = up
51 | down_x, down_y = down
52 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
53 |
54 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
55 |
56 | grad_input = upfirdn2d_op.upfirdn2d(
57 | grad_output,
58 | grad_kernel,
59 | down_x,
60 | down_y,
61 | up_x,
62 | up_y,
63 | g_pad_x0,
64 | g_pad_x1,
65 | g_pad_y0,
66 | g_pad_y1,
67 | )
68 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
69 |
70 | ctx.save_for_backward(kernel)
71 |
72 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
73 |
74 | ctx.up_x = up_x
75 | ctx.up_y = up_y
76 | ctx.down_x = down_x
77 | ctx.down_y = down_y
78 | ctx.pad_x0 = pad_x0
79 | ctx.pad_x1 = pad_x1
80 | ctx.pad_y0 = pad_y0
81 | ctx.pad_y1 = pad_y1
82 | ctx.in_size = in_size
83 | ctx.out_size = out_size
84 |
85 | return grad_input
86 |
87 | @staticmethod
88 | def backward(ctx, gradgrad_input):
89 | kernel, = ctx.saved_tensors
90 |
91 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
92 |
93 | gradgrad_out = upfirdn2d_op.upfirdn2d(
94 | gradgrad_input,
95 | kernel,
96 | ctx.up_x,
97 | ctx.up_y,
98 | ctx.down_x,
99 | ctx.down_y,
100 | ctx.pad_x0,
101 | ctx.pad_x1,
102 | ctx.pad_y0,
103 | ctx.pad_y1,
104 | )
105 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
106 | gradgrad_out = gradgrad_out.view(
107 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
108 | )
109 |
110 | return gradgrad_out, None, None, None, None, None, None, None, None
111 |
112 |
113 | class UpFirDn2d(Function):
114 | @staticmethod
115 | def forward(ctx, input, kernel, up, down, pad):
116 | up_x, up_y = up
117 | down_x, down_y = down
118 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
119 |
120 | kernel_h, kernel_w = kernel.shape
121 | batch, channel, in_h, in_w = input.shape
122 | ctx.in_size = input.shape
123 |
124 | input = input.reshape(-1, in_h, in_w, 1)
125 |
126 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
127 |
128 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
129 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
130 | ctx.out_size = (out_h, out_w)
131 |
132 | ctx.up = (up_x, up_y)
133 | ctx.down = (down_x, down_y)
134 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
135 |
136 | g_pad_x0 = kernel_w - pad_x0 - 1
137 | g_pad_y0 = kernel_h - pad_y0 - 1
138 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
139 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
140 |
141 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
142 |
143 | out = upfirdn2d_op.upfirdn2d(
144 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
145 | )
146 | # out = out.view(major, out_h, out_w, minor)
147 | out = out.view(-1, channel, out_h, out_w)
148 |
149 | return out
150 |
151 | @staticmethod
152 | def backward(ctx, grad_output):
153 | kernel, grad_kernel = ctx.saved_tensors
154 |
155 | grad_input = UpFirDn2dBackward.apply(
156 | grad_output,
157 | kernel,
158 | grad_kernel,
159 | ctx.up,
160 | ctx.down,
161 | ctx.pad,
162 | ctx.g_pad,
163 | ctx.in_size,
164 | ctx.out_size,
165 | )
166 |
167 | return grad_input, None, None, None, None
168 |
169 |
170 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
171 | if input.device.type == "cpu":
172 | out = upfirdn2d_native(
173 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
174 | )
175 |
176 | else:
177 | out = UpFirDn2d.apply(
178 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
179 | )
180 |
181 | return out
182 |
183 |
184 | def upfirdn2d_native(
185 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
186 | ):
187 | _, channel, in_h, in_w = input.shape
188 | input = input.reshape(-1, in_h, in_w, 1)
189 |
190 | _, in_h, in_w, minor = input.shape
191 | kernel_h, kernel_w = kernel.shape
192 |
193 | out = input.view(-1, in_h, 1, in_w, 1, minor)
194 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
195 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
196 |
197 | out = F.pad(
198 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
199 | )
200 | out = out[
201 | :,
202 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
203 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
204 | :,
205 | ]
206 |
207 | out = out.permute(0, 3, 1, 2)
208 | out = out.reshape(
209 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
210 | )
211 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
212 | out = F.conv2d(out, w)
213 | out = out.reshape(
214 | -1,
215 | minor,
216 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
217 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
218 | )
219 | out = out.permute(0, 2, 3, 1)
220 | out = out[:, ::down_y, ::down_x, :]
221 |
222 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
223 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
224 |
225 | return out.view(-1, channel, out_h, out_w)
226 |
--------------------------------------------------------------------------------
/src/utils/ada_op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18 | int c = a / b;
19 |
20 | if (c * b > a) {
21 | c--;
22 | }
23 |
24 | return c;
25 | }
26 |
27 | struct UpFirDn2DKernelParams {
28 | int up_x;
29 | int up_y;
30 | int down_x;
31 | int down_y;
32 | int pad_x0;
33 | int pad_x1;
34 | int pad_y0;
35 | int pad_y1;
36 |
37 | int major_dim;
38 | int in_h;
39 | int in_w;
40 | int minor_dim;
41 | int kernel_h;
42 | int kernel_w;
43 | int out_h;
44 | int out_w;
45 | int loop_major;
46 | int loop_x;
47 | };
48 |
49 | template
50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51 | const scalar_t *kernel,
52 | const UpFirDn2DKernelParams p) {
53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54 | int out_y = minor_idx / p.minor_dim;
55 | minor_idx -= out_y * p.minor_dim;
56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57 | int major_idx_base = blockIdx.z * p.loop_major;
58 |
59 | if (out_x_base >= p.out_w || out_y >= p.out_h ||
60 | major_idx_base >= p.major_dim) {
61 | return;
62 | }
63 |
64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68 |
69 | for (int loop_major = 0, major_idx = major_idx_base;
70 | loop_major < p.loop_major && major_idx < p.major_dim;
71 | loop_major++, major_idx++) {
72 | for (int loop_x = 0, out_x = out_x_base;
73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78 |
79 | const scalar_t *x_p =
80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81 | minor_idx];
82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83 | int x_px = p.minor_dim;
84 | int k_px = -p.up_x;
85 | int x_py = p.in_w * p.minor_dim;
86 | int k_py = -p.up_y * p.kernel_w;
87 |
88 | scalar_t v = 0.0f;
89 |
90 | for (int y = 0; y < h; y++) {
91 | for (int x = 0; x < w; x++) {
92 | v += static_cast(*x_p) * static_cast(*k_p);
93 | x_p += x_px;
94 | k_p += k_px;
95 | }
96 |
97 | x_p += x_py - w * x_px;
98 | k_p += k_py - w * k_px;
99 | }
100 |
101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102 | minor_idx] = v;
103 | }
104 | }
105 | }
106 |
107 | template
109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110 | const scalar_t *kernel,
111 | const UpFirDn2DKernelParams p) {
112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114 |
115 | __shared__ volatile float sk[kernel_h][kernel_w];
116 | __shared__ volatile float sx[tile_in_h][tile_in_w];
117 |
118 | int minor_idx = blockIdx.x;
119 | int tile_out_y = minor_idx / p.minor_dim;
120 | minor_idx -= tile_out_y * p.minor_dim;
121 | tile_out_y *= tile_out_h;
122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123 | int major_idx_base = blockIdx.z * p.loop_major;
124 |
125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126 | major_idx_base >= p.major_dim) {
127 | return;
128 | }
129 |
130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131 | tap_idx += blockDim.x) {
132 | int ky = tap_idx / kernel_w;
133 | int kx = tap_idx - ky * kernel_w;
134 | scalar_t v = 0.0;
135 |
136 | if (kx < p.kernel_w & ky < p.kernel_h) {
137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138 | }
139 |
140 | sk[ky][kx] = v;
141 | }
142 |
143 | for (int loop_major = 0, major_idx = major_idx_base;
144 | loop_major < p.loop_major & major_idx < p.major_dim;
145 | loop_major++, major_idx++) {
146 | for (int loop_x = 0, tile_out_x = tile_out_x_base;
147 | loop_x < p.loop_x & tile_out_x < p.out_w;
148 | loop_x++, tile_out_x += tile_out_w) {
149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151 | int tile_in_x = floor_div(tile_mid_x, up_x);
152 | int tile_in_y = floor_div(tile_mid_y, up_y);
153 |
154 | __syncthreads();
155 |
156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157 | in_idx += blockDim.x) {
158 | int rel_in_y = in_idx / tile_in_w;
159 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
160 | int in_x = rel_in_x + tile_in_x;
161 | int in_y = rel_in_y + tile_in_y;
162 |
163 | scalar_t v = 0.0;
164 |
165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167 | p.minor_dim +
168 | minor_idx];
169 | }
170 |
171 | sx[rel_in_y][rel_in_x] = v;
172 | }
173 |
174 | __syncthreads();
175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176 | out_idx += blockDim.x) {
177 | int rel_out_y = out_idx / tile_out_w;
178 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
179 | int out_x = rel_out_x + tile_out_x;
180 | int out_y = rel_out_y + tile_out_y;
181 |
182 | int mid_x = tile_mid_x + rel_out_x * down_x;
183 | int mid_y = tile_mid_y + rel_out_y * down_y;
184 | int in_x = floor_div(mid_x, up_x);
185 | int in_y = floor_div(mid_y, up_y);
186 | int rel_in_x = in_x - tile_in_x;
187 | int rel_in_y = in_y - tile_in_y;
188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190 |
191 | scalar_t v = 0.0;
192 |
193 | #pragma unroll
194 | for (int y = 0; y < kernel_h / up_y; y++)
195 | #pragma unroll
196 | for (int x = 0; x < kernel_w / up_x; x++)
197 | v += sx[rel_in_y + y][rel_in_x + x] *
198 | sk[kernel_y + y * up_y][kernel_x + x * up_x];
199 |
200 | if (out_x < p.out_w & out_y < p.out_h) {
201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202 | minor_idx] = v;
203 | }
204 | }
205 | }
206 | }
207 | }
208 |
209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210 | const torch::Tensor &kernel, int up_x, int up_y,
211 | int down_x, int down_y, int pad_x0, int pad_x1,
212 | int pad_y0, int pad_y1) {
213 | int curDevice = -1;
214 | cudaGetDevice(&curDevice);
215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
216 |
217 | UpFirDn2DKernelParams p;
218 |
219 | auto x = input.contiguous();
220 | auto k = kernel.contiguous();
221 |
222 | p.major_dim = x.size(0);
223 | p.in_h = x.size(1);
224 | p.in_w = x.size(2);
225 | p.minor_dim = x.size(3);
226 | p.kernel_h = k.size(0);
227 | p.kernel_w = k.size(1);
228 | p.up_x = up_x;
229 | p.up_y = up_y;
230 | p.down_x = down_x;
231 | p.down_y = down_y;
232 | p.pad_x0 = pad_x0;
233 | p.pad_x1 = pad_x1;
234 | p.pad_y0 = pad_y0;
235 | p.pad_y1 = pad_y1;
236 |
237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238 | p.down_y;
239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240 | p.down_x;
241 |
242 | auto out =
243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244 |
245 | int mode = -1;
246 |
247 | int tile_out_h = -1;
248 | int tile_out_w = -1;
249 |
250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251 | p.kernel_h <= 4 && p.kernel_w <= 4) {
252 | mode = 1;
253 | tile_out_h = 16;
254 | tile_out_w = 64;
255 | }
256 |
257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258 | p.kernel_h <= 3 && p.kernel_w <= 3) {
259 | mode = 2;
260 | tile_out_h = 16;
261 | tile_out_w = 64;
262 | }
263 |
264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265 | p.kernel_h <= 4 && p.kernel_w <= 4) {
266 | mode = 3;
267 | tile_out_h = 16;
268 | tile_out_w = 64;
269 | }
270 |
271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272 | p.kernel_h <= 2 && p.kernel_w <= 2) {
273 | mode = 4;
274 | tile_out_h = 16;
275 | tile_out_w = 64;
276 | }
277 |
278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279 | p.kernel_h <= 4 && p.kernel_w <= 4) {
280 | mode = 5;
281 | tile_out_h = 8;
282 | tile_out_w = 32;
283 | }
284 |
285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286 | p.kernel_h <= 2 && p.kernel_w <= 2) {
287 | mode = 6;
288 | tile_out_h = 8;
289 | tile_out_w = 32;
290 | }
291 |
292 | dim3 block_size;
293 | dim3 grid_size;
294 |
295 | if (tile_out_h > 0 && tile_out_w > 0) {
296 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
297 | p.loop_x = 1;
298 | block_size = dim3(32 * 8, 1, 1);
299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301 | (p.major_dim - 1) / p.loop_major + 1);
302 | } else {
303 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
304 | p.loop_x = 4;
305 | block_size = dim3(4, 32, 1);
306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308 | (p.major_dim - 1) / p.loop_major + 1);
309 | }
310 |
311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312 | switch (mode) {
313 | case 1:
314 | upfirdn2d_kernel
315 | <<>>(out.data_ptr(),
316 | x.data_ptr(),
317 | k.data_ptr(), p);
318 |
319 | break;
320 |
321 | case 2:
322 | upfirdn2d_kernel
323 | <<>>(out.data_ptr(),
324 | x.data_ptr(),
325 | k.data_ptr(), p);
326 |
327 | break;
328 |
329 | case 3:
330 | upfirdn2d_kernel
331 | <<>>(out.data_ptr(),
332 | x.data_ptr(),
333 | k.data_ptr(), p);
334 |
335 | break;
336 |
337 | case 4:
338 | upfirdn2d_kernel
339 | <<>>(out.data_ptr(),
340 | x.data_ptr(),
341 | k.data_ptr(), p);
342 |
343 | break;
344 |
345 | case 5:
346 | upfirdn2d_kernel
347 | <<>>(out.data_ptr(),
348 | x.data_ptr(),
349 | k.data_ptr(), p);
350 |
351 | break;
352 |
353 | case 6:
354 | upfirdn2d_kernel
355 | <<>>(out.data_ptr(),
356 | x.data_ptr(),
357 | k.data_ptr(), p);
358 |
359 | break;
360 |
361 | default:
362 | upfirdn2d_kernel_large<<>>(
363 | out.data_ptr(), x.data_ptr(),
364 | k.data_ptr(), p);
365 | }
366 | });
367 |
368 | return out;
369 | }
370 |
--------------------------------------------------------------------------------
/src/utils/biggan_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch
3 |
4 | MIT License
5 |
6 | Copyright (c) 2019 Andy Brock
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 | """
26 |
27 |
28 | import random
29 |
30 | from utils.sample import sample_latents
31 |
32 | import torch
33 |
34 |
35 |
36 | class ema(object):
37 | def __init__(self, source, target, decay=0.9999, start_itr=0):
38 | self.source = source
39 | self.target = target
40 | self.decay = decay
41 | # Optional parameter indicating what iteration to start the decay at
42 | self.start_itr = start_itr
43 | # Initialize target's params to be source's
44 | self.source_dict = self.source.state_dict()
45 | self.target_dict = self.target.state_dict()
46 | print('Initializing EMA parameters to be source parameters...')
47 | with torch.no_grad():
48 | for key in self.source_dict:
49 | self.target_dict[key].data.copy_(self.source_dict[key].data)
50 |
51 | def update(self, itr=None):
52 | # If an iteration counter is provided and itr is less than the start itr,
53 | # peg the ema weights to the underlying weights.
54 | if itr >= 0 and itr < self.start_itr:
55 | decay = 0.0
56 | else:
57 | decay = self.decay
58 | with torch.no_grad():
59 | for key in self.source_dict:
60 | self.target_dict[key].data.copy_(self.target_dict[key].data * decay + self.source_dict[key].data * (1 - decay))
61 |
62 |
63 | class ema_DP_SyncBN(object):
64 | def __init__(self, source, target, decay=0.9999, start_itr=0):
65 | self.source = source
66 | self.target = target
67 | self.decay = decay
68 | self.start_itr = start_itr
69 | # Initialize target's params to be source's
70 | print('Initializing EMA parameters to be source parameters...')
71 | with torch.no_grad():
72 | for key in self.source.state_dict():
73 | self.target.state_dict()[key].data.copy_(self.source.state_dict()[key].data)
74 |
75 |
76 | def update(self, itr=None):
77 | # If an iteration counter is provided and itr is less than the start itr,
78 | # peg the ema weights to the underlying weights.
79 | if itr >= 0 and itr < self.start_itr:
80 | decay = 0.0
81 | else:
82 | decay = self.decay
83 | with torch.no_grad():
84 | for key in self.source.state_dict():
85 | data = self.target.state_dict()[key].data*decay + self.source.state_dict()[key].data*(1.-decay)
86 | self.target.state_dict()[key].data.copy_(data)
87 |
88 |
89 | def ortho(model, strength=1e-4, blacklist=[]):
90 | with torch.no_grad():
91 | for param in model.parameters():
92 | # Only apply this to parameters with at least 2 axes, and not in the blacklist
93 | if len(param.shape) < 2 or any([param is item for item in blacklist]):
94 | continue
95 | w = param.view(param.shape[0], -1)
96 | grad = (2 * torch.mm(torch.mm(w, w.t())
97 | * (1. - torch.eye(w.shape[0], device=w.device)), w))
98 | param.grad.data += strength * grad.view(param.shape)
99 |
100 |
101 | # Interp function; expects x0 and x1 to be of shape (shape0, 1, rest_of_shape..)
102 | def interp(x0, x1, num_midpoints):
103 | lerp = torch.linspace(0, 1.0, num_midpoints + 2, device='cuda').to(x0.dtype)
104 | return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1)))
105 |
--------------------------------------------------------------------------------
/src/utils/cr_diff_aug.py:
--------------------------------------------------------------------------------
1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2 | # The MIT License (MIT)
3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4 |
5 | # src/utils/cr_diff_aug.py
6 |
7 |
8 | import random
9 |
10 | import torch
11 | import torch.nn.functional as F
12 |
13 |
14 |
15 | def CR_DiffAug(x, flip=True, translation=True):
16 | if flip:
17 | x = random_flip(x, 0.5)
18 | if translation:
19 | x = random_translation(x, 1/8)
20 | if flip or translation:
21 | x = x.contiguous()
22 | return x
23 |
24 |
25 | def random_flip(x, p):
26 | x_out = x.clone()
27 | n, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
28 | flip_prob = torch.FloatTensor(n, 1).uniform_(0.0, 1.0)
29 | flip_mask = flip_prob < p
30 | flip_mask = flip_mask.type(torch.bool).view(n, 1, 1, 1).repeat(1, c, h, w).to(x.device)
31 | x_out[flip_mask] = torch.flip(x[flip_mask].view(-1, c, h, w), [3]).view(-1)
32 | return x_out
33 |
34 |
35 | def random_translation(x, ratio):
36 | max_t_x, max_t_y = int(x.shape[2]*ratio), int(x.shape[3]*ratio)
37 | t_x = torch.randint(-max_t_x, max_t_x + 1, size = [x.shape[0], 1, 1], device=x.device)
38 | t_y = torch.randint(-max_t_y, max_t_y + 1, size = [x.shape[0], 1, 1], device=x.device)
39 |
40 | grid_batch, grid_x, grid_y = torch.meshgrid(
41 | torch.arange(x.shape[0], dtype=torch.long, device=x.device),
42 | torch.arange(x.shape[2], dtype=torch.long, device=x.device),
43 | torch.arange(x.shape[3], dtype=torch.long, device=x.device),
44 | )
45 |
46 | grid_x = (grid_x + t_x) + max_t_x
47 | grid_y = (grid_y + t_y) + max_t_y
48 | x_pad = F.pad(input=x, pad=[max_t_x, max_t_x, max_t_y, max_t_y], mode='reflect')
49 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
50 | return x
51 |
--------------------------------------------------------------------------------
/src/utils/diff_aug.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | * Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | """
26 |
27 |
28 | import torch
29 | import torch.nn.functional as F
30 |
31 |
32 |
33 | ### Differentiable Augmentation for Data-Efficient GAN Training (https://arxiv.org/abs/2006.10738)
34 | ### Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
35 | ### https://github.com/mit-han-lab/data-efficient-gans
36 |
37 |
38 | def DiffAugment(x, policy='', channels_first=True):
39 | if policy:
40 | if not channels_first:
41 | x = x.permute(0, 3, 1, 2)
42 | for p in policy.split(','):
43 | for f in AUGMENT_FNS[p]:
44 | x = f(x)
45 | if not channels_first:
46 | x = x.permute(0, 2, 3, 1)
47 | x = x.contiguous()
48 | return x
49 |
50 |
51 | def rand_brightness(x):
52 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
53 | return x
54 |
55 |
56 | def rand_saturation(x):
57 | x_mean = x.mean(dim=1, keepdim=True)
58 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
59 | return x
60 |
61 |
62 | def rand_contrast(x):
63 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
64 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
65 | return x
66 |
67 |
68 | def rand_translation(x, ratio=0.125):
69 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
70 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
71 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
72 | grid_batch, grid_x, grid_y = torch.meshgrid(
73 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
74 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
75 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
76 | )
77 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
78 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
79 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
80 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
81 | return x
82 |
83 |
84 | def rand_cutout(x, ratio=0.5):
85 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
86 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
87 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
88 | grid_batch, grid_x, grid_y = torch.meshgrid(
89 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
90 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
91 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
92 | )
93 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
94 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
95 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
96 | mask[grid_batch, grid_x, grid_y] = 0
97 | x = x * mask.unsqueeze(1)
98 | return x
99 |
100 |
101 | AUGMENT_FNS = {
102 | 'color': [rand_brightness, rand_saturation, rand_contrast],
103 | 'translation': [rand_translation],
104 | 'cutout': [rand_cutout],
105 | }
106 |
--------------------------------------------------------------------------------
/src/utils/load_checkpoint.py:
--------------------------------------------------------------------------------
1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2 | # The MIT License (MIT)
3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4 |
5 | # src/utils/load_checkpoint.py
6 |
7 |
8 | import os
9 |
10 | import torch
11 |
12 |
13 |
14 | def load_checkpoint(model, optimizer, filename, metric=False, ema=False):
15 | start_step = 0
16 | if ema:
17 | checkpoint = torch.load(filename)
18 | model.load_state_dict(checkpoint['state_dict'])
19 | return model
20 | else:
21 | checkpoint = torch.load(filename)
22 | seed = checkpoint['seed']
23 | run_name = checkpoint['run_name']
24 | start_step = checkpoint['step']
25 | model.load_state_dict(checkpoint['state_dict'])
26 | optimizer.load_state_dict(checkpoint['optimizer'])
27 | ada_p = checkpoint['ada_p']
28 | for state in optimizer.state.values():
29 | for k, v in state.items():
30 | if isinstance(v, torch.Tensor):
31 | state[k] = v.cuda()
32 |
33 | if metric:
34 | best_step = checkpoint['best_step']
35 | best_fid = checkpoint['best_fid']
36 | best_fid_checkpoint_path = checkpoint['best_fid_checkpoint_path']
37 | return model, optimizer, seed, run_name, start_step, ada_p, best_step, best_fid, best_fid_checkpoint_path
38 | return model, optimizer, seed, run_name, start_step, ada_p
39 |
--------------------------------------------------------------------------------
/src/utils/log.py:
--------------------------------------------------------------------------------
1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2 | # The MIT License (MIT)
3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4 |
5 | # src/utils/log.py
6 |
7 |
8 | import json
9 | import os
10 | import logging
11 | from os.path import dirname, abspath, exists, join
12 | from datetime import datetime
13 |
14 |
15 |
16 | def make_run_name(format, framework, phase):
17 | return format.format(
18 | framework=framework,
19 | phase=phase,
20 | timestamp=datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
21 | )
22 |
23 |
24 | def make_logger(run_name, log_output):
25 | if log_output is not None:
26 | run_name = log_output.split('/')[-1].split('.')[0]
27 | logger = logging.getLogger(run_name)
28 | logger.propagate = False
29 | log_filepath = log_output if log_output is not None else join('logs', f'{run_name}.log')
30 |
31 | log_dir = dirname(abspath(log_filepath))
32 | if not exists(log_dir):
33 | os.makedirs(log_dir)
34 |
35 | if not logger.handlers: # execute only if logger doesn't already exist
36 | file_handler = logging.FileHandler(log_filepath, 'a', 'utf-8')
37 | stream_handler = logging.StreamHandler(os.sys.stdout)
38 |
39 | formatter = logging.Formatter('[%(levelname)s] %(asctime)s > %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
40 |
41 | file_handler.setFormatter(formatter)
42 | stream_handler.setFormatter(formatter)
43 |
44 | logger.addHandler(file_handler)
45 | logger.addHandler(stream_handler)
46 | logger.setLevel(logging.INFO)
47 | return logger
48 |
49 |
50 | def make_checkpoint_dir(checkpoint_dir, run_name):
51 | checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else join('checkpoints', run_name)
52 | if not exists(abspath(checkpoint_dir)):
53 | os.makedirs(checkpoint_dir)
54 | return checkpoint_dir
55 |
--------------------------------------------------------------------------------
/src/utils/make_hdf5.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch
3 |
4 | MIT License
5 |
6 | Copyright (c) 2019 Andy Brock
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | """
23 |
24 |
25 | import os
26 | import sys
27 | import h5py as h5
28 | import numpy as np
29 | import PIL
30 | from argparse import ArgumentParser
31 | from tqdm import tqdm, trange
32 |
33 | from data_utils.load_dataset import LoadDataset
34 |
35 | import torch
36 | import torchvision.transforms as transforms
37 | from torch.utils.data import DataLoader
38 |
39 |
40 |
41 | def make_hdf5(model_config, train_config, mode):
42 | if 'hdf5' in model_config['dataset_name']:
43 | raise ValueError('Reading from an HDF5 file which you will probably be '
44 | 'about to overwrite! Override this error only if you know '
45 | 'what you''re doing!')
46 |
47 | file_name = '{dataset_name}_{size}_{mode}.hdf5'.format(dataset_name=model_config['dataset_name'], size=model_config['img_size'], mode=mode)
48 | file_path = os.path.join(model_config['data_path'], file_name)
49 | train = True if mode == "train" else False
50 |
51 | if os.path.isfile(file_path):
52 | print("{file_name} exist!\nThe file are located in the {file_path}".format(file_name=file_name, file_path=file_path))
53 | else:
54 | dataset = LoadDataset(model_config['dataset_name'], model_config['data_path'], train=train, download=True, resize_size=model_config['img_size'],
55 | hdf5_path=None, random_flip=False)
56 |
57 | loader = DataLoader(dataset,
58 | batch_size=model_config['batch_size4prcsing'],
59 | shuffle=False,
60 | pin_memory=False,
61 | num_workers=train_config['num_workers'],
62 | drop_last=False)
63 |
64 | print('Starting to load %s into an HDF5 file with chunk size %i and compression %s...' % (model_config['dataset_name'],
65 | model_config['chunk_size'],
66 | model_config['compression']))
67 | # Loop over loader
68 | for i,(x,y) in enumerate(tqdm(loader)):
69 | # Numpyify x, y
70 | x = (255 * ((x + 1) / 2.0)).byte().numpy()
71 | y = y.numpy()
72 | # If we're on the first batch, prepare the hdf5
73 | if i==0:
74 | with h5.File(file_path, 'w') as f:
75 | print('Producing dataset of len %d' % len(loader.dataset))
76 | imgs_dset = f.create_dataset('imgs', x.shape, dtype='uint8', maxshape=(len(loader.dataset), 3,
77 | model_config['img_size'], model_config['img_size']),
78 | chunks=(model_config['chunk_size'], 3, model_config['img_size'], model_config['img_size']), compression=model_config['compression'])
79 | print('Image chunks chosen as ' + str(imgs_dset.chunks))
80 | imgs_dset[...] = x
81 |
82 | labels_dset = f.create_dataset('labels', y.shape, dtype='int64', maxshape=(len(loader.dataset),),
83 | chunks=(model_config['chunk_size'],), compression=model_config['compression'])
84 | print('Label chunks chosen as ' + str(labels_dset.chunks))
85 | labels_dset[...] = y
86 | # Else append to the hdf5
87 | else:
88 | with h5.File(file_path, 'a') as f:
89 | f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0)
90 | f['imgs'][-x.shape[0]:] = x
91 | f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0)
92 | f['labels'][-y.shape[0]:] = y
93 | return file_path
94 |
--------------------------------------------------------------------------------
/src/utils/model_ops.py:
--------------------------------------------------------------------------------
1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2 | # The MIT License (MIT)
3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4 |
5 | # src/utils/model_ops.py
6 |
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn.utils import spectral_norm
11 | from torch.nn import init
12 |
13 |
14 |
15 | def init_weights(modules, initialize):
16 | for module in modules():
17 | if (isinstance(module, nn.Conv2d)
18 | or isinstance(module, nn.ConvTranspose2d)
19 | or isinstance(module, nn.Linear)):
20 | if initialize == 'ortho':
21 | init.orthogonal_(module.weight)
22 | if module.bias is not None:
23 | module.bias.data.fill_(0.)
24 | elif initialize == 'N02':
25 | init.normal_(module.weight, 0, 0.02)
26 | if module.bias is not None:
27 | module.bias.data.fill_(0.)
28 | elif initialize in ['glorot', 'xavier']:
29 | init.xavier_uniform_(module.weight)
30 | if module.bias is not None:
31 | module.bias.data.fill_(0.)
32 | else:
33 | print('Init style not recognized...')
34 | elif isinstance(module, nn.Embedding):
35 | if initialize == 'ortho':
36 | init.orthogonal_(module.weight)
37 | elif initialize == 'N02':
38 | init.normal_(module.weight, 0, 0.02)
39 | elif initialize in ['glorot', 'xavier']:
40 | init.xavier_uniform_(module.weight)
41 | else:
42 | print('Init style not recognized...')
43 | else:
44 | pass
45 |
46 |
47 | def conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
48 | return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
49 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
50 |
51 | def deconv2d(in_channels, out_channels, kernel_size, stride=2, padding=0, dilation=1, groups=1, bias=True):
52 | return nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
53 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
54 |
55 | def linear(in_features, out_features, bias=True):
56 | return nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
57 |
58 | def embedding(num_embeddings, embedding_dim):
59 | return nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
60 |
61 | def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
62 | return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
63 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias), eps=1e-6)
64 |
65 | def sndeconv2d(in_channels, out_channels, kernel_size, stride=2, padding=0, dilation=1, groups=1, bias=True):
66 | return spectral_norm(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
67 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias), eps=1e-6)
68 |
69 | def snlinear(in_features, out_features, bias=True):
70 | return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features, bias=bias), eps=1e-6)
71 |
72 | def sn_embedding(num_embeddings, embedding_dim):
73 | return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim), eps=1e-6)
74 |
75 | def batchnorm_2d(in_features, eps=1e-4, momentum=0.1, affine=True):
76 | return nn.BatchNorm2d(in_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=True)
77 |
78 |
79 | class ConditionalBatchNorm2d(nn.Module):
80 | # https://github.com/voletiv/self-attention-GAN-pytorch
81 | def __init__(self, num_features, num_classes, spectral_norm):
82 | super().__init__()
83 | self.num_features = num_features
84 | self.bn = batchnorm_2d(num_features, eps=1e-4, momentum=0.1, affine=False)
85 |
86 | if spectral_norm:
87 | self.embed0 = sn_embedding(num_classes, num_features)
88 | self.embed1 = sn_embedding(num_classes, num_features)
89 | else:
90 | self.embed0 = embedding(num_classes, num_features)
91 | self.embed1 = embedding(num_classes, num_features)
92 |
93 | def forward(self, x, y):
94 | gain = (1 + self.embed0(y)).view(-1, self.num_features, 1, 1)
95 | bias = self.embed1(y).view(-1, self.num_features, 1, 1)
96 | out = self.bn(x)
97 | return out * gain + bias
98 |
99 |
100 | class ConditionalBatchNorm2d_for_skip_and_shared(nn.Module):
101 | # https://github.com/voletiv/self-attention-GAN-pytorch
102 | def __init__(self, num_features, z_dims_after_concat, spectral_norm):
103 | super().__init__()
104 | self.num_features = num_features
105 | self.bn = batchnorm_2d(num_features, eps=1e-4, momentum=0.1, affine=False)
106 |
107 | if spectral_norm:
108 | self.gain = snlinear(z_dims_after_concat, num_features, bias=False)
109 | self.bias = snlinear(z_dims_after_concat, num_features, bias=False)
110 | else:
111 | self.gain = linear(z_dims_after_concat, num_features, bias=False)
112 | self.bias = linear(z_dims_after_concat, num_features, bias=False)
113 |
114 | def forward(self, x, y):
115 | gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
116 | bias = self.bias(y).view(y.size(0), -1, 1, 1)
117 | out = self.bn(x)
118 | return out * gain + bias
119 |
120 |
121 | class Self_Attn(nn.Module):
122 | # https://github.com/voletiv/self-attention-GAN-pytorch
123 | def __init__(self, in_channels, spectral_norm):
124 | super(Self_Attn, self).__init__()
125 | self.in_channels = in_channels
126 |
127 | if spectral_norm:
128 | self.conv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0, bias=False)
129 | self.conv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0, bias=False)
130 | self.conv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0, bias=False)
131 | self.conv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0, bias=False)
132 | else:
133 | self.conv1x1_theta = conv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0, bias=False)
134 | self.conv1x1_phi = conv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0, bias=False)
135 | self.conv1x1_g = conv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0, bias=False)
136 | self.conv1x1_attn = conv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0, bias=False)
137 |
138 | self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
139 | self.softmax = nn.Softmax(dim=-1)
140 | self.sigma = nn.Parameter(torch.zeros(1))
141 |
142 | def forward(self, x):
143 | """
144 | inputs :
145 | x : input feature maps(B X C X H X W)
146 | returns :
147 | out : self attention value + input feature
148 | attention: B X N X N (N is Width*Height)
149 | """
150 | _, ch, h, w = x.size()
151 | # Theta path
152 | theta = self.conv1x1_theta(x)
153 | theta = theta.view(-1, ch//8, h*w)
154 | # Phi path
155 | phi = self.conv1x1_phi(x)
156 | phi = self.maxpool(phi)
157 | phi = phi.view(-1, ch//8, h*w//4)
158 | # Attn map
159 | attn = torch.bmm(theta.permute(0, 2, 1), phi)
160 | attn = self.softmax(attn)
161 | # g path
162 | g = self.conv1x1_g(x)
163 | g = self.maxpool(g)
164 | g = g.view(-1, ch//2, h*w//4)
165 | # Attn_g
166 | attn_g = torch.bmm(g, attn.permute(0, 2, 1))
167 | attn_g = attn_g.view(-1, ch//2, h, w)
168 | attn_g = self.conv1x1_attn(attn_g)
169 | return x + self.sigma*attn_g
170 |
171 |
--------------------------------------------------------------------------------
/src/utils/sample.py:
--------------------------------------------------------------------------------
1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2 | # The MIT License (MIT)
3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4 |
5 | # src/utils/sample.py
6 |
7 |
8 | import numpy as np
9 | import random
10 | from numpy import linalg
11 | from math import sin,cos,sqrt
12 |
13 | from utils.losses import latent_optimise
14 |
15 | import torch
16 | import torch.nn.functional as F
17 | from torch.nn import DataParallel
18 |
19 |
20 |
21 | def sample_latents(dist, batch_size, dim, truncated_factor=1, num_classes=None, perturb=None, device=torch.device("cpu"), sampler="default"):
22 | if num_classes:
23 | if sampler == "default":
24 | y_fake = torch.randint(low=0, high=num_classes, size=(batch_size,), dtype=torch.long, device=device)
25 | elif sampler == "class_order_some":
26 | assert batch_size % 8 == 0, "The size of the batches should be a multiple of 8."
27 | num_classes_plot = batch_size//8
28 | indices = np.random.permutation(num_classes)[:num_classes_plot]
29 | elif sampler == "class_order_all":
30 | batch_size = num_classes*8
31 | indices = [c for c in range(num_classes)]
32 | elif isinstance(sampler, int):
33 | y_fake = torch.tensor([sampler]*batch_size, dtype=torch.long).to(device)
34 | else:
35 | raise NotImplementedError
36 |
37 | if sampler in ["class_order_some", "class_order_all"]:
38 | y_fake = []
39 | for idx in indices:
40 | y_fake += [idx]*8
41 | y_fake = torch.tensor(y_fake, dtype=torch.long).to(device)
42 | else:
43 | y_fake = None
44 |
45 | if isinstance(perturb, float) and perturb > 0.0:
46 | if dist == "gaussian":
47 | latents = torch.randn(batch_size, dim, device=device)/truncated_factor
48 | eps = perturb*torch.randn(batch_size, dim, device=device)
49 | latents_eps = latents + eps
50 | elif dist == "uniform":
51 | latents = torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device)
52 | eps = perturb*torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device)
53 | latents_eps = latents + eps
54 | elif dist == "hyper_sphere":
55 | latents, latents_eps = random_ball(batch_size, dim, perturb=perturb)
56 | latents, latents_eps = torch.FloatTensor(latents).to(device), torch.FloatTensor(latents_eps).to(device)
57 | return latents, y_fake, latents_eps
58 | else:
59 | if dist == "gaussian":
60 | latents = torch.randn(batch_size, dim, device=device)/truncated_factor
61 | elif dist == "uniform":
62 | latents = torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device)
63 | elif dist == "hyper_sphere":
64 | latents = random_ball(batch_size, dim, perturb=perturb).to(device)
65 | return latents, y_fake
66 |
67 |
68 | def random_ball(batch_size, z_dim, perturb=False):
69 | if perturb:
70 | normal = np.random.normal(size=(z_dim, batch_size))
71 | random_directions = normal/linalg.norm(normal, axis=0)
72 | random_radii = random.random(batch_size) ** (1/z_dim)
73 | zs = 1.0 * (random_directions * random_radii).T
74 |
75 | normal_perturb = normal + 0.05*np.random.normal(size=(z_dim, batch_size))
76 | perturb_random_directions = normal_perturb/linalg.norm(normal_perturb, axis=0)
77 | perturb_random_radii = random.random(batch_size) ** (1/z_dim)
78 | zs_perturb = 1.0 * (perturb_random_directions * perturb_random_radii).T
79 | return zs, zs_perturb
80 | else:
81 | normal = np.random.normal(size=(z_dim, batch_size))
82 | random_directions = normal/linalg.norm(normal, axis=0)
83 | random_radii = random.random(batch_size) ** (1/z_dim)
84 | zs = 1.0 * (random_directions * random_radii).T
85 | return zs
86 |
87 |
88 | # Convenience function to sample an index, not actually a 1-hot
89 | def sample_1hot(batch_size, num_classes, device='cuda'):
90 | return torch.randint(low=0, high=num_classes, size=(batch_size,),
91 | device=device, dtype=torch.int64, requires_grad=False)
92 |
93 |
94 | def make_mask(labels, n_cls, device):
95 | labels = labels.detach().cpu().numpy()
96 | n_samples = labels.shape[0]
97 | mask_multi = np.zeros([n_cls, n_samples])
98 | for c in range(n_cls):
99 | c_indices = np.where(labels==c)
100 | mask_multi[c, c_indices] =+1
101 |
102 | mask_multi = torch.tensor(mask_multi).type(torch.long)
103 | return mask_multi.to(device)
104 |
105 |
106 | def target_class_sampler(dataset, target_class):
107 | try:
108 | targets = dataset.data.targets
109 | except:
110 | targets = dataset.labels
111 | weights = [True if target == target_class else False for target in targets]
112 | num_samples = sum(weights)
113 | weights = torch.DoubleTensor(weights)
114 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights), replacement=False)
115 | return num_samples, sampler
116 |
--------------------------------------------------------------------------------