├── tflib ├── objs │ ├── __init__.py │ ├── discrete_variables.py │ ├── kl.py │ ├── gan.py │ ├── mmd.py │ ├── kl_aggregated.py │ └── gan_inference.py ├── ops │ ├── __init__.py │ ├── layernorm.py │ ├── cond_batchnorm.py │ ├── combination.py │ ├── minibatch.py │ ├── conv3d.py │ ├── conv1d.py │ ├── deconv2d.py │ ├── conv2d.py │ ├── batchnorm.py │ └── linear.py ├── utils │ ├── __init__.py │ └── distance.py ├── plot.py ├── visualization.py ├── cifar10.py ├── svhn.py ├── mnist.py ├── celebA.py ├── save_images.py ├── chairs.py ├── inception_score.py ├── __init__.py └── simple_moving_mnist.py ├── README.md ├── gan_inference_face.py ├── gmgan_inference_face.py ├── gan_inference_svhn.py ├── gan_inference_mnist.py ├── gan_inference_cifar10.py ├── gmgan_inference_svhn.py └── gmgan_inference_cifar10.py /tflib/objs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tflib/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tflib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tflib/objs/discrete_variables.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tflib as lib 3 | 4 | def score_function(f_k, p_k, c_v): 5 | # estimate the gradients of E_p(k|params) f(k) 6 | # the results is like (f(k) - cv) * Grad(log p(k|params)) 7 | # or equivalently Grad(stop_gradient((f(k) - cv)) * log p(k|params))) 8 | 9 | return tf.stop_gradient(f_k - c_v) * tf.log(p_k) -------------------------------------------------------------------------------- /tflib/utils/distance.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def l2(x, y): 4 | return tf.reduce_mean(tf.pow(x - y, 2)) 5 | 6 | def l1(x, y): 7 | return tf.reduce_mean(tf.abs(x - y)) 8 | 9 | def distance(x, y, d_type): 10 | xs = tf.shape(x) 11 | x = tf.reshape(x, [-1, xs[-1]]) 12 | ys = tf.shape(y) 13 | y = tf.reshape(y, [-1, ys[-1]]) 14 | if d_type is 'l1': 15 | return l1(x,y) 16 | elif d_type is 'l2': 17 | return l2(x,y) -------------------------------------------------------------------------------- /tflib/ops/layernorm.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | def Layernorm(name, norm_axes, inputs): 7 | mean, var = tf.nn.moments(inputs, norm_axes, keep_dims=True) 8 | 9 | # Assume the 'neurons' axis is the first of norm_axes. This is the case for fully-connected and BCHW conv layers. 10 | n_neurons = inputs.get_shape().as_list()[norm_axes[0]] 11 | 12 | offset = lib.param(name+'.offset', np.zeros(n_neurons, dtype='float32')) 13 | scale = lib.param(name+'.scale', np.ones(n_neurons, dtype='float32')) 14 | 15 | # Add broadcasting dims to offset and scale (e.g. BCHW conv data) 16 | offset = tf.reshape(offset, [-1] + [1 for i in xrange(len(norm_axes)-1)]) 17 | scale = tf.reshape(scale, [-1] + [1 for i in xrange(len(norm_axes)-1)]) 18 | 19 | result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5) 20 | 21 | return result -------------------------------------------------------------------------------- /tflib/ops/cond_batchnorm.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | def Batchnorm(name, axes, inputs, is_training=None, stats_iter=None, update_moving_stats=True, fused=True, labels=None, n_labels=None): 7 | """conditional batchnorm (dumoulin et al 2016) for BCHW conv filtermaps""" 8 | if axes != [0,2,3]: 9 | raise Exception('unsupported') 10 | mean, var = tf.nn.moments(inputs, axes, keep_dims=True) 11 | shape = mean.get_shape().as_list() # shape is [1,n,1,1] 12 | offset_m = lib.param(name+'.offset', np.zeros([n_labels,shape[1]], dtype='float32')) 13 | scale_m = lib.param(name+'.scale', np.ones([n_labels,shape[1]], dtype='float32')) 14 | offset = tf.nn.embedding_lookup(offset_m, labels) 15 | scale = tf.nn.embedding_lookup(scale_m, labels) 16 | result = tf.nn.batch_normalization(inputs, mean, var, offset[:,:,None,None], scale[:,:,None,None], 1e-5) 17 | return result -------------------------------------------------------------------------------- /tflib/ops/combination.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | def Ladder(inputs, input_dim, name): 7 | with tf.name_scope(name) as scope: 8 | zs = np.zeros(input_dim).astype('float32') 9 | os = np.ones(input_dim).astype('float32') 10 | 11 | a1 = lib.param(name + '.a1', zs) 12 | a2 = lib.param(name + '.a2', os) 13 | a3 = lib.param(name + '.a3', zs) 14 | a4 = lib.param(name + '.a4', zs) 15 | 16 | c1 = lib.param(name + '.c1', zs) 17 | c2 = lib.param(name + '.c2', os) 18 | c3 = lib.param(name + '.c3', zs) 19 | c4 = lib.param(name + '.c4', zs) 20 | 21 | b1 = lib.param(name + '.b1', zs) 22 | 23 | z_lat, u = inputs 24 | 25 | sigval = c1 + c2*z_lat 26 | sigval += c3*u + c4*z_lat*u 27 | sigval = tf.nn.sigmoid(sigval) 28 | z_est = a1 + a2 * z_lat + b1*sigval 29 | z_est += a3*u + a4*z_lat*u 30 | 31 | return z_est -------------------------------------------------------------------------------- /tflib/objs/kl.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tflib as lib 3 | import math 4 | 5 | def kl_q_p_diagonal_gaussian(q_z_mean, q_z_std, p_z_mean, p_z_std): 6 | q_z_var = tf.pow(q_z_std, 2) 7 | p_z_var = tf.pow(p_z_std, 2) 8 | mean_diff = tf.pow(p_z_mean - q_z_mean, 2) 9 | res_mat = .5*(tf.log(p_z_var/q_z_var) + (mean_diff + q_z_var) / p_z_var - 1) 10 | return tf.reduce_mean(tf.reduce_sum(res_mat, axis=1), axis=0) 11 | 12 | def neg_log_likelihood_diagnoal_gaussian(x, mu, std): 13 | res_mat = .5*(tf.pow((x - mu)/std, 2) + tf.log(2*math.pi) + 2*tf.log(std)) 14 | return tf.reduce_mean(tf.reduce_sum(res_mat, axis=1), axis=0) 15 | 16 | def vae(real_x, p_x_mean, p_x_std, q_z_mean, q_z_std, p_z_mean, p_z_std, gen_params, lr=2e-4, beta1=.5): 17 | gen_cost = kl_q_p_diagonal_gaussian(q_z_mean, q_z_std, p_z_mean, p_z_std) 18 | gen_cost += neg_log_likelihood_diagnoal_gaussian(real_x, p_x_mean, p_x_std) 19 | 20 | gen_train_op = tf.train.AdamOptimizer( 21 | learning_rate=lr, 22 | beta1=beta1 23 | ).minimize(gen_cost, var_list=gen_params) 24 | 25 | return gen_cost, gen_train_op -------------------------------------------------------------------------------- /tflib/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | import collections 9 | import time 10 | import cPickle as pickle 11 | 12 | _since_beginning = collections.defaultdict(lambda: {}) 13 | _since_last_flush = collections.defaultdict(lambda: {}) 14 | 15 | _iter = [0] 16 | def tick(): 17 | _iter[0] += 1 18 | 19 | def plot(name, value): 20 | _since_last_flush[name][_iter[0]] = value 21 | 22 | def flush(outf, logfile): 23 | prints = [] 24 | 25 | for name, vals in _since_last_flush.items(): 26 | prints.append("{}\t{}".format(name, np.mean(vals.values()))) 27 | _since_beginning[name].update(vals) 28 | 29 | x_vals = np.sort(_since_beginning[name].keys()) 30 | y_vals = [_since_beginning[name][x] for x in x_vals] 31 | 32 | plt.clf() 33 | plt.plot(x_vals, y_vals) 34 | plt.xlabel('iteration') 35 | plt.ylabel(name) 36 | plt.savefig(os.path.join(outf, name.replace(' ', '_')+'.jpg')) 37 | 38 | print "iter {}\t{}".format(_iter[0], "\t".join(prints)) 39 | with open(logfile,'a') as f: 40 | f.write("iter {}\t{}".format(_iter[0], "\t".join(prints)) + "\n") 41 | _since_last_flush.clear() -------------------------------------------------------------------------------- /tflib/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib as mpl 4 | mpl.use('Agg') 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | import os 8 | 9 | 10 | def scatter(data, label, dir, file_name, mus=None, mark_size=2): 11 | if label.ndim == 2: 12 | label = np.argmax(label, axis=1) 13 | 14 | df = pd.DataFrame(data={'x':data[:,0], 'y':data[:,1], 'class':label}) 15 | sns_plot = sns.lmplot('x', 'y', data=df, hue='class', fit_reg=False, scatter_kws={'s':mark_size}) 16 | sns_plot.savefig(os.path.join(dir, file_name)) 17 | if mus is not None: 18 | df_mus = pd.DataFrame(data={'x':mus[:,0], 'y':mus[:,1], 'class':np.asarray(xrange(mus.shape[0])).astype(np.int32)}) 19 | sns_plot_mus = sns.lmplot('x', 'y', data=df_mus, hue='class', fit_reg=False, scatter_kws={'s':mark_size*20}) 20 | sns_plot_mus.savefig(os.path.join(dir, 'mus_'+file_name)) 21 | # data = np.vstack((data, mus)) 22 | # label = np.hstack((label, (np.ones(mus.shape[0])*(label.max()+1)).astype(np.int32))) 23 | # df = pd.DataFrame(data={'x':data[:,0], 'y':data[:,1], 'class':label}) 24 | # sns_plot = sns.lmplot('x', 'y', data=df, hue='class', fit_reg=False, scatter_kws={'s':mark_size}) 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graphical Generative Adversarial Networks (Graphical-GAN) 2 | ## [Chongxuan Li](https://github.com/zhenxuan00), Max Welling, Jun Zhu and Bo Zhang 3 | 4 | Code for reproducing most of the results in the [paper](https://arxiv.org/abs/1804.03429). The results of our method is called LOCAL_EP in the code. We also provide implementation of a lot of recent papers, which is of independent interests. The papers including [VEGAN](https://arxiv.org/abs/1705.07642), [ALI](https://arxiv.org/abs/1606.00704), [ALICE](https://arxiv.org/abs/1709.01215). We also try some combination of these methods while the most direct competitor of our method is ALI. 5 | 6 | Warning: the code is still under development. If you have any problem with the code, please send an email to chongxuanli1991@gmail.com. Any feedback will be appreciated! 7 | 8 | We thank the authors of [wgan-gp](https://github.com/igul222/improved_wgan_training) for providing their code. Our code is widely adapted from their repositories. 9 | 10 | You may need to download the datasets and save it to the dataset folder except the MNIST case. See details in the corresponding files of the dataset. 11 | 12 | If you find the code is useful, please cite our paper! 13 | 14 | @article{li2018graphical, 15 | title={Graphical Generative Adversarial Networks}, 16 | author={Li, Chongxuan and Welling, Max and Zhu, Jun and Zhang, Bo}, 17 | journal={arXiv preprint arXiv:1804.03429}, 18 | year={2018} 19 | } 20 | -------------------------------------------------------------------------------- /tflib/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | import urllib 5 | import gzip 6 | import cPickle as pickle 7 | 8 | def unpickle(file): 9 | fo = open(file, 'rb') 10 | dict = pickle.load(fo) 11 | fo.close() 12 | return dict['data'], dict['labels'] 13 | 14 | def get_reconstruction_data(n_samples, data_dir): 15 | # fixed reconstruction samples for comparison 16 | np.random.seed(1234) 17 | data, _ = unpickle(data_dir + '/test_batch') 18 | np.random.shuffle(data) 19 | return data[:n_samples] 20 | 21 | def cifar_generator(filenames, batch_size, data_dir): 22 | all_data = [] 23 | all_labels = [] 24 | for filename in filenames: 25 | data, labels = unpickle(data_dir + '/' + filename) 26 | all_data.append(data) 27 | all_labels.append(labels) 28 | 29 | images = np.concatenate(all_data, axis=0) 30 | labels = np.concatenate(all_labels, axis=0) 31 | 32 | def get_epoch(): 33 | rng_state = np.random.get_state() 34 | np.random.shuffle(images) 35 | np.random.set_state(rng_state) 36 | np.random.shuffle(labels) 37 | 38 | for i in xrange(len(images) / batch_size): 39 | yield (images[i*batch_size:(i+1)*batch_size], labels[i*batch_size:(i+1)*batch_size]) 40 | 41 | return get_epoch 42 | 43 | 44 | def load(batch_size, data_dir): 45 | return ( 46 | cifar_generator(['data_batch_1','data_batch_2','data_batch_3','data_batch_4','data_batch_5'], batch_size, data_dir), 47 | cifar_generator(['test_batch'], batch_size, data_dir) 48 | ) -------------------------------------------------------------------------------- /tflib/ops/minibatch.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | _weights_stdev = None 7 | def set_weights_stdev(weights_stdev): 8 | global _weights_stdev 9 | _weights_stdev = weights_stdev 10 | 11 | def unset_weights_stdev(): 12 | global _weights_stdev 13 | _weights_stdev = None 14 | 15 | 16 | def MiniBatchLayer(name, num_inputs, num_kernels, dim_per_kernel, inputs): 17 | with tf.name_scope(name) as scope: 18 | def uniform(stdev, size): 19 | if _weights_stdev is not None: 20 | stdev = _weights_stdev 21 | return np.random.uniform( 22 | low=-stdev * np.sqrt(3), 23 | high=stdev * np.sqrt(3), 24 | size=size 25 | ).astype('float32') 26 | 27 | weight_values = uniform(np.sqrt(2./num_inputs),(num_inputs, num_kernels, dim_per_kernel)) 28 | 29 | weight = lib.param( 30 | name + '.W', 31 | weight_values 32 | ) 33 | 34 | bias = lib.param( 35 | name + '.b', 36 | np.zeros((num_kernels,),dtype='float32') 37 | ) 38 | 39 | activation = tf.tensordot(inputs, weight, [[1], [0]]) 40 | abs_dif = (tf.reduce_sum(tf.abs(tf.expand_dims(activation, axis=-1) - tf.expand_dims(tf.transpose(activation, perm=[1, 2, 0]), axis=0)), axis=2)+ 1e6 * tf.expand_dims(tf.eye(tf.shape(inputs)[0]), axis=1)) 41 | 42 | f = tf.reduce_sum(tf.exp(-abs_dif), axis=2) 43 | f += tf.expand_dims(bias, axis=0) 44 | return tf.concat([inputs, f], axis=1) -------------------------------------------------------------------------------- /tflib/ops/conv3d.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | def Conv3D(name, filter_len, input_dim, output_dim, filter_size, inputs, he_init=True, stride=1, stride_len=1, biases=True): 7 | """ 8 | inputs: tensor of shape (N, L, H, W, C) 9 | 10 | returns: tensor of shape (N, L, H, W, C) 11 | """ 12 | with tf.name_scope(name) as scope: 13 | def uniform(stdev, size): 14 | return np.random.uniform( 15 | low=-stdev * np.sqrt(3), 16 | high=stdev * np.sqrt(3), 17 | size=size 18 | ).astype('float32') 19 | 20 | fan_in = input_dim * filter_size**2 * filter_len 21 | fan_out = output_dim * filter_size**2 / (stride**2) * filter_len / stride_len 22 | 23 | if he_init: 24 | filters_stdev = np.sqrt(4./(fan_in+fan_out)) 25 | else: # Normalized init (Glorot & Bengio) 26 | filters_stdev = np.sqrt(2./(fan_in+fan_out)) 27 | 28 | filter_values = uniform( 29 | filters_stdev, 30 | (filter_len, filter_size, filter_size, input_dim, output_dim) 31 | ) 32 | 33 | filters = lib.param(name+'.Filters', filter_values) 34 | 35 | result = tf.nn.conv3d( 36 | input=inputs, 37 | filter=filters, 38 | strides=[1, stride_len, stride, stride, 1], 39 | padding='SAME', 40 | data_format='NDHWC' 41 | ) 42 | 43 | if biases: 44 | _biases = lib.param( 45 | name+'.Biases', 46 | np.zeros((1, 1, 1, 1, output_dim), dtype='float32') 47 | ) 48 | 49 | result = tf.add(result, _biases) 50 | 51 | return result 52 | -------------------------------------------------------------------------------- /tflib/svhn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | import urllib 5 | import gzip 6 | import cPickle as pickle 7 | from scipy.io import loadmat 8 | 9 | def maybe_download(data_dir): 10 | if not os.path.exists(data_dir): 11 | os.makedirs(data_dir) 12 | def _progress(count, block_size, total_size): 13 | sys.stdout.write('\r>> Downloading %.1f%%' % (float(count * block_size) / float(total_size) * 100.0)) 14 | sys.stdout.flush() 15 | filepath, _ = urllib.request.urlretrieve('http://ufldl.stanford.edu/housenumbers/train_32x32.mat', data_dir+'/train_32x32.mat', _progress) 16 | filepath, _ = urllib.request.urlretrieve('http://ufldl.stanford.edu/housenumbers/test_32x32.mat', data_dir+'/test_32x32.mat', _progress) 17 | 18 | def svhn_generator(data, batch_size): 19 | images, labels = data 20 | 21 | def get_epoch(): 22 | rng_state = np.random.get_state() 23 | np.random.shuffle(images) 24 | np.random.set_state(rng_state) 25 | np.random.shuffle(labels) 26 | 27 | for i in xrange(len(images) / batch_size): 28 | yield (images[i*batch_size:(i+1)*batch_size], labels[i*batch_size:(i+1)*batch_size]) 29 | 30 | return get_epoch 31 | 32 | def load(batch_size, data_dir): 33 | maybe_download(data_dir) 34 | train_data = loadmat(os.path.join(data_dir, 'train_32x32.mat')) 35 | trainx = train_data['X'] 36 | trainy = train_data['y'].flatten() 37 | trainy[trainy==10] = 0 38 | test_data = loadmat(os.path.join(data_dir, 'test_32x32.mat')) 39 | testx = test_data['X'] 40 | testy = test_data['y'].flatten() 41 | testy[testy==10] = 0 42 | trainx = np.transpose(trainx, [3, 2, 0, 1]) 43 | testx = np.transpose(testx, [3, 2, 0, 1]) 44 | trainx = trainx.reshape([-1, 32*32*3]) 45 | testx = testx.reshape([-1, 32*32*3]) 46 | return ( 47 | svhn_generator((trainx, trainy), batch_size), 48 | svhn_generator((testx, testy), batch_size) 49 | ) -------------------------------------------------------------------------------- /tflib/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | import os 4 | import urllib 5 | import gzip 6 | import cPickle as pickle 7 | 8 | def mnist_generator(data, batch_size, n_labelled, limit=None): 9 | images, targets = data 10 | 11 | rng_state = numpy.random.get_state() 12 | numpy.random.shuffle(images) 13 | numpy.random.set_state(rng_state) 14 | numpy.random.shuffle(targets) 15 | if limit is not None: 16 | print "WARNING ONLY FIRST {} MNIST DIGITS".format(limit) 17 | images = images.astype('float32')[:limit] 18 | targets = targets.astype('int32')[:limit] 19 | if n_labelled is not None: 20 | labelled = numpy.zeros(len(images), dtype='int32') 21 | labelled[:n_labelled] = 1 22 | 23 | def get_epoch(): 24 | rng_state = numpy.random.get_state() 25 | numpy.random.shuffle(images) 26 | numpy.random.set_state(rng_state) 27 | numpy.random.shuffle(targets) 28 | 29 | if n_labelled is not None: 30 | numpy.random.set_state(rng_state) 31 | numpy.random.shuffle(labelled) 32 | 33 | image_batches = images.reshape(-1, batch_size, 784) 34 | target_batches = targets.reshape(-1, batch_size) 35 | 36 | if n_labelled is not None: 37 | labelled_batches = labelled.reshape(-1, batch_size) 38 | 39 | for i in xrange(len(image_batches)): 40 | yield (numpy.copy(image_batches[i]), numpy.copy(target_batches[i]), numpy.copy(labelled)) 41 | 42 | else: 43 | 44 | for i in xrange(len(image_batches)): 45 | yield (numpy.copy(image_batches[i]), numpy.copy(target_batches[i])) 46 | 47 | return get_epoch 48 | 49 | def load(batch_size, test_batch_size, n_labelled=None): 50 | filepath = '/tmp/mnist.pkl.gz' 51 | url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' 52 | 53 | if not os.path.isfile(filepath): 54 | print "Couldn't find MNIST dataset in /tmp, downloading..." 55 | urllib.urlretrieve(url, filepath) 56 | 57 | with gzip.open('/tmp/mnist.pkl.gz', 'rb') as f: 58 | train_data, dev_data, test_data = pickle.load(f) 59 | 60 | return ( 61 | mnist_generator(train_data, batch_size, n_labelled), 62 | mnist_generator(dev_data, test_batch_size, n_labelled), 63 | mnist_generator(test_data, test_batch_size, n_labelled) 64 | ) -------------------------------------------------------------------------------- /tflib/celebA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import scipy.misc 4 | import os 5 | import urllib 6 | import gzip 7 | import cPickle as pickle 8 | from glob import glob 9 | from scipy.misc import imsave, imresize 10 | 11 | def celeba_generator(batch_size, images): 12 | def get_epoch(): 13 | rng_state = np.random.get_state() 14 | np.random.shuffle(images) 15 | 16 | for i in xrange(len(images) / batch_size): 17 | yield images[i*batch_size:(i+1)*batch_size] 18 | 19 | return get_epoch 20 | 21 | def load(batch_size, data_dir, num_dev=5000): 22 | data = np.load(os.path.join(data_dir, 'celebA_64x64.npy')) 23 | 24 | data = data.reshape(data.shape[0], -1) 25 | 26 | rng_state = np.random.get_state() 27 | np.random.shuffle(data) 28 | 29 | x_train = data[num_dev:] 30 | x_test = data[:num_dev] 31 | 32 | return ( 33 | celeba_generator(batch_size, x_train), 34 | celeba_generator(batch_size, x_test) 35 | ) 36 | 37 | def imread(path, grayscale=False): 38 | if (grayscale): 39 | return scipy.misc.imread(path, flatten = True).astype(np.float) 40 | else: 41 | return scipy.misc.imread(path).astype(np.float) 42 | 43 | def center_crop(x, resize_h=64, resize_w=64): 44 | h, w = x.shape[:2] 45 | assert(h >= w) 46 | new_h = int(h * resize_w / w) 47 | x = imresize(x, (new_h, resize_w)) 48 | margin = int(round((new_h - resize_h)/2)) 49 | return x[margin:margin+resize_h] 50 | 51 | def transform(image, resize_height=64, resize_width=64): 52 | cropped_image = center_crop(image, resize_height, resize_width) 53 | return np.array(cropped_image) 54 | 55 | def get_image(image_path, resize_height=64, resize_width=64, grayscale=False): 56 | image = imread(image_path, grayscale) 57 | return transform(image, resize_height, resize_width) 58 | 59 | def print_array(x): 60 | print x.shape, x.dtype, x.max(), x.min() 61 | 62 | def convert_to_numpy(data_path, size=64): 63 | data = glob(os.path.join(data_path, '*.jpg')) 64 | 65 | sample_files = data[0:202599] 66 | sample = [get_image(sample_file, 67 | resize_height=size, 68 | resize_width=size, 69 | grayscale=False) for sample_file in sample_files] 70 | sample_inputs = np.array(sample) 71 | sample_inputs = np.transpose(sample_inputs, [0, 3, 1, 2]) 72 | print_array(sample_inputs) 73 | np.save('celebA_64x64.npy', sample_inputs) -------------------------------------------------------------------------------- /tflib/save_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image grid saver, based on color_grid_vis from github.com/Newmu 3 | """ 4 | 5 | import numpy as np 6 | import scipy.misc 7 | from scipy.misc import imsave 8 | import imageio 9 | 10 | 11 | def large_image(X, size=None): 12 | # [0, 1] -> [0,255] 13 | if isinstance(X.flatten()[0], np.floating): 14 | X = (255.99*X).astype('uint8') 15 | 16 | n_samples = X.shape[0] 17 | 18 | if size == None: 19 | rows = int(np.sqrt(n_samples)) 20 | while n_samples % rows != 0: 21 | rows -= 1 22 | 23 | nh, nw = rows, n_samples/rows 24 | else: 25 | nh, nw = size 26 | assert(nh * nw == n_samples) 27 | 28 | if X.ndim == 2: 29 | X = np.reshape(X, (X.shape[0], int(np.sqrt(X.shape[1])), int(np.sqrt(X.shape[1])))) 30 | 31 | if X.ndim == 4: 32 | # BCHW -> BHWC 33 | X = X.transpose(0,2,3,1) 34 | h, w = X[0].shape[:2] 35 | img = np.zeros((h*nh, w*nw, 3)) 36 | elif X.ndim == 3: 37 | h, w = X[0].shape[:2] 38 | img = np.zeros((h*nh, w*nw)) 39 | 40 | for n, x in enumerate(X): 41 | j = n/nw 42 | i = n%nw 43 | img[j*h:j*h+h, i*w:i*w+w] = x 44 | 45 | return img.astype('uint8') 46 | 47 | def save_gifs(x, save_path, size=None): 48 | final_list = [] 49 | for i in xrange(x.shape[1]): 50 | final_list.append(large_image(x[:,i,:,:,:], size=size)) 51 | imageio.mimsave(save_path, final_list) 52 | 53 | def save_images(X, save_path, size=None): 54 | # [0, 1] -> [0,255] 55 | if isinstance(X.flatten()[0], np.floating): 56 | X = (255.99*X).astype('uint8') 57 | 58 | n_samples = X.shape[0] 59 | 60 | if size == None: 61 | rows = int(np.sqrt(n_samples)) 62 | while n_samples % rows != 0: 63 | rows -= 1 64 | 65 | nh, nw = rows, n_samples/rows 66 | else: 67 | nh, nw = size 68 | assert(nh * nw == n_samples) 69 | 70 | if X.ndim == 2: 71 | X = np.reshape(X, (X.shape[0], int(np.sqrt(X.shape[1])), int(np.sqrt(X.shape[1])))) 72 | 73 | if X.ndim == 4: 74 | # BCHW -> BHWC 75 | X = X.transpose(0,2,3,1) 76 | h, w = X[0].shape[:2] 77 | img = np.zeros((h*nh, w*nw, 3)) 78 | elif X.ndim == 3: 79 | h, w = X[0].shape[:2] 80 | img = np.zeros((h*nh, w*nw)) 81 | 82 | for n, x in enumerate(X): 83 | j = n/nw 84 | i = n%nw 85 | img[j*h:j*h+h, i*w:i*w+w] = x 86 | 87 | imsave(save_path, img) -------------------------------------------------------------------------------- /tflib/objs/gan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tflib as lib 3 | 4 | def wgan(disc_fake, disc_real, gen_params, disc_params, lr=5e-5): 5 | gen_cost = -tf.reduce_mean(disc_fake) 6 | disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real) 7 | 8 | gen_train_op = tf.train.RMSPropOptimizer( 9 | learning_rate=lr 10 | ).minimize(gen_cost, var_list=gen_params) 11 | disc_train_op = tf.train.RMSPropOptimizer( 12 | learning_rate=lr 13 | ).minimize(disc_cost, var_list=disc_params) 14 | 15 | clip_ops = [] 16 | for var in lib.params_with_name('Discriminator'): 17 | clip_bounds = [-.01, .01] 18 | clip_ops.append( 19 | tf.assign( 20 | var, 21 | tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]) 22 | ) 23 | ) 24 | clip_disc_weights = tf.group(*clip_ops) 25 | 26 | return gen_cost, disc_cost, clip_disc_weights, gen_train_op, disc_train_op, clip_ops 27 | 28 | def wgan_gp(disc_fake, disc_real, gradient_penalty, gen_params, disc_params, lr=1e-4): 29 | gen_cost = -tf.reduce_mean(disc_fake) 30 | disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real) 31 | 32 | disc_cost += gradient_penalty 33 | 34 | gen_train_op = tf.train.AdamOptimizer( 35 | learning_rate=lr, 36 | beta1=0.5, 37 | beta2=0.9 38 | ).minimize(gen_cost, var_list=gen_params) 39 | disc_train_op = tf.train.AdamOptimizer( 40 | learning_rate=lr, 41 | beta1=0.5, 42 | beta2=0.9 43 | ).minimize(disc_cost, var_list=disc_params) 44 | 45 | clip_disc_weights = None 46 | clip_ops = None 47 | 48 | return gen_cost, disc_cost, clip_disc_weights, gen_train_op, disc_train_op, clip_ops 49 | 50 | def gan(disc_fake, disc_real, gen_params, disc_params, lr=2e-4): 51 | gen_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 52 | logits=disc_fake, 53 | labels=tf.ones_like(disc_fake) 54 | )) 55 | 56 | disc_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 57 | logits=disc_fake, 58 | labels=tf.zeros_like(disc_fake) 59 | )) 60 | disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 61 | logits=disc_real, 62 | labels=tf.ones_like(disc_real) 63 | )) 64 | disc_cost /= 2. 65 | 66 | gen_train_op = tf.train.AdamOptimizer( 67 | learning_rate=lr, 68 | beta1=0.5 69 | ).minimize(gen_cost, var_list=gen_params) 70 | disc_train_op = tf.train.AdamOptimizer( 71 | learning_rate=lr, 72 | beta1=0.5 73 | ).minimize(disc_cost, var_list=disc_params) 74 | 75 | clip_disc_weights = None 76 | clip_ops = None 77 | 78 | return gen_cost, disc_cost, clip_disc_weights, gen_train_op, disc_train_op, clip_ops 79 | -------------------------------------------------------------------------------- /tflib/chairs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import os 4 | import urllib 5 | import gzip 6 | import cPickle as pickle 7 | from glob import glob 8 | from scipy.misc import imsave, imresize, imread 9 | # import moviepy.editor as mpy 10 | 11 | def rand_clip(x, seq_length): 12 | start = np.random.randint(x.shape[0] - seq_length + 1) 13 | return x[start:start+seq_length] 14 | 15 | def chair_generator(batch_size, seq_length, data, size): 16 | def get_epoch(): 17 | if seq_length == 1: 18 | data_all = data.reshape((-1, size*size*3)) 19 | elif seq_length == 31: 20 | data_all = data.reshape((-1, 31, size*size*3)) 21 | elif seq_length == 4: 22 | data_all = [] 23 | for d in data: 24 | data_all.append(rand_clip(d, seq_length)) 25 | data_all = np.asarray(data_all) 26 | data_all = data_all[:,:seq_length,:] 27 | else: 28 | data_all = data[:,:seq_length,:] 29 | 30 | np.random.shuffle(data_all) 31 | #print 'data_shape', data_all.shape 32 | for i in xrange(data_all.shape[0] / batch_size): 33 | yield data_all[i*batch_size:(i+1)*batch_size] 34 | return get_epoch 35 | 36 | def load(seq_length, batch_size, size, data_dir, num_dev=200): 37 | data = np.load(os.path.join(data_dir, 'chairs_'+str(size)+'.npy')) 38 | data = np.transpose(data, [0, 1, 4, 2, 3]) 39 | data = data.reshape((-1, 31, size*size*3)) 40 | np.random.shuffle(data) 41 | return ( 42 | chair_generator(batch_size, seq_length, data[num_dev:], size), 43 | chair_generator(batch_size, seq_length, data[:num_dev], size) 44 | ) 45 | 46 | # def npy_to_gif(npy, size): 47 | # for i in xrange(npy.shape[0]): 48 | # clip = mpy.ImageSequenceClip(list(npy[i]), fps=5) 49 | # clip.write_gif(str(size)+'_'+str(i)+'.gif') 50 | 51 | def npy_to_image(npy, size): 52 | for i in xrange(npy.shape[0]): 53 | imsave(str(size)+'_'+str(i)+'.png', npy[i]) 54 | 55 | def imread(path, grayscale=False): 56 | if (grayscale): 57 | return scipy.misc.imread(path, flatten = True).astype(np.float) 58 | else: 59 | return scipy.misc.imread(path).astype(np.float) 60 | 61 | def center_crop(image, size): 62 | image = image[140:460, 140:460, :] 63 | image = imresize(image, (size, size)) 64 | return np.array(image) 65 | 66 | def get_image(image_path, size, grayscale): 67 | image = imread(image_path, grayscale) 68 | return center_crop(image, size) 69 | 70 | def print_array(x): 71 | print x.shape, x.dtype, x.max(), x.min() 72 | 73 | def convert_to_numpy(size): 74 | data = glob(os.path.join('*/renders/*.png')) 75 | data.sort() 76 | sample = [get_image(d, size, grayscale=False) for d in data] 77 | sample_inputs = np.array(sample).astype(np.int32) 78 | #npy_to_image(sample_inputs, size) 79 | sample_inputs = sample_inputs.reshape((-1, 31, size, size, 3)) 80 | #sample_inputs = np.transpose(sample_inputs, [1, 0, 2, 3, 4]) 81 | print_array(sample_inputs) 82 | #npy_to_gif(sample_inputs, size) 83 | np.save('chairs_'+str(size), sample_inputs) 84 | 85 | # convert_to_numpy(64) 86 | # convert_to_numpy(32) -------------------------------------------------------------------------------- /tflib/objs/mmd.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tflib as lib 3 | 4 | def maximum_mean_discripancy(sample, data, batch_size, sigma=[2. , 5., 10., 20., 40., 80.]): 5 | x = tf.concat([sample, data], axis=0) 6 | xx = tf.matmul(x, x, transpose_b=True) 7 | x2 = tf.reduce_sum(tf.multiply(x, x), axis=1, keep_dims=True) 8 | exponent = tf.add(tf.add(xx, tf.scalar_mul(-.5, x2)), tf.scalar_mul(-.5, tf.transpose(x2))) 9 | 10 | s_samples = tf.scalar_mul(1./batch_size, tf.ones([tf.shape(sample)[0], 1])) 11 | s_data = tf.scalar_mul(-1./batch_size, tf.ones([tf.shape(data)[0], 1])) 12 | s_all = tf.concat([s_samples, s_data], axis=0) 13 | s_mat = tf.matmul(s_all, s_all, transpose_b=True) 14 | mmd_loss = 0. 15 | for s in sigma: 16 | kernel_val = tf.exp(tf.scalar_mul(1./s, exponent)) 17 | mmd_loss += tf.reduce_sum(tf.multiply(s_mat, kernel_val)) 18 | return tf.sqrt(mmd_loss) 19 | 20 | def _mix_rbf_kernel(X, Y, sigmas, wts=None): 21 | if wts is None: 22 | wts = [1] * len(sigmas) 23 | 24 | XX = tf.matmul(X, X, transpose_b=True) 25 | XY = tf.matmul(X, Y, transpose_b=True) 26 | YY = tf.matmul(Y, Y, transpose_b=True) 27 | 28 | X_sqnorms = tf.diag_part(XX) 29 | Y_sqnorms = tf.diag_part(YY) 30 | 31 | r = lambda x: tf.expand_dims(x, 0) 32 | c = lambda x: tf.expand_dims(x, 1) 33 | 34 | K_XX, K_XY, K_YY = 0, 0, 0 35 | for sigma, wt in zip(sigmas, wts): 36 | gamma = 1 / (2 * sigma**2) 37 | K_XX += wt * tf.exp(-gamma * (-2 * XX + c(X_sqnorms) + r(X_sqnorms))) 38 | K_XY += wt * tf.exp(-gamma * (-2 * XY + c(X_sqnorms) + r(Y_sqnorms))) 39 | K_YY += wt * tf.exp(-gamma * (-2 * YY + c(Y_sqnorms) + r(Y_sqnorms))) 40 | 41 | return K_XX, K_XY, K_YY, tf.reduce_sum(wts) 42 | 43 | def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 44 | m = tf.cast(K_XX.get_shape()[0], tf.float32) 45 | n = tf.cast(K_YY.get_shape()[0], tf.float32) 46 | 47 | if biased: 48 | mmd2 = (tf.reduce_sum(K_XX) / (m * m) 49 | + tf.reduce_sum(K_YY) / (n * n) 50 | - 2 * tf.reduce_sum(K_XY) / (m * n)) 51 | else: 52 | if const_diagonal is not False: 53 | trace_X = m * const_diagonal 54 | trace_Y = n * const_diagonal 55 | else: 56 | trace_X = tf.trace(K_XX) 57 | trace_Y = tf.trace(K_YY) 58 | 59 | mmd2 = ((tf.reduce_sum(K_XX) - trace_X) / (m * (m - 1)) 60 | + (tf.reduce_sum(K_YY) - trace_Y) / (n * (n - 1)) 61 | - 2 * tf.reduce_sum(K_XY) / (m * n)) 62 | 63 | return mmd2 64 | 65 | def mix_rbf_mmd2(X, Y, sigmas=[2. , 5., 10., 20., 40., 80.], wts=None, biased=True): 66 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigmas, wts) 67 | return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 68 | 69 | def vegan_mmd(q_z, p_z, rec_penalty, gen_params, batch_size, lamb, lr=2e-4, beta1=.5): 70 | #mmd_cost = maximum_mean_discripancy(q_z, p_z, batch_size) 71 | gen_cost = lamb * mix_rbf_mmd2(q_z, p_z) 72 | gen_cost += rec_penalty 73 | 74 | gen_train_op = tf.train.AdamOptimizer( 75 | learning_rate=lr, 76 | beta1=beta1 77 | ).minimize(gen_cost, var_list=gen_params) 78 | 79 | return gen_cost, gen_train_op -------------------------------------------------------------------------------- /tflib/inception_score.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/openai/improved-gan/blob/master/inception_score/model.py 2 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os.path 8 | import sys 9 | import tarfile 10 | 11 | import numpy as np 12 | from six.moves import urllib 13 | import tensorflow as tf 14 | import glob 15 | import scipy.misc 16 | import math 17 | import sys 18 | 19 | MODEL_DIR = './inception_score_model' 20 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 21 | softmax = None 22 | 23 | # Call this function with list of images. Each of elements should be a 24 | # numpy array with values ranging from 0 to 255. 25 | def get_inception_score(images, splits=10): 26 | assert(type(images) == list) 27 | assert(type(images[0]) == np.ndarray) 28 | assert(len(images[0].shape) == 3) 29 | assert(np.max(images[0]) > 10) 30 | assert(np.min(images[0]) >= 0.0) 31 | inps = [] 32 | for img in images: 33 | img = img.astype(np.float32) 34 | inps.append(np.expand_dims(img, 0)) 35 | bs = 100 36 | with tf.Session() as sess: 37 | preds = [] 38 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 39 | for i in range(n_batches): 40 | # sys.stdout.write(".") 41 | # sys.stdout.flush() 42 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 43 | inp = np.concatenate(inp, 0) 44 | pred = sess.run(softmax, {'ExpandDims:0': inp}) 45 | preds.append(pred) 46 | preds = np.concatenate(preds, 0) 47 | scores = [] 48 | for i in range(splits): 49 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 50 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 51 | kl = np.mean(np.sum(kl, 1)) 52 | scores.append(np.exp(kl)) 53 | return np.mean(scores), np.std(scores) 54 | 55 | # This function is called automatically. 56 | def _init_inception(): 57 | global softmax 58 | if not os.path.exists(MODEL_DIR): 59 | os.makedirs(MODEL_DIR) 60 | filename = DATA_URL.split('/')[-1] 61 | filepath = os.path.join(MODEL_DIR, filename) 62 | if not os.path.exists(filepath): 63 | def _progress(count, block_size, total_size): 64 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 65 | filename, float(count * block_size) / float(total_size) * 100.0)) 66 | sys.stdout.flush() 67 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 68 | print() 69 | statinfo = os.stat(filepath) 70 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 71 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 72 | with tf.gfile.FastGFile(os.path.join( 73 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 74 | graph_def = tf.GraphDef() 75 | graph_def.ParseFromString(f.read()) 76 | _ = tf.import_graph_def(graph_def, name='') 77 | # Works with an arbitrary minibatch size. 78 | with tf.Session() as sess: 79 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 80 | ops = pool3.graph.get_operations() 81 | for op_idx, op in enumerate(ops): 82 | for o in op.outputs: 83 | shape = o.get_shape() 84 | shape = [s.value for s in shape] 85 | new_shape = [] 86 | for j, s in enumerate(shape): 87 | if s == 1 and j == 0: 88 | new_shape.append(None) 89 | else: 90 | new_shape.append(s) 91 | o._shape = tf.TensorShape(new_shape) 92 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 93 | logits = tf.matmul(tf.squeeze(pool3), w) 94 | softmax = tf.nn.softmax(logits) 95 | 96 | if softmax is None: 97 | _init_inception() 98 | -------------------------------------------------------------------------------- /tflib/ops/conv1d.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | _default_weightnorm = False 7 | def enable_default_weightnorm(): 8 | global _default_weightnorm 9 | _default_weightnorm = True 10 | 11 | def Conv1D(name, input_dim, output_dim, filter_size, inputs, he_init=True, mask_type=None, stride=1, weightnorm=None, biases=True, gain=1.): 12 | """ 13 | inputs: tensor of shape (batch size, num channels, width) 14 | mask_type: one of None, 'a', 'b' 15 | 16 | returns: tensor of shape (batch size, num channels, width) 17 | """ 18 | with tf.name_scope(name) as scope: 19 | 20 | if mask_type is not None: 21 | mask_type, mask_n_channels = mask_type 22 | 23 | mask = np.ones( 24 | (filter_size, input_dim, output_dim), 25 | dtype='float32' 26 | ) 27 | center = filter_size // 2 28 | 29 | # Mask out future locations 30 | # filter shape is (width, input channels, output channels) 31 | mask[center+1:, :, :] = 0. 32 | 33 | # Mask out future channels 34 | for i in xrange(mask_n_channels): 35 | for j in xrange(mask_n_channels): 36 | if (mask_type=='a' and i >= j) or (mask_type=='b' and i > j): 37 | mask[ 38 | center, 39 | i::mask_n_channels, 40 | j::mask_n_channels 41 | ] = 0. 42 | 43 | 44 | def uniform(stdev, size): 45 | return np.random.uniform( 46 | low=-stdev * np.sqrt(3), 47 | high=stdev * np.sqrt(3), 48 | size=size 49 | ).astype('float32') 50 | 51 | fan_in = input_dim * filter_size 52 | fan_out = output_dim * filter_size / stride 53 | 54 | if mask_type is not None: # only approximately correct 55 | fan_in /= 2. 56 | fan_out /= 2. 57 | 58 | if he_init: 59 | filters_stdev = np.sqrt(4./(fan_in+fan_out)) 60 | else: # Normalized init (Glorot & Bengio) 61 | filters_stdev = np.sqrt(2./(fan_in+fan_out)) 62 | 63 | filter_values = uniform( 64 | filters_stdev, 65 | (filter_size, input_dim, output_dim) 66 | ) 67 | # print "WARNING IGNORING GAIN" 68 | filter_values *= gain 69 | 70 | filters = lib.param(name+'.Filters', filter_values) 71 | 72 | if weightnorm==None: 73 | weightnorm = _default_weightnorm 74 | if weightnorm: 75 | norm_values = np.sqrt(np.sum(np.square(filter_values), axis=(0,1))) 76 | target_norms = lib.param( 77 | name + '.g', 78 | norm_values 79 | ) 80 | with tf.name_scope('weightnorm') as scope: 81 | norms = tf.sqrt(tf.reduce_sum(tf.square(filters), reduction_indices=[0,1])) 82 | filters = filters * (target_norms / norms) 83 | 84 | if mask_type is not None: 85 | with tf.name_scope('filter_mask'): 86 | filters = filters * mask 87 | 88 | result = tf.nn.conv1d( 89 | value=inputs, 90 | filters=filters, 91 | stride=stride, 92 | padding='SAME', 93 | data_format='NCHW' 94 | ) 95 | 96 | if biases: 97 | _biases = lib.param( 98 | name+'.Biases', 99 | np.zeros([output_dim], dtype='float32') 100 | ) 101 | 102 | # result = result + _biases 103 | 104 | result = tf.expand_dims(result, 3) 105 | result = tf.nn.bias_add(result, _biases, data_format='NCHW') 106 | result = tf.squeeze(result) 107 | 108 | return result 109 | -------------------------------------------------------------------------------- /tflib/ops/deconv2d.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | _default_weightnorm = False 7 | def enable_default_weightnorm(): 8 | global _default_weightnorm 9 | _default_weightnorm = True 10 | 11 | _weights_stdev = None 12 | def set_weights_stdev(weights_stdev): 13 | global _weights_stdev 14 | _weights_stdev = weights_stdev 15 | 16 | def unset_weights_stdev(): 17 | global _weights_stdev 18 | _weights_stdev = None 19 | 20 | def Deconv2D( 21 | name, 22 | input_dim, 23 | output_dim, 24 | filter_size, 25 | inputs, 26 | he_init=True, 27 | weightnorm=None, 28 | biases=True, 29 | gain=1., 30 | mask_type=None, 31 | stride=2, 32 | padding='SAME' 33 | ): 34 | """ 35 | inputs: tensor of shape (batch size, height, width, input_dim) 36 | returns: tensor of shape (batch size, 2*height, 2*width, output_dim) 37 | """ 38 | with tf.name_scope(name) as scope: 39 | 40 | if mask_type != None: 41 | raise Exception('Unsupported configuration') 42 | 43 | def uniform(stdev, size): 44 | return np.random.uniform( 45 | low=-stdev * np.sqrt(3), 46 | high=stdev * np.sqrt(3), 47 | size=size 48 | ).astype('float32') 49 | 50 | 51 | fan_in = input_dim * filter_size**2 / (stride**2) 52 | fan_out = output_dim * filter_size**2 53 | 54 | if he_init: 55 | filters_stdev = np.sqrt(4./(fan_in+fan_out)) 56 | else: # Normalized init (Glorot & Bengio) 57 | filters_stdev = np.sqrt(2./(fan_in+fan_out)) 58 | 59 | 60 | if _weights_stdev is not None: 61 | filter_values = uniform( 62 | _weights_stdev, 63 | (filter_size, filter_size, output_dim, input_dim) 64 | ) 65 | else: 66 | filter_values = uniform( 67 | filters_stdev, 68 | (filter_size, filter_size, output_dim, input_dim) 69 | ) 70 | 71 | filter_values *= gain 72 | 73 | filters = lib.param( 74 | name+'.Filters', 75 | filter_values 76 | ) 77 | 78 | if weightnorm==None: 79 | weightnorm = _default_weightnorm 80 | if weightnorm: 81 | norm_values = np.sqrt(np.sum(np.square(filter_values), axis=(0,1,3))) 82 | target_norms = lib.param( 83 | name + '.g', 84 | norm_values 85 | ) 86 | with tf.name_scope('weightnorm') as scope: 87 | norms = tf.sqrt(tf.reduce_sum(tf.square(filters), reduction_indices=[0,1,3])) 88 | filters = filters * tf.expand_dims(target_norms / norms, 1) 89 | 90 | 91 | inputs = tf.transpose(inputs, [0,2,3,1], name='NCHW_to_NHWC') 92 | 93 | input_shape = tf.shape(inputs) 94 | try: # tf pre-1.0 (top) vs 1.0 (bottom) 95 | output_shape = tf.pack([input_shape[0], stride*input_shape[1], stride*input_shape[2], output_dim]) 96 | except Exception as e: 97 | output_shape = tf.stack([input_shape[0], stride*input_shape[1], stride*input_shape[2], output_dim]) 98 | if padding == 'VALID': 99 | output_shape = tf.stack([input_shape[0], stride*(input_shape[1] - 1) + filter_size, stride*(input_shape[1] - 1) + filter_size, output_dim]) 100 | 101 | result = tf.nn.conv2d_transpose( 102 | value=inputs, 103 | filter=filters, 104 | output_shape=output_shape, 105 | strides=[1, stride, stride, 1], 106 | padding=padding 107 | ) 108 | 109 | if biases: 110 | _biases = lib.param( 111 | name+'.Biases', 112 | np.zeros(output_dim, dtype='float32') 113 | ) 114 | result = tf.nn.bias_add(result, _biases) 115 | 116 | result = tf.transpose(result, [0,3,1,2], name='NHWC_to_NCHW') 117 | 118 | 119 | return result 120 | -------------------------------------------------------------------------------- /tflib/ops/conv2d.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | _default_weightnorm = False 7 | def enable_default_weightnorm(): 8 | global _default_weightnorm 9 | _default_weightnorm = True 10 | 11 | _weights_stdev = None 12 | def set_weights_stdev(weights_stdev): 13 | global _weights_stdev 14 | _weights_stdev = weights_stdev 15 | 16 | def unset_weights_stdev(): 17 | global _weights_stdev 18 | _weights_stdev = None 19 | 20 | def Conv2D(name, input_dim, output_dim, filter_size, inputs, he_init=True, mask_type=None, stride=1, weightnorm=None, biases=True, gain=1., padding='SAME'): 21 | """ 22 | inputs: tensor of shape (batch size, num channels, height, width) 23 | mask_type: one of None, 'a', 'b' 24 | 25 | returns: tensor of shape (batch size, num channels, height, width) 26 | """ 27 | with tf.name_scope(name) as scope: 28 | 29 | if mask_type is not None: 30 | mask_type, mask_n_channels = mask_type 31 | 32 | mask = np.ones( 33 | (filter_size, filter_size, input_dim, output_dim), 34 | dtype='float32' 35 | ) 36 | center = filter_size // 2 37 | 38 | # Mask out future locations 39 | # filter shape is (height, width, input channels, output channels) 40 | mask[center+1:, :, :, :] = 0. 41 | mask[center, center+1:, :, :] = 0. 42 | 43 | # Mask out future channels 44 | for i in xrange(mask_n_channels): 45 | for j in xrange(mask_n_channels): 46 | if (mask_type=='a' and i >= j) or (mask_type=='b' and i > j): 47 | mask[ 48 | center, 49 | center, 50 | i::mask_n_channels, 51 | j::mask_n_channels 52 | ] = 0. 53 | 54 | 55 | def uniform(stdev, size): 56 | return np.random.uniform( 57 | low=-stdev * np.sqrt(3), 58 | high=stdev * np.sqrt(3), 59 | size=size 60 | ).astype('float32') 61 | 62 | fan_in = input_dim * filter_size**2 63 | fan_out = output_dim * filter_size**2 / (stride**2) 64 | 65 | if mask_type is not None: # only approximately correct 66 | fan_in /= 2. 67 | fan_out /= 2. 68 | 69 | if he_init: 70 | filters_stdev = np.sqrt(4./(fan_in+fan_out)) 71 | else: # Normalized init (Glorot & Bengio) 72 | filters_stdev = np.sqrt(2./(fan_in+fan_out)) 73 | 74 | if _weights_stdev is not None: 75 | filter_values = uniform( 76 | _weights_stdev, 77 | (filter_size, filter_size, input_dim, output_dim) 78 | ) 79 | else: 80 | filter_values = uniform( 81 | filters_stdev, 82 | (filter_size, filter_size, input_dim, output_dim) 83 | ) 84 | 85 | # print "WARNING IGNORING GAIN" 86 | filter_values *= gain 87 | 88 | filters = lib.param(name+'.Filters', filter_values) 89 | 90 | if weightnorm==None: 91 | weightnorm = _default_weightnorm 92 | if weightnorm: 93 | norm_values = np.sqrt(np.sum(np.square(filter_values), axis=(0,1,2))) 94 | target_norms = lib.param( 95 | name + '.g', 96 | norm_values 97 | ) 98 | with tf.name_scope('weightnorm') as scope: 99 | norms = tf.sqrt(tf.reduce_sum(tf.square(filters), reduction_indices=[0,1,2])) 100 | filters = filters * (target_norms / norms) 101 | 102 | if mask_type is not None: 103 | with tf.name_scope('filter_mask'): 104 | filters = filters * mask 105 | 106 | result = tf.nn.conv2d( 107 | input=inputs, 108 | filter=filters, 109 | strides=[1, 1, stride, stride], 110 | padding=padding, 111 | data_format='NCHW' 112 | ) 113 | 114 | if biases: 115 | _biases = lib.param( 116 | name+'.Biases', 117 | np.zeros(output_dim, dtype='float32') 118 | ) 119 | 120 | result = tf.nn.bias_add(result, _biases, data_format='NCHW') 121 | 122 | 123 | return result 124 | -------------------------------------------------------------------------------- /tflib/ops/batchnorm.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | def Batchnorm(name, axes, inputs, is_training=None, stats_iter=None, update_moving_stats=True, fused=True): 7 | if ((axes == [0,2,3]) or (axes == [0,2])) and fused==True: 8 | if axes==[0,2]: 9 | inputs = tf.expand_dims(inputs, 3) 10 | # Old (working but pretty slow) implementation: 11 | ########## 12 | 13 | # inputs = tf.transpose(inputs, [0,2,3,1]) 14 | 15 | # mean, var = tf.nn.moments(inputs, [0,1,2], keep_dims=False) 16 | # offset = lib.param(name+'.offset', np.zeros(mean.get_shape()[-1], dtype='float32')) 17 | # scale = lib.param(name+'.scale', np.ones(var.get_shape()[-1], dtype='float32')) 18 | # result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-4) 19 | 20 | # return tf.transpose(result, [0,3,1,2]) 21 | 22 | # New (super fast but untested) implementation: 23 | offset = lib.param(name+'.offset', np.zeros(inputs.get_shape()[1], dtype='float32')) 24 | scale = lib.param(name+'.scale', np.ones(inputs.get_shape()[1], dtype='float32')) 25 | 26 | moving_mean = lib.param(name+'.moving_mean', np.zeros(inputs.get_shape()[1], dtype='float32'), trainable=False) 27 | moving_variance = lib.param(name+'.moving_variance', np.ones(inputs.get_shape()[1], dtype='float32'), trainable=False) 28 | 29 | def _fused_batch_norm_training(): 30 | return tf.nn.fused_batch_norm(inputs, scale, offset, epsilon=1e-5, data_format='NCHW') 31 | def _fused_batch_norm_inference(): 32 | # Version which blends in the current item's statistics 33 | batch_size = tf.cast(tf.shape(inputs)[0], 'float32') 34 | mean, var = tf.nn.moments(inputs, [2,3], keep_dims=True) 35 | mean = ((1./batch_size)*mean) + (((batch_size-1.)/batch_size)*moving_mean)[None,:,None,None] 36 | var = ((1./batch_size)*var) + (((batch_size-1.)/batch_size)*moving_variance)[None,:,None,None] 37 | return tf.nn.batch_normalization(inputs, mean, var, offset[None,:,None,None], scale[None,:,None,None], 1e-5), mean, var 38 | 39 | # Standard version 40 | # return tf.nn.fused_batch_norm( 41 | # inputs, 42 | # scale, 43 | # offset, 44 | # epsilon=1e-2, 45 | # mean=moving_mean, 46 | # variance=moving_variance, 47 | # is_training=False, 48 | # data_format='NCHW' 49 | # ) 50 | 51 | if is_training is None: 52 | outputs, batch_mean, batch_var = _fused_batch_norm_training() 53 | else: 54 | outputs, batch_mean, batch_var = tf.cond(is_training, 55 | _fused_batch_norm_training, 56 | _fused_batch_norm_inference) 57 | if update_moving_stats: 58 | no_updates = lambda: outputs 59 | def _force_updates(): 60 | """Internal function forces updates moving_vars if is_training.""" 61 | float_stats_iter = tf.cast(stats_iter, tf.float32) 62 | 63 | update_moving_mean = tf.assign(moving_mean, ((float_stats_iter/(float_stats_iter+1))*moving_mean) + ((1/(float_stats_iter+1))*batch_mean)) 64 | update_moving_variance = tf.assign(moving_variance, ((float_stats_iter/(float_stats_iter+1))*moving_variance) + ((1/(float_stats_iter+1))*batch_var)) 65 | 66 | with tf.control_dependencies([update_moving_mean, update_moving_variance]): 67 | return tf.identity(outputs) 68 | outputs = tf.cond(is_training, _force_updates, no_updates) 69 | 70 | if axes == [0,2]: 71 | return outputs[:,:,:,0] # collapse last dim 72 | else: 73 | return outputs 74 | else: 75 | # raise Exception('old BN') 76 | # TODO we can probably use nn.fused_batch_norm here too for speedup 77 | mean, var = tf.nn.moments(inputs, axes, keep_dims=True) 78 | shape = mean.get_shape().as_list() 79 | if 0 not in axes: 80 | print "WARNING ({}): didn't find 0 in axes, but not using separate BN params for each item in batch".format(name) 81 | shape[0] = 1 82 | offset = lib.param(name+'.offset', np.zeros(shape, dtype='float32')) 83 | scale = lib.param(name+'.scale', np.ones(shape, dtype='float32')) 84 | result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5) 85 | 86 | 87 | return result 88 | -------------------------------------------------------------------------------- /tflib/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import locale 5 | locale.setlocale(locale.LC_ALL, '') 6 | 7 | _params = {} 8 | _param_aliases = {} 9 | def param(name, *args, **kwargs): 10 | """ 11 | A wrapper for `tf.Variable` which enables parameter sharing in models. 12 | 13 | Creates and returns theano shared variables similarly to `tf.Variable`, 14 | except if you try to create a param with the same name as a 15 | previously-created one, `param(...)` will just return the old one instead of 16 | making a new one. 17 | 18 | This constructor also adds a `param` attribute to the shared variables it 19 | creates, so that you can easily search a graph for all params. 20 | """ 21 | 22 | if name not in _params: 23 | kwargs['name'] = name 24 | param = tf.Variable(*args, **kwargs) 25 | param.param = True 26 | _params[name] = param 27 | result = _params[name] 28 | i = 0 29 | while result in _param_aliases: 30 | # print 'following alias {}: {} to {}'.format(i, result, _param_aliases[result]) 31 | i += 1 32 | result = _param_aliases[result] 33 | return result 34 | 35 | def params_with_name(name): 36 | return [p for n,p in _params.items() if name in n] 37 | 38 | def delete_all_params(): 39 | _params.clear() 40 | 41 | def alias_params(replace_dict): 42 | for old,new in replace_dict.items(): 43 | # print "aliasing {} to {}".format(old,new) 44 | _param_aliases[old] = new 45 | 46 | def delete_param_aliases(): 47 | _param_aliases.clear() 48 | 49 | # def search(node, critereon): 50 | # """ 51 | # Traverse the Theano graph starting at `node` and return a list of all nodes 52 | # which match the `critereon` function. When optimizing a cost function, you 53 | # can use this to get a list of all of the trainable params in the graph, like 54 | # so: 55 | 56 | # `lib.search(cost, lambda x: hasattr(x, "param"))` 57 | # """ 58 | 59 | # def _search(node, critereon, visited): 60 | # if node in visited: 61 | # return [] 62 | # visited.add(node) 63 | 64 | # results = [] 65 | # if isinstance(node, T.Apply): 66 | # for inp in node.inputs: 67 | # results += _search(inp, critereon, visited) 68 | # else: # Variable node 69 | # if critereon(node): 70 | # results.append(node) 71 | # if node.owner is not None: 72 | # results += _search(node.owner, critereon, visited) 73 | # return results 74 | 75 | # return _search(node, critereon, set()) 76 | 77 | # def print_params_info(params): 78 | # """Print information about the parameters in the given param set.""" 79 | 80 | # params = sorted(params, key=lambda p: p.name) 81 | # values = [p.get_value(borrow=True) for p in params] 82 | # shapes = [p.shape for p in values] 83 | # print "Params for cost:" 84 | # for param, value, shape in zip(params, values, shapes): 85 | # print "\t{0} ({1})".format( 86 | # param.name, 87 | # ",".join([str(x) for x in shape]) 88 | # ) 89 | 90 | # total_param_count = 0 91 | # for shape in shapes: 92 | # param_count = 1 93 | # for dim in shape: 94 | # param_count *= dim 95 | # total_param_count += param_count 96 | # print "Total parameter count: {0}".format( 97 | # locale.format("%d", total_param_count, grouping=True) 98 | # ) 99 | 100 | def print_model_settings(locals_): 101 | print "Uppercase local vars:" 102 | all_vars = [(k,v) for (k,v) in locals_.items() if (k.isupper() and k!='T' and k!='SETTINGS' and k!='ALL_SETTINGS')] 103 | all_vars = sorted(all_vars, key=lambda x: x[0]) 104 | for var_name, var_value in all_vars: 105 | print "\t{}: {}".format(var_name, var_value) 106 | 107 | def print_model_settings_to_file(locals_, logfile): 108 | print "Uppercase local vars:" 109 | all_vars = [(k,v) for (k,v) in locals_.items() if (k.isupper() and k!='T' and k!='SETTINGS' and k!='ALL_SETTINGS')] 110 | all_vars = sorted(all_vars, key=lambda x: x[0]) 111 | for var_name, var_value in all_vars: 112 | print "\t{}: {}".format(var_name, var_value) 113 | with open(logfile,'a') as f: 114 | f.write("\t{}: {}".format(var_name, var_value)) 115 | 116 | def print_model_settings_dict(settings): 117 | print "Settings dict:" 118 | all_vars = [(k,v) for (k,v) in settings.items()] 119 | all_vars = sorted(all_vars, key=lambda x: x[0]) 120 | for var_name, var_value in all_vars: 121 | print "\t{}: {}".format(var_name, var_value) -------------------------------------------------------------------------------- /tflib/objs/kl_aggregated.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tflib as lib 4 | import math 5 | 6 | def mixture_gaussian(n_samples, n_coms, dim_z, mu, std): 7 | # prior 8 | pi = tf.constant(np.ones(n_coms).astype('float32')/n_coms) 9 | dist = tf.distributions.Categorical(probs=pi) 10 | # sample components 11 | k = tf.cast(tf.one_hot(indices=dist.sample(n_samples), depth=n_coms), tf.float32) 12 | # sample noise and transfer 13 | mu_k = tf.matmul(k, mu) 14 | std_k = tf.matmul(k, std) 15 | eps = tf.random_normal([n_samples, dim_z]) 16 | return tf.add(mu_k, tf.multiply(std_k, eps)) 17 | 18 | def log_likelihood_diagnoal_gaussian(x, mu, std): 19 | res_mat = -.5*(tf.pow((x - mu)/std, 2) + tf.log(2*math.pi) + 2*tf.log(std)) 20 | return tf.reduce_sum(res_mat, axis=-1) 21 | 22 | def log_likelihood_mixture_gaussian(x, mu, std): 23 | x = tf.expand_dims(x, axis=1) 24 | mu = tf.expand_dims(mu, axis=0) 25 | std = tf.expand_dims(std, axis=0) 26 | res_mat = log_likelihood_diagnoal_gaussian(x, mu, std) 27 | # log sum exp trick to aovid overflow 28 | res_max = tf.reduce_max(res_mat, axis=1) 29 | res_max_keep = tf.expand_dims(res_max, axis=1) 30 | return tf.log(tf.reduce_mean(tf.exp(res_mat - res_max_keep), axis=1)) + res_max 31 | 32 | def log_likelihood_mixture_mixture_gaussian(x, mu_q, std_q, mu_p, std_p, n_coms): 33 | x_q = tf.expand_dims(x, axis=1) # nz x 1 x dz 34 | mu_q = tf.expand_dims(mu_q, axis=0) # 1 x nx x dz 35 | std_q = tf.expand_dims(std_q, axis=0) 36 | res_mat_1 = log_likelihood_diagnoal_gaussian(x_q, mu_q, std_q) # nz x nx 37 | res_mat_2 = log_likelihood_diagnoal_gaussian(x, mu_p, std_p) # nz 38 | res_mat_2 = tf.tile(tf.expand_dims(res_mat_2, axis=1), [1, n_coms]) 39 | res_mat = tf.concat([res_mat_1, res_mat_2], axis=1) 40 | 41 | # log sum exp trick to aovid overflow 42 | res_max = tf.reduce_max(res_mat, axis=1) 43 | res_max_keep = tf.expand_dims(res_max, axis=1) 44 | return tf.log(tf.reduce_mean(tf.exp(res_mat - res_max_keep), axis=1)) + res_max 45 | 46 | def kl_q_aggregated_p_diagonal_gaussian(q_z_mean, q_z_std, p_z_mean, p_z_std, n_samples, n_coms, dim_z): 47 | # sample z from q(z) 48 | z = mixture_gaussian(n_samples, n_coms, dim_z, q_z_mean, q_z_std) 49 | log_q = log_likelihood_mixture_gaussian(z, q_z_mean, q_z_std) 50 | log_p = log_likelihood_diagnoal_gaussian(z, p_z_mean, p_z_std) 51 | return tf.reduce_mean(log_q - log_p, axis=0) 52 | 53 | def ikl_q_aggregated_p_diagonal_gaussian(q_z_mean, q_z_std, p_z_mean, p_z_std, n_samples, dim_z): 54 | # sample z from p(z) 55 | z = tf.random_normal([n_samples, dim_z]) 56 | log_q = log_likelihood_mixture_gaussian(z, q_z_mean, q_z_std) 57 | log_p = log_likelihood_diagnoal_gaussian(z, p_z_mean, p_z_std) 58 | return tf.reduce_mean(log_p - log_q, axis=0) 59 | 60 | def jsd_q_aggregated_p_diagonal_gaussian(q_z_mean, q_z_std, p_z_mean, p_z_std, n_samples, n_coms, dim_z): 61 | # sample z from q(z) 62 | z_1 = mixture_gaussian(n_samples, n_coms, dim_z, q_z_mean, q_z_std) 63 | log_q = log_likelihood_mixture_gaussian(z_1, q_z_mean, q_z_std) 64 | log_m_1 = log_likelihood_mixture_mixture_gaussian(z_1, q_z_mean, q_z_std, p_z_mean, p_z_std, n_coms) 65 | 66 | # sample z from p(z) 67 | z_2 = tf.random_normal([n_samples, dim_z]) 68 | log_p = log_likelihood_diagnoal_gaussian(z_2, p_z_mean, p_z_std) 69 | log_m_2 = log_likelihood_mixture_mixture_gaussian(z_2, q_z_mean, q_z_std, p_z_mean, p_z_std, n_coms) 70 | return tf.reduce_mean(.5*(log_q - log_m_1 + log_p -log_m_2), axis=0) 71 | 72 | def vegan_jsd(q_z_mean, q_z_std, p_z_mean, p_z_std, rec_penalty, gen_params, z_samples, batchsize, dim_z, lamb, lr=2e-4, beta1=.5): 73 | gen_cost = lamb * jsd_q_aggregated_p_diagonal_gaussian(q_z_mean, q_z_std, p_z_mean, p_z_std, z_samples, batchsize, dim_z) 74 | gen_cost += rec_penalty 75 | 76 | gen_train_op = tf.train.AdamOptimizer( 77 | learning_rate=lr, 78 | beta1=beta1 79 | ).minimize(gen_cost, var_list=gen_params) 80 | 81 | return gen_cost, gen_train_op 82 | 83 | def vegan_kl(q_z_mean, q_z_std, p_z_mean, p_z_std, rec_penalty, gen_params, z_samples, batchsize, dim_z, lamb, lr=2e-4, beta1=.5): 84 | gen_cost = lamb * kl_q_aggregated_p_diagonal_gaussian(q_z_mean, q_z_std, p_z_mean, p_z_std, z_samples, batchsize, dim_z) 85 | gen_cost += rec_penalty 86 | 87 | gen_train_op = tf.train.AdamOptimizer( 88 | learning_rate=lr, 89 | beta1=beta1 90 | ).minimize(gen_cost, var_list=gen_params) 91 | 92 | return gen_cost, gen_train_op 93 | 94 | def vegan_ikl(q_z_mean, q_z_std, p_z_mean, p_z_std, rec_penalty, gen_params, z_samples, dim_z, lamb, lr=2e-4, beta1=.5): 95 | gen_cost = lamb * ikl_q_aggregated_p_diagonal_gaussian(q_z_mean, q_z_std, p_z_mean, p_z_std, z_samples, dim_z) 96 | gen_cost += rec_penalty 97 | 98 | gen_train_op = tf.train.AdamOptimizer( 99 | learning_rate=lr, 100 | beta1=beta1 101 | ).minimize(gen_cost, var_list=gen_params) 102 | 103 | return gen_cost, gen_train_op -------------------------------------------------------------------------------- /tflib/ops/linear.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | _default_weightnorm = False 7 | def enable_default_weightnorm(): 8 | global _default_weightnorm 9 | _default_weightnorm = True 10 | 11 | def disable_default_weightnorm(): 12 | global _default_weightnorm 13 | _default_weightnorm = False 14 | 15 | _weights_stdev = None 16 | def set_weights_stdev(weights_stdev): 17 | global _weights_stdev 18 | _weights_stdev = weights_stdev 19 | 20 | def unset_weights_stdev(): 21 | global _weights_stdev 22 | _weights_stdev = None 23 | 24 | def Linear( 25 | name, 26 | input_dim, 27 | output_dim, 28 | inputs, 29 | biases=True, 30 | initialization=None, 31 | weightnorm=None, 32 | gain=1. 33 | ): 34 | """ 35 | initialization: None, `lecun`, 'glorot', `he`, 'glorot_he', `orthogonal`, `("uniform", range)` 36 | """ 37 | with tf.name_scope(name) as scope: 38 | 39 | def uniform(stdev, size): 40 | if _weights_stdev is not None: 41 | stdev = _weights_stdev 42 | return np.random.uniform( 43 | low=-stdev * np.sqrt(3), 44 | high=stdev * np.sqrt(3), 45 | size=size 46 | ).astype('float32') 47 | 48 | if initialization == 'lecun':# and input_dim != output_dim): 49 | # disabling orth. init for now because it's too slow 50 | weight_values = uniform( 51 | np.sqrt(1./input_dim), 52 | (input_dim, output_dim) 53 | ) 54 | 55 | elif initialization == 'glorot' or (initialization == None): 56 | 57 | weight_values = uniform( 58 | np.sqrt(2./(input_dim+output_dim)), 59 | (input_dim, output_dim) 60 | ) 61 | 62 | elif initialization == 'he': 63 | 64 | weight_values = uniform( 65 | np.sqrt(2./input_dim), 66 | (input_dim, output_dim) 67 | ) 68 | 69 | elif initialization == 'glorot_he': 70 | 71 | weight_values = uniform( 72 | np.sqrt(4./(input_dim+output_dim)), 73 | (input_dim, output_dim) 74 | ) 75 | 76 | elif initialization == 'orthogonal' or \ 77 | (initialization == None and input_dim == output_dim): 78 | 79 | # From lasagne 80 | def sample(shape): 81 | if len(shape) < 2: 82 | raise RuntimeError("Only shapes of length 2 or more are " 83 | "supported.") 84 | flat_shape = (shape[0], np.prod(shape[1:])) 85 | # TODO: why normal and not uniform? 86 | a = np.random.normal(0.0, 1.0, flat_shape) 87 | u, _, v = np.linalg.svd(a, full_matrices=False) 88 | # pick the one with the correct shape 89 | q = u if u.shape == flat_shape else v 90 | q = q.reshape(shape) 91 | return q.astype('float32') 92 | weight_values = sample((input_dim, output_dim)) 93 | 94 | elif initialization[0] == 'uniform': 95 | 96 | weight_values = np.random.uniform( 97 | low=-initialization[1], 98 | high=initialization[1], 99 | size=(input_dim, output_dim) 100 | ).astype('float32') 101 | 102 | else: 103 | 104 | raise Exception('Invalid initialization!') 105 | 106 | weight_values *= gain 107 | 108 | weight = lib.param( 109 | name + '.W', 110 | weight_values 111 | ) 112 | 113 | if weightnorm==None: 114 | weightnorm = _default_weightnorm 115 | if weightnorm: 116 | norm_values = np.sqrt(np.sum(np.square(weight_values), axis=0)) 117 | # norm_values = np.linalg.norm(weight_values, axis=0) 118 | 119 | target_norms = lib.param( 120 | name + '.g', 121 | norm_values 122 | ) 123 | 124 | with tf.name_scope('weightnorm') as scope: 125 | norms = tf.sqrt(tf.reduce_sum(tf.square(weight), reduction_indices=[0])) 126 | weight = weight * (target_norms / norms) 127 | 128 | # if 'Discriminator' in name: 129 | # print "WARNING weight constraint on {}".format(name) 130 | # weight = tf.nn.softsign(10.*weight)*.1 131 | 132 | if inputs.get_shape().ndims == 2: 133 | result = tf.matmul(inputs, weight) 134 | else: 135 | reshaped_inputs = tf.reshape(inputs, [-1, input_dim]) 136 | result = tf.matmul(reshaped_inputs, weight) 137 | result = tf.reshape(result, tf.pack(tf.unpack(tf.shape(inputs))[:-1] + [output_dim])) 138 | 139 | if biases: 140 | result = tf.nn.bias_add( 141 | result, 142 | lib.param( 143 | name + '.b', 144 | np.zeros((output_dim,), dtype='float32') 145 | ) 146 | ) 147 | 148 | return result -------------------------------------------------------------------------------- /tflib/simple_moving_mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | import urllib 5 | import gzip 6 | import cPickle as pickle 7 | 8 | 9 | def GetRandomTrajectory(step_length, seq_length, batch_size, image_size,digit_size): 10 | canvas_size = image_size - digit_size 11 | 12 | # Initial position uniform random inside the box. 13 | y = np.random.rand(batch_size) 14 | x = np.random.rand(batch_size) 15 | 16 | # Choose a random velocity. 17 | theta = np.random.rand(batch_size) * 2 * np.pi 18 | v_y = np.sin(theta) 19 | v_x = np.cos(theta) 20 | 21 | start_y = np.zeros((seq_length, batch_size)) 22 | start_x = np.zeros((seq_length, batch_size)) 23 | for i in xrange(seq_length): 24 | # Take a step along velocity. 25 | y += v_y * step_length 26 | x += v_x * step_length 27 | 28 | # Bounce off edges. 29 | for j in xrange(batch_size): 30 | if x[j] <= 0: 31 | x[j] = 0 32 | v_x[j] = -v_x[j] 33 | if x[j] >= 1.0: 34 | x[j] = 1.0 35 | v_x[j] = -v_x[j] 36 | if y[j] <= 0: 37 | y[j] = 0 38 | v_y[j] = -v_y[j] 39 | if y[j] >= 1.0: 40 | y[j] = 1.0 41 | v_y[j] = -v_y[j] 42 | start_y[i, :] = y 43 | start_x[i, :] = x 44 | 45 | # Scale to the size of the canvas. 46 | start_y = (canvas_size * start_y).astype(np.int32) 47 | start_x = (canvas_size * start_x).astype(np.int32) 48 | return start_y, start_x 49 | 50 | def Overlap(a, b): 51 | return np.maximum(a, b) 52 | #return b 53 | 54 | def moving_mnist_generator_video(data_all, seq_length, batch_size): 55 | images, labels = data_all 56 | images = images.reshape([-1, 28, 28]) 57 | image_size = 64 58 | num_digits = 1 59 | step_length = 0.1 60 | digit_size = 28 61 | 62 | def get_epoch(): 63 | rng_state = np.random.get_state() 64 | np.random.shuffle(images) 65 | np.random.set_state(rng_state) 66 | np.random.shuffle(labels) 67 | 68 | start_y, start_x = GetRandomTrajectory(step_length = step_length, seq_length = seq_length, batch_size = images.shape[0]*num_digits, image_size = image_size, digit_size = digit_size) 69 | 70 | data = np.zeros((images.shape[0], seq_length, image_size, image_size), dtype=np.float32) 71 | 72 | for j in xrange(images.shape[0]): 73 | for n in xrange(num_digits): 74 | 75 | digit_image = images[j, :, :] 76 | 77 | # generate video 78 | for i in xrange(seq_length): 79 | top = start_y[i, j * num_digits + n] 80 | left = start_x[i, j * num_digits + n] 81 | bottom = top + digit_size 82 | right = left + digit_size 83 | 84 | data[j, i, top:bottom, left:right] = Overlap(data[j, i, top:bottom, left:right], digit_image) 85 | 86 | data = data.reshape(images.shape[0], seq_length, image_size*image_size) 87 | 88 | for ind in xrange(data.shape[0]/ batch_size): 89 | yield data[ind*batch_size:(ind+1)*batch_size], labels[ind*batch_size:(ind+1)*batch_size] 90 | 91 | return get_epoch 92 | 93 | def load_video(seq_length, batch_size, cla=None): 94 | filepath = '/tmp/mnist.pkl.gz' 95 | url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' 96 | if not os.path.isfile(filepath): 97 | print "Couldn't find MNIST dataset in /tmp, downloading..." 98 | urllib.urlretrieve(url, filepath) 99 | with gzip.open('/tmp/mnist.pkl.gz', 'rb') as f: 100 | train_data, dev_data, test_data = pickle.load(f) 101 | train_all_x = np.concatenate([train_data[0], dev_data[0]], axis=0) 102 | train_all_y = np.concatenate([train_data[1], dev_data[1]], axis=0) 103 | 104 | if cla is not None: 105 | train_all_x = train_all_x[train_all_y == cla] 106 | train_all_y = train_all_y[train_all_y == cla] 107 | test_x, test_y = test_data 108 | test_x = test_x[test_y == cla] 109 | test_y = test_y[test_y == cla] 110 | test_data = (test_x, test_y) 111 | 112 | return (moving_mnist_generator_video((train_all_x, train_all_y), seq_length, batch_size), moving_mnist_generator_video(test_data, seq_length, batch_size)) 113 | 114 | def moving_mnist_generator_image(image, seq_length, batch_size): 115 | assert batch_size % seq_length == 0 116 | video_gen = moving_mnist_generator_video(image, seq_length, batch_size/seq_length) 117 | data = [] 118 | label = [] 119 | for v, y in video_gen(): 120 | data.append(v.reshape([batch_size, 64*64])) 121 | label.append(np.tile(y.reshape(-1, 1), [1, seq_length]).reshape(-1)) 122 | data = np.vstack(data) 123 | label = np.concatenate(label, axis=0) 124 | def get_epoch(): 125 | rng_state = np.random.get_state() 126 | np.random.shuffle(data) 127 | np.random.set_state(rng_state) 128 | np.random.shuffle(label) 129 | 130 | for i in xrange(len(data) / batch_size): 131 | yield data[i*batch_size:(i+1)*batch_size], label[i*batch_size:(i+1)*batch_size] 132 | return get_epoch 133 | 134 | def load_image(seq_length, batch_size, cla=None): 135 | filepath = '/tmp/mnist.pkl.gz' 136 | url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' 137 | if not os.path.isfile(filepath): 138 | print "Couldn't find MNIST dataset in /tmp, downloading..." 139 | urllib.urlretrieve(url, filepath) 140 | with gzip.open('/tmp/mnist.pkl.gz', 'rb') as f: 141 | train_data, dev_data, test_data = pickle.load(f) 142 | train_all_x = np.concatenate([train_data[0], dev_data[0]], axis=0) 143 | train_all_y = np.concatenate([train_data[1], dev_data[1]], axis=0) 144 | 145 | if cla is not None: 146 | train_all_x = train_all_x[train_all_y == cla] 147 | train_all_y = train_all_y[train_all_y == cla] 148 | test_x, test_y = test_data 149 | test_x = test_x[test_y == cla] 150 | test_y = test_y[test_y == cla] 151 | test_data = (test_x, test_y) 152 | 153 | return (moving_mnist_generator_image((train_all_x, train_all_y), seq_length, batch_size), moving_mnist_generator_image(test_data, seq_length, batch_size)) -------------------------------------------------------------------------------- /gan_inference_face.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil, time 2 | sys.path.append(os.getcwd()) 3 | 4 | import functools 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import sklearn.datasets 10 | from sklearn.manifold import TSNE 11 | import tensorflow as tf 12 | 13 | import tflib as lib 14 | import tflib.ops.linear 15 | import tflib.ops.conv2d 16 | import tflib.ops.batchnorm 17 | import tflib.ops.deconv2d 18 | import tflib.save_images 19 | import tflib.celebA 20 | import tflib.plot 21 | import tflib.visualization 22 | import tflib.objs.gan_inference 23 | import tflib.objs.mmd 24 | import tflib.objs.kl 25 | import tflib.objs.kl_aggregated 26 | import tflib.utils.distance 27 | 28 | DATA_DIR = './dataset/celebA' 29 | 30 | ''' 31 | hyperparameters 32 | ''' 33 | MODE = 'ali' # ali 34 | STD = .1 # For fix_std 35 | CRITIC_ITERS = 1 36 | BATCH_SIZE = 128 # Batch size 37 | LAMBDA = 1. # Balance reconstruction and regularization in vegan 38 | LR = 2e-4 39 | DECAY = False 40 | decay = 1. 41 | BETA1 = .5 42 | BETA2 = .999 43 | 44 | ITERS = 100000 # How many generator iterations to train for 45 | 46 | DIM_G = 32 # Model dimensionality 47 | DIM_D = 32 # Model dimensionality 48 | OUTPUT_DIM = 12288 # Number of pixels in celebA (3*64*64) 49 | DIM_LATENT = 128 50 | N_VIS = BATCH_SIZE*2 # Number of samples to be visualized 51 | 52 | 53 | ''' 54 | logs 55 | ''' 56 | filename_script=os.path.basename(os.path.realpath(__file__)) 57 | outf=os.path.join("result", os.path.splitext(filename_script)[0]) 58 | outf+='.MODE-' 59 | outf+=MODE 60 | outf+='.' 61 | outf+=str(int(time.time())) 62 | if not os.path.exists(outf): 63 | os.makedirs(outf) 64 | logfile=os.path.join(outf, 'logfile.txt') 65 | shutil.copy(os.path.realpath(__file__), os.path.join(outf, filename_script)) 66 | lib.print_model_settings_to_file(locals().copy(), logfile) 67 | 68 | 69 | ''' 70 | models 71 | ''' 72 | def nonlinearity(x): 73 | return tf.nn.relu(x) 74 | 75 | def LeakyReLU(x, alpha=0.2): 76 | return tf.maximum(alpha*x, x) 77 | 78 | def Generator(noise): 79 | output = lib.ops.linear.Linear('Generator.Input', DIM_LATENT, 4*4*8*DIM_G, noise) 80 | output = tf.nn.relu(output) 81 | output = tf.reshape(output, [-1, 8*DIM_G, 4, 4]) 82 | 83 | output = lib.ops.deconv2d.Deconv2D('Generator.2', 8*DIM_G, 4*DIM_G, 5, output) 84 | output = tf.nn.relu(output) 85 | 86 | output = lib.ops.deconv2d.Deconv2D('Generator.3', 4*DIM_G, 2*DIM_G, 5, output) 87 | output = tf.nn.relu(output) 88 | 89 | output = lib.ops.deconv2d.Deconv2D('Generator.4', 2*DIM_G, DIM_G, 5, output) 90 | output = tf.nn.relu(output) 91 | 92 | output = lib.ops.deconv2d.Deconv2D('Generator.5', DIM_G, 3, 5, output) 93 | output = tf.tanh(output) 94 | 95 | return tf.reshape(output, [-1, OUTPUT_DIM]) 96 | 97 | def Extractor(inputs): 98 | output = tf.reshape(inputs, [-1, 3, 64, 64]) 99 | 100 | output = lib.ops.conv2d.Conv2D('Extractor.1', 3, DIM_G, 5, output,stride=2) 101 | output = LeakyReLU(output) 102 | 103 | output = lib.ops.conv2d.Conv2D('Extractor.2', DIM_G, 2*DIM_G, 5, output, stride=2) 104 | output = LeakyReLU(output) 105 | 106 | output = lib.ops.conv2d.Conv2D('Extractor.3', 2*DIM_G, 4*DIM_G, 5, output, stride=2) 107 | output = LeakyReLU(output) 108 | 109 | output = lib.ops.conv2d.Conv2D('Extractor.4', 4*DIM_G, 8*DIM_G, 5, output, stride=2) 110 | output = LeakyReLU(output) 111 | 112 | output = tf.reshape(output, [-1, 4*4*8*DIM_G]) 113 | 114 | output = lib.ops.linear.Linear('Extractor.Output', 4*4*8*DIM_G, DIM_LATENT, output) 115 | 116 | return tf.reshape(output, [-1, DIM_LATENT]) 117 | 118 | def Discriminator(x, z): 119 | output = tf.reshape(x, [-1, 3, 64, 64]) 120 | 121 | output = lib.ops.conv2d.Conv2D('Discriminator.1', 3, DIM_D, 5,output, stride=2) 122 | output = LeakyReLU(output) 123 | output = tf.layers.dropout(output, rate=.2) 124 | 125 | output = lib.ops.conv2d.Conv2D('Discriminator.2', DIM_D, 2*DIM_D, 5, output, stride=2) 126 | output = LeakyReLU(output) 127 | output = tf.layers.dropout(output, rate=.2) 128 | 129 | output = lib.ops.conv2d.Conv2D('Discriminator.3', 2*DIM_D, 4*DIM_D, 5, output, stride=2) 130 | output = LeakyReLU(output) 131 | output = tf.layers.dropout(output, rate=.2) 132 | 133 | output = lib.ops.conv2d.Conv2D('Discriminator.4', 4*DIM_D, 8*DIM_D, 5, output, stride=2) 134 | output = LeakyReLU(output) 135 | output = tf.layers.dropout(output, rate=.2) 136 | 137 | output = tf.reshape(output, [-1, 4*4*8*DIM_D]) 138 | 139 | z_output = lib.ops.linear.Linear('Discriminator.z1', DIM_LATENT, 512, z) 140 | z_output = LeakyReLU(z_output) 141 | z_output = tf.layers.dropout(z_output, rate=.2) 142 | 143 | output = tf.concat([output, z_output], 1) 144 | output = lib.ops.linear.Linear('Discriminator.zx1', 4*4*8*DIM_D+512, 512, output) 145 | output = LeakyReLU(output) 146 | output = tf.layers.dropout(output, rate=.2) 147 | 148 | output = lib.ops.linear.Linear('Discriminator.Output', 512, 1, output) 149 | 150 | return tf.reshape(output, [-1]) 151 | 152 | ''' 153 | losses 154 | ''' 155 | real_x_int = tf.placeholder(tf.int32, shape=[BATCH_SIZE, OUTPUT_DIM]) 156 | real_x = tf.reshape(2*((tf.cast(real_x_int, tf.float32)/256.)-.5), [BATCH_SIZE, OUTPUT_DIM]) 157 | real_x += tf.random_uniform(shape=[BATCH_SIZE,OUTPUT_DIM],minval=0.,maxval=1./128) # dequantize 158 | 159 | q_z = Extractor(real_x) 160 | rec_x = Generator(q_z) 161 | p_z = tf.random_normal([BATCH_SIZE, DIM_LATENT]) 162 | fake_x = Generator(p_z) 163 | 164 | disc_real = Discriminator(real_x, q_z) 165 | disc_fake = Discriminator(fake_x, p_z) 166 | 167 | gen_params = lib.params_with_name('Generator') 168 | ext_params = lib.params_with_name('Extractor') 169 | disc_params = lib.params_with_name('Discriminator') 170 | 171 | if MODE == 'ali': 172 | rec_penalty = None 173 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.ali(disc_fake, disc_real, gen_params+ext_params, disc_params, lr=LR*decay, beta1=BETA1, beta2=BETA2) 174 | 175 | else: 176 | raise('NotImplementedError') 177 | 178 | # For visualizing samples 179 | fixed_noise = tf.constant(np.random.normal(size=(N_VIS, DIM_LATENT)).astype('float32')) 180 | fixed_noise_samples = Generator(fixed_noise) 181 | def generate_image(frame, true_dist): 182 | samples = session.run(fixed_noise_samples) 183 | samples = ((samples+1.)*(255.99/2)).astype('int32') 184 | lib.save_images.save_images( 185 | samples.reshape((-1, 3, 64, 64)), 186 | os.path.join(outf, '{}_samples_{}.png'.format(MODE, frame)) 187 | ) 188 | 189 | # Dataset iterator 190 | train_gen, dev_gen = lib.celebA.load(BATCH_SIZE, data_dir=DATA_DIR) 191 | def inf_train_gen(): 192 | while True: 193 | for images in train_gen(): 194 | yield images 195 | 196 | # For reconstruction 197 | fixed_data_int = dev_gen().next() 198 | def reconstruct_image(frame): 199 | rec_samples = session.run(rec_x, feed_dict={real_x_int: fixed_data_int}) 200 | rec_samples = ((rec_samples+1.)*(255.99/2)).astype('int32') 201 | tmp_list = [] 202 | for d, r in zip(fixed_data_int, rec_samples): 203 | tmp_list.append(d) 204 | tmp_list.append(r) 205 | rec_samples = np.vstack(tmp_list) 206 | lib.save_images.save_images( 207 | rec_samples.reshape((-1, 3, 64, 64)), 208 | os.path.join(outf, '{}_reconstruction_{}.png'.format(MODE, frame)) 209 | ) 210 | saver = tf.train.Saver() 211 | 212 | ''' 213 | Train loop 214 | ''' 215 | with tf.Session() as session: 216 | 217 | session.run(tf.global_variables_initializer()) 218 | gen = inf_train_gen() 219 | 220 | total_num = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) 221 | print '\nTotol number of parameters', total_num 222 | with open(logfile,'a') as f: 223 | f.write('Totol number of parameters' + str(total_num) + '\n') 224 | 225 | for iteration in xrange(ITERS): 226 | start_time = time.time() 227 | 228 | if iteration > 0: 229 | _data = gen.next() 230 | _gen_cost, _ = session.run([gen_cost, gen_train_op], 231 | feed_dict={real_x_int: _data}) 232 | 233 | for i in xrange(CRITIC_ITERS): 234 | _data = gen.next() 235 | _disc_cost, _ = session.run( 236 | [disc_cost, disc_train_op], 237 | feed_dict={real_x_int: _data} 238 | ) 239 | 240 | lib.plot.plot('train disc cost', _disc_cost) 241 | lib.plot.plot('time', time.time() - start_time) 242 | 243 | # Calculate dev loss 244 | if iteration % 100 == 99: 245 | dev_gen_costs = [] 246 | for images in dev_gen(): 247 | _dev_gen_cost = session.run( 248 | gen_cost, 249 | feed_dict={real_x_int: images} 250 | ) 251 | dev_gen_costs.append(_dev_gen_cost) 252 | lib.plot.plot('dev gen cost', np.mean(dev_gen_costs)) 253 | 254 | # Write logs 255 | if (iteration < 5) or (iteration % 100 == 99): 256 | lib.plot.flush(outf, logfile) 257 | lib.plot.tick() 258 | 259 | # Generation and reconstruction 260 | if iteration % 1000 == 999: 261 | generate_image(iteration, _data) 262 | reconstruct_image(iteration) 263 | 264 | # Save model 265 | if iteration == ITERS - 1: 266 | save_path = saver.save(session, os.path.join(outf, '{}_model_{}.ckpt'.format(MODE, iteration))) 267 | 268 | if DECAY: 269 | decay = tf.maximum(0., 1.-(tf.cast(iteration, tf.float32)/ITERS)) -------------------------------------------------------------------------------- /tflib/objs/gan_inference.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tflib as lib 3 | 4 | def wali(disc_fake, disc_real, gen_params, disc_params, lr=5e-5): 5 | gen_cost = -tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real) 6 | disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real) 7 | 8 | gen_train_op = tf.train.RMSPropOptimizer( 9 | learning_rate=lr 10 | ).minimize(gen_cost, var_list=gen_params) 11 | disc_train_op = tf.train.RMSPropOptimizer( 12 | learning_rate=lr 13 | ).minimize(disc_cost, var_list=disc_params) 14 | 15 | clip_ops = [] 16 | for var in lib.params_with_name('Discriminator'): 17 | clip_bounds = [-.01, .01] 18 | clip_ops.append( 19 | tf.assign( 20 | var, 21 | tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]) 22 | ) 23 | ) 24 | clip_disc_weights = tf.group(*clip_ops) 25 | 26 | return gen_cost, disc_cost, clip_disc_weights, gen_train_op, disc_train_op, clip_ops 27 | 28 | def wali_gp(disc_fake, disc_real, gradient_penalty, gen_params, disc_params, lr=1e-4): 29 | gen_cost = -tf.reduce_mean(disc_fake) + tf.reduce_mean(disc_real) 30 | disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real) 31 | 32 | disc_cost += gradient_penalty 33 | 34 | gen_train_op = tf.train.AdamOptimizer( 35 | learning_rate=lr, 36 | beta1=0.5, 37 | beta2=0.9 38 | ).minimize(gen_cost, var_list=gen_params) 39 | disc_train_op = tf.train.AdamOptimizer( 40 | learning_rate=lr, 41 | beta1=0.5, 42 | beta2=0.9 43 | ).minimize(disc_cost, var_list=disc_params) 44 | 45 | return gen_cost, disc_cost, gen_train_op, disc_train_op 46 | 47 | def ali(disc_fake, disc_real, gen_params, disc_params, lr=2e-4, beta1=0.5, beta2=0.999, s_f = None): 48 | gen_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 49 | logits=disc_fake, 50 | labels=tf.ones_like(disc_fake) 51 | )) 52 | gen_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 53 | logits=disc_real, 54 | labels=tf.zeros_like(disc_real) 55 | )) 56 | 57 | disc_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 58 | logits=disc_fake, 59 | labels=tf.zeros_like(disc_fake) 60 | )) 61 | disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 62 | logits=disc_real, 63 | labels=tf.ones_like(disc_real) 64 | )) 65 | if s_f is not None: 66 | gen_cost += s_f 67 | 68 | gen_train_op = tf.train.AdamOptimizer( 69 | learning_rate=lr, 70 | beta1=beta1, 71 | beta2=beta2 72 | ).minimize(gen_cost, var_list=gen_params) 73 | disc_train_op = tf.train.AdamOptimizer( 74 | learning_rate=lr, 75 | beta1=beta1, 76 | beta2=beta2 77 | ).minimize(disc_cost, var_list=disc_params) 78 | 79 | return gen_cost, disc_cost, gen_train_op, disc_train_op 80 | 81 | def local_ep(disc_fake_list, disc_real_list, gen_params, disc_params, lr=2e-4, beta1=0.5, beta2=.999, s_f=None): 82 | gen_cost = 0 83 | disc_cost = 0 84 | for disc_fake, disc_real in zip(disc_fake_list, disc_real_list): 85 | gen_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 86 | logits=disc_fake, 87 | labels=tf.ones_like(disc_fake) 88 | )) 89 | gen_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 90 | logits=disc_real, 91 | labels=tf.zeros_like(disc_real) 92 | )) 93 | 94 | disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 95 | logits=disc_fake, 96 | labels=tf.zeros_like(disc_fake) 97 | )) 98 | disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 99 | logits=disc_real, 100 | labels=tf.ones_like(disc_real) 101 | )) 102 | if s_f is not None: 103 | gen_cost += s_f 104 | 105 | gen_cost /= len(disc_fake_list) 106 | disc_cost /= len(disc_fake_list) 107 | 108 | gen_train_op = tf.train.AdamOptimizer( 109 | learning_rate=lr, 110 | beta1=beta1, 111 | beta2=beta2 112 | ).minimize(gen_cost, var_list=gen_params) 113 | disc_train_op = tf.train.AdamOptimizer( 114 | learning_rate=lr, 115 | beta1=beta1, 116 | beta2=beta2 117 | ).minimize(disc_cost, var_list=disc_params) 118 | 119 | return gen_cost, disc_cost, gen_train_op, disc_train_op 120 | 121 | def local_epce(disc_fake_list, disc_real_list, rec_penalty, gen_params, disc_params, lr=2e-4, beta1=0.5, s_f = None): 122 | gen_cost = 0 123 | disc_cost = 0 124 | for disc_fake, disc_real in zip(disc_fake_list, disc_real_list): 125 | gen_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 126 | logits=disc_fake, 127 | labels=tf.ones_like(disc_fake) 128 | )) 129 | gen_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 130 | logits=disc_real, 131 | labels=tf.zeros_like(disc_real) 132 | )) 133 | 134 | disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 135 | logits=disc_fake, 136 | labels=tf.zeros_like(disc_fake) 137 | )) 138 | disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 139 | logits=disc_real, 140 | labels=tf.ones_like(disc_real) 141 | )) 142 | if s_f is not None: 143 | gen_cost += s_f 144 | 145 | gen_cost /= len(disc_fake_list) 146 | disc_cost /= len(disc_fake_list) 147 | 148 | gen_cost += rec_penalty 149 | 150 | gen_train_op = tf.train.AdamOptimizer( 151 | learning_rate=lr, 152 | beta1=beta1 153 | ).minimize(gen_cost, var_list=gen_params) 154 | disc_train_op = tf.train.AdamOptimizer( 155 | learning_rate=lr, 156 | beta1=beta1 157 | ).minimize(disc_cost, var_list=disc_params) 158 | 159 | return gen_cost, disc_cost, gen_train_op, disc_train_op 160 | 161 | def alice(disc_fake, disc_real, rec_penalty, gen_params, disc_params, lr=2e-4, beta1=0.5, s_f = None): 162 | gen_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 163 | logits=disc_fake, 164 | labels=tf.ones_like(disc_fake) 165 | )) 166 | gen_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 167 | logits=disc_real, 168 | labels=tf.zeros_like(disc_real) 169 | )) 170 | if s_f is not None: 171 | gen_cost += s_f 172 | gen_cost += rec_penalty 173 | 174 | disc_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 175 | logits=disc_fake, 176 | labels=tf.zeros_like(disc_fake) 177 | )) 178 | disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 179 | logits=disc_real, 180 | labels=tf.ones_like(disc_real) 181 | )) 182 | 183 | gen_train_op = tf.train.AdamOptimizer( 184 | learning_rate=lr, 185 | beta1=beta1 186 | ).minimize(gen_cost, var_list=gen_params) 187 | disc_train_op = tf.train.AdamOptimizer( 188 | learning_rate=lr, 189 | beta1=beta1 190 | ).minimize(disc_cost, var_list=disc_params) 191 | 192 | return gen_cost, disc_cost, gen_train_op, disc_train_op 193 | 194 | def vegan(disc_fake, disc_real, rec_penalty, gen_params, disc_params, lamb, lr=2e-4, beta1=.5, s_f = None): 195 | gen_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 196 | logits=disc_fake, 197 | labels=tf.ones_like(disc_fake) 198 | )) 199 | if s_f is not None: 200 | gen_cost += s_f 201 | gen_cost *= lamb 202 | gen_cost += rec_penalty 203 | 204 | disc_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 205 | logits=disc_fake, 206 | labels=tf.zeros_like(disc_fake) 207 | )) 208 | disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 209 | logits=disc_real, 210 | labels=tf.ones_like(disc_real) 211 | )) 212 | disc_cost *= (lamb/2) 213 | 214 | gen_train_op = tf.train.AdamOptimizer( 215 | learning_rate=lr, 216 | beta1=beta1 217 | ).minimize(gen_cost, var_list=gen_params) 218 | disc_train_op = tf.train.AdamOptimizer( 219 | learning_rate=lr, 220 | beta1=beta1 221 | ).minimize(disc_cost, var_list=disc_params) 222 | 223 | return gen_cost, disc_cost, gen_train_op, disc_train_op 224 | 225 | def vegan_wgan_gp(disc_fake, disc_real, rec_penalty, gradient_penalty, gen_params, disc_params, lamb, lr=2e-4, beta1=.5): 226 | 227 | gen_cost = -tf.reduce_mean(disc_fake) + tf.reduce_mean(disc_real) 228 | gen_cost *= lamb 229 | gen_cost += rec_penalty 230 | 231 | disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real) 232 | disc_cost *= lamb 233 | disc_cost += gradient_penalty 234 | 235 | gen_train_op = tf.train.AdamOptimizer( 236 | learning_rate=lr, 237 | beta1=beta1 238 | ).minimize(gen_cost, var_list=gen_params) 239 | disc_train_op = tf.train.AdamOptimizer( 240 | learning_rate=lr, 241 | beta1=beta1 242 | ).minimize(disc_cost, var_list=disc_params) 243 | 244 | return gen_cost, disc_cost, gen_train_op, disc_train_op 245 | 246 | def local_ep_dynamic(disc_fake_zz, disc_real_zz, disc_fake_xz, disc_real_xz, gen_params, disc_params, lr=2e-4, beta1=0.5, beta2=.999, rec_penalty=None): 247 | gen_cost = 0 248 | disc_cost = 0 249 | for disc_fake, disc_real in zip(disc_fake_zz, disc_real_zz): 250 | gen_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 251 | logits=disc_fake, 252 | labels=tf.ones_like(disc_fake) 253 | )) 254 | gen_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 255 | logits=disc_real, 256 | labels=tf.zeros_like(disc_real) 257 | )) 258 | 259 | disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 260 | logits=disc_fake, 261 | labels=tf.zeros_like(disc_fake) 262 | )) 263 | disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 264 | logits=disc_real, 265 | labels=tf.ones_like(disc_real) 266 | )) 267 | 268 | if len(disc_fake_zz) > 0: 269 | gen_cost /= (len(disc_fake_zz)+1) 270 | disc_cost /= (len(disc_fake_zz)+1) 271 | 272 | gen_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 273 | logits=disc_fake_xz, 274 | labels=tf.ones_like(disc_fake_xz) 275 | )) 276 | gen_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 277 | logits=disc_real_xz, 278 | labels=tf.zeros_like(disc_real_xz) 279 | )) 280 | 281 | disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 282 | logits=disc_fake_xz, 283 | labels=tf.zeros_like(disc_fake_xz) 284 | )) 285 | disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 286 | logits=disc_real_xz, 287 | labels=tf.ones_like(disc_real_xz) 288 | )) 289 | 290 | if rec_penalty is not None: 291 | gen_cost += rec_penalty 292 | 293 | gen_train_op = tf.train.AdamOptimizer( 294 | learning_rate=lr, 295 | beta1=beta1, 296 | beta2=beta2 297 | ).minimize(gen_cost, var_list=gen_params) 298 | disc_train_op = tf.train.AdamOptimizer( 299 | learning_rate=lr, 300 | beta1=beta1, 301 | beta2=beta2 302 | ).minimize(disc_cost, var_list=disc_params) 303 | 304 | return gen_cost, disc_cost, gen_train_op, disc_train_op 305 | 306 | 307 | def weighted_local_epce(disc_fake_list, disc_real_list, ratio_list, gen_params, disc_params, lr=2e-4, beta1=0.5, rec_penalty = None): 308 | gen_cost = 0 309 | disc_cost = 0 310 | assert len(disc_fake_list) == ratio_list.shape[0] 311 | gen_debug_list, disc_debug_list = [],[] 312 | for disc_fake, disc_real, ratio in zip(disc_fake_list, disc_real_list, ratio_list): 313 | gen_cost += ratio * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 314 | logits=disc_fake, 315 | labels=tf.ones_like(disc_fake) 316 | )) 317 | gen_cost += ratio * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 318 | logits=disc_real, 319 | labels=tf.zeros_like(disc_real) 320 | )) 321 | gen_debug_list.append(ratio * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 322 | logits=disc_fake, 323 | labels=tf.ones_like(disc_fake)))+ 324 | ratio* tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 325 | logits=disc_real, 326 | labels=tf.zeros_like(disc_real))) 327 | ) 328 | 329 | disc_cost += ratio * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 330 | logits=disc_fake, 331 | labels=tf.zeros_like(disc_fake) 332 | )) 333 | disc_cost += ratio * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 334 | logits=disc_real, 335 | labels=tf.ones_like(disc_real) 336 | )) 337 | disc_debug_list.append(ratio * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 338 | logits=disc_fake, 339 | labels=tf.zeros_like(disc_fake)))+ 340 | ratio * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 341 | logits=disc_real, 342 | labels=tf.ones_like(disc_real))) 343 | ) 344 | 345 | 346 | if rec_penalty is not None: 347 | gen_cost += rec_penalty 348 | 349 | gen_train_op = tf.train.AdamOptimizer( 350 | learning_rate=lr, 351 | beta1=beta1 352 | ).minimize(gen_cost, var_list=gen_params) 353 | disc_train_op = tf.train.AdamOptimizer( 354 | learning_rate=lr, 355 | beta1=beta1 356 | ).minimize(disc_cost, var_list=disc_params) 357 | 358 | return gen_cost, disc_cost, gen_debug_list, disc_debug_list, gen_train_op, disc_train_op -------------------------------------------------------------------------------- /gmgan_inference_face.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil, time 2 | sys.path.append(os.getcwd()) 3 | 4 | import functools 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import sklearn.datasets 10 | from sklearn.manifold import TSNE 11 | import tensorflow as tf 12 | 13 | import tflib as lib 14 | import tflib.ops.linear 15 | import tflib.ops.conv2d 16 | import tflib.ops.batchnorm 17 | import tflib.ops.deconv2d 18 | import tflib.save_images 19 | import tflib.celebA 20 | import tflib.plot 21 | import tflib.visualization 22 | import tflib.objs.gan_inference 23 | import tflib.objs.mmd 24 | import tflib.objs.kl 25 | import tflib.objs.kl_aggregated 26 | import tflib.objs.discrete_variables 27 | import tflib.utils.distance 28 | 29 | DATA_DIR = './dataset/celebA' 30 | 31 | 32 | ''' 33 | hyperparameters 34 | ''' 35 | MODE = 'local_ep' # ali, local_ep 36 | CRITIC_ITERS = 1 37 | BATCH_SIZE = 128 # Batch size 38 | LR = 2e-4 39 | DECAY = False 40 | decay = 1. 41 | BETA1 = .5 42 | BETA2 = .999 43 | ITERS = 100000 # How many generator iterations to train for 44 | 45 | DIM_G = 32 # Model dimensionality 46 | DIM_D = 32 # Model dimensionality 47 | OUTPUT_DIM = 12288 # Number of pixels in celebA (3*64*64) 48 | DIM_LATENT = 128 49 | N_COMS = 100 50 | N_VIS = N_COMS*10 # Number of samples to be visualized 51 | assert(N_VIS%N_COMS==0) 52 | MODE_K = 'CONCRETE' 53 | TEMP_INIT = .1 54 | TEMP = TEMP_INIT 55 | 56 | 57 | ''' 58 | logs 59 | ''' 60 | filename_script=os.path.basename(os.path.realpath(__file__)) 61 | outf=os.path.join("result", os.path.splitext(filename_script)[0]) 62 | outf+='.MODE-' 63 | outf+=MODE 64 | outf+='.N_COMS-' 65 | outf+=str(N_COMS) 66 | outf+='.' 67 | outf+=str(int(time.time())) 68 | if not os.path.exists(outf): 69 | os.makedirs(outf) 70 | logfile=os.path.join(outf, 'logfile.txt') 71 | shutil.copy(os.path.realpath(__file__), os.path.join(outf, filename_script)) 72 | lib.print_model_settings_to_file(locals().copy(), logfile) 73 | 74 | 75 | ''' 76 | models 77 | ''' 78 | ### prior 79 | PI = tf.constant(np.asarray([1./N_COMS,]*N_COMS, dtype=np.float32)) 80 | prior_k = tf.distributions.Categorical(probs=PI) 81 | 82 | def sample_gumbel(shape, eps=1e-20): 83 | # Sample from Gumbel(0, 1) 84 | U = tf.random_uniform(shape,minval=0,maxval=1) 85 | return -tf.log(-tf.log(U + eps) + eps) 86 | 87 | def nonlinearity(x): 88 | return tf.nn.relu(x) 89 | 90 | def LeakyReLU(x, alpha=0.2): 91 | return tf.maximum(alpha*x, x) 92 | 93 | ### Very simple MoG 94 | def HyperGenerator(hyper_k, hyper_noise): 95 | com_mu = lib.param('Generator.Hyper.Mu', np.random.normal(size=(N_COMS, DIM_LATENT)).astype('float32')) 96 | noise = tf.add(tf.matmul(tf.cast(hyper_k, tf.float32), com_mu), hyper_noise) 97 | return noise 98 | 99 | ### Very simple soft alignment 100 | def HyperExtractor(latent_z): 101 | com_mu = lib.param('Generator.Hyper.Mu', np.random.normal(size=(N_COMS, DIM_LATENT)).astype('float32')) 102 | com_logits = -.5*tf.reduce_sum(tf.pow((tf.expand_dims(latent_z, axis=1) - tf.expand_dims(com_mu, axis=0)), 2), axis=-1) + tf.expand_dims(tf.log(PI), axis=0) 103 | 104 | k = tf.nn.softmax((com_logits + sample_gumbel(tf.shape(com_logits)))/TEMP) 105 | 106 | return com_logits, k 107 | 108 | def Generator(noise): 109 | output = lib.ops.linear.Linear('Generator.Input', DIM_LATENT, 4*4*8*DIM_G, noise) 110 | output = tf.nn.relu(output) 111 | output = tf.reshape(output, [-1, 8*DIM_G, 4, 4]) 112 | 113 | output = lib.ops.deconv2d.Deconv2D('Generator.2', 8*DIM_G, 4*DIM_G, 5, output) 114 | output = tf.nn.relu(output) 115 | 116 | output = lib.ops.deconv2d.Deconv2D('Generator.3', 4*DIM_G, 2*DIM_G, 5, output) 117 | output = tf.nn.relu(output) 118 | 119 | output = lib.ops.deconv2d.Deconv2D('Generator.4', 2*DIM_G, DIM_G, 5, output) 120 | output = tf.nn.relu(output) 121 | 122 | output = lib.ops.deconv2d.Deconv2D('Generator.5', DIM_G, 3, 5, output) 123 | output = tf.tanh(output) 124 | 125 | return tf.reshape(output, [-1, OUTPUT_DIM]) 126 | 127 | def Extractor(inputs): 128 | output = tf.reshape(inputs, [-1, 3, 64, 64]) 129 | 130 | output = lib.ops.conv2d.Conv2D('Extractor.1', 3, DIM_G, 5, output,stride=2) 131 | output = LeakyReLU(output) 132 | 133 | output = lib.ops.conv2d.Conv2D('Extractor.2', DIM_G, 2*DIM_G, 5, output, stride=2) 134 | output = LeakyReLU(output) 135 | 136 | output = lib.ops.conv2d.Conv2D('Extractor.3', 2*DIM_G, 4*DIM_G, 5, output, stride=2) 137 | output = LeakyReLU(output) 138 | 139 | output = lib.ops.conv2d.Conv2D('Extractor.4', 4*DIM_G, 8*DIM_G, 5, output, stride=2) 140 | output = LeakyReLU(output) 141 | 142 | output = tf.reshape(output, [-1, 4*4*8*DIM_G]) 143 | 144 | output = lib.ops.linear.Linear('Extractor.Output', 4*4*8*DIM_G, DIM_LATENT, output) 145 | 146 | return tf.reshape(output, [-1, DIM_LATENT]) 147 | 148 | if MODE in ['local_ep', 'local_epce']: 149 | 150 | def HyperDiscriminator(z, k): 151 | output = tf.concat([z, k], 1) 152 | output = lib.ops.linear.Linear('Discriminator.HyperInput', DIM_LATENT+N_COMS, 512, output) 153 | output = LeakyReLU(output) 154 | output = tf.layers.dropout(output, rate=.2) 155 | 156 | output = lib.ops.linear.Linear('Discriminator.Hyper2', 512, 512, output) 157 | output = LeakyReLU(output) 158 | output = tf.layers.dropout(output, rate=.2) 159 | 160 | output = lib.ops.linear.Linear('Discriminator.Hyper3', 512, 512, output) 161 | output = LeakyReLU(output) 162 | output = tf.layers.dropout(output, rate=.2) 163 | 164 | output = lib.ops.linear.Linear('Discriminator.HyperOutput', 512, 1, output) 165 | 166 | return tf.reshape(output, [-1]) 167 | 168 | def Discriminator(x, z): 169 | output = tf.reshape(x, [-1, 3, 64, 64]) 170 | 171 | output = lib.ops.conv2d.Conv2D('Discriminator.1', 3, DIM_D, 5,output, stride=2) 172 | output = LeakyReLU(output) 173 | output = tf.layers.dropout(output, rate=.2) 174 | 175 | output = lib.ops.conv2d.Conv2D('Discriminator.2', DIM_D, 2*DIM_D, 5, output, stride=2) 176 | output = LeakyReLU(output) 177 | output = tf.layers.dropout(output, rate=.2) 178 | 179 | output = lib.ops.conv2d.Conv2D('Discriminator.3', 2*DIM_D, 4*DIM_D, 5, output, stride=2) 180 | output = LeakyReLU(output) 181 | output = tf.layers.dropout(output, rate=.2) 182 | 183 | output = lib.ops.conv2d.Conv2D('Discriminator.4', 4*DIM_D, 8*DIM_D, 5, output, stride=2) 184 | output = LeakyReLU(output) 185 | output = tf.layers.dropout(output, rate=.2) 186 | 187 | output = tf.reshape(output, [-1, 4*4*8*DIM_D]) 188 | 189 | z_output = lib.ops.linear.Linear('Discriminator.z1', DIM_LATENT, 512, z) 190 | z_output = LeakyReLU(z_output) 191 | z_output = tf.layers.dropout(z_output, rate=.2) 192 | 193 | output = tf.concat([output, z_output], 1) 194 | output = lib.ops.linear.Linear('Discriminator.zx1', 4*4*8*DIM_D+512, 512, output) 195 | output = LeakyReLU(output) 196 | output = tf.layers.dropout(output, rate=.2) 197 | 198 | output = lib.ops.linear.Linear('Discriminator.Output', 512, 1, output) 199 | 200 | return tf.reshape(output, [-1]) 201 | 202 | else: 203 | 204 | def Discriminator(x, z, k): 205 | output = tf.reshape(x, [-1, 3, 64, 64]) 206 | 207 | output = lib.ops.conv2d.Conv2D('Discriminator.1', 3, DIM_D, 5,output, stride=2) 208 | output = LeakyReLU(output) 209 | output = tf.layers.dropout(output, rate=.2) 210 | 211 | output = lib.ops.conv2d.Conv2D('Discriminator.2', DIM_D, 2*DIM_D, 5, output, stride=2) 212 | output = LeakyReLU(output) 213 | output = tf.layers.dropout(output, rate=.2) 214 | 215 | output = lib.ops.conv2d.Conv2D('Discriminator.3', 2*DIM_D, 4*DIM_D, 5, output, stride=2) 216 | output = LeakyReLU(output) 217 | output = tf.layers.dropout(output, rate=.2) 218 | 219 | output = lib.ops.conv2d.Conv2D('Discriminator.4', 4*DIM_D, 8*DIM_D, 5, output, stride=2) 220 | output = LeakyReLU(output) 221 | output = tf.layers.dropout(output, rate=.2) 222 | 223 | output = tf.reshape(output, [-1, 4*4*8*DIM_D]) 224 | 225 | zk_output = tf.concat([z, k], 1) 226 | zk_output = lib.ops.linear.Linear('Discriminator.zk1', DIM_LATENT+N_COMS, 512, zk_output) 227 | zk_output = LeakyReLU(zk_output) 228 | zk_output = tf.layers.dropout(zk_output, rate=.2) 229 | 230 | output = tf.concat([output, zk_output], 1) 231 | output = lib.ops.linear.Linear('Discriminator.zxk1', 4*4*8*DIM_D+512, 512, output) 232 | output = LeakyReLU(output) 233 | output = tf.layers.dropout(output, rate=.2) 234 | 235 | output = lib.ops.linear.Linear('Discriminator.Output', 512, 1, output) 236 | 237 | return tf.reshape(output, [-1]) 238 | ''' 239 | losses 240 | ''' 241 | real_x_int = tf.placeholder(tf.int32, shape=[BATCH_SIZE, OUTPUT_DIM]) 242 | real_x = tf.reshape(2*((tf.cast(real_x_int, tf.float32)/256.)-.5), [BATCH_SIZE, OUTPUT_DIM]) 243 | real_x += tf.random_uniform(shape=[BATCH_SIZE,OUTPUT_DIM],minval=0.,maxval=1./128) # dequantize 244 | q_z = Extractor(real_x) 245 | q_k_logits, q_k = HyperExtractor(q_z) 246 | q_k_probs = tf.nn.softmax(q_k_logits) 247 | rec_x = Generator(q_z) 248 | hyper_p_z = tf.random_normal([BATCH_SIZE, DIM_LATENT]) 249 | hyper_p_k = tf.one_hot(indices=prior_k.sample(BATCH_SIZE), depth=N_COMS) 250 | p_z = HyperGenerator(hyper_p_k, hyper_p_z) 251 | fake_x = Generator(p_z) 252 | 253 | if MODE in ['local_ep', 'local_epce']: 254 | disc_fake, disc_real = [],[] 255 | disc_fake.append(HyperDiscriminator(p_z, hyper_p_k)) 256 | disc_real.append(HyperDiscriminator(q_z, q_k)) 257 | disc_fake.append(Discriminator(fake_x, p_z)) 258 | disc_real.append(Discriminator(real_x, q_z)) 259 | 260 | else: 261 | disc_real = Discriminator(real_x, q_z, q_k) 262 | disc_fake = Discriminator(fake_x, p_z, hyper_p_k) 263 | 264 | gen_params = lib.params_with_name('Generator') 265 | ext_params = lib.params_with_name('Extractor') 266 | disc_params = lib.params_with_name('Discriminator') 267 | 268 | if MODE == 'ali': 269 | rec_penalty = None 270 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.ali(disc_fake, disc_real, gen_params+ext_params, disc_params, lr=LR*decay, beta1=BETA1, beta2=BETA2) 271 | 272 | elif MODE == 'local_ep': 273 | rec_penalty = None 274 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.local_ep(disc_fake, disc_real, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1, beta2=BETA2) 275 | 276 | else: 277 | raise('NotImplementedError') 278 | 279 | # For visualizing samples 280 | # np_fixed_noise = np.repeat(np.random.normal(size=(N_VIS/N_COMS, DIM_LATENT)).astype('float32'), N_COMS, axis=0) 281 | np_fixed_noise = np.random.normal(size=(N_VIS, DIM_LATENT)).astype('float32') 282 | np_fixed_k = np.tile(np.eye(N_COMS, dtype=int), (N_VIS/N_COMS, 1)) 283 | hyper_fixed_noise = tf.constant(np_fixed_noise) 284 | hyper_fixed_k = tf.constant(np_fixed_k) 285 | fixed_noise = HyperGenerator(hyper_fixed_k, hyper_fixed_noise) 286 | fixed_noise_samples = Generator(fixed_noise) 287 | def generate_image(frame, true_dist): 288 | samples = session.run(fixed_noise_samples) 289 | samples = ((samples+1.)*(255.99/2)).astype('int32') 290 | lib.save_images.save_images( 291 | samples.reshape((-1, 3, 64, 64)), 292 | os.path.join(outf, '{}_samples_{}.png'.format(frame, MODE)), 293 | size = [N_VIS/N_COMS, N_COMS] 294 | ) 295 | 296 | # Dataset iterator 297 | train_gen, dev_gen = lib.celebA.load(BATCH_SIZE, data_dir=DATA_DIR) 298 | def inf_train_gen(): 299 | while True: 300 | for images in train_gen(): 301 | yield images 302 | 303 | # For reconstruction 304 | fixed_data_int = dev_gen().next() 305 | def reconstruct_image(frame): 306 | rec_samples = session.run(rec_x, feed_dict={real_x_int: fixed_data_int}) 307 | rec_samples = ((rec_samples+1.)*(255.99/2)).astype('int32') 308 | tmp_list = [] 309 | for d, r in zip(fixed_data_int, rec_samples): 310 | tmp_list.append(d) 311 | tmp_list.append(r) 312 | rec_samples = np.vstack(tmp_list) 313 | lib.save_images.save_images( 314 | rec_samples.reshape((-1, 3, 64, 64)), 315 | os.path.join(outf, '{}_reconstruction_{}.png'.format(frame, MODE)) 316 | ) 317 | saver = tf.train.Saver() 318 | 319 | 320 | ''' 321 | Train loop 322 | ''' 323 | with tf.Session() as session: 324 | 325 | session.run(tf.global_variables_initializer()) 326 | gen = inf_train_gen() 327 | 328 | total_num = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) 329 | print '\nTotol number of parameters', total_num 330 | with open(logfile,'a') as f: 331 | f.write('Totol number of parameters' + str(total_num) + '\n') 332 | 333 | for iteration in xrange(ITERS): 334 | start_time = time.time() 335 | 336 | if iteration > 0: 337 | _data = gen.next() 338 | _gen_cost, _ = session.run([gen_cost, gen_train_op], 339 | feed_dict={real_x_int: _data} 340 | ) 341 | 342 | for i in xrange(CRITIC_ITERS): 343 | _data = gen.next() 344 | _disc_cost, _ = session.run( 345 | [disc_cost, disc_train_op], 346 | feed_dict={real_x_int: _data} 347 | ) 348 | 349 | lib.plot.plot('train disc cost', _disc_cost) 350 | lib.plot.plot('time', time.time() - start_time) 351 | 352 | # Calculate dev loss 353 | if iteration % 100 == 99: 354 | dev_gen_costs = [] 355 | for images in dev_gen(): 356 | _dev_gen_cost = session.run( 357 | gen_cost, 358 | feed_dict={real_x_int: images} 359 | ) 360 | dev_gen_costs.append(_dev_gen_cost) 361 | lib.plot.plot('dev gen cost', np.mean(dev_gen_costs)) 362 | 363 | # Write logs 364 | if (iteration < 5) or (iteration % 100 == 99): 365 | lib.plot.flush(outf, logfile) 366 | lib.plot.tick() 367 | 368 | # Generation and reconstruction 369 | if iteration % 1000 == 999: 370 | generate_image(iteration, _data) 371 | reconstruct_image(iteration) 372 | 373 | # Save model 374 | if iteration == ITERS - 1: 375 | save_path = saver.save(session, os.path.join(outf, '{}_model_{}.ckpt'.format(iteration, MODE))) 376 | 377 | if DECAY: 378 | decay = tf.maximum(0., 1.-(tf.cast(iteration, tf.float32)/ITERS)) -------------------------------------------------------------------------------- /gan_inference_svhn.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil, time 2 | sys.path.append(os.getcwd()) 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import sklearn.datasets 9 | from sklearn.manifold import TSNE 10 | import tensorflow as tf 11 | 12 | import tflib as lib 13 | import tflib.ops.linear 14 | import tflib.ops.conv2d 15 | import tflib.ops.batchnorm 16 | import tflib.ops.deconv2d 17 | import tflib.save_images 18 | import tflib.svhn 19 | import tflib.plot 20 | import tflib.visualization 21 | import tflib.objs.gan_inference 22 | import tflib.objs.mmd 23 | import tflib.objs.kl 24 | import tflib.objs.kl_aggregated 25 | import tflib.utils.distance 26 | 27 | DATA_DIR = './dataset/svhn' 28 | 29 | ''' 30 | hyperparameters 31 | ''' 32 | MODE = 'ali' # ali, alice, alice-z, alice-x 33 | if MODE in ['vegan-kl', 'vegan-ikl', 'vegan-jsd']: 34 | TYPE_Q = 'learn_std' # learn_std, fix_std, no_std 35 | TYPE_P = 'no_std' 36 | Z_SAMPLES = 100 # MC estimation for D(q(z)||p(z)) 37 | elif MODE is 'vae': 38 | TYPE_Q = 'learn_std' 39 | TYPE_P = 'learn_std' 40 | else: 41 | TYPE_Q = 'no_std' 42 | TYPE_P = 'no_std' 43 | STD = .1 # For fix_std 44 | d_list = ['alice', 'alice-z', 'alice-x', 'vegan', 'vegan-wgan-gp', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vegan-mmd'] 45 | if MODE in d_list: 46 | DISTANCE_X = 'l2' # l1, l2 47 | if MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 48 | CRITIC_ITERS = 0 # No discriminators 49 | elif MODE in ['vegan', 'vegan-wgan-gp', 'wali', 'wali-gp']: 50 | CRITIC_ITERS = 5 # 5 iters of D per iter of G 51 | else: 52 | CRITIC_ITERS = 1 53 | 54 | BATCH_SIZE = 64 # Batch size 55 | LAMBDA = 1. # Balance reconstruction and regularization in vegan 56 | LR = 2e-4 57 | if MODE in ['vae']: 58 | BETA1 = .9 59 | else: 60 | BETA1 = .5 61 | ITERS = 200000 # How many generator iterations to train for 62 | 63 | DIM = 64 # Model dimensionality 64 | OUTPUT_DIM = 3072 # Number of pixels in svhn (3*32*32) 65 | if MODE in ['vegan', 'vegan-wgan-gp', 'vegan-kl', 'vegan-jsd', 'vegan-ikl']: 66 | BN_FLAG = False # Use batch_norm or not 67 | DIM_LATENT = 8 # Dimensionality of the latent z 68 | else: 69 | BN_FLAG = False 70 | DIM_LATENT = 128 71 | N_VIS = BATCH_SIZE*2 # Number of samples to be visualized 72 | DR_RATE = .2 73 | 74 | 75 | ''' 76 | logs 77 | ''' 78 | filename_script=os.path.basename(os.path.realpath(__file__)) 79 | outf=os.path.join("result", os.path.splitext(filename_script)[0]) 80 | outf+='.MODE-' 81 | outf+=MODE 82 | outf+='.' 83 | outf+=str(int(time.time())) 84 | if not os.path.exists(outf): 85 | os.makedirs(outf) 86 | logfile=os.path.join(outf, 'logfile.txt') 87 | shutil.copy(os.path.realpath(__file__), os.path.join(outf, filename_script)) 88 | lib.print_model_settings_to_file(locals().copy(), logfile) 89 | 90 | 91 | ''' 92 | models 93 | ''' 94 | unit_std_x = tf.constant((STD*np.ones(shape=(BATCH_SIZE, OUTPUT_DIM))).astype('float32')) 95 | unit_std_z = tf.constant((STD*np.ones(shape=(BATCH_SIZE, DIM_LATENT))).astype('float32')) 96 | 97 | def LeakyReLU(x, alpha=0.2): 98 | return tf.maximum(alpha*x, x) 99 | 100 | def ReLULayer(name, n_in, n_out, inputs): 101 | output = lib.ops.linear.Linear( 102 | name+'.Linear', 103 | n_in, 104 | n_out, 105 | inputs, 106 | initialization='he' 107 | ) 108 | return tf.nn.relu(output) 109 | 110 | def LeakyReLULayer(name, n_in, n_out, inputs): 111 | output = lib.ops.linear.Linear( 112 | name+'.Linear', 113 | n_in, 114 | n_out, 115 | inputs, 116 | initialization='he' 117 | ) 118 | return LeakyReLU(output) 119 | 120 | def GaussianNoiseLayer(input_layer, std): 121 | noise = tf.random_normal(shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32) 122 | return input_layer + noise 123 | 124 | def Generator(noise): 125 | output = lib.ops.linear.Linear('Generator.Input', DIM_LATENT, 4*4*4*DIM, noise) 126 | if BN_FLAG: 127 | output = lib.ops.batchnorm.Batchnorm('Generator.BN1', [0], output) 128 | output = tf.nn.relu(output) 129 | output = tf.reshape(output, [-1, 4*DIM, 4, 4]) 130 | 131 | output = lib.ops.deconv2d.Deconv2D('Generator.2', 4*DIM, 2*DIM, 5, output) 132 | if BN_FLAG: 133 | output = lib.ops.batchnorm.Batchnorm('Generator.BN2', [0,2,3], output) 134 | output = tf.nn.relu(output) 135 | 136 | output = lib.ops.deconv2d.Deconv2D('Generator.3', 2*DIM, DIM, 5, output) 137 | if BN_FLAG: 138 | output = lib.ops.batchnorm.Batchnorm('Generator.BN3', [0,2,3], output) 139 | output = tf.nn.relu(output) 140 | 141 | output = lib.ops.deconv2d.Deconv2D('Generator.5', DIM, 3, 5, output) 142 | output = tf.tanh(output) 143 | 144 | return tf.reshape(output, [-1, OUTPUT_DIM]), None, None 145 | 146 | def Extractor(inputs): 147 | output = tf.reshape(inputs, [-1, 3, 32, 32]) 148 | 149 | output = lib.ops.conv2d.Conv2D('Extractor.1', 3, DIM, 5, output,stride=2) 150 | output = LeakyReLU(output) 151 | 152 | output = lib.ops.conv2d.Conv2D('Extractor.2', DIM, 2*DIM, 5, output, stride=2) 153 | if BN_FLAG: 154 | output = lib.ops.batchnorm.Batchnorm('Extractor.BN2', [0,2,3], output) 155 | output = LeakyReLU(output) 156 | 157 | output = lib.ops.conv2d.Conv2D('Extractor.3', 2*DIM, 4*DIM, 5, output, stride=2) 158 | if BN_FLAG: 159 | output = lib.ops.batchnorm.Batchnorm('Extractor.BN3', [0,2,3], output) 160 | output = LeakyReLU(output) 161 | 162 | output = tf.reshape(output, [-1, 4*4*4*DIM]) 163 | 164 | if TYPE_Q is 'learn_std': 165 | log_std = lib.ops.linear.Linear('Extractor.Std', 4*4*4*DIM, DIM_LATENT, output) 166 | std = tf.exp(log_std) 167 | elif TYPE_Q is 'fix_std': 168 | std = unit_std_z 169 | else: 170 | std = None 171 | mean = None 172 | 173 | output = lib.ops.linear.Linear('Extractor.Output', 4*4*4*DIM, DIM_LATENT, output) 174 | 175 | if TYPE_Q in ['learn_std', 'fix_std']: 176 | epsilon = tf.random_normal(unit_std_z.shape) 177 | mean = output 178 | output = tf.add(mean, tf.multiply(epsilon, std)) 179 | 180 | return tf.reshape(output, [-1, DIM_LATENT]), mean, std 181 | 182 | if MODE in ['vegan', 'vegan-wgan-gp']: 183 | # define a discriminator on z 184 | def Discriminator(z): 185 | output = GaussianNoiseLayer(z, std=.3) 186 | output = lib.ops.linear.Linear('Discriminator.Input', DIM_LATENT, 1024, output) 187 | if BN_FLAG: 188 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN1', [0], output) 189 | output = LeakyReLU(output) 190 | output = GaussianNoiseLayer(output, std=.5) 191 | 192 | output = lib.ops.linear.Linear('Discriminator.2', 1024, 512, output) 193 | if BN_FLAG: 194 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN2', [0], output) 195 | output = LeakyReLU(output) 196 | output = GaussianNoiseLayer(output, std=.5) 197 | 198 | output = lib.ops.linear.Linear('Discriminator.3', 512, 256, output) 199 | if BN_FLAG: 200 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN3', [0], output) 201 | output = LeakyReLU(output) 202 | output = GaussianNoiseLayer(output, std=.5) 203 | 204 | output = lib.ops.linear.Linear('Discriminator.4', 256, 256, output) 205 | if BN_FLAG: 206 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN4', [0], output) 207 | output = LeakyReLU(output) 208 | 209 | output = lib.ops.linear.Linear('Discriminator.Output', 256, 1, output) 210 | 211 | return tf.reshape(output, [-1]) 212 | 213 | elif MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 214 | pass # no discriminator 215 | 216 | else: 217 | 218 | def Discriminator(x, z): 219 | output = tf.reshape(x, [-1, 3, 32, 32]) 220 | 221 | output = lib.ops.conv2d.Conv2D('Discriminator.1',3,DIM,5,output,stride=2) 222 | output = LeakyReLU(output) 223 | output = tf.layers.dropout(output, rate=DR_RATE) 224 | 225 | output = lib.ops.conv2d.Conv2D('Discriminator.2', DIM, 2*DIM, 5, output, stride=2) 226 | output = LeakyReLU(output) 227 | output = tf.layers.dropout(output, rate=DR_RATE) 228 | 229 | output = lib.ops.conv2d.Conv2D('Discriminator.3', 2*DIM, 4*DIM, 5, output, stride=2) 230 | output = LeakyReLU(output) 231 | output = tf.layers.dropout(output, rate=DR_RATE) 232 | 233 | output = tf.reshape(output, [-1, 4*4*4*DIM]) 234 | 235 | z_output = lib.ops.linear.Linear('Discriminator.z1', DIM_LATENT, 512, z) 236 | z_output = LeakyReLU(z_output) 237 | z_output = tf.layers.dropout(z_output, rate=DR_RATE) 238 | 239 | output = tf.concat([output, z_output], 1) 240 | output = lib.ops.linear.Linear('Discriminator.zx1', 4*4*4*DIM+512, 512, output) 241 | output = LeakyReLU(output) 242 | output = tf.layers.dropout(output, rate=DR_RATE) 243 | 244 | output = lib.ops.linear.Linear('Discriminator.Output', 512, 1, output) 245 | 246 | return tf.reshape(output, [-1]) 247 | 248 | 249 | ''' 250 | losses 251 | ''' 252 | real_x_int = tf.placeholder(tf.int32, shape=[BATCH_SIZE, OUTPUT_DIM]) 253 | real_x = 2*((tf.cast(real_x_int, tf.float32)/255.)-.5) 254 | q_z, q_z_mean, q_z_std = Extractor(real_x) 255 | rec_x, rec_x_mean, rec_x_std = Generator(q_z) 256 | p_z = tf.random_normal([BATCH_SIZE, DIM_LATENT]) 257 | fake_x, _, _ = Generator(p_z) 258 | rec_z, _, _ = Extractor(fake_x) 259 | if MODE in ['vegan-kl', 'vegan-ikl', 'vegan-jsd']: 260 | p_z_mean = tf.constant((np.zeros(shape=(Z_SAMPLES, DIM_LATENT))).astype('float32')) # prior for estimating D(q(z) || p(z)) 261 | p_z_std = tf.constant((np.ones(shape=(Z_SAMPLES, DIM_LATENT))).astype('float32')) 262 | elif MODE is 'vae': 263 | p_z_mean = tf.constant((np.zeros(shape=(BATCH_SIZE, DIM_LATENT))).astype('float32')) # prior for estimating D(q(z) || p(z)) 264 | p_z_std = tf.constant((np.ones(shape=(BATCH_SIZE, DIM_LATENT))).astype('float32')) 265 | 266 | 267 | if MODE in ['vegan', 'vegan-wgan-gp']: 268 | disc_real = Discriminator(p_z) # discriminate code 269 | disc_fake = Discriminator(q_z) 270 | elif MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 271 | pass # no discriminators 272 | else: 273 | disc_real = Discriminator(real_x, q_z) # discriminate code-data pair 274 | disc_fake = Discriminator(fake_x, p_z) 275 | 276 | gen_params = lib.params_with_name('Generator') 277 | ext_params = lib.params_with_name('Extractor') 278 | disc_params = lib.params_with_name('Discriminator') 279 | 280 | if MODE == 'ali': 281 | rec_penalty = None 282 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.ali(disc_fake, disc_real, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1) 283 | 284 | elif MODE == 'alice-z': 285 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 286 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.alice(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1) 287 | 288 | elif MODE == 'alice-x': 289 | rec_penalty = 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 290 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.alice(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1) 291 | 292 | elif MODE == 'alice': 293 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 294 | rec_penalty += 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 295 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.alice(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1) 296 | 297 | elif MODE == 'vegan': 298 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 299 | # rec_penalty += 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 300 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.vegan(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, LAMBDA,lr=LR, beta1=BETA1) 301 | 302 | elif MODE == 'vegan-wgan-gp': 303 | alpha = tf.random_uniform( 304 | shape=[BATCH_SIZE,1], 305 | minval=0., 306 | maxval=1. 307 | ) 308 | differences = q_z - p_z 309 | interpolates = p_z + (alpha*differences) 310 | gradients = tf.gradients(Discriminator(interpolates), interpolates)[0] 311 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) 312 | gradient_penalty = 10.*(tf.reduce_mean((slopes-1.)**2)) 313 | 314 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 315 | 316 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.vegan_wgan_gp(disc_fake, disc_real, rec_penalty, gradient_penalty, gen_params+ext_params, disc_params, LAMBDA,lr=LR, beta1=BETA1) 317 | 318 | elif MODE == 'vegan-mmd': 319 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 320 | gen_cost, gen_train_op = lib.objs.mmd.vegan_mmd(q_z, p_z, rec_penalty, gen_params+ext_params, BATCH_SIZE, LAMBDA, lr=LR, beta1=BETA1) 321 | 322 | elif MODE == 'vegan-kl': 323 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 324 | gen_cost, gen_train_op = lib.objs.kl_aggregated.vegan_kl(q_z_mean, q_z_std, p_z_mean, p_z_std, rec_penalty, gen_params+ext_params, Z_SAMPLES, BATCH_SIZE, DIM_LATENT, LAMBDA, lr=LR, beta1=BETA1) 325 | 326 | elif MODE == 'vegan-ikl': 327 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 328 | gen_cost, gen_train_op = lib.objs.kl_aggregated.vegan_ikl(q_z_mean, q_z_std, p_z_mean, p_z_std, rec_penalty, gen_params+ext_params, Z_SAMPLES, DIM_LATENT, LAMBDA, lr=LR, beta1=BETA1) 329 | 330 | elif MODE == 'vegan-jsd': 331 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 332 | gen_cost, gen_train_op = lib.objs.kl_aggregated.vegan_jsd(q_z_mean, q_z_std, p_z_mean, p_z_std, rec_penalty, gen_params+ext_params, Z_SAMPLES, BATCH_SIZE, DIM_LATENT, LAMBDA, lr=LR, beta1=BETA1) 333 | 334 | elif MODE == 'vae': 335 | rec_penalty = None 336 | gen_cost, gen_train_op = lib.objs.kl.vae(real_x, rec_x_mean, rec_x_std, q_z_mean, q_z_std, p_z_mean, p_z_std, gen_params+ext_params, lr=LR, beta1=BETA1) 337 | 338 | elif MODE == 'wali': 339 | rec_penalty = None 340 | gen_cost, disc_cost, clip_disc_weights, gen_train_op, disc_train_op, clip_ops = lib.objs.gan_inference.wali(disc_fake, disc_real, gen_params+ext_params, disc_params) 341 | 342 | elif MODE == 'wali-gp': 343 | rec_penalty = None 344 | alpha = tf.random_uniform( 345 | shape=[BATCH_SIZE,1], 346 | minval=0., 347 | maxval=1. 348 | ) 349 | differences = fake_x - real_x 350 | interpolates = real_x + (alpha*differences) 351 | differences_z = p_z - q_z 352 | interpolates_z = q_z + (alpha*differences_z) 353 | gradients = tf.gradients(Discriminator(interpolates,interpolates_z), [interpolates,interpolates_z])[0] 354 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) 355 | gradient_penalty = 10.*(tf.reduce_mean((slopes-1.)**2)) 356 | 357 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.wali_gp(disc_fake, disc_real, gradient_penalty, gen_params+ext_params, disc_params) 358 | else: 359 | raise('NotImplementedError') 360 | 361 | # For visualizing samples 362 | fixed_noise = tf.constant(np.random.normal(size=(N_VIS, DIM_LATENT)).astype('float32')) 363 | fixed_noise_samples, _, _ = Generator(fixed_noise) 364 | def generate_image(frame, true_dist): 365 | samples = session.run(fixed_noise_samples) 366 | samples = ((samples+1.)*(255./2)).astype('int32') 367 | lib.save_images.save_images( 368 | samples.reshape((-1, 3, 32, 32)), 369 | os.path.join(outf, '{}_samples_{}.png'.format(MODE, frame)) 370 | ) 371 | 372 | # Dataset iterator 373 | train_gen, dev_gen = lib.svhn.load(BATCH_SIZE, data_dir=DATA_DIR) 374 | def inf_train_gen(): 375 | while True: 376 | for images, targets in train_gen(): 377 | yield images 378 | 379 | # For reconstruction 380 | rand_data_int, _ = dev_gen().next() 381 | def reconstruct_image(frame, data_int): 382 | rec_samples = session.run(rec_x, feed_dict={real_x_int: data_int}) 383 | rec_samples = ((rec_samples+1.)*(255./2)).astype('int32') 384 | tmp_list = [] 385 | for d, r in zip(data_int, rec_samples): 386 | tmp_list.append(d) 387 | tmp_list.append(r) 388 | rec_samples = np.vstack(tmp_list) 389 | lib.save_images.save_images( 390 | rec_samples.reshape((-1, 3, 32, 32)), 391 | os.path.join(outf, '{}_reconstruction_{}.png'.format(MODE, frame)) 392 | ) 393 | saver = tf.train.Saver() 394 | 395 | ''' 396 | Train loop 397 | ''' 398 | with tf.Session() as session: 399 | 400 | session.run(tf.global_variables_initializer()) 401 | gen = inf_train_gen() 402 | 403 | total_num = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) 404 | print '\nTotol number of parameters', total_num 405 | with open(logfile,'a') as f: 406 | f.write('Totol number of parameters' + str(total_num) + '\n') 407 | 408 | 409 | for iteration in xrange(ITERS): 410 | start_time = time.time() 411 | 412 | if iteration > 0: 413 | _data = gen.next() 414 | _gen_cost, _ = session.run([gen_cost, gen_train_op], 415 | feed_dict={real_x_int: _data}) 416 | 417 | for i in xrange(CRITIC_ITERS): 418 | _data = gen.next() 419 | _disc_cost, _ = session.run( 420 | [disc_cost, disc_train_op], 421 | feed_dict={real_x_int: _data} 422 | ) 423 | if MODE is 'wali': 424 | _ = session.run(clip_disc_weights) 425 | 426 | if MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 427 | if iteration > 0: 428 | lib.plot.plot('train gen cost ', _gen_cost) 429 | else: 430 | lib.plot.plot('train disc cost', _disc_cost) 431 | lib.plot.plot('time', time.time() - start_time) 432 | 433 | # Calculate dev loss 434 | if iteration % 100 == 99: 435 | if rec_penalty is not None: 436 | dev_rec_costs = [] 437 | dev_reg_costs = [] 438 | for images,_ in dev_gen(): 439 | _dev_rec_cost, _dev_gen_cost = session.run( 440 | [rec_penalty, gen_cost], 441 | feed_dict={real_x_int: images} 442 | ) 443 | dev_rec_costs.append(_dev_rec_cost) 444 | dev_reg_costs.append(_dev_gen_cost - _dev_rec_cost) 445 | lib.plot.plot('dev rec cost', np.mean(dev_rec_costs)) 446 | lib.plot.plot('dev reg cost', np.mean(dev_reg_costs)) 447 | else: 448 | dev_gen_costs = [] 449 | for images,_ in dev_gen(): 450 | _dev_gen_cost = session.run( 451 | gen_cost, 452 | feed_dict={real_x_int: images} 453 | ) 454 | dev_gen_costs.append(_dev_gen_cost) 455 | lib.plot.plot('dev gen cost', np.mean(dev_gen_costs)) 456 | 457 | # Write logs 458 | if (iteration < 5) or (iteration % 100 == 99): 459 | lib.plot.flush(outf, logfile) 460 | lib.plot.tick() 461 | 462 | # Generation and reconstruction 463 | if iteration % 5000 == 4999: 464 | generate_image(iteration, _data) 465 | reconstruct_image(iteration, rand_data_int) 466 | # reconstruct_image(-iteration, fixed_data_int) 467 | 468 | # Save model 469 | if iteration == ITERS - 1: 470 | save_path = saver.save(session, os.path.join(outf, '{}_model_{}.ckpt'.format(MODE, iteration))) -------------------------------------------------------------------------------- /gan_inference_mnist.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil, time 2 | sys.path.append(os.getcwd()) 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import sklearn.datasets 9 | from sklearn.manifold import TSNE 10 | import tensorflow as tf 11 | 12 | import tflib as lib 13 | import tflib.ops.linear 14 | import tflib.ops.conv2d 15 | import tflib.ops.batchnorm 16 | import tflib.ops.deconv2d 17 | import tflib.save_images 18 | import tflib.mnist 19 | import tflib.plot 20 | import tflib.visualization 21 | import tflib.objs.gan_inference 22 | import tflib.objs.mmd 23 | import tflib.objs.kl 24 | import tflib.objs.kl_aggregated 25 | import tflib.utils.distance 26 | 27 | 28 | ''' 29 | hyperparameters 30 | ''' 31 | MODE = 'ali' # ali, alice, alice-z, alice-x, vegan 32 | if MODE in ['vegan-kl', 'vegan-ikl', 'vegan-jsd']: 33 | TYPE_Q = 'learn_std' # learn_std, fix_std, no_std 34 | TYPE_P = 'no_std' 35 | Z_SAMPLES = 100 # MC estimation for D(q(z)||p(z)) 36 | elif MODE is 'vae': 37 | TYPE_Q = 'learn_std' 38 | TYPE_P = 'learn_std' 39 | else: 40 | TYPE_Q = 'no_std' 41 | TYPE_P = 'no_std' 42 | STD = .1 # For fix_std 43 | d_list = ['alice', 'alice-z', 'alice-x', 'vegan', 'vegan-wgan-gp', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vegan-mmd'] 44 | if MODE in d_list: 45 | DISTANCE_X = 'l2' # l1, l2 46 | if MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 47 | CRITIC_ITERS = 0 # No discriminators 48 | elif MODE in ['vegan', 'vegan-wgan-gp', 'wali', 'wali-gp']: 49 | CRITIC_ITERS = 5 # 5 iters of D per iter of G 50 | else: 51 | CRITIC_ITERS = 1 52 | 53 | BATCH_SIZE = 50 # Batch size 54 | LAMBDA = 1. # Balance reconstruction and regularization in vegan 55 | LR = 2e-4 56 | if MODE in ['vae']: 57 | BETA1 = .9 58 | else: 59 | BETA1 = .5 60 | ITERS = 200000 # How many generator iterations to train for 61 | 62 | DIM = 64 # Model dimensionality 63 | OUTPUT_DIM = 784 # Number of pixels in MNIST (28*28) 64 | if MODE in ['vegan', 'vegan-wgan-gp', 'vegan-kl', 'vegan-jsd', 'vegan-ikl']: 65 | BN_FLAG = False # Use batch_norm or not 66 | DIM_LATENT = 8 # Dimensionality of the latent z 67 | else: 68 | BN_FLAG = True 69 | DIM_LATENT = 128 70 | N_VIS = BATCH_SIZE*2 # Number of samples to be visualized 71 | 72 | 73 | ''' 74 | logs 75 | ''' 76 | filename_script=os.path.basename(os.path.realpath(__file__)) 77 | outf=os.path.join("result", os.path.splitext(filename_script)[0]) 78 | outf+='.MODE-' 79 | outf+=MODE 80 | outf+='.' 81 | outf+=str(int(time.time())) 82 | if not os.path.exists(outf): 83 | os.makedirs(outf) 84 | logfile=os.path.join(outf, 'logfile.txt') 85 | shutil.copy(os.path.realpath(__file__), os.path.join(outf, filename_script)) 86 | lib.print_model_settings_to_file(locals().copy(), logfile) 87 | 88 | 89 | ''' 90 | models 91 | ''' 92 | unit_std_x = tf.constant((STD*np.ones(shape=(BATCH_SIZE, OUTPUT_DIM))).astype('float32')) 93 | unit_std_z = tf.constant((STD*np.ones(shape=(BATCH_SIZE, DIM_LATENT))).astype('float32')) 94 | 95 | def LeakyReLU(x, alpha=0.2): 96 | return tf.maximum(alpha*x, x) 97 | 98 | def ReLULayer(name, n_in, n_out, inputs): 99 | output = lib.ops.linear.Linear( 100 | name+'.Linear', 101 | n_in, 102 | n_out, 103 | inputs, 104 | initialization='he' 105 | ) 106 | return tf.nn.relu(output) 107 | 108 | def LeakyReLULayer(name, n_in, n_out, inputs): 109 | output = lib.ops.linear.Linear( 110 | name+'.Linear', 111 | n_in, 112 | n_out, 113 | inputs, 114 | initialization='he' 115 | ) 116 | return LeakyReLU(output) 117 | 118 | def GaussianNoiseLayer(input_layer, std): 119 | noise = tf.random_normal(shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32) 120 | return input_layer + noise 121 | 122 | def Generator(noise): 123 | output = lib.ops.linear.Linear('Generator.Input', DIM_LATENT, 4*4*4*DIM, noise) 124 | if BN_FLAG: 125 | output = lib.ops.batchnorm.Batchnorm('Generator.BN1', [0], output) 126 | output = tf.nn.relu(output) 127 | output = tf.reshape(output, [-1, 4*DIM, 4, 4]) 128 | 129 | output = lib.ops.deconv2d.Deconv2D('Generator.2', 4*DIM, 2*DIM, 5, output) 130 | if BN_FLAG: 131 | output = lib.ops.batchnorm.Batchnorm('Generator.BN2', [0,2,3], output) 132 | output = tf.nn.relu(output) 133 | 134 | output = output[:,:,:7,:7] 135 | 136 | output = lib.ops.deconv2d.Deconv2D('Generator.3', 2*DIM, DIM, 5, output) 137 | if BN_FLAG: 138 | output = lib.ops.batchnorm.Batchnorm('Generator.BN3', [0,2,3], output) 139 | output = tf.nn.relu(output) 140 | 141 | output = lib.ops.deconv2d.Deconv2D('Generator.5', DIM, 1, 5, output) 142 | output = tf.nn.sigmoid(output) 143 | 144 | return tf.reshape(output, [-1, OUTPUT_DIM]), None, None 145 | 146 | def Extractor(inputs): 147 | output = tf.reshape(inputs, [-1, 1, 28, 28]) 148 | 149 | output = lib.ops.conv2d.Conv2D('Extractor.1',1,DIM,5,output,stride=2) 150 | output = LeakyReLU(output) 151 | 152 | output = lib.ops.conv2d.Conv2D('Extractor.2', DIM, 2*DIM, 5, output, stride=2) 153 | if BN_FLAG: 154 | output = lib.ops.batchnorm.Batchnorm('Extractor.BN2', [0,2,3], output) 155 | output = LeakyReLU(output) 156 | 157 | output = lib.ops.conv2d.Conv2D('Extractor.3', 2*DIM, 4*DIM, 5, output, stride=2) 158 | if BN_FLAG: 159 | output = lib.ops.batchnorm.Batchnorm('Extractor.BN3', [0,2,3], output) 160 | output = LeakyReLU(output) 161 | 162 | output = tf.reshape(output, [-1, 4*4*4*DIM]) 163 | 164 | if TYPE_Q is 'learn_std': 165 | log_std = lib.ops.linear.Linear('Extractor.Std', 4*4*4*DIM, DIM_LATENT, output) 166 | std = tf.exp(log_std) 167 | elif TYPE_Q is 'fix_std': 168 | std = unit_std_z 169 | else: 170 | std = None 171 | mean = None 172 | 173 | output = lib.ops.linear.Linear('Extractor.Output', 4*4*4*DIM, DIM_LATENT, output) 174 | 175 | if TYPE_Q in ['learn_std', 'fix_std']: 176 | epsilon = tf.random_normal(unit_std_z.shape) 177 | mean = output 178 | output = tf.add(mean, tf.multiply(epsilon, std)) 179 | 180 | return tf.reshape(output, [-1, DIM_LATENT]), mean, std 181 | 182 | if MODE in ['vegan', 'vegan-wgan-gp']: 183 | # define a discriminator on z 184 | def Discriminator(z): 185 | output = GaussianNoiseLayer(z, std=.3) 186 | output = lib.ops.linear.Linear('Discriminator.Input', DIM_LATENT, 1024, output) 187 | if BN_FLAG: 188 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN1', [0], output) 189 | output = LeakyReLU(output) 190 | output = GaussianNoiseLayer(output, std=.5) 191 | 192 | output = lib.ops.linear.Linear('Discriminator.2', 1024, 512, output) 193 | if BN_FLAG: 194 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN2', [0], output) 195 | output = LeakyReLU(output) 196 | output = GaussianNoiseLayer(output, std=.5) 197 | 198 | output = lib.ops.linear.Linear('Discriminator.3', 512, 256, output) 199 | if BN_FLAG: 200 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN3', [0], output) 201 | output = LeakyReLU(output) 202 | output = GaussianNoiseLayer(output, std=.5) 203 | 204 | output = lib.ops.linear.Linear('Discriminator.4', 256, 256, output) 205 | if BN_FLAG: 206 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN4', [0], output) 207 | output = LeakyReLU(output) 208 | 209 | output = lib.ops.linear.Linear('Discriminator.Output', 256, 1, output) 210 | 211 | return tf.reshape(output, [-1]) 212 | 213 | elif MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 214 | pass # no discriminator 215 | 216 | else: 217 | def Discriminator(x, z): 218 | output = tf.reshape(x, [-1, 1, 28, 28]) 219 | 220 | output = lib.ops.conv2d.Conv2D('Discriminator.1',1,DIM,5,output,stride=2) 221 | output = LeakyReLU(output) 222 | 223 | output = lib.ops.conv2d.Conv2D('Discriminator.2', DIM, 2*DIM, 5, output, stride=2) 224 | if BN_FLAG: 225 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN2', [0,2,3], output) 226 | output = LeakyReLU(output) 227 | 228 | output = lib.ops.conv2d.Conv2D('Discriminator.3', 2*DIM, 4*DIM, 5, output, stride=2) 229 | if BN_FLAG: 230 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN3', [0,2,3], output) 231 | output = LeakyReLU(output) 232 | 233 | output = tf.reshape(output, [-1, 4*4*4*DIM]) 234 | 235 | z_output = lib.ops.linear.Linear('Discriminator.z1', DIM_LATENT, 512, z) 236 | z_output = LeakyReLU(z_output) 237 | z_output = tf.layers.dropout(z_output, rate=.2) 238 | z_output = lib.ops.linear.Linear('Discriminator.2', 512, 512, z_output) 239 | z_output = LeakyReLU(z_output) 240 | z_output = tf.layers.dropout(z_output, rate=.2) 241 | 242 | output = tf.concat([output, z_output], 1) 243 | output = lib.ops.linear.Linear('Discriminator.zx1', 4*4*4*DIM+512, 512, output) 244 | output = LeakyReLU(output) 245 | output = tf.layers.dropout(output, rate=.2) 246 | output = lib.ops.linear.Linear('Discriminator.zx2', 512, 512, output) 247 | output = LeakyReLU(output) 248 | output = tf.layers.dropout(output, rate=.2) 249 | 250 | output = lib.ops.linear.Linear('Discriminator.Output', 512, 1, output) 251 | 252 | return tf.reshape(output, [-1]) 253 | 254 | ''' 255 | losses 256 | ''' 257 | real_x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, OUTPUT_DIM]) 258 | q_z, q_z_mean, q_z_std = Extractor(real_x) 259 | rec_x, rec_x_mean, rec_x_std = Generator(q_z) 260 | p_z = tf.random_normal([BATCH_SIZE, DIM_LATENT]) 261 | fake_x, _, _ = Generator(p_z) 262 | rec_z, _, _ = Extractor(fake_x) 263 | if MODE in ['vegan-kl', 'vegan-ikl', 'vegan-jsd']: 264 | p_z_mean = tf.constant((np.zeros(shape=(Z_SAMPLES, DIM_LATENT))).astype('float32')) # prior for estimating D(q(z) || p(z)) 265 | p_z_std = tf.constant((np.ones(shape=(Z_SAMPLES, DIM_LATENT))).astype('float32')) 266 | elif MODE is 'vae': 267 | p_z_mean = tf.constant((np.zeros(shape=(BATCH_SIZE, DIM_LATENT))).astype('float32')) # prior for estimating D(q(z) || p(z)) 268 | p_z_std = tf.constant((np.ones(shape=(BATCH_SIZE, DIM_LATENT))).astype('float32')) 269 | 270 | 271 | if MODE in ['vegan', 'vegan-wgan-gp']: 272 | disc_real = Discriminator(p_z) # discriminate code 273 | disc_fake = Discriminator(q_z) 274 | elif MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 275 | pass # no discriminators 276 | else: 277 | disc_real = Discriminator(real_x, q_z) # discriminate code-data pair 278 | disc_fake = Discriminator(fake_x, p_z) 279 | 280 | gen_params = lib.params_with_name('Generator') 281 | ext_params = lib.params_with_name('Extractor') 282 | disc_params = lib.params_with_name('Discriminator') 283 | 284 | if MODE == 'ali': 285 | rec_penalty = None 286 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.ali(disc_fake, disc_real, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1) 287 | 288 | elif MODE == 'alice-z': 289 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 290 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.alice(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1) 291 | 292 | elif MODE == 'alice-x': 293 | rec_penalty = 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 294 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.alice(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1) 295 | 296 | elif MODE == 'alice': 297 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 298 | rec_penalty += 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 299 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.alice(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1) 300 | 301 | elif MODE == 'vegan': 302 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 303 | # rec_penalty += 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 304 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.vegan(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, LAMBDA,lr=LR, beta1=BETA1) 305 | 306 | elif MODE == 'vegan-wgan-gp': 307 | alpha = tf.random_uniform( 308 | shape=[BATCH_SIZE,1], 309 | minval=0., 310 | maxval=1. 311 | ) 312 | differences = q_z - p_z 313 | interpolates = p_z + (alpha*differences) 314 | gradients = tf.gradients(Discriminator(interpolates), interpolates)[0] 315 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) 316 | gradient_penalty = 10.*(tf.reduce_mean((slopes-1.)**2)) 317 | 318 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 319 | 320 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.vegan_wgan_gp(disc_fake, disc_real, rec_penalty, gradient_penalty, gen_params+ext_params, disc_params, LAMBDA,lr=LR, beta1=BETA1) 321 | 322 | elif MODE == 'vegan-mmd': 323 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 324 | gen_cost, gen_train_op = lib.objs.mmd.vegan_mmd(q_z, p_z, rec_penalty, gen_params+ext_params, BATCH_SIZE, LAMBDA, lr=LR, beta1=BETA1) 325 | 326 | elif MODE == 'vegan-kl': 327 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 328 | gen_cost, gen_train_op = lib.objs.kl_aggregated.vegan_kl(q_z_mean, q_z_std, p_z_mean, p_z_std, rec_penalty, gen_params+ext_params, Z_SAMPLES, BATCH_SIZE, DIM_LATENT, LAMBDA, lr=LR, beta1=BETA1) 329 | 330 | elif MODE == 'vegan-ikl': 331 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 332 | gen_cost, gen_train_op = lib.objs.kl_aggregated.vegan_ikl(q_z_mean, q_z_std, p_z_mean, p_z_std, rec_penalty, gen_params+ext_params, Z_SAMPLES, DIM_LATENT, LAMBDA, lr=LR, beta1=BETA1) 333 | 334 | elif MODE == 'vegan-jsd': 335 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 336 | gen_cost, gen_train_op = lib.objs.kl_aggregated.vegan_jsd(q_z_mean, q_z_std, p_z_mean, p_z_std, rec_penalty, gen_params+ext_params, Z_SAMPLES, BATCH_SIZE, DIM_LATENT, LAMBDA, lr=LR, beta1=BETA1) 337 | 338 | elif MODE == 'vae': 339 | rec_penalty = None 340 | gen_cost, gen_train_op = lib.objs.kl.vae(real_x, rec_x_mean, rec_x_std, q_z_mean, q_z_std, p_z_mean, p_z_std, gen_params+ext_params, lr=LR, beta1=BETA1) 341 | 342 | elif MODE == 'wali': 343 | rec_penalty = None 344 | gen_cost, disc_cost, clip_disc_weights, gen_train_op, disc_train_op, clip_ops = lib.objs.gan_inference.wali(disc_fake, disc_real, gen_params+ext_params, disc_params) 345 | 346 | elif MODE == 'wali-gp': 347 | rec_penalty = None 348 | alpha = tf.random_uniform( 349 | shape=[BATCH_SIZE,1], 350 | minval=0., 351 | maxval=1. 352 | ) 353 | differences = fake_x - real_x 354 | interpolates = real_x + (alpha*differences) 355 | differences_z = p_z - q_z 356 | interpolates_z = q_z + (alpha*differences_z) 357 | gradients = tf.gradients(Discriminator(interpolates,interpolates_z), [interpolates,interpolates_z])[0] 358 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) 359 | gradient_penalty = 10.*(tf.reduce_mean((slopes-1.)**2)) 360 | 361 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.wali_gp(disc_fake, disc_real, gradient_penalty, gen_params+ext_params, disc_params) 362 | else: 363 | raise('NotImplementedError') 364 | 365 | # For visualizing samples 366 | fixed_noise = tf.constant(np.random.normal(size=(N_VIS, DIM_LATENT)).astype('float32')) 367 | fixed_noise_samples, _, _ = Generator(fixed_noise) 368 | def generate_image(frame, true_dist): 369 | samples = session.run(fixed_noise_samples) 370 | lib.save_images.save_images( 371 | samples.reshape((N_VIS, 28, 28)), 372 | os.path.join(outf, '{}_mnist_samples_{}.png'.format(MODE, frame)) 373 | ) 374 | 375 | # Dataset iterator 376 | train_gen, dev_gen, test_gen = lib.mnist.load(BATCH_SIZE, BATCH_SIZE) 377 | def inf_train_gen(): 378 | while True: 379 | for images,targets in train_gen(): 380 | yield images 381 | 382 | # For reconstruction 383 | fixed_data, _ = dev_gen().next() 384 | fixed_q_z, _, _ = Extractor(fixed_data) 385 | fixed_rec, _, _ = Generator(fixed_q_z) 386 | def reconstruct_image(frame): 387 | rec_samples = session.run(fixed_rec) 388 | tmp_list = [] 389 | for d, r in zip(fixed_data, rec_samples): 390 | tmp_list.append(d) 391 | tmp_list.append(r) 392 | rec_samples = np.vstack(tmp_list) 393 | lib.save_images.save_images( 394 | rec_samples.reshape((BATCH_SIZE*2, 28, 28)), 395 | os.path.join(outf, '{}_mnist_reconstruction_{}.png'.format(MODE, frame)) 396 | ) 397 | saver = tf.train.Saver() 398 | 399 | ''' 400 | Train loop 401 | ''' 402 | with tf.Session() as session: 403 | 404 | session.run(tf.global_variables_initializer()) 405 | gen = inf_train_gen() 406 | 407 | total_num = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) 408 | print '\nTotol number of parameters', total_num 409 | with open(logfile,'a') as f: 410 | f.write('Totol number of parameters' + str(total_num) + '\n') 411 | 412 | for iteration in xrange(ITERS): 413 | start_time = time.time() 414 | 415 | if iteration > 0: 416 | _data = gen.next() 417 | _gen_cost, _ = session.run([gen_cost, gen_train_op], 418 | feed_dict={real_x: _data}) 419 | 420 | for i in xrange(CRITIC_ITERS): 421 | _data = gen.next() 422 | _disc_cost, _ = session.run( 423 | [disc_cost, disc_train_op], 424 | feed_dict={real_x: _data} 425 | ) 426 | if MODE is 'wali': 427 | _ = session.run(clip_disc_weights) 428 | 429 | if MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 430 | if iteration > 0: 431 | lib.plot.plot('train gen cost ', _gen_cost) 432 | else: 433 | lib.plot.plot('train disc cost', _disc_cost) 434 | lib.plot.plot('time', time.time() - start_time) 435 | 436 | # Calculate dev loss 437 | if iteration % 100 == 99: 438 | if rec_penalty is not None: 439 | dev_rec_costs = [] 440 | dev_reg_costs = [] 441 | for images,_ in dev_gen(): 442 | _dev_rec_cost, _dev_gen_cost = session.run( 443 | [rec_penalty, gen_cost], 444 | feed_dict={real_x: images} 445 | ) 446 | dev_rec_costs.append(_dev_rec_cost) 447 | dev_reg_costs.append(_dev_gen_cost - _dev_rec_cost) 448 | lib.plot.plot('dev rec cost', np.mean(dev_rec_costs)) 449 | lib.plot.plot('dev reg cost', np.mean(dev_reg_costs)) 450 | else: 451 | dev_gen_costs = [] 452 | for images,_ in dev_gen(): 453 | _dev_gen_cost = session.run( 454 | gen_cost, 455 | feed_dict={real_x: images} 456 | ) 457 | dev_gen_costs.append(_dev_gen_cost) 458 | lib.plot.plot('dev gen cost', np.mean(dev_gen_costs)) 459 | 460 | 461 | # Write logs 462 | if (iteration < 5) or (iteration % 100 == 99): 463 | lib.plot.flush(outf, logfile) 464 | 465 | lib.plot.tick() 466 | 467 | # Generation and reconstruction 468 | if iteration % 5000 == 4999: 469 | generate_image(iteration, _data) 470 | reconstruct_image(iteration) 471 | 472 | # Latent space visualization 473 | if iteration % 50000 == 49999: 474 | z_dev, z_mean_dev, y_dev = [],[],[] 475 | for xb, yb in dev_gen(): 476 | zb = session.run(q_z,feed_dict={real_x: xb}) 477 | z_dev.append(zb) 478 | y_dev.append(yb) 479 | z_dev_2D = TSNE().fit_transform(np.vstack(z_dev)) 480 | lib.visualization.scatter(data=z_dev_2D, label=np.hstack(y_dev), dir=outf, file_name='{}_mnist_manifold_{}.png'.format(MODE, iteration)) 481 | 482 | # Save model 483 | if iteration == ITERS - 1: 484 | save_path = saver.save(session, os.path.join(outf, '{}_mnist_model_{}.ckpt'.format(MODE, iteration))) -------------------------------------------------------------------------------- /gan_inference_cifar10.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil, time 2 | sys.path.append(os.getcwd()) 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import sklearn.datasets 9 | from sklearn.manifold import TSNE 10 | import tensorflow as tf 11 | 12 | import tflib as lib 13 | import tflib.ops.linear 14 | import tflib.ops.conv2d 15 | import tflib.ops.batchnorm 16 | import tflib.ops.deconv2d 17 | import tflib.save_images 18 | import tflib.cifar10 19 | import tflib.inception_score 20 | import tflib.plot 21 | import tflib.visualization 22 | import tflib.objs.gan_inference 23 | import tflib.objs.mmd 24 | import tflib.objs.kl 25 | import tflib.objs.kl_aggregated 26 | import tflib.utils.distance 27 | 28 | 29 | # Download CIFAR-10 (Python version) at 30 | # https://www.cs.toronto.edu/~kriz/cifar.html and fill in the path to the 31 | # extracted files here! 32 | DATA_DIR = './dataset/cifar10/cifar-10-batches-py' 33 | if len(DATA_DIR) == 0: 34 | raise Exception('Please specify path to data directory in gan_cifar.py!') 35 | 36 | ''' 37 | hyperparameters 38 | ''' 39 | MODE = 'ali' # ali, alice, alice-z, alice-x 40 | if MODE in ['vegan-kl', 'vegan-ikl', 'vegan-jsd']: 41 | TYPE_Q = 'learn_std' # learn_std, fix_std, no_std 42 | TYPE_P = 'no_std' 43 | Z_SAMPLES = 100 # MC estimation for D(q(z)||p(z)) 44 | elif MODE is 'vae': 45 | TYPE_Q = 'learn_std' 46 | TYPE_P = 'learn_std' 47 | else: 48 | TYPE_Q = 'no_std' 49 | TYPE_P = 'no_std' 50 | STD = .1 # For fix_std 51 | d_list = ['alice', 'alice-z', 'alice-x', 'vegan', 'vegan-wgan-gp', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vegan-mmd'] 52 | if MODE in d_list: 53 | DISTANCE_X = 'l2' # l1, l2 54 | if MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 55 | CRITIC_ITERS = 0 # No discriminators 56 | elif MODE in ['vegan', 'vegan-wgan-gp', 'wali', 'wali-gp']: 57 | CRITIC_ITERS = 5 # 5 iters of D per iter of G 58 | else: 59 | CRITIC_ITERS = 1 60 | 61 | BATCH_SIZE = 64 # Batch size 62 | LAMBDA = 1. # Balance reconstruction and regularization in vegan 63 | LR = 2e-4 64 | if MODE in ['vae']: 65 | BETA1 = .9 66 | else: 67 | BETA1 = .5 68 | ITERS = 200000 # How many generator iterations to train for 69 | 70 | DIM = 64 # Model dimensionality 71 | OUTPUT_DIM = 3072 # Number of pixels in CIFAR10 (3*32*32) 72 | if MODE in ['vegan', 'vegan-wgan-gp', 'vegan-kl', 'vegan-jsd', 'vegan-ikl']: 73 | BN_FLAG = False # Use batch_norm or not 74 | DIM_LATENT = 8 # Dimensionality of the latent z 75 | else: 76 | BN_FLAG = True 77 | DIM_LATENT = 128 78 | N_VIS = BATCH_SIZE*2 # Number of samples to be visualized 79 | DR_RATE = .2 80 | 81 | 82 | ''' 83 | logs 84 | ''' 85 | filename_script=os.path.basename(os.path.realpath(__file__)) 86 | outf=os.path.join("result", os.path.splitext(filename_script)[0]) 87 | outf+='.MODE-' 88 | outf+=MODE 89 | outf+='.DR_RATE-' 90 | outf+=str(DR_RATE) 91 | outf+='.' 92 | outf+=str(int(time.time())) 93 | if not os.path.exists(outf): 94 | os.makedirs(outf) 95 | logfile=os.path.join(outf, 'logfile.txt') 96 | shutil.copy(os.path.realpath(__file__), os.path.join(outf, filename_script)) 97 | lib.print_model_settings_to_file(locals().copy(), logfile) 98 | 99 | 100 | ''' 101 | models 102 | ''' 103 | unit_std_x = tf.constant((STD*np.ones(shape=(BATCH_SIZE, OUTPUT_DIM))).astype('float32')) 104 | unit_std_z = tf.constant((STD*np.ones(shape=(BATCH_SIZE, DIM_LATENT))).astype('float32')) 105 | 106 | def LeakyReLU(x, alpha=0.2): 107 | return tf.maximum(alpha*x, x) 108 | 109 | def ReLULayer(name, n_in, n_out, inputs): 110 | output = lib.ops.linear.Linear( 111 | name+'.Linear', 112 | n_in, 113 | n_out, 114 | inputs, 115 | initialization='he' 116 | ) 117 | return tf.nn.relu(output) 118 | 119 | def LeakyReLULayer(name, n_in, n_out, inputs): 120 | output = lib.ops.linear.Linear( 121 | name+'.Linear', 122 | n_in, 123 | n_out, 124 | inputs, 125 | initialization='he' 126 | ) 127 | return LeakyReLU(output) 128 | 129 | def GaussianNoiseLayer(input_layer, std): 130 | noise = tf.random_normal(shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32) 131 | return input_layer + noise 132 | 133 | def Generator(noise): 134 | output = lib.ops.linear.Linear('Generator.Input', DIM_LATENT, 4*4*4*DIM, noise) 135 | if BN_FLAG: 136 | output = lib.ops.batchnorm.Batchnorm('Generator.BN1', [0], output) 137 | output = tf.nn.relu(output) 138 | output = tf.reshape(output, [-1, 4*DIM, 4, 4]) 139 | 140 | output = lib.ops.deconv2d.Deconv2D('Generator.2', 4*DIM, 2*DIM, 5, output) 141 | if BN_FLAG: 142 | output = lib.ops.batchnorm.Batchnorm('Generator.BN2', [0,2,3], output) 143 | output = tf.nn.relu(output) 144 | 145 | output = lib.ops.deconv2d.Deconv2D('Generator.3', 2*DIM, DIM, 5, output) 146 | if BN_FLAG: 147 | output = lib.ops.batchnorm.Batchnorm('Generator.BN3', [0,2,3], output) 148 | output = tf.nn.relu(output) 149 | 150 | output = lib.ops.deconv2d.Deconv2D('Generator.5', DIM, 3, 5, output) 151 | output = tf.tanh(output) 152 | 153 | return tf.reshape(output, [-1, OUTPUT_DIM]), None, None 154 | 155 | def Extractor(inputs): 156 | output = tf.reshape(inputs, [-1, 3, 32, 32]) 157 | 158 | output = lib.ops.conv2d.Conv2D('Extractor.1', 3, DIM, 5, output,stride=2) 159 | output = LeakyReLU(output) 160 | 161 | output = lib.ops.conv2d.Conv2D('Extractor.2', DIM, 2*DIM, 5, output, stride=2) 162 | if BN_FLAG: 163 | output = lib.ops.batchnorm.Batchnorm('Extractor.BN2', [0,2,3], output) 164 | output = LeakyReLU(output) 165 | 166 | output = lib.ops.conv2d.Conv2D('Extractor.3', 2*DIM, 4*DIM, 5, output, stride=2) 167 | if BN_FLAG: 168 | output = lib.ops.batchnorm.Batchnorm('Extractor.BN3', [0,2,3], output) 169 | output = LeakyReLU(output) 170 | 171 | output = tf.reshape(output, [-1, 4*4*4*DIM]) 172 | 173 | if TYPE_Q is 'learn_std': 174 | log_std = lib.ops.linear.Linear('Extractor.Std', 4*4*4*DIM, DIM_LATENT, output) 175 | std = tf.exp(log_std) 176 | elif TYPE_Q is 'fix_std': 177 | std = unit_std_z 178 | else: 179 | std = None 180 | mean = None 181 | 182 | output = lib.ops.linear.Linear('Extractor.Output', 4*4*4*DIM, DIM_LATENT, output) 183 | 184 | if TYPE_Q in ['learn_std', 'fix_std']: 185 | epsilon = tf.random_normal(unit_std_z.shape) 186 | mean = output 187 | output = tf.add(mean, tf.multiply(epsilon, std)) 188 | 189 | return tf.reshape(output, [-1, DIM_LATENT]), mean, std 190 | 191 | if MODE in ['vegan', 'vegan-wgan-gp']: 192 | # define a discriminator on z 193 | def Discriminator(z): 194 | output = GaussianNoiseLayer(z, std=.3) 195 | output = lib.ops.linear.Linear('Discriminator.Input', DIM_LATENT, 1024, output) 196 | if BN_FLAG: 197 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN1', [0], output) 198 | output = LeakyReLU(output) 199 | output = GaussianNoiseLayer(output, std=.5) 200 | 201 | output = lib.ops.linear.Linear('Discriminator.2', 1024, 512, output) 202 | if BN_FLAG: 203 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN2', [0], output) 204 | output = LeakyReLU(output) 205 | output = GaussianNoiseLayer(output, std=.5) 206 | 207 | output = lib.ops.linear.Linear('Discriminator.3', 512, 256, output) 208 | if BN_FLAG: 209 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN3', [0], output) 210 | output = LeakyReLU(output) 211 | output = GaussianNoiseLayer(output, std=.5) 212 | 213 | output = lib.ops.linear.Linear('Discriminator.4', 256, 256, output) 214 | if BN_FLAG: 215 | output = lib.ops.batchnorm.Batchnorm('Discriminator.BN4', [0], output) 216 | output = LeakyReLU(output) 217 | 218 | output = lib.ops.linear.Linear('Discriminator.Output', 256, 1, output) 219 | 220 | return tf.reshape(output, [-1]) 221 | 222 | elif MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 223 | pass # no discriminator 224 | 225 | else: 226 | 227 | def Discriminator(x, z): 228 | output = tf.reshape(x, [-1, 3, 32, 32]) 229 | 230 | output = lib.ops.conv2d.Conv2D('Discriminator.1',3,DIM,5,output,stride=2) 231 | output = LeakyReLU(output) 232 | output = tf.layers.dropout(output, rate=DR_RATE) 233 | 234 | output = lib.ops.conv2d.Conv2D('Discriminator.2', DIM, 2*DIM, 5, output, stride=2) 235 | output = LeakyReLU(output) 236 | output = tf.layers.dropout(output, rate=DR_RATE) 237 | 238 | output = lib.ops.conv2d.Conv2D('Discriminator.3', 2*DIM, 4*DIM, 5, output, stride=2) 239 | output = LeakyReLU(output) 240 | output = tf.layers.dropout(output, rate=DR_RATE) 241 | 242 | output = tf.reshape(output, [-1, 4*4*4*DIM]) 243 | 244 | z_output = lib.ops.linear.Linear('Discriminator.z1', DIM_LATENT, 512, z) 245 | z_output = LeakyReLU(z_output) 246 | z_output = tf.layers.dropout(z_output, rate=DR_RATE) 247 | 248 | output = tf.concat([output, z_output], 1) 249 | output = lib.ops.linear.Linear('Discriminator.zx1', 4*4*4*DIM+512, 512, output) 250 | output = LeakyReLU(output) 251 | output = tf.layers.dropout(output, rate=DR_RATE) 252 | 253 | output = lib.ops.linear.Linear('Discriminator.Output', 512, 1, output) 254 | 255 | return tf.reshape(output, [-1]) 256 | 257 | 258 | ''' 259 | losses 260 | ''' 261 | real_x_int = tf.placeholder(tf.int32, shape=[BATCH_SIZE, OUTPUT_DIM]) 262 | real_x = 2*((tf.cast(real_x_int, tf.float32)/255.)-.5) 263 | q_z, q_z_mean, q_z_std = Extractor(real_x) 264 | rec_x, rec_x_mean, rec_x_std = Generator(q_z) 265 | p_z = tf.random_normal([BATCH_SIZE, DIM_LATENT]) 266 | fake_x, _, _ = Generator(p_z) 267 | rec_z, _, _ = Extractor(fake_x) 268 | if MODE in ['vegan-kl', 'vegan-ikl', 'vegan-jsd']: 269 | p_z_mean = tf.constant((np.zeros(shape=(Z_SAMPLES, DIM_LATENT))).astype('float32')) # prior for estimating D(q(z) || p(z)) 270 | p_z_std = tf.constant((np.ones(shape=(Z_SAMPLES, DIM_LATENT))).astype('float32')) 271 | elif MODE is 'vae': 272 | p_z_mean = tf.constant((np.zeros(shape=(BATCH_SIZE, DIM_LATENT))).astype('float32')) # prior for estimating D(q(z) || p(z)) 273 | p_z_std = tf.constant((np.ones(shape=(BATCH_SIZE, DIM_LATENT))).astype('float32')) 274 | 275 | 276 | if MODE in ['vegan', 'vegan-wgan-gp']: 277 | disc_real = Discriminator(p_z) # discriminate code 278 | disc_fake = Discriminator(q_z) 279 | elif MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 280 | pass # no discriminators 281 | else: 282 | disc_real = Discriminator(real_x, q_z) # discriminate code-data pair 283 | disc_fake = Discriminator(fake_x, p_z) 284 | 285 | gen_params = lib.params_with_name('Generator') 286 | ext_params = lib.params_with_name('Extractor') 287 | disc_params = lib.params_with_name('Discriminator') 288 | 289 | if MODE == 'ali': 290 | rec_penalty = None 291 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.ali(disc_fake, disc_real, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1) 292 | 293 | elif MODE == 'alice-z': 294 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 295 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.alice(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1) 296 | 297 | elif MODE == 'alice-x': 298 | rec_penalty = 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 299 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.alice(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1) 300 | 301 | elif MODE == 'alice': 302 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 303 | rec_penalty += 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 304 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.alice(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1) 305 | 306 | elif MODE == 'vegan': 307 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 308 | # rec_penalty += 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 309 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.vegan(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, LAMBDA,lr=LR, beta1=BETA1) 310 | 311 | elif MODE == 'vegan-wgan-gp': 312 | alpha = tf.random_uniform( 313 | shape=[BATCH_SIZE,1], 314 | minval=0., 315 | maxval=1. 316 | ) 317 | differences = q_z - p_z 318 | interpolates = p_z + (alpha*differences) 319 | gradients = tf.gradients(Discriminator(interpolates), interpolates)[0] 320 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) 321 | gradient_penalty = 10.*(tf.reduce_mean((slopes-1.)**2)) 322 | 323 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 324 | 325 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.vegan_wgan_gp(disc_fake, disc_real, rec_penalty, gradient_penalty, gen_params+ext_params, disc_params, LAMBDA,lr=LR, beta1=BETA1) 326 | 327 | elif MODE == 'vegan-mmd': 328 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 329 | gen_cost, gen_train_op = lib.objs.mmd.vegan_mmd(q_z, p_z, rec_penalty, gen_params+ext_params, BATCH_SIZE, LAMBDA, lr=LR, beta1=BETA1) 330 | 331 | elif MODE == 'vegan-kl': 332 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 333 | gen_cost, gen_train_op = lib.objs.kl_aggregated.vegan_kl(q_z_mean, q_z_std, p_z_mean, p_z_std, rec_penalty, gen_params+ext_params, Z_SAMPLES, BATCH_SIZE, DIM_LATENT, LAMBDA, lr=LR, beta1=BETA1) 334 | 335 | elif MODE == 'vegan-ikl': 336 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 337 | gen_cost, gen_train_op = lib.objs.kl_aggregated.vegan_ikl(q_z_mean, q_z_std, p_z_mean, p_z_std, rec_penalty, gen_params+ext_params, Z_SAMPLES, DIM_LATENT, LAMBDA, lr=LR, beta1=BETA1) 338 | 339 | elif MODE == 'vegan-jsd': 340 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 341 | gen_cost, gen_train_op = lib.objs.kl_aggregated.vegan_jsd(q_z_mean, q_z_std, p_z_mean, p_z_std, rec_penalty, gen_params+ext_params, Z_SAMPLES, BATCH_SIZE, DIM_LATENT, LAMBDA, lr=LR, beta1=BETA1) 342 | 343 | elif MODE == 'vae': 344 | rec_penalty = None 345 | gen_cost, gen_train_op = lib.objs.kl.vae(real_x, rec_x_mean, rec_x_std, q_z_mean, q_z_std, p_z_mean, p_z_std, gen_params+ext_params, lr=LR, beta1=BETA1) 346 | 347 | elif MODE == 'wali': 348 | rec_penalty = None 349 | gen_cost, disc_cost, clip_disc_weights, gen_train_op, disc_train_op, clip_ops = lib.objs.gan_inference.wali(disc_fake, disc_real, gen_params+ext_params, disc_params) 350 | 351 | elif MODE == 'wali-gp': 352 | rec_penalty = None 353 | alpha = tf.random_uniform( 354 | shape=[BATCH_SIZE,1], 355 | minval=0., 356 | maxval=1. 357 | ) 358 | differences = fake_x - real_x 359 | interpolates = real_x + (alpha*differences) 360 | differences_z = p_z - q_z 361 | interpolates_z = q_z + (alpha*differences_z) 362 | gradients = tf.gradients(Discriminator(interpolates,interpolates_z), [interpolates,interpolates_z])[0] 363 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) 364 | gradient_penalty = 10.*(tf.reduce_mean((slopes-1.)**2)) 365 | 366 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.wali_gp(disc_fake, disc_real, gradient_penalty, gen_params+ext_params, disc_params) 367 | else: 368 | raise('NotImplementedError') 369 | 370 | # For visualizing samples 371 | fixed_noise = tf.constant(np.random.normal(size=(N_VIS, DIM_LATENT)).astype('float32')) 372 | fixed_noise_samples, _, _ = Generator(fixed_noise) 373 | def generate_image(frame, true_dist): 374 | samples = session.run(fixed_noise_samples) 375 | samples = ((samples+1.)*(255./2)).astype('int32') 376 | lib.save_images.save_images( 377 | samples.reshape((-1, 3, 32, 32)), 378 | os.path.join(outf, '{}_samples_{}.png'.format(MODE, frame)) 379 | ) 380 | 381 | # For calculating inception score 382 | p_z_100 = tf.random_normal([100, DIM_LATENT]) 383 | samples_100, _, _ = Generator(p_z_100) 384 | def get_inception_score(): 385 | all_samples = [] 386 | for i in xrange(500): 387 | all_samples.append(session.run(samples_100)) 388 | all_samples = np.concatenate(all_samples, axis=0) 389 | all_samples = ((all_samples+1.)*(255./2)).astype('int32') 390 | all_samples = all_samples.reshape((-1, 3, 32, 32)).transpose(0,2,3,1) 391 | return lib.inception_score.get_inception_score(list(all_samples)) 392 | 393 | # Dataset iterator 394 | train_gen, dev_gen = lib.cifar10.load(BATCH_SIZE, data_dir=DATA_DIR) 395 | def inf_train_gen(): 396 | while True: 397 | for images, targets in train_gen(): 398 | yield images 399 | 400 | # For reconstruction 401 | fixed_data_int = lib.cifar10.get_reconstruction_data(BATCH_SIZE, data_dir=DATA_DIR) 402 | rand_data_int, _ = dev_gen().next() 403 | def reconstruct_image(frame, data_int): 404 | rec_samples = session.run(rec_x, feed_dict={real_x_int: data_int}) 405 | rec_samples = ((rec_samples+1.)*(255./2)).astype('int32') 406 | tmp_list = [] 407 | for d, r in zip(data_int, rec_samples): 408 | tmp_list.append(d) 409 | tmp_list.append(r) 410 | rec_samples = np.vstack(tmp_list) 411 | lib.save_images.save_images( 412 | rec_samples.reshape((-1, 3, 32, 32)), 413 | os.path.join(outf, '{}_reconstruction_{}.png'.format(MODE, frame)) 414 | ) 415 | saver = tf.train.Saver() 416 | 417 | ''' 418 | Train loop 419 | ''' 420 | with tf.Session() as session: 421 | 422 | session.run(tf.global_variables_initializer()) 423 | gen = inf_train_gen() 424 | 425 | total_num = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) 426 | print '\nTotol number of parameters', total_num 427 | with open(logfile,'a') as f: 428 | f.write('Totol number of parameters' + str(total_num) + '\n') 429 | 430 | 431 | for iteration in xrange(ITERS): 432 | start_time = time.time() 433 | 434 | if iteration > 0: 435 | _data = gen.next() 436 | _gen_cost, _ = session.run([gen_cost, gen_train_op], 437 | feed_dict={real_x_int: _data}) 438 | 439 | for i in xrange(CRITIC_ITERS): 440 | _data = gen.next() 441 | _disc_cost, _ = session.run( 442 | [disc_cost, disc_train_op], 443 | feed_dict={real_x_int: _data} 444 | ) 445 | if MODE is 'wali': 446 | _ = session.run(clip_disc_weights) 447 | 448 | if MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 449 | if iteration > 0: 450 | lib.plot.plot('train gen cost ', _gen_cost) 451 | else: 452 | lib.plot.plot('train disc cost', _disc_cost) 453 | lib.plot.plot('time', time.time() - start_time) 454 | 455 | # Calculate dev loss 456 | if iteration % 100 == 99: 457 | if rec_penalty is not None: 458 | dev_rec_costs = [] 459 | dev_reg_costs = [] 460 | for images,_ in dev_gen(): 461 | _dev_rec_cost, _dev_gen_cost = session.run( 462 | [rec_penalty, gen_cost], 463 | feed_dict={real_x_int: images} 464 | ) 465 | dev_rec_costs.append(_dev_rec_cost) 466 | dev_reg_costs.append(_dev_gen_cost - _dev_rec_cost) 467 | lib.plot.plot('dev rec cost', np.mean(dev_rec_costs)) 468 | lib.plot.plot('dev reg cost', np.mean(dev_reg_costs)) 469 | else: 470 | dev_gen_costs = [] 471 | for images,_ in dev_gen(): 472 | _dev_gen_cost = session.run( 473 | gen_cost, 474 | feed_dict={real_x_int: images} 475 | ) 476 | dev_gen_costs.append(_dev_gen_cost) 477 | lib.plot.plot('dev gen cost', np.mean(dev_gen_costs)) 478 | 479 | # Write logs 480 | if (iteration < 5) or (iteration % 100 == 99): 481 | lib.plot.flush(outf, logfile) 482 | 483 | # Calculate inception score 484 | if iteration % 10000 == 9999: 485 | inception_score = get_inception_score() 486 | lib.plot.plot('inception score', inception_score[0]) 487 | lib.plot.plot('inception score std', inception_score[1]) 488 | 489 | lib.plot.tick() 490 | 491 | # Generation and reconstruction 492 | if iteration % 5000 == 4999: 493 | generate_image(iteration, _data) 494 | reconstruct_image(iteration, rand_data_int) 495 | reconstruct_image(-iteration, fixed_data_int) 496 | 497 | # Save model 498 | if iteration == ITERS - 1: 499 | save_path = saver.save(session, os.path.join(outf, '{}_model_{}.ckpt'.format(MODE, iteration))) -------------------------------------------------------------------------------- /gmgan_inference_svhn.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil, time 2 | sys.path.append(os.getcwd()) 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import sklearn.datasets 9 | from sklearn.manifold import TSNE 10 | import tensorflow as tf 11 | 12 | import tflib as lib 13 | import tflib.ops.linear 14 | import tflib.ops.conv2d 15 | import tflib.ops.batchnorm 16 | import tflib.ops.deconv2d 17 | import tflib.save_images 18 | import tflib.svhn 19 | import tflib.plot 20 | import tflib.visualization 21 | import tflib.objs.gan_inference 22 | import tflib.objs.mmd 23 | import tflib.objs.kl 24 | import tflib.objs.kl_aggregated 25 | import tflib.objs.discrete_variables 26 | import tflib.utils.distance 27 | 28 | DATA_DIR = './dataset/svhn' 29 | 30 | ''' 31 | hyperparameters 32 | ''' 33 | MODE = 'local_ep' # ali, local_ep, alice, local_epce, vegan 34 | if MODE in ['vegan-kl', 'vegan-ikl', 'vegan-jsd']: 35 | TYPE_Q = 'learn_std' # learn_std, fix_std, no_std 36 | TYPE_P = 'no_std' 37 | Z_SAMPLES = 100 # MC estimation for D(q(z)||p(z)) 38 | elif MODE is 'vae': 39 | TYPE_Q = 'learn_std' 40 | TYPE_P = 'learn_std' 41 | else: 42 | TYPE_Q = 'no_std' 43 | TYPE_P = 'no_std' 44 | STD = .1 # For fix_std 45 | d_list = ['alice', 'alice-z', 'alice-x', 'vegan', 'vegan-wgan-gp', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vegan-mmd', 'local_epce'] 46 | if MODE in d_list: 47 | DISTANCE_X = 'l2' # l1, l2 48 | if MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 49 | CRITIC_ITERS = 0 # No discriminators 50 | elif MODE in ['vegan', 'vegan-wgan-gp', 'wali', 'wali-gp']: 51 | CRITIC_ITERS = 5 # 5 iters of D per iter of G 52 | else: 53 | CRITIC_ITERS = 1 54 | 55 | BATCH_SIZE = 64 # Batch size 56 | LAMBDA = 1. # Balance reconstruction and regularization in vegan 57 | LR = 2e-4 58 | if MODE in ['vae']: 59 | BETA1 = .9 60 | else: 61 | BETA1 = .5 62 | ITERS = 200000 # How many generator iterations to train for 63 | 64 | DIM = 64 # Model dimensionality 65 | OUTPUT_DIM = 3072 # Number of pixels in SVHN (3*32*32) 66 | if MODE in ['vegan', 'vegan-wgan-gp', 'vegan-kl', 'vegan-jsd', 'vegan-ikl']: 67 | BN_FLAG = False # Use batch_norm or not 68 | DIM_LATENT = 8 # Dimensionality of the latent z 69 | else: 70 | BN_FLAG = False 71 | DIM_LATENT = 128 72 | N_COMS = 50 73 | N_VIS = N_COMS*10 # Number of samples to be visualized 74 | assert(N_VIS%N_COMS==0) 75 | MODE_K = 'CONCRETE' # CONCRETE, REINFORCE, STRAIGHT_THROUGHT_CONCRETE, STRAIGHT_THROUGHT 76 | if MODE_K is 'REINFORCE': 77 | CONTROL_VARIATE = .0 78 | elif MODE_K in ['CONCRETE', 'STRAIGHT_THROUGHT_CONCRETE']: 79 | TEMP_INIT = .1 80 | TEMP = TEMP_INIT 81 | DR_RATE = .2 82 | 83 | 84 | ''' 85 | logs 86 | ''' 87 | filename_script=os.path.basename(os.path.realpath(__file__)) 88 | outf=os.path.join("result", os.path.splitext(filename_script)[0]) 89 | outf+='.MODE-' 90 | outf+=MODE 91 | outf+='.N_COMS-' 92 | outf+=str(N_COMS) 93 | outf+='.' 94 | outf+=str(int(time.time())) 95 | if not os.path.exists(outf): 96 | os.makedirs(outf) 97 | logfile=os.path.join(outf, 'logfile.txt') 98 | shutil.copy(os.path.realpath(__file__), os.path.join(outf, filename_script)) 99 | lib.print_model_settings_to_file(locals().copy(), logfile) 100 | 101 | 102 | ''' 103 | models 104 | ''' 105 | unit_std_x = tf.constant((STD*np.ones(shape=(BATCH_SIZE, OUTPUT_DIM))).astype('float32')) 106 | unit_std_z = tf.constant((STD*np.ones(shape=(BATCH_SIZE, DIM_LATENT))).astype('float32')) 107 | ### prior 108 | PI = tf.constant(np.asarray([1./N_COMS,]*N_COMS, dtype=np.float32)) 109 | prior_k = tf.distributions.Categorical(probs=PI) 110 | 111 | def sample_gumbel(shape, eps=1e-20): 112 | # Sample from Gumbel(0, 1) 113 | U = tf.random_uniform(shape,minval=0,maxval=1) 114 | return -tf.log(-tf.log(U + eps) + eps) 115 | 116 | def LeakyReLU(x, alpha=0.2): 117 | return tf.maximum(alpha*x, x) 118 | 119 | def ReLULayer(name, n_in, n_out, inputs): 120 | output = lib.ops.linear.Linear( 121 | name+'.Linear', 122 | n_in, 123 | n_out, 124 | inputs, 125 | initialization='he' 126 | ) 127 | return tf.nn.relu(output) 128 | 129 | def LeakyReLULayer(name, n_in, n_out, inputs): 130 | output = lib.ops.linear.Linear( 131 | name+'.Linear', 132 | n_in, 133 | n_out, 134 | inputs, 135 | initialization='he' 136 | ) 137 | return LeakyReLU(output) 138 | 139 | def GaussianNoiseLayer(input_layer, std): 140 | noise = tf.random_normal(shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32) 141 | return input_layer + noise 142 | 143 | ### Very simple MoG 144 | def HyperGenerator(hyper_k, hyper_noise): 145 | com_mu = lib.param('Generator.Hyper.Mu', np.random.normal(size=(N_COMS, DIM_LATENT)).astype('float32')) 146 | noise = tf.add(tf.matmul(tf.cast(hyper_k, tf.float32), com_mu), hyper_noise) 147 | return noise 148 | 149 | ### Very simple soft alignment 150 | def HyperExtractor(latent_z): 151 | com_mu = lib.param('Generator.Hyper.Mu', np.random.normal(size=(N_COMS, DIM_LATENT)).astype('float32')) 152 | com_logits = -.5*tf.reduce_sum(tf.pow((tf.expand_dims(latent_z, axis=1) - tf.expand_dims(com_mu, axis=0)), 2), axis=-1) + tf.expand_dims(tf.log(PI), axis=0) 153 | 154 | if MODE_K is 'REINFORCE': 155 | k = tf.one_hot(indices=tf.argmax(com_logits, axis=-1), depth=N_COMS) 156 | elif MODE_K is 'CONCRETE': 157 | k = tf.nn.softmax((com_logits + sample_gumbel(tf.shape(com_logits)))/TEMP) 158 | elif MODE_K is 'STRAIGHT_THROUGHT_CONCRETE': 159 | k = tf.nn.softmax((com_logits + sample_gumbel(tf.shape(com_logits)))/TEMP) 160 | k_hard = tf.one_hot(indices=tf.argmax(k, axis=-1), depth=N_COMS) 161 | k = tf.stop_gradient(k_hard - k) + k 162 | 163 | elif MODE_K is 'STRAIGHT_THROUGHT': 164 | k_hard = tf.one_hot(indices=tf.argmax(com_logits, axis=-1), depth=N_COMS) 165 | k = tf.stop_gradient(k_hard - com_logits) + com_logits 166 | 167 | return com_logits, k 168 | 169 | def Generator(noise): 170 | output = lib.ops.linear.Linear('Generator.Input', DIM_LATENT, 4*4*4*DIM, noise) 171 | if BN_FLAG: 172 | output = lib.ops.batchnorm.Batchnorm('Generator.BN1', [0], output) 173 | output = tf.nn.relu(output) 174 | output = tf.reshape(output, [-1, 4*DIM, 4, 4]) 175 | 176 | output = lib.ops.deconv2d.Deconv2D('Generator.2', 4*DIM, 2*DIM, 5, output) 177 | if BN_FLAG: 178 | output = lib.ops.batchnorm.Batchnorm('Generator.BN2', [0,2,3], output) 179 | output = tf.nn.relu(output) 180 | 181 | output = lib.ops.deconv2d.Deconv2D('Generator.3', 2*DIM, DIM, 5, output) 182 | if BN_FLAG: 183 | output = lib.ops.batchnorm.Batchnorm('Generator.BN3', [0,2,3], output) 184 | output = tf.nn.relu(output) 185 | 186 | output = lib.ops.deconv2d.Deconv2D('Generator.5', DIM, 3, 5, output) 187 | output = tf.tanh(output) 188 | 189 | return tf.reshape(output, [-1, OUTPUT_DIM]), None, None 190 | 191 | def Extractor(inputs): 192 | output = tf.reshape(inputs, [-1, 3, 32, 32]) 193 | 194 | output = lib.ops.conv2d.Conv2D('Extractor.1', 3, DIM, 5, output,stride=2) 195 | output = LeakyReLU(output) 196 | 197 | output = lib.ops.conv2d.Conv2D('Extractor.2', DIM, 2*DIM, 5, output, stride=2) 198 | if BN_FLAG: 199 | output = lib.ops.batchnorm.Batchnorm('Extractor.BN2', [0,2,3], output) 200 | output = LeakyReLU(output) 201 | 202 | output = lib.ops.conv2d.Conv2D('Extractor.3', 2*DIM, 4*DIM, 5, output, stride=2) 203 | if BN_FLAG: 204 | output = lib.ops.batchnorm.Batchnorm('Extractor.BN3', [0,2,3], output) 205 | output = LeakyReLU(output) 206 | 207 | output = tf.reshape(output, [-1, 4*4*4*DIM]) 208 | 209 | if TYPE_Q is 'learn_std': 210 | log_std = lib.ops.linear.Linear('Extractor.Std', 4*4*4*DIM, DIM_LATENT, output) 211 | std = tf.exp(log_std) 212 | elif TYPE_Q is 'fix_std': 213 | std = unit_std_z 214 | else: 215 | std = None 216 | mean = None 217 | 218 | output = lib.ops.linear.Linear('Extractor.Output', 4*4*4*DIM, DIM_LATENT, output) 219 | 220 | if TYPE_Q in ['learn_std', 'fix_std']: 221 | epsilon = tf.random_normal(unit_std_z.shape) 222 | mean = output 223 | output = tf.add(mean, tf.multiply(epsilon, std)) 224 | 225 | return tf.reshape(output, [-1, DIM_LATENT]), mean, std 226 | 227 | if MODE in ['vegan', 'vegan-wgan-gp']: 228 | 229 | def Discriminator(z, k): 230 | output = tf.concat([z, k], 1) 231 | output = lib.ops.linear.Linear('Discriminator.HyperInput', DIM_LATENT+N_COMS, 512, output) 232 | output = LeakyReLU(output) 233 | output = tf.layers.dropout(output, rate=DR_RATE) 234 | 235 | output = lib.ops.linear.Linear('Discriminator.Hyper2', 512, 512, output) 236 | output = LeakyReLU(output) 237 | output = tf.layers.dropout(output, rate=DR_RATE) 238 | 239 | output = lib.ops.linear.Linear('Discriminator.Hyper3', 512, 512, output) 240 | output = LeakyReLU(output) 241 | output = tf.layers.dropout(output, rate=DR_RATE) 242 | 243 | output = lib.ops.linear.Linear('Discriminator.HyperOutput', 512, 1, output) 244 | 245 | return tf.reshape(output, [-1]) 246 | 247 | elif MODE in ['local_ep', 'local_epce']: 248 | 249 | def HyperDiscriminator(z, k): 250 | output = tf.concat([z, k], 1) 251 | output = lib.ops.linear.Linear('Discriminator.HyperInput', DIM_LATENT+N_COMS, 512, output) 252 | output = LeakyReLU(output) 253 | output = tf.layers.dropout(output, rate=DR_RATE) 254 | 255 | output = lib.ops.linear.Linear('Discriminator.Hyper2', 512, 512, output) 256 | output = LeakyReLU(output) 257 | output = tf.layers.dropout(output, rate=DR_RATE) 258 | 259 | output = lib.ops.linear.Linear('Discriminator.Hyper3', 512, 512, output) 260 | output = LeakyReLU(output) 261 | output = tf.layers.dropout(output, rate=DR_RATE) 262 | 263 | output = lib.ops.linear.Linear('Discriminator.HyperOutput', 512, 1, output) 264 | 265 | return tf.reshape(output, [-1]) 266 | 267 | def Discriminator(x, z): 268 | output = tf.reshape(x, [-1, 3, 32, 32]) 269 | 270 | output = lib.ops.conv2d.Conv2D('Discriminator.1',3,DIM,5,output,stride=2) 271 | output = LeakyReLU(output) 272 | output = tf.layers.dropout(output, rate=DR_RATE) 273 | 274 | output = lib.ops.conv2d.Conv2D('Discriminator.2', DIM, 2*DIM, 5, output, stride=2) 275 | output = LeakyReLU(output) 276 | output = tf.layers.dropout(output, rate=DR_RATE) 277 | 278 | output = lib.ops.conv2d.Conv2D('Discriminator.3', 2*DIM, 4*DIM, 5, output, stride=2) 279 | output = LeakyReLU(output) 280 | output = tf.layers.dropout(output, rate=DR_RATE) 281 | 282 | output = tf.reshape(output, [-1, 4*4*4*DIM]) 283 | 284 | z_output = lib.ops.linear.Linear('Discriminator.z1', DIM_LATENT, 512, z) 285 | z_output = LeakyReLU(z_output) 286 | z_output = tf.layers.dropout(z_output, rate=DR_RATE) 287 | 288 | output = tf.concat([output, z_output], 1) 289 | output = lib.ops.linear.Linear('Discriminator.zx1', 4*4*4*DIM+512, 512, output) 290 | output = LeakyReLU(output) 291 | output = tf.layers.dropout(output, rate=DR_RATE) 292 | 293 | output = lib.ops.linear.Linear('Discriminator.Output', 512, 1, output) 294 | 295 | return tf.reshape(output, [-1]) 296 | 297 | elif MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 298 | pass # no discriminator 299 | 300 | else: 301 | def Discriminator(x, z, k): 302 | output = tf.reshape(x, [-1, 3, 32, 32]) 303 | 304 | output = lib.ops.conv2d.Conv2D('Discriminator.x1',3,DIM,5,output,stride=2) 305 | output = LeakyReLU(output) 306 | output = tf.layers.dropout(output, rate=DR_RATE) 307 | 308 | output = lib.ops.conv2d.Conv2D('Discriminator.x2', DIM, 2*DIM, 5, output, stride=2) 309 | output = LeakyReLU(output) 310 | output = tf.layers.dropout(output, rate=DR_RATE) 311 | 312 | output = lib.ops.conv2d.Conv2D('Discriminator.x3', 2*DIM, 4*DIM, 5, output, stride=2) 313 | output = LeakyReLU(output) 314 | output = tf.layers.dropout(output, rate=DR_RATE) 315 | 316 | output = tf.reshape(output, [-1, 4*4*4*DIM]) 317 | 318 | zk_output = tf.concat([z, k], 1) 319 | zk_output = lib.ops.linear.Linear('Discriminator.zk1', DIM_LATENT+N_COMS, 512, zk_output) 320 | zk_output = LeakyReLU(zk_output) 321 | zk_output = tf.layers.dropout(zk_output, rate=DR_RATE) 322 | 323 | output = tf.concat([output, zk_output], 1) 324 | output = lib.ops.linear.Linear('Discriminator.zkx1', 4*4*4*DIM+512, 512, output) 325 | output = LeakyReLU(output) 326 | output = tf.layers.dropout(output, rate=DR_RATE) 327 | 328 | output = lib.ops.linear.Linear('Discriminator.Output', 512, 1, output) 329 | 330 | return tf.reshape(output, [-1]) 331 | 332 | ''' 333 | losses 334 | ''' 335 | real_x_int = tf.placeholder(tf.int32, shape=[BATCH_SIZE, OUTPUT_DIM]) 336 | real_x = 2*((tf.cast(real_x_int, tf.float32)/255.)-.5) 337 | q_z, q_z_mean, q_z_std = Extractor(real_x) 338 | q_k_logits, q_k = HyperExtractor(q_z) 339 | q_k_probs = tf.nn.softmax(q_k_logits) 340 | if MODE_K is 'REINFORCE': 341 | q_k_prob_max = tf.reduce_max(q_k_probs, axis=1) 342 | rec_x, rec_x_mean, rec_x_std = Generator(q_z) 343 | hyper_p_z = tf.random_normal([BATCH_SIZE, DIM_LATENT]) 344 | hyper_p_k = tf.one_hot(indices=prior_k.sample(BATCH_SIZE), depth=N_COMS) 345 | p_z = HyperGenerator(hyper_p_k, hyper_p_z) 346 | fake_x, _, _ = Generator(p_z) 347 | rec_z, _, _ = Extractor(fake_x) 348 | rec_q_k_logits, rec_q_k = HyperExtractor(rec_z) 349 | 350 | if MODE_K is not 'REINFORCE': 351 | score_function = None 352 | 353 | if MODE is 'vegan': 354 | disc_fake = Discriminator(p_z, hyper_p_k) 355 | disc_real = Discriminator(q_z, q_k) 356 | if MODE_K is 'REINFORCE': 357 | score_function = lib.objs.discrete_variables.score_function(disc_real, q_k_prob_max, CONTROL_VARIATE) 358 | 359 | elif MODE in ['local_ep', 'local_epce']: 360 | disc_fake, disc_real = [],[] 361 | disc_fake.append(HyperDiscriminator(p_z, hyper_p_k)) 362 | disc_real.append(HyperDiscriminator(q_z, q_k)) 363 | disc_fake.append(Discriminator(fake_x, p_z)) 364 | disc_real.append(Discriminator(real_x, q_z)) 365 | 366 | if MODE_K is 'REINFORCE': 367 | score_function = lib.objs.discrete_variables.score_function(disc_real[0], q_k_prob_max, CONTROL_VARIATE) 368 | 369 | else: 370 | disc_real = Discriminator(real_x, q_z, q_k) 371 | disc_fake = Discriminator(fake_x, p_z, hyper_p_k) 372 | if MODE_K is 'REINFORCE': 373 | score_function = lib.objs.discrete_variables.score_function(disc_real, q_k_prob_max, CONTROL_VARIATE) 374 | 375 | gen_params = lib.params_with_name('Generator') 376 | ext_params = lib.params_with_name('Extractor') 377 | disc_params = lib.params_with_name('Discriminator') 378 | 379 | if MODE == 'ali': 380 | rec_penalty = None 381 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.ali(disc_fake, disc_real, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1, s_f=score_function) 382 | 383 | elif MODE == 'alice': 384 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 385 | # rec_penalty += 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 386 | # rec_penalty += 1.*tf.nn.softmax_cross_entropy_with_logits(labels=hyper_p_k, logits=rec_q_k_logits) 387 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.alice(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1, s_f=score_function) 388 | 389 | elif MODE == 'local_ep': 390 | rec_penalty = None 391 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.local_ep(disc_fake, disc_real, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1, s_f=score_function) 392 | 393 | elif MODE == 'local_epce': 394 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 395 | # rec_penalty += 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 396 | # rec_penalty += 1.*tf.nn.softmax_cross_entropy_with_logits(labels=hyper_p_k, logits=rec_q_k_logits) 397 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.local_epce(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1, s_f=score_function) 398 | 399 | elif MODE == 'vegan': 400 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 401 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.vegan(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, LAMBDA,lr=LR, beta1=BETA1, s_f=score_function) 402 | 403 | else: 404 | raise('NotImplementedError') 405 | 406 | # For visualizing samples 407 | # np_fixed_noise = np.repeat(np.random.normal(size=(N_VIS/N_COMS, DIM_LATENT)).astype('float32'), N_COMS, axis=0) 408 | np_fixed_noise = np.random.normal(size=(N_VIS, DIM_LATENT)).astype('float32') 409 | np_fixed_k = np.tile(np.eye(N_COMS, dtype=int), (N_VIS/N_COMS, 1)) 410 | hyper_fixed_noise = tf.constant(np_fixed_noise) 411 | hyper_fixed_k = tf.constant(np_fixed_k) 412 | fixed_noise = HyperGenerator(hyper_fixed_k, hyper_fixed_noise) 413 | fixed_noise_samples, _, _ = Generator(fixed_noise) 414 | def generate_image(frame, true_dist): 415 | samples = session.run(fixed_noise_samples) 416 | samples = ((samples+1.)*(255./2)).astype('int32') 417 | lib.save_images.save_images( 418 | samples.reshape((-1, 3, 32, 32)), 419 | os.path.join(outf, '{}_samples_{}.png'.format(frame, MODE)), 420 | size = [N_VIS/N_COMS, N_COMS] 421 | ) 422 | 423 | # Dataset iterator 424 | train_gen, dev_gen = lib.svhn.load(BATCH_SIZE, data_dir=DATA_DIR) 425 | def inf_train_gen(): 426 | while True: 427 | for images, targets in train_gen(): 428 | yield images 429 | 430 | # For reconstruction 431 | rand_data_int, _ = dev_gen().next() 432 | def reconstruct_image(frame, data_int): 433 | rec_samples = session.run(rec_x, feed_dict={real_x_int: data_int}) 434 | rec_samples = ((rec_samples+1.)*(255./2)).astype('int32') 435 | tmp_list = [] 436 | for d, r in zip(data_int, rec_samples): 437 | tmp_list.append(d) 438 | tmp_list.append(r) 439 | rec_samples = np.vstack(tmp_list) 440 | lib.save_images.save_images( 441 | rec_samples.reshape((-1, 3, 32, 32)), 442 | os.path.join(outf, '{}_reconstruction_{}.png'.format(MODE, frame)) 443 | ) 444 | saver = tf.train.Saver() 445 | 446 | ''' 447 | Train loop 448 | ''' 449 | with tf.Session() as session: 450 | 451 | session.run(tf.global_variables_initializer()) 452 | gen = inf_train_gen() 453 | 454 | total_num = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) 455 | print '\nTotol number of parameters', total_num 456 | with open(logfile,'a') as f: 457 | f.write('Totol number of parameters' + str(total_num) + '\n') 458 | 459 | for iteration in xrange(ITERS): 460 | start_time = time.time() 461 | 462 | if iteration > 0: 463 | _data = gen.next() 464 | _gen_cost, _ = session.run([gen_cost, gen_train_op], 465 | feed_dict={real_x_int: _data} 466 | ) 467 | 468 | for i in xrange(CRITIC_ITERS): 469 | _data = gen.next() 470 | _disc_cost, _ = session.run( 471 | [disc_cost, disc_train_op], 472 | feed_dict={real_x_int: _data} 473 | ) 474 | if MODE is 'wali': 475 | _ = session.run(clip_disc_weights) 476 | 477 | if MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 478 | if iteration > 0: 479 | lib.plot.plot('train gen cost ', _gen_cost) 480 | else: 481 | lib.plot.plot('train disc cost', _disc_cost) 482 | lib.plot.plot('time', time.time() - start_time) 483 | 484 | # Calculate dev loss 485 | if iteration % 100 == 99: 486 | if rec_penalty is not None: 487 | dev_rec_costs = [] 488 | dev_reg_costs = [] 489 | for images,_ in dev_gen(): 490 | _dev_rec_cost, _dev_gen_cost = session.run( 491 | [rec_penalty, gen_cost], 492 | feed_dict={real_x_int: images} 493 | ) 494 | dev_rec_costs.append(_dev_rec_cost) 495 | dev_reg_costs.append(_dev_gen_cost - _dev_rec_cost) 496 | lib.plot.plot('dev rec cost', np.mean(dev_rec_costs)) 497 | lib.plot.plot('dev reg cost', np.mean(dev_reg_costs)) 498 | else: 499 | dev_gen_costs = [] 500 | for images,_ in dev_gen(): 501 | _dev_gen_cost = session.run( 502 | gen_cost, 503 | feed_dict={real_x_int: images} 504 | ) 505 | dev_gen_costs.append(_dev_gen_cost) 506 | lib.plot.plot('dev gen cost', np.mean(dev_gen_costs)) 507 | 508 | # Write logs 509 | if (iteration < 5) or (iteration % 100 == 99): 510 | lib.plot.flush(outf, logfile) 511 | lib.plot.tick() 512 | 513 | # Generation and reconstruction 514 | if iteration % 5000 == 4999: 515 | generate_image(iteration, _data) 516 | reconstruct_image(iteration, rand_data_int) 517 | # reconstruct_image(-iteration, fixed_data_int) 518 | 519 | # Save model 520 | if iteration == ITERS - 1: 521 | save_path = saver.save(session, os.path.join(outf, '{}_model_{}.ckpt'.format(iteration, MODE))) -------------------------------------------------------------------------------- /gmgan_inference_cifar10.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil, time 2 | sys.path.append(os.getcwd()) 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import sklearn.datasets 9 | from sklearn.manifold import TSNE 10 | import tensorflow as tf 11 | 12 | import tflib as lib 13 | import tflib.ops.linear 14 | import tflib.ops.conv2d 15 | import tflib.ops.batchnorm 16 | import tflib.ops.deconv2d 17 | import tflib.save_images 18 | import tflib.cifar10 19 | import tflib.inception_score 20 | import tflib.plot 21 | import tflib.visualization 22 | import tflib.objs.gan_inference 23 | import tflib.objs.mmd 24 | import tflib.objs.kl 25 | import tflib.objs.kl_aggregated 26 | import tflib.objs.discrete_variables 27 | import tflib.utils.distance 28 | 29 | # Download CIFAR-10 (Python version) at 30 | # https://www.cs.toronto.edu/~kriz/cifar.html and fill in the path to the 31 | # extracted files here! 32 | DATA_DIR = './dataset/cifar10/cifar-10-batches-py' 33 | if len(DATA_DIR) == 0: 34 | raise Exception('Please specify path to data directory in gan_cifar.py!') 35 | 36 | ''' 37 | hyperparameters 38 | ''' 39 | MODE = 'local_ep' # ali, local_ep, alice, local_epce, vegan 40 | if MODE in ['vegan-kl', 'vegan-ikl', 'vegan-jsd']: 41 | TYPE_Q = 'learn_std' # learn_std, fix_std, no_std 42 | TYPE_P = 'no_std' 43 | Z_SAMPLES = 100 # MC estimation for D(q(z)||p(z)) 44 | elif MODE is 'vae': 45 | TYPE_Q = 'learn_std' 46 | TYPE_P = 'learn_std' 47 | else: 48 | TYPE_Q = 'no_std' 49 | TYPE_P = 'no_std' 50 | STD = .1 # For fix_std 51 | d_list = ['alice', 'alice-z', 'alice-x', 'vegan', 'vegan-wgan-gp', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vegan-mmd', 'local_epce'] 52 | if MODE in d_list: 53 | DISTANCE_X = 'l2' # l1, l2 54 | if MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 55 | CRITIC_ITERS = 0 # No discriminators 56 | elif MODE in ['vegan', 'vegan-wgan-gp', 'wali', 'wali-gp']: 57 | CRITIC_ITERS = 5 # 5 iters of D per iter of G 58 | else: 59 | CRITIC_ITERS = 1 60 | 61 | BATCH_SIZE = 64 # Batch size 62 | LAMBDA = 1. # Balance reconstruction and regularization in vegan 63 | LR = 2e-4 64 | if MODE in ['vae']: 65 | BETA1 = .9 66 | else: 67 | BETA1 = .5 68 | ITERS = 200000 # How many generator iterations to train for 69 | 70 | DIM = 64 # Model dimensionality 71 | OUTPUT_DIM = 3072 # Number of pixels in CIFAR10 (3*32*32) 72 | if MODE in ['vegan', 'vegan-wgan-gp', 'vegan-kl', 'vegan-jsd', 'vegan-ikl']: 73 | BN_FLAG = False # Use batch_norm or not 74 | DIM_LATENT = 8 # Dimensionality of the latent z 75 | else: 76 | BN_FLAG = True 77 | DIM_LATENT = 128 78 | N_COMS = 30 79 | N_VIS = N_COMS*10 # Number of samples to be visualized 80 | assert(N_VIS%N_COMS==0) 81 | MODE_K = 'CONCRETE' # CONCRETE, REINFORCE, STRAIGHT_THROUGHT_CONCRETE, STRAIGHT_THROUGHT 82 | if MODE_K is 'REINFORCE': 83 | CONTROL_VARIATE = .0 84 | elif MODE_K in ['CONCRETE', 'STRAIGHT_THROUGHT_CONCRETE']: 85 | TEMP_INIT = .1 86 | TEMP = TEMP_INIT 87 | DR_RATE = .2 88 | 89 | 90 | ''' 91 | logs 92 | ''' 93 | filename_script=os.path.basename(os.path.realpath(__file__)) 94 | outf=os.path.join("result", os.path.splitext(filename_script)[0]) 95 | outf+='.MODE-' 96 | outf+=MODE 97 | outf+='.N_COMS-' 98 | outf+=str(N_COMS) 99 | outf+='.' 100 | outf+=str(int(time.time())) 101 | if not os.path.exists(outf): 102 | os.makedirs(outf) 103 | logfile=os.path.join(outf, 'logfile.txt') 104 | shutil.copy(os.path.realpath(__file__), os.path.join(outf, filename_script)) 105 | lib.print_model_settings_to_file(locals().copy(), logfile) 106 | 107 | 108 | ''' 109 | models 110 | ''' 111 | unit_std_x = tf.constant((STD*np.ones(shape=(BATCH_SIZE, OUTPUT_DIM))).astype('float32')) 112 | unit_std_z = tf.constant((STD*np.ones(shape=(BATCH_SIZE, DIM_LATENT))).astype('float32')) 113 | ### prior 114 | PI = tf.constant(np.asarray([1./N_COMS,]*N_COMS, dtype=np.float32)) 115 | prior_k = tf.distributions.Categorical(probs=PI) 116 | 117 | def sample_gumbel(shape, eps=1e-20): 118 | # Sample from Gumbel(0, 1) 119 | U = tf.random_uniform(shape,minval=0,maxval=1) 120 | return -tf.log(-tf.log(U + eps) + eps) 121 | 122 | def LeakyReLU(x, alpha=0.2): 123 | return tf.maximum(alpha*x, x) 124 | 125 | def ReLULayer(name, n_in, n_out, inputs): 126 | output = lib.ops.linear.Linear( 127 | name+'.Linear', 128 | n_in, 129 | n_out, 130 | inputs, 131 | initialization='he' 132 | ) 133 | return tf.nn.relu(output) 134 | 135 | def LeakyReLULayer(name, n_in, n_out, inputs): 136 | output = lib.ops.linear.Linear( 137 | name+'.Linear', 138 | n_in, 139 | n_out, 140 | inputs, 141 | initialization='he' 142 | ) 143 | return LeakyReLU(output) 144 | 145 | def GaussianNoiseLayer(input_layer, std): 146 | noise = tf.random_normal(shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32) 147 | return input_layer + noise 148 | 149 | ### Very simple MoG 150 | def HyperGenerator(hyper_k, hyper_noise): 151 | com_mu = lib.param('Generator.Hyper.Mu', np.random.normal(size=(N_COMS, DIM_LATENT)).astype('float32')) 152 | noise = tf.add(tf.matmul(tf.cast(hyper_k, tf.float32), com_mu), hyper_noise) 153 | return noise 154 | 155 | ### Very simple soft alignment 156 | def HyperExtractor(latent_z): 157 | com_mu = lib.param('Generator.Hyper.Mu', np.random.normal(size=(N_COMS, DIM_LATENT)).astype('float32')) 158 | com_logits = -.5*tf.reduce_sum(tf.pow((tf.expand_dims(latent_z, axis=1) - tf.expand_dims(com_mu, axis=0)), 2), axis=-1) + tf.expand_dims(tf.log(PI), axis=0) 159 | 160 | if MODE_K is 'REINFORCE': 161 | k = tf.one_hot(indices=tf.argmax(com_logits, axis=-1), depth=N_COMS) 162 | elif MODE_K is 'CONCRETE': 163 | k = tf.nn.softmax((com_logits + sample_gumbel(tf.shape(com_logits)))/TEMP) 164 | elif MODE_K is 'STRAIGHT_THROUGHT_CONCRETE': 165 | k = tf.nn.softmax((com_logits + sample_gumbel(tf.shape(com_logits)))/TEMP) 166 | k_hard = tf.one_hot(indices=tf.argmax(k, axis=-1), depth=N_COMS) 167 | k = tf.stop_gradient(k_hard - k) + k 168 | 169 | elif MODE_K is 'STRAIGHT_THROUGHT': 170 | k_hard = tf.one_hot(indices=tf.argmax(com_logits, axis=-1), depth=N_COMS) 171 | k = tf.stop_gradient(k_hard - com_logits) + com_logits 172 | 173 | return com_logits, k 174 | 175 | def Generator(noise): 176 | output = lib.ops.linear.Linear('Generator.Input', DIM_LATENT, 4*4*4*DIM, noise) 177 | if BN_FLAG: 178 | output = lib.ops.batchnorm.Batchnorm('Generator.BN1', [0], output) 179 | output = tf.nn.relu(output) 180 | output = tf.reshape(output, [-1, 4*DIM, 4, 4]) 181 | 182 | output = lib.ops.deconv2d.Deconv2D('Generator.2', 4*DIM, 2*DIM, 5, output) 183 | if BN_FLAG: 184 | output = lib.ops.batchnorm.Batchnorm('Generator.BN2', [0,2,3], output) 185 | output = tf.nn.relu(output) 186 | 187 | output = lib.ops.deconv2d.Deconv2D('Generator.3', 2*DIM, DIM, 5, output) 188 | if BN_FLAG: 189 | output = lib.ops.batchnorm.Batchnorm('Generator.BN3', [0,2,3], output) 190 | output = tf.nn.relu(output) 191 | 192 | output = lib.ops.deconv2d.Deconv2D('Generator.5', DIM, 3, 5, output) 193 | output = tf.tanh(output) 194 | 195 | return tf.reshape(output, [-1, OUTPUT_DIM]), None, None 196 | 197 | def Extractor(inputs): 198 | output = tf.reshape(inputs, [-1, 3, 32, 32]) 199 | 200 | output = lib.ops.conv2d.Conv2D('Extractor.1', 3, DIM, 5, output,stride=2) 201 | output = LeakyReLU(output) 202 | 203 | output = lib.ops.conv2d.Conv2D('Extractor.2', DIM, 2*DIM, 5, output, stride=2) 204 | if BN_FLAG: 205 | output = lib.ops.batchnorm.Batchnorm('Extractor.BN2', [0,2,3], output) 206 | output = LeakyReLU(output) 207 | 208 | output = lib.ops.conv2d.Conv2D('Extractor.3', 2*DIM, 4*DIM, 5, output, stride=2) 209 | if BN_FLAG: 210 | output = lib.ops.batchnorm.Batchnorm('Extractor.BN3', [0,2,3], output) 211 | output = LeakyReLU(output) 212 | 213 | output = tf.reshape(output, [-1, 4*4*4*DIM]) 214 | 215 | if TYPE_Q is 'learn_std': 216 | log_std = lib.ops.linear.Linear('Extractor.Std', 4*4*4*DIM, DIM_LATENT, output) 217 | std = tf.exp(log_std) 218 | elif TYPE_Q is 'fix_std': 219 | std = unit_std_z 220 | else: 221 | std = None 222 | mean = None 223 | 224 | output = lib.ops.linear.Linear('Extractor.Output', 4*4*4*DIM, DIM_LATENT, output) 225 | 226 | if TYPE_Q in ['learn_std', 'fix_std']: 227 | epsilon = tf.random_normal(unit_std_z.shape) 228 | mean = output 229 | output = tf.add(mean, tf.multiply(epsilon, std)) 230 | 231 | return tf.reshape(output, [-1, DIM_LATENT]), mean, std 232 | 233 | if MODE in ['vegan', 'vegan-wgan-gp']: 234 | 235 | def Discriminator(z, k): 236 | output = tf.concat([z, k], 1) 237 | output = lib.ops.linear.Linear('Discriminator.HyperInput', DIM_LATENT+N_COMS, 512, output) 238 | output = LeakyReLU(output) 239 | output = tf.layers.dropout(output, rate=DR_RATE) 240 | 241 | output = lib.ops.linear.Linear('Discriminator.Hyper2', 512, 512, output) 242 | output = LeakyReLU(output) 243 | output = tf.layers.dropout(output, rate=DR_RATE) 244 | 245 | output = lib.ops.linear.Linear('Discriminator.Hyper3', 512, 512, output) 246 | output = LeakyReLU(output) 247 | output = tf.layers.dropout(output, rate=DR_RATE) 248 | 249 | output = lib.ops.linear.Linear('Discriminator.HyperOutput', 512, 1, output) 250 | 251 | return tf.reshape(output, [-1]) 252 | 253 | elif MODE in ['local_ep', 'local_epce']: 254 | 255 | def HyperDiscriminator(z, k): 256 | output = tf.concat([z, k], 1) 257 | output = lib.ops.linear.Linear('Discriminator.HyperInput', DIM_LATENT+N_COMS, 512, output) 258 | output = LeakyReLU(output) 259 | output = tf.layers.dropout(output, rate=DR_RATE) 260 | 261 | output = lib.ops.linear.Linear('Discriminator.Hyper2', 512, 512, output) 262 | output = LeakyReLU(output) 263 | output = tf.layers.dropout(output, rate=DR_RATE) 264 | 265 | output = lib.ops.linear.Linear('Discriminator.Hyper3', 512, 512, output) 266 | output = LeakyReLU(output) 267 | output = tf.layers.dropout(output, rate=DR_RATE) 268 | 269 | output = lib.ops.linear.Linear('Discriminator.HyperOutput', 512, 1, output) 270 | 271 | return tf.reshape(output, [-1]) 272 | 273 | def Discriminator(x, z): 274 | output = tf.reshape(x, [-1, 3, 32, 32]) 275 | 276 | output = lib.ops.conv2d.Conv2D('Discriminator.1',3,DIM,5,output,stride=2) 277 | output = LeakyReLU(output) 278 | output = tf.layers.dropout(output, rate=DR_RATE) 279 | 280 | output = lib.ops.conv2d.Conv2D('Discriminator.2', DIM, 2*DIM, 5, output, stride=2) 281 | output = LeakyReLU(output) 282 | output = tf.layers.dropout(output, rate=DR_RATE) 283 | 284 | output = lib.ops.conv2d.Conv2D('Discriminator.3', 2*DIM, 4*DIM, 5, output, stride=2) 285 | output = LeakyReLU(output) 286 | output = tf.layers.dropout(output, rate=DR_RATE) 287 | 288 | output = tf.reshape(output, [-1, 4*4*4*DIM]) 289 | 290 | z_output = lib.ops.linear.Linear('Discriminator.z1', DIM_LATENT, 512, z) 291 | z_output = LeakyReLU(z_output) 292 | z_output = tf.layers.dropout(z_output, rate=DR_RATE) 293 | 294 | output = tf.concat([output, z_output], 1) 295 | output = lib.ops.linear.Linear('Discriminator.zx1', 4*4*4*DIM+512, 512, output) 296 | output = LeakyReLU(output) 297 | output = tf.layers.dropout(output, rate=DR_RATE) 298 | 299 | output = lib.ops.linear.Linear('Discriminator.Output', 512, 1, output) 300 | 301 | return tf.reshape(output, [-1]) 302 | 303 | elif MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 304 | pass # no discriminator 305 | 306 | else: 307 | def Discriminator(x, z, k): 308 | output = tf.reshape(x, [-1, 3, 32, 32]) 309 | 310 | output = lib.ops.conv2d.Conv2D('Discriminator.x1',3,DIM,5,output,stride=2) 311 | output = LeakyReLU(output) 312 | output = tf.layers.dropout(output, rate=DR_RATE) 313 | 314 | output = lib.ops.conv2d.Conv2D('Discriminator.x2', DIM, 2*DIM, 5, output, stride=2) 315 | output = LeakyReLU(output) 316 | output = tf.layers.dropout(output, rate=DR_RATE) 317 | 318 | output = lib.ops.conv2d.Conv2D('Discriminator.x3', 2*DIM, 4*DIM, 5, output, stride=2) 319 | output = LeakyReLU(output) 320 | output = tf.layers.dropout(output, rate=DR_RATE) 321 | 322 | output = tf.reshape(output, [-1, 4*4*4*DIM]) 323 | 324 | zk_output = tf.concat([z, k], 1) 325 | zk_output = lib.ops.linear.Linear('Discriminator.zk1', DIM_LATENT+N_COMS, 512, zk_output) 326 | zk_output = LeakyReLU(zk_output) 327 | zk_output = tf.layers.dropout(zk_output, rate=DR_RATE) 328 | 329 | output = tf.concat([output, zk_output], 1) 330 | output = lib.ops.linear.Linear('Discriminator.zkx1', 4*4*4*DIM+512, 512, output) 331 | output = LeakyReLU(output) 332 | output = tf.layers.dropout(output, rate=DR_RATE) 333 | 334 | output = lib.ops.linear.Linear('Discriminator.Output', 512, 1, output) 335 | 336 | return tf.reshape(output, [-1]) 337 | 338 | ''' 339 | losses 340 | ''' 341 | real_x_int = tf.placeholder(tf.int32, shape=[BATCH_SIZE, OUTPUT_DIM]) 342 | real_x = 2*((tf.cast(real_x_int, tf.float32)/255.)-.5) 343 | q_z, q_z_mean, q_z_std = Extractor(real_x) 344 | q_k_logits, q_k = HyperExtractor(q_z) 345 | q_k_probs = tf.nn.softmax(q_k_logits) 346 | if MODE_K is 'REINFORCE': 347 | q_k_prob_max = tf.reduce_max(q_k_probs, axis=1) 348 | rec_x, rec_x_mean, rec_x_std = Generator(q_z) 349 | hyper_p_z = tf.random_normal([BATCH_SIZE, DIM_LATENT]) 350 | hyper_p_k = tf.one_hot(indices=prior_k.sample(BATCH_SIZE), depth=N_COMS) 351 | p_z = HyperGenerator(hyper_p_k, hyper_p_z) 352 | fake_x, _, _ = Generator(p_z) 353 | rec_z, _, _ = Extractor(fake_x) 354 | rec_q_k_logits, rec_q_k = HyperExtractor(rec_z) 355 | 356 | if MODE_K is not 'REINFORCE': 357 | score_function = None 358 | 359 | if MODE is 'vegan': 360 | disc_fake = Discriminator(p_z, hyper_p_k) 361 | disc_real = Discriminator(q_z, q_k) 362 | if MODE_K is 'REINFORCE': 363 | score_function = lib.objs.discrete_variables.score_function(disc_real, q_k_prob_max, CONTROL_VARIATE) 364 | 365 | elif MODE in ['local_ep', 'local_epce']: 366 | disc_fake, disc_real = [],[] 367 | disc_fake.append(HyperDiscriminator(p_z, hyper_p_k)) 368 | disc_real.append(HyperDiscriminator(q_z, q_k)) 369 | disc_fake.append(Discriminator(fake_x, p_z)) 370 | disc_real.append(Discriminator(real_x, q_z)) 371 | 372 | if MODE_K is 'REINFORCE': 373 | score_function = lib.objs.discrete_variables.score_function(disc_real[0], q_k_prob_max, CONTROL_VARIATE) 374 | 375 | else: 376 | disc_real = Discriminator(real_x, q_z, q_k) 377 | disc_fake = Discriminator(fake_x, p_z, hyper_p_k) 378 | if MODE_K is 'REINFORCE': 379 | score_function = lib.objs.discrete_variables.score_function(disc_real, q_k_prob_max, CONTROL_VARIATE) 380 | 381 | gen_params = lib.params_with_name('Generator') 382 | ext_params = lib.params_with_name('Extractor') 383 | disc_params = lib.params_with_name('Discriminator') 384 | 385 | if MODE == 'ali': 386 | rec_penalty = None 387 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.ali(disc_fake, disc_real, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1, s_f=score_function) 388 | 389 | elif MODE == 'alice': 390 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 391 | # rec_penalty += 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 392 | # rec_penalty += 1.*tf.nn.softmax_cross_entropy_with_logits(labels=hyper_p_k, logits=rec_q_k_logits) 393 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.alice(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1, s_f=score_function) 394 | 395 | elif MODE == 'local_ep': 396 | rec_penalty = None 397 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.local_ep(disc_fake, disc_real, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1, s_f=score_function) 398 | 399 | elif MODE == 'local_epce': 400 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 401 | # rec_penalty += 1.*lib.utils.distance.distance(p_z, rec_z, DISTANCE_X) 402 | # rec_penalty += 1.*tf.nn.softmax_cross_entropy_with_logits(labels=hyper_p_k, logits=rec_q_k_logits) 403 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.local_epce(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, lr=LR, beta1=BETA1, s_f=score_function) 404 | 405 | elif MODE == 'vegan': 406 | rec_penalty = 1.*lib.utils.distance.distance(real_x, rec_x, DISTANCE_X) 407 | gen_cost, disc_cost, gen_train_op, disc_train_op = lib.objs.gan_inference.vegan(disc_fake, disc_real, rec_penalty, gen_params+ext_params, disc_params, LAMBDA,lr=LR, beta1=BETA1, s_f=score_function) 408 | 409 | else: 410 | raise('NotImplementedError') 411 | 412 | # For visualizing samples 413 | # np_fixed_noise = np.repeat(np.random.normal(size=(N_VIS/N_COMS, DIM_LATENT)).astype('float32'), N_COMS, axis=0) 414 | np_fixed_noise = np.random.normal(size=(N_VIS, DIM_LATENT)).astype('float32') 415 | np_fixed_k = np.tile(np.eye(N_COMS, dtype=int), (N_VIS/N_COMS, 1)) 416 | hyper_fixed_noise = tf.constant(np_fixed_noise) 417 | hyper_fixed_k = tf.constant(np_fixed_k) 418 | fixed_noise = HyperGenerator(hyper_fixed_k, hyper_fixed_noise) 419 | fixed_noise_samples, _, _ = Generator(fixed_noise) 420 | def generate_image(frame, true_dist): 421 | samples = session.run(fixed_noise_samples) 422 | samples = ((samples+1.)*(255./2)).astype('int32') 423 | lib.save_images.save_images( 424 | samples.reshape((-1, 3, 32, 32)), 425 | os.path.join(outf, '{}_samples_{}.png'.format(frame, MODE)), 426 | size = [N_VIS/N_COMS, N_COMS] 427 | ) 428 | 429 | # For calculating inception score 430 | hyper_p_z_is = tf.random_normal([BATCH_SIZE, DIM_LATENT]) 431 | hyper_p_k_is = tf.one_hot(indices=prior_k.sample(BATCH_SIZE), depth=N_COMS) 432 | p_z_is = HyperGenerator(hyper_p_k_is, hyper_p_z_is) 433 | samples_is, _, _ = Generator(p_z_is) 434 | def get_inception_score(): 435 | all_samples = [] 436 | for i in xrange(50000/BATCH_SIZE): 437 | all_samples.append(session.run(samples_is)) 438 | all_samples = np.concatenate(all_samples, axis=0) 439 | all_samples = ((all_samples+1.)*(255./2)).astype('int32') 440 | all_samples = all_samples.reshape((-1, 3, 32, 32)).transpose(0,2,3,1) 441 | return lib.inception_score.get_inception_score(list(all_samples)) 442 | 443 | # Dataset iterator 444 | train_gen, dev_gen = lib.cifar10.load(BATCH_SIZE, data_dir=DATA_DIR) 445 | def inf_train_gen(): 446 | while True: 447 | for images, targets in train_gen(): 448 | yield images 449 | 450 | # For reconstruction 451 | fixed_data_int = lib.cifar10.get_reconstruction_data(BATCH_SIZE, data_dir=DATA_DIR) 452 | rand_data_int, _ = dev_gen().next() 453 | def reconstruct_image(frame, data_int): 454 | rec_samples = session.run(rec_x, feed_dict={real_x_int: data_int}) 455 | rec_samples = ((rec_samples+1.)*(255./2)).astype('int32') 456 | tmp_list = [] 457 | for d, r in zip(data_int, rec_samples): 458 | tmp_list.append(d) 459 | tmp_list.append(r) 460 | rec_samples = np.vstack(tmp_list) 461 | lib.save_images.save_images( 462 | rec_samples.reshape((-1, 3, 32, 32)), 463 | os.path.join(outf, '{}_reconstruction_{}.png'.format(MODE, frame)) 464 | ) 465 | saver = tf.train.Saver() 466 | 467 | ''' 468 | Train loop 469 | ''' 470 | with tf.Session() as session: 471 | 472 | session.run(tf.global_variables_initializer()) 473 | gen = inf_train_gen() 474 | 475 | total_num = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) 476 | print '\nTotol number of parameters', total_num 477 | with open(logfile,'a') as f: 478 | f.write('Totol number of parameters' + str(total_num) + '\n') 479 | 480 | for iteration in xrange(ITERS): 481 | start_time = time.time() 482 | 483 | if iteration > 0: 484 | _data = gen.next() 485 | _gen_cost, _ = session.run([gen_cost, gen_train_op], 486 | feed_dict={real_x_int: _data} 487 | ) 488 | 489 | for i in xrange(CRITIC_ITERS): 490 | _data = gen.next() 491 | _disc_cost, _ = session.run( 492 | [disc_cost, disc_train_op], 493 | feed_dict={real_x_int: _data} 494 | ) 495 | if MODE is 'wali': 496 | _ = session.run(clip_disc_weights) 497 | 498 | if MODE in ['vegan-mmd', 'vegan-kl', 'vegan-ikl', 'vegan-jsd', 'vae']: 499 | if iteration > 0: 500 | lib.plot.plot('train gen cost ', _gen_cost) 501 | else: 502 | lib.plot.plot('train disc cost', _disc_cost) 503 | lib.plot.plot('time', time.time() - start_time) 504 | 505 | # Calculate dev loss 506 | if iteration % 100 == 99: 507 | if rec_penalty is not None: 508 | dev_rec_costs = [] 509 | dev_reg_costs = [] 510 | for images,_ in dev_gen(): 511 | _dev_rec_cost, _dev_gen_cost = session.run( 512 | [rec_penalty, gen_cost], 513 | feed_dict={real_x_int: images} 514 | ) 515 | dev_rec_costs.append(_dev_rec_cost) 516 | dev_reg_costs.append(_dev_gen_cost - _dev_rec_cost) 517 | lib.plot.plot('dev rec cost', np.mean(dev_rec_costs)) 518 | lib.plot.plot('dev reg cost', np.mean(dev_reg_costs)) 519 | else: 520 | dev_gen_costs = [] 521 | for images,_ in dev_gen(): 522 | _dev_gen_cost = session.run( 523 | gen_cost, 524 | feed_dict={real_x_int: images} 525 | ) 526 | dev_gen_costs.append(_dev_gen_cost) 527 | lib.plot.plot('dev gen cost', np.mean(dev_gen_costs)) 528 | 529 | # Write logs 530 | if (iteration < 5) or (iteration % 100 == 99): 531 | lib.plot.flush(outf, logfile) 532 | 533 | # Calculate inception score 534 | if iteration % 10000 == 9999: 535 | inception_score = get_inception_score() 536 | lib.plot.plot('inception score', inception_score[0]) 537 | lib.plot.plot('inception score std', inception_score[1]) 538 | 539 | lib.plot.tick() 540 | 541 | # Generation and reconstruction 542 | if iteration % 5000 == 4999: 543 | generate_image(iteration, _data) 544 | reconstruct_image(iteration, rand_data_int) 545 | reconstruct_image(-iteration, fixed_data_int) 546 | 547 | # Save model 548 | if iteration == ITERS - 1: 549 | save_path = saver.save(session, os.path.join(outf, '{}_model_{}.ckpt'.format(iteration, MODE))) --------------------------------------------------------------------------------