├── src ├── __init__.py ├── helper │ ├── __init__.py │ ├── visualizer.py │ ├── generator.py │ └── trainer.py ├── models │ ├── __init__.py │ ├── ops.py │ ├── base.py │ ├── distribution.py │ ├── vae.py │ ├── modules.py │ ├── layers.py │ └── aae.py ├── utils │ ├── __init__.py │ ├── utils.py │ ├── dataflow.py │ └── viz.py └── dataflow │ ├── __init__.py │ └── mnist.py ├── experiment ├── __init__.py ├── vae_mnist.py └── aae_mnist.py ├── figs ├── gmm.png ├── s_1.png ├── s_2.png ├── s_3.png ├── s_4.png ├── gaussian.png ├── gmm_latent.png ├── semi_gan.png ├── semi_gan_1.png ├── semi_gan_2.png ├── semi_train.png ├── semi_valid.png ├── gmm_manifold.png ├── gaussian_latent.png ├── gmm_10k_label.png ├── gmm_10k_label_0.png ├── gmm_10k_label_1.png ├── gmm_10k_label_2.png ├── gmm_10k_label_9.png ├── gmm_full_label.png ├── supervise_code2.png ├── supervise_code5.png ├── gaussian_diagonal.png ├── gaussian_manifold.png ├── gmm_full_label_0.png ├── gmm_full_label_1.png ├── gmm_full_label_2.png ├── gmm_full_label_9.png └── supervise_code10.png ├── LICENSE ├── .gitignore └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/helper/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dataflow/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figs/gmm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm.png -------------------------------------------------------------------------------- /figs/s_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/s_1.png -------------------------------------------------------------------------------- /figs/s_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/s_2.png -------------------------------------------------------------------------------- /figs/s_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/s_3.png -------------------------------------------------------------------------------- /figs/s_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/s_4.png -------------------------------------------------------------------------------- /figs/gaussian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gaussian.png -------------------------------------------------------------------------------- /figs/gmm_latent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm_latent.png -------------------------------------------------------------------------------- /figs/semi_gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/semi_gan.png -------------------------------------------------------------------------------- /figs/semi_gan_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/semi_gan_1.png -------------------------------------------------------------------------------- /figs/semi_gan_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/semi_gan_2.png -------------------------------------------------------------------------------- /figs/semi_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/semi_train.png -------------------------------------------------------------------------------- /figs/semi_valid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/semi_valid.png -------------------------------------------------------------------------------- /figs/gmm_manifold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm_manifold.png -------------------------------------------------------------------------------- /figs/gaussian_latent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gaussian_latent.png -------------------------------------------------------------------------------- /figs/gmm_10k_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm_10k_label.png -------------------------------------------------------------------------------- /figs/gmm_10k_label_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm_10k_label_0.png -------------------------------------------------------------------------------- /figs/gmm_10k_label_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm_10k_label_1.png -------------------------------------------------------------------------------- /figs/gmm_10k_label_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm_10k_label_2.png -------------------------------------------------------------------------------- /figs/gmm_10k_label_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm_10k_label_9.png -------------------------------------------------------------------------------- /figs/gmm_full_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm_full_label.png -------------------------------------------------------------------------------- /figs/supervise_code2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/supervise_code2.png -------------------------------------------------------------------------------- /figs/supervise_code5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/supervise_code5.png -------------------------------------------------------------------------------- /figs/gaussian_diagonal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gaussian_diagonal.png -------------------------------------------------------------------------------- /figs/gaussian_manifold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gaussian_manifold.png -------------------------------------------------------------------------------- /figs/gmm_full_label_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm_full_label_0.png -------------------------------------------------------------------------------- /figs/gmm_full_label_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm_full_label_1.png -------------------------------------------------------------------------------- /figs/gmm_full_label_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm_full_label_2.png -------------------------------------------------------------------------------- /figs/gmm_full_label_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/gmm_full_label_9.png -------------------------------------------------------------------------------- /figs/supervise_code10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/conan7882/adversarial-autoencoders/HEAD/figs/supervise_code10.png -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: utils.py 4 | # Author: Qian Ge 5 | 6 | 7 | def make_list(inputs): 8 | if not isinstance(inputs, list): 9 | return [inputs] 10 | else: 11 | return inputs 12 | 13 | def assert_len(check_list): 14 | for ele in check_list[1:]: 15 | assert len(check_list[0]) == len(ele) -------------------------------------------------------------------------------- /src/models/ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: ops.py 4 | # Author: Qian Ge 5 | 6 | import tensorflow as tf 7 | import tensorflow_probability as tfp 8 | 9 | 10 | def tf_sample_standard_diag_guassian(b_size, n_code): 11 | mean_list = [0.0 for i in range(0, n_code)] 12 | std_list = [1.0 for i in range(0, n_code)] 13 | mvn = tfp.distributions.MultivariateNormalDiag( 14 | loc=mean_list, 15 | scale_diag=std_list) 16 | samples = mvn.sample(sample_shape=(b_size,), seed=None, name='sample') 17 | return samples 18 | 19 | def tf_sample_diag_guassian(mean, std, b_size, n_code): 20 | mean_list = [0.0 for i in range(0, n_code)] 21 | std_list = [1.0 for i in range(0, n_code)] 22 | mvn = tfp.distributions.MultivariateNormalDiag( 23 | loc=mean_list, 24 | scale_diag=std_list) 25 | samples = mvn.sample(sample_shape=(b_size,), seed=None, name='sample') 26 | samples = mean + tf.multiply(std, samples) 27 | 28 | return samples -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Qian Ge 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/utils/dataflow.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: dataflow.py 4 | # Author: Qian Ge 5 | 6 | import os 7 | import scipy.misc 8 | import numpy as np 9 | from datetime import datetime 10 | 11 | 12 | _RNG_SEED = None 13 | 14 | def get_rng(obj=None): 15 | """ 16 | This function is copied from `tensorpack 17 | `__. 18 | Get a good RNG seeded with time, pid and the object. 19 | Args: 20 | obj: some object to use to generate random seed. 21 | Returns: 22 | np.random.RandomState: the RNG. 23 | """ 24 | seed = (id(obj) + os.getpid() + 25 | int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295 26 | if _RNG_SEED is not None: 27 | seed = _RNG_SEED 28 | return np.random.RandomState(seed) 29 | 30 | 31 | def get_file_list(file_dir, file_ext, sub_name=None): 32 | re_list = [] 33 | 34 | if sub_name is None: 35 | return np.array([os.path.join(root, name) 36 | for root, dirs, files in os.walk(file_dir) 37 | for name in sorted(files) if name.endswith(file_ext)]) 38 | else: 39 | return np.array([os.path.join(root, name) 40 | for root, dirs, files in os.walk(file_dir) 41 | for name in sorted(files) if name.endswith(file_ext) and sub_name in name]) 42 | -------------------------------------------------------------------------------- /src/models/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: base.py 4 | # Author: Qian Ge 5 | 6 | import tensorflow as tf 7 | from abc import abstractmethod 8 | 9 | 10 | class BaseModel(object): 11 | """ Model with single loss and single optimizer """ 12 | 13 | def set_is_training(self, is_training=True): 14 | self.is_training = is_training 15 | 16 | def get_loss(self): 17 | try: 18 | return self._loss 19 | except AttributeError: 20 | self._loss = self._get_loss() 21 | return self._loss 22 | 23 | def _get_loss(self): 24 | raise NotImplementedError() 25 | 26 | def get_optimizer(self): 27 | try: 28 | return self.optimizer 29 | except AttributeError: 30 | self.optimizer = self._get_optimizer() 31 | return self.optimizer 32 | 33 | def _get_optimizer(self): 34 | raise NotImplementedError() 35 | 36 | def get_train_op(self): 37 | with tf.name_scope('train'): 38 | opt = self.get_optimizer() 39 | loss = self.get_loss() 40 | var_list = tf.trainable_variables() 41 | grads = tf.gradients(loss, var_list) 42 | # [tf.summary.histogram('gradient/' + var.name, grad, 43 | # collections=['train']) for grad, var in zip(grads, var_list)] 44 | return opt.apply_gradients(zip(grads, var_list)) 45 | 46 | 47 | -------------------------------------------------------------------------------- /src/utils/viz.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: viz.py 4 | # Author: Qian Ge 5 | 6 | import numpy as np 7 | import imageio 8 | 9 | 10 | def viz_batch_im(batch_im, grid_size, save_path, 11 | gap=0, gap_color=0, shuffle=False): 12 | 13 | batch_im = np.array(batch_im) 14 | if len(batch_im.shape) == 4: 15 | n_channel = batch_im.shape[-1] 16 | elif len(batch_im.shape) == 3: 17 | n_channel = 1 18 | batch_im = np.expand_dims(batch_im, axis=-1) 19 | assert len(grid_size) == 2 20 | 21 | h = batch_im.shape[1] 22 | w = batch_im.shape[2] 23 | 24 | merge_im = np.zeros((h * grid_size[0] + (grid_size[0] + 1) * gap, 25 | w * grid_size[1] + (grid_size[1] + 1) * gap, 26 | n_channel)) + gap_color 27 | 28 | n_viz_filter = min(batch_im.shape[0], grid_size[0] * grid_size[1]) 29 | if shuffle == True: 30 | pick_id = np.random.permutation(batch_im.shape[0]) 31 | else: 32 | pick_id = range(0, batch_im.shape[0]) 33 | for idx in range(0, n_viz_filter): 34 | i = idx % grid_size[1] 35 | j = idx // grid_size[1] 36 | cur_filter = batch_im[pick_id[idx], :, :, :] 37 | merge_im[j * (h + gap) + gap: j * (h + gap) + h + gap, 38 | i * (w + gap) + gap: i * (w + gap) + w + gap, :]\ 39 | = (cur_filter) 40 | imageio.imwrite(save_path, np.squeeze(merge_im)) 41 | 42 | 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /src/helper/visualizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: visualizer.py 4 | # Author: Qian Ge 5 | 6 | import os 7 | import numpy as np 8 | import tensorflow as tf 9 | import matplotlib.patches as mpatches 10 | 11 | # import src.utils.viz as viz 12 | 13 | class Visualizer(object): 14 | def __init__(self, model, save_path=None): 15 | 16 | self._save_path = save_path 17 | self._model = model 18 | self._latent_op = model.layers['z'] 19 | 20 | def viz_2Dlatent_variable(self, sess, dataflow, batch_size=128, file_id=None): 21 | """ 22 | modify from: 23 | https://github.com/fastforwardlabs/vae-tf/blob/master/plot.py#L45 24 | """ 25 | import matplotlib as mpl 26 | mpl.use('Agg') 27 | import matplotlib.pyplot as plt 28 | self._model.set_is_training(False) 29 | 30 | dataflow.setup(epoch_val=0, batch_size=batch_size) 31 | latent_var_list = [] 32 | label_list = [] 33 | 34 | while dataflow.epochs_completed == 0: 35 | batch_data = dataflow.next_batch_dict() 36 | im = batch_data['im'] 37 | labels = batch_data['label'] 38 | latent_var = sess.run( 39 | self._latent_op, 40 | feed_dict={self._model.encoder_in: im, 41 | self._model.keep_prob: 1.}) 42 | latent_var_list.extend(latent_var) 43 | # try: 44 | # latent_var_list.extend(latent_var[:, pick_dim]) 45 | # except UnboundLocalError: 46 | # pick_dim = np.random.choice(len(latent_var[0]), 2, replace=False) 47 | # pick_dim = sorted(pick_dim) 48 | # latent_var_list.extend(latent_var[:, pick_dim]) 49 | 50 | label_list.extend(labels) 51 | # print(latent_var) 52 | 53 | xs, ys = np.array(latent_var_list).T 54 | 55 | plt.figure() 56 | # plt.title("round {}: {} in latent space".format(model.step, title)) 57 | kwargs = {'alpha': 0.8} 58 | 59 | classes = set(label_list) 60 | if classes: 61 | colormap = plt.cm.rainbow(np.linspace(0, 1, len(classes))) 62 | kwargs['c'] = [colormap[i] for i in label_list] 63 | 64 | # make room for legend 65 | ax = plt.subplot(111, aspect='equal') 66 | box = ax.get_position() 67 | ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) 68 | handles = [mpatches.Circle((0,0), label=class_, color=colormap[i]) 69 | for i, class_ in enumerate(classes)] 70 | ax.legend(handles=handles, shadow=True, bbox_to_anchor=(1.05, 0.45), 71 | fancybox=True, loc='center left') 72 | 73 | plt.scatter(xs, ys, s=2, **kwargs) 74 | 75 | ax.set_xlim([-3.5, 3.5]) 76 | ax.set_ylim([-3.5, 3.5]) 77 | 78 | if file_id is not None: 79 | fig_save_path = os.path.join(self._save_path, 'latent_{}.png'.format(file_id)) 80 | else: 81 | fig_save_path = os.path.join(self._save_path, 'latent.png') 82 | plt.savefig(fig_save_path, bbox_inches="tight") 83 | 84 | -------------------------------------------------------------------------------- /src/models/distribution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: distribution.py 4 | # Author: Qian Ge 5 | 6 | import numpy as np 7 | from math import sin,cos,sqrt 8 | 9 | 10 | def interpolate(plot_size=20, interpolate_range=[-3, 3, -3, 3]): 11 | assert len(interpolate_range) == 4 12 | nx = plot_size 13 | ny = plot_size 14 | min_x = interpolate_range[0] 15 | max_x = interpolate_range[1] 16 | min_y = interpolate_range[2] 17 | max_y = interpolate_range[3] 18 | 19 | zs = np.rollaxis(np.mgrid[min_x: max_x: nx*1j, max_y:min_y: ny*1j], 0, 3) 20 | zs = zs.transpose(1, 0, 2) 21 | return np.reshape(zs, (plot_size*plot_size, 2)) 22 | 23 | def interpolate_gm(plot_size=20, interpolate_range=[-1., 1., -0.2, 0.2], 24 | mode_id=0, n_mode=10): 25 | n_samples = plot_size * plot_size 26 | def sample(x, y, mode_id, n_mode): 27 | shift = 1.4 28 | r = 2.0 * np.pi / float(n_mode) * float(mode_id) 29 | new_x = x * cos(r) - y * sin(r) 30 | new_y = x * sin(r) + y * cos(r) 31 | new_x += shift * cos(r) 32 | new_y += shift * sin(r) 33 | return np.array([new_x, new_y]).reshape((2,)) 34 | 35 | interp_grid = interpolate(plot_size=20, interpolate_range=interpolate_range) 36 | x = interp_grid[:, 0] 37 | y = interp_grid[:, 1] 38 | 39 | z = np.empty((n_samples, 2), dtype=np.float32) 40 | for i in range(n_samples): 41 | z[i, :2] = sample(x[i], y[i], mode_id, n_mode) 42 | return z 43 | 44 | def gaussian(batch_size, n_dim, mean=0, var=1.): 45 | z = np.random.normal(mean, var, (batch_size, n_dim)).astype(np.float32) 46 | return z 47 | 48 | def diagonal_gaussian(batch_size, n_dim, mean=0, var=1.): 49 | cov_mat = np.diag([var for i in range(n_dim)]) 50 | mean_vec = [mean for i in range(n_dim)] 51 | z = np.random.multivariate_normal( 52 | mean_vec, cov_mat, (batch_size,)).astype(np.float32) 53 | return z 54 | 55 | def gaussian_mixture(batch_size, n_dim=2, n_labels=10, 56 | x_var=0.5, y_var=0.1, label_indices=None): 57 | # borrow from: 58 | # https://github.com/nicklhy/AdversarialAutoEncoder/blob/master/data_factory.py#L40 59 | if n_dim % 2 != 0: 60 | raise Exception("n_dim must be a multiple of 2.") 61 | 62 | def sample(x, y, label, n_labels): 63 | shift = 1.4 64 | if label >= n_labels: 65 | label = np.random.randint(0, n_labels) 66 | r = 2.0 * np.pi / float(n_labels) * float(label) 67 | new_x = x * cos(r) - y * sin(r) 68 | new_y = x * sin(r) + y * cos(r) 69 | new_x += shift * cos(r) 70 | new_y += shift * sin(r) 71 | return np.array([new_x, new_y]).reshape((2,)) 72 | 73 | x = np.random.normal(0, x_var, (batch_size, n_dim // 2)) 74 | y = np.random.normal(0, y_var, (batch_size, n_dim // 2)) 75 | z = np.empty((batch_size, n_dim), dtype=np.float32) 76 | for batch in range(batch_size): 77 | for zi in range(n_dim // 2): 78 | if label_indices is not None: 79 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], label_indices[batch], n_labels) 80 | else: 81 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], np.random.randint(0, n_labels), n_labels) 82 | 83 | return z 84 | -------------------------------------------------------------------------------- /src/helper/generator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: generator.py 4 | # Author: Qian Ge 5 | 6 | import os 7 | import math 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | import src.utils.viz as viz 12 | import src.models.distribution as distribution 13 | 14 | 15 | class Generator(object): 16 | def __init__(self, generate_model, distr_type='gaussian', 17 | n_labels=None, use_label=False, save_path=None): 18 | 19 | self._save_path = save_path 20 | self._g_model = generate_model 21 | self._generate_op = generate_model.layers['generate'] 22 | 23 | self._dist = distr_type 24 | self._n_labels = n_labels 25 | self._use_label = use_label 26 | if self._use_label: 27 | assert self._n_labels is not None 28 | 29 | def sample_style(self, sess, plot_size, n_sample=10, file_id=None): 30 | # epochs_completed, batch_size = dataflow.epochs_completed, dataflow.batch_size 31 | # dataflow.setup(epoch_val=0, batch_size=n_sample) 32 | 33 | # batch_data = dataflow.next_batch_dict() 34 | # latent_var = sess.run( 35 | # self._latent_op, 36 | # feed_dict={self._g_model.encoder_in: batch_data['im'], 37 | # self._g_model.keep_prob: 1.}) 38 | 39 | # label = [] 40 | # for i in range(n_labels): 41 | # label.extend([i for k in range(n_sample)]) 42 | # code = np.tile(latent_var, [n_labels, 1]) # [n_class*10, n_code] 43 | # print(batch_data['label']) 44 | gen_im = sess.run(self._g_model.layers['generate'], 45 | feed_dict={ 46 | # self._g_model.image: batch_data['im'], 47 | # self._g_model.label: label, 48 | # self._g_model.keep_prob: 1. 49 | }) 50 | 51 | if self._save_path: 52 | if file_id is not None: 53 | im_save_path = os.path.join( 54 | self._save_path, 'sample_style_{}.png'.format(file_id)) 55 | else: 56 | im_save_path = os.path.join( 57 | self._save_path, 'sample_style.png') 58 | 59 | n_sample = len(gen_im) 60 | plot_size = int(min(plot_size, math.sqrt(n_sample))) 61 | viz.viz_batch_im(batch_im=gen_im, grid_size=[plot_size, plot_size], 62 | save_path=im_save_path, gap=0, gap_color=0, 63 | shuffle=False) 64 | 65 | # dataflow.setup(epoch_val=epochs_completed, batch_size=batch_size) 66 | 67 | def generate_samples(self, sess, plot_size, manifold=False, file_id=None): 68 | # if z is None: 69 | # gen_im = sess.run(self._generate_op) 70 | # else: 71 | n_samples = plot_size * plot_size 72 | 73 | label_indices = None 74 | if self._use_label: 75 | cur_r = 0 76 | label_indices = [] 77 | cur_label = -1 78 | while cur_r < plot_size: 79 | cur_label = cur_label + 1 if cur_label < self._n_labels - 1 else 0 80 | row_label = np.ones(plot_size) * cur_label 81 | label_indices.extend(row_label) 82 | cur_r += 1 83 | 84 | if manifold: 85 | if self._dist == 'gaussian': 86 | random_code = distribution.interpolate( 87 | plot_size=plot_size, interpolate_range=[-3, 3, -3, 3]) 88 | self.viz_samples(sess, random_code, plot_size, file_id=file_id) 89 | else: 90 | for mode_id in range(self._n_labels): 91 | random_code = distribution.interpolate_gm( 92 | plot_size=plot_size, interpolate_range=[-1., 1., -0.2, 0.2], 93 | mode_id=mode_id, n_mode=self._n_labels) 94 | self.viz_samples(sess, random_code, plot_size, 95 | file_id='{}_{}'.format(file_id, mode_id)) 96 | else: 97 | if self._dist == 'gaussian': 98 | random_code = distribution.diagonal_gaussian( 99 | n_samples, self._g_model.n_code, mean=0, var=1.0) 100 | else: 101 | random_code = distribution.gaussian_mixture( 102 | n_samples, n_dim=self._g_model.n_code, n_labels=self._n_labels, 103 | x_var=0.5, y_var=0.1, label_indices=label_indices) 104 | 105 | self.viz_samples(sess, random_code, plot_size, file_id=file_id) 106 | 107 | def viz_samples(self, sess, random_code, plot_size, file_id=None): 108 | gen_im = sess.run(self._generate_op, feed_dict={self._g_model.z: random_code}) 109 | if self._save_path: 110 | if file_id is not None: 111 | im_save_path = os.path.join( 112 | self._save_path, 'generate_im_{}.png'.format(file_id)) 113 | else: 114 | im_save_path = os.path.join( 115 | self._save_path, 'generate_im.png') 116 | 117 | n_sample = len(gen_im) 118 | plot_size = int(min(plot_size, math.sqrt(n_sample))) 119 | viz.viz_batch_im(batch_im=gen_im, grid_size=[plot_size, plot_size], 120 | save_path=im_save_path, gap=0, gap_color=0, 121 | shuffle=False) 122 | 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /src/dataflow/mnist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: mnist.py 4 | # Author: Qian Ge 5 | 6 | # from tensorflow.examples.tutorials.mnist import input_data 7 | import os 8 | import gzip 9 | import struct 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from src.utils.dataflow import get_rng 14 | 15 | def identity(im): 16 | return im 17 | 18 | class MNISTData(object): 19 | def __init__(self, name, data_dir='', n_use_label=None, n_use_sample=None, 20 | batch_dict_name=None, shuffle=True, pf=identity): 21 | assert os.path.isdir(data_dir) 22 | self._data_dir = data_dir 23 | 24 | self._shuffle = shuffle 25 | self._pf = pf 26 | 27 | if not isinstance(batch_dict_name, list): 28 | batch_dict_name = [batch_dict_name] 29 | self._batch_dict_name = batch_dict_name 30 | 31 | assert name in ['train', 'test', 'val'] 32 | self.setup(epoch_val=0, batch_size=1) 33 | 34 | self._load_files(name, n_use_label, n_use_sample) 35 | self._image_id = 0 36 | 37 | def next_batch_dict(self): 38 | batch_data = self.next_batch() 39 | data_dict = {key: data for key, data in zip(self._batch_dict_name, batch_data)} 40 | return data_dict 41 | 42 | def _load_files(self, name, n_use_label, n_use_sample): 43 | if name == 'train': 44 | image_name = 'train-images-idx3-ubyte.gz' 45 | label_name = 'train-labels-idx1-ubyte.gz' 46 | else: 47 | image_name = 't10k-images-idx3-ubyte.gz' 48 | label_name = 't10k-labels-idx1-ubyte.gz' 49 | 50 | image_path = os.path.join(self._data_dir, image_name) 51 | label_path = os.path.join(self._data_dir, label_name) 52 | 53 | with gzip.open(label_path) as f: 54 | magic = struct.unpack('>I', f.read(4)) 55 | if magic[0] != 2049: 56 | raise Exception('Invalid file: unexpected magic number.') 57 | n_label = struct.unpack('>I', f.read(4)) 58 | label_list = np.fromstring(f.read(n_label[0]), dtype = np.uint8) 59 | 60 | with gzip.open(image_path) as f: 61 | magic = struct.unpack('>I', f.read(4)) 62 | if magic[0] != 2051: 63 | raise Exception('Invalid file: unexpected magic number.') 64 | n_im, rows, cols = struct.unpack('>III', f.read(12)) 65 | image_list = np.fromstring(f.read(n_im * rows * cols), dtype = np.uint8) 66 | image_list = np.reshape(image_list, (n_im, rows, cols, 1)) 67 | # image_list = image_list.astype(np.float32) 68 | im_list = [] 69 | if n_use_sample is not None and n_use_sample < len(label_list): 70 | remain_sample = n_use_sample // 10 * 10 71 | left_sample = n_use_sample - remain_sample 72 | keep_sign = [0 for i in range(10)] 73 | data_idx = 0 74 | new_label_list = [] 75 | for idx, im in enumerate(image_list): 76 | 77 | if remain_sample > 0: 78 | if keep_sign[label_list[idx]] < (n_use_sample // 10): 79 | keep_sign[label_list[idx]] += 1 80 | im_list.append(self._pf(im)) 81 | new_label_list.append(label_list[idx]) 82 | remain_sample -= 1 83 | else: 84 | break 85 | im_list.extend(image_list[idx:idx + left_sample]) 86 | new_label_list.extend(label_list[idx:idx + left_sample]) 87 | label_list = new_label_list 88 | 89 | else: 90 | for im in image_list: 91 | im_list.append(self._pf(im)) 92 | 93 | self.im_list = np.array(im_list) 94 | self.label_list = np.array(label_list) 95 | 96 | if n_use_label is not None and n_use_label < self.size(): 97 | remain_sample = n_use_label // 10 * 10 98 | left_sample = n_use_label - remain_sample 99 | keep_sign = [0 for i in range(10)] 100 | data_idx = 0 101 | while remain_sample > 0: 102 | if keep_sign[self.label_list[data_idx]] < (n_use_label // 10): 103 | keep_sign[self.label_list[data_idx]] += 1 104 | remain_sample -= 1 105 | else: 106 | self.label_list[data_idx] = 10 107 | data_idx += 1 108 | 109 | self.label_list[data_idx + left_sample:] = 10 110 | self._suffle_files() 111 | 112 | def _suffle_files(self): 113 | if self._shuffle: 114 | idxs = np.arange(self.size()) 115 | 116 | self.rng.shuffle(idxs) 117 | self.im_list = self.im_list[idxs] 118 | self.label_list = self.label_list[idxs] 119 | 120 | def size(self): 121 | return self.im_list.shape[0] 122 | 123 | def next_batch(self): 124 | assert self._batch_size <= self.size(), \ 125 | "batch_size {} cannot be larger than data size {}".\ 126 | format(self._batch_size, self.size()) 127 | start = self._image_id 128 | self._image_id += self._batch_size 129 | end = self._image_id 130 | batch_files = self.im_list[start:end] 131 | batch_label = self.label_list[start:end] 132 | 133 | if self._image_id + self._batch_size > self.size(): 134 | self._epochs_completed += 1 135 | self._image_id = 0 136 | self._suffle_files() 137 | return [batch_files, batch_label] 138 | 139 | def setup(self, epoch_val, batch_size, **kwargs): 140 | self.reset_epochs_completed(epoch_val) 141 | self.set_batch_size(batch_size) 142 | self.reset_state() 143 | try: 144 | self._suffle_files() 145 | except AttributeError: 146 | pass 147 | 148 | def reset_epoch(self): 149 | self._epochs_completed = 0 150 | 151 | @property 152 | def batch_size(self): 153 | return self._batch_size 154 | 155 | @property 156 | def epochs_completed(self): 157 | return self._epochs_completed 158 | 159 | def set_batch_size(self, batch_size): 160 | self._batch_size = batch_size 161 | 162 | def reset_epochs_completed(self, epoch_val): 163 | self._epochs_completed = epoch_val 164 | 165 | def reset_state(self): 166 | self.rng = get_rng(self) 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /src/models/vae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: vae.py 4 | # Author: Qian Ge 5 | 6 | import tensorflow as tf 7 | from src.models.base import BaseModel 8 | import src.models.layers as L 9 | import src.models.modules as modules 10 | import src.models.ops as ops 11 | 12 | # INIT_W = tf.keras.initializers.he_normal() 13 | INIT_W = tf.contrib.layers.variance_scaling_initializer() 14 | 15 | 16 | class VAE(BaseModel): 17 | def __init__(self, im_size=[28, 28], n_code=1000, n_channel=1, wd=0): 18 | self._n_channel = n_channel 19 | self._wd = wd 20 | self._n_code = n_code 21 | self._im_size = im_size 22 | self.layers = {} 23 | 24 | def create_generate_model(self, b_size): 25 | self.set_is_training(False) 26 | self._create_generate_input() 27 | self.z = ops.tf_sample_standard_diag_guassian(b_size, self._n_code) 28 | self.layers['generate'] = tf.nn.sigmoid(self.decoder(self.z)) 29 | 30 | def _create_generate_input(self): 31 | self.z = tf.placeholder( 32 | tf.float32, name='latent_z', 33 | shape=[None, self._n_code]) 34 | self.keep_prob = 1. 35 | 36 | def create_train_model(self): 37 | self.set_is_training(True) 38 | self._create_train_input() 39 | self.layers['encoder_out'] = self.encoder() 40 | self.layers['z'], self.layers['z_mu'], self.layers['z_std'], self.layers['z_log_std'] =\ 41 | self.sample_latent() 42 | self.layers['decoder_out'] = self.decoder(self.layers['z']) 43 | 44 | def _create_train_input(self): 45 | self.image = tf.placeholder( 46 | tf.float32, name='image', 47 | shape=[None, self._im_size[0], self._im_size[1], self._n_channel]) 48 | self.lr = tf.placeholder(tf.float32, name='lr') 49 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 50 | 51 | def encoder(self): 52 | with tf.variable_scope('encoder'): 53 | # cnn_out = modules.encoder_CNN( 54 | # self.image, is_training=self.is_training, init_w=INIT_W, 55 | # wd=self._wd, bn=False, name='encoder_CNN') 56 | 57 | fc_out = modules.encoder_FC(self.image, self.is_training, keep_prob=self.keep_prob, wd=self._wd, name='encoder_FC', init_w=INIT_W) 58 | 59 | # fc_out = L.linear( 60 | # out_dim=self._n_code*2, layer_dict=self.layers, 61 | # inputs=cnn_out, init_w=INIT_W, wd=self._wd, name='Linear') 62 | 63 | return fc_out 64 | 65 | def sample_latent(self): 66 | with tf.variable_scope('sample_latent'): 67 | cnn_out = self.layers['encoder_out'] 68 | 69 | z_mean = L.linear( 70 | out_dim=self._n_code, layer_dict=self.layers, 71 | inputs=cnn_out, init_w=INIT_W, wd=self._wd, name='latent_mean') 72 | z_std = L.linear( 73 | out_dim=self._n_code, layer_dict=self.layers, nl=L.softplus, 74 | inputs=cnn_out, init_w=INIT_W, wd=self._wd, name='latent_std') 75 | z_log_std = tf.log(z_std + 1e-8) 76 | 77 | b_size = tf.shape(cnn_out)[0] 78 | z = ops.tf_sample_diag_guassian(z_mean, z_std, b_size, self._n_code) 79 | return z, z_mean, z_std, z_log_std 80 | 81 | def decoder(self, inputs): 82 | with tf.variable_scope('decoder'): 83 | # out_h = int(self._im_size[0] / 4) 84 | # out_w = int(self._im_size[1] / 4) 85 | # out_dim = out_h * out_w * 64 86 | 87 | # z_linear = L.linear( 88 | # out_dim=out_dim, layer_dict=self.layers, nl=tf.nn.relu, 89 | # inputs=inputs, init_w=INIT_W, wd=self._wd, name='z_linear') 90 | # z_linear = tf.reshape(z_linear, (-1, out_h, out_w, 64)) 91 | 92 | # decoder_out = modules.decoder_CNN( 93 | # z_linear, is_training=self.is_training, out_channel=self._n_channel, 94 | # wd=self._wd, bn=False, name='decoder_CNN') 95 | 96 | fc_out = modules.decoder_FC(inputs, self.is_training, keep_prob=self.keep_prob, wd=self._wd, name='decoder_FC', init_w=INIT_W) 97 | out_dim = self._im_size[0] * self._im_size[1] * self._n_channel 98 | decoder_out = L.linear( 99 | out_dim=out_dim, layer_dict=self.layers, 100 | inputs=fc_out, init_w=None, wd=self._wd, name='decoder_linear') 101 | decoder_out = tf.reshape(decoder_out, (-1, self._im_size[0], self._im_size[1], self._n_channel)) 102 | 103 | return decoder_out 104 | 105 | def _get_loss(self): 106 | with tf.name_scope('loss'): 107 | with tf.name_scope('likelihood'): 108 | p_hat = tf.nn.sigmoid(self.layers['decoder_out'], name='estimate_prob') 109 | p = self.image 110 | cross_entropy = p * tf.log(p_hat + 1e-6) + (1 - p) * tf.log(1 - p_hat + 1e-6) 111 | # likelihood_loss = tf.nn.sigmoid_cross_entropy_with_logits( 112 | # labels=label, 113 | # logits=logits, 114 | # name='likelihood_loss') 115 | cross_entropy_loss = -tf.reduce_mean(tf.reduce_sum(cross_entropy, axis=[1,2,3])) 116 | 117 | with tf.name_scope('KL'): 118 | kl_loss = tf.reduce_sum( 119 | tf.square(self.layers['z_mu']) 120 | + tf.square(self.layers['z_std']) 121 | - 2 * self.layers['z_log_std'], 122 | axis=1) 123 | 124 | kl_loss = 0.5 * kl_loss 125 | kl_loss = tf.reduce_mean(kl_loss) 126 | 127 | return cross_entropy_loss + kl_loss 128 | 129 | def _get_optimizer(self): 130 | return tf.train.AdamOptimizer(self.lr) 131 | 132 | def get_valid_summary(self): 133 | with tf.name_scope('generate'): 134 | tf.summary.image( 135 | 'image', 136 | tf.cast(self.layers['generate'], tf.float32), 137 | collections=['generate']) 138 | return tf.summary.merge_all(key='generate') 139 | 140 | def get_train_summary(self): 141 | tf.summary.image( 142 | 'input_image', 143 | tf.cast(self.image, tf.float32), 144 | collections=['train']) 145 | tf.summary.image( 146 | 'out_image', 147 | tf.cast(tf.nn.sigmoid(self.layers['decoder_out']), tf.float32), 148 | collections=['train']) 149 | tf.summary.histogram( 150 | name='encoder distribution', values=self.layers['z'], 151 | collections=['train']) 152 | return tf.summary.merge_all(key='train') 153 | 154 | -------------------------------------------------------------------------------- /src/models/modules.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: modules.py 4 | # Author: Qian Ge 5 | 6 | import tensorflow as tf 7 | import src.models.layers as L 8 | 9 | 10 | def encoder_FC(inputs, is_training, n_hidden=1000, nl=tf.nn.relu, 11 | keep_prob=0.5, wd=0, name='encoder_FC', init_w=None): 12 | layer_dict = {} 13 | layer_dict['cur_input'] = inputs 14 | with tf.variable_scope(name): 15 | arg_scope = tf.contrib.framework.arg_scope 16 | with arg_scope([L.linear], 17 | out_dim=n_hidden, layer_dict=layer_dict, init_w=init_w, 18 | wd=wd): 19 | L.linear(name='linear1', nl=nl) 20 | L.drop_out(layer_dict, is_training, keep_prob=keep_prob) 21 | L.linear(name='linear2', nl=nl) 22 | L.drop_out(layer_dict, is_training, keep_prob=keep_prob) 23 | 24 | return layer_dict['cur_input'] 25 | 26 | def decoder_FC(inputs, is_training, n_hidden=1000, nl=tf.nn.relu, 27 | keep_prob=0.5, wd=0, name='decoder_FC', init_w=None): 28 | layer_dict = {} 29 | layer_dict['cur_input'] = inputs 30 | with tf.variable_scope(name): 31 | arg_scope = tf.contrib.framework.arg_scope 32 | with arg_scope([L.linear], 33 | out_dim=n_hidden, layer_dict=layer_dict, init_w=init_w, 34 | wd=wd): 35 | L.linear(name='linear1', nl=nl) 36 | L.drop_out(layer_dict, is_training, keep_prob=keep_prob) 37 | L.linear(name='linear2', nl=nl) 38 | L.drop_out(layer_dict, is_training, keep_prob=keep_prob) 39 | 40 | return layer_dict['cur_input'] 41 | 42 | def discriminator_FC(inputs, is_training, n_hidden=1000, nl=tf.nn.relu, 43 | wd=0, name='discriminator_FC', init_w=None): 44 | layer_dict = {} 45 | layer_dict['cur_input'] = inputs 46 | with tf.variable_scope(name): 47 | arg_scope = tf.contrib.framework.arg_scope 48 | with arg_scope([L.linear], 49 | layer_dict=layer_dict, init_w=init_w, 50 | wd=wd): 51 | L.linear(name='linear1', nl=nl, out_dim=n_hidden) 52 | L.linear(name='linear2', nl=nl, out_dim=n_hidden) 53 | L.linear(name='output', out_dim=1) 54 | 55 | return layer_dict['cur_input'] 56 | 57 | def encoder_CNN(inputs, is_training, wd=0, bn=False, name='encoder_CNN', 58 | init_w=tf.keras.initializers.he_normal()): 59 | # init_w = tf.keras.initializers.he_normal() 60 | layer_dict = {} 61 | layer_dict['cur_input'] = inputs 62 | with tf.variable_scope(name): 63 | arg_scope = tf.contrib.framework.arg_scope 64 | with arg_scope([L.conv], 65 | layer_dict=layer_dict, bn=bn, nl=tf.nn.relu, 66 | init_w=init_w, padding='SAME', pad_type='ZERO', 67 | is_training=is_training, wd=0): 68 | 69 | L.conv(filter_size=5, out_dim=32, name='conv1', add_summary=False) 70 | L.max_pool(layer_dict, name='pool1') 71 | L.conv(filter_size=3, out_dim=64, name='conv2', add_summary=False) 72 | L.max_pool(layer_dict, name='pool2') 73 | # L.conv(filter_size=3, out_dim=128, name='conv3', add_summary=False) 74 | # L.max_pool(layer_dict, name='pool3') 75 | 76 | return layer_dict['cur_input'] 77 | 78 | def decoder_CNN(inputs, is_training, out_channel=1, wd=0, bn=False, name='decoder_CNN', 79 | init_w=tf.keras.initializers.he_normal()): 80 | # init_w = tf.keras.initializers.he_normal() 81 | layer_dict = {} 82 | layer_dict['cur_input'] = inputs 83 | with tf.variable_scope(name): 84 | arg_scope = tf.contrib.framework.arg_scope 85 | with arg_scope([L.transpose_conv], 86 | layer_dict=layer_dict, nl=tf.nn.relu, stride=2, 87 | init_w=init_w, wd=0): 88 | 89 | # L.transpose_conv(filter_size=3, out_dim=64, name='deconv1') 90 | L.transpose_conv(filter_size=3, out_dim=32, name='deconv2') 91 | L.transpose_conv(filter_size=3, out_dim=out_channel, name='deconv3') 92 | 93 | return layer_dict['cur_input'] 94 | 95 | def train_discrimator(fake_in, real_in, loss_weight, opt, var_list, name): 96 | with tf.name_scope(name): 97 | with tf.name_scope('discrimator_loss'): 98 | loss_real = tf.nn.sigmoid_cross_entropy_with_logits( 99 | labels=tf.ones_like(real_in), 100 | logits=real_in, 101 | name='loss_real') 102 | loss_fake = tf.nn.sigmoid_cross_entropy_with_logits( 103 | labels=tf.zeros_like(fake_in), 104 | logits=fake_in, 105 | name='loss_fake') 106 | d_loss = tf.reduce_mean(loss_real) + tf.reduce_mean(loss_fake) 107 | 108 | # opt = tf.train.AdamOptimizer(lr, beta1=0.5) 109 | # opt = tf.train.MomentumOptimizer(self.lr, momentum=0.1) 110 | # dc_var = [var for var in all_variables if 'dc_' in var.name] 111 | # var_list = tf.trainable_variables(scope='discriminator') 112 | # print(tf.trainable_variables()) 113 | # print(var_list) 114 | grads = tf.gradients(d_loss * loss_weight, var_list) 115 | # [tf.summary.histogram('gradient/' + var.name, grad, 116 | # collections=['train']) for grad, var in zip(grads, var_list)] 117 | train_op = opt.apply_gradients(zip(grads, var_list)) 118 | 119 | return d_loss, train_op 120 | 121 | def train_generator(fake_in, loss_weight, opt, var_list, name): 122 | with tf.name_scope(name): 123 | with tf.name_scope('generator_loss'): 124 | loss_fake = tf.nn.sigmoid_cross_entropy_with_logits( 125 | labels=tf.ones_like(fake_in), 126 | logits=fake_in, 127 | name='loss_fake') 128 | g_loss = tf.reduce_mean(loss_fake) 129 | # opt = tf.train.AdamOptimizer(lr, beta1=0.5) 130 | # print(var_list) 131 | grads = tf.gradients(g_loss * loss_weight, var_list) 132 | train_op = opt.apply_gradients(zip(grads, var_list)) 133 | 134 | return g_loss, train_op 135 | 136 | def train_by_cross_entropy_loss(logits, labels, loss_weight, opt, var_list, name): 137 | with tf.name_scope(name): 138 | with tf.name_scope('cross_entropy_loss'): 139 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 140 | labels=labels, 141 | logits=logits, 142 | name='cross_entropy') 143 | cross_entropy = tf.reduce_mean(cross_entropy) 144 | # opt = tf.train.AdamOptimizer(lr, beta1=0.5) 145 | # print(var_list) 146 | grads = tf.gradients(cross_entropy * loss_weight, var_list) 147 | train_op = opt.apply_gradients(zip(grads, var_list)) 148 | 149 | return cross_entropy, train_op 150 | 151 | -------------------------------------------------------------------------------- /experiment/vae_mnist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: vae_mnist.py 4 | # Author: Qian Ge 5 | 6 | import os 7 | import sys 8 | import numpy as np 9 | import tensorflow as tf 10 | import platform 11 | import scipy.misc 12 | import argparse 13 | import matplotlib.pyplot as plt 14 | 15 | sys.path.append('../') 16 | from src.dataflow.mnist import MNISTData 17 | from src.models.vae import VAE 18 | from src.helper.trainer import Trainer 19 | from src.helper.generator import Generator 20 | from src.helper.visualizer import Visualizer 21 | import src.models.distribution as distribution 22 | 23 | DATA_PATH = '/home/qge2/workspace/data/MNIST_data/' 24 | SAVE_PATH = '/home/qge2/workspace/data/out/vae/vae/' 25 | 26 | 27 | def get_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--train', action='store_true', 30 | help='Train the model') 31 | parser.add_argument('--generate', action='store_true', 32 | help='generate') 33 | parser.add_argument('--viz', action='store_true', 34 | help='visualize') 35 | parser.add_argument('--test', action='store_true', 36 | help='test') 37 | parser.add_argument('--load', type=int, default=99, 38 | help='Load step of pre-trained') 39 | parser.add_argument('--lr', type=float, default=1e-3, 40 | help='Init learning rate') 41 | parser.add_argument('--ncode', type=int, default=2, 42 | help='number of code') 43 | 44 | parser.add_argument('--bsize', type=int, default=128, 45 | help='Init learning rate') 46 | parser.add_argument('--maxepoch', type=int, default=100, 47 | help='Max iteration') 48 | 49 | 50 | return parser.parse_args() 51 | 52 | 53 | def preprocess_im(im): 54 | im = im / 255. 55 | return im 56 | 57 | def train(): 58 | FLAGS = get_args() 59 | train_data = MNISTData('train', 60 | data_dir=DATA_PATH, 61 | shuffle=True, 62 | pf=preprocess_im, 63 | batch_dict_name=['im', 'label']) 64 | train_data.setup(epoch_val=0, batch_size=FLAGS.bsize) 65 | valid_data = MNISTData('test', 66 | data_dir=DATA_PATH, 67 | shuffle=True, 68 | pf=preprocess_im, 69 | batch_dict_name=['im', 'label']) 70 | valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize) 71 | 72 | with tf.variable_scope('VAE') as scope: 73 | model = VAE(n_code=FLAGS.ncode, wd=0) 74 | model.create_train_model() 75 | 76 | with tf.variable_scope('VAE') as scope: 77 | scope.reuse_variables() 78 | valid_model = VAE(n_code=FLAGS.ncode, wd=0) 79 | valid_model.create_generate_model(b_size=400) 80 | 81 | trainer = Trainer(model, valid_model, train_data, init_lr=FLAGS.lr, save_path=SAVE_PATH) 82 | if FLAGS.ncode == 2: 83 | z = distribution.interpolate(plot_size=20) 84 | z = np.reshape(z, (400, 2)) 85 | visualizer = Visualizer(model, save_path=SAVE_PATH) 86 | else: 87 | z = None 88 | generator = Generator(generate_model=valid_model, save_path=SAVE_PATH) 89 | 90 | sessconfig = tf.ConfigProto() 91 | sessconfig.gpu_options.allow_growth = True 92 | with tf.Session(config=sessconfig) as sess: 93 | writer = tf.summary.FileWriter(SAVE_PATH) 94 | saver = tf.train.Saver() 95 | sess.run(tf.global_variables_initializer()) 96 | writer.add_graph(sess.graph) 97 | 98 | for epoch_id in range(FLAGS.maxepoch): 99 | trainer.train_epoch(sess, summary_writer=writer) 100 | trainer.valid_epoch(sess, summary_writer=writer) 101 | if epoch_id % 10 == 0: 102 | saver.save(sess, '{}vae-epoch-{}'.format(SAVE_PATH, epoch_id)) 103 | if FLAGS.ncode == 2: 104 | generator.generate_samples(sess, plot_size=20, z=z, file_id=epoch_id) 105 | visualizer.viz_2Dlatent_variable(sess, valid_data, file_id=epoch_id) 106 | 107 | def generate(): 108 | FLAGS = get_args() 109 | plot_size = 20 110 | 111 | with tf.variable_scope('VAE') as scope: 112 | # scope.reuse_variables() 113 | generate_model = VAE(n_code=FLAGS.ncode, wd=0) 114 | generate_model.create_generate_model(b_size=plot_size*plot_size) 115 | 116 | generator = Generator(generate_model=generate_model, save_path=SAVE_PATH) 117 | 118 | sessconfig = tf.ConfigProto() 119 | sessconfig.gpu_options.allow_growth = True 120 | with tf.Session(config=sessconfig) as sess: 121 | saver = tf.train.Saver() 122 | sess.run(tf.global_variables_initializer()) 123 | saver.restore(sess, '{}vae-epoch-{}'.format(SAVE_PATH, FLAGS.load)) 124 | generator.generate_samples(sess, plot_size=plot_size, z=None) 125 | 126 | def visualize(): 127 | FLAGS = get_args() 128 | plot_size = 20 129 | 130 | valid_data = MNISTData('test', 131 | data_dir=DATA_PATH, 132 | shuffle=True, 133 | pf=preprocess_im, 134 | batch_dict_name=['im', 'label']) 135 | valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize) 136 | 137 | with tf.variable_scope('VAE') as scope: 138 | model = VAE(n_code=FLAGS.ncode, wd=0) 139 | model.create_train_model() 140 | 141 | with tf.variable_scope('VAE') as scope: 142 | scope.reuse_variables() 143 | valid_model = VAE(n_code=FLAGS.ncode, wd=0) 144 | valid_model.create_generate_model(b_size=400) 145 | 146 | visualizer = Visualizer(model, save_path=SAVE_PATH) 147 | generator = Generator(generate_model=valid_model, save_path=SAVE_PATH) 148 | 149 | z = distribution.interpolate(plot_size=plot_size) 150 | z = np.reshape(z, (plot_size*plot_size, 2)) 151 | 152 | sessconfig = tf.ConfigProto() 153 | sessconfig.gpu_options.allow_growth = True 154 | with tf.Session(config=sessconfig) as sess: 155 | saver = tf.train.Saver() 156 | sess.run(tf.global_variables_initializer()) 157 | saver.restore(sess, '{}vae-epoch-{}'.format(SAVE_PATH, FLAGS.load)) 158 | visualizer.viz_2Dlatent_variable(sess, valid_data) 159 | generator.generate_samples(sess, plot_size=plot_size, z=z) 160 | 161 | def test(): 162 | valid_data = MNISTData('test', 163 | data_dir=DATA_PATH, 164 | shuffle=True, 165 | pf=preprocess_im, 166 | batch_dict_name=['im', 'label']) 167 | batch_data = valid_data.next_batch_dict() 168 | plt.figure() 169 | plt.imshow(np.squeeze(batch_data['im'][0])) 170 | plt.show() 171 | print(batch_data['label']) 172 | 173 | if __name__ == '__main__': 174 | FLAGS = get_args() 175 | 176 | if FLAGS.train: 177 | train() 178 | elif FLAGS.generate: 179 | generate() 180 | elif FLAGS.viz: 181 | visualize() 182 | elif FLAGS.test: 183 | test() 184 | 185 | -------------------------------------------------------------------------------- /src/models/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: layers.py 4 | # Author: Qian Ge 5 | 6 | import math 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflow.contrib.framework import add_arg_scope 10 | 11 | 12 | def get_shape4D(in_val): 13 | """ 14 | Return a 4D shape 15 | Args: 16 | in_val (int or list with length 2) 17 | Returns: 18 | list with length 4 19 | """ 20 | # if isinstance(in_val, int): 21 | return [1] + get_shape2D(in_val) + [1] 22 | 23 | def get_shape2D(in_val): 24 | """ 25 | Return a 2D shape 26 | Args: 27 | in_val (int or list with length 2) 28 | Returns: 29 | list with length 2 30 | """ 31 | in_val = int(in_val) 32 | if isinstance(in_val, int): 33 | return [in_val, in_val] 34 | if isinstance(in_val, list): 35 | assert len(in_val) == 2 36 | return in_val 37 | raise RuntimeError('Illegal shape: {}'.format(in_val)) 38 | 39 | def batch_flatten(x): 40 | """ 41 | Flatten the tensor except the first dimension. 42 | """ 43 | shape = x.get_shape().as_list()[1:] 44 | if None not in shape: 45 | return tf.reshape(x, [-1, int(np.prod(shape))]) 46 | return tf.reshape(x, tf.stack([tf.shape(x)[0], -1])) 47 | 48 | def softplus(inputs, name): 49 | return tf.log(1 + tf.exp(inputs), name=name) 50 | 51 | def softmax(logits, axis=-1, name='softmax'): 52 | with tf.name_scope(name): 53 | max_in = tf.reduce_max(logits, axis=axis, keepdims=True) 54 | stable_in = logits - max_in 55 | normal_p = tf.reduce_sum(tf.exp(stable_in), axis=axis, keepdims=True) 56 | 57 | return tf.exp(stable_in) / normal_p 58 | 59 | def leaky_relu(x, leak=0.2, name='LeakyRelu'): 60 | """ 61 | leaky_relu 62 | Allow a small non-zero gradient when the unit is not active 63 | Args: 64 | x (tf.tensor): a tensor 65 | leak (float): Default to 0.2 66 | Returns: 67 | tf.tensor with name 'name' 68 | """ 69 | return tf.maximum(x, leak*x, name=name) 70 | 71 | 72 | @add_arg_scope 73 | def linear(out_dim, 74 | layer_dict=None, 75 | inputs=None, 76 | init_w=None, 77 | init_b=tf.zeros_initializer(), 78 | wd=0, 79 | name='Linear', 80 | nl=tf.identity): 81 | with tf.variable_scope(name): 82 | if inputs is None: 83 | assert layer_dict is not None 84 | inputs = layer_dict['cur_input'] 85 | inputs = batch_flatten(inputs) 86 | in_dim = inputs.get_shape().as_list()[1] 87 | if wd > 0: 88 | regularizer = tf.contrib.layers.l2_regularizer(scale=wd) 89 | else: 90 | regularizer=None 91 | weights = tf.get_variable('weights', 92 | shape=[in_dim, out_dim], 93 | # dtype=None, 94 | initializer=init_w, 95 | regularizer=regularizer, 96 | trainable=True) 97 | biases = tf.get_variable('biases', 98 | shape=[out_dim], 99 | # dtype=None, 100 | initializer=init_b, 101 | regularizer=None, 102 | trainable=True) 103 | # print('init: {}'.format(weights)) 104 | act = tf.nn.xw_plus_b(inputs, weights, biases) 105 | result = nl(act, name='output') 106 | if layer_dict is not None: 107 | layer_dict['cur_input'] = result 108 | 109 | return result 110 | 111 | @add_arg_scope 112 | def transpose_conv( 113 | filter_size, 114 | out_dim, 115 | layer_dict, 116 | inputs=None, 117 | out_shape=None, 118 | stride=2, 119 | padding='SAME', 120 | trainable=True, 121 | nl=tf.identity, 122 | init_w=None, 123 | init_b=tf.zeros_initializer(), 124 | wd=0, 125 | constant_init=False, 126 | name='dconv'): 127 | if inputs is None: 128 | inputs = layer_dict['cur_input'] 129 | stride = get_shape4D(stride) 130 | in_dim = inputs.get_shape().as_list()[-1] 131 | 132 | # TODO other ways to determine the output shape 133 | x_shape = tf.shape(inputs) 134 | # assume output shape is input_shape*stride 135 | if out_shape is None: 136 | out_shape = tf.stack([x_shape[0], 137 | tf.multiply(x_shape[1], stride[1]), 138 | tf.multiply(x_shape[2], stride[2]), 139 | out_dim]) 140 | 141 | filter_shape = get_shape2D(filter_size) + [out_dim, in_dim] 142 | 143 | with tf.variable_scope(name) as scope: 144 | if wd > 0: 145 | regularizer = tf.contrib.layers.l2_regularizer(scale=wd) 146 | else: 147 | regularizer=None 148 | weights = tf.get_variable('weights', 149 | filter_shape, 150 | initializer=init_w, 151 | trainable=trainable, 152 | regularizer=regularizer) 153 | biases = tf.get_variable('biases', 154 | [out_dim], 155 | initializer=init_b, 156 | trainable=trainable) 157 | 158 | output = tf.nn.conv2d_transpose(inputs, 159 | weights, 160 | output_shape=out_shape, 161 | strides=stride, 162 | padding=padding, 163 | name=scope.name) 164 | 165 | output = tf.nn.bias_add(output, biases) 166 | output.set_shape([None, None, None, out_dim]) 167 | output = nl(output, name='output') 168 | layer_dict['cur_input'] = output 169 | return output 170 | 171 | def max_pool(layer_dict, 172 | inputs=None, 173 | name='max_pool', 174 | filter_size=2, 175 | stride=None, 176 | padding='SAME', 177 | switch=False): 178 | """ 179 | Max pooling layer 180 | Args: 181 | x (tf.tensor): a tensor 182 | name (str): name scope of the layer 183 | filter_size (int or list with length 2): size of filter 184 | stride (int or list with length 2): Default to be the same as shape 185 | padding (str): 'VALID' or 'SAME'. Use 'SAME' for FCN. 186 | Returns: 187 | tf.tensor with name 'name' 188 | """ 189 | if inputs is not None: 190 | layer_dict['cur_input'] = inputs 191 | padding = padding.upper() 192 | filter_shape = get_shape4D(filter_size) 193 | if stride is None: 194 | stride = filter_shape 195 | else: 196 | stride = get_shape4D(stride) 197 | 198 | if switch == True: 199 | layer_dict['cur_input'], switch_s = tf.nn.max_pool_with_argmax( 200 | layer_dict['cur_input'], 201 | ksize=filter_shape, 202 | strides=stride, 203 | padding=padding, 204 | Targmax=tf.int64, 205 | name=name) 206 | return layer_dict['cur_input'], switch_s 207 | else: 208 | layer_dict['cur_input'] = tf.nn.max_pool( 209 | layer_dict['cur_input'], 210 | ksize=filter_shape, 211 | strides=stride, 212 | padding=padding, 213 | name=name) 214 | return layer_dict['cur_input'], None 215 | 216 | @add_arg_scope 217 | def conv(filter_size, 218 | out_dim, 219 | layer_dict, 220 | inputs=None, 221 | pretrained_dict=None, 222 | stride=1, 223 | dilations=[1, 1, 1, 1], 224 | bn=False, 225 | nl=tf.identity, 226 | init_w=None, 227 | init_b=tf.zeros_initializer(), 228 | use_bias=True, 229 | padding='SAME', 230 | pad_type='ZERO', 231 | trainable=True, 232 | is_training=None, 233 | wd=0, 234 | name='conv', 235 | add_summary=False): 236 | if inputs is None: 237 | inputs = layer_dict['cur_input'] 238 | stride = get_shape4D(stride) 239 | in_dim = inputs.get_shape().as_list()[-1] 240 | filter_shape = get_shape2D(filter_size) + [in_dim, out_dim] 241 | 242 | if padding == 'SAME' and pad_type == 'REFLECT': 243 | pad_size_1 = int((filter_shape[0] - 1) / 2) 244 | pad_size_2 = int((filter_shape[1] - 1) / 2) 245 | inputs = tf.pad( 246 | inputs, 247 | [[0, 0], [pad_size_1, pad_size_1], [pad_size_2, pad_size_2], [0, 0]], 248 | "REFLECT") 249 | padding = 'VALID' 250 | 251 | with tf.variable_scope(name): 252 | if wd > 0: 253 | regularizer = tf.contrib.layers.l2_regularizer(scale=wd) 254 | else: 255 | regularizer=None 256 | 257 | if pretrained_dict is not None and name in pretrained_dict: 258 | try: 259 | load_w = pretrained_dict[name][0] 260 | except KeyError: 261 | load_w = pretrained_dict[name]['weights'] 262 | print('Load {} weights!'.format(name)) 263 | 264 | load_w = np.reshape(load_w, filter_shape) 265 | init_w = tf.constant_initializer(load_w) 266 | 267 | weights = tf.get_variable('weights', 268 | filter_shape, 269 | initializer=init_w, 270 | trainable=trainable, 271 | regularizer=regularizer) 272 | if add_summary: 273 | tf.summary.histogram( 274 | 'weights/{}'.format(name), weights, collections = ['train']) 275 | 276 | outputs = tf.nn.conv2d(inputs, 277 | filter=weights, 278 | strides=stride, 279 | padding=padding, 280 | use_cudnn_on_gpu=True, 281 | data_format="NHWC", 282 | dilations=dilations, 283 | name='conv2d') 284 | 285 | if use_bias: 286 | if pretrained_dict is not None and name in pretrained_dict: 287 | try: 288 | load_b = pretrained_dict[name][1] 289 | except KeyError: 290 | load_b = pretrained_dict[name]['biases'] 291 | print('Load {} biases!'.format(name)) 292 | 293 | load_b = np.reshape(load_b, [out_dim]) 294 | init_b = tf.constant_initializer(load_b) 295 | 296 | biases = tf.get_variable('biases', 297 | [out_dim], 298 | initializer=init_b, 299 | trainable=trainable) 300 | outputs += biases 301 | 302 | # if bn is True: 303 | # outputs = layers.batch_norm(outputs, train=is_training, name='bn') 304 | 305 | layer_dict['cur_input'] = nl(outputs) 306 | layer_dict[name] = layer_dict['cur_input'] 307 | return layer_dict['cur_input'] 308 | 309 | def drop_out(layer_dict, is_training, inputs=None, keep_prob=0.5): 310 | if inputs is None: 311 | inputs = layer_dict['cur_input'] 312 | if is_training: 313 | layer_dict['cur_input'] = tf.nn.dropout(inputs, keep_prob=keep_prob) 314 | else: 315 | layer_dict['cur_input'] = inputs 316 | return layer_dict['cur_input'] 317 | 318 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Autoencoders (AAE) 2 | 3 | - Tensorflow implementation of [Adversarial Autoencoders](https://arxiv.org/abs/1511.05644) (ICLR 2016) 4 | - Similar to variational autoencoder (VAE), AAE imposes a prior on the latent variable z. Howerver, instead of maximizing the evidence lower bound (ELBO) like VAE, AAE utilizes a adversarial network structure to guides the model distribution of z to match the prior distribution. 5 | - This repository contains reproduce of several experiments mentioned in the paper. 6 | 7 | ## Requirements 8 | - Python 3.3+ 9 | - [TensorFlow 1.9+](https://www.tensorflow.org/) 10 | - [TensorFlow Probability](https://github.com/tensorflow/probability) 11 | - [Numpy](http://www.numpy.org/) 12 | - [Scipy](https://www.scipy.org/) 13 | 14 | 15 | ## Implementation details 16 | - All the models of AAE are defined in [src/models/aae.py](src/models/aae.py). 17 | - Model corresponds to fig 1 and 3 in the paper can be found here: [train](https://github.com/conan7882/adversarial-autoencoders-tf/blob/master/src/models/aae.py#L110) and [test](https://github.com/conan7882/adversarial-autoencoders-tf/blob/master/src/models/aae.py#L164). 18 | - Model corresponds to fig 6 in the paper can be found here: [train](https://github.com/conan7882/adversarial-autoencoders-tf/blob/master/src/models/aae.py#L110) and [test](https://github.com/conan7882/adversarial-autoencoders-tf/blob/master/src/models/aae.py#L148). 19 | - Model corresponds to fig 8 in the paper can be found here: [train](https://github.com/conan7882/adversarial-autoencoders-tf/blob/master/src/models/aae.py#L71) and [test](https://github.com/conan7882/adversarial-autoencoders-tf/blob/master/src/models/aae.py#L182). 20 | - Examples of how to use AAE models can be found in [experiment/aae_mnist.py](experiment/aae_mnist.py). 21 | - Encoder, decoder and all discriminators contain two fully connected layers with 1000 hidden units and RelU activation function. Decoder and all discriminators contain an additional fully connected layer for output. 22 | - Images are normalized to [-1, 1] before fed into the encoder and tanh is used as the output nonlinear of decoder. 23 | - All the sub-networks are optimized by Adam optimizer with `beta1 = 0.5`. 24 | 25 | ## Preparation 26 | - Download the MNIST dataset from [here](http://yann.lecun.com/exdb/mnist/). 27 | - Setup path in [`experiment/aae_mnist.py`](experiment/aae_mnist.pyy): 28 | `DATA_PATH ` is the path to put MNIST dataset. 29 | `SAVE_PATH ` is the path to save output images and trained model. 30 | 31 | ## Usage 32 | The script [experiment/aae_mnist.py](experiment/aae_mnist.py) contains all the experiments shown here. Detailed usage for each experiment will be describe later along with the results. 33 | ### Argument 34 | * `--train`: Train the model of Fig 1 and 3 in the paper. 35 | * `--train_supervised`: Train the model of Fig 6 in the paper. 36 | * `--train_semisupervised`: Train the model of Fig 8 in the paper. 37 | * `--label`: Incorporate label information in the adversarial regularization (Fig 3 in the paper). 38 | * `--generate`: Randomly sample images from trained model. 39 | * `--viz`: Visualize latent space and data manifold (only when `--ncode` is 2). 40 | * `--supervise`: Sampling from supervised model (Fig 6 in the paper) when `--generate` is True. 41 | * `--load`: The epoch ID of pre-trained model to be restored. 42 | * `--ncode`: Dimension of code. Default: `2` 43 | * `--dist_type`: Type of the prior distribution used to impose on the hidden codes. Default: `gaussian`. `gmm` for Gaussian mixture distribution. 44 | * `--noise`: Add noise to encoder input (Gaussian with std=0.6). 45 | * `--lr`: Initial learning rate. Default: `2e-4`. 46 | * `--dropout`: Keep probability for dropout. Default: `1.0`. 47 | * `--bsize`: Batch size. Default: `128`. 48 | * `--maxepoch`: Max number of epochs. Default: `100`. 49 | * `--encw`: Weight of autoencoder loss. Default: `1.0`. 50 | * `--genw`: Weight of z generator loss. Default: `6.0`. 51 | * `--disw`: Weight of z discriminator loss. Default: `6.0`. 52 | * `--clsw`: Weight of semi-supervised loss. Default: `1.0`. 53 | * `--ygenw`: Weight of y generator loss. Default: `6.0`. 54 | * `--ydisw`: Weight of y discriminator loss. Default: `6.0`. 55 | 56 | ## 1. Adversarial Autoencoder 57 | 58 | ### Architecture 59 | *Architecture* | *Description* 60 | :---: | :--- | 61 | | The top row is an autoencoder. z is sampled through the re-parameterization trick discussed in [variational autoencoder paper](https://arxiv.org/abs/1312.6114). The bottom row is a discriminator to separate samples generate from the encoder and samples from the prior distribution p(z). 62 | 63 | ### Hyperparameters 64 | *name* | *value* | 65 | :---| :---| 66 | Reconstruction Loss Weight | 1.0 | 67 | Latent z G/D Loss Weight | 6.0 / 6.0 | 68 | Batch Size | 128 | 69 | Max Epoch | 400 | 70 | Learning Rate | 2e-4 (initial) / 2e-5 (100 epochs) / 2e-6 (300 epochs) 71 | 72 | ### Usage 73 | 74 | - Training. Summary, randomly sampled images and latent space during training will be saved in `SAVE_PATH`. 75 | 76 | ``` 77 | python aae_mnist.py --train \ 78 | --ncode CODE_DIM \ 79 | --dist_type TYPE_OF_PRIOR (`gaussian` or `gmm`) 80 | ``` 81 | 82 | - Random sample data from trained model. Image will be saved in `SAVE_PATH` with name `generate_im.png`. 83 | ``` 84 | python aae_mnist.py --generate \ 85 | --ncode CODE_DIM \ 86 | --dist_type TYPE_OF_PRIOR (`gaussian` or `gmm`)\ 87 | --load RESTORE_MODEL_ID 88 | ``` 89 | - Visualize latent space and data manifold (only when code dim = 2). Image will be saved in `SAVE_PATH` with name `generate_im.png` and `latent.png`. For Gaussian distribution, there will be one image for data manifold. For mixture of 10 2D Gaussian, there will be 10 images of data manifold for each component of the distribution. 90 | ``` 91 | python aae_mnist.py --viz \ 92 | --ncode CODE_DIM \ 93 | --dist_type TYPE_OF_PRIOR (`gaussian` or `gmm`)\ 94 | --load RESTORE_MODEL_ID 95 | ``` 96 | 104 | 105 | ### Result 106 | - For 2D Gaussian, we can see sharp transitions (no gaps) as mentioned in the paper. Also, from the learned manifold, we can see almost all the sampled images are readable. 107 | - For mixture of 10 Gaussian, I just uniformly sample images in a 2D square space as I did for 2D Gaussian instead of sampling along the axes of the corresponding mixture component, which will be shown in the next section. We can see in the gap area between two component, it is less likely to generate good samples. 108 | 109 | *Prior Distribution* | *Learned Coding Space* | *Learned Manifold* 110 | :---: | :---: | :---: | 111 | | | 112 | | | 113 | 114 | ## 2. Incorporating label in the Adversarial Regularization 115 | 116 | ### Architecture 117 | *Architecture* | *Description* 118 | :---: | :--- | 119 | | The only difference from previous model is that the one-hot label is used as input of encoder and there is one extra class for unlabeled data. For mixture of Gaussian prior, real samples are drawn from each components for each labeled class and for unlabeled data, real samples are drawn from the mixture distribution. 120 | 121 | ### Hyperparameters 122 | Hyperparameters are the same as previous section. 123 | 124 | ### Usage 125 | - Training. Summary, randomly sampled images and latent space will be saved in `SAVE_PATH`. 126 | 127 | ``` 128 | python aae_mnist.py --train --label\ 129 | --ncode CODE_DIM \ 130 | --dist_type TYPE_OF_PRIOR (`gaussian` or `gmm`) 131 | ``` 132 | 133 | - Random sample data from trained model. Image will be saved in `SAVE_PATH` with name `generate_im.png`. 134 | ``` 135 | python aae_mnist.py --generate --ncode --label --dist_type --load 136 | ``` 137 | 138 | - Visualize latent space and data manifold (only when code dim = 2). Image will be saved in `SAVE_PATH` with name `generate_im.png` and `latent.png`. For Gaussian distribution, there will be one image for data manifold. For mixture of 10 2D Gaussian, there will be 10 images of data manifold for each component of the distribution. 139 | ``` 140 | python aae_mnist.py --viz --label \ 141 | --ncode CODE_DIM \ 142 | --dist_type TYPE_OF_PRIOR (`gaussian` or `gmm`) \ 143 | --load RESTORE_MODEL_ID 144 | ``` 145 | ### Result 146 | - Compare with the result in the previous section, incorporating labeling information provides better fitted distribution for codes. 147 | - The learned manifold images demonstrate that each Gaussian component corresponds to the one class of digit. However, the style representation is not consistently represented within each mixture component as shown in the paper. For example, the right most column of the first row experiment, the lower right of digit 1 tilt to left while the lower right of digit 9 tilt to right. 148 | 149 | *Number of Label Used* | *Learned Coding Space* | *Learned Manifold* 150 | :--- | :---: | :---: | 151 | **Use full label**| | 152 | **10k labeled data and 40k unlabeled data** | | 153 | 154 | ### 3. Supervised Adversarial Autoencoders 155 | 156 | ### Architecture 157 | *Architecture* | *Description* 158 | :---: | :--- | 159 | | The decoder takes code as well as a one-hot vector encoding the label as input. Then it forces the network learn the code independent of the label. 160 | 161 | ### Hyperparameters 162 | 163 | ### Usage 164 | - Training. Summary and randomly sampled images will be saved in `SAVE_PATH`. 165 | 166 | ``` 167 | python aae_mnist.py --train_supervised \ 168 | --ncode CODE_DIM 169 | ``` 170 | 171 | - Random sample data from trained model. Image will be saved in `SAVE_PATH` with name `sample_style.png`. 172 | ``` 173 | python aae_mnist.py --generate --supervise\ 174 | --ncode CODE_DIM \ 175 | --load RESTORE_MODEL_ID 176 | ``` 177 | 178 | ### Result 179 | - The result images are generated by using the same code for each column and the same digit label for each row. 180 | - When code dimension is 2, we can see each column consists the same style clearly. But for dimension 10, we can hardly read some digits. Maybe there are some issues of implementation or the hyper-parameters are not properly picked, which makes the code still depend on the label. 181 | 182 | *Code Dim=2* | *Code Dim=10* | 183 | :---: | :---: | 184 | | | 185 | 186 | ### 4. Semi-supervised learning 187 | 188 | ### Architecture 189 | *Architecture* | *Description* 190 | :---: | :--- | 191 | | The encoder outputs code z as well as the estimated label y. Encoder again takes code z and one-hot label y as input. A Gaussian distribution is imposed on code z and a Categorical distribution is imposed on label y. In this implementation, the autoencoder is trained by semi-supervised classification phase every ten training steps when using 1000 label images and the one-hot label y is approximated by output of softmax. 192 | 193 | ### Hyperparameters 194 | *name* | *value* | 195 | :---| :---| 196 | Dimention of z | 10 | 197 | Reconstruction Loss Weight | 1.0 | 198 | Letant z G/D Loss Weight | 6.0 / 6.0 | 199 | Letant y G/D Loss Weight | 6.0 / 6.0 | 200 | Batch Size | 128 | 201 | Max Epoch | 250 | 202 | Learning Rate | 1e-4 (initial) / 1e-5 (150 epochs) / 1e-6 (200 epochs) 203 | 204 | ### Usage 205 | - Training. Summary will be saved in `SAVE_PATH`. 206 | 207 | ``` 208 | python aae_mnist.py \ 209 | --ncode 10 \ 210 | --train_semisupervised \ 211 | --lr 2e-4 \ 212 | --maxepoch 250 213 | ``` 214 | 215 | ### Result 216 | - 1280 labels are used (128 labeled images per class) 217 | 218 | learning curve for training set (computed only on the training set with labels) 219 | ![train](figs/semi_train.png) 220 | 221 | learning curve for testing set 222 | - The accuracy on testing set is 97.10% around 200 epochs. 223 | ![valid](figs/semi_valid.png) 224 | 225 | 226 | -------------------------------------------------------------------------------- /experiment/aae_mnist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: aae_mnist.py 4 | # Author: Qian Ge 5 | 6 | import sys 7 | import argparse 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | sys.path.append('../') 12 | from src.dataflow.mnist import MNISTData 13 | from src.models.aae import AAE 14 | from src.helper.trainer import Trainer 15 | from src.helper.generator import Generator 16 | from src.helper.visualizer import Visualizer 17 | 18 | 19 | DATA_PATH = '/home/qge2/workspace/data/MNIST_data/' 20 | SAVE_PATH = '/home/qge2/workspace/data/out/vae/' 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--train', action='store_true', 25 | help='Train the model of Fig 1 and 3 in the paper.') 26 | parser.add_argument('--train_supervised', action='store_true', 27 | help='Train the model of Fig 6 in the paper.') 28 | parser.add_argument('--train_semisupervised', action='store_true', 29 | help='Train the model of Fig 8 in the paper.') 30 | parser.add_argument('--label', action='store_true', 31 | help='Incorporate label info (Fig 3 in the paper).') 32 | parser.add_argument('--generate', action='store_true', 33 | help='Sample images from trained model.') 34 | parser.add_argument('--viz', action='store_true', 35 | help='Visualize learned model when ncode=2.') 36 | parser.add_argument('--supervise', action='store_true', 37 | help='Sampling from supervised model (Fig 6 in the paper).') 38 | parser.add_argument('--load', type=int, default=99, 39 | help='The epoch ID of pre-trained model to be restored.') 40 | 41 | parser.add_argument('--ncode', type=int, default=2, 42 | help='Dimension of code') 43 | parser.add_argument('--dist_type', type=str, default='gaussian', 44 | help='Prior distribution to be imposed on latent z (gaussian and gmm).') 45 | parser.add_argument('--noise', action='store_true', 46 | help='Add noise to encoder input (Gaussian with std=0.6).') 47 | 48 | parser.add_argument('--lr', type=float, default=2e-4, 49 | help='Initial learning rate') 50 | parser.add_argument('--dropout', type=float, default=1.0, 51 | help='Keep probability for dropout') 52 | parser.add_argument('--bsize', type=int, default=128, 53 | help='Batch size') 54 | parser.add_argument('--maxepoch', type=int, default=100, 55 | help='Max number of epochs') 56 | 57 | parser.add_argument('--encw', type=float, default=1., 58 | help='Weight of autoencoder loss') 59 | parser.add_argument('--genw', type=float, default=6., 60 | help='Weight of z generator loss') 61 | parser.add_argument('--disw', type=float, default=6., 62 | help='Weight of z discriminator loss') 63 | 64 | parser.add_argument('--clsw', type=float, default=1., 65 | help='Weight of semi-supervised loss') 66 | parser.add_argument('--ygenw', type=float, default=6., 67 | help='Weight of y generator loss') 68 | parser.add_argument('--ydisw', type=float, default=6., 69 | help='Weight of y discriminator loss') 70 | 71 | return parser.parse_args() 72 | 73 | 74 | def preprocess_im(im): 75 | """ normalize input image to [-1., 1.] """ 76 | im = im / 255. * 2. - 1. 77 | return im 78 | 79 | def read_train_data(batch_size, n_use_label=None, n_use_sample=None): 80 | """ Function for load training data 81 | 82 | If n_use_label or n_use_sample is not None, samples will be 83 | randomly picked to have a balanced number of examples 84 | 85 | Args: 86 | batch_size (int): batch size 87 | n_use_label (int): how many labels are used for training 88 | n_use_sample (int): how many samples are used for training 89 | 90 | Retuns: 91 | MNISTData 92 | 93 | """ 94 | data = MNISTData('train', 95 | data_dir=DATA_PATH, 96 | shuffle=True, 97 | pf=preprocess_im, 98 | n_use_label=n_use_label, 99 | n_use_sample=n_use_sample, 100 | batch_dict_name=['im', 'label']) 101 | data.setup(epoch_val=0, batch_size=batch_size) 102 | return data 103 | 104 | def read_valid_data(batch_size): 105 | """ Function for load validation data """ 106 | data = MNISTData('test', 107 | data_dir=DATA_PATH, 108 | shuffle=True, 109 | pf=preprocess_im, 110 | batch_dict_name=['im', 'label']) 111 | data.setup(epoch_val=0, batch_size=batch_size) 112 | return data 113 | 114 | def semisupervised_train(): 115 | """ Function for semisupervised training (Fig 8 in the paper) 116 | 117 | Validation will be processed after each epoch of training 118 | Loss of each modules will be averaged and saved in summaries 119 | every 100 steps. 120 | """ 121 | 122 | FLAGS = get_args() 123 | # load dataset 124 | train_data_unlabel = read_train_data(FLAGS.bsize) 125 | train_data_label = read_train_data(FLAGS.bsize, n_use_sample=1280) 126 | train_data = {'unlabeled': train_data_unlabel, 'labeled': train_data_label} 127 | valid_data = read_valid_data(FLAGS.bsize) 128 | 129 | # create an AAE model for semisupervised training 130 | train_model = AAE( 131 | n_code=FLAGS.ncode, wd=0, n_class=10, add_noise=FLAGS.noise, 132 | enc_weight=FLAGS.encw, gen_weight=FLAGS.genw, dis_weight=FLAGS.disw, 133 | cat_dis_weight=FLAGS.ydisw, cat_gen_weight=FLAGS.ygenw, cls_weight=FLAGS.clsw) 134 | train_model.create_semisupervised_train_model() 135 | 136 | # create an separated AAE model for semisupervised validation 137 | # shared weights with training model 138 | cls_valid_model = AAE(n_code=FLAGS.ncode, n_class=10) 139 | cls_valid_model.create_semisupervised_test_model() 140 | 141 | # initialize a trainer for training 142 | trainer = Trainer(train_model, 143 | cls_valid_model=cls_valid_model, 144 | generate_model=None, 145 | train_data=train_data, 146 | init_lr=FLAGS.lr, 147 | save_path=SAVE_PATH) 148 | 149 | sessconfig = tf.ConfigProto() 150 | sessconfig.gpu_options.allow_growth = True 151 | with tf.Session(config=sessconfig) as sess: 152 | writer = tf.summary.FileWriter(SAVE_PATH) 153 | sess.run(tf.global_variables_initializer()) 154 | writer.add_graph(sess.graph) 155 | for epoch_id in range(FLAGS.maxepoch): 156 | trainer.train_semisupervised_epoch( 157 | sess, ae_dropout=FLAGS.dropout, summary_writer=writer) 158 | trainer.valid_semisupervised_epoch( 159 | sess, valid_data, summary_writer=writer) 160 | 161 | def supervised_train(): 162 | """ Function for supervised training (Fig 6 in the paper) 163 | 164 | Validation will be processed after each epoch of training. 165 | Loss of each modules will be averaged and saved in summaries 166 | every 100 steps. Every 10 epochs, 10 different style for 10 digits 167 | will be saved. 168 | """ 169 | 170 | FLAGS = get_args() 171 | # load dataset 172 | train_data = read_train_data(FLAGS.bsize) 173 | valid_data = read_valid_data(FLAGS.bsize) 174 | 175 | # create an AAE model for supervised training 176 | model = AAE(n_code=FLAGS.ncode, wd=0, n_class=10, 177 | use_supervise=True, add_noise=FLAGS.noise, 178 | enc_weight=FLAGS.encw, gen_weight=FLAGS.genw, 179 | dis_weight=FLAGS.disw) 180 | model.create_train_model() 181 | 182 | # Create an separated AAE model for supervised validation 183 | # shared weights with training model. This model is used to 184 | # generate 10 different style for 10 digits for every 10 epochs. 185 | valid_model = AAE(n_code=FLAGS.ncode, use_supervise=True, n_class=10) 186 | valid_model.create_generate_style_model(n_sample=10) 187 | 188 | # initialize a trainer for training 189 | trainer = Trainer(model, valid_model, train_data, 190 | init_lr=FLAGS.lr, save_path=SAVE_PATH) 191 | # initialize a generator for generating style images 192 | generator = Generator( 193 | generate_model=valid_model, save_path=SAVE_PATH, n_labels=10) 194 | 195 | sessconfig = tf.ConfigProto() 196 | sessconfig.gpu_options.allow_growth = True 197 | with tf.Session(config=sessconfig) as sess: 198 | writer = tf.summary.FileWriter(SAVE_PATH) 199 | saver = tf.train.Saver() 200 | sess.run(tf.global_variables_initializer()) 201 | writer.add_graph(sess.graph) 202 | 203 | for epoch_id in range(FLAGS.maxepoch): 204 | trainer.train_z_gan_epoch( 205 | sess, ae_dropout=FLAGS.dropout, summary_writer=writer) 206 | trainer.valid_epoch(sess, dataflow=valid_data, summary_writer=writer) 207 | 208 | if epoch_id % 10 == 0: 209 | saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id)) 210 | generator.sample_style(sess, valid_data, plot_size=10, 211 | file_id=epoch_id, n_sample=10) 212 | saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id)) 213 | 214 | def train(): 215 | """ Function for unsupervised training and incorporate 216 | label info in adversarial regularization 217 | (Fig 1 and 3 in the paper) 218 | 219 | Validation will be processed after each epoch of training. 220 | Loss of each modules will be averaged and saved in summaries 221 | every 100 steps. Random samples and learned latent space will 222 | be saved for every 10 epochs. 223 | """ 224 | 225 | FLAGS = get_args() 226 | # image size for visualization. plot_size * plot_size digits will be visualized. 227 | plot_size = 20 228 | 229 | # Use 10000 labels info to train latent space 230 | n_use_label = 10000 231 | 232 | # load data 233 | train_data = read_train_data(FLAGS.bsize, n_use_label=n_use_label) 234 | valid_data = read_valid_data(FLAGS.bsize) 235 | 236 | # create an AAE model for training 237 | model = AAE(n_code=FLAGS.ncode, wd=0, n_class=10, 238 | use_label=FLAGS.label, add_noise=FLAGS.noise, 239 | enc_weight=FLAGS.encw, gen_weight=FLAGS.genw, 240 | dis_weight=FLAGS.disw) 241 | model.create_train_model() 242 | 243 | # Create an separated AAE model for validation shared weights 244 | # with training model. This model is used to 245 | # randomly sample model data every 10 epoches. 246 | valid_model = AAE(n_code=FLAGS.ncode, n_class=10) 247 | valid_model.create_generate_model(b_size=400) 248 | 249 | # initialize a trainer for training 250 | trainer = Trainer(model, valid_model, train_data, 251 | distr_type=FLAGS.dist_type, use_label=FLAGS.label, 252 | init_lr=FLAGS.lr, save_path=SAVE_PATH) 253 | # Initialize a visualizer and a generator to monitor learned 254 | # latent space and data generation. 255 | # Latent space visualization only for code dim = 2 256 | if FLAGS.ncode == 2: 257 | visualizer = Visualizer(model, save_path=SAVE_PATH) 258 | generator = Generator(generate_model=valid_model, save_path=SAVE_PATH, 259 | distr_type=FLAGS.dist_type, n_labels=10, 260 | use_label=FLAGS.label) 261 | 262 | sessconfig = tf.ConfigProto() 263 | sessconfig.gpu_options.allow_growth = True 264 | with tf.Session(config=sessconfig) as sess: 265 | writer = tf.summary.FileWriter(SAVE_PATH) 266 | saver = tf.train.Saver() 267 | sess.run(tf.global_variables_initializer()) 268 | writer.add_graph(sess.graph) 269 | 270 | for epoch_id in range(FLAGS.maxepoch): 271 | trainer.train_z_gan_epoch(sess, ae_dropout=FLAGS.dropout, summary_writer=writer) 272 | trainer.valid_epoch(sess, dataflow=valid_data, summary_writer=writer) 273 | 274 | if epoch_id % 10 == 0: 275 | saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id)) 276 | generator.generate_samples(sess, plot_size=plot_size, file_id=epoch_id) 277 | if FLAGS.ncode == 2: 278 | visualizer.viz_2Dlatent_variable(sess, valid_data, file_id=epoch_id) 279 | saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id)) 280 | 281 | def generate(): 282 | """ function for sampling images from trained model """ 283 | FLAGS = get_args() 284 | plot_size = 20 285 | 286 | # Greate model for sampling 287 | generate_model = AAE(n_code=FLAGS.ncode, n_class=10) 288 | 289 | if FLAGS.supervise: 290 | # create samping model of Fig 6 in the paper 291 | generate_model.create_generate_style_model(n_sample=10) 292 | else: 293 | # create samping model of Fig 1 and 3 in the paper 294 | generate_model.create_generate_model(b_size=plot_size*plot_size) 295 | 296 | # initalize the Generator for sampling 297 | generator = Generator(generate_model=generate_model, save_path=SAVE_PATH, 298 | distr_type=FLAGS.dist_type, n_labels=10, use_label=FLAGS.label) 299 | 300 | sessconfig = tf.ConfigProto() 301 | sessconfig.gpu_options.allow_growth = True 302 | with tf.Session(config=sessconfig) as sess: 303 | saver = tf.train.Saver() 304 | sess.run(tf.global_variables_initializer()) 305 | saver.restore(sess, '{}aae-epoch-{}'.format(SAVE_PATH, FLAGS.load)) 306 | if FLAGS.supervise: 307 | generator.sample_style(sess, plot_size=10, n_sample=10) 308 | else: 309 | generator.generate_samples(sess, plot_size=plot_size) 310 | 311 | def visualize(): 312 | """ function for visualize latent space of trained model when ncode = 2 """ 313 | FLAGS = get_args() 314 | if FLAGS.ncode != 2: 315 | raise ValueError('Visualization only for ncode = 2!') 316 | 317 | plot_size = 20 318 | 319 | # read validation set 320 | valid_data = MNISTData('test', 321 | data_dir=DATA_PATH, 322 | shuffle=True, 323 | pf=preprocess_im, 324 | batch_dict_name=['im', 'label']) 325 | valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize) 326 | 327 | # create model for computing the latent z 328 | model = AAE(n_code=FLAGS.ncode, use_label=FLAGS.label, n_class=10) 329 | model.create_train_model() 330 | 331 | # create model for sampling images 332 | valid_model = AAE(n_code=FLAGS.ncode) 333 | valid_model.create_generate_model(b_size=400) 334 | 335 | # initialize Visualizer and Generator 336 | visualizer = Visualizer(model, save_path=SAVE_PATH) 337 | generator = Generator(generate_model=valid_model, save_path=SAVE_PATH, 338 | distr_type=FLAGS.dist_type, n_labels=10, 339 | use_label=FLAGS.label) 340 | 341 | sessconfig = tf.ConfigProto() 342 | sessconfig.gpu_options.allow_growth = True 343 | with tf.Session(config=sessconfig) as sess: 344 | saver = tf.train.Saver() 345 | sess.run(tf.global_variables_initializer()) 346 | saver.restore(sess, '{}aae-epoch-{}'.format(SAVE_PATH, FLAGS.load)) 347 | # visulize the learned latent space 348 | visualizer.viz_2Dlatent_variable(sess, valid_data) 349 | # visulize the learned manifold 350 | generator.generate_samples(sess, plot_size=plot_size, manifold=True) 351 | 352 | if __name__ == '__main__': 353 | FLAGS = get_args() 354 | 355 | if FLAGS.train: 356 | train() 357 | elif FLAGS.train_supervised: 358 | supervised_train() 359 | elif FLAGS.train_semisupervised: 360 | semisupervised_train() 361 | elif FLAGS.generate: 362 | generate() 363 | elif FLAGS.viz: 364 | visualize() 365 | -------------------------------------------------------------------------------- /src/helper/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: trainer.py 4 | # Author: Qian Ge 5 | 6 | import os 7 | # import scipy.misc 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | import src.utils.viz as viz 12 | import src.models.distribution as distribution 13 | 14 | 15 | def display(global_step, 16 | step, 17 | scaler_sum_list, 18 | name_list, 19 | collection, 20 | summary_val=None, 21 | summary_writer=None, 22 | ): 23 | print('[step: {}]'.format(global_step), end='') 24 | for val, name in zip(scaler_sum_list, name_list): 25 | print(' {}: {:.4f}'.format(name, val * 1. / step), end='') 26 | print('') 27 | if summary_writer is not None: 28 | s = tf.Summary() 29 | for val, name in zip(scaler_sum_list, name_list): 30 | s.value.add(tag='{}/{}'.format(collection, name), 31 | simple_value=val * 1. / step) 32 | summary_writer.add_summary(s, global_step) 33 | if summary_val is not None: 34 | summary_writer.add_summary(summary_val, global_step) 35 | 36 | class Trainer(object): 37 | def __init__(self, train_model, generate_model, train_data, cls_valid_model=None, distr_type='gaussian', 38 | use_label=False, init_lr=1e-3, save_path=None): 39 | 40 | self._save_path = save_path 41 | 42 | self._t_model = train_model 43 | 44 | self._train_data = train_data 45 | self._lr = init_lr 46 | self._use_label = use_label 47 | self._dist = distr_type 48 | 49 | self._train_op = train_model.get_reconstruction_train_op() 50 | self._loss_op = train_model.get_reconstruction_loss() 51 | self._train_summary_op = train_model.get_train_summary() 52 | self._valid_summary_op = train_model.get_valid_summary() 53 | 54 | try: 55 | self._train_d_op = train_model.get_latent_discrimator_train_op() 56 | self._train_g_op = train_model.get_latent_generator_train_op() 57 | self._d_loss_op = train_model.latent_d_loss 58 | self._g_loss_op = train_model.latent_g_loss 59 | except (AttributeError, KeyError): 60 | pass 61 | 62 | try: 63 | self._train_cat_d_op = train_model.get_cat_discrimator_train_op() 64 | self._train_cat_g_op = train_model.get_cat_generator_train_op() 65 | self._cat_d_loss_op = train_model.cat_d_loss 66 | self._cat_g_loss_op = train_model.cat_g_loss 67 | except (AttributeError, KeyError): 68 | pass 69 | 70 | try: 71 | self._cls_train_op = train_model.get_cls_train_op() 72 | self._cls_loss_op = train_model.get_cls_loss() 73 | self._cls_accuracy_op = train_model.get_cls_accuracy() 74 | except (AttributeError, KeyError): 75 | pass 76 | 77 | if generate_model is not None: 78 | self._g_model = generate_model 79 | self._generate_op = generate_model.layers['generate'] 80 | self._generate_summary_op = generate_model.get_generate_summary() 81 | 82 | if cls_valid_model is not None: 83 | self._cls_v_model = cls_valid_model 84 | self._cls_valid_loss_op = cls_valid_model.get_cls_loss() 85 | self._cls_v_accuracy_op = cls_valid_model.get_cls_accuracy() 86 | 87 | self.global_step = 0 88 | self.epoch_id = 0 89 | 90 | def valid_semisupervised_epoch(self, sess, dataflow, summary_writer=None): 91 | dataflow.setup(epoch_val=0, batch_size=dataflow.batch_size) 92 | display_name_list = ['cls_loss', 'cls_accuracy'] 93 | cur_summary = None 94 | step = 0 95 | cls_loss_sum = 0 96 | cls_accuracy_sum = 0 97 | while dataflow.epochs_completed < 1: 98 | step += 1 99 | batch_data = dataflow.next_batch_dict() 100 | im = batch_data['im'] 101 | label = batch_data['label'] 102 | 103 | cls_loss, cls_accuracy = sess.run( 104 | [self._cls_valid_loss_op, self._cls_v_accuracy_op], 105 | feed_dict={self._cls_v_model.image: im, 106 | self._cls_v_model.label: label}) 107 | cls_loss_sum += cls_loss 108 | cls_accuracy_sum += cls_accuracy 109 | 110 | print('[Valid]: ', end='') 111 | display(self.global_step, 112 | step, 113 | [cls_loss_sum, cls_accuracy_sum], 114 | display_name_list, 115 | 'valid', 116 | summary_val=cur_summary, 117 | summary_writer=summary_writer) 118 | 119 | 120 | def train_semisupervised_epoch(self, sess, ae_dropout=1.0, summary_writer=None): 121 | label_data = self._train_data['labeled'] 122 | unlabel_data = self._train_data['unlabeled'] 123 | display_name_list = ['loss', 'z_d_loss', 'z_g_loss', 'y_d_loss', 'y_g_loss', 124 | 'cls_loss', 'cls_accuracy'] 125 | cur_summary = None 126 | cur_epoch = unlabel_data.epochs_completed 127 | 128 | self.epoch_id += 1 129 | 130 | if self.epoch_id == 150: 131 | self._lr = self._lr / 10 132 | if self.epoch_id == 200: 133 | self._lr = self._lr / 10 134 | 135 | step = 0 136 | loss_sum = 0 137 | z_d_loss_sum = 0 138 | z_g_loss_sum = 0 139 | y_d_loss_sum = 0 140 | y_g_loss_sum = 0 141 | cls_loss_sum = 0 142 | cls_accuracy_sum = 0 143 | while cur_epoch == unlabel_data.epochs_completed: 144 | self.global_step += 1 145 | step += 1 146 | 147 | batch_data = unlabel_data.next_batch_dict() 148 | im = batch_data['im'] 149 | label = batch_data['label'] 150 | 151 | z_real_sample = distribution.diagonal_gaussian( 152 | len(im), self._t_model.n_code, mean=0, var=1.0) 153 | 154 | y_real_sample = np.random.choice(self._t_model.n_class, len(im)) 155 | # a = np.array([1, 0, len(im)]) 156 | # b = np.zeros((len(im), self._t_model.n_class)) 157 | # b[np.arange(len(im)), y_real_sample] = 1 158 | # y_real_sample = b 159 | # print(y_real_sample) 160 | 161 | # train autoencoder 162 | _, loss, cur_summary = sess.run( 163 | [self._train_op, self._loss_op, self._train_summary_op], 164 | feed_dict={self._t_model.image: im, 165 | self._t_model.lr: self._lr, 166 | self._t_model.keep_prob: ae_dropout, 167 | self._t_model.label: label, 168 | self._t_model.real_distribution: z_real_sample, 169 | self._t_model.real_y: y_real_sample}) 170 | 171 | # z discriminator 172 | _, z_d_loss = sess.run( 173 | [self._train_d_op, self._d_loss_op], 174 | feed_dict={self._t_model.image: im, 175 | # self._t_model.label: label, 176 | self._t_model.lr: self._lr, 177 | self._t_model.keep_prob: 1., 178 | self._t_model.real_distribution: z_real_sample}) 179 | 180 | # z generator 181 | _, z_g_loss = sess.run( 182 | [self._train_g_op, self._g_loss_op], 183 | feed_dict={self._t_model.image: im, 184 | # self._t_model.label: label, 185 | self._t_model.lr: self._lr, 186 | self._t_model.keep_prob: 1.}) 187 | 188 | # y discriminator 189 | _, y_d_loss = sess.run( 190 | [self._train_cat_d_op, self._cat_d_loss_op], 191 | feed_dict={self._t_model.image: im, 192 | # self._t_model.label: label, 193 | self._t_model.lr: self._lr, 194 | self._t_model.keep_prob: 1., 195 | self._t_model.real_y: y_real_sample}) 196 | 197 | # y generator 198 | _, y_g_loss = sess.run( 199 | [self._train_cat_g_op, self._cat_g_loss_op], 200 | feed_dict={self._t_model.image: im, 201 | # self._t_model.label: label, 202 | self._t_model.lr: self._lr, 203 | self._t_model.keep_prob: 1.}) 204 | 205 | batch_data = label_data.next_batch_dict() 206 | im = batch_data['im'] 207 | label = batch_data['label'] 208 | # semisupervise 209 | if self.global_step % 10 == 0: 210 | _, cls_loss, cls_accuracy = sess.run( 211 | [self._cls_train_op, self._cls_loss_op, self._cls_accuracy_op], 212 | feed_dict={self._t_model.image: im, 213 | self._t_model.label: label, 214 | self._t_model.lr: self._lr, 215 | self._t_model.keep_prob: 1.}) 216 | cls_loss_sum += cls_loss 217 | cls_accuracy_sum += cls_accuracy 218 | 219 | 220 | loss_sum += loss 221 | z_d_loss_sum += z_d_loss 222 | z_g_loss_sum += z_g_loss 223 | y_d_loss_sum += y_d_loss 224 | y_g_loss_sum += y_g_loss 225 | 226 | 227 | if step % 100 == 0: 228 | display(self.global_step, 229 | step, 230 | [loss_sum, z_d_loss_sum, z_g_loss_sum, y_d_loss_sum, y_g_loss_sum, 231 | cls_loss_sum * 10, cls_accuracy_sum * 10], 232 | display_name_list, 233 | 'train', 234 | summary_val=cur_summary, 235 | summary_writer=summary_writer) 236 | 237 | print('==== epoch: {}, lr:{} ===='.format(cur_epoch, self._lr)) 238 | display(self.global_step, 239 | step, 240 | [loss_sum, z_d_loss_sum, z_g_loss_sum, y_d_loss_sum, y_g_loss_sum], 241 | display_name_list, 242 | 'train', 243 | summary_val=cur_summary, 244 | summary_writer=summary_writer) 245 | 246 | 247 | def train_z_gan_epoch(self, sess, ae_dropout=1.0, summary_writer=None): 248 | self._t_model.set_is_training(True) 249 | display_name_list = ['loss', 'd_loss', 'g_loss'] 250 | cur_summary = None 251 | # if self.epoch_id == 50: 252 | # self._lr = self._lr / 10 253 | # if self.epoch_id == 200: 254 | # self._lr = self._lr / 10 255 | if self.epoch_id == 100: 256 | self._lr = self._lr / 10 257 | if self.epoch_id == 300: 258 | self._lr = self._lr / 10 259 | 260 | cur_epoch = self._train_data.epochs_completed 261 | 262 | step = 0 263 | loss_sum = 0 264 | d_loss_sum = 0 265 | g_loss_sum = 0 266 | self.epoch_id += 1 267 | while cur_epoch == self._train_data.epochs_completed: 268 | self.global_step += 1 269 | step += 1 270 | 271 | # batch_data = self._train_data.next_batch_dict() 272 | # im = batch_data['im'] 273 | # label = batch_data['label'] 274 | 275 | # _, d_loss = sess.run( 276 | # [self._train_d_op, self._d_loss_op], 277 | # feed_dict={self._t_model.image: im, 278 | # self._t_model.lr: self._lr, 279 | # self._t_model.keep_prob: 1.}) 280 | 281 | batch_data = self._train_data.next_batch_dict() 282 | im = batch_data['im'] 283 | label = batch_data['label'] 284 | 285 | if self._use_label: 286 | label_indices = label 287 | else: 288 | label_indices = None 289 | 290 | if self._dist == 'gmm': 291 | real_sample = distribution.gaussian_mixture( 292 | len(im), n_dim=self._t_model.n_code, n_labels=10, 293 | x_var=0.5, y_var=0.1, label_indices=label_indices) 294 | else: 295 | real_sample = distribution.diagonal_gaussian( 296 | len(im), self._t_model.n_code, mean=0, var=1.0) 297 | 298 | # train autoencoder 299 | _, loss, cur_summary = sess.run( 300 | [self._train_op, self._loss_op, self._train_summary_op], 301 | feed_dict={self._t_model.image: im, 302 | self._t_model.lr: self._lr, 303 | self._t_model.keep_prob: ae_dropout, 304 | self._t_model.label: label, 305 | self._t_model.real_distribution: real_sample}) 306 | 307 | # train discriminator 308 | 309 | _, d_loss = sess.run( 310 | [self._train_d_op, self._d_loss_op], 311 | feed_dict={self._t_model.image: im, 312 | self._t_model.label: label, 313 | self._t_model.lr: self._lr, 314 | self._t_model.keep_prob: 1., 315 | self._t_model.real_distribution: real_sample}) 316 | 317 | # train generator 318 | _, g_loss = sess.run( 319 | [self._train_g_op, self._g_loss_op], 320 | feed_dict={self._t_model.image: im, 321 | self._t_model.label: label, 322 | self._t_model.lr: self._lr, 323 | self._t_model.keep_prob: 1.}) 324 | 325 | # batch_data = self._train_data.next_batch_dict() 326 | # im = batch_data['im'] 327 | # label = batch_data['label'] 328 | loss_sum += loss 329 | d_loss_sum += d_loss 330 | g_loss_sum += g_loss 331 | 332 | if step % 100 == 0: 333 | display(self.global_step, 334 | step, 335 | [loss_sum, d_loss_sum, g_loss_sum], 336 | display_name_list, 337 | 'train', 338 | summary_val=cur_summary, 339 | summary_writer=summary_writer) 340 | 341 | print('==== epoch: {}, lr:{} ===='.format(cur_epoch, self._lr)) 342 | display(self.global_step, 343 | step, 344 | [loss_sum, d_loss_sum, g_loss_sum], 345 | display_name_list, 346 | 'train', 347 | summary_val=cur_summary, 348 | summary_writer=summary_writer) 349 | 350 | def train_epoch(self, sess, summary_writer=None): 351 | self._t_model.set_is_training(True) 352 | display_name_list = ['loss'] 353 | cur_summary = None 354 | 355 | cur_epoch = self._train_data.epochs_completed 356 | 357 | step = 0 358 | loss_sum = 0 359 | self.epoch_id += 1 360 | while cur_epoch == self._train_data.epochs_completed: 361 | self.global_step += 1 362 | step += 1 363 | 364 | batch_data = self._train_data.next_batch_dict() 365 | im = batch_data['im'] 366 | label = batch_data['label'] 367 | _, loss, cur_summary = sess.run( 368 | [self._train_op, self._loss_op, self._train_summary_op], 369 | feed_dict={self._t_model.image: im, 370 | self._t_model.lr: self._lr, 371 | self._t_model.keep_prob: 0.9}) 372 | 373 | loss_sum += loss 374 | 375 | if step % 100 == 0: 376 | display(self.global_step, 377 | step, 378 | [loss_sum], 379 | display_name_list, 380 | 'train', 381 | summary_val=cur_summary, 382 | summary_writer=summary_writer) 383 | 384 | print('==== epoch: {}, lr:{} ===='.format(cur_epoch, self._lr)) 385 | display(self.global_step, 386 | step, 387 | [loss_sum], 388 | display_name_list, 389 | 'train', 390 | summary_val=cur_summary, 391 | summary_writer=summary_writer) 392 | 393 | def valid_epoch(self, sess, dataflow=None, moniter_generation=False, summary_writer=None): 394 | # self._g_model.set_is_training(True) 395 | # display_name_list = ['loss'] 396 | # cur_summary = None 397 | 398 | dataflow.setup(epoch_val=0, batch_size=dataflow.batch_size) 399 | display_name_list = ['loss'] 400 | 401 | step = 0 402 | loss_sum = 0 403 | while dataflow.epochs_completed == 0: 404 | step += 1 405 | 406 | batch_data = dataflow.next_batch_dict() 407 | im = batch_data['im'] 408 | label = batch_data['label'] 409 | loss, valid_summary = sess.run( 410 | [self._loss_op, self._valid_summary_op], 411 | feed_dict={self._t_model.encoder_in: im, 412 | self._t_model.image: im, 413 | self._t_model.keep_prob: 1.0, 414 | self._t_model.label: label, 415 | }) 416 | loss_sum += loss 417 | 418 | print('[Valid]: ', end='') 419 | display(self.global_step, 420 | step, 421 | [loss_sum], 422 | display_name_list, 423 | 'valid', 424 | summary_val=None, 425 | summary_writer=summary_writer) 426 | dataflow.setup(epoch_val=0, batch_size=dataflow.batch_size) 427 | 428 | gen_im = sess.run(self._generate_op) 429 | if moniter_generation and self._save_path: 430 | im_save_path = os.path.join(self._save_path, 431 | 'generate_step_{}.png'.format(self.global_step)) 432 | viz.viz_batch_im(batch_im=gen_im, grid_size=[10, 10], 433 | save_path=im_save_path, gap=0, gap_color=0, 434 | shuffle=False) 435 | if summary_writer: 436 | cur_summary = sess.run(self._generate_summary_op) 437 | summary_writer.add_summary(cur_summary, self.global_step) 438 | summary_writer.add_summary(valid_summary, self.global_step) 439 | -------------------------------------------------------------------------------- /src/models/aae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: aae.py 4 | # Author: Qian Ge 5 | 6 | import tensorflow as tf 7 | from src.models.base import BaseModel 8 | import src.models.layers as L 9 | import src.models.modules as modules 10 | import src.models.ops as ops 11 | 12 | INIT_W = tf.contrib.layers.variance_scaling_initializer() 13 | 14 | class AAE(BaseModel): 15 | """ model of Adversarical Autoencoders """ 16 | 17 | def __init__(self, im_size=[28, 28], n_channel=1, n_class=None, n_code=1000, 18 | use_label=False, use_supervise=False, add_noise=False, wd=0, 19 | enc_weight=1., gen_weight=1., dis_weight=1., 20 | cat_dis_weight=1., cat_gen_weight=1., cls_weight=1.): 21 | """ 22 | Args: 23 | im_size (int or list of length 2): size of input image 24 | n_channel (int): number of input image channel (1 or 3) 25 | n_class (int): number of classes 26 | n_code (int): dimension of code 27 | use_label (bool): whether incoporate label information 28 | in the adversarial regularization or not 29 | use_supervise (bool): whether supervised training or not 30 | add_noise (bool): whether add noise to encoder input or not 31 | wd (float): weight decay 32 | enc_weight (float): weight of autoencoder loss 33 | gen_weight (float): weight of latent z generator loss 34 | dis_weight (float): weight of latent z discriminator loss 35 | cat_gen_weight (float): weight of label y generator loss 36 | cat_dis_weight (float): weight of label y discriminator loss 37 | cls_weight (float): weight of classification loss 38 | """ 39 | self._n_channel = n_channel 40 | self._wd = wd 41 | self.n_code = n_code 42 | self._im_size = im_size 43 | if use_supervise: 44 | use_label = False 45 | self._flag_label = use_label 46 | self._flag_supervise = use_supervise 47 | self._flag_noise = add_noise 48 | self.n_class = n_class 49 | self._enc_w = enc_weight 50 | self._gen_w = gen_weight 51 | self._dis_w = dis_weight 52 | self._cat_dis_w = cat_dis_weight 53 | self._cat_gen_w = cat_gen_weight 54 | self._cls_w = cls_weight 55 | self.layers = {} 56 | 57 | def _create_train_input(self): 58 | """ create input for training model in fig 1, 3, 6 and 8 in the paper """ 59 | self.image = tf.placeholder( 60 | tf.float32, name='image', 61 | shape=[None, self._im_size[0], self._im_size[1], self._n_channel]) 62 | self.label = tf.placeholder( 63 | tf.int64, name='label', shape=[None]) 64 | self.real_distribution = tf.placeholder( 65 | tf.float32, name='real_distribution', shape=[None, self.n_code]) 66 | self.real_y = tf.placeholder( 67 | tf.int64, name='real_y', shape=[None]) 68 | self.lr = tf.placeholder(tf.float32, name='lr') 69 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 70 | 71 | def create_semisupervised_train_model(self): 72 | """ create training model in fig 8 in the paper """ 73 | self.set_is_training(True) 74 | self._create_train_input() 75 | with tf.variable_scope('AE', reuse=tf.AUTO_REUSE): 76 | encoder_in = self.image 77 | if self._flag_noise: 78 | # add gaussian noise to encoder input 79 | encoder_in += tf.random_normal( 80 | tf.shape(encoder_in), mean=0.0, stddev=0.6, dtype=tf.float32) 81 | self.encoder_in = encoder_in 82 | self.layers['encoder_out'] = self.encoder(self.encoder_in) 83 | # continuous latent variable 84 | self.layers['z'], self.layers['z_mu'], self.layers['z_std'], self.layers['z_log_std'] =\ 85 | self.sample_latent(self.layers['encoder_out']) 86 | # discrete class variable 87 | self.layers['cls_logits'] = self.cls_layer(self.layers['encoder_out']) 88 | 89 | self.layers['y'] = tf.argmax(self.layers['cls_logits'], axis=-1, 90 | name='label_predict') 91 | # one hot label is approximated by output of softmax for back-prop 92 | self.layers['one_hot_y_approx'] = tf.nn.softmax(self.layers['cls_logits'], axis=-1) 93 | 94 | decoder_in = tf.concat((self.layers['z'], self.layers['one_hot_y_approx']), axis=-1) 95 | self.layers['decoder_out'] = self.decoder(decoder_in) 96 | self.layers['sample_im'] = (self.layers['decoder_out'] + 1. ) / 2. 97 | 98 | with tf.variable_scope('regularization_z'): 99 | fake_in = self.layers['z'] 100 | real_in = self.real_distribution 101 | self.layers['fake_z'] = self.latent_discriminator(fake_in) 102 | self.layers['real_z'] = self.latent_discriminator(real_in) 103 | 104 | with tf.variable_scope('regularization_y'): 105 | fake_in = self.layers['one_hot_y_approx'] 106 | real_in = tf.one_hot(self.real_y, self.n_class) 107 | self.layers['fake_y'] = self.cat_discriminator(fake_in) 108 | self.layers['real_y'] = self.cat_discriminator(real_in) 109 | 110 | def create_train_model(self): 111 | """ create training model in fig 1, 3 and 6 in the paper """ 112 | self.set_is_training(True) 113 | self._create_train_input() 114 | with tf.variable_scope('AE', reuse=tf.AUTO_REUSE): 115 | encoder_in = self.image 116 | if self._flag_noise: 117 | # add gaussian noise to encoder input 118 | encoder_in += tf.random_normal( 119 | tf.shape(encoder_in), mean=0.0, stddev=0.6, dtype=tf.float32) 120 | self.encoder_in = encoder_in 121 | self.layers['encoder_out'] = self.encoder(self.encoder_in) 122 | self.layers['z'], self.layers['z_mu'], self.layers['z_std'], self.layers['z_log_std'] =\ 123 | self.sample_latent(self.layers['encoder_out']) 124 | 125 | self.decoder_in = self.layers['z'] 126 | if self._flag_supervise: 127 | one_hot_label = tf.one_hot(self.label, self.n_class) 128 | self.decoder_in = tf.concat((self.decoder_in, one_hot_label), axis=-1) 129 | self.layers['decoder_out'] = self.decoder(self.decoder_in) 130 | self.layers['sample_im'] = (self.layers['decoder_out'] + 1. ) / 2. 131 | 132 | with tf.variable_scope('regularization_z'): 133 | fake_in = self.layers['z'] 134 | real_in = self.real_distribution 135 | self.layers['fake_z'] = self.latent_discriminator(fake_in) 136 | self.layers['real_z'] = self.latent_discriminator(real_in) 137 | 138 | def _create_generate_input(self): 139 | """ create input for sampling model in fig 1, 3 and 6 in the paper """ 140 | self.z = tf.placeholder( 141 | tf.float32, name='latent_z', 142 | shape=[None, self.n_code]) 143 | self.keep_prob = 1. 144 | self.image = tf.placeholder( 145 | tf.float32, name='image', 146 | shape=[None, self._im_size[0], self._im_size[1], self._n_channel]) 147 | 148 | def create_generate_style_model(self, n_sample): 149 | """ create samping model in fig 6 in the paper """ 150 | self.set_is_training(False) 151 | with tf.variable_scope('AE', reuse=tf.AUTO_REUSE): 152 | self._create_generate_input() 153 | label = [] 154 | for i in range(self.n_class): 155 | label.extend([i for k in range(n_sample)]) 156 | label = tf.convert_to_tensor(label) # [n_class] 157 | one_hot_label = tf.one_hot(label, self.n_class) # [n_class*n_sample, n_class] 158 | 159 | z = ops.tf_sample_standard_diag_guassian(n_sample, self.n_code) 160 | z = tf.tile(z, [self.n_class, 1]) # [n_class*n_sample, n_code] 161 | decoder_in = tf.concat((z, one_hot_label), axis=-1) 162 | self.layers['generate'] = (self.decoder(decoder_in) + 1. ) / 2. 163 | 164 | def create_generate_model(self, b_size): 165 | """ create samping model in fig 1 and 3 in the paper """ 166 | self.set_is_training(False) 167 | with tf.variable_scope('AE', reuse=tf.AUTO_REUSE): 168 | self._create_generate_input() 169 | # if self.z is not fed in, just sample from diagonal Gaussian 170 | self.z = ops.tf_sample_standard_diag_guassian(b_size, self.n_code) 171 | decoder_in = self.z 172 | self.layers['generate'] = (self.decoder(decoder_in) + 1. ) / 2. 173 | 174 | def _create_cls_input(self): 175 | """ create input for testing model in fig 8 in the paper """ 176 | self.keep_prob = 1. 177 | self.label = tf.placeholder(tf.int64, name='label', shape=[None]) 178 | self.image = tf.placeholder( 179 | tf.float32, name='image', 180 | shape=[None, self._im_size[0], self._im_size[1], self._n_channel]) 181 | 182 | def create_semisupervised_test_model(self): 183 | """ create testing model in fig 8 in the paper """ 184 | self.set_is_training(False) 185 | self._create_cls_input() 186 | with tf.variable_scope('AE', reuse=tf.AUTO_REUSE): 187 | encoder_in = self.image 188 | self.encoder_in = encoder_in 189 | self.layers['encoder_out'] = self.encoder(self.encoder_in) 190 | # discrete class variable 191 | self.layers['cls_logits'] = self.cls_layer(self.layers['encoder_out']) 192 | self.layers['y'] = tf.argmax(self.layers['cls_logits'], axis=-1, 193 | name='label_predict') 194 | 195 | def encoder(self, inputs): 196 | with tf.variable_scope('encoder'): 197 | fc_out = modules.encoder_FC( 198 | inputs, self.is_training, keep_prob=self.keep_prob, 199 | wd=self._wd, name='encoder_FC', init_w=INIT_W) 200 | return fc_out 201 | 202 | def decoder(self, inputs): 203 | with tf.variable_scope('decoder'): 204 | fc_out = modules.decoder_FC( 205 | inputs, self.is_training, keep_prob=self.keep_prob, 206 | wd=self._wd, name='decoder_FC', init_w=INIT_W) 207 | out_dim = self._im_size[0] * self._im_size[1] * self._n_channel 208 | decoder_out = L.linear( 209 | out_dim=out_dim, layer_dict=self.layers, 210 | inputs=fc_out, init_w=None, wd=self._wd, name='decoder_linear') 211 | decoder_out = tf.reshape( 212 | decoder_out, (-1, self._im_size[0], self._im_size[1], self._n_channel)) 213 | return tf.tanh(decoder_out) 214 | 215 | def cls_layer(self, encoder_out): 216 | """ estimate digit label for semi-supervised model """ 217 | cls_logits = L.linear( 218 | out_dim=self.n_class, layer_dict=self.layers, 219 | inputs=encoder_out, init_w=INIT_W, wd=self._wd, name='cls_layer') 220 | return cls_logits 221 | 222 | def sample_latent(self, encoder_out): 223 | with tf.variable_scope('sample_latent'): 224 | encoder_out = encoder_out 225 | 226 | z_mean = L.linear( 227 | out_dim=self.n_code, layer_dict=self.layers, 228 | inputs=encoder_out, init_w=INIT_W, wd=self._wd, name='latent_mean') 229 | z_std = L.linear( 230 | out_dim=self.n_code, layer_dict=self.layers, nl=L.softplus, 231 | inputs=encoder_out, init_w=INIT_W, wd=self._wd, name='latent_std') 232 | z_log_std = tf.log(z_std + 1e-8) 233 | 234 | b_size = tf.shape(encoder_out)[0] 235 | z = ops.tf_sample_diag_guassian(z_mean, z_std, b_size, self.n_code) 236 | return z, z_mean, z_std, z_log_std 237 | 238 | def latent_discriminator(self, inputs): 239 | with tf.variable_scope('latent_discriminator', reuse=tf.AUTO_REUSE): 240 | fc_out = modules.discriminator_FC( 241 | inputs, self.is_training, nl=L.leaky_relu, 242 | wd=self._wd, name='latent_discriminator_FC', init_w=INIT_W) 243 | return fc_out 244 | 245 | def cat_discriminator(self, inputs): 246 | with tf.variable_scope('cat_discriminator', reuse=tf.AUTO_REUSE): 247 | fc_out = modules.discriminator_FC( 248 | inputs, self.is_training, nl=L.leaky_relu, 249 | wd=self._wd, name='cat_discriminator_FC', init_w=INIT_W) 250 | return fc_out 251 | 252 | def get_generate_summary(self): 253 | with tf.name_scope('generate'): 254 | tf.summary.image( 255 | 'image', 256 | tf.cast(self.layers['generate'], tf.float32), 257 | collections=['generate']) 258 | return tf.summary.merge_all(key='generate') 259 | 260 | def get_valid_summary(self): 261 | with tf.name_scope('valid'): 262 | tf.summary.image( 263 | 'encoder input', 264 | tf.cast(self.encoder_in, tf.float32), 265 | collections=['valid']) 266 | tf.summary.image( 267 | 'decoder output', 268 | tf.cast(self.layers['sample_im'], tf.float32), 269 | collections=['valid']) 270 | return tf.summary.merge_all(key='valid') 271 | 272 | def get_train_summary(self): 273 | with tf.name_scope('train'): 274 | tf.summary.image( 275 | 'input image', 276 | tf.cast(self.image, tf.float32), 277 | collections=['train']) 278 | tf.summary.image( 279 | 'encoder input', 280 | tf.cast(self.encoder_in, tf.float32), 281 | collections=['train']) 282 | tf.summary.image( 283 | 'decoder output', 284 | tf.cast(self.layers['sample_im'], tf.float32), 285 | collections=['train']) 286 | 287 | tf.summary.histogram( 288 | name='z real distribution', values=self.real_distribution, 289 | collections=['train']) 290 | tf.summary.histogram( 291 | name='z encoder distribution', values=self.layers['z'], 292 | collections=['train']) 293 | try: 294 | tf.summary.histogram( 295 | name='y encoder distribution', values=self.layers['y'], 296 | collections=['train']) 297 | tf.summary.histogram( 298 | name='y real distribution', values=self.real_y, 299 | collections=['train']) 300 | except KeyError: 301 | pass 302 | 303 | return tf.summary.merge_all(key='train') 304 | 305 | def _get_reconstruction_loss(self): 306 | with tf.name_scope('reconstruction_loss'): 307 | p_hat = self.layers['decoder_out'] 308 | p = self.image 309 | autoencoder_loss = 0.5 * tf.reduce_mean(tf.reduce_sum(tf.square(p - p_hat), axis=[1,2,3])) 310 | 311 | return autoencoder_loss * self._enc_w 312 | 313 | def get_reconstruction_loss(self): 314 | try: 315 | return self._reconstr_loss 316 | except AttributeError: 317 | self._reconstr_loss = self._get_reconstruction_loss() 318 | return self._reconstr_loss 319 | 320 | def get_reconstruction_train_op(self): 321 | with tf.name_scope('reconstruction_train'): 322 | opt = tf.train.AdamOptimizer(self.lr, beta1=0.5) 323 | loss = self.get_reconstruction_loss() 324 | var_list = tf.trainable_variables(scope='AE') 325 | # print(var_list) 326 | grads = tf.gradients(loss, var_list) 327 | return opt.apply_gradients(zip(grads, var_list)) 328 | 329 | def get_latent_generator_train_op(self): 330 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='AE/encoder') +\ 331 | tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='AE/sample_latent') 332 | self.latent_g_loss, train_op = modules.train_generator( 333 | fake_in=self.layers['fake_z'], 334 | loss_weight=self._gen_w, 335 | opt=tf.train.AdamOptimizer(self.lr, beta1=0.5), 336 | var_list=var_list, 337 | name='z_generate_train_op') 338 | return train_op 339 | 340 | def get_cat_generator_train_op(self): 341 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='AE/encoder') +\ 342 | tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='AE/cls_layer') 343 | self.cat_g_loss, train_op = modules.train_generator( 344 | fake_in=self.layers['fake_y'], 345 | loss_weight=self._cat_gen_w, 346 | opt=tf.train.AdamOptimizer(self.lr, beta1=0.5), 347 | var_list=var_list, 348 | name='y_generate_train_op') 349 | return train_op 350 | 351 | def get_latent_discrimator_train_op(self): 352 | self.latent_d_loss, train_op = modules.train_discrimator( 353 | fake_in=self.layers['fake_z'], 354 | real_in=self.layers['real_z'], 355 | loss_weight=self._dis_w, 356 | opt=tf.train.AdamOptimizer(self.lr, beta1=0.5), 357 | var_list=tf.trainable_variables(scope='regularization_z'), 358 | name='z_discrimator_train_op') 359 | return train_op 360 | 361 | def get_cat_discrimator_train_op(self): 362 | self.cat_d_loss, train_op = modules.train_discrimator( 363 | fake_in=self.layers['fake_y'], 364 | real_in=self.layers['real_y'], 365 | loss_weight=self._cat_dis_w, 366 | opt=tf.train.AdamOptimizer(self.lr, beta1=0.5), 367 | var_list=tf.trainable_variables(scope='regularization_y'), 368 | name='y_discrimator_train_op') 369 | return train_op 370 | 371 | def get_cls_train_op(self): 372 | with tf.name_scope('cls_train_op'): 373 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='AE/encoder') +\ 374 | tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='AE/cls_layer') 375 | loss = self.get_cls_loss() 376 | opt=tf.train.AdamOptimizer(self.lr, beta1=0.5) 377 | grads = tf.gradients(loss, var_list) 378 | return opt.apply_gradients(zip(grads, var_list)) 379 | 380 | def _get_cls_loss(self): 381 | with tf.name_scope('cls_loss'): 382 | logits=self.layers['cls_logits'], 383 | labels=self.label, 384 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 385 | labels=labels, 386 | logits=logits, 387 | name='cross_entropy') 388 | cross_entropy = tf.reduce_mean(cross_entropy) 389 | 390 | return cross_entropy * self._cls_w 391 | 392 | def get_cls_loss(self): 393 | try: 394 | return self._cls_loss 395 | except AttributeError: 396 | self._cls_loss = self._get_cls_loss() 397 | return self._cls_loss 398 | 399 | def get_cls_accuracy(self): 400 | with tf.name_scope('cls_accuracy'): 401 | labels = self.label 402 | cls_predict = self.layers['y'] 403 | num_correct = tf.cast(tf.equal(labels, cls_predict), tf.float32) 404 | return tf.reduce_mean(num_correct) 405 | --------------------------------------------------------------------------------