├── 10x_73k ├── __init__.py ├── clus_wgan.py └── mlp.py ├── Gene_ClusterGAN.py ├── Gene_wgan_gp.py ├── Image_ClusterGAN.py ├── Image_wgan_gp.py ├── README.md ├── fashion ├── __init__.py ├── clus_wgan.py └── dcgan.py ├── metric.py ├── mnist ├── __init__.py ├── clus_wgan.py └── dcgan.py ├── pen_ClusterGAN.py ├── pen_wgan_gp.py ├── pendigit ├── __init__.py ├── clus_wgan.py └── mlp.py ├── util.py └── visualize.py /10x_73k/__init__.py: -------------------------------------------------------------------------------- 1 | import scipy.sparse 2 | import scipy.io 3 | import numpy as np 4 | 5 | # Original Labels are within 0 to 9. But proper label mapping is required as there are 8 classes. 6 | 7 | class DataSampler(object): 8 | def __init__(self): 9 | self.total_size = 73233 10 | self.train_size = 60000 11 | self.test_size = 13233 12 | self.X_train, self.X_test = self._load_gene_mtx() 13 | self.y_train, self.y_test = self._load_labels() 14 | 15 | 16 | def _read_mtx(self, filename): 17 | buf = scipy.io.mmread(filename) 18 | return buf 19 | 20 | def _load_gene_mtx(self): 21 | data_path = './data/10x_73k/sub_set-720.mtx' 22 | data = self._read_mtx(data_path) 23 | data = data.toarray() 24 | data = np.log2(data + 1) 25 | scale = np.max(data) 26 | data = data / scale 27 | 28 | np.random.seed(0) 29 | indx = np.random.permutation(np.arange(self.total_size)) 30 | data_train = data[indx[0:self.train_size], :] 31 | data_test = data[indx[self.train_size:], :] 32 | 33 | return data_train, data_test 34 | 35 | 36 | def _load_labels(self): 37 | data_path = './data/10x_73k/labels.txt' 38 | labels = np.loadtxt(data_path).astype(int) 39 | 40 | np.random.seed(0) 41 | indx = np.random.permutation(np.arange(self.total_size)) 42 | labels_train = labels[indx[0:self.train_size]] 43 | labels_test = labels[indx[self.train_size:]] 44 | return labels_train, labels_test 45 | 46 | 47 | def train(self, batch_size, label = False): 48 | indx = np.random.randint(low = 0, high = self.train_size, size = batch_size) 49 | 50 | if label: 51 | return self.X_train[indx, :], self.y_train[indx].flatten() 52 | else: 53 | return self.X_train[indx, :] 54 | 55 | def validation(self): 56 | return self.X_train[-5000:,:], self.y_train[-5000:].flatten() 57 | 58 | def test(self): 59 | return self.X_test, self.y_test 60 | 61 | def load_all(self): 62 | return np.concatenate((self.X_train, self.X_test)), np.concatenate((self.y_train, self.y_test)) 63 | 64 | -------------------------------------------------------------------------------- /10x_73k/clus_wgan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tc 3 | import tensorflow.contrib.layers as tcl 4 | 5 | def leaky_relu(x, alpha=0.2): 6 | return tf.maximum(tf.minimum(0.0, alpha * x), x) 7 | 8 | class Discriminator(object): 9 | def __init__(self, x_dim=720): 10 | self.x_dim = x_dim 11 | self.name = '10x_73k/clus_wgan/d_net' 12 | 13 | def __call__(self, x, reuse=True): 14 | with tf.variable_scope(self.name) as vs: 15 | if reuse: 16 | vs.reuse_variables() 17 | 18 | fc1 = tc.layers.fully_connected( 19 | x, 256, 20 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 21 | activation_fn=tf.identity 22 | ) 23 | fc1 = leaky_relu(fc1) 24 | 25 | fc2 = tc.layers.fully_connected( 26 | fc1, 256, 27 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 28 | activation_fn=tf.identity 29 | ) 30 | fc2 = leaky_relu(fc2) 31 | 32 | fc3 = tc.layers.fully_connected(fc2, 1, 33 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 34 | activation_fn=tf.identity 35 | ) 36 | return fc3 37 | 38 | @property 39 | def vars(self): 40 | return [var for var in tf.global_variables() if self.name in var.name] 41 | 42 | 43 | class Generator(object): 44 | def __init__(self, z_dim = 38, x_dim = 720): 45 | self.z_dim = z_dim 46 | self.x_dim = x_dim 47 | self.name = '10x_73k/clus_wgan/g_net' 48 | 49 | def __call__(self, z): 50 | with tf.variable_scope(self.name) as vs: 51 | fc1 = tcl.fully_connected( 52 | z, 256, 53 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 54 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 55 | activation_fn=tf.identity 56 | ) 57 | fc1 = leaky_relu(fc1) 58 | fc2 = tcl.fully_connected( 59 | fc1, 256, 60 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 61 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 62 | activation_fn=tf.identity 63 | ) 64 | fc2 = leaky_relu(fc2) 65 | 66 | fc3 = tc.layers.fully_connected( 67 | fc2, self.x_dim, 68 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 69 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 70 | activation_fn=tf.identity 71 | ) 72 | return fc3 73 | 74 | @property 75 | def vars(self): 76 | return [var for var in tf.global_variables() if self.name in var.name] 77 | 78 | class Encoder(object): 79 | def __init__(self, z_dim = 38, dim_gen = 30, x_dim = 720): 80 | self.z_dim = z_dim 81 | self.dim_gen = dim_gen 82 | self.x_dim = x_dim 83 | self.name = '10x_73k/clus_wgan/enc_net' 84 | 85 | def __call__(self, x, reuse=True): 86 | 87 | with tf.variable_scope(self.name) as vs: 88 | if reuse: 89 | vs.reuse_variables() 90 | 91 | fc1 = tc.layers.fully_connected( 92 | x, 256, 93 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 94 | activation_fn=tf.identity 95 | ) 96 | fc1 = leaky_relu(fc1) 97 | 98 | fc2 = tc.layers.fully_connected( 99 | fc1, 256, 100 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 101 | activation_fn=tf.identity 102 | ) 103 | fc2 = leaky_relu(fc2) 104 | 105 | fc3 = tc.layers.fully_connected(fc2, self.z_dim, activation_fn=tf.identity) 106 | logits = fc3[:, self.dim_gen:] 107 | y = tf.nn.softmax(logits) 108 | return fc3[:, 0:self.dim_gen], y, logits 109 | 110 | 111 | @property 112 | def vars(self): 113 | return [var for var in tf.global_variables() if self.name in var.name] 114 | -------------------------------------------------------------------------------- /10x_73k/mlp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tc 3 | import tensorflow.contrib.layers as tcl 4 | 5 | def leaky_relu(x, alpha=0.2): 6 | return tf.maximum(tf.minimum(0.0, alpha * x), x) 7 | 8 | 9 | class Discriminator(object): 10 | def __init__(self, x_dim=720): 11 | self.x_dim = x_dim 12 | self.name = '10x_73k/mlp/d_net' 13 | 14 | def __call__(self, x, reuse=True): 15 | with tf.variable_scope(self.name) as vs: 16 | if reuse: 17 | vs.reuse_variables() 18 | 19 | fc1 = tc.layers.fully_connected( 20 | x, 256, 21 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 22 | activation_fn=tf.identity 23 | ) 24 | fc1 = leaky_relu(fc1) 25 | 26 | fc2 = tc.layers.fully_connected( 27 | fc1, 256, 28 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 29 | activation_fn=tf.identity 30 | ) 31 | fc2 = leaky_relu(fc2) 32 | 33 | fc3 = tc.layers.fully_connected(fc2, 1, 34 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 35 | activation_fn=tf.identity 36 | ) 37 | return fc3 38 | 39 | @property 40 | def vars(self): 41 | return [var for var in tf.global_variables() if self.name in var.name] 42 | 43 | 44 | class Generator(object): 45 | def __init__(self, z_dim = 38, x_dim = 720): 46 | self.z_dim = z_dim 47 | self.x_dim = x_dim 48 | self.name = '10x_73k/mlp/g_net' 49 | 50 | def __call__(self, z): 51 | with tf.variable_scope(self.name) as vs: 52 | fc1 = tcl.fully_connected( 53 | z, 256, 54 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 55 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 56 | activation_fn=tf.identity 57 | ) 58 | fc1 = leaky_relu(fc1) 59 | fc2 = tcl.fully_connected( 60 | fc1, 256, 61 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 62 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 63 | activation_fn=tf.identity 64 | ) 65 | fc2 = leaky_relu(fc2) 66 | 67 | fc3 = tc.layers.fully_connected( 68 | fc2, self.x_dim, 69 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 70 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 71 | activation_fn=tf.identity 72 | ) 73 | return fc3 74 | 75 | @property 76 | def vars(self): 77 | return [var for var in tf.global_variables() if self.name in var.name] 78 | -------------------------------------------------------------------------------- /Gene_ClusterGAN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import dateutil.tz 4 | import datetime 5 | import argparse 6 | import importlib 7 | import tensorflow as tf 8 | import numpy as np 9 | from sklearn.cluster import KMeans 10 | from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score 11 | 12 | import metric 13 | import util 14 | 15 | tf.set_random_seed(0) 16 | 17 | 18 | class clusGAN(object): 19 | def __init__(self, g_net, d_net, enc_net, x_sampler, z_sampler, data, model, sampler, 20 | num_classes, dim_gen, n_cat, batch_size, beta_cycle_gen, beta_cycle_label): 21 | self.model = model 22 | self.data = data 23 | self.sampler = sampler 24 | self.g_net = g_net 25 | self.d_net = d_net 26 | self.enc_net = enc_net 27 | self.x_sampler = x_sampler 28 | self.z_sampler = z_sampler 29 | self.num_classes = num_classes 30 | self.dim_gen = dim_gen 31 | self.n_cat = n_cat 32 | self.batch_size = batch_size 33 | scale = 10.0 34 | self.beta_cycle_gen = beta_cycle_gen 35 | self.beta_cycle_label = beta_cycle_label 36 | 37 | self.x_dim = self.d_net.x_dim 38 | self.z_dim = self.g_net.z_dim 39 | 40 | 41 | self.x = tf.placeholder(tf.float32, [None, self.x_dim], name='x') 42 | self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z') 43 | 44 | self.z_gen = self.z[:,0:self.dim_gen] 45 | self.z_hot = self.z[:,self.dim_gen:] 46 | 47 | self.x_ = self.g_net(self.z) 48 | self.z_enc_gen, self.z_enc_label, self.z_enc_logits = self.enc_net(self.x_, reuse=False) 49 | self.z_infer_gen, self.z_infer_label, self.z_infer_logits = self.enc_net(self.x) 50 | 51 | 52 | self.d = self.d_net(self.x, reuse=False) 53 | self.d_ = self.d_net(self.x_) 54 | 55 | 56 | self.g_loss = tf.reduce_mean(self.d_) + \ 57 | self.beta_cycle_gen * tf.reduce_mean(tf.square(self.z_gen - self.z_enc_gen)) +\ 58 | self.beta_cycle_label * tf.reduce_mean( 59 | tf.nn.softmax_cross_entropy_with_logits(logits=self.z_enc_logits,labels=self.z_hot)) 60 | 61 | self.d_loss = tf.reduce_mean(self.d) - tf.reduce_mean(self.d_) 62 | 63 | epsilon = tf.random_uniform([], 0.0, 1.0) 64 | x_hat = epsilon * self.x + (1 - epsilon) * self.x_ 65 | d_hat = self.d_net(x_hat) 66 | 67 | ddx = tf.gradients(d_hat, x_hat)[0] 68 | ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1)) 69 | ddx = tf.reduce_mean(tf.square(ddx - 1.0) * scale) 70 | 71 | self.d_loss = self.d_loss + ddx 72 | 73 | self.d_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \ 74 | .minimize(self.d_loss, var_list=self.d_net.vars) 75 | self.g_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \ 76 | .minimize(self.g_loss, var_list=[self.g_net.vars, self.enc_net.vars]) 77 | 78 | # Reconstruction Nodes 79 | self.recon_loss = tf.reduce_mean(tf.abs(self.x - self.x_), 1) 80 | self.compute_grad = tf.gradients(self.recon_loss, self.z) 81 | 82 | self.saver = tf.train.Saver() 83 | 84 | run_config = tf.ConfigProto() 85 | run_config.gpu_options.per_process_gpu_memory_fraction = 1.0 86 | run_config.gpu_options.allow_growth = True 87 | self.sess = tf.Session(config=run_config) 88 | 89 | def train(self, num_batches=100000): 90 | 91 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 92 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 93 | 94 | batch_size = self.batch_size 95 | 96 | self.sess.run(tf.global_variables_initializer()) 97 | start_time = time.time() 98 | print( 99 | 'Training {} on {}, sampler = {}, z = {} dimension, beta_n = {}, beta_c = {}'. 100 | format(self.model, self.data, self.sampler, self.z_dim, self.beta_cycle_gen, self.beta_cycle_label)) 101 | 102 | 103 | for t in range(0, num_batches): 104 | 105 | d_iters = 5 106 | 107 | for _ in range(0, d_iters): 108 | bx = self.x_sampler.train(batch_size) 109 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 110 | self.sess.run(self.d_adam, feed_dict={self.x: bx, self.z: bz}) 111 | 112 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 113 | self.sess.run(self.g_adam, feed_dict={self.z: bz}) 114 | 115 | if (t+1) % 100 == 0: 116 | bx = self.x_sampler.train(batch_size) 117 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 118 | 119 | d_loss = self.sess.run( 120 | self.d_loss, feed_dict={self.x: bx, self.z: bz} 121 | ) 122 | g_loss = self.sess.run( 123 | self.g_loss, feed_dict={self.z: bz} 124 | ) 125 | print('Iter [%8d] Time [%5.4f] d_loss [%.4f] g_loss [%.4f]' % 126 | (t+1, time.time() - start_time, d_loss, g_loss)) 127 | 128 | self.recon_enc(timestamp, val=True) 129 | self.save(timestamp) 130 | 131 | def save(self, timestamp): 132 | 133 | checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}_cyc{}_gen{}'.format(self.data, timestamp, self.model, self.sampler, 134 | self.z_dim, self.beta_cycle_label, 135 | self.beta_cycle_gen) 136 | 137 | if not os.path.exists(checkpoint_dir): 138 | os.makedirs(checkpoint_dir) 139 | 140 | self.saver.save(self.sess, os.path.join(checkpoint_dir, 'model.ckpt')) 141 | 142 | def load(self, pre_trained = False, timestamp = ''): 143 | 144 | if pre_trained == True: 145 | print('Loading Pre-trained Model...') 146 | checkpoint_dir = 'pre_trained_models/{}/{}_{}_z{}_cyc{}_gen{}'.format(self.data, self.model, self.sampler, 147 | self.z_dim, self.beta_cycle_label, self.beta_cycle_gen) 148 | else: 149 | if timestamp == '': 150 | print('Best Timestamp not provided. Abort !') 151 | checkpoint_dir = '' 152 | else: 153 | checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}_cyc{}_gen{}'.format(self.data, timestamp, self.model, self.sampler, 154 | self.z_dim, self.beta_cycle_label, 155 | self.beta_cycle_gen) 156 | 157 | 158 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, 'model.ckpt')) 159 | print('Restored model weights.') 160 | 161 | 162 | def _gen_samples(self, num_samples): 163 | 164 | batch_size = self.batch_size 165 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 166 | fake_samples = self.sess.run(self.x_, feed_dict = {self.z : bz}) 167 | for t in range(num_samples // batch_size): 168 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 169 | samp = self.sess.run(self.x_, feed_dict = {self.z : bz}) 170 | fake_samples = np.vstack((fake_samples, samp)) 171 | 172 | print(' Generated {} samples .'.format(fake_samples.shape[0])) 173 | np.save('./Image_samples/{}/{}_{}_K_{}_gen_images.npy'.format(self.data, self.model, self.sampler, self.num_classes), fake_samples) 174 | 175 | def recon_enc(self, timestamp, val = True): 176 | 177 | if val: 178 | data_recon, label_recon = self.x_sampler.validation() 179 | else: 180 | data_recon, label_recon = self.x_sampler.test() 181 | #data_recon, label_recon = self.x_sampler.load_all() 182 | 183 | num_pts_to_plot = data_recon.shape[0] 184 | recon_batch_size = self.batch_size 185 | latent = np.zeros(shape=(num_pts_to_plot, self.z_dim)) 186 | 187 | print('Data Shape = {}, Labels Shape = {}'.format(data_recon.shape, label_recon.shape)) 188 | for b in range(int(np.ceil(num_pts_to_plot*1.0 / recon_batch_size))): 189 | 190 | if (b+1)*recon_batch_size > num_pts_to_plot: 191 | pt_indx = np.arange(b*recon_batch_size, num_pts_to_plot) 192 | else: 193 | pt_indx = np.arange(b*recon_batch_size, (b+1)*recon_batch_size) 194 | xtrue = data_recon[pt_indx, :] 195 | 196 | zhats_gen, zhats_label = self.sess.run([self.z_infer_gen, self.z_infer_label], feed_dict={self.x : xtrue}) 197 | 198 | latent[pt_indx, :] = np.concatenate((zhats_gen, zhats_label), axis=1) 199 | 200 | if self.beta_cycle_gen == 0: 201 | self._eval_cluster(latent[:, self.dim_gen:], label_recon, timestamp, val) 202 | else: 203 | self._eval_cluster(latent, label_recon, timestamp, val) 204 | 205 | 206 | def _eval_cluster(self, latent_rep, labels_true, timestamp, val): 207 | 208 | map_labels = {0: 0, 1: 1, 2: 2, 4: 3, 6: 4, 7: 5, 8: 6, 9: 7} 209 | labels_true = np.array([map_labels[i] for i in labels_true]) 210 | 211 | km = KMeans(n_clusters=max(self.num_classes, len(np.unique(labels_true))), random_state=0).fit(latent_rep) 212 | labels_pred = km.labels_ 213 | 214 | purity = metric.compute_purity(labels_pred, labels_true) 215 | ari = adjusted_rand_score(labels_true, labels_pred) 216 | nmi = normalized_mutual_info_score(labels_true, labels_pred) 217 | 218 | if val: 219 | data_split = 'Validation' 220 | else: 221 | data_split = 'Test' 222 | #data_split = 'All' 223 | 224 | print('Data = {}, Model = {}, sampler = {}, z_dim = {}, beta_label = {}, beta_gen = {} ' 225 | .format(self.data, self.model, self.sampler, self.z_dim, self.beta_cycle_label, self.beta_cycle_gen)) 226 | print(' #Points = {}, K = {}, Purity = {}, NMI = {}, ARI = {}, ' 227 | .format(latent_rep.shape[0], self.num_classes, purity, nmi, ari)) 228 | 229 | with open('logs/Res_{}_{}.txt'.format(self.data, self.model), 'a+') as f: 230 | f.write('{}, {} : K = {}, z_dim = {}, beta_label = {}, beta_gen = {}, sampler = {}, Purity = {}, NMI = {}, ARI = {}\n' 231 | .format(timestamp, data_split, self.num_classes, self.z_dim, self.beta_cycle_label, self.beta_cycle_gen, 232 | self.sampler, purity, nmi, ari)) 233 | f.flush() 234 | 235 | 236 | if __name__ == '__main__': 237 | parser = argparse.ArgumentParser('') 238 | parser.add_argument('--data', type=str, default='10x_73k') 239 | parser.add_argument('--model', type=str, default='clus_wgan') 240 | parser.add_argument('--sampler', type=str, default='one_hot') 241 | parser.add_argument('--K', type=int, default=8) 242 | parser.add_argument('--dz', type=int, default=30) 243 | parser.add_argument('--bs', type=int, default=64) 244 | parser.add_argument('--beta_n', type=float, default=10.0) 245 | parser.add_argument('--beta_c', type=float, default=10.0) 246 | parser.add_argument('--timestamp', type=str, default='') 247 | parser.add_argument('--train', type=str, default='False') 248 | args = parser.parse_args() 249 | data = importlib.import_module(args.data) 250 | model = importlib.import_module(args.data + '.' + args.model) 251 | 252 | num_classes = args.K 253 | dim_gen = args.dz 254 | n_cat = 1 255 | batch_size = args.bs 256 | beta_cycle_gen = args.beta_n 257 | beta_cycle_label = args.beta_c 258 | timestamp = args.timestamp 259 | 260 | z_dim = dim_gen + num_classes * n_cat 261 | d_net = model.Discriminator() 262 | g_net = model.Generator(z_dim=z_dim) 263 | enc_net = model.Encoder(z_dim=z_dim, dim_gen = dim_gen) 264 | xs = data.DataSampler() 265 | zs = util.sample_Z 266 | 267 | cl_gan = clusGAN(g_net, d_net, enc_net, xs, zs, args.data, args.model, args.sampler, 268 | num_classes, dim_gen, n_cat, batch_size, beta_cycle_gen, beta_cycle_label) 269 | 270 | if args.train == 'True': 271 | cl_gan.train() 272 | else: 273 | 274 | print('Attempting to Restore Model ...') 275 | if timestamp == '': 276 | cl_gan.load(pre_trained=True) 277 | timestamp = 'pre-trained' 278 | else: 279 | cl_gan.load(pre_trained=False, timestamp = timestamp) 280 | 281 | cl_gan.recon_enc(timestamp, val=False) 282 | 283 | -------------------------------------------------------------------------------- /Gene_wgan_gp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import dateutil.tz 4 | import datetime 5 | import argparse 6 | import importlib 7 | import tensorflow as tf 8 | import numpy as np 9 | from sklearn.cluster import KMeans 10 | from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score 11 | 12 | import metric 13 | import util 14 | 15 | tf.set_random_seed(0) 16 | 17 | class WassersteinGAN(object): 18 | def __init__(self, g_net, d_net, x_sampler, z_sampler, data, model, sampler, num_classes, dim_gen, n_cat, 19 | batch_size, beta_reg): 20 | self.model = model 21 | self.data = data 22 | self.sampler = sampler 23 | self.g_net = g_net 24 | self.d_net = d_net 25 | self.x_sampler = x_sampler 26 | self.z_sampler = z_sampler 27 | self.num_classes = num_classes 28 | self.dim_gen = dim_gen 29 | self.n_cat = n_cat 30 | self.batch_size = batch_size 31 | scale = 10.0 32 | self.beta_reg = beta_reg 33 | 34 | self.x_dim = self.d_net.x_dim 35 | self.z_dim = self.g_net.z_dim 36 | 37 | if sampler == 'mul_cat': 38 | self.clip_lim = [-0.6, 0.6] 39 | elif sampler == 'one_hot': 40 | self.clip_lim = [-0.6, 0.6] 41 | elif sampler == 'clus': 42 | self.clip_lim = [-1.0, 1.0] 43 | elif sampler == 'uniform': 44 | self.clip_lim = [-1.0, 1.0] 45 | elif sampler == 'normal': 46 | self.clip_lim = [-1.0, 1.0] 47 | elif sampler == 'mix_gauss': 48 | self.clip_lim = [-1.0, 2.0] 49 | elif sampler == 'pca_kmeans': 50 | self.clip_lim = [-2.0, 2.0] 51 | 52 | self.x = tf.placeholder(tf.float32, [None, self.x_dim], name='x') 53 | self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z') 54 | 55 | self.x_ = self.g_net(self.z) 56 | 57 | self.d = self.d_net(self.x, reuse=False) 58 | self.d_ = self.d_net(self.x_) 59 | 60 | self.g_loss = tf.reduce_mean(self.d_) 61 | self.d_loss = tf.reduce_mean(self.d) - tf.reduce_mean(self.d_) 62 | 63 | epsilon = tf.random_uniform([], 0.0, 1.0) 64 | x_hat = epsilon * self.x + (1 - epsilon) * self.x_ 65 | d_hat = self.d_net(x_hat) 66 | 67 | ddx = tf.gradients(d_hat, x_hat)[0] 68 | ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1)) 69 | ddx = tf.reduce_mean(tf.square(ddx - 1.0) * scale) 70 | 71 | self.d_loss = self.d_loss + ddx 72 | 73 | self.d_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \ 74 | .minimize(self.d_loss, var_list=self.d_net.vars) 75 | self.g_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \ 76 | .minimize(self.g_loss, var_list=self.g_net.vars) 77 | 78 | self.saver = tf.train.Saver() 79 | 80 | run_config = tf.ConfigProto() 81 | run_config.gpu_options.per_process_gpu_memory_fraction = 1.0 82 | run_config.gpu_options.allow_growth = True 83 | self.sess = tf.Session(config=run_config) 84 | 85 | def train(self, num_batches=200000): 86 | 87 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 88 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 89 | 90 | batch_size = self.batch_size 91 | self.sess.run(tf.global_variables_initializer()) 92 | start_time = time.time() 93 | print('Training {} on {}, sampler = {}, z = {} dimension'.format(self.model, self.data, self.sampler, self.z_dim)) 94 | 95 | 96 | for t in range(0, num_batches): 97 | d_iters = 5 98 | 99 | for _ in range(0, d_iters): 100 | bx = self.x_sampler.train(batch_size) 101 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 102 | self.sess.run(self.d_adam, feed_dict={self.x: bx, self.z: bz}) 103 | 104 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 105 | self.sess.run(self.g_adam, feed_dict={self.z: bz}) 106 | 107 | if t % 100 == 0: 108 | bx = self.x_sampler.train(batch_size) 109 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 110 | 111 | d_loss = self.sess.run( 112 | self.d_loss, feed_dict={self.x: bx, self.z: bz} 113 | ) 114 | g_loss = self.sess.run( 115 | self.g_loss, feed_dict={self.z: bz} 116 | ) 117 | print('Iter [%8d] Time [%5.4f] d_loss [%.4f] g_loss [%.4f]' % 118 | (t+1, time.time() - start_time, d_loss, g_loss)) 119 | 120 | self.recon_enc(timestamp, val=True) 121 | self.save(timestamp) 122 | 123 | def save(self, timestamp): 124 | 125 | checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}'.format(self.data, timestamp, self.model, self.sampler, 126 | self.z_dim) 127 | 128 | if not os.path.exists(checkpoint_dir): 129 | os.makedirs(checkpoint_dir) 130 | 131 | self.saver.save(self.sess, os.path.join(checkpoint_dir, 'model.ckpt')) 132 | 133 | def load(self, pre_trained=False, timestamp=''): 134 | 135 | if pre_trained == True: 136 | print('Loading Pre-trained Model...') 137 | checkpoint_dir = 'pre_trained_models/{}/{}_{}_z{}'.format(self.data, self.model, self.sampler, self.z_dim) 138 | else: 139 | if timestamp == '': 140 | print('Best Timestamp not provided !') 141 | checkpoint_dir = '' 142 | else: 143 | checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}'.format(self.data, timestamp, self.model, self.sampler, 144 | self.z_dim) 145 | 146 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, 'model.ckpt')) 147 | print('Restored model weights.') 148 | 149 | def _gen_samples(self, num_samples): 150 | 151 | batch_size = self.batch_size 152 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 153 | fake_samples = self.sess.run(self.x_, feed_dict = {self.z : bz}) 154 | for t in range(num_samples // batch_size): 155 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 156 | samp = self.sess.run(self.x_, feed_dict = {self.z : bz}) 157 | fake_samples = np.vstack((fake_samples, samp)) 158 | 159 | print(' Generated {} samples .'.format(fake_samples.shape[0])) 160 | np.save('./Image_samples/{}/{}_{}_K_{}_gen_images.npy'.format(self.data, self.model, self.sampler, self.num_classes), fake_samples) 161 | 162 | 163 | def recon_enc(self, timestamp, val = True): 164 | 165 | if val: 166 | data_recon, label_recon = self.x_sampler.validation() 167 | else: 168 | data_recon, label_recon = self.x_sampler.test() 169 | #data_recon, label_recon = self.x_sampler.load_all() 170 | 171 | num_pts_to_plot = data_recon.shape[0] 172 | recon_batch_size = 1000 173 | latent = np.zeros(shape=(num_pts_to_plot, self.z_dim)) 174 | clip_lim = self.clip_lim 175 | 176 | print('Data Shape = {}, Labels Shape = {}'.format(data_recon.shape, label_recon.shape)) 177 | 178 | # Regularized Reconstruction objective 179 | 180 | self.recon_reg_loss = tf.reduce_mean(tf.abs(self.x - self.x_), 1) + \ 181 | self.beta_reg * tf.reduce_mean(tf.square(self.z[:, 0:self.dim_gen]), 1) 182 | self.compute_reg_grad = tf.gradients(self.recon_reg_loss, self.z) 183 | 184 | for b in range(int(np.ceil(num_pts_to_plot*1.0 / recon_batch_size))): 185 | 186 | if (b+1)*recon_batch_size > num_pts_to_plot: 187 | pt_indx = np.arange(b*recon_batch_size, num_pts_to_plot) 188 | else: 189 | pt_indx = np.arange(b*recon_batch_size, (b+1)*recon_batch_size) 190 | xtrue = data_recon[pt_indx, :] 191 | 192 | num_backprop_iter = 5000 193 | num_restarts = self.num_classes 194 | seed_labels = np.tile(np.arange(self.num_classes), int(np.ceil(len(pt_indx) * 1.0 / self.num_classes))) 195 | seed_labels = seed_labels[0:len(pt_indx)] 196 | best_zhats = np.zeros(shape=(len(pt_indx), self.z_dim)) 197 | best_loss = np.inf * np.ones(len(pt_indx)) 198 | mu_mat = 1.0 * np.eye(self.num_classes) 199 | alg = 'adam' 200 | for t in range(num_restarts): 201 | print('Backprop Decoding [{} / {} ] ...'.format(t + 1, num_restarts)) 202 | 203 | if self.sampler == 'one_hot': 204 | label_index = (seed_labels + t) % self.num_classes 205 | zhats = util.sample_Z(len(pt_indx), self.z_dim, self.sampler, num_class=self.num_classes, 206 | n_cat=1, label_index=label_index) 207 | elif self.sampler == 'mul_cat': 208 | label_index = (seed_labels + t) % self.num_classes 209 | zhats = util.sample_Z(len(pt_indx), self.z_dim, self.sampler, num_class=self.num_classes, 210 | n_cat=self.n_cat, 211 | label_index=label_index) 212 | elif self.sampler == 'mix_gauss': 213 | label_index = (seed_labels + t) % self.num_classes 214 | zhats = util.sample_Z(len(pt_indx), self.z_dim, self.sampler, num_class=self.num_classes, 215 | n_cat=0, label_index=label_index) 216 | else: 217 | zhats = util.sample_Z(len(pt_indx), self.z_dim, self.sampler) 218 | if alg == 'adam': 219 | beta1 = 0.9 220 | beta2 = 0.999 221 | lr = 0.01 222 | eps = 1e-8 223 | m = 0 224 | v = 0 225 | elif alg == 'grad_descent': 226 | lr = 1.00 227 | 228 | for i in range(num_backprop_iter): 229 | 230 | L, g = self.sess.run([self.recon_reg_loss, self.compute_reg_grad], 231 | feed_dict={self.z: zhats, self.x: xtrue}) 232 | 233 | if alg == 'adam': 234 | m_prev = np.copy(m) 235 | v_prev = np.copy(v) 236 | m = beta1 * m_prev + (1 - beta1) * g[0] 237 | v = beta2 * v_prev + (1 - beta2) * np.multiply(g[0], g[0]) 238 | m_hat = m / (1 - beta1 ** (i + 1)) 239 | v_hat = v / (1 - beta2 ** (i + 1)) 240 | zhats += - np.true_divide(lr * m_hat, (np.sqrt(v_hat) + eps)) 241 | 242 | elif alg == 'grad_descent': 243 | zhats += - lr * g[0] 244 | 245 | zhats = np.clip(zhats, a_min=clip_lim[0], a_max=clip_lim[1]) 246 | 247 | if self.sampler == 'one_hot': 248 | zhats[:, -self.num_classes:] = mu_mat[label_index, :] 249 | elif self.sampler == 'mul_hot': 250 | zhats[:, self.dim_gen:] = np.tile(mu_mat[label_index, :], (1, self.n_cat)) 251 | 252 | change_index = best_loss > L 253 | best_zhats[change_index, :] = zhats[change_index, :] 254 | best_loss[change_index] = L[change_index] 255 | 256 | latent[pt_indx, :] = best_zhats 257 | print(' [{} / {} ] ...'.format(pt_indx[-1]+1, num_pts_to_plot)) 258 | 259 | self._eval_cluster(latent, label_recon, timestamp, val) 260 | 261 | def _eval_cluster(self, latent_rep, labels_true, timestamp, val): 262 | 263 | map_labels = {0:0, 1:1, 2:2, 4:3, 6:4, 7:5, 8:6, 9:7} 264 | labels_true = np.array([map_labels[i] for i in labels_true]) 265 | 266 | km = KMeans(n_clusters=max(self.num_classes, len(np.unique(labels_true))), random_state=0).fit(latent_rep) 267 | labels_pred = km.labels_ 268 | 269 | purity = metric.compute_purity(labels_pred, labels_true) 270 | ari = adjusted_rand_score(labels_true, labels_pred) 271 | nmi = normalized_mutual_info_score(labels_true, labels_pred) 272 | 273 | if val: 274 | data_split = 'Validation' 275 | else: 276 | data_split = 'Test' 277 | #data_split = 'All' 278 | 279 | print('Data = {}, Model = {}, sampler = {}, z_dim = {}, beta_reg = {}' 280 | .format(self.data, self.model, self.sampler, self.z_dim, self.beta_reg)) 281 | print(' #Points = {}, K = {}, Purity = {}, NMI = {}, ARI = {}, ' 282 | .format(latent_rep.shape[0], self.num_classes, purity, nmi, ari)) 283 | 284 | with open('logs/Res_{}_{}.txt'.format(self.data, self.model), 'a+') as f: 285 | f.write('{}, {} : K = {}, z_dim = {}, beta_reg = {}, sampler = {}, Purity = {}, NMI = {}, ARI = {}\n' 286 | .format(timestamp, data_split, self.num_classes, self.z_dim, self.beta_reg, 287 | self.sampler, purity, nmi, ari)) 288 | f.flush() 289 | 290 | if __name__ == '__main__': 291 | parser = argparse.ArgumentParser('') 292 | parser.add_argument('--data', type=str, default='10x_73k') 293 | parser.add_argument('--model', type=str, default='mlp') 294 | parser.add_argument('--sampler', type=str, default='one_hot') 295 | parser.add_argument('--K', type=int, default=8) 296 | parser.add_argument('--dz', type=int, default=30) 297 | parser.add_argument('--bs', type=int, default=64) 298 | parser.add_argument('--beta_reg', type=float, default=10.0) 299 | parser.add_argument('--timestamp', type=str, default='') 300 | parser.add_argument('--train', type=str, default='False') 301 | 302 | args = parser.parse_args() 303 | data = importlib.import_module(args.data) 304 | model = importlib.import_module(args.data + '.' + args.model) 305 | 306 | num_classes = args.K 307 | dim_gen = args.dz 308 | n_cat = 1 309 | batch_size = args.bs 310 | beta_reg = args.beta_reg 311 | timestamp = args.timestamp 312 | 313 | z_dim = dim_gen + num_classes * n_cat 314 | d_net = model.Discriminator() 315 | g_net = model.Generator(z_dim=z_dim) 316 | xs = data.DataSampler() 317 | zs = util.sample_Z 318 | 319 | wgan = WassersteinGAN(g_net, d_net, xs, zs, args.data, args.model, args.sampler, 320 | num_classes, dim_gen, n_cat, batch_size, beta_reg) 321 | 322 | if args.train == 'True': 323 | wgan.train() 324 | else: 325 | 326 | print('Attempting to Restore Model ...') 327 | if timestamp == '': 328 | wgan.load(pre_trained=True) 329 | timestamp = 'pre-trained' 330 | else: 331 | wgan.load(pre_trained=False, timestamp=timestamp) 332 | 333 | wgan.recon_enc(timestamp, val=False) 334 | 335 | 336 | 337 | -------------------------------------------------------------------------------- /Image_ClusterGAN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import dateutil.tz 4 | import datetime 5 | import argparse 6 | import importlib 7 | import tensorflow as tf 8 | from scipy.misc import imsave 9 | import numpy as np 10 | from sklearn.cluster import KMeans 11 | from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score 12 | 13 | import metric 14 | from visualize import * 15 | import util 16 | 17 | tf.set_random_seed(0) 18 | 19 | class clusGAN(object): 20 | def __init__(self, g_net, d_net, enc_net, x_sampler, z_sampler, data, model, sampler, 21 | num_classes, dim_gen, n_cat, batch_size, beta_cycle_gen, beta_cycle_label): 22 | self.model = model 23 | self.data = data 24 | self.sampler = sampler 25 | self.g_net = g_net 26 | self.d_net = d_net 27 | self.enc_net = enc_net 28 | self.x_sampler = x_sampler 29 | self.z_sampler = z_sampler 30 | self.num_classes = num_classes 31 | self.dim_gen = dim_gen 32 | self.n_cat = n_cat 33 | self.batch_size = batch_size 34 | scale = 10.0 35 | self.beta_cycle_gen = beta_cycle_gen 36 | self.beta_cycle_label = beta_cycle_label 37 | 38 | 39 | self.x_dim = self.d_net.x_dim 40 | self.z_dim = self.g_net.z_dim 41 | 42 | self.x = tf.placeholder(tf.float32, [None, self.x_dim], name='x') 43 | self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z') 44 | 45 | self.z_gen = self.z[:,0:self.dim_gen] 46 | self.z_hot = self.z[:,self.dim_gen:] 47 | 48 | self.x_ = self.g_net(self.z) 49 | self.z_enc_gen, self.z_enc_label, self.z_enc_logits = self.enc_net(self.x_, reuse=False) 50 | self.z_infer_gen, self.z_infer_label, self.z_infer_logits = self.enc_net(self.x) 51 | 52 | 53 | self.d = self.d_net(self.x, reuse=False) 54 | self.d_ = self.d_net(self.x_) 55 | 56 | 57 | self.g_loss = tf.reduce_mean(self.d_) + \ 58 | self.beta_cycle_gen * tf.reduce_mean(tf.square(self.z_gen - self.z_enc_gen)) +\ 59 | self.beta_cycle_label * tf.reduce_mean( 60 | tf.nn.softmax_cross_entropy_with_logits(logits=self.z_enc_logits,labels=self.z_hot)) 61 | 62 | self.d_loss = tf.reduce_mean(self.d) - tf.reduce_mean(self.d_) 63 | 64 | epsilon = tf.random_uniform([], 0.0, 1.0) 65 | x_hat = epsilon * self.x + (1 - epsilon) * self.x_ 66 | d_hat = self.d_net(x_hat) 67 | 68 | ddx = tf.gradients(d_hat, x_hat)[0] 69 | ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1)) 70 | ddx = tf.reduce_mean(tf.square(ddx - 1.0) * scale) 71 | 72 | self.d_loss = self.d_loss + ddx 73 | 74 | self.d_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \ 75 | .minimize(self.d_loss, var_list=self.d_net.vars) 76 | self.g_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \ 77 | .minimize(self.g_loss, var_list=[self.g_net.vars, self.enc_net.vars]) 78 | 79 | # Reconstruction Nodes 80 | self.recon_loss = tf.reduce_mean(tf.abs(self.x - self.x_), 1) 81 | self.compute_grad = tf.gradients(self.recon_loss, self.z) 82 | 83 | self.saver = tf.train.Saver() 84 | 85 | run_config = tf.ConfigProto() 86 | run_config.gpu_options.per_process_gpu_memory_fraction = 1.0 87 | run_config.gpu_options.allow_growth = True 88 | self.sess = tf.Session(config=run_config) 89 | 90 | def train(self, num_batches=500000): 91 | 92 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 93 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 94 | 95 | batch_size = self.batch_size 96 | plt.ion() 97 | self.sess.run(tf.global_variables_initializer()) 98 | start_time = time.time() 99 | print( 100 | 'Training {} on {}, sampler = {}, z = {} dimension, beta_n = {}, beta_c = {}'. 101 | format(self.model, self.data, self.sampler, self.z_dim, self.beta_cycle_gen, self.beta_cycle_label)) 102 | 103 | 104 | im_save_dir = 'logs/{}/{}/{}_z{}_cyc{}_gen{}'.format(self.data, self.model, self.sampler, self.z_dim, 105 | self.beta_cycle_label, self.beta_cycle_gen) 106 | if not os.path.exists(im_save_dir): 107 | os.makedirs(im_save_dir) 108 | 109 | for t in range(0, num_batches): 110 | d_iters = 5 111 | 112 | for _ in range(0, d_iters): 113 | bx = self.x_sampler.train(batch_size) 114 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 115 | self.sess.run(self.d_adam, feed_dict={self.x: bx, self.z: bz}) 116 | 117 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 118 | self.sess.run(self.g_adam, feed_dict={self.z: bz}) 119 | 120 | if (t+1) % 100 == 0: 121 | bx = self.x_sampler.train(batch_size) 122 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 123 | 124 | 125 | d_loss = self.sess.run( 126 | self.d_loss, feed_dict={self.x: bx, self.z: bz} 127 | ) 128 | g_loss = self.sess.run( 129 | self.g_loss, feed_dict={self.z: bz} 130 | ) 131 | print('Iter [%8d] Time [%5.4f] d_loss [%.4f] g_loss [%.4f]' % 132 | (t+1, time.time() - start_time, d_loss, g_loss)) 133 | 134 | 135 | if (t+1) % 5000 == 0: 136 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 137 | bx = self.sess.run(self.x_, feed_dict={self.z: bz}) 138 | bx = xs.data2img(bx) 139 | bx = grid_transform(bx, xs.shape) 140 | 141 | imsave('logs/{}/{}/{}_z{}_cyc{}_gen{}/{}.png'.format(self.data, self.model, self.sampler, 142 | self.z_dim, self.beta_cycle_label, self.beta_cycle_gen, (t+1) / 100), bx) 143 | 144 | self.recon_enc(timestamp, val = True) 145 | self.save(timestamp) 146 | 147 | def save(self, timestamp): 148 | 149 | checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}_cyc{}_gen{}'.format(self.data, timestamp, self.model, self.sampler, 150 | self.z_dim, self.beta_cycle_label, 151 | self.beta_cycle_gen) 152 | 153 | if not os.path.exists(checkpoint_dir): 154 | os.makedirs(checkpoint_dir) 155 | 156 | self.saver.save(self.sess, os.path.join(checkpoint_dir, 'model.ckpt')) 157 | 158 | def load(self, pre_trained = False, timestamp = ''): 159 | 160 | if pre_trained == True: 161 | print('Loading Pre-trained Model...') 162 | checkpoint_dir = 'pre_trained_models/{}/{}_{}_z{}_cyc{}_gen{}'.format(self.data, self.model, self.sampler, 163 | self.z_dim, self.beta_cycle_label, self.beta_cycle_gen) 164 | else: 165 | if timestamp == '': 166 | print('Best Timestamp not provided. Abort !') 167 | checkpoint_dir = '' 168 | else: 169 | checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}_cyc{}_gen{}'.format(self.data, timestamp, self.model, self.sampler, 170 | self.z_dim, self.beta_cycle_label, 171 | self.beta_cycle_gen) 172 | 173 | 174 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, 'model.ckpt')) 175 | print('Restored model weights.') 176 | 177 | 178 | 179 | def _gen_samples(self, num_images): 180 | 181 | batch_size = self.batch_size 182 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 183 | fake_im = self.sess.run(self.x_, feed_dict = {self.z : bz}) 184 | for t in range(num_images // batch_size): 185 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 186 | im = self.sess.run(self.x_, feed_dict = {self.z : bz}) 187 | fake_im = np.vstack((fake_im, im)) 188 | 189 | print(' Generated {} images .'.format(fake_im.shape[0])) 190 | np.save('./Image_samples/{}/{}_{}_K_{}_gen_images.npy'.format(self.data, self.model, self.sampler, self.num_classes), fake_im) 191 | 192 | 193 | def gen_from_all_modes(self): 194 | 195 | if self.sampler == 'one_hot': 196 | batch_size = 1000 197 | label_index = np.tile(np.arange(self.num_classes), int(np.ceil(batch_size * 1.0 / self.num_classes))) 198 | 199 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, num_class=self.num_classes, 200 | n_cat= self.n_cat, label_index=label_index) 201 | bx = self.sess.run(self.x_, feed_dict={self.z: bz}) 202 | 203 | for m in range(self.num_classes): 204 | print('Generating samples from mode {} ...'.format(m)) 205 | mode_index = np.where(label_index == m)[0] 206 | mode_bx = bx[mode_index, :] 207 | mode_bx = xs.data2img(mode_bx) 208 | mode_bx = grid_transform(mode_bx, xs.shape) 209 | 210 | imsave('logs/{}/{}/{}_z{}_cyc{}_gen{}/mode{}_samples.png'.format(self.data, self.model, self.sampler, 211 | self.z_dim, self.beta_cycle_label, self.beta_cycle_gen, m), mode_bx) 212 | 213 | def recon_enc(self, timestamp, val = True): 214 | 215 | if val: 216 | data_recon, label_recon = self.x_sampler.validation() 217 | else: 218 | data_recon, label_recon = self.x_sampler.test() 219 | #data_recon, label_recon = self.x_sampler.load_all() 220 | 221 | num_pts_to_plot = data_recon.shape[0] 222 | recon_batch_size = self.batch_size 223 | latent = np.zeros(shape=(num_pts_to_plot, self.z_dim)) 224 | 225 | print('Data Shape = {}, Labels Shape = {}'.format(data_recon.shape, label_recon.shape)) 226 | for b in range(int(np.ceil(num_pts_to_plot * 1.0 / recon_batch_size))): 227 | if (b+1)*recon_batch_size > num_pts_to_plot: 228 | pt_indx = np.arange(b*recon_batch_size, num_pts_to_plot) 229 | else: 230 | pt_indx = np.arange(b*recon_batch_size, (b+1)*recon_batch_size) 231 | xtrue = data_recon[pt_indx, :] 232 | 233 | zhats_gen, zhats_label = self.sess.run([self.z_infer_gen, self.z_infer_label], feed_dict={self.x : xtrue}) 234 | 235 | latent[pt_indx, :] = np.concatenate((zhats_gen, zhats_label), axis=1) 236 | 237 | 238 | if self.beta_cycle_gen == 0: 239 | self._eval_cluster(latent[:, self.dim_gen:], label_recon, timestamp, val) 240 | else: 241 | self._eval_cluster(latent, label_recon, timestamp, val) 242 | 243 | 244 | def _eval_cluster(self, latent_rep, labels_true, timestamp, val): 245 | 246 | if self.data == 'fashion' and self.num_classes == 5: 247 | map_labels = {0 : 0, 1 : 1, 2 : 2, 3 : 0, 4 : 2, 5 : 3, 6 : 2, 7 : 3, 8 : 4, 9 : 3} 248 | labels_true = np.array([map_labels[i] for i in labels_true]) 249 | 250 | km = KMeans(n_clusters=max(self.num_classes, len(np.unique(labels_true))), random_state=0).fit(latent_rep) 251 | labels_pred = km.labels_ 252 | 253 | purity = metric.compute_purity(labels_pred, labels_true) 254 | ari = adjusted_rand_score(labels_true, labels_pred) 255 | nmi = normalized_mutual_info_score(labels_true, labels_pred) 256 | 257 | 258 | if val: 259 | data_split = 'Validation' 260 | else: 261 | data_split = 'Test' 262 | #data_split = 'All' 263 | 264 | print('Data = {}, Model = {}, sampler = {}, z_dim = {}, beta_label = {}, beta_gen = {} ' 265 | .format(self.data, self.model, self.sampler, self.z_dim, self.beta_cycle_label, self.beta_cycle_gen)) 266 | print(' #Points = {}, K = {}, Purity = {}, NMI = {}, ARI = {}, ' 267 | .format(latent_rep.shape[0], self.num_classes, purity, nmi, ari)) 268 | 269 | with open('logs/Res_{}_{}.txt'.format(self.data, self.model), 'a+') as f: 270 | f.write('{}, {} : K = {}, z_dim = {}, beta_label = {}, beta_gen = {}, sampler = {}, Purity = {}, NMI = {}, ARI = {}\n' 271 | .format(timestamp, data_split, self.num_classes, self.z_dim, self.beta_cycle_label, self.beta_cycle_gen, 272 | self.sampler, purity, nmi, ari)) 273 | f.flush() 274 | 275 | 276 | if __name__ == '__main__': 277 | parser = argparse.ArgumentParser('') 278 | parser.add_argument('--data', type=str, default='mnist') 279 | parser.add_argument('--model', type=str, default='clus_wgan') 280 | parser.add_argument('--sampler', type=str, default='one_hot') 281 | parser.add_argument('--K', type=int, default=10) 282 | parser.add_argument('--dz', type=int, default=30) 283 | parser.add_argument('--bs', type=int, default=64) 284 | parser.add_argument('--beta_n', type=float, default=10.0) 285 | parser.add_argument('--beta_c', type=float, default=10.0) 286 | parser.add_argument('--timestamp', type=str, default='') 287 | parser.add_argument('--train', type=str, default='False') 288 | 289 | args = parser.parse_args() 290 | data = importlib.import_module(args.data) 291 | model = importlib.import_module(args.data + '.' + args.model) 292 | 293 | num_classes = args.K 294 | dim_gen = args.dz 295 | n_cat = 1 296 | batch_size = args.bs 297 | beta_cycle_gen = args.beta_n 298 | beta_cycle_label = args.beta_c 299 | timestamp = args.timestamp 300 | 301 | z_dim = dim_gen + num_classes * n_cat 302 | d_net = model.Discriminator() 303 | g_net = model.Generator(z_dim=z_dim) 304 | enc_net = model.Encoder(z_dim=z_dim, dim_gen = dim_gen) 305 | xs = data.DataSampler() 306 | zs = util.sample_Z 307 | 308 | 309 | cl_gan = clusGAN(g_net, d_net, enc_net, xs, zs, args.data, args.model, args.sampler, 310 | num_classes, dim_gen, n_cat, batch_size, beta_cycle_gen, beta_cycle_label) 311 | if args.train == 'True': 312 | cl_gan.train() 313 | else: 314 | 315 | print('Attempting to Restore Model ...') 316 | if timestamp == '': 317 | cl_gan.load(pre_trained=True) 318 | timestamp = 'pre-trained' 319 | else: 320 | cl_gan.load(pre_trained=False, timestamp = timestamp) 321 | 322 | cl_gan.recon_enc(timestamp, val=False) 323 | 324 | 325 | 326 | -------------------------------------------------------------------------------- /Image_wgan_gp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import dateutil.tz 4 | import datetime 5 | import argparse 6 | import importlib 7 | import tensorflow as tf 8 | from scipy.misc import imsave 9 | import numpy as np 10 | from sklearn.cluster import KMeans 11 | from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score 12 | 13 | import metric 14 | from visualize import * 15 | import util 16 | 17 | tf.set_random_seed(0) 18 | 19 | class WassersteinGAN(object): 20 | def __init__(self, g_net, d_net, x_sampler, z_sampler, data, model, sampler, num_classes, dim_gen, n_cat, 21 | batch_size, beta_reg): 22 | self.model = model 23 | self.data = data 24 | self.sampler = sampler 25 | self.g_net = g_net 26 | self.d_net = d_net 27 | self.x_sampler = x_sampler 28 | self.z_sampler = z_sampler 29 | self.num_classes = num_classes 30 | self.dim_gen = dim_gen 31 | self.n_cat = n_cat 32 | self.batch_size = batch_size 33 | scale = 10.0 34 | self.beta_reg = beta_reg 35 | 36 | self.x_dim = self.d_net.x_dim 37 | self.z_dim = self.g_net.z_dim 38 | 39 | if sampler == 'mul_cat': 40 | self.clip_lim = [-0.6, 0.6] 41 | elif sampler == 'one_hot': 42 | self.clip_lim = [-0.6, 0.6] 43 | elif sampler == 'clus': 44 | self.clip_lim = [-1.0, 1.0] 45 | elif sampler == 'uniform': 46 | self.clip_lim = [-1.0, 1.0] 47 | elif sampler == 'normal': 48 | self.clip_lim = [-1.0, 1.0] 49 | elif sampler == 'mix_gauss': 50 | self.clip_lim = [-1.0, 2.0] 51 | elif sampler == 'pca_kmeans': 52 | self.clip_lim = [-2.0, 2.0] 53 | 54 | 55 | self.x = tf.placeholder(tf.float32, [None, self.x_dim], name='x') 56 | self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z') 57 | 58 | self.x_ = self.g_net(self.z) 59 | 60 | self.d = self.d_net(self.x, reuse=False) 61 | self.d_ = self.d_net(self.x_) 62 | 63 | self.g_loss = tf.reduce_mean(self.d_) 64 | self.d_loss = tf.reduce_mean(self.d) - tf.reduce_mean(self.d_) 65 | 66 | epsilon = tf.random_uniform([], 0.0, 1.0) 67 | x_hat = epsilon * self.x + (1 - epsilon) * self.x_ 68 | d_hat = self.d_net(x_hat) 69 | 70 | ddx = tf.gradients(d_hat, x_hat)[0] 71 | ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1)) 72 | ddx = tf.reduce_mean(tf.square(ddx - 1.0) * scale) 73 | 74 | self.d_loss = self.d_loss + ddx 75 | 76 | self.d_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \ 77 | .minimize(self.d_loss, var_list=self.d_net.vars) 78 | self.g_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9)\ 79 | .minimize(self.g_loss, var_list=self.g_net.vars) 80 | 81 | self.saver = tf.train.Saver() 82 | 83 | run_config = tf.ConfigProto() 84 | run_config.gpu_options.per_process_gpu_memory_fraction = 1.0 85 | run_config.gpu_options.allow_growth = True 86 | self.sess = tf.Session(config=run_config) 87 | 88 | def train(self, num_batches=200000): 89 | 90 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 91 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 92 | 93 | batch_size = self.batch_size 94 | plt.ion() 95 | self.sess.run(tf.global_variables_initializer()) 96 | start_time = time.time() 97 | print('Training {} on {}, sampler = {}, z = {} dimension'.format(self.model, self.data, self.sampler, self.z_dim)) 98 | 99 | im_save_dir = 'logs/{}/{}/{}_z{}'.format(self.data, self.model, self.sampler, self.z_dim) 100 | 101 | if not os.path.exists(im_save_dir): 102 | os.makedirs(im_save_dir) 103 | 104 | for t in range(0, num_batches): 105 | d_iters = 5 106 | 107 | for _ in range(0, d_iters): 108 | 109 | bx = self.x_sampler.train(batch_size) 110 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 111 | self.sess.run(self.d_adam, feed_dict={self.x: bx, self.z: bz}) 112 | 113 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 114 | self.sess.run(self.g_adam, feed_dict={self.z: bz}) 115 | 116 | if (t+1) % 100 == 0: 117 | bx = self.x_sampler.train(batch_size) 118 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 119 | 120 | d_loss = self.sess.run( 121 | self.d_loss, feed_dict={self.x: bx, self.z: bz} 122 | ) 123 | g_loss = self.sess.run( 124 | self.g_loss, feed_dict={self.z: bz} 125 | ) 126 | print('Iter [%8d] Time [%5.4f] d_loss [%.4f] g_loss [%.4f]' % 127 | (t+1, time.time() - start_time, d_loss, g_loss)) 128 | 129 | if (t+1) % 5000 == 0: 130 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 131 | bx = self.sess.run(self.x_, feed_dict={self.z: bz}) 132 | bx = xs.data2img(bx) 133 | bx = grid_transform(bx, xs.shape) 134 | 135 | imsave('logs/{}/{}/{}_z{}/{}.png'.format(self.data, self.model, self.sampler, self.z_dim, (t+1) / 100), bx) 136 | 137 | self.recon_enc(timestamp, val=True) 138 | self.save(timestamp) 139 | 140 | def save(self, timestamp): 141 | 142 | checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}'.format(self.data, timestamp, self.model, self.sampler, self.z_dim) 143 | 144 | if not os.path.exists(checkpoint_dir): 145 | os.makedirs(checkpoint_dir) 146 | 147 | self.saver.save(self.sess, os.path.join(checkpoint_dir, 'model.ckpt')) 148 | 149 | 150 | def load(self, pre_trained = False, timestamp = ''): 151 | 152 | if pre_trained == True: 153 | print('Loading Pre-trained Model...') 154 | checkpoint_dir = 'pre_trained_models/{}/{}_{}_z{}'.format(self.data, self.model, self.sampler, self.z_dim) 155 | else: 156 | if timestamp == '': 157 | print('Best Timestamp not provided !') 158 | checkpoint_dir = '' 159 | else: 160 | checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}'.format(self.data, timestamp, self.model, self.sampler, self.z_dim) 161 | 162 | 163 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, 'model.ckpt')) 164 | print('Restored model weights.') 165 | 166 | 167 | 168 | def _gen_samples(self, num_images): 169 | 170 | batch_size = self.batch_size 171 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 172 | fake_im = self.sess.run(self.x_, feed_dict = {self.z : bz}) 173 | for t in range(num_images // batch_size): 174 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 175 | im = self.sess.run(self.x_, feed_dict = {self.z : bz}) 176 | fake_im = np.vstack((fake_im, im)) 177 | 178 | print(' Generated {} images .'.format(fake_im.shape[0])) 179 | np.save('./Image_samples/{}/{}_{}_K_{}_gen_images.npy'.format(self.data, self.model, self.sampler, self.num_classes), fake_im) 180 | 181 | 182 | 183 | def gen_from_all_modes(self): 184 | 185 | if self.sampler == 'one_hot': 186 | batch_size = 1000 187 | label_index = np.tile(np.arange(self.num_classes), int(np.ceil(batch_size * 1.0 / self.num_classes))) 188 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, num_class=self.num_classes, 189 | n_cat = self.n_cat,label_index = label_index) 190 | bx = self.sess.run(self.x_, feed_dict={self.z: bz}) 191 | 192 | for m in range(self.num_classes): 193 | print('Generating samples from mode {} ...'.format(m)) 194 | mode_index = np.where(label_index == m)[0] 195 | mode_bx = bx[mode_index,:] 196 | mode_bx = xs.data2img(mode_bx) 197 | mode_bx = grid_transform(mode_bx, xs.shape) 198 | 199 | imsave('logs/{}/{}/{}_z{}/mode{}_samples.png'.format(self.data, self.model, self.sampler, 200 | self.z_dim, m), mode_bx) 201 | 202 | 203 | def recon_enc(self, timestamp, val = True): 204 | 205 | if val: 206 | data_recon, label_recon = self.x_sampler.validation() 207 | else: 208 | data_recon, label_recon = self.x_sampler.test() 209 | #data_recon, label_recon = self.x_sampler.load_all() 210 | 211 | num_pts_to_plot = data_recon.shape[0] 212 | recon_batch_size = 1000 213 | latent = np.zeros(shape=(num_pts_to_plot, self.z_dim)) 214 | clip_lim = self.clip_lim 215 | 216 | print('Data Shape = {}, Labels Shape = {}'.format(data_recon.shape, label_recon.shape)) 217 | 218 | # Regularized Reconstruction objective 219 | 220 | self.recon_reg_loss = tf.reduce_mean(tf.abs(self.x - self.x_), 1) + \ 221 | self.beta_reg * tf.reduce_mean(tf.square(self.z[:, 0:self.dim_gen]), 1) 222 | self.compute_reg_grad = tf.gradients(self.recon_reg_loss, self.z) 223 | 224 | for b in range(int(np.ceil(num_pts_to_plot * 1.0 / recon_batch_size))): 225 | 226 | if (b+1)*recon_batch_size > num_pts_to_plot: 227 | pt_indx = np.arange(b*recon_batch_size, num_pts_to_plot) 228 | else: 229 | pt_indx = np.arange(b*recon_batch_size, (b+1)*recon_batch_size) 230 | xtrue = data_recon[pt_indx, :] 231 | 232 | num_backprop_iter = 5000 233 | num_restarts = self.num_classes 234 | seed_labels = np.tile(np.arange(self.num_classes), int(np.ceil(recon_batch_size *1.0 / self.num_classes))) 235 | seed_labels = seed_labels[0:len(pt_indx)] 236 | best_zhats = np.zeros(shape=(len(pt_indx), self.z_dim)) 237 | best_loss = np.inf * np.ones(len(pt_indx)) 238 | mu_mat = 1.0 * np.eye(self.num_classes) 239 | alg = 'adam' 240 | for t in range(num_restarts): 241 | print('Backprop Decoding [{} / {} ] ...'.format(t + 1, num_restarts)) 242 | 243 | if self.sampler == 'one_hot': 244 | label_index = (seed_labels + t) % self.num_classes 245 | zhats = util.sample_Z(len(pt_indx), self.z_dim, self.sampler, num_class=self.num_classes, 246 | n_cat=1, label_index=label_index) 247 | elif self.sampler == 'mul_cat': 248 | label_index = (seed_labels + t) % self.num_classes 249 | zhats = util.sample_Z(len(pt_indx), self.z_dim, self.sampler, num_class=self.num_classes, 250 | n_cat=self.n_cat, 251 | label_index=label_index) 252 | elif self.sampler == 'mix_gauss': 253 | label_index = (seed_labels + t) % self.num_classes 254 | zhats = util.sample_Z(len(pt_indx), self.z_dim, self.sampler, num_class=self.num_classes, 255 | n_cat=0, label_index=label_index) 256 | else: 257 | zhats = util.sample_Z(len(pt_indx), self.z_dim, self.sampler) 258 | if alg == 'adam': 259 | beta1 = 0.9 260 | beta2 = 0.999 261 | lr = 0.01 262 | eps = 1e-8 263 | m = 0 264 | v = 0 265 | elif alg == 'grad_descent': 266 | lr = 1.00 267 | 268 | 269 | for i in range(num_backprop_iter): 270 | 271 | L, g = self.sess.run([self.recon_reg_loss, self.compute_reg_grad], 272 | feed_dict={self.z: zhats, self.x: xtrue}) 273 | 274 | if alg == 'adam': 275 | m_prev = np.copy(m) 276 | v_prev = np.copy(v) 277 | m = beta1 * m_prev + (1 - beta1) * g[0] 278 | v = beta2 * v_prev + (1 - beta2) * np.multiply(g[0], g[0]) 279 | m_hat = m / (1 - beta1 ** (i + 1)) 280 | v_hat = v / (1 - beta2 ** (i + 1)) 281 | zhats += - np.true_divide(lr * m_hat, (np.sqrt(v_hat) + eps)) 282 | 283 | elif alg == 'grad_descent': 284 | zhats += - lr * g[0] 285 | 286 | zhats = np.clip(zhats, a_min=clip_lim[0], a_max=clip_lim[1]) 287 | 288 | if self.sampler == 'one_hot': 289 | zhats[:, -self.num_classes:] = mu_mat[label_index,:] 290 | elif self.sampler == 'mul_hot': 291 | zhats[:, self.dim_gen:] = np.tile(mu_mat[label_index,:], (1, self.n_cat)) 292 | 293 | change_index = best_loss > L 294 | best_zhats[change_index, :] = zhats[change_index, :] 295 | best_loss[change_index] = L[change_index] 296 | 297 | latent[pt_indx, :] = best_zhats 298 | print(' [{} / {} ] ...'.format((b + 1) * recon_batch_size, num_pts_to_plot)) 299 | 300 | 301 | self._eval_cluster(latent, label_recon, timestamp, val) 302 | 303 | def _eval_cluster(self, latent_rep, labels_true, timestamp, val): 304 | 305 | if self.data == 'fashion' and self.num_classes == 5: 306 | map_labels = {0 : 0, 1 : 1, 2 : 2, 3 : 0, 4 : 2, 5 : 3, 6 : 2, 7 : 3, 8 : 4, 9 : 3} 307 | labels_true = np.array([map_labels[i] for i in labels_true]) 308 | 309 | km = KMeans(n_clusters=max(self.num_classes, len(np.unique(labels_true))), random_state=0).fit(latent_rep) 310 | labels_pred = km.labels_ 311 | 312 | purity = metric.compute_purity(labels_pred, labels_true) 313 | ari = adjusted_rand_score(labels_true, labels_pred) 314 | nmi = normalized_mutual_info_score(labels_true, labels_pred) 315 | 316 | if val: 317 | data_split = 'Validation' 318 | else: 319 | data_split = 'Test' 320 | #data_split = 'All' 321 | 322 | print('Data = {}, Model = {}, sampler = {}, z_dim = {}, beta_reg = {}' 323 | .format(self.data, self.model, self.sampler, self.z_dim, self.beta_reg)) 324 | print(' #Points = {}, K = {}, Purity = {}, NMI = {}, ARI = {}, ' 325 | .format(latent_rep.shape[0], self.num_classes, purity, nmi, ari)) 326 | 327 | with open('logs/Res_{}_{}.txt'.format(self.data, self.model), 'a+') as f: 328 | f.write('{}, {} : K = {}, z_dim = {}, beta_reg = {}, sampler = {}, Purity = {}, NMI = {}, ARI = {}\n' 329 | .format(timestamp, data_split, self.num_classes, self.z_dim, self.beta_reg, 330 | self.sampler, purity, nmi, ari)) 331 | f.flush() 332 | 333 | 334 | 335 | if __name__ == '__main__': 336 | parser = argparse.ArgumentParser('') 337 | parser.add_argument('--data', type=str, default='mnist') 338 | parser.add_argument('--model', type=str, default='dcgan') 339 | parser.add_argument('--sampler', type=str, default='one_hot') 340 | parser.add_argument('--K', type=int, default=10) 341 | parser.add_argument('--dz', type=int, default=30) 342 | parser.add_argument('--bs', type=int, default=64) 343 | parser.add_argument('--beta_reg', type=float, default=10.0) 344 | parser.add_argument('--timestamp', type=str, default='') 345 | parser.add_argument('--train', type=str, default='False') 346 | 347 | args = parser.parse_args() 348 | data = importlib.import_module(args.data) 349 | model = importlib.import_module(args.data + '.' + args.model) 350 | 351 | num_classes = args.K 352 | dim_gen = args.dz 353 | n_cat = 1 354 | batch_size = args.bs 355 | beta_reg = args.beta_reg 356 | timestamp = args.timestamp 357 | 358 | z_dim = dim_gen + num_classes * n_cat 359 | d_net = model.Discriminator() 360 | g_net = model.Generator(z_dim = z_dim) 361 | xs = data.DataSampler() 362 | zs = util.sample_Z 363 | 364 | wgan = WassersteinGAN(g_net, d_net, xs, zs, args.data, args.model, args.sampler, 365 | num_classes, dim_gen, n_cat, batch_size, beta_reg) 366 | 367 | if args.train == 'True': 368 | wgan.train() 369 | else: 370 | 371 | print('Attempting to Restore Model ...') 372 | if timestamp == '': 373 | wgan.load(pre_trained=True) 374 | timestamp = 'pre-trained' 375 | else: 376 | wgan.load(pre_trained=False, timestamp = timestamp) 377 | 378 | wgan.recon_enc(timestamp, val=False) 379 | 380 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ClusterGAN 2 | 3 | Code for reproducing key results in the paper [ClusterGAN : Latent Space Clustering in Generative Adversarial Networks](https://arxiv.org/abs/1809.03627) by Sudipto Mukherjee, Himanshu Asnani, Eugene Lin and Sreeram Kannan. If you use the code, please cite our paper. 4 | 5 | ## Dependencies 6 | 7 | The code has been tested with the following versions of packages. 8 | - Python 2.7.12 9 | - Tensorflow 1.4.0 10 | - Numpy 1.14.2 11 | 12 | ## Datasets 13 | 14 | The datasets used in the paper can be downloaded from the Google Drive link (https://drive.google.com/open?id=1XnGkSamF5DiwnpHFG0OexmoqAwe27ucR). 15 | Unzip the folder so that the path is : ./ClusterGAN/data/ 16 | 17 | ## Training 18 | 19 | You can either train your own models on the datasets or use pre-trained models. Even though we have used a fixed seed using tf.random.seed(0), there will still be randomness introduced by CUDA. So, to reproduce the results, train 5 models and compare the Validation purity in the logs directory. Each model can be trained as follows : 20 | 21 | ```bash 22 | $ python Image_Cluster.py --data mnist --K 10 --dz 30 --beta_n 10 --beta_c 10 --train True 23 | ``` 24 | 25 | This will save the model along with timestamp in checkpoint-dir/. Also, the Validation set performance will be written to logs/Res__.txt. Then run the best model (with highest Validation Purity) on the Test set. 26 | 27 | ```bash 28 | $ python Image_ClusterGAN.py --data mnist --K 10 --dz 30 --beta_n 10 --beta_c 10 --timestamp 29 | ``` 30 | 31 | Training the models for other datasets has a similar format. 32 | 33 | Fashion-10 : 34 | ```bash 35 | $ python Image_ClusterGAN.py --data fashion --K 10 --dz 40 --beta_n 0 --beta_c 10 --train True 36 | ``` 37 | 38 | Fashion-5 : 39 | ```bash 40 | $ python Image_ClusterGAN.py --data fashion --K 5 --dz 40 --beta_n 0 --beta_c 10 --train True 41 | ``` 42 | 43 | Single Cell 10x genomics : 44 | ```bash 45 | $ python Gene_ClusterGAN.py --data 10x_73k --K 8 --dz 30 --beta_n 10 --beta_c 10 --train True 46 | ``` 47 | 48 | Pendigits : 49 | ```bash 50 | $ python pen_ClusterGAN.py --data pendigit --K 10 --dz 5 --beta_n 10 --beta_c 10 --train True 51 | ``` 52 | 53 | Provide the timestamp of best saved model to obtain the Test set clustering performance on all the datasets (similar to MNIST above). 54 | 55 | ## Pre-trained models 56 | 57 | Run the following code : 58 | 59 | ```bash 60 | $ python Image_ClusterGAN.py --data mnist --K 10 --dz 30 --beta_n 10 --beta_c 10 61 | ``` 62 | 63 | Similarly for the other datasets. 64 | 65 | ## Clustering Performance 66 | 67 | Table shows the mean +- standard deviation of 10 runs of ClusterGAN (with the reported hyperparameter settings in the paper) for various datasets. 68 | 69 | | Dataset | ACC | NMI | ARI | 70 | |:-------------:|:-------------------:|:-------------------:|:---------------------:| 71 | | MNIST | 0.9097 +- 0.0398 | 0.8544 +- 0.0361 | 0.8290 +- 0.0621 | 72 | | Fashion-10 | 0.6119 +- 0.0230 | 0.6157 +- 0.0112 | 0.4617 +- 0.0226 | 73 | | Fashion-5 | 0.7218 +- 0.0089 | 0.6163 +- 0.0243 | 0.5035 +- 0.0228 | 74 | | 10x_73k | 0.8172 +- 0.0262 | 0.7272 +- 0.0322 | 0.6786 +- 0.0369 | 75 | | Pendigits | 0.7638 +- 0.0120 | 0.7343 +- 0.0120 | 0.6336 +- 0.0177 | 76 | 77 | 78 | ## Feedback 79 | 80 | Please feel free to provide any feedback about the code to sudipto.ece.ju@gmail.com 81 | 82 | 83 | -------------------------------------------------------------------------------- /fashion/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.examples.tutorials.mnist import input_data 3 | from sklearn.decomposition import PCA 4 | from sklearn.cluster import KMeans 5 | mnist = input_data.read_data_sets('./data/fashion') 6 | 7 | 8 | class DataSampler(object): 9 | def __init__(self): 10 | self.shape = [28, 28, 1] 11 | 12 | def train(self, batch_size, label=False): 13 | if label: 14 | return mnist.train.next_batch(batch_size) 15 | else: 16 | return mnist.train.next_batch(batch_size)[0] 17 | 18 | def test(self): 19 | return mnist.test.images, mnist.test.labels 20 | 21 | def validation(self): 22 | return mnist.validation.images, mnist.validation.labels 23 | 24 | 25 | def data2img(self, data): 26 | return np.reshape(data, [data.shape[0]] + self.shape) 27 | 28 | def load_all(self): 29 | 30 | X_train = mnist.train.images 31 | X_val = mnist.validation.images 32 | X_test = mnist.test.images 33 | 34 | Y_train = mnist.train.labels 35 | Y_val = mnist.validation.labels 36 | Y_test = mnist.test.labels 37 | 38 | X = np.concatenate((X_train, X_val, X_test)) 39 | Y = np.concatenate((Y_train, Y_val, Y_test)) 40 | 41 | return X, Y.flatten() 42 | 43 | 44 | class NoiseSampler(object): 45 | 46 | def __init__(self, z_dim = 100, mode='uniform'): 47 | self.mode = mode 48 | self.z_dim = z_dim 49 | self.K = 10 50 | 51 | if self.mode == 'mix_gauss': 52 | self.mu_mat = (1.0) * np.eye(self.K, self.z_dim) 53 | self.sig = 0.1 54 | 55 | elif self.mode == 'one_hot': 56 | self.mu_mat = (1.0) * np.eye(self.K) 57 | self.sig = 0.10 58 | 59 | 60 | elif self.mode == 'pca_kmeans': 61 | 62 | data_x = mnist.train.images 63 | feature_mean = np.mean(data_x, axis = 0) 64 | data_x -= feature_mean 65 | data_embed = PCA(n_components=self.z_dim, random_state=0).fit_transform(data_x) 66 | data_x += feature_mean 67 | kmeans = KMeans(n_clusters=self.K, random_state=0) 68 | kmeans.fit(data_embed) 69 | self.mu_mat = kmeans.cluster_centers_ 70 | shift = np.min(self.mu_mat) 71 | scale = np.max(self.mu_mat - shift) 72 | self.mu_mat = (self.mu_mat - shift)/scale 73 | self.sig = 0.10 74 | 75 | 76 | def __call__(self, batch_size, z_dim): 77 | if self.mode == 'uniform': 78 | return np.random.uniform(-1.0, 1.0, [batch_size, z_dim]) 79 | elif self.mode == 'normal': 80 | return 0.15*np.random.randn(batch_size, z_dim) 81 | elif self.mode == 'mix_gauss': 82 | k = np.random.randint(low = 0, high = self.K, size=batch_size) 83 | return self.sig*np.random.randn(batch_size, z_dim) + self.mu_mat[k] 84 | elif self.mode == 'pca_kmeans': 85 | k = np.random.randint(low=0, high=self.K, size=batch_size) 86 | return self.sig * np.random.randn(batch_size, z_dim) + self.mu_mat[k] 87 | elif self.mode == 'one_hot': 88 | k = np.random.randint(low=0, high=self.K, size=batch_size) 89 | return np.hstack((self.sig * np.random.randn(batch_size, z_dim-self.K), self.mu_mat[k])) 90 | 91 | 92 | 93 | if __name__=='__main__': 94 | 95 | data_x = mnist.train.images 96 | from sklearn.decomposition import PCA 97 | from sklearn.cluster import KMeans 98 | import pdb 99 | mu = np.mean(data_x, axis = 0) 100 | data_x -= mu 101 | print('Computing PCA ...') 102 | data_embed = PCA(n_components=10, random_state=0).fit_transform(data_x) 103 | print('Done !') 104 | data_x += mu 105 | 106 | print('Computing kmeans ...') 107 | kmeans = KMeans(n_clusters=4, random_state=0) 108 | kmeans.fit(data_embed) 109 | print('Done !') 110 | mu_mat = kmeans.cluster_centers_ 111 | shift = np.min(mu_mat) 112 | scale = np.max(mu_mat - shift) 113 | mu_mat = (mu_mat - shift)/scale 114 | print(mu_mat.shape) 115 | print("Done !") 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /fashion/clus_wgan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tc 3 | import tensorflow.contrib.layers as tcl 4 | 5 | 6 | def leaky_relu(x, alpha=0.2): 7 | return tf.maximum(tf.minimum(0.0, alpha * x), x) 8 | 9 | 10 | class Discriminator(object): 11 | def __init__(self, x_dim = 784): 12 | self.x_dim = x_dim 13 | self.name = 'fashion/clus_wgan/d_net' 14 | 15 | def __call__(self, x, reuse=True): 16 | with tf.variable_scope(self.name) as vs: 17 | if reuse: 18 | vs.reuse_variables() 19 | bs = tf.shape(x)[0] 20 | 21 | x = tf.reshape(x, [bs, 28, 28, 1]) 22 | conv1 = tc.layers.convolution2d( 23 | x, 64, [4, 4], [2, 2], 24 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 25 | activation_fn=tf.identity 26 | ) 27 | conv1 = leaky_relu(conv1) 28 | conv2 = tc.layers.convolution2d( 29 | conv1, 128, [4, 4], [2, 2], 30 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 31 | activation_fn=tf.identity 32 | ) 33 | conv2 = leaky_relu(conv2) 34 | conv2 = tcl.flatten(conv2) 35 | 36 | fc1 = tc.layers.fully_connected( 37 | conv2, 1024, 38 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 39 | activation_fn=tf.identity 40 | ) 41 | fc1 = leaky_relu(fc1) 42 | fc2 = tc.layers.fully_connected(fc1, 1, activation_fn=tf.identity) 43 | return fc2 44 | 45 | @property 46 | def vars(self): 47 | return [var for var in tf.global_variables() if self.name in var.name] 48 | 49 | 50 | class Generator(object): 51 | def __init__(self, z_dim = 40, x_dim = 784): 52 | self.z_dim = z_dim 53 | self.x_dim = x_dim 54 | self.name = 'fashion/clus_wgan/g_net' 55 | 56 | def __call__(self, z): 57 | with tf.variable_scope(self.name) as vs: 58 | bs = tf.shape(z)[0] 59 | fc1 = tc.layers.fully_connected( 60 | z, 1024, 61 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 62 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 63 | activation_fn=tf.identity 64 | ) 65 | fc1 = tc.layers.batch_norm(fc1) 66 | fc1 = tf.nn.relu(fc1) 67 | fc2 = tc.layers.fully_connected( 68 | fc1, 7 * 7 * 128, 69 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 70 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 71 | activation_fn=tf.identity 72 | ) 73 | fc2 = tf.reshape(fc2, tf.stack([bs, 7, 7, 128])) 74 | fc2 = tc.layers.batch_norm(fc2) 75 | fc2 = tf.nn.relu(fc2) 76 | conv1 = tc.layers.convolution2d_transpose( 77 | fc2, 64, [4, 4], [2, 2], 78 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 79 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 80 | activation_fn=tf.identity 81 | ) 82 | conv1 = tc.layers.batch_norm(conv1) 83 | conv1 = tf.nn.relu(conv1) 84 | conv2 = tc.layers.convolution2d_transpose( 85 | conv1, 1, [4, 4], [2, 2], 86 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 87 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 88 | activation_fn=tf.sigmoid 89 | ) 90 | conv2 = tf.reshape(conv2, tf.stack([bs, self.x_dim])) 91 | return conv2 92 | 93 | @property 94 | def vars(self): 95 | return [var for var in tf.global_variables() if self.name in var.name] 96 | 97 | 98 | class Encoder(object): 99 | def __init__(self, z_dim = 40, dim_gen = 30, x_dim = 784): 100 | self.z_dim = z_dim 101 | self.dim_gen = dim_gen 102 | self.x_dim = x_dim 103 | self.name = 'fashion/clus_wgan/enc_net' 104 | 105 | def __call__(self, x, reuse=True): 106 | 107 | with tf.variable_scope(self.name) as vs: 108 | if reuse: 109 | vs.reuse_variables() 110 | bs = tf.shape(x)[0] 111 | x = tf.reshape(x, [bs, 28, 28, 1]) 112 | conv1 = tc.layers.convolution2d( 113 | x, 64, [4, 4], [2, 2], 114 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 115 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 116 | activation_fn=tf.identity 117 | ) 118 | conv1 = leaky_relu(conv1) 119 | conv2 = tc.layers.convolution2d( 120 | conv1, 128, [4, 4], [2, 2], 121 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 122 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 123 | activation_fn=tf.identity 124 | ) 125 | conv2 = leaky_relu(conv2) 126 | conv2 = tcl.flatten(conv2) 127 | fc1 = tc.layers.fully_connected( 128 | conv2, 1024, 129 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 130 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 131 | activation_fn=tf.identity 132 | ) 133 | fc1 = leaky_relu(fc1) 134 | fc2 = tc.layers.fully_connected(fc1, self.z_dim, activation_fn=tf.identity) 135 | logits = fc2[:, self.dim_gen:] 136 | y = tf.nn.softmax(logits) 137 | return fc2[:, 0:self.dim_gen], y, logits 138 | 139 | 140 | @property 141 | def vars(self): 142 | return [var for var in tf.global_variables() if self.name in var.name] 143 | -------------------------------------------------------------------------------- /fashion/dcgan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tc 3 | import tensorflow.contrib.layers as tcl 4 | 5 | 6 | def leaky_relu(x, alpha=0.2): 7 | return tf.maximum(tf.minimum(0.0, alpha * x), x) 8 | 9 | 10 | class Discriminator(object): 11 | def __init__(self, x_dim = 784): 12 | self.x_dim = x_dim 13 | self.name = 'fashion/dcgan/d_net' 14 | 15 | def __call__(self, x, reuse=True): 16 | with tf.variable_scope(self.name) as vs: 17 | if reuse: 18 | vs.reuse_variables() 19 | bs = tf.shape(x)[0] 20 | x = tf.reshape(x, [bs, 28, 28, 1]) 21 | conv1 = tc.layers.convolution2d( 22 | x, 64, [4, 4], [2, 2], 23 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 24 | activation_fn=tf.identity 25 | ) 26 | conv1 = leaky_relu(conv1) 27 | conv2 = tc.layers.convolution2d( 28 | conv1, 128, [4, 4], [2, 2], 29 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 30 | activation_fn=tf.identity 31 | ) 32 | conv2 = leaky_relu(conv2) 33 | conv2 = tcl.flatten(conv2) 34 | fc1 = tc.layers.fully_connected( 35 | conv2, 1024, 36 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 37 | activation_fn=tf.identity 38 | ) 39 | fc1 = leaky_relu(fc1) 40 | fc2 = tc.layers.fully_connected(fc1, 1, activation_fn=tf.identity) 41 | return fc2 42 | 43 | @property 44 | def vars(self): 45 | return [var for var in tf.global_variables() if self.name in var.name] 46 | 47 | 48 | class Generator(object): 49 | def __init__(self, z_dim = 10, x_dim = 784): 50 | self.z_dim = z_dim 51 | self.x_dim = x_dim 52 | self.name = 'fashion/dcgan/g_net' 53 | 54 | def __call__(self, z): 55 | with tf.variable_scope(self.name) as vs: 56 | bs = tf.shape(z)[0] 57 | fc1 = tc.layers.fully_connected( 58 | z, 1024, 59 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 60 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 61 | activation_fn=tf.identity 62 | ) 63 | fc1 = tc.layers.batch_norm(fc1) 64 | fc1 = tf.nn.relu(fc1) 65 | fc2 = tc.layers.fully_connected( 66 | fc1, 7 * 7 * 128, 67 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 68 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 69 | activation_fn=tf.identity 70 | ) 71 | fc2 = tf.reshape(fc2, tf.stack([bs, 7, 7, 128])) 72 | fc2 = tc.layers.batch_norm(fc2) 73 | fc2 = tf.nn.relu(fc2) 74 | conv1 = tc.layers.convolution2d_transpose( 75 | fc2, 64, [4, 4], [2, 2], 76 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 77 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 78 | activation_fn=tf.identity 79 | ) 80 | conv1 = tc.layers.batch_norm(conv1) 81 | conv1 = tf.nn.relu(conv1) 82 | conv2 = tc.layers.convolution2d_transpose( 83 | conv1, 1, [4, 4], [2, 2], 84 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 85 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 86 | activation_fn=tf.sigmoid 87 | ) 88 | conv2 = tf.reshape(conv2, tf.stack([bs, 784])) 89 | return conv2 90 | 91 | @property 92 | def vars(self): 93 | return [var for var in tf.global_variables() if self.name in var.name] 94 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def compute_purity(y_pred, y_true): 4 | """ 5 | Calculate the purity, a measurement of quality for the clustering 6 | results. 7 | 8 | Each cluster is assigned to the class which is most frequent in the 9 | cluster. Using these classes, the percent accuracy is then calculated. 10 | 11 | Returns: 12 | A number between 0 and 1. Poor clusterings have a purity close to 0 13 | while a perfect clustering has a purity of 1. 14 | 15 | """ 16 | 17 | # get the set of unique cluster ids 18 | clusters = set(y_pred) 19 | 20 | # find out what class is most frequent in each cluster 21 | cluster_classes = {} 22 | correct = 0 23 | for cluster in clusters: 24 | # get the indices of rows in this cluster 25 | indices = np.where(y_pred == cluster)[0] 26 | 27 | cluster_labels = y_true[indices] 28 | majority_label = np.argmax(np.bincount(cluster_labels)) 29 | correct += np.sum(cluster_labels == majority_label) 30 | 31 | return float(correct) / len(y_pred) 32 | 33 | -------------------------------------------------------------------------------- /mnist/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.examples.tutorials.mnist import input_data 3 | from sklearn.decomposition import PCA 4 | from sklearn.cluster import KMeans 5 | mnist = input_data.read_data_sets('./data/mnist') 6 | 7 | 8 | class DataSampler(object): 9 | def __init__(self): 10 | self.shape = [28, 28, 1] 11 | 12 | def train(self, batch_size, label=False): 13 | if label: 14 | return mnist.train.next_batch(batch_size) 15 | else: 16 | return mnist.train.next_batch(batch_size)[0] 17 | 18 | def test(self): 19 | return mnist.test.images, mnist.test.labels 20 | 21 | def validation(self): 22 | return mnist.validation.images, mnist.validation.labels 23 | 24 | 25 | def data2img(self, data): 26 | return np.reshape(data, [data.shape[0]] + self.shape) 27 | 28 | def load_all(self): 29 | 30 | X_train = mnist.train.images 31 | X_val = mnist.validation.images 32 | X_test = mnist.test.images 33 | 34 | Y_train = mnist.train.labels 35 | Y_val = mnist.validation.labels 36 | Y_test = mnist.test.labels 37 | 38 | X = np.concatenate((X_train, X_val, X_test)) 39 | Y = np.concatenate((Y_train, Y_val, Y_test)) 40 | 41 | return X, Y.flatten() 42 | 43 | 44 | class NoiseSampler(object): 45 | 46 | def __init__(self, z_dim = 100, mode='uniform'): 47 | self.mode = mode 48 | self.z_dim = z_dim 49 | self.K = 10 50 | 51 | if self.mode == 'mix_gauss': 52 | self.mu_mat = (1.0) * np.eye(self.K, self.z_dim) 53 | self.sig = 0.1 54 | 55 | elif self.mode == 'one_hot': 56 | self.mu_mat = (1.0) * np.eye(self.K) 57 | self.sig = 0.10 58 | 59 | 60 | elif self.mode == 'pca_kmeans': 61 | 62 | data_x = mnist.train.images 63 | feature_mean = np.mean(data_x, axis = 0) 64 | data_x -= feature_mean 65 | data_embed = PCA(n_components=self.z_dim, random_state=0).fit_transform(data_x) 66 | data_x += feature_mean 67 | kmeans = KMeans(n_clusters=self.K, random_state=0) 68 | kmeans.fit(data_embed) 69 | self.mu_mat = kmeans.cluster_centers_ 70 | shift = np.min(self.mu_mat) 71 | scale = np.max(self.mu_mat - shift) 72 | self.mu_mat = (self.mu_mat - shift)/scale 73 | self.sig = 0.15 74 | 75 | 76 | def __call__(self, batch_size, z_dim): 77 | if self.mode == 'uniform': 78 | return np.random.uniform(-1.0, 1.0, [batch_size, z_dim]) 79 | elif self.mode == 'normal': 80 | return 0.15*np.random.randn(batch_size, z_dim) 81 | elif self.mode == 'mix_gauss': 82 | k = np.random.randint(low = 0, high = self.K, size=batch_size) 83 | return self.sig*np.random.randn(batch_size, z_dim) + self.mu_mat[k] 84 | elif self.mode == 'pca_kmeans': 85 | k = np.random.randint(low=0, high=self.K, size=batch_size) 86 | return self.sig * np.random.randn(batch_size, z_dim) + self.mu_mat[k] 87 | elif self.mode == 'one_hot': 88 | k = np.random.randint(low=0, high=self.K, size=batch_size) 89 | return np.hstack((self.sig * np.random.randn(batch_size, z_dim-self.K), self.mu_mat[k])) 90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /mnist/clus_wgan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tc 3 | import tensorflow.contrib.layers as tcl 4 | 5 | 6 | def leaky_relu(x, alpha=0.2): 7 | return tf.maximum(tf.minimum(0.0, alpha * x), x) 8 | 9 | 10 | class Discriminator(object): 11 | def __init__(self, x_dim = 784): 12 | self.x_dim = x_dim 13 | self.name = 'mnist/clus_wgan/d_net' 14 | 15 | def __call__(self, x, reuse=True): 16 | with tf.variable_scope(self.name) as vs: 17 | if reuse: 18 | vs.reuse_variables() 19 | bs = tf.shape(x)[0] 20 | 21 | x = tf.reshape(x, [bs, 28, 28, 1]) 22 | conv1 = tc.layers.convolution2d( 23 | x, 64, [4, 4], [2, 2], 24 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 25 | activation_fn=tf.identity 26 | ) 27 | conv1 = leaky_relu(conv1) 28 | conv2 = tc.layers.convolution2d( 29 | conv1, 128, [4, 4], [2, 2], 30 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 31 | activation_fn=tf.identity 32 | ) 33 | conv2 = leaky_relu(conv2) 34 | conv2 = tcl.flatten(conv2) 35 | 36 | fc1 = tc.layers.fully_connected( 37 | conv2, 1024, 38 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 39 | activation_fn=tf.identity 40 | ) 41 | fc1 = leaky_relu(fc1) 42 | fc2 = tc.layers.fully_connected(fc1, 1, activation_fn=tf.identity) 43 | return fc2 44 | 45 | @property 46 | def vars(self): 47 | return [var for var in tf.global_variables() if self.name in var.name] 48 | 49 | 50 | class Generator(object): 51 | def __init__(self, z_dim = 10, x_dim = 784): 52 | self.z_dim = z_dim 53 | self.x_dim = x_dim 54 | self.name = 'mnist/clus_wgan/g_net' 55 | 56 | def __call__(self, z): 57 | with tf.variable_scope(self.name) as vs: 58 | bs = tf.shape(z)[0] 59 | fc1 = tc.layers.fully_connected( 60 | z, 1024, 61 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 62 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 63 | activation_fn=tf.identity 64 | ) 65 | fc1 = tc.layers.batch_norm(fc1) 66 | fc1 = tf.nn.relu(fc1) 67 | fc2 = tc.layers.fully_connected( 68 | fc1, 7 * 7 * 128, 69 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 70 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 71 | activation_fn=tf.identity 72 | ) 73 | fc2 = tf.reshape(fc2, tf.stack([bs, 7, 7, 128])) 74 | fc2 = tc.layers.batch_norm(fc2) 75 | fc2 = tf.nn.relu(fc2) 76 | conv1 = tc.layers.convolution2d_transpose( 77 | fc2, 64, [4, 4], [2, 2], 78 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 79 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 80 | activation_fn=tf.identity 81 | ) 82 | conv1 = tc.layers.batch_norm(conv1) 83 | conv1 = tf.nn.relu(conv1) 84 | conv2 = tc.layers.convolution2d_transpose( 85 | conv1, 1, [4, 4], [2, 2], 86 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 87 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 88 | activation_fn=tf.sigmoid 89 | ) 90 | conv2 = tf.reshape(conv2, tf.stack([bs, self.x_dim])) 91 | return conv2 92 | 93 | @property 94 | def vars(self): 95 | return [var for var in tf.global_variables() if self.name in var.name] 96 | 97 | 98 | class Encoder(object): 99 | def __init__(self, z_dim = 10, dim_gen = 10, x_dim = 784): 100 | self.z_dim = z_dim 101 | self.dim_gen = dim_gen 102 | self.x_dim = x_dim 103 | self.name = 'mnist/clus_wgan/enc_net' 104 | 105 | def __call__(self, x, reuse=True): 106 | 107 | with tf.variable_scope(self.name) as vs: 108 | if reuse: 109 | vs.reuse_variables() 110 | bs = tf.shape(x)[0] 111 | x = tf.reshape(x, [bs, 28, 28, 1]) 112 | conv1 = tc.layers.convolution2d( 113 | x, 64, [4, 4], [2, 2], 114 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 115 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 116 | activation_fn=tf.identity 117 | ) 118 | conv1 = leaky_relu(conv1) 119 | conv2 = tc.layers.convolution2d( 120 | conv1, 128, [4, 4], [2, 2], 121 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 122 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 123 | activation_fn=tf.identity 124 | ) 125 | conv2 = leaky_relu(conv2) 126 | conv2 = tcl.flatten(conv2) 127 | fc1 = tc.layers.fully_connected( 128 | conv2, 1024, 129 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 130 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 131 | activation_fn=tf.identity 132 | ) 133 | fc1 = leaky_relu(fc1) 134 | fc2 = tc.layers.fully_connected(fc1, self.z_dim, activation_fn=tf.identity) 135 | logits = fc2[:, self.dim_gen:] 136 | y = tf.nn.softmax(logits) 137 | return fc2[:, 0:self.dim_gen], y, logits 138 | 139 | 140 | @property 141 | def vars(self): 142 | return [var for var in tf.global_variables() if self.name in var.name] 143 | -------------------------------------------------------------------------------- /mnist/dcgan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tc 3 | import tensorflow.contrib.layers as tcl 4 | 5 | 6 | def leaky_relu(x, alpha=0.2): 7 | return tf.maximum(tf.minimum(0.0, alpha * x), x) 8 | 9 | 10 | class Discriminator(object): 11 | def __init__(self, x_dim = 784): 12 | self.x_dim = x_dim 13 | self.name = 'mnist/dcgan/d_net' 14 | 15 | def __call__(self, x, reuse=True): 16 | with tf.variable_scope(self.name) as vs: 17 | if reuse: 18 | vs.reuse_variables() 19 | bs = tf.shape(x)[0] 20 | x = tf.reshape(x, [bs, 28, 28, 1]) 21 | conv1 = tc.layers.convolution2d( 22 | x, 64, [4, 4], [2, 2], 23 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 24 | activation_fn=tf.identity 25 | ) 26 | conv1 = leaky_relu(conv1) 27 | conv2 = tc.layers.convolution2d( 28 | conv1, 128, [4, 4], [2, 2], 29 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 30 | activation_fn=tf.identity 31 | ) 32 | conv2 = leaky_relu(conv2) 33 | conv2 = tcl.flatten(conv2) 34 | fc1 = tc.layers.fully_connected( 35 | conv2, 1024, 36 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 37 | activation_fn=tf.identity 38 | ) 39 | fc1 = leaky_relu(fc1) 40 | fc2 = tc.layers.fully_connected(fc1, 1, activation_fn=tf.identity) 41 | return fc2 42 | 43 | @property 44 | def vars(self): 45 | return [var for var in tf.global_variables() if self.name in var.name] 46 | 47 | 48 | class Generator(object): 49 | def __init__(self, z_dim = 10, x_dim = 784): 50 | self.z_dim = z_dim 51 | self.x_dim = x_dim 52 | self.name = 'mnist/dcgan/g_net' 53 | 54 | def __call__(self, z): 55 | with tf.variable_scope(self.name) as vs: 56 | bs = tf.shape(z)[0] 57 | fc1 = tc.layers.fully_connected( 58 | z, 1024, 59 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 60 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 61 | activation_fn=tf.identity 62 | ) 63 | fc1 = tc.layers.batch_norm(fc1) 64 | fc1 = tf.nn.relu(fc1) 65 | fc2 = tc.layers.fully_connected( 66 | fc1, 7 * 7 * 128, 67 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 68 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 69 | activation_fn=tf.identity 70 | ) 71 | fc2 = tf.reshape(fc2, tf.stack([bs, 7, 7, 128])) 72 | fc2 = tc.layers.batch_norm(fc2) 73 | fc2 = tf.nn.relu(fc2) 74 | conv1 = tc.layers.convolution2d_transpose( 75 | fc2, 64, [4, 4], [2, 2], 76 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 77 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 78 | activation_fn=tf.identity 79 | ) 80 | conv1 = tc.layers.batch_norm(conv1) 81 | conv1 = tf.nn.relu(conv1) 82 | conv2 = tc.layers.convolution2d_transpose( 83 | conv1, 1, [4, 4], [2, 2], 84 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 85 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 86 | activation_fn=tf.sigmoid 87 | ) 88 | conv2 = tf.reshape(conv2, tf.stack([bs, 784])) 89 | return conv2 90 | 91 | @property 92 | def vars(self): 93 | return [var for var in tf.global_variables() if self.name in var.name] 94 | -------------------------------------------------------------------------------- /pen_ClusterGAN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import dateutil.tz 4 | import datetime 5 | import argparse 6 | import importlib 7 | import tensorflow as tf 8 | import numpy as np 9 | from sklearn.cluster import KMeans 10 | from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score 11 | 12 | import metric 13 | import util 14 | 15 | tf.set_random_seed(0) 16 | 17 | class clusGAN(object): 18 | def __init__(self, g_net, d_net, enc_net, x_sampler, z_sampler, data, model, sampler, 19 | num_classes, dim_gen, n_cat, batch_size, beta_cycle_gen, beta_cycle_label): 20 | self.model = model 21 | self.data = data 22 | self.sampler = sampler 23 | self.g_net = g_net 24 | self.d_net = d_net 25 | self.enc_net = enc_net 26 | self.x_sampler = x_sampler 27 | self.z_sampler = z_sampler 28 | self.num_classes = num_classes 29 | self.dim_gen = dim_gen 30 | self.n_cat = n_cat 31 | self.batch_size = batch_size 32 | scale = 10.0 33 | self.beta_cycle_gen = beta_cycle_gen 34 | self.beta_cycle_label = beta_cycle_label 35 | 36 | 37 | self.x_dim = self.d_net.x_dim 38 | self.z_dim = self.g_net.z_dim 39 | 40 | 41 | self.x = tf.placeholder(tf.float32, [None, self.x_dim], name='x') 42 | self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z') 43 | 44 | self.z_gen = self.z[:,0:self.dim_gen] 45 | self.z_hot = self.z[:,self.dim_gen:] 46 | 47 | self.x_ = self.g_net(self.z) 48 | self.z_enc_gen, self.z_enc_label, self.z_enc_logits = self.enc_net(self.x_, reuse=False) 49 | self.z_infer_gen, self.z_infer_label, self.z_infer_logits = self.enc_net(self.x) 50 | 51 | 52 | self.d = self.d_net(self.x, reuse=False) 53 | self.d_ = self.d_net(self.x_) 54 | 55 | 56 | self.g_loss = tf.reduce_mean(self.d_) + \ 57 | self.beta_cycle_gen * tf.reduce_mean(tf.square(self.z_gen - self.z_enc_gen)) +\ 58 | self.beta_cycle_label * tf.reduce_mean( 59 | tf.nn.softmax_cross_entropy_with_logits(logits=self.z_enc_logits,labels=self.z_hot)) 60 | 61 | self.d_loss = tf.reduce_mean(self.d) - tf.reduce_mean(self.d_) 62 | 63 | epsilon = tf.random_uniform([], 0.0, 1.0) 64 | x_hat = epsilon * self.x + (1 - epsilon) * self.x_ 65 | d_hat = self.d_net(x_hat) 66 | 67 | ddx = tf.gradients(d_hat, x_hat)[0] 68 | ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1)) 69 | ddx = tf.reduce_mean(tf.square(ddx - 1.0) * scale) 70 | 71 | self.d_loss = self.d_loss + ddx 72 | 73 | 74 | self.d_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \ 75 | .minimize(self.d_loss, var_list=self.d_net.vars) 76 | self.g_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \ 77 | .minimize(self.g_loss, var_list=[self.g_net.vars, self.enc_net.vars]) 78 | 79 | 80 | self.saver = tf.train.Saver() 81 | 82 | run_config = tf.ConfigProto() 83 | run_config.gpu_options.per_process_gpu_memory_fraction = 1.0 84 | run_config.gpu_options.allow_growth = True 85 | self.sess = tf.Session(config=run_config) 86 | 87 | def train(self, num_batches=500000): 88 | 89 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 90 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 91 | 92 | batch_size = self.batch_size 93 | self.sess.run(tf.global_variables_initializer()) 94 | start_time = time.time() 95 | print( 96 | 'Training {} on {}, sampler = {}, z = {} dimension, beta_n = {}, beta_c = {}'. 97 | format(self.model, self.data, self.sampler, self.z_dim, self.beta_cycle_gen, self.beta_cycle_label)) 98 | 99 | 100 | for t in range(0, num_batches): 101 | d_iters = 5 102 | 103 | for _ in range(0, d_iters): 104 | bx = self.x_sampler.train(batch_size) 105 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 106 | self.sess.run(self.d_adam, feed_dict={self.x: bx, self.z: bz}) 107 | 108 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 109 | self.sess.run(self.g_adam, feed_dict={self.z: bz}) 110 | 111 | if (t+1) % 100 == 0: 112 | bx = self.x_sampler.train(batch_size) 113 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 114 | 115 | 116 | d_loss = self.sess.run( 117 | self.d_loss, feed_dict={self.x: bx, self.z: bz} 118 | ) 119 | g_loss = self.sess.run( 120 | self.g_loss, feed_dict={self.z: bz} 121 | ) 122 | print('Iter [%8d] Time [%5.4f] d_loss [%.4f] g_loss [%.4f]' % 123 | (t+1, time.time() - start_time, d_loss, g_loss)) 124 | 125 | self.recon_enc(timestamp, val=True) 126 | self.save(timestamp) 127 | 128 | def save(self, timestamp): 129 | 130 | checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}_cyc{}_gen{}'.format(self.data, timestamp, self.model, self.sampler, 131 | self.z_dim, self.beta_cycle_label, 132 | self.beta_cycle_gen) 133 | 134 | if not os.path.exists(checkpoint_dir): 135 | os.makedirs(checkpoint_dir) 136 | 137 | self.saver.save(self.sess, os.path.join(checkpoint_dir, 'model.ckpt')) 138 | 139 | def load(self, pre_trained = False, timestamp = ''): 140 | 141 | if pre_trained == True: 142 | print('Loading Pre-trained Model...') 143 | checkpoint_dir = 'pre_trained_models/{}/{}_{}_z{}_cyc{}_gen{}'.format(self.data, self.model, self.sampler, 144 | self.z_dim, self.beta_cycle_label, self.beta_cycle_gen) 145 | else: 146 | if timestamp == '': 147 | print('Best Timestamp not provided. Abort !') 148 | checkpoint_dir = '' 149 | else: 150 | checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}_cyc{}_gen{}'.format(self.data, timestamp, self.model, self.sampler, 151 | self.z_dim, self.beta_cycle_label, 152 | self.beta_cycle_gen) 153 | 154 | 155 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, 'model.ckpt')) 156 | print('Restored model weights.') 157 | 158 | def _gen_samples(self, num_samples): 159 | 160 | batch_size = self.batch_size 161 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 162 | fake_samples = self.sess.run(self.x_, feed_dict = {self.z : bz}) 163 | for t in range(num_samples // batch_size): 164 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 165 | samp = self.sess.run(self.x_, feed_dict = {self.z : bz}) 166 | fake_samples = np.vstack((fake_samples, samp)) 167 | 168 | print(' Generated {} samples .'.format(fake_samples.shape[0])) 169 | np.save('./Image_samples/{}/{}_{}_K_{}_gen_images.npy'. 170 | format(self.data, self.model, self.sampler, self.num_classes), fake_samples) 171 | 172 | def recon_enc(self, timestamp, val = True): 173 | 174 | if val: 175 | data_recon, label_recon = self.x_sampler.validation() 176 | else: 177 | data_recon, label_recon = self.x_sampler.test() 178 | #data_recon, label_recon = self.x_sampler.load_all() 179 | 180 | num_pts_to_plot = data_recon.shape[0] 181 | recon_batch_size = self.batch_size 182 | latent = np.zeros(shape=(num_pts_to_plot, self.z_dim)) 183 | 184 | print('Data Shape = {}, Labels Shape = {}'.format(data_recon.shape, label_recon.shape)) 185 | for b in range(int(np.ceil(num_pts_to_plot*1.0 / recon_batch_size))): 186 | 187 | if (b+1)*recon_batch_size > num_pts_to_plot: 188 | pt_indx = np.arange(b*recon_batch_size, num_pts_to_plot) 189 | else: 190 | pt_indx = np.arange(b*recon_batch_size, (b+1)*recon_batch_size) 191 | xtrue = data_recon[pt_indx, :] 192 | 193 | zhats_gen, zhats_label = self.sess.run([self.z_infer_gen, self.z_infer_label], feed_dict={self.x : xtrue}) 194 | latent[pt_indx, :] = np.concatenate((zhats_gen, zhats_label), axis=1) 195 | 196 | 197 | if self.beta_cycle_gen == 0: 198 | self._eval_cluster(latent[:, self.dim_gen:], label_recon, timestamp, val) 199 | else: 200 | self._eval_cluster(latent, label_recon, timestamp, val) 201 | 202 | 203 | def _eval_cluster(self, latent_rep, labels_true, timestamp, val): 204 | 205 | km = KMeans(n_clusters=max(self.num_classes, len(np.unique(labels_true))), random_state=0).fit(latent_rep) 206 | labels_pred = km.labels_ 207 | 208 | purity = metric.compute_purity(labels_pred, labels_true) 209 | ari = adjusted_rand_score(labels_true, labels_pred) 210 | nmi = normalized_mutual_info_score(labels_true, labels_pred) 211 | 212 | if val: 213 | data_split = 'Validation' 214 | else: 215 | data_split = 'Test' 216 | #data_split = 'All' 217 | 218 | print('Data = {}, Model = {}, sampler = {}, z_dim = {}, beta_label = {}, beta_gen = {} ' 219 | .format(self.data, self.model, self.sampler, self.z_dim, self.beta_cycle_label, self.beta_cycle_gen)) 220 | print(' #Points = {}, K = {}, Purity = {}, NMI = {}, ARI = {}, ' 221 | .format(latent_rep.shape[0], self.num_classes, purity, nmi, ari)) 222 | 223 | with open('logs/Res_{}_{}.txt'.format(self.data, self.model), 'a+') as f: 224 | f.write('{}, {} : K = {}, z_dim = {}, beta_label = {}, beta_gen = {}, sampler = {}, Purity = {}, NMI = {}, ARI = {}\n' 225 | .format(timestamp, data_split, self.num_classes, self.z_dim, self.beta_cycle_label, self.beta_cycle_gen, 226 | self.sampler, purity, nmi, ari)) 227 | f.flush() 228 | 229 | 230 | if __name__ == '__main__': 231 | parser = argparse.ArgumentParser('') 232 | parser.add_argument('--data', type=str, default='pendigit') 233 | parser.add_argument('--model', type=str, default='clus_wgan') 234 | parser.add_argument('--sampler', type=str, default='one_hot') 235 | parser.add_argument('--K', type=int, default=10) 236 | parser.add_argument('--dz', type=int, default=5) 237 | parser.add_argument('--bs', type=int, default=64) 238 | parser.add_argument('--beta_n', type=float, default=10.0) 239 | parser.add_argument('--beta_c', type=float, default=10.0) 240 | parser.add_argument('--timestamp', type=str, default='') 241 | parser.add_argument('--train', type=str, default='False') 242 | args = parser.parse_args() 243 | data = importlib.import_module(args.data) 244 | model = importlib.import_module(args.data + '.' + args.model) 245 | 246 | num_classes = args.K 247 | dim_gen = args.dz 248 | n_cat = 1 249 | batch_size = args.bs 250 | beta_cycle_gen = args.beta_n 251 | beta_cycle_label = args.beta_c 252 | timestamp = args.timestamp 253 | 254 | z_dim = dim_gen + num_classes * n_cat 255 | d_net = model.Discriminator() 256 | g_net = model.Generator(z_dim=z_dim) 257 | enc_net = model.Encoder(z_dim=z_dim, dim_gen = dim_gen) 258 | xs = data.DataSampler() 259 | zs = util.sample_Z 260 | 261 | cl_gan = clusGAN(g_net, d_net, enc_net, xs, zs, args.data, args.model, args.sampler, 262 | num_classes, dim_gen, n_cat, batch_size, beta_cycle_gen, beta_cycle_label) 263 | if args.train == 'True': 264 | cl_gan.train() 265 | else: 266 | 267 | print('Attempting to Restore Model ...') 268 | if timestamp == '': 269 | cl_gan.load(pre_trained=True) 270 | timestamp = 'pre-trained' 271 | else: 272 | cl_gan.load(pre_trained=False, timestamp=timestamp) 273 | 274 | cl_gan.recon_enc(timestamp, val=False) 275 | -------------------------------------------------------------------------------- /pen_wgan_gp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import dateutil.tz 4 | import datetime 5 | import argparse 6 | import importlib 7 | import tensorflow as tf 8 | import numpy as np 9 | from sklearn.cluster import KMeans 10 | from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score 11 | 12 | import metric 13 | import util 14 | 15 | tf.set_random_seed(0) 16 | 17 | class WassersteinGAN(object): 18 | def __init__(self, g_net, d_net, x_sampler, z_sampler, data, model, sampler, num_classes, dim_gen, n_cat, 19 | batch_size, beta_reg): 20 | self.model = model 21 | self.data = data 22 | self.sampler = sampler 23 | self.g_net = g_net 24 | self.d_net = d_net 25 | self.x_sampler = x_sampler 26 | self.z_sampler = z_sampler 27 | self.num_classes = num_classes 28 | self.dim_gen = dim_gen 29 | self.n_cat = n_cat 30 | self.batch_size = batch_size 31 | scale = 10.0 32 | self.beta_reg = beta_reg 33 | 34 | self.x_dim = self.d_net.x_dim 35 | self.z_dim = self.g_net.z_dim 36 | 37 | if sampler == 'mul_cat': 38 | self.clip_lim = [-0.6, 0.6] 39 | elif sampler == 'one_hot': 40 | self.clip_lim = [-0.6, 0.6] 41 | elif sampler == 'clus': 42 | self.clip_lim = [-1.0, 1.0] 43 | elif sampler == 'uniform': 44 | self.clip_lim = [-1.0, 1.0] 45 | elif sampler == 'normal': 46 | self.clip_lim = [-1.0, 1.0] 47 | elif sampler == 'mix_gauss': 48 | self.clip_lim = [-1.0, 2.0] 49 | elif sampler == 'pca_kmeans': 50 | self.clip_lim = [-2.0, 2.0] 51 | 52 | self.x = tf.placeholder(tf.float32, [None, self.x_dim], name='x') 53 | self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z') 54 | 55 | self.x_ = self.g_net(self.z) 56 | 57 | self.d = self.d_net(self.x, reuse=False) 58 | self.d_ = self.d_net(self.x_) 59 | 60 | self.g_loss = tf.reduce_mean(self.d_) 61 | self.d_loss = tf.reduce_mean(self.d) - tf.reduce_mean(self.d_) 62 | 63 | epsilon = tf.random_uniform([], 0.0, 1.0) 64 | x_hat = epsilon * self.x + (1 - epsilon) * self.x_ 65 | d_hat = self.d_net(x_hat) 66 | 67 | ddx = tf.gradients(d_hat, x_hat)[0] 68 | ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1)) 69 | ddx = tf.reduce_mean(tf.square(ddx - 1.0) * scale) 70 | 71 | self.d_loss = self.d_loss + ddx 72 | 73 | self.d_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \ 74 | .minimize(self.d_loss, var_list=self.d_net.vars) 75 | self.g_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \ 76 | .minimize(self.g_loss, var_list=self.g_net.vars) 77 | 78 | self.saver = tf.train.Saver() 79 | 80 | run_config = tf.ConfigProto() 81 | run_config.gpu_options.per_process_gpu_memory_fraction = 1.0 82 | run_config.gpu_options.allow_growth = True 83 | self.sess = tf.Session(config=run_config) 84 | 85 | def train(self, num_batches=50000): 86 | 87 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 88 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 89 | 90 | batch_size = self.batch_size 91 | self.sess.run(tf.global_variables_initializer()) 92 | start_time = time.time() 93 | print('Training {} on {}, sampler = {}, z = {} dimension'.format(self.model, self.data, self.sampler, self.z_dim)) 94 | 95 | 96 | for t in range(0, num_batches): 97 | d_iters = 5 98 | 99 | for _ in range(0, d_iters): 100 | bx = self.x_sampler.train(batch_size) 101 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 102 | self.sess.run(self.d_adam, feed_dict={self.x: bx, self.z: bz}) 103 | 104 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 105 | self.sess.run(self.g_adam, feed_dict={self.z: bz}) 106 | 107 | if (t+1) % 100 == 0: 108 | bx = self.x_sampler.train(batch_size) 109 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 110 | 111 | d_loss = self.sess.run( 112 | self.d_loss, feed_dict={self.x: bx, self.z: bz} 113 | ) 114 | g_loss = self.sess.run( 115 | self.g_loss, feed_dict={self.z: bz} 116 | ) 117 | print('Iter [%8d] Time [%5.4f] d_loss [%.4f] g_loss [%.4f]' % 118 | (t+1, time.time() - start_time, d_loss, g_loss)) 119 | 120 | self.recon_enc(timestamp, val=True) 121 | self.save(timestamp) 122 | 123 | def save(self, timestamp): 124 | 125 | checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}'.format(self.data, timestamp, self.model, self.sampler, 126 | self.z_dim) 127 | 128 | if not os.path.exists(checkpoint_dir): 129 | os.makedirs(checkpoint_dir) 130 | 131 | self.saver.save(self.sess, os.path.join(checkpoint_dir, 'model.ckpt')) 132 | 133 | def load(self, pre_trained=False, timestamp=''): 134 | 135 | if pre_trained == True: 136 | print('Loading Pre-trained Model...') 137 | checkpoint_dir = 'pre_trained_models/{}/{}_{}_z{}'.format(self.data, self.model, self.sampler, self.z_dim) 138 | else: 139 | if timestamp == '': 140 | print('Best Timestamp not provided !') 141 | checkpoint_dir = '' 142 | else: 143 | checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}'.format(self.data, timestamp, self.model, self.sampler, 144 | self.z_dim) 145 | 146 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, 'model.ckpt')) 147 | print('Restored model weights.') 148 | 149 | def _gen_samples(self, num_samples): 150 | 151 | batch_size = self.batch_size 152 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 153 | fake_samples = self.sess.run(self.x_, feed_dict = {self.z : bz}) 154 | for t in range(num_samples // batch_size): 155 | bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat) 156 | samp = self.sess.run(self.x_, feed_dict = {self.z : bz}) 157 | fake_samples = np.vstack((fake_samples, samp)) 158 | 159 | print(' Generated {} samples .'.format(fake_samples.shape[0])) 160 | np.save('./Image_samples/{}/{}_{}_K_{}_gen_images.npy'.format(self.data, self.model, self.sampler, self.num_classes), fake_samples) 161 | 162 | 163 | def recon_enc(self, timestamp, val = True): 164 | 165 | if val: 166 | data_recon, label_recon = self.x_sampler.validation() 167 | else: 168 | data_recon, label_recon = self.x_sampler.test() 169 | #data_recon, label_recon = self.x_sampler.load_all() 170 | 171 | num_pts_to_plot = data_recon.shape[0] 172 | recon_batch_size = 1000 173 | latent = np.zeros(shape=(num_pts_to_plot, self.z_dim)) 174 | clip_lim = self.clip_lim 175 | 176 | print('Data Shape = {}, Labels Shape = {}'.format(data_recon.shape, label_recon.shape)) 177 | 178 | # Regularized Reconstruction objective 179 | 180 | self.recon_reg_loss = tf.reduce_mean(tf.abs(self.x - self.x_), 1) + \ 181 | self.beta_reg * tf.reduce_mean(tf.square(self.z[:, 0:self.dim_gen]), 1) 182 | self.compute_reg_grad = tf.gradients(self.recon_reg_loss, self.z) 183 | 184 | for b in range(int(np.ceil(num_pts_to_plot*1.0 / recon_batch_size))): 185 | 186 | if (b+1)*recon_batch_size > num_pts_to_plot: 187 | pt_indx = np.arange(b*recon_batch_size, num_pts_to_plot) 188 | else: 189 | pt_indx = np.arange(b*recon_batch_size, (b+1)*recon_batch_size) 190 | xtrue = data_recon[pt_indx, :] 191 | 192 | num_backprop_iter = 5000 193 | num_restarts = self.num_classes 194 | seed_labels = np.tile(np.arange(self.num_classes), int(np.ceil(len(pt_indx) * 1.0 / self.num_classes))) 195 | seed_labels = seed_labels[0:len(pt_indx)] 196 | best_zhats = np.zeros(shape=(len(pt_indx), self.z_dim)) 197 | best_loss = np.inf * np.ones(len(pt_indx)) 198 | mu_mat = 1.0 * np.eye(self.num_classes) 199 | alg = 'adam' 200 | for t in range(num_restarts): 201 | print('Backprop Decoding [{} / {} ] ...'.format(t + 1, num_restarts)) 202 | 203 | if self.sampler == 'one_hot': 204 | label_index = (seed_labels + t) % self.num_classes 205 | zhats = util.sample_Z(len(pt_indx), self.z_dim, self.sampler, num_class=self.num_classes, 206 | n_cat=1, label_index=label_index) 207 | elif self.sampler == 'mul_cat': 208 | label_index = (seed_labels + t) % self.num_classes 209 | zhats = util.sample_Z(len(pt_indx), self.z_dim, self.sampler, num_class=self.num_classes, 210 | n_cat=self.n_cat, 211 | label_index=label_index) 212 | elif self.sampler == 'mix_gauss': 213 | label_index = (seed_labels + t) % self.num_classes 214 | zhats = util.sample_Z(len(pt_indx), self.z_dim, self.sampler, num_class=self.num_classes, 215 | n_cat=0, label_index=label_index) 216 | else: 217 | zhats = util.sample_Z(len(pt_indx), self.z_dim, self.sampler) 218 | if alg == 'adam': 219 | beta1 = 0.9 220 | beta2 = 0.999 221 | lr = 0.01 222 | eps = 1e-8 223 | m = 0 224 | v = 0 225 | elif alg == 'grad_descent': 226 | lr = 1.00 227 | 228 | for i in range(num_backprop_iter): 229 | 230 | L, g = self.sess.run([self.recon_reg_loss, self.compute_reg_grad], 231 | feed_dict={self.z: zhats, self.x: xtrue}) 232 | 233 | if alg == 'adam': 234 | m_prev = np.copy(m) 235 | v_prev = np.copy(v) 236 | m = beta1 * m_prev + (1 - beta1) * g[0] 237 | v = beta2 * v_prev + (1 - beta2) * np.multiply(g[0], g[0]) 238 | m_hat = m / (1 - beta1 ** (i + 1)) 239 | v_hat = v / (1 - beta2 ** (i + 1)) 240 | zhats += - np.true_divide(lr * m_hat, (np.sqrt(v_hat) + eps)) 241 | 242 | elif alg == 'grad_descent': 243 | zhats += - lr * g[0] 244 | 245 | zhats = np.clip(zhats, a_min=clip_lim[0], a_max=clip_lim[1]) 246 | 247 | if self.sampler == 'one_hot': 248 | zhats[:, -self.num_classes:] = mu_mat[label_index, :] 249 | elif self.sampler == 'mul_hot': 250 | zhats[:, self.dim_gen:] = np.tile(mu_mat[label_index, :], (1, self.n_cat)) 251 | 252 | change_index = best_loss > L 253 | best_zhats[change_index, :] = zhats[change_index, :] 254 | best_loss[change_index] = L[change_index] 255 | 256 | latent[pt_indx, :] = best_zhats 257 | print(' [{} / {} ] ...'.format(pt_indx[-1]+1, num_pts_to_plot)) 258 | 259 | 260 | self._eval_cluster(latent, label_recon, timestamp, val) 261 | 262 | def _eval_cluster(self, latent_rep, labels_true, timestamp, val): 263 | 264 | km = KMeans(n_clusters=max(self.num_classes, len(np.unique(labels_true))), random_state=0).fit(latent_rep) 265 | labels_pred = km.labels_ 266 | 267 | purity = metric.compute_purity(labels_pred, labels_true) 268 | ari = adjusted_rand_score(labels_true, labels_pred) 269 | nmi = normalized_mutual_info_score(labels_true, labels_pred) 270 | 271 | if val: 272 | data_split = 'Validation' 273 | else: 274 | data_split = 'Test' 275 | #data_split = 'All' 276 | 277 | print('Data = {}, Model = {}, sampler = {}, z_dim = {}, beta_reg = {}' 278 | .format(self.data, self.model, self.sampler, self.z_dim, self.beta_reg)) 279 | print(' #Points = {}, K = {}, Purity = {}, NMI = {}, ARI = {}, ' 280 | .format(latent_rep.shape[0], self.num_classes, purity, nmi, ari)) 281 | 282 | with open('logs/Res_{}_{}.txt'.format(self.data, self.model), 'a+') as f: 283 | f.write('{}, {} : K = {}, z_dim = {}, beta_reg = {}, sampler = {}, Purity = {}, NMI = {}, ARI = {}\n' 284 | .format(timestamp, data_split, self.num_classes, self.z_dim, self.beta_reg, 285 | self.sampler, purity, nmi, ari)) 286 | f.flush() 287 | 288 | 289 | if __name__ == '__main__': 290 | parser = argparse.ArgumentParser('') 291 | parser.add_argument('--data', type=str, default='pendigit') 292 | parser.add_argument('--model', type=str, default='mlp') 293 | parser.add_argument('--sampler', type=str, default='one_hot') 294 | parser.add_argument('--K', type=int, default=10) 295 | parser.add_argument('--dz', type=int, default=5) 296 | parser.add_argument('--bs', type=int, default=64) 297 | parser.add_argument('--beta_reg', type=float, default=10.0) 298 | parser.add_argument('--timestamp', type=str, default='') 299 | parser.add_argument('--train', type=str, default='False') 300 | 301 | args = parser.parse_args() 302 | data = importlib.import_module(args.data) 303 | model = importlib.import_module(args.data + '.' + args.model) 304 | 305 | num_classes = args.K 306 | dim_gen = args.dz 307 | n_cat = 1 308 | batch_size = args.bs 309 | beta_reg = args.beta_reg 310 | timestamp = args.timestamp 311 | 312 | z_dim = dim_gen + num_classes * n_cat 313 | d_net = model.Discriminator() 314 | g_net = model.Generator(z_dim=z_dim) 315 | xs = data.DataSampler() 316 | zs = util.sample_Z 317 | 318 | wgan = WassersteinGAN(g_net, d_net, xs, zs, args.data, args.model, args.sampler, 319 | num_classes, dim_gen, n_cat, batch_size, beta_reg) 320 | 321 | if args.train == 'True': 322 | wgan.train() 323 | else: 324 | 325 | print('Attempting to Restore Model ...') 326 | if timestamp == '': 327 | wgan.load(pre_trained=True) 328 | timestamp = 'pre-trained' 329 | else: 330 | wgan.load(pre_trained=False, timestamp=timestamp) 331 | 332 | wgan.recon_enc(timestamp, val=False) 333 | -------------------------------------------------------------------------------- /pendigit/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # Original Labels are from 0 to 9. 3 | 4 | class DataSampler(object): 5 | def __init__(self): 6 | finp_tr = './data/pendigit/pendigits.tra.txt' 7 | finp_tes = './data/pendigit/pendigits.tes.txt' 8 | data_tr = np.loadtxt(finp_tr, delimiter=',') 9 | self.X_train = data_tr[:, 0:16] 10 | self.X_train /= 100.0 11 | self.Y_train = data_tr[:, -1].astype(int) 12 | 13 | data_tes = np.loadtxt(finp_tes, delimiter=',') 14 | self.X_test = data_tes[:, 0:16] 15 | self.X_test /= 100.0 16 | self.Y_test = data_tes[:, -1].astype(int) 17 | 18 | self.X = np.concatenate((self.X_train, self.X_test)) 19 | self.Y = np.concatenate((self.Y_train, self.Y_test)).astype(int) 20 | 21 | self.train_size = self.X_train.shape[0] 22 | self.test_size = self.X_test.shape[0] 23 | self.data_size = self.X.shape[0] 24 | 25 | def train(self, batch_size, label=False): 26 | indx = np.random.randint(low=0, high=self.train_size, size=batch_size) 27 | 28 | if label: 29 | return self.X_train[indx, :], self.Y_train[indx].flatten() 30 | else: 31 | return self.X_train[indx, :] 32 | 33 | def validation(self): 34 | return self.X_train[-1000:, :], self.Y_train[-1000:].flatten() 35 | 36 | def test(self): 37 | return self.X_test, self.Y_test.flatten() 38 | 39 | 40 | def load_all(self): 41 | return self.X, self.Y.flatten() 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /pendigit/clus_wgan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tc 3 | import tensorflow.contrib.layers as tcl 4 | 5 | 6 | def leaky_relu(x, alpha=0.2): 7 | return tf.maximum(tf.minimum(0.0, alpha * x), x) 8 | 9 | 10 | class Discriminator(object): 11 | def __init__(self, x_dim=16): 12 | self.x_dim = x_dim 13 | self.name = 'pendigit/clus_wgan/d_net' 14 | 15 | def __call__(self, x, keep=1.0, reuse=True): 16 | with tf.variable_scope(self.name) as vs: 17 | if reuse: 18 | vs.reuse_variables() 19 | 20 | fc1 = tc.layers.fully_connected( 21 | x, 256, 22 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 23 | activation_fn=tf.identity 24 | ) 25 | 26 | fc1 = leaky_relu(tc.layers.batch_norm(fc1)) 27 | 28 | fc2 = tc.layers.fully_connected( 29 | fc1, 256, 30 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 31 | activation_fn=tf.identity 32 | ) 33 | 34 | fc2 = leaky_relu(tc.layers.batch_norm(fc2)) 35 | 36 | fc3 = tc.layers.fully_connected(fc2, 1, 37 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 38 | activation_fn=tf.identity 39 | ) 40 | return fc3 41 | 42 | @property 43 | def vars(self): 44 | return [var for var in tf.global_variables() if self.name in var.name] 45 | 46 | 47 | class Generator(object): 48 | def __init__(self, z_dim=15, x_dim=16): 49 | self.z_dim = z_dim 50 | self.x_dim = x_dim 51 | self.name = 'pendigit/clus_wgan/g_net' 52 | 53 | def __call__(self, z, keep=1.0): 54 | with tf.variable_scope(self.name) as vs: 55 | fc1 = tcl.fully_connected( 56 | z, 256, 57 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 58 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 59 | activation_fn=tf.identity 60 | ) 61 | 62 | fc1 = leaky_relu(tc.layers.batch_norm(fc1)) 63 | 64 | fc2 = tcl.fully_connected( 65 | fc1, 256, 66 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 67 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 68 | activation_fn=tf.identity 69 | ) 70 | 71 | fc2 = leaky_relu(tc.layers.batch_norm(fc2)) 72 | 73 | fc3 = tc.layers.fully_connected( 74 | fc2, self.x_dim, 75 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 76 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 77 | activation_fn=tf.sigmoid 78 | ) 79 | return fc3 80 | 81 | @property 82 | def vars(self): 83 | return [var for var in tf.global_variables() if self.name in var.name] 84 | 85 | 86 | class Encoder(object): 87 | def __init__(self, z_dim=15, dim_gen=5, x_dim=16): 88 | self.z_dim = z_dim 89 | self.dim_gen = dim_gen 90 | self.x_dim = x_dim 91 | self.name = 'pendigit/clus_wgan/enc_net' 92 | 93 | def __call__(self, x, keep=1.0, reuse=True): 94 | with tf.variable_scope(self.name) as vs: 95 | if reuse: 96 | vs.reuse_variables() 97 | 98 | fc1 = tc.layers.fully_connected( 99 | x, 256, 100 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 101 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 102 | activation_fn=tf.identity 103 | ) 104 | 105 | fc1 = leaky_relu(tc.layers.batch_norm(fc1)) 106 | 107 | fc2 = tc.layers.fully_connected( 108 | fc1, 256, 109 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 110 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 111 | activation_fn=tf.identity 112 | ) 113 | 114 | fc2 = leaky_relu(tc.layers.batch_norm(fc2)) 115 | 116 | fc3 = tc.layers.fully_connected( 117 | fc2, self.z_dim, 118 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 119 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 120 | activation_fn=tf.identity) 121 | logits = fc3[:, self.dim_gen:] 122 | y = tf.nn.softmax(logits) 123 | return fc3[:, 0:self.dim_gen], y, logits 124 | 125 | @property 126 | def vars(self): 127 | return [var for var in tf.global_variables() if self.name in var.name] 128 | -------------------------------------------------------------------------------- /pendigit/mlp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tc 3 | import tensorflow.contrib.layers as tcl 4 | 5 | 6 | def leaky_relu(x, alpha=0.2): 7 | return tf.maximum(tf.minimum(0.0, alpha * x), x) 8 | 9 | 10 | class Discriminator(object): 11 | def __init__(self, x_dim=16): 12 | self.x_dim = x_dim 13 | self.name = 'pendigit/mlp/d_net' 14 | 15 | def __call__(self, x, keep=1.0, reuse=True): 16 | with tf.variable_scope(self.name) as vs: 17 | if reuse: 18 | vs.reuse_variables() 19 | 20 | fc1 = tc.layers.fully_connected( 21 | x, 256, 22 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 23 | activation_fn=tf.identity 24 | ) 25 | fc1 = leaky_relu(tc.layers.batch_norm(fc1)) 26 | 27 | fc2 = tc.layers.fully_connected( 28 | fc1,256, 29 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 30 | activation_fn=tf.identity 31 | ) 32 | fc2 = leaky_relu(tc.layers.batch_norm(fc2)) 33 | 34 | fc3 = tc.layers.fully_connected(fc2, 1, 35 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 36 | activation_fn=tf.identity 37 | ) 38 | return fc3 39 | 40 | @property 41 | def vars(self): 42 | return [var for var in tf.global_variables() if self.name in var.name] 43 | 44 | 45 | class Generator(object): 46 | def __init__(self, z_dim=15, x_dim=16): 47 | self.z_dim = z_dim 48 | self.x_dim = x_dim 49 | self.name = 'pendigit/mlp/g_net' 50 | 51 | def __call__(self, z, keep=1.0): 52 | with tf.variable_scope(self.name) as vs: 53 | fc1 = tcl.fully_connected( 54 | z, 256, 55 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 56 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 57 | activation_fn=tcl.batch_norm 58 | ) 59 | fc1 = leaky_relu(fc1) 60 | fc2 = tcl.fully_connected( 61 | fc1, 256, 62 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 63 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 64 | activation_fn=tcl.batch_norm 65 | ) 66 | fc2 = leaky_relu(fc2) 67 | 68 | fc3 = tc.layers.fully_connected( 69 | fc2, self.x_dim, 70 | weights_initializer=tf.random_normal_initializer(stddev=0.02), 71 | weights_regularizer=tc.layers.l2_regularizer(2.5e-5), 72 | activation_fn=tf.sigmoid 73 | ) 74 | return fc3 75 | 76 | @property 77 | def vars(self): 78 | return [var for var in tf.global_variables() if self.name in var.name] 79 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | def sample_Z(batch, z_dim , sampler = 'one_hot', num_class = 10, n_cat = 1, label_index = None): 5 | if sampler == 'mul_cat': 6 | if label_index is None: 7 | label_index = np.random.randint(low = 0 , high = num_class, size = batch) 8 | return np.hstack((0.10 * np.random.randn(batch, z_dim-num_class*n_cat), 9 | np.tile(np.eye(num_class)[label_index], (1, n_cat)))) 10 | elif sampler == 'one_hot': 11 | if label_index is None: 12 | label_index = np.random.randint(low = 0 , high = num_class, size = batch) 13 | return np.hstack((0.10 * np.random.randn(batch, z_dim-num_class), np.eye(num_class)[label_index])) 14 | elif sampler == 'uniform': 15 | return np.random.uniform(-1., 1., size=[batch, z_dim]) 16 | elif sampler == 'normal': 17 | return 0.15*np.random.randn(batch, z_dim) 18 | elif sampler == 'mix_gauss': 19 | if label_index is None: 20 | label_index = np.random.randint(low = 0 , high = num_class, size = batch) 21 | return (0.1 * np.random.randn(batch, z_dim) + np.eye(num_class, z_dim)[label_index]) 22 | 23 | 24 | def sample_labelled_Z(batch, z_dim, sampler = 'one_hot', num_class = 10, n_cat = 1, label_index = None): 25 | 26 | if sampler == 'mul_cat': 27 | if label_index is None: 28 | label_index = np.random.randint(low=0, high=num_class, size=batch) 29 | return label_index, np.hstack((0.10 * np.random.randn(batch, z_dim - num_class*n_cat), 30 | np.tile(np.eye(num_class)[label_index], (1, n_cat)))) 31 | elif sampler == 'one_hot': 32 | if label_index is None: 33 | label_index = np.random.randint(low=0, high=num_class, size=batch) 34 | return label_index, np.hstack((0.10 * np.random.randn(batch, z_dim - num_class), np.eye(num_class)[label_index])) 35 | elif sampler == 'mix_gauss': 36 | if label_index is None: 37 | label_index = np.random.randint(low=0, high=num_class, size=batch) 38 | return label_index, (0.1 * np.random.randn(batch, z_dim) + np.eye(num_class, z_dim)[label_index]) 39 | 40 | 41 | def reshape_mnist(X): 42 | return X.reshape(X.shape[0], 28, 28, 1) 43 | 44 | 45 | def clus_sample_Z(batch, dim_gen=20, dim_c=2, num_class = 10, label_index = None): 46 | 47 | if label_index is None: 48 | label_index = np.random.randint(low=0, high=num_class, size=batch) 49 | batch_mat = np.zeros((batch, num_class* dim_c)) 50 | for b in range(batch): 51 | batch_mat[b, label_index[b] * dim_c:(label_index[b] + 1) * dim_c] = np.random.normal(loc = 1.0, scale = 0.05, size = (1, dim_c)) 52 | return np.hstack((0.10 * np.random.randn(batch, dim_gen), batch_mat)) 53 | 54 | 55 | def clus_sample_labelled_Z(batch, dim_gen=20, dim_c=2, num_class = 10, label_index = None): 56 | if label_index is None: 57 | label_index = np.random.randint(low=0, high=num_class, size=batch) 58 | batch_mat = np.zeros((batch, num_class*dim_c)) 59 | for b in range(batch): 60 | batch_mat[b, label_index[b] * dim_c:(label_index[b] + 1) * dim_c] = np.random.normal(loc=1.0, scale = 0.05, size = (1, dim_c)) 61 | return label_index, np.hstack((0.10 * np.random.randn(batch, dim_gen), batch_mat)) 62 | 63 | 64 | 65 | def sample_info(batch, z_dim, sampler = 'one_hot', num_class = 10, n_cat = 1, label_index = None): 66 | if sampler == 'one_hot': 67 | if label_index is None: 68 | label_index = np.random.randint(low=0, high=num_class, size=batch) 69 | return label_index, np.hstack( 70 | (np.random.randn(batch, z_dim - num_class), np.eye(num_class)[label_index])) 71 | elif sampler == 'mul_cat': 72 | if label_index is None: 73 | label_index = np.random.randint(low=0, high=num_class, size=batch) 74 | return label_index, np.hstack((np.random.randn(batch, z_dim - num_class*n_cat), 75 | np.tile(np.eye(num_class)[label_index], (1, n_cat)))) 76 | 77 | 78 | if __name__=='__main__': 79 | 80 | l = sample_Z(10, 22, 'mul_cat', 10, 2) 81 | print(l) 82 | 83 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def split(x): 9 | assert type(x) == int 10 | t = int(np.floor(np.sqrt(x))) 11 | for a in range(t, 0, -1): 12 | if x % a == 0: 13 | return a, x / a 14 | 15 | 16 | def grid_transform(x, size): 17 | a, b = split(x.shape[0]) 18 | h, w, c = size[0], size[1], size[2] 19 | x = np.reshape(x, [a, b, h, w, c]) 20 | x = np.transpose(x, [0, 2, 1, 3, 4]) 21 | x = np.reshape(x, [a * h, b * w, c]) 22 | if x.shape[2] == 1: 23 | x = np.squeeze(x, axis=2) 24 | return x 25 | 26 | 27 | def grid_show(fig, x, size): 28 | ax = fig.add_subplot(111) 29 | x = grid_transform(x, size) 30 | if len(x.shape) > 2: 31 | ax.imshow(x) 32 | else: 33 | ax.imshow(x, cmap='gray') 34 | 35 | 36 | 37 | if __name__=='__main__': 38 | 39 | from keras.datasets import cifar10 40 | from scipy.misc import imsave 41 | import pdb 42 | 43 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 44 | 45 | shape = x_train[0].shape 46 | bx = x_train[0:64,:] 47 | bx = grid_transform(bx, shape) 48 | 49 | imsave('cifar_batch.png', bx) 50 | 51 | pdb.set_trace() 52 | 53 | print('Done !') 54 | 55 | 56 | 57 | --------------------------------------------------------------------------------