├── .gitignore
├── .idea
├── WGAN-LP-tensorflow.iml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── README.md
├── ckpts
└── README.md
├── data_generator.py
├── model.py
├── reg_losses.py
├── run_experiments.sh
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/.idea/WGAN-LP-tensorflow.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 mikigom
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WGAN-LP-tensorflow
2 |
3 | [Report on arXiv](https://arxiv.org/abs/1712.05882)
4 |
5 | Reproduction code for the following paper:
6 |
7 | ```
8 | Title:
9 | On the regularization of Wasserstein GANs
10 | Authors:
11 | Petzka, Henning; Fischer, Asja; Lukovnicov, Denis
12 | Publication:
13 | eprint arXiv:1709.08894
14 | Publication Date:
15 | 09/2017
16 | Origin:
17 | ARXIV
18 | Keywords:
19 | Statistics - Machine Learning, Computer Science - Learning
20 | 2017arXiv170908894P
21 | ```
22 | [Original Paper on arXiv](https://arxiv.org/abs/1709.08894)
23 |
24 | ## Repository structure
25 |
26 | *data\_generator.py*
27 | - provides a class that generates the sample data needed for learning.
28 |
29 | *reg\_losses.py*
30 | - defines the sampling method and loss term for regularization.
31 |
32 | *model.py*
33 | - implements 3-layer neural networks for a generator and a critic.
34 |
35 | *trainer.py*
36 | - a pipeline for model learning and visualization.
37 |
--------------------------------------------------------------------------------
/ckpts/README.md:
--------------------------------------------------------------------------------
1 | ## ckpts
2 |
--------------------------------------------------------------------------------
/data_generator.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/igul222/improved_wgan_training/blob/master/gan_toy.py
2 |
3 | import numpy as np
4 | import sklearn.datasets
5 | import random
6 |
7 | # This module referes 'Python Generator', not 'Generative Model'.
8 |
9 |
10 | class GeneratorGaussians8(object):
11 | def __init__(self,
12 | batch_size: int=256,
13 | scale: float=2.,
14 | center_coor_min: float=-1.,
15 | center_coor_max: float=1.,
16 | stdev: float=1.414):
17 | self.batch_size = batch_size
18 | self.stdev = stdev
19 | scale = scale
20 | diag_len = np.sqrt(center_coor_min**2 + center_coor_max**2)
21 | centers = [
22 | (center_coor_max, 0.),
23 | (center_coor_min, 0.),
24 | (0., center_coor_max),
25 | (0., center_coor_min),
26 | (center_coor_max / diag_len, center_coor_max / diag_len),
27 | (center_coor_max / diag_len, center_coor_min / diag_len),
28 | (center_coor_min / diag_len, center_coor_max / diag_len),
29 | (center_coor_min / diag_len, center_coor_min / diag_len)
30 | ]
31 | self.centers = [(scale * x, scale * y) for x, y in centers]
32 |
33 | def __iter__(self):
34 | while True:
35 | dataset = []
36 | for i in range(self.batch_size):
37 | point = np.random.randn(2) * .02
38 | center = random.choice(self.centers)
39 | point[0] += center[0]
40 | point[1] += center[1]
41 | dataset.append(point)
42 | dataset = np.array(dataset, dtype='float32')
43 | dataset /= self.stdev
44 | yield dataset
45 |
46 |
47 | class GeneratorGaussians25(object):
48 | def __init__(self,
49 | batch_size: int=256,
50 | n_init_loop: int=4000,
51 | x_iter_range_min: int=-2,
52 | x_iter_range_max: int=2,
53 | y_iter_range_min: int=-2,
54 | y_iter_range_max: int=2,
55 | noise_const: float = 0.05,
56 | stdev: float=2.828):
57 | self.batch_size = batch_size
58 | self.dataset = []
59 | for i in range(n_init_loop):
60 | for x in range(x_iter_range_min, x_iter_range_max+1):
61 | for y in range(y_iter_range_min, y_iter_range_max+1):
62 | point = np.random.randn(2) * noise_const
63 | point[0] += 2 * x
64 | point[1] += 2 * y
65 | self.dataset.append(point)
66 | self.dataset = np.array(self.dataset, dtype='float32')
67 | np.random.shuffle(self.dataset)
68 | self.dataset /= stdev
69 |
70 | def __iter__(self):
71 | while True:
72 | for i in range(int(len(self.dataset) / self.batch_size)):
73 | yield self.dataset[i * self.batch_size:(i + 1)*self.batch_size]
74 |
75 |
76 | class GeneratorSwissRoll(object):
77 | def __init__(self,
78 | batch_size: int=256,
79 | noise_stdev: float=0.25,
80 | stdev: float=7.5):
81 | self.batch_size = batch_size
82 | self.noise_stdev = noise_stdev
83 | self.stdev = stdev
84 |
85 | def __iter__(self):
86 | while True:
87 | data = sklearn.datasets.make_swiss_roll(
88 | n_samples=self.batch_size,
89 | noise=self.noise_stdev
90 | )[0]
91 | data = data.astype('float32')[:, [0, 2]]
92 | data /= self.stdev # stdev plus a little
93 | yield data
94 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from typing import Callable
3 |
4 | slim = tf.contrib.slim
5 |
6 | __leaky_relu_alpha__ = 0.2
7 |
8 |
9 | def __leaky_relu__(x, alpha=__leaky_relu_alpha__, name='Leaky_ReLU'):
10 | return tf.maximum(x, alpha*x, name=name)
11 |
12 |
13 | class Model(object):
14 | def __init__(self,
15 | input_tensor: tf.Variable,
16 | variable_scope_name: str,
17 | n_hidden_neurons: int,
18 | n_hidden_layers: int,
19 | n_out_dim: int,
20 | activation_fn: Callable,
21 | reuse: bool):
22 | self.input = input_tensor
23 | self.variable_scope_name = variable_scope_name
24 | self.n_hidden_neurons = n_hidden_neurons
25 | self.n_hidden_layers = n_hidden_layers
26 | self.n_out_dim = n_out_dim
27 | self.activation_fn = activation_fn
28 | self.reuse = reuse
29 | self.output_tensor = None
30 | self.var_list = None
31 | self.define_model()
32 |
33 | def define_model(self):
34 | with tf.variable_scope(self.variable_scope_name, reuse=self.reuse) as vs:
35 | x = self.input
36 | with slim.arg_scope([slim.fully_connected],
37 | num_outputs=self.n_hidden_neurons,
38 | activation_fn=self.activation_fn):
39 | for i in range(self.n_hidden_layers):
40 | x = slim.fully_connected(inputs=x)
41 | self.output_tensor = slim.fully_connected(inputs=x,
42 | num_outputs=self.n_out_dim,
43 | activation_fn=None)
44 |
45 | self.var_list = tf.contrib.framework.get_variables(vs)
46 |
47 |
48 | class Generator(Model):
49 | def __init__(self,
50 | input_tensor: tf.Variable,
51 | variable_scope_name: str='Generator',
52 | n_hidden_neurons: int=512,
53 | n_hidden_layers: int=3,
54 | n_out_dim: int=2,
55 | activation_fn: Callable=__leaky_relu__,
56 | reuse: bool=False):
57 | super(Generator, self).__init__(input_tensor,
58 | variable_scope_name,
59 | n_hidden_neurons,
60 | n_hidden_layers,
61 | n_out_dim,
62 | activation_fn,
63 | reuse)
64 |
65 |
66 | class Critic(Model):
67 | def __init__(self,
68 | input_tensor: tf.Variable,
69 | variable_scope_name: str='Critic',
70 | n_hidden_neurons: int=512,
71 | n_hidden_layers: int=3,
72 | n_out_dim: int=1,
73 | activation_fn: Callable=__leaky_relu__,
74 | reuse: bool=False):
75 | super(Critic, self).__init__(input_tensor,
76 | variable_scope_name,
77 | n_hidden_neurons,
78 | n_hidden_layers,
79 | n_out_dim,
80 | activation_fn,
81 | reuse)
82 |
--------------------------------------------------------------------------------
/reg_losses.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from model import Critic
3 | slim = tf.contrib.slim
4 |
5 |
6 | def get_perbatuation_samples(training_samples, generated_samples, per_type,
7 | dragan_parameter_C):
8 | x_hat = None
9 |
10 | if per_type == 'no_purf':
11 | x_hat = training_samples
12 |
13 | # Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." arXiv preprint arXiv:1704.00028 (2017).
14 | elif per_type == 'wgan_gp':
15 | epsilon = tf.random_uniform(
16 | shape=[tf.shape(training_samples)[0], 1],
17 | minval=0.,
18 | maxval=1.
19 | )
20 | x_hat = epsilon * training_samples + (1 - epsilon) * generated_samples
21 |
22 | # Kodali, Naveen, et al. "How to Train Your DRAGAN." arXiv preprint arXiv:1705.07215 (2017).
23 | elif per_type == 'dragan_only_training':
24 | u = tf.random_uniform(
25 | shape=[tf.shape(training_samples)[0], 1],
26 | minval=0.,
27 | maxval=1.
28 | )
29 | _, batch_std = tf.nn.moments(tf.reshape(training_samples, [-1]), axes=[0])
30 |
31 | delta = dragan_parameter_C * batch_std * u
32 |
33 | alpha = tf.random_uniform(
34 | shape=[tf.shape(training_samples)[0], 1],
35 | minval=0.,
36 | maxval=1.
37 | )
38 |
39 | x_hat = training_samples + (1 - alpha) * delta
40 |
41 | elif per_type == 'dragan_both':
42 | samples = tf.concat([training_samples, generated_samples], axis=0)
43 |
44 | u = tf.random_uniform(
45 | shape=[tf.shape(samples)[0], 1],
46 | minval=0.,
47 | maxval=1.
48 | )
49 | _, batch_std = tf.nn.moments(tf.reshape(samples, [-1]), axes=[0])
50 |
51 | delta = dragan_parameter_C * batch_std * u
52 |
53 | alpha = tf.random_uniform(
54 | shape=[tf.shape(samples)[0], 1],
55 | minval=0.,
56 | maxval=1.
57 | )
58 |
59 | x_hat = samples + (1 - alpha) * delta
60 |
61 | else:
62 | NotImplementedError('arg per_type is not injected correctly.')
63 |
64 | return x_hat
65 |
66 |
67 | def get_regularization_term(training_samples, generated_samples,
68 | reg_type, per_type,
69 | critic_variable_scope_name,
70 | dragan_parameter_C=0.5):
71 | x_hat = get_perbatuation_samples(training_samples, generated_samples, per_type,
72 | dragan_parameter_C)
73 |
74 | critic = Critic(x_hat, variable_scope_name=critic_variable_scope_name, reuse=True)
75 | gradients = tf.gradients(critic.output_tensor, [x_hat])[0]
76 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
77 |
78 | gradient_penalty = None
79 |
80 | # Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." arXiv preprint arXiv:1704.00028 (2017).
81 | if reg_type == 'GP':
82 | gradient_penalty = tf.reduce_mean((slopes - 1) ** 2)
83 |
84 | # Henning Petzka, Asja Fischer, and Denis Lukovnicov. "On the regularization of Wasserstein GANs."
85 | # arXiv preprint arXiv:1709.08894 (2017).
86 | elif reg_type == 'LP':
87 | gradient_penalty = tf.reduce_mean((tf.maximum(0., slopes - 1)) ** 2)
88 |
89 | else:
90 | NotImplementedError('arg reg_type is not injected correctly.')
91 |
92 | return gradient_penalty, x_hat
93 |
--------------------------------------------------------------------------------
/run_experiments.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Fig6-Top
4 | python3 trainer.py --Regularization_type GP --Purturbation_type wgan_gp --Lambda 10.0
5 | # Fig6-Middle, Fig13, Fig15
6 | python3 trainer.py --Regularization_type GP --Purturbation_type wgan_gp --Lambda 1.0
7 | # Fig6-Bottom
8 | python3 trainer.py --Regularization_type LP --Purturbation_type wgan_gp --Lambda 10.0
9 | # Not shown in Fig6, but written
10 | python3 trainer.py --Regularization_type LP --Purturbation_type wgan_gp --Lambda 1.0
11 | python3 trainer.py --Regularization_type LP --Purturbation_type wgan_gp --Lambda 100.0
12 |
13 | # Fig11-Top
14 | python3 trainer.py --dataset GeneratorGaussians8 --Regularization_type GP --Purturbation_type wgan_gp --Lambda 10.0
15 | # Fig11-Middle
16 | python3 trainer.py --dataset GeneratorGaussians8 --Regularization_type GP --Purturbation_type wgan_gp --Lambda 1.0
17 | # Fig11-Bottom
18 | python3 trainer.py --dataset GeneratorGaussians8 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 10.0
19 | # Not shown in Fig11, but written
20 | python3 trainer.py --dataset GeneratorGaussians8 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 1.0
21 | python3 trainer.py --dataset GeneratorGaussians8 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 100.0
22 |
23 | # Fig12-Top
24 | python3 trainer.py --dataset GeneratorGaussians25 --Regularization_type GP --Purturbation_type wgan_gp --Lambda 10.0
25 | # Fig12-Middle
26 | python3 trainer.py --dataset GeneratorGaussians25 --Regularization_type GP --Purturbation_type wgan_gp --Lambda 1.0
27 | # Fig12-Bottom
28 | python3 trainer.py --dataset GeneratorGaussians25 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 10.0
29 | # Not shown in Fig12, but written
30 | python3 trainer.py --dataset GeneratorGaussians25 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 1.0
31 | python3 trainer.py --dataset GeneratorGaussians25 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 100.0
32 |
33 | # Fig7-Top, Fig9-Top
34 | python3 trainer.py --Regularization_type GP --Purturbation_type wgan_gp --Lambda 5.0
35 | # Fig7-Bottom, Fig9-Bottom
36 | python3 trainer.py --Regularization_type LP --Purturbation_type wgan_gp --Lambda 5.0
37 |
38 | # Fig8-Top
39 | python3 trainer.py --Regularization_type GP --Purturbation_type dragan_only_training --Lambda 5.0
40 | # Fig8-Middle, Fig14-Top
41 | python3 trainer.py --Regularization_type GP --Purturbation_type dragan_both --Lambda 5.0
42 | # Fig8-Bottom, Fig14-Bottom
43 | python3 trainer.py --Regularization_type LP --Purturbation_type dragan_both --Lambda 5.0
44 |
45 |
46 | python3 trainer.py --n_epoch 2000 --Regularization_type GP --Purturbation_type wgan_gp --Lambda 1.0 --emd_records True
47 | python3 trainer.py --n_epoch 2000 --Regularization_type GP --Purturbation_type wgan_gp --Lambda 5.0 --emd_records True
48 | python3 trainer.py --n_epoch 2000 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 5.0 --emd_records True
49 | python3 trainer.py --n_epoch 2000 --Regularization_type GP --Purturbation_type dragan_both --Lambda 5.0 --emd_records True
50 | python3 trainer.py --n_epoch 2000 --Regularization_type LP --Purturbation_type dragan_both --Lambda 5.0 --emd_records True
51 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import tensorflow as tf
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from scipy.optimize import linear_sum_assignment
7 | from scipy.spatial import distance
8 |
9 | import data_generator
10 | from model import Generator, Critic
11 | from reg_losses import get_regularization_term
12 |
13 | slim = tf.contrib.slim
14 |
15 | __eval_step_list__ = [10, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 15000, 20000]
16 |
17 | flags = tf.app.flags
18 | flags.DEFINE_integer("n_epoch", 20000, "Epoch to train [20000]")
19 | flags.DEFINE_integer("n_batch_size", 256, "Batch size to train [256]")
20 | flags.DEFINE_integer("latent_dimensionality", 2, "Dimensionality of the latent variables [2]")
21 |
22 | """
23 | During training, 10 critic updates are performed for every generator update,
24 | except for the first 25 generator updates,
25 | where the critic is updated 100 times for each generator update
26 | in order to get closer to the optimal critic in the beginning of training.
27 | """
28 |
29 | flags.DEFINE_integer("begining_init_step", 25, "[25]")
30 | flags.DEFINE_integer("n_c_iters_under_begining_init_step", 100, "[100]")
31 | flags.DEFINE_integer("n_c_iters_over_begining_init_step", 10, "[10]")
32 | flags.DEFINE_integer("interval_record_earth_mover", 10, "[10]")
33 |
34 | flags.DEFINE_float("learning_rate", 5e-5, "Learning rate of optimizer [5e-5]")
35 | flags.DEFINE_float("Lambda", 5., "Weights for critics' regularization term [5]")
36 | flags.DEFINE_string("Regularization_type", "LP", "[no_reg, no_reg_but_clipping, LP, GP]")
37 | flags.DEFINE_string("Purturbation_type", "dragan_only_training",
38 | "[no_purf, wgan_gp, dragan_only_training, dragan_both]")
39 | flags.DEFINE_string("dataset", 'GeneratorSwissRoll',
40 | "Which dataset is used? [GeneratorGaussians8, GeneratorGaussians25, GeneratorSwissRoll]")
41 |
42 | flags.DEFINE_string("critic_variable_scope_name", "Critic", "[Critic]")
43 | flags.DEFINE_string("generator_variable_scope_name", "Generator", "Generator")
44 |
45 | flags.DEFINE_bool("emd_records", False, "Whether EMD is recorded. (It takes some time...)[True, False]")
46 | FLAGS = flags.FLAGS
47 |
48 |
49 | class Trainer(object):
50 | def __init__(self):
51 | self.dataset_generator = None
52 | self.real_input = None
53 |
54 | self.z = None
55 |
56 | self.generator = None
57 | self.critic_x = None
58 | self.critic_gz = None
59 |
60 | self.g_loss = None
61 | self.c_negative_loss = None
62 | self.c_regularization_loss = None
63 | self.c_loss = None
64 | self.c_clipping = None
65 | self.x_hat = None
66 |
67 | self.ckpt_dir = None
68 | self.summary_writer = None
69 | self.c_summary_op = None
70 | self.g_summary_op = None
71 | self.emd_placeholder = None
72 | self.emd_summary = None
73 |
74 | self.saver = None
75 |
76 | self.step = None
77 |
78 | self.sess = None
79 | self.step_inc = None
80 | self.g_opt = None
81 | self.c_opt = None
82 |
83 | self.g_update_fetch_dict = None
84 | self.c_update_fetch_dict = None
85 | self.c_feed_dict = None
86 |
87 | self.coord = None
88 | self.threads = None
89 |
90 | self.define_dataset()
91 | self.define_latent()
92 | self.define_model()
93 | self.define_loss()
94 | self.define_optim()
95 | self.define_writer_and_summary()
96 | self.define_saver()
97 | self.initialize_session_and_etc()
98 | self.define_feed_and_fetch()
99 |
100 | def define_dataset(self):
101 | self.dataset_generator = iter(getattr(data_generator, FLAGS.dataset)(FLAGS.n_batch_size))
102 | self.real_input = tf.placeholder(tf.float32, shape=(None, 2))
103 |
104 | def define_latent(self):
105 | self.z = tf.random_normal([FLAGS.n_batch_size, FLAGS.latent_dimensionality], mean=0.0, stddev=1.0, name='z')
106 |
107 | def define_model(self):
108 | self.generator = Generator(self.z,
109 | variable_scope_name=FLAGS.generator_variable_scope_name)
110 | self.critic_x = Critic(self.real_input,
111 | variable_scope_name=FLAGS.critic_variable_scope_name)
112 | self.critic_gz = Critic(self.generator.output_tensor,
113 | variable_scope_name=FLAGS.critic_variable_scope_name,
114 | reuse=True)
115 |
116 | def define_loss(self):
117 | self.g_loss = -tf.reduce_mean(self.critic_gz.output_tensor)
118 | self.c_negative_loss = -self.g_loss - tf.reduce_mean(self.critic_x.output_tensor)
119 | if FLAGS.Regularization_type == 'no_reg_but_clipping' or \
120 | FLAGS.Regularization_type == 'no_reg':
121 | self.c_regularization_loss = tf.Variable(0., trainable=False)
122 | else:
123 | self.c_regularization_loss, self.x_hat = get_regularization_term(
124 | training_samples=self.real_input,
125 | generated_samples=self.generator.output_tensor,
126 | reg_type=FLAGS.Regularization_type,
127 | per_type=FLAGS.Purturbation_type,
128 | critic_variable_scope_name=FLAGS.critic_variable_scope_name
129 | )
130 |
131 | self.c_loss = self.c_negative_loss + FLAGS.Lambda * self.c_regularization_loss
132 |
133 | def define_optim(self):
134 | self.step = tf.Variable(0, name='step', trainable=False)
135 | self.step_inc = tf.assign(self.step, self.step + 1)
136 |
137 | optimizer = tf.train.RMSPropOptimizer(FLAGS.learning_rate)
138 |
139 | self.g_opt = optimizer.minimize(self.g_loss, var_list=self.generator.var_list)
140 | self.c_opt = optimizer.minimize(self.c_loss, var_list=self.critic_x.var_list)
141 |
142 | with tf.control_dependencies([self.c_opt]):
143 | if FLAGS.Regularization_type == 'no_reg_but_clipping':
144 | self.c_clipping = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in self.critic_x.var_list]
145 | else:
146 | self.c_clipping = [tf.no_op()]
147 |
148 | def define_writer_and_summary(self):
149 | self.ckpt_dir = ''.join(['ckpts/',
150 | FLAGS.dataset+'_',
151 | FLAGS.Regularization_type+'_',
152 | FLAGS.Purturbation_type+'_',
153 | str(FLAGS.Lambda)+'_',
154 | str(FLAGS.emd_records),
155 | '/'])
156 |
157 | if not os.path.exists(self.ckpt_dir):
158 | os.makedirs(self.ckpt_dir)
159 |
160 | self.summary_writer = tf.summary.FileWriter(self.ckpt_dir)
161 |
162 | self.c_summary_op = tf.summary.merge([
163 | tf.summary.scalar('loss/c', self.c_loss),
164 | tf.summary.scalar('loss/c_negative_loss', self.c_negative_loss),
165 | tf.summary.scalar('loss/c_regularization_loss', self.c_regularization_loss)
166 | ])
167 | self.g_summary_op = tf.summary.merge([
168 | tf.summary.scalar('loss/g', self.g_loss)
169 | ])
170 |
171 | self.emd_placeholder = tf.placeholder(tf.float32, shape=())
172 | self.emd_summary = tf.summary.scalar('EMD', self.emd_placeholder)
173 |
174 | def define_saver(self):
175 | self.saver = tf.train.Saver()
176 |
177 | def initialize_session_and_etc(self):
178 | gpu_options = tf.GPUOptions(allow_growth=True)
179 | sess_config = tf.ConfigProto(allow_soft_placement=True,
180 | gpu_options=gpu_options)
181 | self.sess = tf.Session(config=sess_config)
182 |
183 | self.sess.run(tf.local_variables_initializer())
184 | self.sess.run(tf.global_variables_initializer())
185 |
186 | self.coord = tf.train.Coordinator()
187 | self.threads = tf.train.start_queue_runners(sess=self.sess, coord=self.coord)
188 |
189 | def define_feed_and_fetch(self):
190 | self.g_update_fetch_dict = {
191 | "opt": self.g_opt,
192 | "z": self.z,
193 | "G_z": self.generator.output_tensor,
194 | "loss": self.g_loss,
195 | 'summary': self.g_summary_op,
196 | "step": self.step
197 | }
198 |
199 | self.c_update_fetch_dict = {
200 | 'gradient_clipping': self.c_clipping,
201 | "x": self.real_input,
202 | "G_z": self.generator.output_tensor,
203 | "loss": self.c_loss,
204 | "negative_loss": self.c_negative_loss,
205 | "regularization_loss": self.c_regularization_loss,
206 | 'summary': self.c_summary_op,
207 | "step": self.step
208 | }
209 |
210 | self.c_feed_dict = {
211 | self.real_input: None
212 | }
213 |
214 | def draw_level_sets(self, step,
215 | x_min=-2.5, x_max=2.5,
216 | y_min=-2.5, y_max=2.5,
217 | n_batch=2):
218 | x = np.linspace(x_min, x_max, 200)
219 | y = np.linspace(y_min, y_max, 200)
220 |
221 | x, y = np.meshgrid(x, y)
222 | grid_pts = np.stack([x.flatten(), y.flatten()], axis=1)
223 |
224 | z = self.sess.run(self.critic_x.output_tensor, feed_dict={self.real_input: grid_pts})
225 | z = np.reshape(z, (200, 200))
226 |
227 | plt.contour(x, y, z, 30, cmap='copper')
228 |
229 | real = list()
230 | fake = list()
231 | perturbated = list()
232 | for i in range(n_batch):
233 | __real__ = next(self.dataset_generator)
234 |
235 | __fake__, __perturbated__ = \
236 | self.sess.run([self.generator.output_tensor, self.x_hat], feed_dict={self.real_input: __real__})
237 | real.append(__real__)
238 | fake.append(__fake__)
239 | perturbated.append(__perturbated__)
240 |
241 | real = np.vstack(real)
242 | fake = np.vstack(fake)
243 | perturbated = np.vstack(perturbated)
244 |
245 | plt.scatter(perturbated[:, 0], perturbated[:, 1], c='r', s=1)
246 | plt.scatter(real[:, 0], real[:, 1], c='y', s=2)
247 | plt.scatter(fake[:, 0], fake[:, 1], c='g', s=2)
248 |
249 | plt.savefig(self.ckpt_dir+str(step)+'.png')
250 | plt.clf()
251 |
252 | def estimate_earth_mover_distance(self, step, n_batch=2):
253 | real = list()
254 | fake = list()
255 |
256 | for i in range(n_batch):
257 | __real__ = next(self.dataset_generator)
258 | __fake__ = self.sess.run(self.generator.output_tensor)
259 |
260 | real.append(__real__)
261 | fake.append(__fake__)
262 |
263 | real = np.vstack(real)
264 | fake = np.vstack(fake)
265 |
266 | cost_matrix = distance.cdist(fake, real, 'euclidean')
267 |
268 | row_ind, col_ind = linear_sum_assignment(cost_matrix)
269 | linear_sum = cost_matrix[row_ind, col_ind].sum()
270 | emd = linear_sum/real.shape[0]
271 |
272 | emd_fetch = self.sess.run(self.emd_summary, feed_dict={self.emd_placeholder: emd})
273 | self.summary_writer.add_summary(emd_fetch, step)
274 | self.summary_writer.flush()
275 |
276 | def train(self):
277 | try:
278 | c_fetch_dict = None
279 | print("[.] Learning Start...")
280 | step = 0
281 | while not self.coord.should_stop():
282 | if step > FLAGS.n_epoch:
283 | break
284 |
285 | self.c_feed_dict[self.real_input] = next(self.dataset_generator)
286 | step = self.sess.run(self.step)
287 |
288 | n_c_iters = (FLAGS.n_c_iters_under_begining_init_step
289 | if step < FLAGS.begining_init_step
290 | else FLAGS.n_c_iters_over_begining_init_step)
291 | for _ in range(n_c_iters):
292 | c_fetch_dict = self.sess.run(self.c_update_fetch_dict,
293 | feed_dict=self.c_feed_dict)
294 |
295 | g_fetch_dict = self.sess.run(self.g_update_fetch_dict)
296 |
297 | self.summary_writer.add_summary(c_fetch_dict["summary"], c_fetch_dict["step"])
298 | self.summary_writer.add_summary(g_fetch_dict["summary"], g_fetch_dict["step"])
299 | self.summary_writer.flush()
300 |
301 | if step in __eval_step_list__:
302 | self.draw_level_sets(step)
303 |
304 | if FLAGS.emd_records and step % FLAGS.interval_record_earth_mover == 0 and step != 0:
305 | self.estimate_earth_mover_distance(step)
306 |
307 | self.sess.run(self.step_inc)
308 |
309 | except KeyboardInterrupt:
310 | print("Interrupted")
311 | self.coord.request_stop()
312 | finally:
313 | self.saver.save(self.sess, self.ckpt_dir)
314 | print('Stop')
315 | self.coord.request_stop()
316 | self.coord.join(self.threads)
317 |
318 |
319 | if __name__ == '__main__':
320 | trainer = Trainer()
321 | trainer.train()
322 | trainer.sess.close()
323 |
--------------------------------------------------------------------------------