├── .gitignore ├── README.md ├── autoencoder.py ├── images ├── ae_sample.jpg ├── conditional.gif ├── conditioned_samples.png ├── gated_cnn.png ├── loss.png └── sample.jpg ├── layers.py ├── main.py ├── models.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | ckpts 3 | samples 4 | data 5 | .DS_Store 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Generation with Gated PixelCNN Decoders 2 | 3 | This is a Tensorflow implementation of [Conditional Image Generation with PixelCNN Decoders](https://arxiv.org/abs/1606.05328) which introduces the Gated PixelCNN model based on PixelCNN architecture originally mentioned in [Pixel Recurrent Neural Networks](https://arxiv.org/abs/1601.06759). The model can be conditioned on latent representation of labels or images to generate images accordingly. Images can also be modelled unconditionally. It can also act as a powerful decoder and can replace deconvolution (transposed convolution) in Autoencoders and GANs. A detailed summary of the paper can be found [here](https://gist.github.com/anantzoid/b2dca657003998027c2861f3121c43b7). 4 | 5 | These are some conditioned samples generated by the authors of the paper: 6 | 7 | ![Paper Sample](images/conditioned_samples.png) 8 | 9 | ## Architecture 10 | 11 | This is the architecture for Gated PixelCNN used in the model: 12 | 13 | ![Gated PCNN](images/gated_cnn.png) 14 | 15 | The gating accounts for remembering the context and model more complex interactions, like in LSTM. The network stack on the left is the Vertical stack that takes care of blind spots that occure while convolution due to the masking layer (Refer the Pixel RNN paper to know more about masking). Use of residual connection significantly improves the model performance. 16 | 17 | ## Usage 18 | 19 | This implementation consists of the following models based on the Gated PixelCNN architecture: 20 | 21 | - **Unconditional image generation**: 22 | ``` 23 | python main.py 24 | ``` 25 | 26 | Sample generated by training MNIST dataset after 70 epochs with a cross-entropy loss of 0.104610: 27 | 28 | ![Unconditional image](images/sample.jpg) 29 | 30 | - **Conditional image generation based on class labels**: 31 | ``` 32 | python main.py --model=conditional 33 | ``` 34 | 35 | As mentioned in the paper, conditionally generated images are more visually appealing though the loss difference is almost same. It has a loss of 0.102719 after 40 epochs: 36 | 37 | ![Conditional image](images/conditional.gif) 38 | 39 | - **Autoencoder with PixelCNN decoder**: 40 | ``` 41 | python main.py --model=autoencoder 42 | ``` 43 | 44 | The encoder part of the autoencoder has the original architecture as mentioned in [Stacked Convolutional Auto-Encoders for Hierarchical Feature Extraction](https://pdfs.semanticscholar.org/1c6d/990c80e60aa0b0059415444cdf94b3574f0f.pdf). The representation is encoded into 10d tensor. The image generated after 10 epochs with a loss of 0.115306: 45 | 46 | ![AE image](images/ae_sample.jpg) 47 | 48 | To only generate images append the `--epochs=0` flag after the command. 49 | 50 | To train the any model on CIFAR-10 dataset, add the `--data=cifar` flag. 51 | 52 | Refer `main.py` for other available flags for hyperparameter tuning. 53 | 54 | ## Training Details 55 | 56 | The system was trained on a single AWS p2.xlarge spot instance. The implementation was only done on MNIST dataset. Generation of samples based on CIFAR-10 images took the authors 32 GPUs trained for 60 hours. 57 | 58 | To visualize the graph and loss during training, run: 59 | ``` 60 | tensorboard --logdir=logs 61 | ``` 62 | 63 | Loss minimization for the autoencoder model: 64 | 65 | ![Loss](images/loss.png) 66 | -------------------------------------------------------------------------------- /autoencoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from utils import * 4 | from models import * 5 | 6 | def trainAE(conf, data): 7 | encoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) 8 | decoder_X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) 9 | 10 | encoder = ConvolutionalEncoder(encoder_X, conf) 11 | decoder = PixelCNN(decoder_X, conf, encoder.pred) 12 | y = decoder.pred 13 | tf.scalar_summary('loss', decoder.loss) 14 | 15 | trainer = tf.train.RMSPropOptimizer(1e-3) 16 | gradients = trainer.compute_gradients(decoder.loss) 17 | 18 | clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip, conf.grad_clip), _[1]) for _ in gradients] 19 | optimizer = trainer.apply_gradients(clipped_gradients) 20 | 21 | saver = tf.train.Saver(tf.trainable_variables()) 22 | with tf.Session() as sess: 23 | merged = tf.merge_all_summaries() 24 | writer = tf.train.SummaryWriter(conf.summary_path, sess.graph) 25 | 26 | sess.run(tf.initialize_all_variables()) 27 | 28 | if os.path.exists(conf.ckpt_file): 29 | saver.restore(sess, conf.ckpt_file) 30 | print("Model Restored") 31 | 32 | # TODO The training part below and in main.py could be generalized 33 | if conf.epochs > 0: 34 | print("Started Model Training...") 35 | pointer = 0 36 | step = 0 37 | for i in range(conf.epochs): 38 | for j in range(conf.num_batches): 39 | if conf.data == 'mnist': 40 | batch_X = binarize(data.train.next_batch(conf.batch_size)[0].reshape(conf.batch_size, conf.img_height, conf.img_width, conf.channel)) 41 | else: 42 | batch_X, pointer = get_batch(data, pointer, conf.batch_size) 43 | 44 | _, l, summary = sess.run([optimizer, decoder.loss, merged], feed_dict={encoder_X: batch_X, decoder_X: batch_X}) 45 | writer.add_summary(summary, step) 46 | step += 1 47 | 48 | print("Epoch: %d, Cost: %f"%(i, l)) 49 | if (i+1)%10 == 0: 50 | saver.save(sess, conf.ckpt_file) 51 | generate_ae(sess, encoder_X, decoder_X, y, data, conf, str(i)) 52 | 53 | writer.close() 54 | generate_ae(sess, encoder_X, decoder_X, y, data, conf, '') 55 | 56 | -------------------------------------------------------------------------------- /images/ae_sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anantzoid/Conditional-PixelCNN-decoder/ecb0c814604c7503fc9cf2ff704e86086ffb6bd9/images/ae_sample.jpg -------------------------------------------------------------------------------- /images/conditional.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anantzoid/Conditional-PixelCNN-decoder/ecb0c814604c7503fc9cf2ff704e86086ffb6bd9/images/conditional.gif -------------------------------------------------------------------------------- /images/conditioned_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anantzoid/Conditional-PixelCNN-decoder/ecb0c814604c7503fc9cf2ff704e86086ffb6bd9/images/conditioned_samples.png -------------------------------------------------------------------------------- /images/gated_cnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anantzoid/Conditional-PixelCNN-decoder/ecb0c814604c7503fc9cf2ff704e86086ffb6bd9/images/gated_cnn.png -------------------------------------------------------------------------------- /images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anantzoid/Conditional-PixelCNN-decoder/ecb0c814604c7503fc9cf2ff704e86086ffb6bd9/images/loss.png -------------------------------------------------------------------------------- /images/sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anantzoid/Conditional-PixelCNN-decoder/ecb0c814604c7503fc9cf2ff704e86086ffb6bd9/images/sample.jpg -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def get_weights(shape, name, horizontal, mask_mode='noblind', mask=None): 5 | weights_initializer = tf.contrib.layers.xavier_initializer() 6 | W = tf.get_variable(name, shape, tf.float32, weights_initializer) 7 | 8 | ''' 9 | Use of masking to hide subsequent pixel values 10 | ''' 11 | if mask: 12 | filter_mid_y = shape[0]//2 13 | filter_mid_x = shape[1]//2 14 | mask_filter = np.ones(shape, dtype=np.float32) 15 | if mask_mode == 'noblind': 16 | if horizontal: 17 | # All rows after center must be zero 18 | mask_filter[filter_mid_y+1:, :, :, :] = 0.0 19 | # All columns after center in center row must be zero 20 | mask_filter[filter_mid_y, filter_mid_x+1:, :, :] = 0.0 21 | else: 22 | if mask == 'a': 23 | # In the first layer, can ONLY access pixels above it 24 | mask_filter[filter_mid_y:, :, :, :] = 0.0 25 | else: 26 | # In the second layer, can access pixels above or even with it. 27 | # Reason being that the pixels to the right or left of the current pixel 28 | # only have a receptive field of the layer above the current layer and up. 29 | mask_filter[filter_mid_y+1:, :, :, :] = 0.0 30 | 31 | if mask == 'a': 32 | # Center must be zero in first layer 33 | mask_filter[filter_mid_y, filter_mid_x, :, :] = 0.0 34 | else: 35 | mask_filter[filter_mid_y, filter_mid_x+1:, :, :] = 0. 36 | mask_filter[filter_mid_y+1:, :, :, :] = 0. 37 | 38 | if mask == 'a': 39 | mask_filter[filter_mid_y, filter_mid_x, :, :] = 0. 40 | 41 | W *= mask_filter 42 | return W 43 | 44 | def get_bias(shape, name): 45 | return tf.get_variable(name, shape, tf.float32, tf.zeros_initializer) 46 | 47 | def conv_op(x, W): 48 | return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') 49 | 50 | def max_pool_2x2(x): 51 | return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') 52 | 53 | class GatedCNN(): 54 | def __init__(self, W_shape, fan_in, horizontal, gated=True, payload=None, mask=None, activation=True, conditional=None, conditional_image=None): 55 | self.fan_in = fan_in 56 | in_dim = self.fan_in.get_shape()[-1] 57 | self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]] 58 | self.b_shape = W_shape[2] 59 | 60 | self.in_dim = in_dim 61 | self.payload = payload 62 | self.mask = mask 63 | self.activation = activation 64 | self.conditional = conditional 65 | self.conditional_image = conditional_image 66 | self.horizontal = horizontal 67 | 68 | if gated: 69 | self.gated_conv() 70 | else: 71 | self.simple_conv() 72 | 73 | def gated_conv(self): 74 | W_f = get_weights(self.W_shape, "v_W", self.horizontal, mask=self.mask) 75 | W_g = get_weights(self.W_shape, "h_W", self.horizontal, mask=self.mask) 76 | 77 | b_f_total = get_bias(self.b_shape, "v_b") 78 | b_g_total = get_bias(self.b_shape, "h_b") 79 | if self.conditional is not None: 80 | h_shape = int(self.conditional.get_shape()[1]) 81 | V_f = get_weights([h_shape, self.W_shape[3]], "v_V", self.horizontal) 82 | b_f = tf.matmul(self.conditional, V_f) 83 | V_g = get_weights([h_shape, self.W_shape[3]], "h_V", self.horizontal) 84 | b_g = tf.matmul(self.conditional, V_g) 85 | 86 | b_f_shape = tf.shape(b_f) 87 | b_f = tf.reshape(b_f, (b_f_shape[0], 1, 1, b_f_shape[1])) 88 | b_g_shape = tf.shape(b_g) 89 | b_g = tf.reshape(b_g, (b_g_shape[0], 1, 1, b_g_shape[1])) 90 | 91 | b_f_total = b_f_total + b_f 92 | b_g_total = b_g_total + b_g 93 | if self.conditional_image is not None: 94 | b_f_total = b_f_total + tf.layers.conv2d(self.conditional_image, self.in_dim, 1, use_bias=False, name="ci_f") 95 | b_g_total = b_g_total + tf.layers.conv2d(self.conditional_image, self.in_dim, 1, use_bias=False, name="ci_g") 96 | 97 | conv_f = conv_op(self.fan_in, W_f) 98 | conv_g = conv_op(self.fan_in, W_g) 99 | 100 | if self.payload is not None: 101 | conv_f += self.payload 102 | conv_g += self.payload 103 | 104 | self.fan_out = tf.multiply(tf.tanh(conv_f + b_f_total), tf.sigmoid(conv_g + b_g_total)) 105 | 106 | def simple_conv(self): 107 | W = get_weights(self.W_shape, "W", self.horizontal, mask_mode="standard", mask=self.mask) 108 | b = get_bias(self.b_shape, "b") 109 | conv = conv_op(self.fan_in, W) 110 | if self.activation: 111 | self.fan_out = tf.nn.relu(tf.add(conv, b)) 112 | else: 113 | self.fan_out = tf.add(conv, b) 114 | 115 | def output(self): 116 | return self.fan_out 117 | 118 | 119 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import argparse 4 | from models import PixelCNN 5 | from autoencoder import * 6 | from utils import * 7 | 8 | def train(conf, data): 9 | X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel]) 10 | model = PixelCNN(X, conf) 11 | 12 | trainer = tf.train.RMSPropOptimizer(1e-3) 13 | gradients = trainer.compute_gradients(model.loss) 14 | 15 | clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip, conf.grad_clip), _[1]) for _ in gradients] 16 | optimizer = trainer.apply_gradients(clipped_gradients) 17 | 18 | saver = tf.train.Saver(tf.trainable_variables()) 19 | 20 | with tf.Session() as sess: 21 | sess.run(tf.initialize_all_variables()) 22 | if os.path.exists(conf.ckpt_file): 23 | saver.restore(sess, conf.ckpt_file) 24 | print("Model Restored") 25 | 26 | if conf.epochs > 0: 27 | print("Started Model Training...") 28 | pointer = 0 29 | for i in range(conf.epochs): 30 | for j in range(conf.num_batches): 31 | if conf.data == "mnist": 32 | batch_X, batch_y = data.train.next_batch(conf.batch_size) 33 | batch_X = binarize(batch_X.reshape([conf.batch_size, \ 34 | conf.img_height, conf.img_width, conf.channel])) 35 | batch_y = one_hot(batch_y, conf.num_classes) 36 | else: 37 | batch_X, pointer = get_batch(data, pointer, conf.batch_size) 38 | data_dict = {X:batch_X} 39 | if conf.conditional is True: 40 | data_dict[model.h] = batch_y 41 | _, cost = sess.run([optimizer, model.loss], feed_dict=data_dict) 42 | print("Epoch: %d, Cost: %f"%(i, cost)) 43 | if (i+1)%10 == 0: 44 | saver.save(sess, conf.ckpt_file) 45 | generate_samples(sess, X, model.h, model.pred, conf, "") 46 | 47 | generate_samples(sess, X, model.h, model.pred, conf, "") 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('--data', type=str, default='mnist') 52 | parser.add_argument('--layers', type=int, default=12) 53 | parser.add_argument('--f_map', type=int, default=32) 54 | parser.add_argument('--epochs', type=int, default=50) 55 | parser.add_argument('--batch_size', type=int, default=100) 56 | parser.add_argument('--grad_clip', type=int, default=1) 57 | parser.add_argument('--model', type=str, default='') 58 | parser.add_argument('--data_path', type=str, default='data') 59 | parser.add_argument('--ckpt_path', type=str, default='ckpts') 60 | parser.add_argument('--samples_path', type=str, default='samples') 61 | parser.add_argument('--summary_path', type=str, default='logs') 62 | conf = parser.parse_args() 63 | 64 | if conf.data == 'mnist': 65 | from tensorflow.examples.tutorials.mnist import input_data 66 | if not os.path.exists(conf.data_path): 67 | os.makedirs(conf.data_path) 68 | data = input_data.read_data_sets(conf.data_path) 69 | conf.num_classes = 10 70 | conf.img_height = 28 71 | conf.img_width = 28 72 | conf.channel = 1 73 | conf.num_batches = data.train.num_examples // conf.batch_size 74 | else: 75 | from keras.datasets import cifar10 76 | data = cifar10.load_data() 77 | labels = data[0][1] 78 | data = data[0][0].astype(np.float32) 79 | data[:,0,:,:] -= np.mean(data[:,0,:,:]) 80 | data[:,1,:,:] -= np.mean(data[:,1,:,:]) 81 | data[:,2,:,:] -= np.mean(data[:,2,:,:]) 82 | data = np.transpose(data, (0, 2, 3, 1)) 83 | conf.img_height = 32 84 | conf.img_width = 32 85 | conf.channel = 3 86 | conf.num_classes = 10 87 | conf.num_batches = data.shape[0] // conf.batch_size 88 | 89 | conf = makepaths(conf) 90 | if conf.model == '': 91 | conf.conditional = False 92 | train(conf, data) 93 | elif conf.model.lower() == 'conditional': 94 | conf.conditional = True 95 | train(conf, data) 96 | elif conf.model.lower() == 'autoencoder': 97 | conf.conditional = True 98 | trainAE(conf, data) 99 | 100 | 101 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from layers import * 3 | 4 | class PixelCNN(object): 5 | def __init__(self, X, conf, full_horizontal=True, h=None): 6 | self.X = X 7 | if conf.data == "mnist": 8 | self.X_norm = X 9 | else: 10 | ''' 11 | Image normalization for CIFAR-10 was supposed to be done here 12 | ''' 13 | self.X_norm = X 14 | v_stack_in, h_stack_in = self.X_norm, self.X_norm 15 | 16 | if conf.conditional is True: 17 | if h is not None: 18 | self.h = h 19 | else: 20 | self.h = tf.placeholder(tf.float32, shape=[None, conf.num_classes]) 21 | else: 22 | self.h = None 23 | 24 | for i in range(conf.layers): 25 | filter_size = 3 if i > 0 else 7 26 | mask = 'b' if i > 0 else 'a' 27 | residual = True if i > 0 else False 28 | i = str(i) 29 | with tf.variable_scope("v_stack"+i): 30 | v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, False, mask=mask, conditional=self.h).output() 31 | v_stack_in = v_stack 32 | 33 | with tf.variable_scope("v_stack_1"+i): 34 | v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, False, gated=False, mask=None).output() 35 | 36 | with tf.variable_scope("h_stack"+i): 37 | h_stack = GatedCNN([filter_size if full_horizontal else 1, filter_size, conf.f_map], h_stack_in, True, payload=v_stack_1, mask=mask, conditional=self.h).output() 38 | 39 | with tf.variable_scope("h_stack_1"+i): 40 | h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, True, gated=False, mask=None).output() 41 | if residual: 42 | h_stack_1 += h_stack_in # Residual connection 43 | h_stack_in = h_stack_1 44 | 45 | with tf.variable_scope("fc_1"): 46 | fc1 = GatedCNN([1, 1, conf.f_map], h_stack_in, True, gated=False, mask='b').output() 47 | 48 | if conf.data == "mnist": 49 | with tf.variable_scope("fc_2"): 50 | self.fc2 = GatedCNN([1, 1, 1], fc1, True, gated=False, mask='b', activation=False).output() 51 | self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fc2, labels=self.X)) 52 | self.pred = tf.nn.sigmoid(self.fc2) 53 | else: 54 | color_dim = 256 55 | with tf.variable_scope("fc_2"): 56 | self.fc2 = GatedCNN([1, 1, conf.channel * color_dim], fc1, True, gated=False, mask='b', activation=False).output() 57 | self.fc2 = tf.reshape(self.fc2, (-1, color_dim)) 58 | 59 | self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(self.fc2, tf.cast(tf.reshape(self.X, [-1]), dtype=tf.int32))) 60 | 61 | ''' 62 | Since this code was not run on CIFAR-10, I'm not sure which 63 | would be a suitable way to generate 3-channel images. Below are 64 | the 2 methods which may be used, with the first one (self.pred) 65 | being more likely. 66 | ''' 67 | self.pred_sampling = tf.reshape(tf.multinomial(tf.nn.softmax(self.fc2), num_samples=1, seed=100), tf.shape(self.X)) 68 | self.pred = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), tf.shape(self.X)) 69 | 70 | 71 | class ConvolutionalEncoder(object): 72 | def __init__(self, X, conf): 73 | ''' 74 | This is the 6-layer architecture for Convolutional Autoencoder 75 | mentioned in the original paper: 76 | Stacked Convolutional Auto-Encoders for Hierarchical Feature Extraction 77 | 78 | Note that only the encoder part is implemented as PixelCNN is taken 79 | as the decoder. 80 | ''' 81 | 82 | W_conv1 = get_weights([5, 5, conf.channel, 100], "W_conv1") 83 | b_conv1 = get_bias([100], "b_conv1") 84 | conv1 = tf.nn.relu(conv_op(X, W_conv1) + b_conv1) 85 | pool1 = max_pool_2x2(conv1) 86 | 87 | W_conv2 = get_weights([5, 5, 100, 150], "W_conv2") 88 | b_conv2 = get_bias([150], "b_conv2") 89 | conv2 = tf.nn.relu(conv_op(pool1, W_conv2) + b_conv2) 90 | pool2 = max_pool_2x2(conv2) 91 | 92 | W_conv3 = get_weights([3, 3, 150, 200], "W_conv3") 93 | b_conv3 = get_bias([200], "b_conv3") 94 | conv3 = tf.nn.relu(conv_op(pool2, W_conv3) + b_conv3) 95 | conv3_reshape = tf.reshape(conv3, (-1, 7*7*200)) 96 | 97 | W_fc = get_weights([7*7*200, 10], "W_fc") 98 | b_fc = get_bias([10], "b_fc") 99 | self.pred = tf.nn.softmax(tf.add(tf.matmul(conv3_reshape, W_fc), b_fc)) 100 | 101 | 102 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import scipy.misc 4 | from datetime import datetime 5 | import tensorflow as tf 6 | 7 | def binarize(images): 8 | return (np.random.uniform(size=images.shape) < images).astype(np.float32) 9 | 10 | def generate_samples(sess, X, h, pred, conf, suff): 11 | print("Generating Sample Images...") 12 | n_row, n_col = 10,10 13 | samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) 14 | # TODO make it generic 15 | labels = one_hot(np.array([0,1,2,3,4,5,6,7,8,9]*10), conf.num_classes) 16 | 17 | for i in range(conf.img_height): 18 | for j in range(conf.img_width): 19 | for k in range(conf.channel): 20 | data_dict = {X:samples} 21 | if conf.conditional is True: 22 | data_dict[h] = labels 23 | next_sample = sess.run(pred, feed_dict=data_dict) 24 | if conf.data == "mnist": 25 | next_sample = binarize(next_sample) 26 | samples[:, i, j, k] = next_sample[:, i, j, k] 27 | 28 | save_images(samples, n_row, n_col, conf, suff) 29 | 30 | 31 | def generate_ae(sess, encoder_X, decoder_X, y, data, conf, suff=''): 32 | print("Generating Sample Images...") 33 | n_row, n_col = 10,10 34 | samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32) 35 | if conf.data == 'mnist': 36 | labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel)) 37 | else: 38 | labels = get_batch(data, 0, n_row*n_col) 39 | 40 | for i in range(conf.img_height): 41 | for j in range(conf.img_width): 42 | for k in range(conf.channel): 43 | next_sample = sess.run(y, {encoder_X: labels, decoder_X: samples}) 44 | if conf.data == 'mnist': 45 | next_sample = binarize(next_sample) 46 | samples[:, i, j, k] = next_sample[:, i, j, k] 47 | 48 | save_images(samples, n_row, n_col, conf, suff) 49 | 50 | 51 | def save_images(samples, n_row, n_col, conf, suff): 52 | images = samples 53 | if conf.data == "mnist": 54 | images = images.reshape((n_row, n_col, conf.img_height, conf.img_width)) 55 | images = images.transpose(1, 2, 0, 3) 56 | images = images.reshape((conf.img_height * n_row, conf.img_width * n_col)) 57 | else: 58 | images = images.reshape((n_row, n_col, conf.img_height, conf.img_width, conf.channel)) 59 | images = images.transpose(1, 2, 0, 3, 4) 60 | images = images.reshape((conf.img_height * n_row, conf.img_width * n_col, conf.channel)) 61 | 62 | filename = datetime.now().strftime('%Y_%m_%d_%H_%M')+suff+".jpg" 63 | scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join(conf.samples_path, filename)) 64 | 65 | 66 | def get_batch(data, pointer, batch_size): 67 | if (batch_size + 1) * pointer >= data.shape[0]: 68 | pointer = 0 69 | batch = data[batch_size * pointer : batch_size * (pointer + 1)] 70 | pointer += 1 71 | return [batch, pointer] 72 | 73 | 74 | def one_hot(batch_y, num_classes): 75 | y_ = np.zeros((batch_y.shape[0], num_classes)) 76 | y_[np.arange(batch_y.shape[0]), batch_y] = 1 77 | return y_ 78 | 79 | 80 | def makepaths(conf): 81 | ckpt_full_path = os.path.join(conf.ckpt_path, "data=%s_bs=%d_layers=%d_fmap=%d"%(conf.data, conf.batch_size, conf.layers, conf.f_map)) 82 | if not os.path.exists(ckpt_full_path): 83 | os.makedirs(ckpt_full_path) 84 | conf.ckpt_file = os.path.join(ckpt_full_path, "model.ckpt") 85 | 86 | conf.samples_path = os.path.join(conf.samples_path, "epoch=%d_bs=%d_layers=%d_fmap=%d"%(conf.epochs, conf.batch_size, conf.layers, conf.f_map)) 87 | if not os.path.exists(conf.samples_path): 88 | os.makedirs(conf.samples_path) 89 | 90 | if tf.gfile.Exists(conf.summary_path): 91 | tf.gfile.DeleteRecursively(conf.summary_path) 92 | tf.gfile.MakeDirs(conf.summary_path) 93 | 94 | return conf 95 | --------------------------------------------------------------------------------