├── .gitignore ├── README.md ├── infogan ├── __init__.py ├── algos │ ├── __init__.py │ └── infogan_trainer.py ├── misc │ ├── __init__.py │ ├── custom_ops.py │ ├── datasets.py │ ├── distributions.py │ └── utils.py └── models │ ├── __init__.py │ └── regularized_gan.py ├── launchers ├── __init__.py └── run_mnist_exp.py ├── requirements.txt └── tests ├── __init__.py └── test_distributions.py /.gitignore: -------------------------------------------------------------------------------- 1 | MNIST 2 | *.pyc 3 | .idea 4 | imgs 5 | logs 6 | ckt 7 | .ipynb_checkpoints 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | # InfoGAN 4 | 5 | Code for reproducing key results in the paper [InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets](https://arxiv.org/abs/1606.03657) by Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, Pieter Abbeel. 6 | 7 | ## Dependencies 8 | 9 | This project currently requires the dev version of TensorFlow available on Github: https://github.com/tensorflow/tensorflow. As of the release, the latest commit is [79174a](https://github.com/tensorflow/tensorflow/commit/79174afa30046ecdc437b531812f2cb41a32695e). 10 | 11 | In addition, please `pip install` the following packages: 12 | - `prettytensor` 13 | - `progressbar` 14 | - `python-dateutil` 15 | 16 | ## Running in Docker 17 | 18 | ```bash 19 | $ git clone git@github.com:openai/InfoGAN.git 20 | $ docker run -v $(pwd)/InfoGAN:/InfoGAN -w /InfoGAN -it -p 8888:8888 gcr.io/tensorflow/tensorflow:r0.9rc0-devel 21 | root@X:/InfoGAN# pip install -r requirements.txt 22 | root@X:/InfoGAN# python launchers/run_mnist_exp.py 23 | ``` 24 | 25 | ## Running Experiment 26 | 27 | We provide the source code to run the MNIST example: 28 | 29 | ```bash 30 | PYTHONPATH='.' python launchers/run_mnist_exp.py 31 | ``` 32 | 33 | You can launch TensorBoard to view the generated images: 34 | 35 | ```bash 36 | tensorboard --logdir logs/mnist 37 | ``` 38 | -------------------------------------------------------------------------------- /infogan/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import -------------------------------------------------------------------------------- /infogan/algos/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import -------------------------------------------------------------------------------- /infogan/algos/infogan_trainer.py: -------------------------------------------------------------------------------- 1 | from infogan.models.regularized_gan import RegularizedGAN 2 | import prettytensor as pt 3 | import tensorflow as tf 4 | import numpy as np 5 | from progressbar import ETA, Bar, Percentage, ProgressBar 6 | from infogan.misc.distributions import Bernoulli, Gaussian, Categorical 7 | import sys 8 | 9 | TINY = 1e-8 10 | 11 | 12 | class InfoGANTrainer(object): 13 | def __init__(self, 14 | model, 15 | batch_size, 16 | dataset=None, 17 | exp_name="experiment", 18 | log_dir="logs", 19 | checkpoint_dir="ckt", 20 | max_epoch=100, 21 | updates_per_epoch=100, 22 | snapshot_interval=10000, 23 | info_reg_coeff=1.0, 24 | discriminator_learning_rate=2e-4, 25 | generator_learning_rate=2e-4, 26 | ): 27 | """ 28 | :type model: RegularizedGAN 29 | """ 30 | self.model = model 31 | self.dataset = dataset 32 | self.batch_size = batch_size 33 | self.max_epoch = max_epoch 34 | self.exp_name = exp_name 35 | self.log_dir = log_dir 36 | self.checkpoint_dir = checkpoint_dir 37 | self.snapshot_interval = snapshot_interval 38 | self.updates_per_epoch = updates_per_epoch 39 | self.generator_learning_rate = generator_learning_rate 40 | self.discriminator_learning_rate = discriminator_learning_rate 41 | self.info_reg_coeff = info_reg_coeff 42 | self.discriminator_trainer = None 43 | self.generator_trainer = None 44 | self.input_tensor = None 45 | self.log_vars = [] 46 | 47 | def init_opt(self): 48 | self.input_tensor = input_tensor = tf.placeholder(tf.float32, [self.batch_size, self.dataset.image_dim]) 49 | 50 | with pt.defaults_scope(phase=pt.Phase.train): 51 | z_var = self.model.latent_dist.sample_prior(self.batch_size) 52 | fake_x, _ = self.model.generate(z_var) 53 | real_d, _, _, _ = self.model.discriminate(input_tensor) 54 | fake_d, _, fake_reg_z_dist_info, _ = self.model.discriminate(fake_x) 55 | 56 | reg_z = self.model.reg_z(z_var) 57 | 58 | discriminator_loss = - tf.reduce_mean(tf.log(real_d + TINY) + tf.log(1. - fake_d + TINY)) 59 | generator_loss = - tf.reduce_mean(tf.log(fake_d + TINY)) 60 | 61 | self.log_vars.append(("discriminator_loss", discriminator_loss)) 62 | self.log_vars.append(("generator_loss", generator_loss)) 63 | 64 | mi_est = tf.constant(0.) 65 | cross_ent = tf.constant(0.) 66 | 67 | # compute for discrete and continuous codes separately 68 | # discrete: 69 | if len(self.model.reg_disc_latent_dist.dists) > 0: 70 | disc_reg_z = self.model.disc_reg_z(reg_z) 71 | disc_reg_dist_info = self.model.disc_reg_dist_info(fake_reg_z_dist_info) 72 | disc_log_q_c_given_x = self.model.reg_disc_latent_dist.logli(disc_reg_z, disc_reg_dist_info) 73 | disc_log_q_c = self.model.reg_disc_latent_dist.logli_prior(disc_reg_z) 74 | disc_cross_ent = tf.reduce_mean(-disc_log_q_c_given_x) 75 | disc_ent = tf.reduce_mean(-disc_log_q_c) 76 | disc_mi_est = disc_ent - disc_cross_ent 77 | mi_est += disc_mi_est 78 | cross_ent += disc_cross_ent 79 | self.log_vars.append(("MI_disc", disc_mi_est)) 80 | self.log_vars.append(("CrossEnt_disc", disc_cross_ent)) 81 | discriminator_loss -= self.info_reg_coeff * disc_mi_est 82 | generator_loss -= self.info_reg_coeff * disc_mi_est 83 | 84 | if len(self.model.reg_cont_latent_dist.dists) > 0: 85 | cont_reg_z = self.model.cont_reg_z(reg_z) 86 | cont_reg_dist_info = self.model.cont_reg_dist_info(fake_reg_z_dist_info) 87 | cont_log_q_c_given_x = self.model.reg_cont_latent_dist.logli(cont_reg_z, cont_reg_dist_info) 88 | cont_log_q_c = self.model.reg_cont_latent_dist.logli_prior(cont_reg_z) 89 | cont_cross_ent = tf.reduce_mean(-cont_log_q_c_given_x) 90 | cont_ent = tf.reduce_mean(-cont_log_q_c) 91 | cont_mi_est = cont_ent - cont_cross_ent 92 | mi_est += cont_mi_est 93 | cross_ent += cont_cross_ent 94 | self.log_vars.append(("MI_cont", cont_mi_est)) 95 | self.log_vars.append(("CrossEnt_cont", cont_cross_ent)) 96 | discriminator_loss -= self.info_reg_coeff * cont_mi_est 97 | generator_loss -= self.info_reg_coeff * cont_mi_est 98 | 99 | for idx, dist_info in enumerate(self.model.reg_latent_dist.split_dist_info(fake_reg_z_dist_info)): 100 | if "stddev" in dist_info: 101 | self.log_vars.append(("max_std_%d" % idx, tf.reduce_max(dist_info["stddev"]))) 102 | self.log_vars.append(("min_std_%d" % idx, tf.reduce_min(dist_info["stddev"]))) 103 | 104 | self.log_vars.append(("MI", mi_est)) 105 | self.log_vars.append(("CrossEnt", cross_ent)) 106 | 107 | all_vars = tf.trainable_variables() 108 | d_vars = [var for var in all_vars if var.name.startswith('d_')] 109 | g_vars = [var for var in all_vars if var.name.startswith('g_')] 110 | 111 | self.log_vars.append(("max_real_d", tf.reduce_max(real_d))) 112 | self.log_vars.append(("min_real_d", tf.reduce_min(real_d))) 113 | self.log_vars.append(("max_fake_d", tf.reduce_max(fake_d))) 114 | self.log_vars.append(("min_fake_d", tf.reduce_min(fake_d))) 115 | 116 | discriminator_optimizer = tf.train.AdamOptimizer(self.discriminator_learning_rate, beta1=0.5) 117 | self.discriminator_trainer = pt.apply_optimizer(discriminator_optimizer, losses=[discriminator_loss], 118 | var_list=d_vars) 119 | 120 | generator_optimizer = tf.train.AdamOptimizer(self.generator_learning_rate, beta1=0.5) 121 | self.generator_trainer = pt.apply_optimizer(generator_optimizer, losses=[generator_loss], var_list=g_vars) 122 | 123 | for k, v in self.log_vars: 124 | tf.scalar_summary(k, v) 125 | 126 | with pt.defaults_scope(phase=pt.Phase.test): 127 | with tf.variable_scope("model", reuse=True) as scope: 128 | self.visualize_all_factors() 129 | 130 | def visualize_all_factors(self): 131 | with tf.Session(): 132 | fixed_noncat = np.concatenate([ 133 | np.tile( 134 | self.model.nonreg_latent_dist.sample_prior(10).eval(), 135 | [10, 1] 136 | ), 137 | self.model.nonreg_latent_dist.sample_prior(self.batch_size - 100).eval(), 138 | ], axis=0) 139 | fixed_cat = np.concatenate([ 140 | np.tile( 141 | self.model.reg_latent_dist.sample_prior(10).eval(), 142 | [10, 1] 143 | ), 144 | self.model.reg_latent_dist.sample_prior(self.batch_size - 100).eval(), 145 | ], axis=0) 146 | 147 | offset = 0 148 | for dist_idx, dist in enumerate(self.model.reg_latent_dist.dists): 149 | if isinstance(dist, Gaussian): 150 | assert dist.dim == 1, "Only dim=1 is currently supported" 151 | c_vals = [] 152 | for idx in xrange(10): 153 | c_vals.extend([-1.0 + idx * 2.0 / 9] * 10) 154 | c_vals.extend([0.] * (self.batch_size - 100)) 155 | vary_cat = np.asarray(c_vals, dtype=np.float32).reshape((-1, 1)) 156 | cur_cat = np.copy(fixed_cat) 157 | cur_cat[:, offset:offset+1] = vary_cat 158 | offset += 1 159 | elif isinstance(dist, Categorical): 160 | lookup = np.eye(dist.dim, dtype=np.float32) 161 | cat_ids = [] 162 | for idx in xrange(10): 163 | cat_ids.extend([idx] * 10) 164 | cat_ids.extend([0] * (self.batch_size - 100)) 165 | cur_cat = np.copy(fixed_cat) 166 | cur_cat[:, offset:offset+dist.dim] = lookup[cat_ids] 167 | offset += dist.dim 168 | elif isinstance(dist, Bernoulli): 169 | assert dist.dim == 1, "Only dim=1 is currently supported" 170 | lookup = np.eye(dist.dim, dtype=np.float32) 171 | cat_ids = [] 172 | for idx in xrange(10): 173 | cat_ids.extend([int(idx / 5)] * 10) 174 | cat_ids.extend([0] * (self.batch_size - 100)) 175 | cur_cat = np.copy(fixed_cat) 176 | cur_cat[:, offset:offset+dist.dim] = np.expand_dims(np.array(cat_ids), axis=-1) 177 | # import ipdb; ipdb.set_trace() 178 | offset += dist.dim 179 | else: 180 | raise NotImplementedError 181 | z_var = tf.constant(np.concatenate([fixed_noncat, cur_cat], axis=1)) 182 | 183 | _, x_dist_info = self.model.generate(z_var) 184 | 185 | # just take the mean image 186 | if isinstance(self.model.output_dist, Bernoulli): 187 | img_var = x_dist_info["p"] 188 | elif isinstance(self.model.output_dist, Gaussian): 189 | img_var = x_dist_info["mean"] 190 | else: 191 | raise NotImplementedError 192 | img_var = self.dataset.inverse_transform(img_var) 193 | rows = 10 194 | img_var = tf.reshape(img_var, [self.batch_size] + list(self.dataset.image_shape)) 195 | img_var = img_var[:rows * rows, :, :, :] 196 | imgs = tf.reshape(img_var, [rows, rows] + list(self.dataset.image_shape)) 197 | stacked_img = [] 198 | for row in xrange(rows): 199 | row_img = [] 200 | for col in xrange(rows): 201 | row_img.append(imgs[row, col, :, :, :]) 202 | stacked_img.append(tf.concat(1, row_img)) 203 | imgs = tf.concat(0, stacked_img) 204 | imgs = tf.expand_dims(imgs, 0) 205 | tf.image_summary("image_%d_%s" % (dist_idx, dist.__class__.__name__), imgs) 206 | 207 | 208 | def train(self): 209 | 210 | self.init_opt() 211 | 212 | init = tf.initialize_all_variables() 213 | 214 | with tf.Session() as sess: 215 | sess.run(init) 216 | 217 | summary_op = tf.merge_all_summaries() 218 | summary_writer = tf.train.SummaryWriter(self.log_dir, sess.graph) 219 | 220 | saver = tf.train.Saver() 221 | 222 | counter = 0 223 | 224 | log_vars = [x for _, x in self.log_vars] 225 | log_keys = [x for x, _ in self.log_vars] 226 | 227 | for epoch in range(self.max_epoch): 228 | widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()] 229 | pbar = ProgressBar(maxval=self.updates_per_epoch, widgets=widgets) 230 | pbar.start() 231 | 232 | all_log_vals = [] 233 | for i in range(self.updates_per_epoch): 234 | pbar.update(i) 235 | x, _ = self.dataset.train.next_batch(self.batch_size) 236 | feed_dict = {self.input_tensor: x} 237 | log_vals = sess.run([self.discriminator_trainer] + log_vars, feed_dict)[1:] 238 | sess.run(self.generator_trainer, feed_dict) 239 | all_log_vals.append(log_vals) 240 | counter += 1 241 | 242 | if counter % self.snapshot_interval == 0: 243 | snapshot_name = "%s_%s" % (self.exp_name, str(counter)) 244 | fn = saver.save(sess, "%s/%s.ckpt" % (self.checkpoint_dir, snapshot_name)) 245 | print("Model saved in file: %s" % fn) 246 | 247 | x, _ = self.dataset.train.next_batch(self.batch_size) 248 | 249 | summary_str = sess.run(summary_op, {self.input_tensor: x}) 250 | summary_writer.add_summary(summary_str, counter) 251 | 252 | avg_log_vals = np.mean(np.array(all_log_vals), axis=0) 253 | log_dict = dict(zip(log_keys, avg_log_vals)) 254 | 255 | log_line = "; ".join("%s: %s" % (str(k), str(v)) for k, v in zip(log_keys, avg_log_vals)) 256 | print("Epoch %d | " % (epoch) + log_line) 257 | sys.stdout.flush() 258 | if np.any(np.isnan(avg_log_vals)): 259 | raise ValueError("NaN detected!") 260 | -------------------------------------------------------------------------------- /infogan/misc/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import -------------------------------------------------------------------------------- /infogan/misc/custom_ops.py: -------------------------------------------------------------------------------- 1 | import prettytensor as pt 2 | import tensorflow as tf 3 | from prettytensor.pretty_tensor_class import Phase 4 | import numpy as np 5 | 6 | 7 | class conv_batch_norm(pt.VarStoreMethod): 8 | """Code modification of http://stackoverflow.com/a/33950177""" 9 | 10 | def __call__(self, input_layer, epsilon=1e-5, momentum=0.1, name="batch_norm", 11 | in_dim=None, phase=Phase.train): 12 | self.ema = tf.train.ExponentialMovingAverage(decay=0.9) 13 | 14 | shape = input_layer.shape 15 | shp = in_dim or shape[-1] 16 | with tf.variable_scope(name) as scope: 17 | self.gamma = self.variable("gamma", [shp], init=tf.random_normal_initializer(1., 0.02)) 18 | self.beta = self.variable("beta", [shp], init=tf.constant_initializer(0.)) 19 | 20 | self.mean, self.variance = tf.nn.moments(input_layer.tensor, [0, 1, 2]) 21 | # sigh...tf's shape system is so.. 22 | self.mean.set_shape((shp,)) 23 | self.variance.set_shape((shp,)) 24 | self.ema_apply_op = self.ema.apply([self.mean, self.variance]) 25 | 26 | if phase == Phase.train: 27 | with tf.control_dependencies([self.ema_apply_op]): 28 | normalized_x = tf.nn.batch_norm_with_global_normalization( 29 | input_layer.tensor, self.mean, self.variance, self.beta, self.gamma, epsilon, 30 | scale_after_normalization=True) 31 | else: 32 | normalized_x = tf.nn.batch_norm_with_global_normalization( 33 | x, self.ema.average(self.mean), self.ema.average(self.variance), self.beta, 34 | self.gamma, epsilon, 35 | scale_after_normalization=True) 36 | return input_layer.with_tensor(normalized_x, parameters=self.vars) 37 | 38 | 39 | pt.Register(assign_defaults=('phase'))(conv_batch_norm) 40 | 41 | 42 | @pt.Register(assign_defaults=('phase')) 43 | class fc_batch_norm(conv_batch_norm): 44 | def __call__(self, input_layer, *args, **kwargs): 45 | ori_shape = input_layer.shape 46 | if ori_shape[0] is None: 47 | ori_shape[0] = -1 48 | new_shape = [ori_shape[0], 1, 1, ori_shape[1]] 49 | x = tf.reshape(input_layer.tensor, new_shape) 50 | normalized_x = super(self.__class__, self).__call__(input_layer.with_tensor(x), *args, **kwargs) # input_layer) 51 | return normalized_x.reshape(ori_shape) 52 | 53 | 54 | def leaky_rectify(x, leakiness=0.01): 55 | assert leakiness <= 1 56 | ret = tf.maximum(x, leakiness * x) 57 | # import ipdb; ipdb.set_trace() 58 | return ret 59 | 60 | 61 | @pt.Register 62 | class custom_conv2d(pt.VarStoreMethod): 63 | def __call__(self, input_layer, output_dim, 64 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, in_dim=None, padding='SAME', 65 | name="conv2d"): 66 | with tf.variable_scope(name): 67 | w = self.variable('w', [k_h, k_w, in_dim or input_layer.shape[-1], output_dim], 68 | init=tf.truncated_normal_initializer(stddev=stddev)) 69 | conv = tf.nn.conv2d(input_layer.tensor, w, strides=[1, d_h, d_w, 1], padding=padding) 70 | 71 | biases = self.variable('biases', [output_dim], init=tf.constant_initializer(0.0)) 72 | # import ipdb; ipdb.set_trace() 73 | return input_layer.with_tensor(tf.nn.bias_add(conv, biases), parameters=self.vars) 74 | 75 | 76 | @pt.Register 77 | class custom_deconv2d(pt.VarStoreMethod): 78 | def __call__(self, input_layer, output_shape, 79 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 80 | name="deconv2d"): 81 | output_shape[0] = input_layer.shape[0] 82 | ts_output_shape = tf.pack(output_shape) 83 | with tf.variable_scope(name): 84 | # filter : [height, width, output_channels, in_channels] 85 | w = self.variable('w', [k_h, k_w, output_shape[-1], input_layer.shape[-1]], 86 | init=tf.random_normal_initializer(stddev=stddev)) 87 | 88 | try: 89 | deconv = tf.nn.conv2d_transpose(input_layer, w, 90 | output_shape=ts_output_shape, 91 | strides=[1, d_h, d_w, 1]) 92 | 93 | # Support for versions of TensorFlow before 0.7.0 94 | except AttributeError: 95 | deconv = tf.nn.deconv2d(input_layer, w, output_shape=ts_output_shape, 96 | strides=[1, d_h, d_w, 1]) 97 | 98 | biases = self.variable('biases', [output_shape[-1]], init=tf.constant_initializer(0.0)) 99 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), [-1] + output_shape[1:]) 100 | 101 | return deconv 102 | 103 | 104 | @pt.Register 105 | class custom_fully_connected(pt.VarStoreMethod): 106 | def __call__(self, input_layer, output_size, scope=None, in_dim=None, stddev=0.02, bias_start=0.0): 107 | shape = input_layer.shape 108 | input_ = input_layer.tensor 109 | try: 110 | if len(shape) == 4: 111 | input_ = tf.reshape(input_, tf.pack([tf.shape(input_)[0], np.prod(shape[1:])])) 112 | input_.set_shape([None, np.prod(shape[1:])]) 113 | shape = input_.get_shape().as_list() 114 | 115 | with tf.variable_scope(scope or "Linear"): 116 | matrix = self.variable("Matrix", [in_dim or shape[1], output_size], dt=tf.float32, 117 | init=tf.random_normal_initializer(stddev=stddev)) 118 | bias = self.variable("bias", [output_size], init=tf.constant_initializer(bias_start)) 119 | return input_layer.with_tensor(tf.matmul(input_, matrix) + bias, parameters=self.vars) 120 | except Exception: 121 | import ipdb; ipdb.set_trace() 122 | -------------------------------------------------------------------------------- /infogan/misc/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.examples.tutorials import mnist 3 | import os 4 | import numpy as np 5 | 6 | 7 | class Dataset(object): 8 | def __init__(self, images, labels=None): 9 | self._images = images.reshape(images.shape[0], -1) 10 | self._labels = labels 11 | self._epochs_completed = -1 12 | self._num_examples = images.shape[0] 13 | # shuffle on first run 14 | self._index_in_epoch = self._num_examples 15 | 16 | @property 17 | def images(self): 18 | return self._images 19 | 20 | @property 21 | def labels(self): 22 | return self._labels 23 | 24 | @property 25 | def num_examples(self): 26 | return self._num_examples 27 | 28 | @property 29 | def epochs_completed(self): 30 | return self._epochs_completed 31 | 32 | def next_batch(self, batch_size): 33 | """Return the next `batch_size` examples from this data set.""" 34 | start = self._index_in_epoch 35 | self._index_in_epoch += batch_size 36 | if self._index_in_epoch > self._num_examples: 37 | # Finished epoch 38 | self._epochs_completed += 1 39 | # Shuffle the data 40 | perm = np.arange(self._num_examples) 41 | np.random.shuffle(perm) 42 | self._images = self._images[perm] 43 | if self._labels is not None: 44 | self._labels = self._labels[perm] 45 | # Start next epoch 46 | start = 0 47 | self._index_in_epoch = batch_size 48 | assert batch_size <= self._num_examples 49 | end = self._index_in_epoch 50 | if self._labels is None: 51 | return self._images[start:end], None 52 | else: 53 | return self._images[start:end], self._labels[start:end] 54 | 55 | 56 | class MnistDataset(object): 57 | def __init__(self): 58 | data_directory = "MNIST" 59 | if not os.path.exists(data_directory): 60 | os.makedirs(data_directory) 61 | dataset = mnist.input_data.read_data_sets(data_directory) 62 | self.train = dataset.train 63 | # make sure that each type of digits have exactly 10 samples 64 | sup_images = [] 65 | sup_labels = [] 66 | rnd_state = np.random.get_state() 67 | np.random.seed(0) 68 | for cat in range(10): 69 | ids = np.where(self.train.labels == cat)[0] 70 | np.random.shuffle(ids) 71 | sup_images.extend(self.train.images[ids[:10]]) 72 | sup_labels.extend(self.train.labels[ids[:10]]) 73 | np.random.set_state(rnd_state) 74 | self.supervised_train = Dataset( 75 | np.asarray(sup_images), 76 | np.asarray(sup_labels), 77 | ) 78 | self.test = dataset.test 79 | self.validation = dataset.validation 80 | self.image_dim = 28 * 28 81 | self.image_shape = (28, 28, 1) 82 | 83 | def transform(self, data): 84 | return data 85 | 86 | def inverse_transform(self, data): 87 | return data 88 | -------------------------------------------------------------------------------- /infogan/misc/distributions.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | import itertools 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | TINY = 1e-8 8 | 9 | floatX = np.float32 10 | 11 | 12 | class Distribution(object): 13 | @property 14 | def dist_flat_dim(self): 15 | """ 16 | :rtype: int 17 | """ 18 | raise NotImplementedError 19 | 20 | @property 21 | def dim(self): 22 | """ 23 | :rtype: int 24 | """ 25 | raise NotImplementedError 26 | 27 | @property 28 | def effective_dim(self): 29 | """ 30 | The effective dimension when used for rescaling quantities. This can be different from the 31 | actual dimension when the actual values are using redundant representations (e.g. for categorical 32 | distributions we encode it in onehot representation) 33 | :rtype: int 34 | """ 35 | raise NotImplementedError 36 | 37 | def kl_prior(self, dist_info): 38 | return self.kl(dist_info, self.prior_dist_info(dist_info.values()[0].get_shape()[0])) 39 | 40 | def logli(self, x_var, dist_info): 41 | """ 42 | :param x_var: 43 | :param dist_info: 44 | :return: log likelihood of the data 45 | """ 46 | raise NotImplementedError 47 | 48 | def logli_prior(self, x_var): 49 | return self.logli(x_var, self.prior_dist_info(x_var.get_shape()[0])) 50 | 51 | def nonreparam_logli(self, x_var, dist_info): 52 | """ 53 | :param x_var: 54 | :param dist_info: 55 | :return: the non-reparameterizable part of the log likelihood 56 | """ 57 | raise NotImplementedError 58 | 59 | def activate_dist(self, flat_dist): 60 | """ 61 | :param flat_dist: flattened dist info without applying nonlinearity yet 62 | :return: a dictionary of dist infos 63 | """ 64 | raise NotImplementedError 65 | 66 | @property 67 | def dist_info_keys(self): 68 | """ 69 | :rtype: list[str] 70 | """ 71 | raise NotImplementedError 72 | 73 | def entropy(self, dist_info): 74 | """ 75 | :return: entropy for each minibatch entry 76 | """ 77 | raise NotImplementedError 78 | 79 | def marginal_entropy(self, dist_info): 80 | """ 81 | :return: the entropy of the mixture distribution averaged over all minibatch entries. Will return in the same 82 | shape as calling `:code:Distribution.entropy` 83 | """ 84 | raise NotImplementedError 85 | 86 | def marginal_logli(self, x_var, dist_info): 87 | """ 88 | :return: the log likelihood of the given variable under the mixture distribution averaged over all minibatch 89 | entries. 90 | """ 91 | raise NotImplementedError 92 | 93 | def sample(self, dist_info): 94 | raise NotImplementedError 95 | 96 | def sample_prior(self, batch_size): 97 | return self.sample(self.prior_dist_info(batch_size)) 98 | 99 | def prior_dist_info(self, batch_size): 100 | """ 101 | :return: a dictionary containing distribution information about the standard prior distribution, the shape 102 | of which is jointly decided by batch_size and self.dim 103 | """ 104 | raise NotImplementedError 105 | 106 | 107 | class Categorical(Distribution): 108 | def __init__(self, dim): 109 | self._dim = dim 110 | 111 | @property 112 | def dim(self): 113 | return self._dim 114 | 115 | @property 116 | def dist_flat_dim(self): 117 | return self.dim 118 | 119 | @property 120 | def effective_dim(self): 121 | return 1 122 | 123 | def logli(self, x_var, dist_info): 124 | prob = dist_info["prob"] 125 | return tf.reduce_sum(tf.log(prob + TINY) * x_var, reduction_indices=1) 126 | 127 | def prior_dist_info(self, batch_size): 128 | prob = tf.ones([batch_size, self.dim]) * floatX(1.0 / self.dim) 129 | return dict(prob=prob) 130 | 131 | def marginal_logli(self, x_var, dist_info): 132 | prob = dist_info["prob"] 133 | avg_prob = tf.tile( 134 | tf.reduce_mean(prob, reduction_indices=0, keep_dims=True), 135 | tf.pack([tf.shape(prob)[0], 1]) 136 | ) 137 | return self.logli(x_var, dict(prob=avg_prob)) 138 | 139 | def nonreparam_logli(self, x_var, dist_info): 140 | return self.logli(x_var, dist_info) 141 | 142 | def kl(self, p, q): 143 | """ 144 | :param p: left dist info 145 | :param q: right dist info 146 | :return: KL(p||q) 147 | """ 148 | p_prob = p["prob"] 149 | q_prob = q["prob"] 150 | return tf.reduce_sum( 151 | p_prob * (tf.log(p_prob + TINY) - tf.log(q_prob + TINY)), 152 | reduction_indices=1 153 | ) 154 | 155 | def sample(self, dist_info): 156 | prob = dist_info["prob"] 157 | ids = tf.multinomial(tf.log(prob + TINY), num_samples=1)[:, 0] 158 | onehot = tf.constant(np.eye(self.dim, dtype=np.float32)) 159 | return tf.nn.embedding_lookup(onehot, ids) 160 | 161 | def activate_dist(self, flat_dist): 162 | return dict(prob=tf.nn.softmax(flat_dist)) 163 | 164 | def entropy(self, dist_info): 165 | prob = dist_info["prob"] 166 | return -tf.reduce_sum(prob * tf.log(prob + TINY), reduction_indices=1) 167 | 168 | def marginal_entropy(self, dist_info): 169 | prob = dist_info["prob"] 170 | avg_prob = tf.tile( 171 | tf.reduce_mean(prob, reduction_indices=0, keep_dims=True), 172 | tf.pack([tf.shape(prob)[0], 1]) 173 | ) 174 | return self.entropy(dict(prob=avg_prob)) 175 | 176 | @property 177 | def dist_info_keys(self): 178 | return ["prob"] 179 | 180 | 181 | class Gaussian(Distribution): 182 | def __init__(self, dim, fix_std=False): 183 | self._dim = dim 184 | self._fix_std = fix_std 185 | 186 | @property 187 | def dim(self): 188 | return self._dim 189 | 190 | @property 191 | def dist_flat_dim(self): 192 | return self._dim * 2 193 | 194 | @property 195 | def effective_dim(self): 196 | return self._dim 197 | 198 | def logli(self, x_var, dist_info): 199 | mean = dist_info["mean"] 200 | stddev = dist_info["stddev"] 201 | epsilon = (x_var - mean) / (stddev + TINY) 202 | return tf.reduce_sum( 203 | - 0.5 * np.log(2 * np.pi) - tf.log(stddev + TINY) - 0.5 * tf.square(epsilon), 204 | reduction_indices=1, 205 | ) 206 | 207 | def prior_dist_info(self, batch_size): 208 | mean = tf.zeros([batch_size, self.dim]) 209 | stddev = tf.ones([batch_size, self.dim]) 210 | return dict(mean=mean, stddev=stddev) 211 | 212 | def nonreparam_logli(self, x_var, dist_info): 213 | return tf.zeros_like(x_var[:, 0]) 214 | 215 | def kl(self, p, q): 216 | p_mean = p["mean"] 217 | p_stddev = p["stddev"] 218 | q_mean = q["mean"] 219 | q_stddev = q["stddev"] 220 | # means: (N*D) 221 | # std: (N*D) 222 | # formula: 223 | # { (\mu_1 - \mu_2)^2 + \sigma_1^2 - \sigma_2^2 } / (2\sigma_2^2) + ln(\sigma_2/\sigma_1) 224 | numerator = tf.square(p_mean - q_mean) + tf.square(p_stddev) - tf.square(q_stddev) 225 | denominator = 2. * tf.square(q_stddev) 226 | return tf.reduce_sum( 227 | numerator / (denominator + TINY) + tf.log(q_stddev + TINY) - tf.log(p_stddev + TINY), 228 | reduction_indices=1 229 | ) 230 | 231 | def sample(self, dist_info): 232 | mean = dist_info["mean"] 233 | stddev = dist_info["stddev"] 234 | epsilon = tf.random_normal(tf.shape(mean)) 235 | return mean + epsilon * stddev 236 | 237 | @property 238 | def dist_info_keys(self): 239 | return ["mean", "stddev"] 240 | 241 | def activate_dist(self, flat_dist): 242 | mean = flat_dist[:, :self.dim] 243 | if self._fix_std: 244 | stddev = tf.ones_like(mean) 245 | else: 246 | stddev = tf.sqrt(tf.exp(flat_dist[:, self.dim:])) 247 | return dict(mean=mean, stddev=stddev) 248 | 249 | 250 | class Uniform(Gaussian): 251 | """ 252 | This distribution will sample prior data from a uniform distribution, but 253 | the prior and posterior are still modeled as a Gaussian 254 | """ 255 | 256 | def kl_prior(self): 257 | raise NotImplementedError 258 | 259 | # def prior_dist_info(self, batch_size): 260 | # raise NotImplementedError 261 | 262 | # def logli_prior(self, x_var): 263 | # # 264 | # raise NotImplementedError 265 | 266 | def sample_prior(self, batch_size): 267 | return tf.random_uniform([batch_size, self.dim], minval=-1., maxval=1.) 268 | 269 | 270 | class Bernoulli(Distribution): 271 | def __init__(self, dim): 272 | self._dim = dim 273 | 274 | @property 275 | def dim(self): 276 | return self._dim 277 | 278 | @property 279 | def dist_flat_dim(self): 280 | return self._dim 281 | 282 | @property 283 | def effective_dim(self): 284 | return self._dim 285 | 286 | @property 287 | def dist_info_keys(self): 288 | return ["p"] 289 | 290 | def logli(self, x_var, dist_info): 291 | p = dist_info["p"] 292 | return tf.reduce_sum( 293 | x_var * tf.log(p + TINY) + (1.0 - x_var) * tf.log(1.0 - p + TINY), 294 | reduction_indices=1 295 | ) 296 | 297 | def nonreparam_logli(self, x_var, dist_info): 298 | return self.logli(x_var, dist_info) 299 | 300 | def activate_dist(self, flat_dist): 301 | return dict(p=tf.nn.sigmoid(flat_dist)) 302 | 303 | def sample(self, dist_info): 304 | p = dist_info["p"] 305 | return tf.cast(tf.less(tf.random_uniform(p.get_shape()), p), tf.float32) 306 | 307 | def prior_dist_info(self, batch_size): 308 | return dict(p=0.5 * tf.ones([batch_size, self.dim])) 309 | 310 | class MeanBernoulli(Bernoulli): 311 | """ 312 | Behaves almost the same as the usual Bernoulli distribution, except that when sampling from it, directly 313 | return the mean instead of sampling binary values 314 | """ 315 | 316 | def sample(self, dist_info): 317 | return dist_info["p"] 318 | 319 | def nonreparam_logli(self, x_var, dist_info): 320 | return tf.zeros_like(x_var[:, 0]) 321 | 322 | 323 | # class MeanCenteredUniform(MeanBernoulli): 324 | # """ 325 | # Behaves almost the same as the usual Bernoulli distribution, except that when sampling from it, directly 326 | # return the mean instead of sampling binary values 327 | # """ 328 | 329 | 330 | class Product(Distribution): 331 | def __init__(self, dists): 332 | """ 333 | :type dists: list[Distribution] 334 | """ 335 | self._dists = dists 336 | 337 | @property 338 | def dists(self): 339 | return list(self._dists) 340 | 341 | @property 342 | def dim(self): 343 | return sum(x.dim for x in self.dists) 344 | 345 | @property 346 | def effective_dim(self): 347 | return sum(x.effective_dim for x in self.dists) 348 | 349 | @property 350 | def dims(self): 351 | return [x.dim for x in self.dists] 352 | 353 | @property 354 | def dist_flat_dims(self): 355 | return [x.dist_flat_dim for x in self.dists] 356 | 357 | @property 358 | def dist_flat_dim(self): 359 | return sum(x.dist_flat_dim for x in self.dists) 360 | 361 | @property 362 | def dist_info_keys(self): 363 | ret = [] 364 | for idx, dist in enumerate(self.dists): 365 | for k in dist.dist_info_keys: 366 | ret.append("id_%d_%s" % (idx, k)) 367 | return ret 368 | 369 | def split_dist_info(self, dist_info): 370 | ret = [] 371 | for idx, dist in enumerate(self.dists): 372 | cur_dist_info = dict() 373 | for k in dist.dist_info_keys: 374 | cur_dist_info[k] = dist_info["id_%d_%s" % (idx, k)] 375 | ret.append(cur_dist_info) 376 | return ret 377 | 378 | def join_dist_infos(self, dist_infos): 379 | ret = dict() 380 | for idx, dist, dist_info_i in zip(itertools.count(), self.dists, dist_infos): 381 | for k in dist.dist_info_keys: 382 | ret["id_%d_%s" % (idx, k)] = dist_info_i[k] 383 | return ret 384 | 385 | def split_var(self, x): 386 | """ 387 | Split the tensor variable or value into per component. 388 | """ 389 | cum_dims = list(np.cumsum(self.dims)) 390 | out = [] 391 | for slice_from, slice_to, dist in zip([0] + cum_dims, cum_dims, self.dists): 392 | sliced = x[:, slice_from:slice_to] 393 | out.append(sliced) 394 | return out 395 | 396 | def join_vars(self, xs): 397 | """ 398 | Join the per component tensor variables into a whole tensor 399 | """ 400 | return tf.concat(1, xs) 401 | 402 | def split_dist_flat(self, dist_flat): 403 | """ 404 | Split flat dist info into per component 405 | """ 406 | cum_dims = list(np.cumsum(self.dist_flat_dims)) 407 | out = [] 408 | for slice_from, slice_to, dist in zip([0] + cum_dims, cum_dims, self.dists): 409 | sliced = dist_flat[:, slice_from:slice_to] 410 | out.append(sliced) 411 | return out 412 | 413 | def prior_dist_info(self, batch_size): 414 | ret = [] 415 | for dist_i in self.dists: 416 | ret.append(dist_i.prior_dist_info(batch_size)) 417 | return self.join_dist_infos(ret) 418 | 419 | def kl(self, p, q): 420 | ret = tf.constant(0.) 421 | for p_i, q_i, dist_i in zip(self.split_dist_info(p), self.split_dist_info(q), self.dists): 422 | ret += dist_i.kl(p_i, q_i) 423 | return ret 424 | 425 | def activate_dist(self, dist_flat): 426 | ret = dict() 427 | for idx, dist_flat_i, dist_i in zip(itertools.count(), self.split_dist_flat(dist_flat), self.dists): 428 | dist_info_i = dist_i.activate_dist(dist_flat_i) 429 | for k, v in dist_info_i.iteritems(): 430 | ret["id_%d_%s" % (idx, k)] = v 431 | return ret 432 | 433 | def sample(self, dist_info): 434 | ret = [] 435 | for dist_info_i, dist_i in zip(self.split_dist_info(dist_info), self.dists): 436 | ret.append(tf.cast(dist_i.sample(dist_info_i), tf.float32)) 437 | return tf.concat(1, ret) 438 | 439 | def sample_prior(self, batch_size): 440 | ret = [] 441 | for dist_i in self.dists: 442 | ret.append(tf.cast(dist_i.sample_prior(batch_size), tf.float32)) 443 | return tf.concat(1, ret) 444 | 445 | def logli(self, x_var, dist_info): 446 | ret = tf.constant(0.) 447 | for x_i, dist_info_i, dist_i in zip(self.split_var(x_var), self.split_dist_info(dist_info), self.dists): 448 | ret += dist_i.logli(x_i, dist_info_i) 449 | return ret 450 | 451 | def marginal_logli(self, x_var, dist_info): 452 | ret = tf.constant(0.) 453 | for x_i, dist_info_i, dist_i in zip(self.split_var(x_var), self.split_dist_info(dist_info), self.dists): 454 | ret += dist_i.marginal_logli(x_i, dist_info_i) 455 | return ret 456 | 457 | def entropy(self, dist_info): 458 | ret = tf.constant(0.) 459 | for dist_info_i, dist_i in zip(self.split_dist_info(dist_info), self.dists): 460 | ret += dist_i.entropy(dist_info_i) 461 | return ret 462 | 463 | def marginal_entropy(self, dist_info): 464 | ret = tf.constant(0.) 465 | for dist_info_i, dist_i in zip(self.split_dist_info(dist_info), self.dists): 466 | ret += dist_i.marginal_entropy(dist_info_i) 467 | return ret 468 | 469 | def nonreparam_logli(self, x_var, dist_info): 470 | ret = tf.constant(0.) 471 | for x_i, dist_info_i, dist_i in zip(self.split_var(x_var), self.split_dist_info(dist_info), self.dists): 472 | ret += dist_i.nonreparam_logli(x_i, dist_info_i) 473 | return ret 474 | -------------------------------------------------------------------------------- /infogan/misc/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | import errno 4 | import os 5 | 6 | 7 | def mkdir_p(path): 8 | try: 9 | os.makedirs(path) 10 | except OSError as exc: # Python >2.5 11 | if exc.errno == errno.EEXIST and os.path.isdir(path): 12 | pass 13 | else: 14 | raise 15 | -------------------------------------------------------------------------------- /infogan/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import -------------------------------------------------------------------------------- /infogan/models/regularized_gan.py: -------------------------------------------------------------------------------- 1 | from infogan.misc.distributions import Product, Distribution, Gaussian, Categorical, Bernoulli 2 | import prettytensor as pt 3 | import tensorflow as tf 4 | import infogan.misc.custom_ops 5 | from infogan.misc.custom_ops import leaky_rectify 6 | 7 | 8 | class RegularizedGAN(object): 9 | def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type): 10 | """ 11 | :type output_dist: Distribution 12 | :type latent_spec: list[(Distribution, bool)] 13 | :type batch_size: int 14 | :type network_type: string 15 | """ 16 | self.output_dist = output_dist 17 | self.latent_spec = latent_spec 18 | self.latent_dist = Product([x for x, _ in latent_spec]) 19 | self.reg_latent_dist = Product([x for x, reg in latent_spec if reg]) 20 | self.nonreg_latent_dist = Product([x for x, reg in latent_spec if not reg]) 21 | self.batch_size = batch_size 22 | self.network_type = network_type 23 | self.image_shape = image_shape 24 | assert all(isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists) 25 | 26 | self.reg_cont_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)]) 27 | self.reg_disc_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli))]) 28 | 29 | image_size = image_shape[0] 30 | if network_type == "mnist": 31 | with tf.variable_scope("d_net"): 32 | shared_template = \ 33 | (pt.template("input"). 34 | reshape([-1] + list(image_shape)). 35 | custom_conv2d(64, k_h=4, k_w=4). 36 | apply(leaky_rectify). 37 | custom_conv2d(128, k_h=4, k_w=4). 38 | conv_batch_norm(). 39 | apply(leaky_rectify). 40 | custom_fully_connected(1024). 41 | fc_batch_norm(). 42 | apply(leaky_rectify)) 43 | self.discriminator_template = shared_template.custom_fully_connected(1) 44 | self.encoder_template = \ 45 | (shared_template. 46 | custom_fully_connected(128). 47 | fc_batch_norm(). 48 | apply(leaky_rectify). 49 | custom_fully_connected(self.reg_latent_dist.dist_flat_dim)) 50 | 51 | with tf.variable_scope("g_net"): 52 | self.generator_template = \ 53 | (pt.template("input"). 54 | custom_fully_connected(1024). 55 | fc_batch_norm(). 56 | apply(tf.nn.relu). 57 | custom_fully_connected(image_size / 4 * image_size / 4 * 128). 58 | fc_batch_norm(). 59 | apply(tf.nn.relu). 60 | reshape([-1, image_size / 4, image_size / 4, 128]). 61 | custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4). 62 | conv_batch_norm(). 63 | apply(tf.nn.relu). 64 | custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4). 65 | flatten()) 66 | else: 67 | raise NotImplementedError 68 | 69 | def discriminate(self, x_var): 70 | d_out = self.discriminator_template.construct(input=x_var) 71 | d = tf.nn.sigmoid(d_out[:, 0]) 72 | reg_dist_flat = self.encoder_template.construct(input=x_var) 73 | reg_dist_info = self.reg_latent_dist.activate_dist(reg_dist_flat) 74 | return d, self.reg_latent_dist.sample(reg_dist_info), reg_dist_info, reg_dist_flat 75 | 76 | def generate(self, z_var): 77 | x_dist_flat = self.generator_template.construct(input=z_var) 78 | x_dist_info = self.output_dist.activate_dist(x_dist_flat) 79 | return self.output_dist.sample(x_dist_info), x_dist_info 80 | 81 | def disc_reg_z(self, reg_z_var): 82 | ret = [] 83 | for dist_i, z_i in zip(self.reg_latent_dist.dists, self.reg_latent_dist.split_var(reg_z_var)): 84 | if isinstance(dist_i, (Categorical, Bernoulli)): 85 | ret.append(z_i) 86 | return self.reg_disc_latent_dist.join_vars(ret) 87 | 88 | def cont_reg_z(self, reg_z_var): 89 | ret = [] 90 | for dist_i, z_i in zip(self.reg_latent_dist.dists, self.reg_latent_dist.split_var(reg_z_var)): 91 | if isinstance(dist_i, Gaussian): 92 | ret.append(z_i) 93 | return self.reg_cont_latent_dist.join_vars(ret) 94 | 95 | def disc_reg_dist_info(self, reg_dist_info): 96 | ret = [] 97 | for dist_i, dist_info_i in zip(self.reg_latent_dist.dists, self.reg_latent_dist.split_dist_info(reg_dist_info)): 98 | if isinstance(dist_i, (Categorical, Bernoulli)): 99 | ret.append(dist_info_i) 100 | return self.reg_disc_latent_dist.join_dist_infos(ret) 101 | 102 | def cont_reg_dist_info(self, reg_dist_info): 103 | ret = [] 104 | for dist_i, dist_info_i in zip(self.reg_latent_dist.dists, self.reg_latent_dist.split_dist_info(reg_dist_info)): 105 | if isinstance(dist_i, Gaussian): 106 | ret.append(dist_info_i) 107 | return self.reg_cont_latent_dist.join_dist_infos(ret) 108 | 109 | def reg_z(self, z_var): 110 | ret = [] 111 | for (_, reg_i), z_i in zip(self.latent_spec, self.latent_dist.split_var(z_var)): 112 | if reg_i: 113 | ret.append(z_i) 114 | return self.reg_latent_dist.join_vars(ret) 115 | 116 | def nonreg_z(self, z_var): 117 | ret = [] 118 | for (_, reg_i), z_i in zip(self.latent_spec, self.latent_dist.split_var(z_var)): 119 | if not reg_i: 120 | ret.append(z_i) 121 | return self.nonreg_latent_dist.join_vars(ret) 122 | 123 | def reg_dist_info(self, dist_info): 124 | ret = [] 125 | for (_, reg_i), dist_info_i in zip(self.latent_spec, self.latent_dist.split_dist_info(dist_info)): 126 | if reg_i: 127 | ret.append(dist_info_i) 128 | return self.reg_latent_dist.join_dist_infos(ret) 129 | 130 | def nonreg_dist_info(self, dist_info): 131 | ret = [] 132 | for (_, reg_i), dist_info_i in zip(self.latent_spec, self.latent_dist.split_dist_info(dist_info)): 133 | if not reg_i: 134 | ret.append(dist_info_i) 135 | return self.nonreg_latent_dist.join_dist_infos(ret) 136 | 137 | def combine_reg_nonreg_z(self, reg_z_var, nonreg_z_var): 138 | reg_z_vars = self.reg_latent_dist.split_var(reg_z_var) 139 | reg_idx = 0 140 | nonreg_z_vars = self.nonreg_latent_dist.split_var(nonreg_z_var) 141 | nonreg_idx = 0 142 | ret = [] 143 | for idx, (dist_i, reg_i) in enumerate(self.latent_spec): 144 | if reg_i: 145 | ret.append(reg_z_vars[reg_idx]) 146 | reg_idx += 1 147 | else: 148 | ret.append(nonreg_z_vars[nonreg_idx]) 149 | nonreg_idx += 1 150 | return self.latent_dist.join_vars(ret) 151 | 152 | def combine_reg_nonreg_dist_info(self, reg_dist_info, nonreg_dist_info): 153 | reg_dist_infos = self.reg_latent_dist.split_dist_info(reg_dist_info) 154 | reg_idx = 0 155 | nonreg_dist_infos = self.nonreg_latent_dist.split_dist_info(nonreg_dist_info) 156 | nonreg_idx = 0 157 | ret = [] 158 | for idx, (dist_i, reg_i) in enumerate(self.latent_spec): 159 | if reg_i: 160 | ret.append(reg_dist_infos[reg_idx]) 161 | reg_idx += 1 162 | else: 163 | ret.append(nonreg_dist_infos[nonreg_idx]) 164 | nonreg_idx += 1 165 | return self.latent_dist.join_dist_infos(ret) 166 | -------------------------------------------------------------------------------- /launchers/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import -------------------------------------------------------------------------------- /launchers/run_mnist_exp.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from infogan.misc.distributions import Uniform, Categorical, Gaussian, MeanBernoulli 4 | 5 | import tensorflow as tf 6 | import os 7 | from infogan.misc.datasets import MnistDataset 8 | from infogan.models.regularized_gan import RegularizedGAN 9 | from infogan.algos.infogan_trainer import InfoGANTrainer 10 | from infogan.misc.utils import mkdir_p 11 | import dateutil 12 | import dateutil.tz 13 | import datetime 14 | 15 | if __name__ == "__main__": 16 | 17 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 18 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 19 | 20 | root_log_dir = "logs/mnist" 21 | root_checkpoint_dir = "ckt/mnist" 22 | batch_size = 128 23 | updates_per_epoch = 100 24 | max_epoch = 50 25 | 26 | exp_name = "mnist_%s" % timestamp 27 | 28 | log_dir = os.path.join(root_log_dir, exp_name) 29 | checkpoint_dir = os.path.join(root_checkpoint_dir, exp_name) 30 | 31 | mkdir_p(log_dir) 32 | mkdir_p(checkpoint_dir) 33 | 34 | dataset = MnistDataset() 35 | 36 | latent_spec = [ 37 | (Uniform(62), False), 38 | (Categorical(10), True), 39 | (Uniform(1, fix_std=True), True), 40 | (Uniform(1, fix_std=True), True), 41 | ] 42 | 43 | model = RegularizedGAN( 44 | output_dist=MeanBernoulli(dataset.image_dim), 45 | latent_spec=latent_spec, 46 | batch_size=batch_size, 47 | image_shape=dataset.image_shape, 48 | network_type="mnist", 49 | ) 50 | 51 | algo = InfoGANTrainer( 52 | model=model, 53 | dataset=dataset, 54 | batch_size=batch_size, 55 | exp_name=exp_name, 56 | log_dir=log_dir, 57 | checkpoint_dir=checkpoint_dir, 58 | max_epoch=max_epoch, 59 | updates_per_epoch=updates_per_epoch, 60 | info_reg_coeff=1.0, 61 | generator_learning_rate=1e-3, 62 | discriminator_learning_rate=2e-4, 63 | ) 64 | 65 | algo.train() 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | prettytensor 2 | progressbar 3 | python-dateutil 4 | ipdb 5 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import -------------------------------------------------------------------------------- /tests/test_distributions.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | from nose2.tools import such 5 | from misc.distributions import Categorical, Gaussian, Product, Bernoulli 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | sess = tf.Session() 10 | 11 | 12 | def random_softmax(ndim): 13 | x = np.random.uniform(size=(ndim,)) 14 | x = x - np.max(x) 15 | x = np.exp(x) / np.sum(np.exp(x)) 16 | return np.cast['float32'](x) 17 | 18 | 19 | with such.A("Product Distribution") as it: 20 | dist1 = Product([Categorical(5), Categorical(3)]) 21 | dist2 = Product([Gaussian(5), dist1]) 22 | 23 | 24 | @it.should 25 | def test_dist_info_keys(): 26 | it.assertEqual(set(dist1.dist_info_keys), {"id_0_prob", "id_1_prob"}) 27 | it.assertEqual(set(dist2.dist_info_keys), {"id_0_mean", "id_0_stddev", 28 | "id_1_id_0_prob", "id_1_id_1_prob"}) 29 | 30 | 31 | @it.should 32 | def test_kl_sym(): 33 | old_id_0_prob = np.array([random_softmax(5)]) 34 | old_id_1_prob = np.array([random_softmax(3)]) 35 | new_id_0_prob = np.array([random_softmax(5)]) 36 | new_id_1_prob = np.array([random_softmax(3)]) 37 | old_dist_info_vars = dict( 38 | id_0_prob=tf.constant(old_id_0_prob), 39 | id_1_prob=tf.constant(old_id_1_prob) 40 | ) 41 | new_dist_info_vars = dict( 42 | id_0_prob=tf.constant(new_id_0_prob), 43 | id_1_prob=tf.constant(new_id_1_prob) 44 | ) 45 | np.testing.assert_allclose( 46 | dist1.kl(old_dist_info_vars, new_dist_info_vars).eval(session=sess), 47 | Categorical(5).kl(dict(prob=old_id_0_prob), dict(prob=new_id_0_prob)).eval(session=sess) + 48 | Categorical(3).kl(dict(prob=old_id_1_prob), dict(prob=new_id_1_prob)).eval(session=sess) 49 | ) 50 | 51 | it.createTests(globals()) 52 | 53 | with such.A("Categorical") as it: 54 | @it.should 55 | def test_categorical(): 56 | cat = Categorical(3) 57 | new_prob = np.array( 58 | [random_softmax(3), random_softmax(3)], 59 | ) 60 | old_prob = np.array( 61 | [random_softmax(3), random_softmax(3)], 62 | ) 63 | 64 | x = np.array([ 65 | [0, 1, 0], 66 | [0, 0, 1], 67 | ], dtype=np.float32) 68 | 69 | new_prob_sym = tf.constant(new_prob) 70 | old_prob_sym = tf.constant(old_prob) 71 | 72 | x_sym = tf.constant(x) 73 | 74 | new_info_sym = dict(prob=new_prob_sym) 75 | old_info_sym = dict(prob=old_prob_sym) 76 | 77 | np.testing.assert_allclose( 78 | cat.kl(new_info_sym, new_info_sym).eval(session=sess), 79 | np.array([0., 0.]) 80 | ) 81 | np.testing.assert_allclose( 82 | cat.kl(old_info_sym, new_info_sym).eval(session=sess), 83 | np.sum(old_prob * (np.log(old_prob + 1e-8) - np.log(new_prob + 1e-8)), axis=-1) 84 | ) 85 | np.testing.assert_allclose( 86 | cat.logli(x_sym, old_info_sym).eval(session=sess), 87 | [np.log(old_prob[0][1] + 1e-8), np.log(old_prob[1][2] + 1e-8)], 88 | rtol=1e-5 89 | ) 90 | 91 | it.createTests(globals()) 92 | 93 | with such.A("Bernoulli") as it: 94 | @it.should 95 | def test_bernoulli(): 96 | bernoulli = Bernoulli(3) 97 | 98 | new_p = np.array([[0.5, 0.5, 0.5], [.9, .9, .9]], dtype=np.float32) 99 | old_p = np.array([[.9, .9, .9], [.1, .1, .1]], dtype=np.float32) 100 | 101 | x = np.array([[1, 0, 1], [1, 1, 1]], dtype=np.float32) 102 | 103 | x_sym = tf.constant(x) 104 | new_p_sym = tf.constant(new_p) 105 | old_p_sym = tf.constant(old_p) 106 | 107 | new_info = dict(p=new_p) 108 | old_info = dict(p=old_p) 109 | 110 | new_info_sym = dict(p=new_p_sym) 111 | old_info_sym = dict(p=old_p_sym) 112 | 113 | # np.testing.assert_allclose( 114 | # np.sum(bernoulli.entropy(dist_info=new_info)), 115 | # np.sum(- new_p * np.log(new_p + 1e-8) - (1 - new_p) * np.log(1 - new_p + 1e-8)), 116 | # ) 117 | 118 | # np.testing.assert_allclose( 119 | # np.sum(bernoulli.kl(old_info_sym, new_info_sym).eval()), 120 | # np.sum(old_p * (np.log(old_p + 1e-8) - np.log(new_p + 1e-8)) + (1 - old_p) * (np.log(1 - old_p + 1e-8) - 121 | # np.log(1 - new_p + 1e-8))), 122 | # ) 123 | # np.testing.assert_allclose( 124 | # np.sum(bernoulli.kl(old_info, new_info)), 125 | # np.sum(old_p * (np.log(old_p + 1e-8) - np.log(new_p + 1e-8)) + (1 - old_p) * (np.log(1 - old_p + 1e-8) - 126 | # np.log(1 - new_p + 1e-8))), 127 | # ) 128 | # np.testing.assert_allclose( 129 | # bernoulli.likelihood_ratio_sym(x_sym, old_info_sym, new_info_sym).eval(), 130 | # np.prod((x * new_p + (1 - x) * (1 - new_p)) / (x * old_p + (1 - x) * (1 - old_p) + 1e-8), axis=-1) 131 | # ) 132 | np.testing.assert_allclose( 133 | bernoulli.logli(x_sym, old_info_sym).eval(session=sess), 134 | np.sum(x * np.log(old_p + 1e-8) + (1 - x) * np.log(1 - old_p + 1e-8), axis=-1) 135 | ) 136 | # np.testing.assert_allclose( 137 | # bernoulli.log_likelihood(x, old_info), 138 | # np.sum(x * np.log(old_p + 1e-8) + (1 - x) * np.log(1 - old_p + 1e-8), axis=-1) 139 | # ) 140 | 141 | it.createTests(globals()) 142 | --------------------------------------------------------------------------------