├── .gitignore
├── LICENSE
├── README.md
├── assets
└── teaser.png
└── realness_loss.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Junho Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## RealnessGAN — Simple TensorFlow Implementation [[Paper]](https://openreview.net/pdf?id=B1lPaCNtPB)
2 | ### : Real or Not Real, that is the Question
3 |
4 |
5 |

6 |
7 |
8 | ## Usage
9 | ```python
10 |
11 | fake_img = generator(noise)
12 |
13 | real_logit = discriminator(real_img)
14 | fake_logit = discriminator(fake_img)
15 |
16 | g_loss = generator_loss(real_logit, fake_logit)
17 | d_loss = discriminator_loss(real_logit, fake_logit)
18 |
19 | ```
20 |
21 | ## Author
22 | [Junho Kim](http://bit.ly/jhkim_ai)
23 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/RealnessGAN-Tensorflow/f22df9bdf179148568abc602494b35b0cbb6d1e7/assets/teaser.png
--------------------------------------------------------------------------------
/realness_loss.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | # tensorflow >= 2.0
5 |
6 | def discriminator_loss(Ra, real_logit, fake_logit):
7 | # Ra = Relativistic
8 | if Ra:
9 | fake_logit = tf.exp(tf.nn.log_softmax(fake_logit, axis=-1))
10 | real_logit = tf.exp(tf.nn.log_softmax(real_logit, axis=-1))
11 |
12 | num_outcomes = real_logit.shape[-1]
13 |
14 | gauss = np.random.normal(0, 0.1, 1000)
15 | count, bins = np.histogram(gauss, num_outcomes)
16 | anchor0 = count / sum(count) # anchor_fake
17 |
18 | unif = np.random.uniform(-1, 1, 1000)
19 | count, bins = np.histogram(unif, num_outcomes)
20 | anchor1 = count / sum(count) # anchor_real
21 |
22 | anchor_real = tf.zeros([real_logit.shape[0], num_outcomes]) + tf.cast(anchor1, tf.float32)
23 | anchor_fake = tf.zeros([real_logit.shape[0], num_outcomes]) + tf.cast(anchor0, tf.float32)
24 |
25 | real_loss = realness_loss(anchor_real, real_logit, skewness=10.0)
26 | fake_loss = realness_loss(anchor_fake, fake_logit, skewness=-10.0)
27 |
28 | else:
29 | fake_logit = tf.exp(tf.nn.log_softmax(fake_logit, axis=-1))
30 | real_logit = tf.exp(tf.nn.log_softmax(real_logit, axis=-1))
31 |
32 | num_outcomes = real_logit.shape[-1]
33 |
34 | gauss = np.random.normal(0, 0.1, 1000)
35 | count, bins = np.histogram(gauss, num_outcomes)
36 | anchor0 = count / sum(count) # anchor_fake
37 |
38 | unif = np.random.uniform(-1, 1, 1000)
39 | count, bins = np.histogram(unif, num_outcomes)
40 | anchor1 = count / sum(count) # anchor_real
41 |
42 | anchor_real = tf.zeros([real_logit.shape[0], num_outcomes]) + tf.cast(anchor1, tf.float32)
43 | anchor_fake = tf.zeros([real_logit.shape[0], num_outcomes]) + tf.cast(anchor0, tf.float32)
44 |
45 | real_loss = realness_loss(anchor_real, real_logit, skewness=10.0)
46 | fake_loss = realness_loss(anchor_fake, fake_logit, skewness=-10.0)
47 |
48 | loss = real_loss + fake_loss
49 |
50 | return loss
51 |
52 |
53 | def generator_loss(Ra, real_logit, fake_logit):
54 | # Ra = Relativistic
55 |
56 | if Ra:
57 | fake_logit = tf.exp(tf.nn.log_softmax(fake_logit, axis=-1))
58 | real_logit = tf.exp(tf.nn.log_softmax(real_logit, axis=-1))
59 |
60 | fake_loss = realness_loss(real_logit, fake_logit)
61 |
62 | else:
63 | num_outcomes = real_logit.shape[-1]
64 | unif = np.random.uniform(-1, 1, 1000)
65 | count, bins = np.histogram(unif, num_outcomes)
66 | anchor1 = count / sum(count) # anchor_real
67 | anchor_real = tf.zeros([real_logit.shape[0], num_outcomes]) + tf.cast(anchor1, tf.float32)
68 |
69 | fake_logit = tf.exp(tf.nn.log_softmax(fake_logit, axis=-1))
70 | fake_loss = realness_loss(anchor_real, fake_logit, skewness=10.0)
71 |
72 | loss = fake_loss
73 |
74 | return loss
75 |
76 | def realness_loss(anchor, feature, skewness=0.0, positive_skew=10.0, negative_skew=-10.0):
77 | """
78 | num_outcomes = anchor.shape[-1]
79 | positive_skew = 10.0
80 | negative_skew = -10.0
81 | # [num_outcomes, positive_skew, negative_skew]
82 | # [51, 10.0, -10.0]
83 | # [21, 1.0, -1.0]
84 |
85 | gauss = np.random.normal(0, 0.1, 1000)
86 | count, bins = np.histogram(gauss, num_outcomes)
87 | anchor0 = count / sum(count) # anchor_fake
88 |
89 | unif = np.random.uniform(-1, 1, 1000)
90 | count, bins = np.histogram(unif, num_outcomes)
91 | anchor1 = count / sum(count) # anchor_real
92 | """
93 |
94 | batch_size = feature.shape[0]
95 | num_outcomes = feature.shape[-1]
96 |
97 | supports = tf.linspace(start=negative_skew, stop=positive_skew, num=num_outcomes)
98 | delta = (positive_skew - negative_skew) / (num_outcomes - 1)
99 |
100 | skew = tf.fill(dims=[batch_size, num_outcomes], value=skewness)
101 |
102 | # experiment to adjust KL divergence between positive/negative anchors
103 | Tz = skew + tf.reshape(supports, shape=[1, -1]) * tf.ones(shape=[batch_size, 1])
104 | Tz = tf.clip_by_value(Tz, negative_skew, positive_skew)
105 |
106 | b = (Tz - negative_skew) / delta
107 | lower_b = tf.cast(tf.math.floor(b), tf.int32).numpy()
108 | upper_b = tf.cast(tf.math.ceil(b), tf.int32).numpy()
109 |
110 | lower_b[(upper_b > 0) * (lower_b == upper_b)] -= 1
111 | upper_b[(lower_b < (num_outcomes - 1)) * (lower_b == upper_b)] += 1
112 |
113 | offset = tf.expand_dims(tf.linspace(start=0.0, stop=(batch_size - 1) * num_outcomes, num=batch_size), axis=1)
114 | offset = tf.tile(offset, multiples=[1, num_outcomes])
115 |
116 | skewed_anchor = tf.reshape(tf.zeros(shape=[batch_size, num_outcomes]), shape=[-1]).numpy()
117 | lower_idx = tf.cast(tf.reshape(lower_b + offset, shape=[-1]), tf.int32).numpy()
118 | lower_updates = tf.reshape(anchor * (tf.cast(upper_b, tf.float32) - b), shape=[-1]).numpy()
119 | skewed_anchor[lower_idx] += lower_updates
120 |
121 | upper_idx = tf.cast(tf.reshape(upper_b + offset, shape=[-1]), tf.int32).numpy()
122 | upper_updates = tf.reshape(anchor * (b - tf.cast(lower_b, tf.float32)), shape=[-1])
123 | skewed_anchor[upper_idx] += upper_updates
124 |
125 | loss = -(skewed_anchor * tf.reduce_mean(tf.reduce_sum(tf.math.log((feature + 1e-16)), axis=-1)))
126 |
127 | return loss
--------------------------------------------------------------------------------