├── .gitignore ├── .idea ├── WGAN-LP-tensorflow.iml ├── misc.xml ├── modules.xml └── vcs.xml ├── LICENSE ├── README.md ├── ckpts └── README.md ├── data_generator.py ├── model.py ├── reg_losses.py ├── run_experiments.sh └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /.idea/WGAN-LP-tensorflow.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 mikigom 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WGAN-LP-tensorflow 2 | 3 | [Report on arXiv](https://arxiv.org/abs/1712.05882) 4 | 5 | Reproduction code for the following paper: 6 | 7 | ``` 8 | Title: 9 | On the regularization of Wasserstein GANs 10 | Authors: 11 | Petzka, Henning; Fischer, Asja; Lukovnicov, Denis 12 | Publication: 13 | eprint arXiv:1709.08894 14 | Publication Date: 15 | 09/2017 16 | Origin: 17 | ARXIV 18 | Keywords: 19 | Statistics - Machine Learning, Computer Science - Learning 20 | 2017arXiv170908894P 21 | ``` 22 | [Original Paper on arXiv](https://arxiv.org/abs/1709.08894) 23 | 24 | ## Repository structure 25 | 26 | *data\_generator.py* 27 | - provides a class that generates the sample data needed for learning. 28 | 29 | *reg\_losses.py* 30 | - defines the sampling method and loss term for regularization. 31 | 32 | *model.py* 33 | - implements 3-layer neural networks for a generator and a critic. 34 | 35 | *trainer.py* 36 | - a pipeline for model learning and visualization. 37 | -------------------------------------------------------------------------------- /ckpts/README.md: -------------------------------------------------------------------------------- 1 | ## ckpts 2 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/igul222/improved_wgan_training/blob/master/gan_toy.py 2 | 3 | import numpy as np 4 | import sklearn.datasets 5 | import random 6 | 7 | # This module referes 'Python Generator', not 'Generative Model'. 8 | 9 | 10 | class GeneratorGaussians8(object): 11 | def __init__(self, 12 | batch_size: int=256, 13 | scale: float=2., 14 | center_coor_min: float=-1., 15 | center_coor_max: float=1., 16 | stdev: float=1.414): 17 | self.batch_size = batch_size 18 | self.stdev = stdev 19 | scale = scale 20 | diag_len = np.sqrt(center_coor_min**2 + center_coor_max**2) 21 | centers = [ 22 | (center_coor_max, 0.), 23 | (center_coor_min, 0.), 24 | (0., center_coor_max), 25 | (0., center_coor_min), 26 | (center_coor_max / diag_len, center_coor_max / diag_len), 27 | (center_coor_max / diag_len, center_coor_min / diag_len), 28 | (center_coor_min / diag_len, center_coor_max / diag_len), 29 | (center_coor_min / diag_len, center_coor_min / diag_len) 30 | ] 31 | self.centers = [(scale * x, scale * y) for x, y in centers] 32 | 33 | def __iter__(self): 34 | while True: 35 | dataset = [] 36 | for i in range(self.batch_size): 37 | point = np.random.randn(2) * .02 38 | center = random.choice(self.centers) 39 | point[0] += center[0] 40 | point[1] += center[1] 41 | dataset.append(point) 42 | dataset = np.array(dataset, dtype='float32') 43 | dataset /= self.stdev 44 | yield dataset 45 | 46 | 47 | class GeneratorGaussians25(object): 48 | def __init__(self, 49 | batch_size: int=256, 50 | n_init_loop: int=4000, 51 | x_iter_range_min: int=-2, 52 | x_iter_range_max: int=2, 53 | y_iter_range_min: int=-2, 54 | y_iter_range_max: int=2, 55 | noise_const: float = 0.05, 56 | stdev: float=2.828): 57 | self.batch_size = batch_size 58 | self.dataset = [] 59 | for i in range(n_init_loop): 60 | for x in range(x_iter_range_min, x_iter_range_max+1): 61 | for y in range(y_iter_range_min, y_iter_range_max+1): 62 | point = np.random.randn(2) * noise_const 63 | point[0] += 2 * x 64 | point[1] += 2 * y 65 | self.dataset.append(point) 66 | self.dataset = np.array(self.dataset, dtype='float32') 67 | np.random.shuffle(self.dataset) 68 | self.dataset /= stdev 69 | 70 | def __iter__(self): 71 | while True: 72 | for i in range(int(len(self.dataset) / self.batch_size)): 73 | yield self.dataset[i * self.batch_size:(i + 1)*self.batch_size] 74 | 75 | 76 | class GeneratorSwissRoll(object): 77 | def __init__(self, 78 | batch_size: int=256, 79 | noise_stdev: float=0.25, 80 | stdev: float=7.5): 81 | self.batch_size = batch_size 82 | self.noise_stdev = noise_stdev 83 | self.stdev = stdev 84 | 85 | def __iter__(self): 86 | while True: 87 | data = sklearn.datasets.make_swiss_roll( 88 | n_samples=self.batch_size, 89 | noise=self.noise_stdev 90 | )[0] 91 | data = data.astype('float32')[:, [0, 2]] 92 | data /= self.stdev # stdev plus a little 93 | yield data 94 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from typing import Callable 3 | 4 | slim = tf.contrib.slim 5 | 6 | __leaky_relu_alpha__ = 0.2 7 | 8 | 9 | def __leaky_relu__(x, alpha=__leaky_relu_alpha__, name='Leaky_ReLU'): 10 | return tf.maximum(x, alpha*x, name=name) 11 | 12 | 13 | class Model(object): 14 | def __init__(self, 15 | input_tensor: tf.Variable, 16 | variable_scope_name: str, 17 | n_hidden_neurons: int, 18 | n_hidden_layers: int, 19 | n_out_dim: int, 20 | activation_fn: Callable, 21 | reuse: bool): 22 | self.input = input_tensor 23 | self.variable_scope_name = variable_scope_name 24 | self.n_hidden_neurons = n_hidden_neurons 25 | self.n_hidden_layers = n_hidden_layers 26 | self.n_out_dim = n_out_dim 27 | self.activation_fn = activation_fn 28 | self.reuse = reuse 29 | self.output_tensor = None 30 | self.var_list = None 31 | self.define_model() 32 | 33 | def define_model(self): 34 | with tf.variable_scope(self.variable_scope_name, reuse=self.reuse) as vs: 35 | x = self.input 36 | with slim.arg_scope([slim.fully_connected], 37 | num_outputs=self.n_hidden_neurons, 38 | activation_fn=self.activation_fn): 39 | for i in range(self.n_hidden_layers): 40 | x = slim.fully_connected(inputs=x) 41 | self.output_tensor = slim.fully_connected(inputs=x, 42 | num_outputs=self.n_out_dim, 43 | activation_fn=None) 44 | 45 | self.var_list = tf.contrib.framework.get_variables(vs) 46 | 47 | 48 | class Generator(Model): 49 | def __init__(self, 50 | input_tensor: tf.Variable, 51 | variable_scope_name: str='Generator', 52 | n_hidden_neurons: int=512, 53 | n_hidden_layers: int=3, 54 | n_out_dim: int=2, 55 | activation_fn: Callable=__leaky_relu__, 56 | reuse: bool=False): 57 | super(Generator, self).__init__(input_tensor, 58 | variable_scope_name, 59 | n_hidden_neurons, 60 | n_hidden_layers, 61 | n_out_dim, 62 | activation_fn, 63 | reuse) 64 | 65 | 66 | class Critic(Model): 67 | def __init__(self, 68 | input_tensor: tf.Variable, 69 | variable_scope_name: str='Critic', 70 | n_hidden_neurons: int=512, 71 | n_hidden_layers: int=3, 72 | n_out_dim: int=1, 73 | activation_fn: Callable=__leaky_relu__, 74 | reuse: bool=False): 75 | super(Critic, self).__init__(input_tensor, 76 | variable_scope_name, 77 | n_hidden_neurons, 78 | n_hidden_layers, 79 | n_out_dim, 80 | activation_fn, 81 | reuse) 82 | -------------------------------------------------------------------------------- /reg_losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from model import Critic 3 | slim = tf.contrib.slim 4 | 5 | 6 | def get_perbatuation_samples(training_samples, generated_samples, per_type, 7 | dragan_parameter_C): 8 | x_hat = None 9 | 10 | if per_type == 'no_purf': 11 | x_hat = training_samples 12 | 13 | # Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." arXiv preprint arXiv:1704.00028 (2017). 14 | elif per_type == 'wgan_gp': 15 | epsilon = tf.random_uniform( 16 | shape=[tf.shape(training_samples)[0], 1], 17 | minval=0., 18 | maxval=1. 19 | ) 20 | x_hat = epsilon * training_samples + (1 - epsilon) * generated_samples 21 | 22 | # Kodali, Naveen, et al. "How to Train Your DRAGAN." arXiv preprint arXiv:1705.07215 (2017). 23 | elif per_type == 'dragan_only_training': 24 | u = tf.random_uniform( 25 | shape=[tf.shape(training_samples)[0], 1], 26 | minval=0., 27 | maxval=1. 28 | ) 29 | _, batch_std = tf.nn.moments(tf.reshape(training_samples, [-1]), axes=[0]) 30 | 31 | delta = dragan_parameter_C * batch_std * u 32 | 33 | alpha = tf.random_uniform( 34 | shape=[tf.shape(training_samples)[0], 1], 35 | minval=0., 36 | maxval=1. 37 | ) 38 | 39 | x_hat = training_samples + (1 - alpha) * delta 40 | 41 | elif per_type == 'dragan_both': 42 | samples = tf.concat([training_samples, generated_samples], axis=0) 43 | 44 | u = tf.random_uniform( 45 | shape=[tf.shape(samples)[0], 1], 46 | minval=0., 47 | maxval=1. 48 | ) 49 | _, batch_std = tf.nn.moments(tf.reshape(samples, [-1]), axes=[0]) 50 | 51 | delta = dragan_parameter_C * batch_std * u 52 | 53 | alpha = tf.random_uniform( 54 | shape=[tf.shape(samples)[0], 1], 55 | minval=0., 56 | maxval=1. 57 | ) 58 | 59 | x_hat = samples + (1 - alpha) * delta 60 | 61 | else: 62 | NotImplementedError('arg per_type is not injected correctly.') 63 | 64 | return x_hat 65 | 66 | 67 | def get_regularization_term(training_samples, generated_samples, 68 | reg_type, per_type, 69 | critic_variable_scope_name, 70 | dragan_parameter_C=0.5): 71 | x_hat = get_perbatuation_samples(training_samples, generated_samples, per_type, 72 | dragan_parameter_C) 73 | 74 | critic = Critic(x_hat, variable_scope_name=critic_variable_scope_name, reuse=True) 75 | gradients = tf.gradients(critic.output_tensor, [x_hat])[0] 76 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) 77 | 78 | gradient_penalty = None 79 | 80 | # Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." arXiv preprint arXiv:1704.00028 (2017). 81 | if reg_type == 'GP': 82 | gradient_penalty = tf.reduce_mean((slopes - 1) ** 2) 83 | 84 | # Henning Petzka, Asja Fischer, and Denis Lukovnicov. "On the regularization of Wasserstein GANs." 85 | # arXiv preprint arXiv:1709.08894 (2017). 86 | elif reg_type == 'LP': 87 | gradient_penalty = tf.reduce_mean((tf.maximum(0., slopes - 1)) ** 2) 88 | 89 | else: 90 | NotImplementedError('arg reg_type is not injected correctly.') 91 | 92 | return gradient_penalty, x_hat 93 | -------------------------------------------------------------------------------- /run_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Fig6-Top 4 | python3 trainer.py --Regularization_type GP --Purturbation_type wgan_gp --Lambda 10.0 5 | # Fig6-Middle, Fig13, Fig15 6 | python3 trainer.py --Regularization_type GP --Purturbation_type wgan_gp --Lambda 1.0 7 | # Fig6-Bottom 8 | python3 trainer.py --Regularization_type LP --Purturbation_type wgan_gp --Lambda 10.0 9 | # Not shown in Fig6, but written 10 | python3 trainer.py --Regularization_type LP --Purturbation_type wgan_gp --Lambda 1.0 11 | python3 trainer.py --Regularization_type LP --Purturbation_type wgan_gp --Lambda 100.0 12 | 13 | # Fig11-Top 14 | python3 trainer.py --dataset GeneratorGaussians8 --Regularization_type GP --Purturbation_type wgan_gp --Lambda 10.0 15 | # Fig11-Middle 16 | python3 trainer.py --dataset GeneratorGaussians8 --Regularization_type GP --Purturbation_type wgan_gp --Lambda 1.0 17 | # Fig11-Bottom 18 | python3 trainer.py --dataset GeneratorGaussians8 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 10.0 19 | # Not shown in Fig11, but written 20 | python3 trainer.py --dataset GeneratorGaussians8 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 1.0 21 | python3 trainer.py --dataset GeneratorGaussians8 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 100.0 22 | 23 | # Fig12-Top 24 | python3 trainer.py --dataset GeneratorGaussians25 --Regularization_type GP --Purturbation_type wgan_gp --Lambda 10.0 25 | # Fig12-Middle 26 | python3 trainer.py --dataset GeneratorGaussians25 --Regularization_type GP --Purturbation_type wgan_gp --Lambda 1.0 27 | # Fig12-Bottom 28 | python3 trainer.py --dataset GeneratorGaussians25 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 10.0 29 | # Not shown in Fig12, but written 30 | python3 trainer.py --dataset GeneratorGaussians25 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 1.0 31 | python3 trainer.py --dataset GeneratorGaussians25 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 100.0 32 | 33 | # Fig7-Top, Fig9-Top 34 | python3 trainer.py --Regularization_type GP --Purturbation_type wgan_gp --Lambda 5.0 35 | # Fig7-Bottom, Fig9-Bottom 36 | python3 trainer.py --Regularization_type LP --Purturbation_type wgan_gp --Lambda 5.0 37 | 38 | # Fig8-Top 39 | python3 trainer.py --Regularization_type GP --Purturbation_type dragan_only_training --Lambda 5.0 40 | # Fig8-Middle, Fig14-Top 41 | python3 trainer.py --Regularization_type GP --Purturbation_type dragan_both --Lambda 5.0 42 | # Fig8-Bottom, Fig14-Bottom 43 | python3 trainer.py --Regularization_type LP --Purturbation_type dragan_both --Lambda 5.0 44 | 45 | 46 | python3 trainer.py --n_epoch 2000 --Regularization_type GP --Purturbation_type wgan_gp --Lambda 1.0 --emd_records True 47 | python3 trainer.py --n_epoch 2000 --Regularization_type GP --Purturbation_type wgan_gp --Lambda 5.0 --emd_records True 48 | python3 trainer.py --n_epoch 2000 --Regularization_type LP --Purturbation_type wgan_gp --Lambda 5.0 --emd_records True 49 | python3 trainer.py --n_epoch 2000 --Regularization_type GP --Purturbation_type dragan_both --Lambda 5.0 --emd_records True 50 | python3 trainer.py --n_epoch 2000 --Regularization_type LP --Purturbation_type dragan_both --Lambda 5.0 --emd_records True 51 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from scipy.optimize import linear_sum_assignment 7 | from scipy.spatial import distance 8 | 9 | import data_generator 10 | from model import Generator, Critic 11 | from reg_losses import get_regularization_term 12 | 13 | slim = tf.contrib.slim 14 | 15 | __eval_step_list__ = [10, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 15000, 20000] 16 | 17 | flags = tf.app.flags 18 | flags.DEFINE_integer("n_epoch", 20000, "Epoch to train [20000]") 19 | flags.DEFINE_integer("n_batch_size", 256, "Batch size to train [256]") 20 | flags.DEFINE_integer("latent_dimensionality", 2, "Dimensionality of the latent variables [2]") 21 | 22 | """ 23 | During training, 10 critic updates are performed for every generator update, 24 | except for the first 25 generator updates, 25 | where the critic is updated 100 times for each generator update 26 | in order to get closer to the optimal critic in the beginning of training. 27 | """ 28 | 29 | flags.DEFINE_integer("begining_init_step", 25, "[25]") 30 | flags.DEFINE_integer("n_c_iters_under_begining_init_step", 100, "[100]") 31 | flags.DEFINE_integer("n_c_iters_over_begining_init_step", 10, "[10]") 32 | flags.DEFINE_integer("interval_record_earth_mover", 10, "[10]") 33 | 34 | flags.DEFINE_float("learning_rate", 5e-5, "Learning rate of optimizer [5e-5]") 35 | flags.DEFINE_float("Lambda", 5., "Weights for critics' regularization term [5]") 36 | flags.DEFINE_string("Regularization_type", "LP", "[no_reg, no_reg_but_clipping, LP, GP]") 37 | flags.DEFINE_string("Purturbation_type", "dragan_only_training", 38 | "[no_purf, wgan_gp, dragan_only_training, dragan_both]") 39 | flags.DEFINE_string("dataset", 'GeneratorSwissRoll', 40 | "Which dataset is used? [GeneratorGaussians8, GeneratorGaussians25, GeneratorSwissRoll]") 41 | 42 | flags.DEFINE_string("critic_variable_scope_name", "Critic", "[Critic]") 43 | flags.DEFINE_string("generator_variable_scope_name", "Generator", "Generator") 44 | 45 | flags.DEFINE_bool("emd_records", False, "Whether EMD is recorded. (It takes some time...)[True, False]") 46 | FLAGS = flags.FLAGS 47 | 48 | 49 | class Trainer(object): 50 | def __init__(self): 51 | self.dataset_generator = None 52 | self.real_input = None 53 | 54 | self.z = None 55 | 56 | self.generator = None 57 | self.critic_x = None 58 | self.critic_gz = None 59 | 60 | self.g_loss = None 61 | self.c_negative_loss = None 62 | self.c_regularization_loss = None 63 | self.c_loss = None 64 | self.c_clipping = None 65 | self.x_hat = None 66 | 67 | self.ckpt_dir = None 68 | self.summary_writer = None 69 | self.c_summary_op = None 70 | self.g_summary_op = None 71 | self.emd_placeholder = None 72 | self.emd_summary = None 73 | 74 | self.saver = None 75 | 76 | self.step = None 77 | 78 | self.sess = None 79 | self.step_inc = None 80 | self.g_opt = None 81 | self.c_opt = None 82 | 83 | self.g_update_fetch_dict = None 84 | self.c_update_fetch_dict = None 85 | self.c_feed_dict = None 86 | 87 | self.coord = None 88 | self.threads = None 89 | 90 | self.define_dataset() 91 | self.define_latent() 92 | self.define_model() 93 | self.define_loss() 94 | self.define_optim() 95 | self.define_writer_and_summary() 96 | self.define_saver() 97 | self.initialize_session_and_etc() 98 | self.define_feed_and_fetch() 99 | 100 | def define_dataset(self): 101 | self.dataset_generator = iter(getattr(data_generator, FLAGS.dataset)(FLAGS.n_batch_size)) 102 | self.real_input = tf.placeholder(tf.float32, shape=(None, 2)) 103 | 104 | def define_latent(self): 105 | self.z = tf.random_normal([FLAGS.n_batch_size, FLAGS.latent_dimensionality], mean=0.0, stddev=1.0, name='z') 106 | 107 | def define_model(self): 108 | self.generator = Generator(self.z, 109 | variable_scope_name=FLAGS.generator_variable_scope_name) 110 | self.critic_x = Critic(self.real_input, 111 | variable_scope_name=FLAGS.critic_variable_scope_name) 112 | self.critic_gz = Critic(self.generator.output_tensor, 113 | variable_scope_name=FLAGS.critic_variable_scope_name, 114 | reuse=True) 115 | 116 | def define_loss(self): 117 | self.g_loss = -tf.reduce_mean(self.critic_gz.output_tensor) 118 | self.c_negative_loss = -self.g_loss - tf.reduce_mean(self.critic_x.output_tensor) 119 | if FLAGS.Regularization_type == 'no_reg_but_clipping' or \ 120 | FLAGS.Regularization_type == 'no_reg': 121 | self.c_regularization_loss = tf.Variable(0., trainable=False) 122 | else: 123 | self.c_regularization_loss, self.x_hat = get_regularization_term( 124 | training_samples=self.real_input, 125 | generated_samples=self.generator.output_tensor, 126 | reg_type=FLAGS.Regularization_type, 127 | per_type=FLAGS.Purturbation_type, 128 | critic_variable_scope_name=FLAGS.critic_variable_scope_name 129 | ) 130 | 131 | self.c_loss = self.c_negative_loss + FLAGS.Lambda * self.c_regularization_loss 132 | 133 | def define_optim(self): 134 | self.step = tf.Variable(0, name='step', trainable=False) 135 | self.step_inc = tf.assign(self.step, self.step + 1) 136 | 137 | optimizer = tf.train.RMSPropOptimizer(FLAGS.learning_rate) 138 | 139 | self.g_opt = optimizer.minimize(self.g_loss, var_list=self.generator.var_list) 140 | self.c_opt = optimizer.minimize(self.c_loss, var_list=self.critic_x.var_list) 141 | 142 | with tf.control_dependencies([self.c_opt]): 143 | if FLAGS.Regularization_type == 'no_reg_but_clipping': 144 | self.c_clipping = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in self.critic_x.var_list] 145 | else: 146 | self.c_clipping = [tf.no_op()] 147 | 148 | def define_writer_and_summary(self): 149 | self.ckpt_dir = ''.join(['ckpts/', 150 | FLAGS.dataset+'_', 151 | FLAGS.Regularization_type+'_', 152 | FLAGS.Purturbation_type+'_', 153 | str(FLAGS.Lambda)+'_', 154 | str(FLAGS.emd_records), 155 | '/']) 156 | 157 | if not os.path.exists(self.ckpt_dir): 158 | os.makedirs(self.ckpt_dir) 159 | 160 | self.summary_writer = tf.summary.FileWriter(self.ckpt_dir) 161 | 162 | self.c_summary_op = tf.summary.merge([ 163 | tf.summary.scalar('loss/c', self.c_loss), 164 | tf.summary.scalar('loss/c_negative_loss', self.c_negative_loss), 165 | tf.summary.scalar('loss/c_regularization_loss', self.c_regularization_loss) 166 | ]) 167 | self.g_summary_op = tf.summary.merge([ 168 | tf.summary.scalar('loss/g', self.g_loss) 169 | ]) 170 | 171 | self.emd_placeholder = tf.placeholder(tf.float32, shape=()) 172 | self.emd_summary = tf.summary.scalar('EMD', self.emd_placeholder) 173 | 174 | def define_saver(self): 175 | self.saver = tf.train.Saver() 176 | 177 | def initialize_session_and_etc(self): 178 | gpu_options = tf.GPUOptions(allow_growth=True) 179 | sess_config = tf.ConfigProto(allow_soft_placement=True, 180 | gpu_options=gpu_options) 181 | self.sess = tf.Session(config=sess_config) 182 | 183 | self.sess.run(tf.local_variables_initializer()) 184 | self.sess.run(tf.global_variables_initializer()) 185 | 186 | self.coord = tf.train.Coordinator() 187 | self.threads = tf.train.start_queue_runners(sess=self.sess, coord=self.coord) 188 | 189 | def define_feed_and_fetch(self): 190 | self.g_update_fetch_dict = { 191 | "opt": self.g_opt, 192 | "z": self.z, 193 | "G_z": self.generator.output_tensor, 194 | "loss": self.g_loss, 195 | 'summary': self.g_summary_op, 196 | "step": self.step 197 | } 198 | 199 | self.c_update_fetch_dict = { 200 | 'gradient_clipping': self.c_clipping, 201 | "x": self.real_input, 202 | "G_z": self.generator.output_tensor, 203 | "loss": self.c_loss, 204 | "negative_loss": self.c_negative_loss, 205 | "regularization_loss": self.c_regularization_loss, 206 | 'summary': self.c_summary_op, 207 | "step": self.step 208 | } 209 | 210 | self.c_feed_dict = { 211 | self.real_input: None 212 | } 213 | 214 | def draw_level_sets(self, step, 215 | x_min=-2.5, x_max=2.5, 216 | y_min=-2.5, y_max=2.5, 217 | n_batch=2): 218 | x = np.linspace(x_min, x_max, 200) 219 | y = np.linspace(y_min, y_max, 200) 220 | 221 | x, y = np.meshgrid(x, y) 222 | grid_pts = np.stack([x.flatten(), y.flatten()], axis=1) 223 | 224 | z = self.sess.run(self.critic_x.output_tensor, feed_dict={self.real_input: grid_pts}) 225 | z = np.reshape(z, (200, 200)) 226 | 227 | plt.contour(x, y, z, 30, cmap='copper') 228 | 229 | real = list() 230 | fake = list() 231 | perturbated = list() 232 | for i in range(n_batch): 233 | __real__ = next(self.dataset_generator) 234 | 235 | __fake__, __perturbated__ = \ 236 | self.sess.run([self.generator.output_tensor, self.x_hat], feed_dict={self.real_input: __real__}) 237 | real.append(__real__) 238 | fake.append(__fake__) 239 | perturbated.append(__perturbated__) 240 | 241 | real = np.vstack(real) 242 | fake = np.vstack(fake) 243 | perturbated = np.vstack(perturbated) 244 | 245 | plt.scatter(perturbated[:, 0], perturbated[:, 1], c='r', s=1) 246 | plt.scatter(real[:, 0], real[:, 1], c='y', s=2) 247 | plt.scatter(fake[:, 0], fake[:, 1], c='g', s=2) 248 | 249 | plt.savefig(self.ckpt_dir+str(step)+'.png') 250 | plt.clf() 251 | 252 | def estimate_earth_mover_distance(self, step, n_batch=2): 253 | real = list() 254 | fake = list() 255 | 256 | for i in range(n_batch): 257 | __real__ = next(self.dataset_generator) 258 | __fake__ = self.sess.run(self.generator.output_tensor) 259 | 260 | real.append(__real__) 261 | fake.append(__fake__) 262 | 263 | real = np.vstack(real) 264 | fake = np.vstack(fake) 265 | 266 | cost_matrix = distance.cdist(fake, real, 'euclidean') 267 | 268 | row_ind, col_ind = linear_sum_assignment(cost_matrix) 269 | linear_sum = cost_matrix[row_ind, col_ind].sum() 270 | emd = linear_sum/real.shape[0] 271 | 272 | emd_fetch = self.sess.run(self.emd_summary, feed_dict={self.emd_placeholder: emd}) 273 | self.summary_writer.add_summary(emd_fetch, step) 274 | self.summary_writer.flush() 275 | 276 | def train(self): 277 | try: 278 | c_fetch_dict = None 279 | print("[.] Learning Start...") 280 | step = 0 281 | while not self.coord.should_stop(): 282 | if step > FLAGS.n_epoch: 283 | break 284 | 285 | self.c_feed_dict[self.real_input] = next(self.dataset_generator) 286 | step = self.sess.run(self.step) 287 | 288 | n_c_iters = (FLAGS.n_c_iters_under_begining_init_step 289 | if step < FLAGS.begining_init_step 290 | else FLAGS.n_c_iters_over_begining_init_step) 291 | for _ in range(n_c_iters): 292 | c_fetch_dict = self.sess.run(self.c_update_fetch_dict, 293 | feed_dict=self.c_feed_dict) 294 | 295 | g_fetch_dict = self.sess.run(self.g_update_fetch_dict) 296 | 297 | self.summary_writer.add_summary(c_fetch_dict["summary"], c_fetch_dict["step"]) 298 | self.summary_writer.add_summary(g_fetch_dict["summary"], g_fetch_dict["step"]) 299 | self.summary_writer.flush() 300 | 301 | if step in __eval_step_list__: 302 | self.draw_level_sets(step) 303 | 304 | if FLAGS.emd_records and step % FLAGS.interval_record_earth_mover == 0 and step != 0: 305 | self.estimate_earth_mover_distance(step) 306 | 307 | self.sess.run(self.step_inc) 308 | 309 | except KeyboardInterrupt: 310 | print("Interrupted") 311 | self.coord.request_stop() 312 | finally: 313 | self.saver.save(self.sess, self.ckpt_dir) 314 | print('Stop') 315 | self.coord.request_stop() 316 | self.coord.join(self.threads) 317 | 318 | 319 | if __name__ == '__main__': 320 | trainer = Trainer() 321 | trainer.train() 322 | trainer.sess.close() 323 | --------------------------------------------------------------------------------