├── .gitignore ├── assets ├── digit_style_0.png ├── digit_style_1.png ├── digit_style_2.png ├── captured_digit_style.png ├── decorrelator-network-01.png ├── digit_style_interpolation_0.png ├── digit_style_interpolation_1.png ├── digit_style_interpolation_2.png └── digit_style_interpolation_3.png ├── LICENSE ├── README.md ├── components.py ├── main_aae_semi.py ├── data.py └── aae_semi.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | 4 | log.txt -------------------------------------------------------------------------------- /assets/digit_style_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/HEAD/assets/digit_style_0.png -------------------------------------------------------------------------------- /assets/digit_style_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/HEAD/assets/digit_style_1.png -------------------------------------------------------------------------------- /assets/digit_style_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/HEAD/assets/digit_style_2.png -------------------------------------------------------------------------------- /assets/captured_digit_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/HEAD/assets/captured_digit_style.png -------------------------------------------------------------------------------- /assets/decorrelator-network-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/HEAD/assets/decorrelator-network-01.png -------------------------------------------------------------------------------- /assets/digit_style_interpolation_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/HEAD/assets/digit_style_interpolation_0.png -------------------------------------------------------------------------------- /assets/digit_style_interpolation_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/HEAD/assets/digit_style_interpolation_1.png -------------------------------------------------------------------------------- /assets/digit_style_interpolation_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/HEAD/assets/digit_style_interpolation_2.png -------------------------------------------------------------------------------- /assets/digit_style_interpolation_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/HEAD/assets/digit_style_interpolation_3.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This source code in this repository is made available under the MIT License: 2 | 3 | Copyright 2017, Patrick Gadd 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # decorrelated-adversarial-autoencoder 2 | Tensorflow implementation of Adversarial Autoencoders (with extra option to decorrelate style and classes) 3 | 4 | ## How it works 5 | 6 | This is a semi-supervised adversarial autoencoder as described in the original paper by Allreza Makhzani et al ([arxiv.org/abs/1511.05644](https://arxiv.org/abs/1511.05644)) with the following modification: 7 | 8 | I added the possibility to de-correlate the style and class to allow for nice sampling of the style space, and because without this, q(y|X) could be predicted with 60-80% accuracy from q(z|X), meaning that they were heavily correlated and that q(z|X) did not only contain information about the style, but also about the digit in the image, effectively making class-information surpass q(y|X). Adding de-correlation results in that the classifying part of the network only achieve 10% accuracy in predicting q(y|X) from q(z|X), which can be interpreted as there being no correlation at all for the 10 MNIST classes. 9 | 10 | ![Network architecture](https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/master/assets/decorrelator-network-01.png) 11 | 12 | (modified version of Figure 8 from the original paper) 13 | 14 | # Examples 15 | 16 | The following examples were generated by a network trained in a semi-supervised fashion on just 10 labelled examples of each of the MNIST-classes, as well as all of the unlabelled training data. The hyperparameters for this training were as in the available source code. 17 | 18 | ## Capturing the style of real examples 19 | 20 | Using existing MNIST examples, **X**, one can feed these through to capture their style q(**z** | **X**) and generate all the possible digits from 0 to 9 in the same style. 21 | 22 | The first column below contains samples from the MNIST dataset, and each row next to these samples is the network's interpretation of their style applied to the 10 digits: 23 | 24 | ![Capturing of the style of MNIST digits](https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/master/assets/captured_digit_style.png) 25 | 26 | ## Sampling the style space 27 | 28 | As **z**, the style space, is an N-dimensional space with values independently distributed as Gaussian distributions with mean 0, and variance 1, this can be sampled and the digits from 0 to 9 can be generated in the same style: 29 | 30 | ![Randomly styled digits, no. 1](https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/master/assets/digit_style_0.png) 31 | 32 | ![Randomly styled digits, no. 2](https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/master/assets/digit_style_1.png) 33 | 34 | ![Randomly styled digits, no. 3](https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/master/assets/digit_style_2.png) 35 | 36 | ## Style-interpolation. 37 | 38 | This can be done by sampling two points in **z** and interpolating linearly (or otherwise) between them: 39 | 40 | ![Interpolation between random styles, no. 1](https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/master/assets/digit_style_interpolation_0.png) 41 | 42 | ![Interpolation between random styles, no. 2](https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/master/assets/digit_style_interpolation_1.png) 43 | 44 | ![Interpolation between random styles, no. 3](https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/master/assets/digit_style_interpolation_2.png) 45 | 46 | ![Interpolation between random styles, no. 4](https://raw.githubusercontent.com/patrickgadd/decorrelated-adversarial-autoencoder/master/assets/digit_style_interpolation_3.png) -------------------------------------------------------------------------------- /components.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import layers 3 | 4 | def semi_supervised_encoder_convolutional(input_tensor, z_dim, y_dim, batch_size, network_scale=1.0, img_res=28, img_channels=1): 5 | f_multiplier = network_scale 6 | 7 | net = tf.reshape(input_tensor, [-1, img_res, img_res, img_channels]) 8 | 9 | net = layers.conv2d(net, int(16*f_multiplier), 3, stride=2) 10 | net = layers.conv2d(net, int(16*f_multiplier), 3, stride=1) 11 | net = layers.conv2d(net, int(32*f_multiplier), 3, stride=2) 12 | net = layers.conv2d(net, int(32*f_multiplier), 3, stride=1) 13 | net = layers.conv2d(net, int(64*f_multiplier), 3, stride=2) 14 | net = layers.conv2d(net, int(64*f_multiplier), 3, stride=1) 15 | net = layers.conv2d(net, int(128*f_multiplier), 3, stride=2) 16 | 17 | net = tf.reshape(net, [batch_size, -1]) 18 | net = layers.fully_connected(net, 1000) 19 | 20 | y = layers.fully_connected(net, y_dim, activation_fn=None, normalizer_fn=None) 21 | 22 | z = layers.fully_connected(net, z_dim, activation_fn=None) 23 | 24 | return y, z 25 | 26 | def semi_supervised_encoder_fully_connected(input_tensor, z_dim, y_dim, network_scale=1.0): 27 | hidden_size = int(1000 * network_scale) 28 | net = layers.fully_connected(input_tensor, hidden_size) 29 | net = layers.fully_connected(net, hidden_size) 30 | 31 | y = layers.fully_connected(net, y_dim, activation_fn=None, normalizer_fn=None) 32 | 33 | z = layers.fully_connected(net, z_dim, activation_fn=None) 34 | 35 | return y, z 36 | 37 | def semi_supervised_encoder(input_tensor, z_dim, y_dim, batch_size, do_convolutional, network_scale=1.0, img_res=28, img_channels=1): 38 | if do_convolutional: 39 | return semi_supervised_encoder_convolutional(input_tensor, z_dim, y_dim, batch_size, network_scale, img_res, img_channels) 40 | else: 41 | return semi_supervised_encoder_fully_connected(input_tensor, z_dim, y_dim, network_scale) 42 | 43 | def semi_supervised_decoder_convolutional(input_tensor, batch_size, n_dimensions, network_scale=1.0, img_res=28, img_channels=1): 44 | f_multiplier = network_scale 45 | 46 | net = layers.fully_connected(input_tensor, 2*2*int(128*f_multiplier)) 47 | net = tf.reshape(net, [-1, 2, 2, int(128*f_multiplier)]) 48 | 49 | assert(img_res in [28, 32]) 50 | 51 | if img_res==28: 52 | net = layers.conv2d_transpose(net, int(64*f_multiplier), 3, stride=2) 53 | net = layers.conv2d_transpose(net, int(64*f_multiplier), 3, stride=1) 54 | net = layers.conv2d_transpose(net, int(32*f_multiplier), 4, stride=1, padding='VALID') 55 | net = layers.conv2d_transpose(net, int(32*f_multiplier), 4, stride=1) 56 | net = layers.conv2d_transpose(net, int(16*f_multiplier), 3, stride=2) 57 | net = layers.conv2d_transpose(net, int(16*f_multiplier), 3, stride=1) 58 | net = layers.conv2d_transpose(net, int(8*f_multiplier), 3, stride=2) 59 | net = layers.conv2d_transpose(net, int(8*f_multiplier), 3, stride=1) 60 | else: 61 | net = layers.conv2d_transpose(net, int(64*f_multiplier), 3, stride=2) 62 | net = layers.conv2d_transpose(net, int(64*f_multiplier), 3, stride=1) 63 | net = layers.conv2d_transpose(net, int(32*f_multiplier), 3, stride=2) 64 | net = layers.conv2d_transpose(net, int(32*f_multiplier), 3, stride=1) 65 | net = layers.conv2d_transpose(net, int(16*f_multiplier), 3, stride=2) 66 | net = layers.conv2d_transpose(net, int(16*f_multiplier), 3, stride=1) 67 | net = layers.conv2d_transpose(net, int(8*f_multiplier), 3, stride=2) 68 | net = layers.conv2d_transpose(net, int(8*f_multiplier), 3, stride=1) 69 | 70 | net = layers.conv2d_transpose(net, img_channels, 5, stride=1, activation_fn=tf.nn.sigmoid) 71 | net = layers.flatten(net) 72 | 73 | return net 74 | 75 | 76 | def semi_supervised_decoder_fully_connected(input_tensor, batch_size, n_dimensions, network_scale=1.0, img_res=28, img_channels=1): 77 | output_size = img_res*img_res*img_channels 78 | n_hid = int(1000*network_scale) 79 | 80 | net = layers.fully_connected(input_tensor, n_hid) 81 | net = layers.fully_connected(net, n_hid) 82 | 83 | net = layers.fully_connected(net, output_size, activation_fn=tf.nn.sigmoid) 84 | 85 | return net 86 | 87 | 88 | def semi_supervised_decoder(input_tensor, batch_size, n_dimensions, do_convolutional, network_scale=1.0, img_res=28, img_channels=1): 89 | if do_convolutional: 90 | return semi_supervised_decoder_convolutional(input_tensor, batch_size, n_dimensions, network_scale, img_res, img_channels) 91 | else: 92 | return semi_supervised_decoder_fully_connected(input_tensor, batch_size, n_dimensions, network_scale, img_res, img_channels) 93 | 94 | def aa_discriminator(input_tensor, batch_size, n_dimensions): 95 | n_hid = 1000 96 | 97 | net = layers.fully_connected(input_tensor, n_hid) 98 | net = layers.fully_connected(net, n_hid) 99 | 100 | return layers.fully_connected(net, 2, activation_fn=None) 101 | 102 | def correlation_classifier(input_tensor, batch_size, n_classes=10): 103 | n_hid = 1000 104 | 105 | net = layers.fully_connected(input_tensor, n_hid) 106 | net = layers.fully_connected(net, n_hid) 107 | net = layers.fully_connected(net, n_classes, activation_fn=None) 108 | 109 | return net -------------------------------------------------------------------------------- /main_aae_semi.py: -------------------------------------------------------------------------------- 1 | '''TensorFlow implementation of https://arxiv.org/abs/1511.05644 (with variations)''' 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import os 6 | import tensorflow as tf 7 | from tensorflow.contrib.layers import batch_norm 8 | 9 | from progressbar import ETA, Bar, Percentage, ProgressBar 10 | 11 | from data import Data 12 | from aae_semi import AAE_Semi 13 | 14 | flags = tf.flags 15 | logging = tf.logging 16 | 17 | flags.DEFINE_integer("batch_size", 100, "batch size") 18 | flags.DEFINE_integer("max_epoch", 150, "max epoch") 19 | flags.DEFINE_integer("updates_per_iteration", 500, "number of updates per iteration") 20 | flags.DEFINE_float("learning_rate", 0.0003, "learning rate") 21 | flags.DEFINE_string("working_directory", "./data", "the directory in which the results will be stored") 22 | flags.DEFINE_integer("z_dim", 9, "dimensionality of the z space") 23 | flags.DEFINE_float("network_scale", 2.0, "scaling the number of neurons/filters in the network") 24 | flags.DEFINE_float("decorrelation_importance", 0.5, "The importance of the de-correlation of q(y|X) and q(z|X)") 25 | flags.DEFINE_integer("cnt_per_class", 10, "Number of labelled examples per class") 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | if __name__ == "__main__": 30 | # is 10 for both MNIST and SVHN 31 | n_classes = 10 32 | 33 | updates_per_epoch = 50000 / FLAGS.batch_size 34 | max_iteration = int(FLAGS.max_epoch * updates_per_epoch / FLAGS.updates_per_iteration ) 35 | 36 | dataset = 'MNIST' 37 | data = Data(dataset=dataset, n_classes=n_classes, cnt_per_class=FLAGS.cnt_per_class, 38 | working_directory=FLAGS.working_directory, batch_size=FLAGS.batch_size) 39 | if dataset == 'MNIST': 40 | img_res = 28 41 | img_channels = 1 42 | elif dataset == 'SVHN': 43 | img_res = 32 44 | img_channels = 3 45 | else: 46 | assert(False) 47 | 48 | # Important: Even if batch-norm is on, it's not applied for q(y|X) (personal taste) 49 | normalizer_fn = batch_norm # alternatively, use None 50 | 51 | log_path = 'log.txt' 52 | with open(log_path, 'a') as log: 53 | log.write('epoch\ttest acc.\ttrain acc.\ty from z acc.\n') 54 | 55 | model = AAE_Semi(n_classes, FLAGS.z_dim, FLAGS.batch_size, normalizer_fn, 56 | decorr_scale=FLAGS.decorrelation_importance, network_scale=FLAGS.network_scale, 57 | img_res=img_res, img_channels=img_channels) 58 | 59 | learning_rate = FLAGS.learning_rate 60 | 61 | saver = tf.train.Saver() 62 | ckpt = tf.train.get_checkpoint_state('checkpoints/') # get latest checkpoint (if any) 63 | if ckpt and ckpt.model_checkpoint_path: 64 | # if checkpoint exists, restore the parameters and set epoch_n and i_iter 65 | saver.restore(model.sess, ckpt.model_checkpoint_path) 66 | start_iteration_n = int(ckpt.model_checkpoint_path.split('-')[1]) 67 | print('Restored iteration no.: {0}'.format(start_iteration_n)) 68 | else: 69 | # no checkpoint exists. create checkpoints directory if it does not exist. 70 | if not os.path.exists('checkpoints'): 71 | os.makedirs('checkpoints') 72 | if tf.__version__ == '0.10.0': 73 | init = tf.initialize_all_variables() 74 | else: 75 | init = tf.global_variables_initializer() 76 | model.sess.run(init) 77 | start_iteration_n = 0 78 | 79 | 80 | 81 | for iteration_n in range(start_iteration_n, max_iteration+1 ): 82 | epoch = iteration_n * FLAGS.updates_per_iteration / updates_per_epoch 83 | print('Beginning epoch {0}'.format(epoch)) 84 | 85 | if epoch >= 20: 86 | learning_rate = 0.00003 87 | if epoch >= 50: 88 | learning_rate = 0.000003 89 | 90 | reconstruction_loss = 0.0 91 | discriminative_loss = 0.0 92 | generative_loss = 0.0 93 | classification_loss = 0.0 94 | corr_classification_loss = 0.0 95 | decorr_classification_loss = 0.0 96 | 97 | pbar = ProgressBar() 98 | for i in pbar(range(FLAGS.updates_per_iteration)): 99 | img_batch_unlabelled, _ = data.get_random_minibatch(FLAGS.batch_size, n_classes, purpose='train') 100 | img_batch_labelled, y_batch_labelled = data.get_random_minibatch(FLAGS.batch_size, n_classes, purpose='train_few') 101 | 102 | 103 | loss_value = model.reconstruction_phase(img_batch_unlabelled, learning_rate) 104 | reconstruction_loss += loss_value 105 | 106 | loss_value = model.discriminator_phase(img_batch_unlabelled, learning_rate) 107 | discriminative_loss += loss_value 108 | 109 | loss_value = model.generator_phase(img_batch_unlabelled, learning_rate) 110 | generative_loss += loss_value 111 | 112 | loss_value = model.supervised_phase(img_batch_labelled, y_batch_labelled, learning_rate) 113 | classification_loss += loss_value 114 | 115 | loss_value = model.correlation_classifier_phase(img_batch_unlabelled, learning_rate) 116 | corr_classification_loss += loss_value 117 | if epoch >= 1: 118 | loss_value = model.decorrelation_phase(img_batch_unlabelled, learning_rate) 119 | decorr_classification_loss += loss_value 120 | 121 | 122 | if int(epoch * 100) % 20 == 0: 123 | reconstruction_loss = reconstruction_loss / (FLAGS.updates_per_iteration * FLAGS.batch_size) 124 | print('Reconstruction loss: {0}'.format(reconstruction_loss)) 125 | 126 | discriminative_loss = discriminative_loss / (FLAGS.updates_per_iteration * FLAGS.batch_size) 127 | print('Discriminative loss: {0}'.format(discriminative_loss)) 128 | 129 | generative_loss = generative_loss / (FLAGS.updates_per_iteration * FLAGS.batch_size) 130 | print('Generative loss: {0}'.format(generative_loss)) 131 | 132 | classification_loss = classification_loss / (FLAGS.updates_per_iteration * FLAGS.batch_size) 133 | print('Classification loss: {0}'.format(classification_loss)) 134 | 135 | corr_classification_loss = corr_classification_loss / (FLAGS.updates_per_iteration * FLAGS.batch_size) 136 | print('Corr-Classification loss: {0}'.format(corr_classification_loss)) 137 | 138 | # Printing some stats about q(y|X) and q(z|X) 139 | img_batch_unlabelled, _ = data.get_random_minibatch(FLAGS.batch_size, n_classes, purpose='train') 140 | model.test_print_q_z_given_x(img_batch_unlabelled) 141 | 142 | # Produce imagery 143 | X, y = data.get_first_x_mnist(FLAGS.batch_size, n_classes) 144 | model.generate_similar_style(X, y, FLAGS.batch_size, FLAGS.working_directory, img_res, img_channels, n_classes, FLAGS.z_dim) 145 | model.generate_digits(FLAGS.batch_size, FLAGS.working_directory, img_res, img_channels, n_classes, FLAGS.z_dim) 146 | model.interpolate_digits(FLAGS.batch_size, FLAGS.working_directory, img_res, img_channels, n_classes, FLAGS.z_dim) 147 | 148 | # Logging 149 | if int(epoch * 100) % 50 == 0: 150 | print('Computing the accuracies for train and test') 151 | test_acc_n = 10 152 | test_acc_sum = 0.0 153 | for i in range(test_acc_n): 154 | images, y_ = data.get_random_minibatch(FLAGS.batch_size, n_classes, purpose='test') 155 | test_acc_sum += model.compute_accuracy(images, y_) 156 | test_acc = 100*test_acc_sum/float(test_acc_n) 157 | print('Avg. acc for {0} test samples: {1:.2f} %'.format(FLAGS.batch_size*test_acc_n, test_acc)) 158 | 159 | train_acc_n = 5 160 | train_acc_sum = 0.0 161 | for i in range(train_acc_n): 162 | images, y_ = data.get_random_minibatch(FLAGS.batch_size, n_classes, purpose='train_few') 163 | train_acc_sum += model.compute_accuracy(images, y_) 164 | train_acc = 100*train_acc_sum/float(train_acc_n) 165 | print('Avg. acc for {0} training samples: {1:.2f} %'.format(FLAGS.batch_size*test_acc_n, train_acc)) 166 | 167 | zy_acc_n = 5 168 | zy_acc_sum = 0.0 169 | for i in range(zy_acc_n): 170 | images, y_ = data.get_random_minibatch(FLAGS.batch_size, n_classes, purpose='train') 171 | zy_acc_sum += model.compute_accuracy_2(images) 172 | zy_acc = 100*zy_acc_sum/float(zy_acc_n) 173 | print('Avg. acc for {0} classification samples: {1:.2f} %'.format(FLAGS.batch_size*test_acc_n, zy_acc)) 174 | 175 | with open(log_path, 'a') as log: 176 | log.write('{0}\t{1}\t{2}\t{3}\n'.format(epoch, test_acc, train_acc, zy_acc)) 177 | 178 | # Saving the network 179 | if int(epoch * 100) % 100 == 0: 180 | print('Saving the model') 181 | saver.save(model.sess, 'checkpoints/model.ckpt', iteration_n) 182 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from tensorflow.examples.tutorials.mnist import input_data 2 | import os 3 | import scipy.io as sio 4 | import numpy as np 5 | from random import sample 6 | from random import uniform 7 | 8 | def one_hot_to_int(one_hot): 9 | # the digits are encoded as 0 = idx0, 9 == idx9, all good and straightforward 10 | for i in range(len(one_hot)): 11 | if one_hot[i] > 0.1: # greater than zero, really. The >0.1 is float-paranoia 12 | return i 13 | 14 | print(one_hot) 15 | assert(False) 16 | return -1 17 | 18 | 19 | def extract_few_mnist_labels(cnt_per_class, n_classes, batch_size, mnist): 20 | # NB: there's a risk of repeated samples because of the way this is implemented 21 | print('Creating a subset of labelled data for the semi-supervised learning.') 22 | def one_hot_to_int(one_hot): 23 | # the digits are encoded as 0 = idx0, 9 == idx9, all good and straightforward 24 | for i in range(len(one_hot)): 25 | if one_hot[i] > 0.1: # greater than zero, really. The >0.1 is float-paranoia 26 | return i 27 | 28 | assert(False) 29 | return -1 30 | 31 | label_cnts = np.zeros((n_classes)) 32 | labelled_cnt = cnt_per_class * n_classes 33 | few_annotated_imgs = np.zeros((labelled_cnt, 28*28)) 34 | few_annotated_ys = np.zeros((labelled_cnt, n_classes)) 35 | idx = 0 36 | 37 | test_cnt = 0 38 | # For training the model on different *random* annotated subsets 39 | # Not optimal, but better than using the same labels and it's been easy to implement 40 | skip_prob = 0.999 41 | 42 | am_i_done = False 43 | while not am_i_done: 44 | test_cnt += batch_size 45 | images, y_ = mnist.train.next_batch(batch_size) 46 | for i in range(y_.shape[0]): 47 | 48 | label_int = one_hot_to_int(y_[i]) 49 | 50 | if(label_cnts[label_int] < cnt_per_class): 51 | # Found a datum which is relevant 52 | if uniform(0,1) < skip_prob: 53 | continue 54 | 55 | few_annotated_imgs[idx] = images[i] 56 | few_annotated_ys[idx] = y_[i] 57 | idx += 1 58 | label_cnts[label_int] += 1 59 | 60 | if np.sum(label_cnts) == labelled_cnt: 61 | am_i_done = True 62 | 63 | print('No. of data points evaluated when creating the labelled subset: {0}'.format(test_cnt)) 64 | 65 | return few_annotated_imgs, few_annotated_ys 66 | 67 | 68 | 69 | class Data: 70 | def __load_svhn(self, n_classes, cnt_per_class): 71 | res = 32 72 | 73 | test_data = sio.loadmat('../SVHN/test_32x32.mat') 74 | test_data['X'] = np.transpose(test_data['X'], axes=[3,0,1,2]) 75 | 76 | train_data = sio.loadmat('../SVHN/train_32x32.mat') 77 | train_data['X'] = np.transpose(train_data['X'], axes=[3,0,1,2]) 78 | 79 | extra_data = sio.loadmat('../SVHN/extra_32x32.mat') 80 | extra_data['X'] = np.transpose(extra_data['X'], axes=[3,0,1,2]) 81 | 82 | # Zero is originally labelled as 10, convert that to 0 83 | test_data['y'] = test_data['y'] % 10 84 | train_data['y'] = train_data['y'] % 10 85 | extra_data['y'] = extra_data['y'] % 10 86 | 87 | 88 | self.X_train = np.concatenate([train_data['X'], extra_data['X']], axis=0) 89 | self.X_train = self.X_train.reshape(self.X_train.shape[0], -1) 90 | 91 | self.y_train = np.concatenate([train_data['y'], extra_data['y']], axis=0) 92 | self.y_train = self.y_train.reshape(self.y_train.shape[0]) 93 | 94 | few_annotated_imgs, few_annotated_ys = self.__extract_few_labels_svhn(self.X_train, self.y_train, res, cnt_per_class, n_classes) 95 | self.X_train_few = few_annotated_imgs 96 | self.y_train_few = few_annotated_ys 97 | 98 | self.X_test = test_data['X'] 99 | self.X_test = self.X_test.reshape(self.X_test.shape[0], -1) 100 | 101 | self.y_test = test_data['y'] 102 | self.y_test = self.y_test.reshape(self.y_test.shape[0]) 103 | 104 | def __load_mnist(self, n_classes, cnt_per_class, working_directory, batch_size): 105 | data_directory = os.path.join(working_directory, "MNIST") 106 | if not os.path.exists(data_directory): 107 | os.makedirs(data_directory) 108 | 109 | self.mnist = input_data.read_data_sets(data_directory, one_hot=True) 110 | 111 | # labelled_cnt = cnt_per_class * n_classes # How many labelled data points will we use for the semi-supervised learning process? 112 | self.X_train_few, self.y_train_few = extract_few_mnist_labels(cnt_per_class, n_classes, batch_size, self.mnist) 113 | 114 | 115 | def __init__(self, dataset='MNIST', n_classes=10, cnt_per_class=10, working_directory=None, batch_size=-1): 116 | self.dataset = dataset 117 | if dataset == 'SVHN': 118 | self.__load_svhn(n_classes, cnt_per_class) 119 | elif dataset == 'MNIST': 120 | self.__load_mnist(n_classes, cnt_per_class, working_directory, batch_size) 121 | 122 | def __get_random_minibatch_svhn(self, batch_size, n_classes, purpose): 123 | if purpose == 'train_few': 124 | Xs = self.X_train_few 125 | ys = self.y_train_few 126 | elif purpose == 'train': 127 | Xs = self.X_train 128 | ys = self.y_train 129 | elif purpose == 'test': 130 | Xs = self.X_test 131 | ys = self.y_test 132 | else: 133 | assert(False) 134 | 135 | rnd_idxs = sample(range(ys.shape[0]), batch_size) 136 | y_batch = ys[rnd_idxs] 137 | X_batch = Xs[rnd_idxs, :] 138 | 139 | X_batch = X_batch.astype('float32') 140 | X_batch = X_batch / np.max(X_batch) 141 | 142 | y_batch = y_batch.astype('int32') 143 | 144 | # One-hot encoding of ys? 145 | b = np.zeros((batch_size, n_classes)) 146 | b[np.arange(batch_size), y_batch] = 1 147 | y_batch = b 148 | 149 | return X_batch, y_batch 150 | 151 | def __get_random_minibatch_mnist(self, batch_size, n_classes, purpose): 152 | if purpose == 'train_few': 153 | if batch_size <= self.y_train_few.shape[0]: 154 | rnd_idxs = sample(range(self.y_train_few.shape[0]), batch_size) 155 | else: 156 | assert(batch_size % self.y_train_few.shape[0] == 0) 157 | rnd_idxs = range(self.y_train_few.shape[0]) 158 | rnd_idxs = sample(rnd_idxs, len(rnd_idxs)) 159 | # Not random, but it can't be really, so just shuffle it 160 | while len(rnd_idxs) < batch_size: 161 | rnd_idxs.extend(range(self.y_train_few.shape[0])) 162 | 163 | y_batch = self.y_train_few[rnd_idxs, :] 164 | X_batch = self.X_train_few[rnd_idxs, :] 165 | elif purpose == 'train': 166 | X_batch, y_batch = self.mnist.train.next_batch(batch_size) 167 | elif purpose == 'test': 168 | X_batch, y_batch = self.mnist.test.next_batch(batch_size) 169 | else: 170 | assert(False) 171 | 172 | return X_batch, y_batch 173 | 174 | def get_first_x_mnist(self, batch_size, n_classes): 175 | # Used for getting the style of a digit, and then reproducing it along with the other 9 digits. 176 | 177 | # TODO: implement for SVHN 178 | assert(self.dataset == 'MNIST') 179 | assert(batch_size % n_classes == 0) 180 | n = int(batch_size / n_classes) 181 | 182 | X = self.mnist.test.images[:n] 183 | y = self.mnist.test.labels[:n] 184 | 185 | return X, y 186 | 187 | def get_random_minibatch(self, batch_size, n_classes, purpose='train_few'): 188 | assert(purpose in ['train_few', 'train', 'test']) 189 | 190 | if self.dataset == 'SVHN': 191 | return self.__get_random_minibatch_svhn(batch_size, n_classes, purpose) 192 | elif self.dataset == 'MNIST': 193 | return self.__get_random_minibatch_mnist(batch_size, n_classes, purpose) 194 | 195 | def __extract_few_labels_svhn(self, X_train, y_train, res, cnt_per_class, n_classes): 196 | print('Creating a subset of labelled data for the semi-supervised learning.') 197 | data_cnt = X_train.shape[0] 198 | 199 | 200 | label_cnts = np.zeros((n_classes)) 201 | labelled_cnt = cnt_per_class * n_classes 202 | few_annotated_imgs = np.zeros((labelled_cnt, res*res*3)) 203 | few_annotated_ys = np.zeros((labelled_cnt)) 204 | idx = 0 205 | 206 | test_cnt = 0 207 | sample_size = 10 208 | 209 | # For training the model on different *random* annotated subsets 210 | # Not optimal, but better than using the same labels and it's been easy to implement 211 | 212 | am_i_done = False 213 | while not am_i_done: 214 | test_cnt += sample_size 215 | sample_idxs = sample(range(data_cnt), sample_size) 216 | 217 | images = X_train[sample_idxs,:] 218 | y_ = y_train[sample_idxs] 219 | 220 | # images, y_ = mnist.train.next_batch(batch_size) 221 | for i in range(y_.shape[0]): 222 | label_int = y_[i] 223 | 224 | if(label_cnts[label_int] < cnt_per_class): 225 | # NB: there's a risk of repeated samples because of the way this is implemented 226 | 227 | few_annotated_imgs[idx] = images[i] 228 | few_annotated_ys[idx] = label_int 229 | idx += 1 230 | label_cnts[label_int] += 1 231 | 232 | if np.sum(label_cnts) == labelled_cnt: 233 | am_i_done = True 234 | 235 | print('No. of data points evaluated when creating the labelled subset: {0}'.format(test_cnt)) 236 | 237 | assert((np.sum(few_annotated_ys == 2)) == cnt_per_class) 238 | 239 | return few_annotated_imgs, few_annotated_ys 240 | 241 | 242 | 243 | -------------------------------------------------------------------------------- /aae_semi.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | 4 | from scipy.misc import imsave 5 | import os 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from tensorflow.contrib import layers 11 | from tensorflow.contrib.framework import arg_scope 12 | 13 | 14 | from components import aa_discriminator, semi_supervised_encoder, semi_supervised_decoder, correlation_classifier 15 | 16 | def random_one_hot(batch_size, n_classes): 17 | rnd_indices = tf.random_uniform([batch_size], minval=0, maxval=n_classes, dtype=tf.int32) 18 | p_y = tf.one_hot(rnd_indices, n_classes, on_value=1.0, off_value=0.0, dtype=tf.float32) 19 | 20 | return p_y 21 | 22 | class AAE_Semi(): 23 | def __init__(self, n_classes, z_dim, batch_size, normalizer_fn, img_res=28, img_channels=1, do_convolutional=True, 24 | decorr_scale=0.5, network_scale=1.0, adversarial_mean=0.0, adversarial_stddev=1.0): 25 | 26 | self.learning_rate = tf.placeholder(tf.float32, shape=[]) 27 | learning_rate = self.learning_rate 28 | 29 | self.input_x = tf.placeholder( 30 | tf.float32, [batch_size, img_res * img_res * img_channels]) 31 | 32 | self.z_tensor = tf.placeholder( 33 | tf.float32, [batch_size, z_dim]) 34 | 35 | self.target_y = tf.placeholder( 36 | tf.float32, [batch_size, n_classes]) 37 | 38 | self.dummy_p_yz = tf.placeholder( 39 | tf.float32, [batch_size, n_classes + z_dim]) 40 | 41 | 42 | with arg_scope([layers.conv2d, layers.conv2d_transpose, layers.fully_connected], 43 | activation_fn=tf.nn.relu, 44 | normalizer_fn=normalizer_fn, 45 | normalizer_params={'scale': True}): 46 | 47 | with tf.variable_scope("encoder") as scope: 48 | noise = tf.random_normal((batch_size, img_res*img_res*img_channels), mean=0.0, stddev=0.3, dtype=tf.float32) 49 | input_img = tf.add(self.input_x, noise) 50 | 51 | self.unnormalized_q_y_given_x, q_z_given_x = semi_supervised_encoder(input_img, z_dim, n_classes, batch_size, do_convolutional, network_scale, img_res, img_channels) 52 | 53 | self.q_z_given_x = q_z_given_x 54 | self.q_y_given_x = tf.nn.softmax(self.unnormalized_q_y_given_x) 55 | if tf.__version__ == '0.10.0': 56 | self.q_yz_given_x = tf.concat(1, [self.q_y_given_x, self.q_z_given_x]) 57 | else: 58 | self.q_yz_given_x = tf.concat([self.q_y_given_x, self.q_z_given_x],1 ) 59 | 60 | encoder_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='encoder') 61 | 62 | if normalizer_fn == None: 63 | G_shared_params = encoder_params[:-4] 64 | G_y_params = G_shared_params + encoder_params[-4:-2] 65 | G_z_params = G_shared_params + encoder_params[-2:] 66 | else: 67 | G_shared_params = encoder_params[:-6] 68 | G_y_params = G_shared_params + encoder_params[-6:-3] 69 | G_z_params = G_shared_params + encoder_params[-3:] 70 | 71 | with tf.variable_scope("decoder") as scope: 72 | output_x = semi_supervised_decoder(self.q_yz_given_x, batch_size, n_classes+z_dim, do_convolutional, network_scale, img_res, img_channels) 73 | decoder_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='decoder') 74 | 75 | 76 | with tf.variable_scope("z_discriminator") as scope: 77 | # Should predict "false" 78 | self.z_D2 = aa_discriminator(self.q_z_given_x, batch_size, z_dim) 79 | 80 | with tf.variable_scope("z_discriminator", reuse=True) as scope: 81 | # Not the full density function, just a random sample 82 | self.p_z = tf.random_normal([batch_size, z_dim], mean=adversarial_mean, stddev=adversarial_stddev) 83 | # The output of the discriminator for p(z) 84 | self.z_D1 = aa_discriminator(self.p_z, batch_size, z_dim) 85 | 86 | D_z_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='z_discriminator') 87 | 88 | with tf.variable_scope("y_discriminator") as scope: 89 | # Not the full categorical density function, just a random sample 90 | # The output of the discriminator for p(y) (assuming all classes are equally likely) 91 | p_y = random_one_hot(batch_size, n_classes) 92 | self.y_D1 = aa_discriminator(p_y, batch_size, n_classes) 93 | 94 | D_y_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='y_discriminator') 95 | 96 | with tf.variable_scope("y_discriminator", reuse=True) as scope: 97 | self.y_D2 = aa_discriminator(self.q_y_given_x, batch_size, n_classes) 98 | 99 | with tf.variable_scope("decoder", reuse=True) as scope: 100 | # For generating digits with the same style 101 | self.sampled_style_digits = semi_supervised_decoder(self.dummy_p_yz, batch_size, n_classes+z_dim, do_convolutional, network_scale, img_res, img_channels) 102 | 103 | with tf.variable_scope("decoder", reuse=True) as scope: 104 | # For generating images from the original images 105 | self.x_given_yz_given_x = semi_supervised_decoder(self.q_yz_given_x, batch_size, n_classes+z_dim, do_convolutional, network_scale, img_res, img_channels) 106 | 107 | with tf.variable_scope("correlation_classifier") as scope: 108 | # testing if y and z are correlated. If they are, try to make this not happen, as this should mean that information other than "style", but also class, is floating through z. 109 | self.q_y_given_z = correlation_classifier(self.q_z_given_x, batch_size, n_classes=n_classes) 110 | 111 | corr_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='correlation_classifier') 112 | 113 | reconstruction_loss = self.__reconstruction_loss(output_x, self.input_x) 114 | D_z_loss = self.__discriminator_loss(self.z_D1, self.z_D2, batch_size) 115 | D_y_loss = self.__discriminator_loss(self.y_D1, self.y_D2, batch_size) 116 | G_z_loss = self.__generator_loss(self.z_D2, batch_size) 117 | G_y_loss = self.__generator_loss(self.y_D2, batch_size) 118 | classification_loss = self.__classification_loss(self.target_y, self.unnormalized_q_y_given_x, batch_size, n_classes) 119 | correlation_classification_loss = self.__classification_loss(self.q_y_given_x, self.q_y_given_z, batch_size, n_classes) 120 | 121 | global_step = tf.contrib.framework.get_or_create_global_step() 122 | 123 | optimizer = 'Adam' 124 | 125 | ae_params = encoder_params + decoder_params 126 | self.train_reconstruction = layers.optimize_loss( 127 | reconstruction_loss, global_step, learning_rate, optimizer=optimizer, variables=ae_params, update_ops=[]) 128 | 129 | self.train_z_generator = layers.optimize_loss( 130 | G_z_loss, global_step, learning_rate, optimizer=optimizer, variables=G_z_params, update_ops=[]) 131 | self.train_y_generator = layers.optimize_loss( 132 | G_y_loss, global_step, learning_rate, optimizer=optimizer, variables=G_y_params, update_ops=[]) 133 | 134 | self.train_z_discrimator = layers.optimize_loss( 135 | D_z_loss, global_step, learning_rate, optimizer=optimizer, variables=D_z_params, update_ops=[]) 136 | self.train_y_discrimator = layers.optimize_loss( 137 | D_y_loss, global_step, learning_rate, optimizer=optimizer, variables=D_y_params, update_ops=[]) 138 | 139 | self.train_y_classifier = layers.optimize_loss( 140 | classification_loss, global_step, learning_rate, optimizer=optimizer, variables=G_y_params, update_ops=[]) 141 | 142 | self.train_correlation_classifier = layers.optimize_loss( 143 | correlation_classification_loss, global_step, learning_rate, optimizer=optimizer, variables=corr_params, update_ops=[]) 144 | 145 | self.train_decorrelation = layers.optimize_loss( 146 | -correlation_classification_loss*decorr_scale, global_step, learning_rate*decorr_scale, optimizer=optimizer, variables=encoder_params, update_ops=[]) 147 | 148 | self.sess = tf.Session() 149 | self.sess.run(tf.initialize_all_variables()) 150 | 151 | def reconstruction_phase(self, input_x, learning_rate): 152 | return self.sess.run(self.train_reconstruction, {self.input_x: input_x, self.learning_rate: learning_rate}) 153 | 154 | def discriminator_phase(self, input_x, learning_rate): 155 | return self.sess.run(self.train_z_discrimator, {self.input_x: input_x, self.learning_rate: learning_rate}) + \ 156 | self.sess.run(self.train_y_discrimator, {self.input_x: input_x, self.learning_rate: learning_rate}) 157 | 158 | def generator_phase(self, input_x, learning_rate): 159 | return self.sess.run(self.train_z_generator, {self.input_x: input_x, self.learning_rate: learning_rate}) +\ 160 | self.sess.run(self.train_y_generator, {self.input_x: input_x, self.learning_rate: learning_rate}) 161 | 162 | def supervised_phase(self, input_x, target_y, learning_rate): 163 | return self.sess.run(self.train_y_classifier, {self.input_x: input_x, self.target_y: target_y, self.learning_rate: learning_rate}) 164 | 165 | def correlation_classifier_phase(self, input_x, learning_rate): 166 | return self.sess.run(self.train_correlation_classifier, {self.input_x: input_x, self.learning_rate: learning_rate}) 167 | 168 | def decorrelation_phase(self, input_x, learning_rate): 169 | return self.sess.run(self.train_decorrelation, {self.input_x: input_x, self.learning_rate: learning_rate}) 170 | 171 | 172 | # The various losses 173 | def __classification_loss(self, target_y, pred_y, batch_size, n_classes): 174 | return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=target_y, logits=pred_y)) 175 | 176 | def __discriminator_loss(self, D1, D2, batch_size): 177 | return (tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.ones([batch_size],dtype=tf.int32), logits=D1))) + 178 | tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.zeros([batch_size],dtype=tf.int32), logits=D2)))) 179 | 180 | def __generator_loss(self, D2, batch_size): 181 | return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.ones([batch_size],dtype=tf.int32), logits=D2)) 182 | 183 | def __reconstruction_loss(self, output_tensor, target_tensor): 184 | return tf.reduce_mean(tf.square(target_tensor - output_tensor)) 185 | 186 | # Just for sanity-checking 187 | def test_print_q_z_given_x(self, input_x): 188 | # q(z|X) should have mean 0.0 and std. dev. 1.0 189 | q_z = self.sess.run(self.q_z_given_x, {self.input_x: input_x}) 190 | print('q_z:') 191 | print('mean(q_z):\n{0}'.format(np.mean(q_z))) 192 | print('stddev(q_z):\n{0}'.format(np.std(q_z))) 193 | print('Note-to-self: It seems that it\'s good to minimize the GAN learning rate as long as stddev(q_z) is close to 1.0') 194 | 195 | # for 10 classes q(y|X) should have mean 0.1 and std. dev. 0.3 196 | q_y, unnormalized_q_y = self.sess.run([self.q_y_given_x, self.unnormalized_q_y_given_x], {self.input_x: input_x}) 197 | print('mean(q_y):\n{0}'.format(np.mean(q_y))) 198 | print('stddev(q_y):\n{0}'.format(np.std(q_y))) 199 | 200 | 201 | def compute_accuracy_2(self, input_x): 202 | # Compute the accuracy of the classifier predicting q(y|X) given q(z|X) 203 | correct_prediction = tf.equal(tf.argmax(self.q_y_given_z,1), tf.argmax(self.q_y_given_x,1)) 204 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 205 | accuracy = self.sess.run(accuracy, {self.input_x: input_x}) 206 | return accuracy 207 | 208 | def compute_accuracy(self, input_x, target_y): 209 | # Compute the accurace of the classifier predicting the labelled examples 210 | correct_prediction = tf.equal(tf.argmax(self.q_y_given_x,1), tf.argmax(self.target_y,1)) 211 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 212 | accuracy = self.sess.run(accuracy, {self.input_x: input_x, self.target_y: target_y}) 213 | return accuracy 214 | 215 | def generate_similar_style(self, X_labelled, y_labelled, batch_size, directory, img_res, img_channels, n_classes, z_dim): 216 | assert(batch_size % n_classes == 0) 217 | 218 | # TODO: implement for SVHN 219 | assert(img_channels == 1) 220 | 221 | n = int(batch_size / n_classes) 222 | X_labelled = np.repeat(X_labelled, n, axis=0) 223 | 224 | q_yz_given_x = self.sess.run(self.q_yz_given_x, {self.input_x: X_labelled}) 225 | q_z_given_x = q_yz_given_x[:,n_classes:] # Only pick the first N samples, and only the style hereof 226 | 227 | indices = [] 228 | for i in range(batch_size): 229 | indices.append(i % n_classes) 230 | indices = np.asarray(indices) 231 | p_y = np.zeros((batch_size, n_classes)) 232 | p_y[np.arange(batch_size), indices] = 1 233 | 234 | p_yz = np.concatenate([p_y, q_z_given_x], axis=1) 235 | 236 | imgs = self.sess.run(self.sampled_style_digits, {self.dummy_p_yz: p_yz}) 237 | 238 | 239 | combined_img = np.zeros((n*img_res, (n_classes+1)*img_res)) 240 | # The left column the is the original images 241 | for i in range(n): 242 | combined_img[i*img_res:(i+1)*img_res, 0:img_res] = X_labelled[i*n_classes].reshape(img_res, img_res) 243 | 244 | # The remaining ones are the digits produced from the style captured from the original imgs 245 | for r in range(n): 246 | for c in range(n_classes): 247 | 248 | img = imgs[r*n_classes+c].reshape(img_res, img_res) 249 | combined_img[r*img_res:(1+r)*img_res, (c+1)*img_res:(c+2)*img_res] = img 250 | 251 | imgs_folder = os.path.join(directory, 'imgs') 252 | if not os.path.exists(imgs_folder): 253 | os.makedirs(imgs_folder) 254 | 255 | imsave(os.path.join(imgs_folder, 'captured_digit_style.png'), combined_img) 256 | 257 | 258 | def generate_digits(self, batch_size, directory, img_res, img_channels, n_classes, z_dim): 259 | assert(batch_size % n_classes == 0) 260 | 261 | # TODO: implement for SVHN 262 | assert(img_channels == 1) 263 | 264 | for n in range(3): 265 | indices = [] 266 | for i in range(batch_size): 267 | indices.append(i % n_classes) 268 | indices = np.asarray(indices) 269 | p_y = np.zeros((batch_size, n_classes)) 270 | p_y[np.arange(batch_size), indices] = 1 271 | 272 | 273 | p_z = np.zeros((batch_size, z_dim), dtype='float32') 274 | for i in range(int(batch_size / n_classes)): 275 | rnd_z = np.random.normal(0, 1.0, (1, z_dim)) 276 | 277 | # Use the same random Z for digits 0..9 to see the different digits of the same style. 278 | for j in range(n_classes): 279 | p_z[i*n_classes + j] = rnd_z 280 | 281 | 282 | p_yz = np.concatenate([p_y, p_z], axis=1) 283 | 284 | imgs = self.sess.run(self.sampled_style_digits, {self.dummy_p_yz: p_yz}) 285 | 286 | combined_img = np.zeros((int(batch_size/n_classes)*img_res, n_classes*img_res)) 287 | 288 | for r in range(int(batch_size/n_classes)): 289 | for c in range(n_classes): 290 | 291 | img = imgs[r*n_classes+c].reshape(img_res, img_res) 292 | combined_img[r*img_res:(1+r)*img_res, c*img_res:(c+1)*img_res] = img 293 | 294 | imgs_folder = os.path.join(directory, 'imgs') 295 | if not os.path.exists(imgs_folder): 296 | os.makedirs(imgs_folder) 297 | 298 | imsave(os.path.join(imgs_folder, 'digit_style_{0}.png'.format(n)), combined_img) 299 | 300 | def interpolate_digits(self, batch_size, directory, img_res, img_channels, n_classes, z_dim): 301 | assert(batch_size % n_classes == 0) 302 | 303 | # TODO: implement for SVHN 304 | assert(img_channels == 1) 305 | 306 | for n in range(3): 307 | indices = [] 308 | for i in range(batch_size): 309 | indices.append(i % n_classes) 310 | indices = np.asarray(indices) 311 | p_y = np.zeros((batch_size, n_classes)) 312 | p_y[np.arange(batch_size), indices] = 1 313 | 314 | p_z_1 = np.random.normal(0, 1.0, (1, z_dim)) 315 | p_z_2 = np.random.normal(0, 1.0, (1, z_dim)) 316 | p_z = np.zeros((batch_size, z_dim), dtype='float32') 317 | N = int(batch_size / n_classes) 318 | for i in range(N): 319 | interpol = i / float(N-1) 320 | rnd_z = p_z_1 * interpol + p_z_2 * (1-interpol) 321 | 322 | # Use the same random Z for digits 0..9 to see the different digits of the same style. 323 | for j in range(n_classes): 324 | p_z[i*n_classes + j] = rnd_z 325 | 326 | 327 | p_yz = np.concatenate([p_y, p_z], axis=1) 328 | 329 | imgs = self.sess.run(self.sampled_style_digits, {self.dummy_p_yz: p_yz}) 330 | 331 | combined_img = np.zeros((int(batch_size/n_classes)*img_res, n_classes*img_res)) 332 | 333 | for r in range(int(batch_size/n_classes)): 334 | for c in range(n_classes): 335 | 336 | img = imgs[r*n_classes+c].reshape(img_res, img_res) 337 | combined_img[r*img_res:(1+r)*img_res, c*img_res:(c+1)*img_res] = img 338 | 339 | imgs_folder = os.path.join(directory, 'imgs') 340 | if not os.path.exists(imgs_folder): 341 | os.makedirs(imgs_folder) 342 | 343 | imsave(os.path.join(imgs_folder, 'digit_style_interpolation_{0}.png'.format(n)), combined_img) 344 | --------------------------------------------------------------------------------