├── .gitignore ├── LICENSE ├── README.md ├── assets └── teaser.png └── realness_loss.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## RealnessGAN — Simple TensorFlow Implementation [[Paper]](https://openreview.net/pdf?id=B1lPaCNtPB) 2 | ### : Real or Not Real, that is the Question 3 | 4 |
5 | 6 |
7 | 8 | ## Usage 9 | ```python 10 | 11 | fake_img = generator(noise) 12 | 13 | real_logit = discriminator(real_img) 14 | fake_logit = discriminator(fake_img) 15 | 16 | g_loss = generator_loss(real_logit, fake_logit) 17 | d_loss = discriminator_loss(real_logit, fake_logit) 18 | 19 | ``` 20 | 21 | ## Author 22 | [Junho Kim](http://bit.ly/jhkim_ai) 23 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/RealnessGAN-Tensorflow/f22df9bdf179148568abc602494b35b0cbb6d1e7/assets/teaser.png -------------------------------------------------------------------------------- /realness_loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | # tensorflow >= 2.0 5 | 6 | def discriminator_loss(Ra, real_logit, fake_logit): 7 | # Ra = Relativistic 8 | if Ra: 9 | fake_logit = tf.exp(tf.nn.log_softmax(fake_logit, axis=-1)) 10 | real_logit = tf.exp(tf.nn.log_softmax(real_logit, axis=-1)) 11 | 12 | num_outcomes = real_logit.shape[-1] 13 | 14 | gauss = np.random.normal(0, 0.1, 1000) 15 | count, bins = np.histogram(gauss, num_outcomes) 16 | anchor0 = count / sum(count) # anchor_fake 17 | 18 | unif = np.random.uniform(-1, 1, 1000) 19 | count, bins = np.histogram(unif, num_outcomes) 20 | anchor1 = count / sum(count) # anchor_real 21 | 22 | anchor_real = tf.zeros([real_logit.shape[0], num_outcomes]) + tf.cast(anchor1, tf.float32) 23 | anchor_fake = tf.zeros([real_logit.shape[0], num_outcomes]) + tf.cast(anchor0, tf.float32) 24 | 25 | real_loss = realness_loss(anchor_real, real_logit, skewness=10.0) 26 | fake_loss = realness_loss(anchor_fake, fake_logit, skewness=-10.0) 27 | 28 | else: 29 | fake_logit = tf.exp(tf.nn.log_softmax(fake_logit, axis=-1)) 30 | real_logit = tf.exp(tf.nn.log_softmax(real_logit, axis=-1)) 31 | 32 | num_outcomes = real_logit.shape[-1] 33 | 34 | gauss = np.random.normal(0, 0.1, 1000) 35 | count, bins = np.histogram(gauss, num_outcomes) 36 | anchor0 = count / sum(count) # anchor_fake 37 | 38 | unif = np.random.uniform(-1, 1, 1000) 39 | count, bins = np.histogram(unif, num_outcomes) 40 | anchor1 = count / sum(count) # anchor_real 41 | 42 | anchor_real = tf.zeros([real_logit.shape[0], num_outcomes]) + tf.cast(anchor1, tf.float32) 43 | anchor_fake = tf.zeros([real_logit.shape[0], num_outcomes]) + tf.cast(anchor0, tf.float32) 44 | 45 | real_loss = realness_loss(anchor_real, real_logit, skewness=10.0) 46 | fake_loss = realness_loss(anchor_fake, fake_logit, skewness=-10.0) 47 | 48 | loss = real_loss + fake_loss 49 | 50 | return loss 51 | 52 | 53 | def generator_loss(Ra, real_logit, fake_logit): 54 | # Ra = Relativistic 55 | 56 | if Ra: 57 | fake_logit = tf.exp(tf.nn.log_softmax(fake_logit, axis=-1)) 58 | real_logit = tf.exp(tf.nn.log_softmax(real_logit, axis=-1)) 59 | 60 | fake_loss = realness_loss(real_logit, fake_logit) 61 | 62 | else: 63 | num_outcomes = real_logit.shape[-1] 64 | unif = np.random.uniform(-1, 1, 1000) 65 | count, bins = np.histogram(unif, num_outcomes) 66 | anchor1 = count / sum(count) # anchor_real 67 | anchor_real = tf.zeros([real_logit.shape[0], num_outcomes]) + tf.cast(anchor1, tf.float32) 68 | 69 | fake_logit = tf.exp(tf.nn.log_softmax(fake_logit, axis=-1)) 70 | fake_loss = realness_loss(anchor_real, fake_logit, skewness=10.0) 71 | 72 | loss = fake_loss 73 | 74 | return loss 75 | 76 | def realness_loss(anchor, feature, skewness=0.0, positive_skew=10.0, negative_skew=-10.0): 77 | """ 78 | num_outcomes = anchor.shape[-1] 79 | positive_skew = 10.0 80 | negative_skew = -10.0 81 | # [num_outcomes, positive_skew, negative_skew] 82 | # [51, 10.0, -10.0] 83 | # [21, 1.0, -1.0] 84 | 85 | gauss = np.random.normal(0, 0.1, 1000) 86 | count, bins = np.histogram(gauss, num_outcomes) 87 | anchor0 = count / sum(count) # anchor_fake 88 | 89 | unif = np.random.uniform(-1, 1, 1000) 90 | count, bins = np.histogram(unif, num_outcomes) 91 | anchor1 = count / sum(count) # anchor_real 92 | """ 93 | 94 | batch_size = feature.shape[0] 95 | num_outcomes = feature.shape[-1] 96 | 97 | supports = tf.linspace(start=negative_skew, stop=positive_skew, num=num_outcomes) 98 | delta = (positive_skew - negative_skew) / (num_outcomes - 1) 99 | 100 | skew = tf.fill(dims=[batch_size, num_outcomes], value=skewness) 101 | 102 | # experiment to adjust KL divergence between positive/negative anchors 103 | Tz = skew + tf.reshape(supports, shape=[1, -1]) * tf.ones(shape=[batch_size, 1]) 104 | Tz = tf.clip_by_value(Tz, negative_skew, positive_skew) 105 | 106 | b = (Tz - negative_skew) / delta 107 | lower_b = tf.cast(tf.math.floor(b), tf.int32).numpy() 108 | upper_b = tf.cast(tf.math.ceil(b), tf.int32).numpy() 109 | 110 | lower_b[(upper_b > 0) * (lower_b == upper_b)] -= 1 111 | upper_b[(lower_b < (num_outcomes - 1)) * (lower_b == upper_b)] += 1 112 | 113 | offset = tf.expand_dims(tf.linspace(start=0.0, stop=(batch_size - 1) * num_outcomes, num=batch_size), axis=1) 114 | offset = tf.tile(offset, multiples=[1, num_outcomes]) 115 | 116 | skewed_anchor = tf.reshape(tf.zeros(shape=[batch_size, num_outcomes]), shape=[-1]).numpy() 117 | lower_idx = tf.cast(tf.reshape(lower_b + offset, shape=[-1]), tf.int32).numpy() 118 | lower_updates = tf.reshape(anchor * (tf.cast(upper_b, tf.float32) - b), shape=[-1]).numpy() 119 | skewed_anchor[lower_idx] += lower_updates 120 | 121 | upper_idx = tf.cast(tf.reshape(upper_b + offset, shape=[-1]), tf.int32).numpy() 122 | upper_updates = tf.reshape(anchor * (b - tf.cast(lower_b, tf.float32)), shape=[-1]) 123 | skewed_anchor[upper_idx] += upper_updates 124 | 125 | loss = -(skewed_anchor * tf.reduce_mean(tf.reduce_sum(tf.math.log((feature + 1e-16)), axis=-1))) 126 | 127 | return loss --------------------------------------------------------------------------------