├── .gitignore ├── README.md ├── aae_supervised.py ├── aae_unsupervised.py ├── data_factory.py ├── model.py ├── utils.py └── visualize.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | cache/* 2 | *.pyc 3 | *.json 4 | *.params 5 | .ipynb_checkpoints/ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Adversarial AutoEncoder 2 | ---------------------------- 3 | 4 | [Adversarial Autoencoder [arXiv:1511.05644]](http://arxiv.org/abs/1511.05644) implemented with [MXNet](https://github.com/dmlc/mxnet). 5 | 6 | ### Requirements 7 | * MXNet 8 | * numpy 9 | * matplotlib 10 | * scikit-learn 11 | * OpenCV 12 | 13 | ### Unsupervised Adversarial Autoencoder 14 | Please run aae\_unsupervised.py for model training. Set task to `unsupervised` in visualize.ipynb to display the results. Notice the desired prior distribution of the 2-d latent variable can be one of {gaussian, gaussian mixture, swiss roll or uniform}. In this case, no label info is being used during the training process. 15 | 16 | Some results: 17 | 18 | p(z) and q(z) with z_prior set to gaussian distribution. 19 | 20 | ![p(z) gaussian](http://closure11.com/images/post/2016/10/gaussian_unsupervised_pz.png) 21 | ![q(z) gaussian](http://closure11.com/images/post/2016/10/gaussian_unsupervised_qz.png) 22 | 23 | p(z) and q(z) with z_prior set to 10 gaussian mixture distribution. 24 | 25 | ![p(z) gaussian](http://closure11.com/images/post/2016/10/gaussian_mixture_unsupervised_pz.png) 26 | ![q(z) gaussian](http://closure11.com/images/post/2016/10/gaussian_mixture_unsupervised_qz.png) 27 | 28 | p(z) and q(z) with z_prior set to swiss roll distribution. 29 | 30 | ![p(z) gaussian](http://closure11.com/images/post/2016/10/swiss_roll_unsupervised_pz.png) 31 | ![q(z) gaussian](http://closure11.com/images/post/2016/10/swiss_roll_unsupervised_qz.png) 32 | 33 | ### Supervised Adversarial Autoencoder 34 | Please run aae\_supervised.py for model training. Set task to `supervised` in visualize.ipynb to display the results. Notice the desired prior distribution of the 2-d latent variable can be one of {gaussian mixture, swiss roll or uniform}. In this case, label info of both real and fake data is being used during the training process. 35 | 36 | Some results: 37 | 38 | p(z), q(z) and output images from fake data with z_prior set to 10 gaussian mixture distribution. 39 | 40 | ![p(z) gaussian](http://closure11.com/images/post/2016/10/gaussian_mixture_supervised_pz.png) 41 | ![q(z) gaussian](http://closure11.com/images/post/2016/10/gaussian_mixture_supervised_qz.png) 42 | ![output images from gaussian fake data](http://closure11.com/images/post/2016/10/gaussian_mixture_supervised_output.png) 43 | 44 | p(z) and q(z) with z_prior set to swiss roll distribution. 45 | 46 | ![p(z) gaussian](http://closure11.com/images/post/2016/10/swiss_roll_supervised_pz.png) 47 | ![q(z) gaussian](http://closure11.com/images/post/2016/10/swiss_roll_supervised_qz.png) 48 | 49 | p(z) and q(z) with z_prior set to 10 uniform distribution. 50 | 51 | ![p(z) gaussian](http://closure11.com/images/post/2016/10/uniform_supervised_pz.png) 52 | ![q(z) gaussian](http://closure11.com/images/post/2016/10/uniform_supervised_qz.png) 53 | 54 | 55 | ### Semi-Supervised Adversarial Autoencoder 56 | Not implemented yet. 57 | -------------------------------------------------------------------------------- /aae_supervised.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import mxnet as mx 4 | import numpy as np 5 | from data_factory import RandIter, get_mnist 6 | from model import make_aae_sym 7 | 8 | 9 | if __name__ == '__main__': 10 | # =============setting============ 11 | ctx = mx.gpu(3) 12 | display_step = 20 13 | epoch_num = 40 14 | check_point = 20 15 | early_stop = True 16 | dataset = 'mnist' 17 | n_labels = 10 18 | z_prior = 'swiss_roll' 19 | if z_prior == 'uniform': 20 | z_args = { 21 | 'minv': -2.0, 22 | 'maxv': 2.0 23 | } 24 | elif z_prior == 'gaussian_mixture': 25 | z_args = { 26 | 'x_var': 0.5, 27 | 'y_var': 0.1 28 | } 29 | elif z_prior == 'swiss_roll': 30 | z_args = {} 31 | elif z_prior == 'gaussian': 32 | raise ValueError, 'please use gaussian mixture for supervised training' 33 | else: 34 | raise ValueError, 'unknown z_prior' 35 | batch_size=100 36 | n_dim = 2 37 | n_encoder = 1000 38 | n_decoder = 1000 39 | n_discriminator = 500 40 | enc_mult = 1 41 | dec_mult = 1 42 | dis_mult = 1 43 | std = 0.01 44 | lr = 0.002 45 | lr_factor = 1 46 | lr_factor_step = 2000 47 | wd = 0.0 48 | beta1 = 0.1 49 | optimizer_args = { 50 | 'optimizer': 'adam', 51 | 'optimizer_params': { 52 | 'clip_gradient': 5.0, 53 | 'learning_rate': lr, 54 | 'lr_scheduler': mx.lr_scheduler.FactorScheduler(lr_factor_step, lr_factor), 55 | 'wd': wd, 56 | 'beta1': beta1, 57 | } 58 | } 59 | 60 | if not os.path.exists('cache'): 61 | os.mkdir('cache') 62 | 63 | if dataset == 'mnist': 64 | X_train, X_test, Y_train, Y_test = get_mnist(root_dir='./cache', train_ratio=0.9) 65 | train_iter = mx.io.NDArrayIter(X_train, label=Y_train, batch_size=batch_size) 66 | # change the label name from softmax_label to n 67 | train_iter.label = [('label_n', train_iter.label[0][1])] 68 | else: 69 | raise NotImplementedError 70 | rand_iter = RandIter(batch_size, n_dim, z_prior=z_prior, n_labels=n_labels, with_label=True, **z_args) 71 | label_pq = mx.nd.zeros((batch_size,), ctx=ctx) 72 | 73 | sym_enc, sym_dec, sym_dis = make_aae_sym(data_dim=784, 74 | n_dim=n_dim, 75 | n_encoder=n_encoder, 76 | n_decoder=n_decoder, 77 | n_discriminator=n_discriminator, 78 | enc_mult=enc_mult, 79 | dec_mult=dec_mult, 80 | dis_mult=dis_mult, 81 | with_bn=False, 82 | supervised=True) 83 | 84 | mod_enc = mx.mod.Module(symbol=sym_enc, data_names=('data',), label_names=(), context=ctx) 85 | mod_enc.bind(data_shapes=train_iter.provide_data, 86 | label_shapes=[], 87 | inputs_need_grad=False) 88 | mod_enc.init_params(initializer=mx.init.Normal(std)) 89 | mod_enc.init_optimizer(**optimizer_args) 90 | 91 | mod_dec = mx.mod.Module(symbol=sym_dec, data_names=('z',), label_names=('data',), context=ctx) 92 | mod_dec.bind(data_shapes=rand_iter.provide_data, 93 | label_shapes=train_iter.provide_data, 94 | inputs_need_grad=True) 95 | mod_dec.init_params(initializer=mx.init.Normal(std)) 96 | mod_dec.init_optimizer(**optimizer_args) 97 | 98 | mod_dis = mx.mod.Module(symbol=sym_dis, data_names=('z', 'label_n'), label_names=('label_pq',), context=ctx) 99 | mod_dis.bind(data_shapes=rand_iter.provide_data+rand_iter.provide_label, 100 | label_shapes=[('label_pq', (batch_size,))], 101 | inputs_need_grad=True) 102 | mod_dis.init_params(initializer=mx.init.Normal(std)) 103 | mod_dis.init_optimizer(**optimizer_args) 104 | 105 | def facc(label, pred): 106 | pred = pred.ravel() 107 | label = label.ravel() 108 | return ((pred > 0.5) == label).mean() 109 | 110 | def fentropy(label, pred): 111 | pred = pred.ravel() 112 | label = label.ravel() 113 | return -(label*np.log(pred+1e-12) + (1.-label)*np.log(1.-pred+1e-12)).mean() 114 | 115 | def frmse(label, pred): 116 | dim = label.size/label.shape[0] 117 | label = label.reshape((-1, dim)) 118 | pred = pred.reshape((-1, dim)) 119 | return np.linalg.norm(label-pred, axis=1).mean() 120 | 121 | metric_dec_rmse = mx.metric.CustomMetric(frmse) 122 | metric_dis_accuracy = mx.metric.CustomMetric(facc) 123 | metric_dis_entropy = mx.metric.CustomMetric(fentropy) 124 | metric_fool_dis_accuracy = mx.metric.CustomMetric(facc) 125 | metric_fool_dis_entropy = mx.metric.CustomMetric(fentropy) 126 | # metric_dis_accuracy = mx.metric.create('accuracy') 127 | # metric_dis_entropy = mx.metric.create('ce') 128 | # metric_fool_dis_accuracy = mx.metric.create('accuracy') 129 | # metric_fool_dis_entropy = mx.metric.create('ce') 130 | 131 | print 'Training ...' 132 | for epoch in xrange(epoch_num): 133 | train_iter.reset() 134 | for t, batch in enumerate(train_iter): 135 | batch_label_one_hot = np.zeros((batch_size, n_labels), dtype=np.float32) 136 | batch_label_np = batch.label[0].asnumpy() 137 | for i in xrange(batch_size): 138 | batch_label_one_hot[i, int(batch_label_np[i])] = 1 139 | batch_label_one_hot = mx.nd.array(batch_label_one_hot) 140 | 141 | rbatch = rand_iter.next() 142 | 143 | # reconstruction phase: update encoder and decoder 144 | mod_enc.forward(batch, is_train=True) 145 | qz = mod_enc.get_outputs() 146 | mod_dec.forward(mx.io.DataBatch(qz, batch.data), is_train=True) 147 | mod_dec.backward() 148 | diff_dec = mod_dec.get_input_grads() 149 | mod_enc.backward(diff_dec) 150 | mod_enc.update() 151 | mod_dec.update() 152 | 153 | metric_dec_rmse.update(batch.data, mod_dec.get_outputs()) 154 | 155 | # regularization phase 156 | # step 1: update discriminator 157 | label_pq[:] = 0 158 | mod_dis.forward(mx.io.DataBatch(qz+[batch_label_one_hot], [label_pq]), is_train=True) 159 | mod_dis.backward() 160 | # mod_dis.update() 161 | gradD = [[grad.copyto(grad.context) for grad in grads] for grads in mod_dis._exec_group.grad_arrays] 162 | metric_dis_accuracy.update([label_pq], mod_dis.get_outputs()) 163 | metric_dis_entropy.update([label_pq], mod_dis.get_outputs()) 164 | 165 | label_pq[:] = 1 166 | pz = rbatch.data 167 | mod_dis.forward(mx.io.DataBatch(pz+rbatch.label, [label_pq]), is_train=True) 168 | mod_dis.backward() 169 | for gradsr, gradsf in zip(mod_dis._exec_group.grad_arrays, gradD): 170 | for gradr, gradf in zip(gradsr, gradsf): 171 | gradr += gradf 172 | mod_dis.update() 173 | metric_dis_accuracy.update([label_pq], mod_dis.get_outputs()) 174 | metric_dis_entropy.update([label_pq], mod_dis.get_outputs()) 175 | 176 | # step 2: update encoder(fool discriminator) 177 | label_pq[:] = 1 178 | mod_enc.forward(batch, is_train=True) 179 | qz = mod_enc.get_outputs() 180 | mod_dis.forward(mx.io.DataBatch(qz+[batch_label_one_hot], [label_pq]), is_train=True) 181 | mod_dis.backward() 182 | diff_dis = mod_dis.get_input_grads() 183 | mod_enc.backward([diff_dis[1]]) 184 | mod_enc.update() 185 | 186 | # metric update 187 | mod_dis.forward(mx.io.DataBatch(qz, [label_pq]), is_train=False) 188 | metric_fool_dis_accuracy.update([label_pq], mod_dis.get_outputs()) 189 | metric_fool_dis_entropy.update([label_pq], mod_dis.get_outputs()) 190 | 191 | if t % display_step == 0: 192 | print '\rEpoch %d, iter %d: dec_rmse=%.2f, dis_acc=%.4f, dis_entropy=%.2f, fool_dis_acc=%.4f, fool_dis_entropy=%.4f' % (epoch, t, metric_dec_rmse.get()[1], metric_dis_accuracy.get()[1], metric_dis_entropy.get()[1], metric_fool_dis_accuracy.get()[1], metric_fool_dis_entropy.get()[1]), 193 | sys.stdout.flush() 194 | 195 | metric_dec_rmse.reset() 196 | metric_dis_accuracy.reset() 197 | metric_dis_entropy.reset() 198 | metric_fool_dis_accuracy.reset() 199 | metric_fool_dis_entropy.reset() 200 | 201 | if (epoch+1) % check_point == 0: 202 | print 'Saving...' 203 | sym_dec.save('cache/models/%s_%s_dec_supervised.json'%(dataset, z_prior)) 204 | mod_dec.save_params('cache/models/%s_%s_dec_supervised_%04d.params'%(dataset, z_prior, epoch)) 205 | sym_enc.save('cache/models/%s_%s_enc_supervised.json'%(dataset, z_prior)) 206 | mod_enc.save_params('cache/models/%s_%s_enc_supervised_%04d.params'%(dataset, z_prior, epoch)) 207 | sym_dis.save('cache/models/%s_%s_dis_supervised.json'%(dataset, z_prior)) 208 | mod_dis.save_params('cache/models/%s_%s_dis_supervised_%04d.params'%(dataset, z_prior, epoch)) 209 | -------------------------------------------------------------------------------- /aae_unsupervised.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import mxnet as mx 3 | import numpy as np 4 | from data_factory import RandIter, get_mnist 5 | from model import make_aae_sym 6 | 7 | 8 | if __name__ == '__main__': 9 | # =============setting============ 10 | display_step = 20 11 | check_point = 20 12 | ctx = mx.gpu(3) 13 | epoch_num = 40 14 | early_stop = True 15 | dataset = 'mnist' 16 | z_prior = 'swiss_roll' 17 | if z_prior == 'uniform': 18 | z_args = { 19 | 'minv': -2.0, 20 | 'maxv': 2.0 21 | } 22 | elif z_prior == 'gaussian_mixture': 23 | z_args = { 24 | 'x_var': 0.5, 25 | 'y_var': 0.1 26 | } 27 | elif z_prior == 'swiss_roll': 28 | z_args = {} 29 | elif z_prior == 'gaussian': 30 | z_args = { 31 | 'mean': 0, 32 | 'var': 1 33 | } 34 | else: 35 | raise ValueError, 'unknown z_prior' 36 | batch_size=100 37 | n_dim = 2 38 | n_encoder = 1000 39 | n_decoder = 1000 40 | n_discriminator = 500 41 | enc_mult = 1 42 | dec_mult = 1 43 | dis_mult = 0.2 44 | std = 0.01 45 | lr = 0.002 46 | lr_factor = 1 47 | lr_factor_step = 2000 48 | wd = 0.0 49 | beta1 = 0.1 50 | optimizer_args = { 51 | 'optimizer': 'adam', 52 | 'optimizer_params': { 53 | 'clip_gradient': 5.0, 54 | 'learning_rate': lr, 55 | 'lr_scheduler': mx.lr_scheduler.FactorScheduler(lr_factor_step, lr_factor), 56 | 'wd': wd, 57 | 'beta1': beta1, 58 | } 59 | } 60 | 61 | if not os.path.exists('cache'): 62 | os.mkdir('cache') 63 | 64 | if dataset == 'mnist': 65 | X_train, X_test, Y_train, Y_test = get_mnist(root_dir='./cache', train_ratio=0.9) 66 | train_iter = mx.io.NDArrayIter(X_train, batch_size=batch_size) 67 | else: 68 | raise NotImplementedError 69 | rand_iter = RandIter(batch_size, n_dim, z_prior=z_prior, **z_args) 70 | label = mx.nd.zeros((batch_size,), ctx=ctx) 71 | 72 | sym_enc, sym_dec, sym_dis = make_aae_sym(data_dim=784, 73 | n_dim=n_dim, 74 | n_encoder=n_encoder, 75 | n_decoder=n_decoder, 76 | n_discriminator=n_discriminator, 77 | enc_mult=enc_mult, 78 | dec_mult=dec_mult, 79 | dis_mult=dis_mult, 80 | with_bn=False) 81 | 82 | mod_enc = mx.mod.Module(symbol=sym_enc, data_names=('data',), label_names=(), context=ctx) 83 | mod_enc.bind(data_shapes=train_iter.provide_data, 84 | label_shapes=[], 85 | inputs_need_grad=False) 86 | mod_enc.init_params(initializer=mx.init.Normal(std)) 87 | mod_enc.init_optimizer(**optimizer_args) 88 | 89 | mod_dec = mx.mod.Module(symbol=sym_dec, data_names=('z',), label_names=('data',), context=ctx) 90 | mod_dec.bind(data_shapes=rand_iter.provide_data, 91 | label_shapes=train_iter.provide_data, 92 | inputs_need_grad=True) 93 | mod_dec.init_params(initializer=mx.init.Normal(std)) 94 | mod_dec.init_optimizer(**optimizer_args) 95 | 96 | mod_dis = mx.mod.Module(symbol=sym_dis, data_names=('z', ), label_names=('label_pq',), context=ctx) 97 | mod_dis.bind(data_shapes=rand_iter.provide_data, 98 | label_shapes=[('label_pq', (batch_size,))], 99 | inputs_need_grad=True) 100 | mod_dis.init_params(initializer=mx.init.Normal(std)) 101 | mod_dis.init_optimizer(**optimizer_args) 102 | 103 | def facc(label, pred): 104 | pred = pred.ravel() 105 | label = label.ravel() 106 | return ((pred > 0.5) == label).mean() 107 | 108 | def fentropy(label, pred): 109 | pred = pred.ravel() 110 | label = label.ravel() 111 | return -(label*np.log(pred+1e-12) + (1.-label)*np.log(1.-pred+1e-12)).mean() 112 | 113 | def frmse(label, pred): 114 | dim = label.size/label.shape[0] 115 | label = label.reshape((-1, dim)) 116 | pred = pred.reshape((-1, dim)) 117 | return np.linalg.norm(label-pred, axis=1).mean() 118 | 119 | metric_dec_rmse = mx.metric.CustomMetric(frmse) 120 | metric_dis_accuracy = mx.metric.CustomMetric(facc) 121 | metric_dis_entropy = mx.metric.CustomMetric(fentropy) 122 | metric_fool_dis_accuracy = mx.metric.CustomMetric(facc) 123 | metric_fool_dis_entropy = mx.metric.CustomMetric(fentropy) 124 | # metric_dis_accuracy = mx.metric.create('accuracy') 125 | # metric_dis_entropy = mx.metric.create('ce') 126 | # metric_fool_dis_accuracy = mx.metric.create('accuracy') 127 | # metric_fool_dis_entropy = mx.metric.create('ce') 128 | 129 | print 'Training ...' 130 | for epoch in xrange(epoch_num): 131 | train_iter.reset() 132 | for t, batch in enumerate(train_iter): 133 | rbatch = rand_iter.next() 134 | 135 | # reconstruction phase: update encoder and decoder 136 | mod_enc.forward(batch, is_train=True) 137 | qz = mod_enc.get_outputs() 138 | mod_dec.forward(mx.io.DataBatch(qz, batch.data), is_train=True) 139 | mod_dec.backward() 140 | diff_dec = mod_dec.get_input_grads() 141 | mod_enc.backward(diff_dec) 142 | mod_enc.update() 143 | mod_dec.update() 144 | 145 | metric_dec_rmse.update(batch.data, mod_dec.get_outputs()) 146 | 147 | # regularization phase 148 | # step 1: update discriminator 149 | label[:] = 0 150 | mod_dis.forward(mx.io.DataBatch(qz, [label]), is_train=True) 151 | mod_dis.backward() 152 | # mod_dis.update() 153 | gradD = [[grad.copyto(grad.context) for grad in grads] for grads in mod_dis._exec_group.grad_arrays] 154 | metric_dis_accuracy.update([label], mod_dis.get_outputs()) 155 | metric_dis_entropy.update([label], mod_dis.get_outputs()) 156 | 157 | label[:] = 1 158 | pz = rbatch.data 159 | mod_dis.forward(mx.io.DataBatch(pz, [label]), is_train=True) 160 | mod_dis.backward() 161 | for gradsr, gradsf in zip(mod_dis._exec_group.grad_arrays, gradD): 162 | for gradr, gradf in zip(gradsr, gradsf): 163 | gradr += gradf 164 | mod_dis.update() 165 | metric_dis_accuracy.update([label], mod_dis.get_outputs()) 166 | metric_dis_entropy.update([label], mod_dis.get_outputs()) 167 | 168 | # step 2: update encoder(fool discriminator) 169 | label[:] = 1 170 | mod_enc.forward(batch, is_train=True) 171 | qz = mod_enc.get_outputs() 172 | mod_dis.forward(mx.io.DataBatch(qz, [label]), is_train=True) 173 | mod_dis.backward() 174 | diff_dis = mod_dis.get_input_grads() 175 | mod_enc.backward(diff_dis) 176 | mod_enc.update() 177 | 178 | # metric update 179 | mod_dis.forward(mx.io.DataBatch(qz, [label]), is_train=False) 180 | metric_fool_dis_accuracy.update([label], mod_dis.get_outputs()) 181 | metric_fool_dis_entropy.update([label], mod_dis.get_outputs()) 182 | 183 | if t % display_step == 0: 184 | print '\rEpoch %d, iter %d: dec_rmse=%.2f, dis_acc=%.4f,\ 185 | dis_entropy=%.2f, fool_dis_acc=%.4f, fool_dis_entropy=%.4f' %\ 186 | (epoch, t, metric_dec_rmse.get()[1], metric_dis_accuracy.get()[1],\ 187 | metric_dis_entropy.get()[1], metric_fool_dis_accuracy.get()[1],\ 188 | metric_fool_dis_entropy.get()[1]), 189 | sys.stdout.flush() 190 | 191 | metric_dec_rmse.reset() 192 | metric_dis_accuracy.reset() 193 | metric_dis_entropy.reset() 194 | metric_fool_dis_accuracy.reset() 195 | metric_fool_dis_entropy.reset() 196 | 197 | if (epoch+1) % check_point == 0: 198 | print 'Saving...' 199 | sym_dec.save('cache/models/%s_%s_dec_unsupervised.json'%(dataset, z_prior)) 200 | mod_dec.save_params('cache/models/%s_%s_dec_unsupervised_%04d.params'%(dataset, z_prior, epoch)) 201 | sym_enc.save('cache/models/%s_%s_enc_unsupervised.json'%(dataset, z_prior)) 202 | mod_enc.save_params('cache/models/%s_%s_enc_unsupervised_%04d.params'%(dataset, z_prior, epoch)) 203 | sym_dis.save('cache/models/%s_%s_dis_unsupervised.json'%(dataset, z_prior)) 204 | mod_dis.save_params('cache/models/%s_%s_dis_unsupervised_%04d.params'%(dataset, z_prior, epoch)) 205 | -------------------------------------------------------------------------------- /data_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import mxnet as mx 5 | from math import sin,cos,sqrt 6 | from sklearn.datasets import fetch_mldata 7 | 8 | def onehot_categorical(batch_size, n_labels): 9 | y = np.zeros((batch_size, n_labels), dtype=np.float32) 10 | indices = np.random.randint(0, n_labels, batch_size) 11 | for b in xrange(batch_size): 12 | y[b, indices[b]] = 1 13 | return y 14 | 15 | def uniform(batch_size, n_dim, n_labels=10, minv=-1, maxv=1, label_indices=None): 16 | # z = np.random.uniform(minv, maxv, (batch_size, n_dim)).astype(np.float32) 17 | def sample(label, n_labels): 18 | num = int(np.ceil(np.sqrt(n_labels))) 19 | size = (maxv-minv)*1.0/num 20 | x, y = np.random.uniform(-size/2, size/2, (2,)) 21 | i = label / num 22 | j = label % num 23 | x += j*size+minv+0.5*size 24 | y += i*size+minv+0.5*size 25 | return np.array([x, y]).reshape((2,)) 26 | 27 | z = np.empty((batch_size, n_dim), dtype=np.float32) 28 | for batch in xrange(batch_size): 29 | for zi in xrange(n_dim/2): 30 | if label_indices is not None: 31 | z[batch, zi*2:zi*2+2] = sample(label_indices[batch], n_labels) 32 | else: 33 | z[batch, zi*2:zi*2+2] = sample(np.random.randint(0, n_labels), n_labels) 34 | return z 35 | 36 | def gaussian(batch_size, n_dim, mean=0, var=1): 37 | z = np.random.normal(mean, var, (batch_size, n_dim)).astype(np.float32) 38 | return z 39 | 40 | def gaussian_mixture(batch_size, n_dim=2, n_labels=10, x_var=0.5, y_var=0.1, label_indices=None): 41 | if n_dim % 2 != 0: 42 | raise Exception("n_dim must be a multiple of 2.") 43 | 44 | def sample(x, y, label, n_labels): 45 | shift = 1.4 46 | r = 2.0 * np.pi / float(n_labels) * float(label) 47 | new_x = x * cos(r) - y * sin(r) 48 | new_y = x * sin(r) + y * cos(r) 49 | new_x += shift * cos(r) 50 | new_y += shift * sin(r) 51 | return np.array([new_x, new_y]).reshape((2,)) 52 | 53 | x = np.random.normal(0, x_var, (batch_size, n_dim / 2)) 54 | y = np.random.normal(0, y_var, (batch_size, n_dim / 2)) 55 | z = np.empty((batch_size, n_dim), dtype=np.float32) 56 | for batch in xrange(batch_size): 57 | for zi in xrange(n_dim / 2): 58 | if label_indices is not None: 59 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], label_indices[batch], n_labels) 60 | else: 61 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], np.random.randint(0, n_labels), n_labels) 62 | 63 | return z 64 | 65 | def swiss_roll(batch_size, n_dim=2, n_labels=10, label_indices=None): 66 | def sample(label, n_labels): 67 | uni = np.random.uniform(0.0, 1.0) / float(n_labels) + float(label) / float(n_labels) 68 | r = sqrt(uni) * 3.0 69 | rad = np.pi * 4.0 * sqrt(uni) 70 | x = r * cos(rad) 71 | y = r * sin(rad) 72 | return np.array([x, y]).reshape((2,)) 73 | 74 | z = np.zeros((batch_size, n_dim), dtype=np.float32) 75 | for batch in xrange(batch_size): 76 | for zi in xrange(n_dim / 2): 77 | if label_indices is not None: 78 | z[batch, zi*2:zi*2+2] = sample(label_indices[batch], n_labels) 79 | else: 80 | z[batch, zi*2:zi*2+2] = sample(np.random.randint(0, n_labels), n_labels) 81 | return z 82 | 83 | 84 | class RandIter(mx.io.DataIter): 85 | def __init__(self, batch_size, n_dim, z_prior='gaussian', n_labels=10, with_label=False, **zargs): 86 | self.batch_size = batch_size 87 | self.n_dim = n_dim 88 | self.provide_data = [('z', (batch_size, n_dim))] 89 | self.n_labels = n_labels 90 | if with_label: 91 | self.provide_label = [('label_n', (batch_size, n_labels))] 92 | assert z_prior in ['gaussian_mixture', 'swiss_roll', 'uniform'] 93 | else: 94 | self.provide_label = [] 95 | self.z_prior = z_prior 96 | self.with_label = with_label 97 | self.zargs = zargs 98 | self.tmp_labels = None 99 | 100 | def iter_next(self): 101 | if self.with_label: 102 | self.tmp_labels = np.random.randint(0, self.n_labels, (self.batch_size)) 103 | return True 104 | 105 | def getlabel(self): 106 | if self.with_label: 107 | label = np.zeros((self.batch_size, self.n_labels), dtype=np.float32) 108 | for i in xrange(self.batch_size): 109 | label[i, self.tmp_labels[i]] = 1 110 | return [mx.nd.array(label)] 111 | else: 112 | return [] 113 | 114 | def getdata(self): 115 | if self.with_label: 116 | self.zargs['label_indices'] = self.tmp_labels 117 | 118 | if self.z_prior == 'gaussian': 119 | return [mx.nd.array(gaussian(self.batch_size, self.n_dim, **self.zargs))] 120 | elif self.z_prior == 'uniform': 121 | return [mx.nd.array(uniform(self.batch_size, self.n_dim, **self.zargs))] 122 | elif self.z_prior == 'gaussian_mixture': 123 | return [mx.nd.array(gaussian_mixture(self.batch_size, self.n_dim, n_labels=self.n_labels, **self.zargs))] 124 | elif self.z_prior == 'swiss_roll': 125 | return [mx.nd.array(swiss_roll(self.batch_size, self.n_dim, n_labels=self.n_labels, **self.zargs))] 126 | else: 127 | raise NotImplementedError 128 | 129 | 130 | def get_mnist(root_dir='~', train_ratio=0.5, resize=None): 131 | mnist = fetch_mldata('MNIST original', data_home=os.path.join(root_dir, 'scikit_learn_data')) 132 | np.random.seed(1234) # set seed for deterministic ordering 133 | p = np.random.permutation(mnist.data.shape[0]) 134 | X = mnist.data[p] 135 | X = X.reshape((70000, 28, 28)) 136 | Y = mnist.target[p] 137 | 138 | train_num = int(X.shape[0]*train_ratio) 139 | 140 | if resize: 141 | X = np.asarray([cv2.resize(x, resize) for x in X]) 142 | X = X.astype(np.float32)/(255.0) 143 | X = X.reshape((70000, 1, 28, 28)) 144 | # X = np.tile(X, (1, 3, 1, 1)) 145 | X_train = X[:train_num] 146 | Y_train = Y[:train_num] 147 | X_test = X[train_num:] 148 | Y_test = Y[train_num:] 149 | 150 | return X_train, X_test, Y_train, Y_test 151 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | 3 | 4 | def make_aae_sym(data_dim=784, n_dim=2, n_encoder=1000, n_decoder=1000, n_discriminator=500, enc_mult=1, dec_mult=1, dis_mult=1, with_bn=True, supervised=False): 5 | data = mx.sym.Variable('data') 6 | data = mx.sym.Flatten(data=data) 7 | z = mx.sym.Variable('z') 8 | z = mx.sym.Flatten(data=z) 9 | 10 | # Encoder 11 | enc = mx.sym.FullyConnected(data=data, num_hidden=n_encoder, attr={'lr_mult': str(enc_mult)}, name='enc_fc1') 12 | if with_bn: 13 | enc = mx.sym.BatchNorm(data=enc, name='enc_bn1') 14 | enc = mx.sym.Activation(data=enc, name='enc_relu1', act_type='relu') 15 | enc = mx.sym.FullyConnected(data=enc, num_hidden=n_encoder, attr={'lr_mult': str(enc_mult)}, name='enc_fc2') 16 | if with_bn: 17 | enc = mx.sym.BatchNorm(data=enc, name='enc_bn2') 18 | enc = mx.sym.Activation(data=enc, name='enc_relu2', act_type='relu') 19 | enc = mx.sym.FullyConnected(data=enc, num_hidden=n_dim, attr={'lr_mult': str(enc_mult)}, name='enc_fc3') 20 | if with_bn: 21 | enc = mx.sym.BatchNorm(data=enc, name='enc_bn3') 22 | 23 | # Decoder 24 | dec = mx.sym.FullyConnected(data=z, num_hidden=n_decoder, attr={'lr_mult': str(dec_mult)}, name='dec_fc1') 25 | if with_bn: 26 | dec = mx.sym.BatchNorm(data=dec, name='enc_bn1') 27 | dec = mx.sym.Activation(data=dec, name='dec_relu1', act_type='relu') 28 | dec = mx.sym.FullyConnected(data=dec, num_hidden=n_decoder, attr={'lr_mult': str(dec_mult)}, name='dec_fc2') 29 | if with_bn: 30 | dec = mx.sym.BatchNorm(data=dec, name='enc_bn2') 31 | dec = mx.sym.Activation(data=dec, name='dec_relu2', act_type='relu') 32 | dec = mx.sym.FullyConnected(data=dec, num_hidden=data_dim, attr={'lr_mult': str(dec_mult)}, name='dec_fc3') 33 | dec = mx.sym.Activation(data=dec, name='dec_out', act_type='sigmoid') 34 | dec = mx.sym.LinearRegressionOutput(data=dec, label=data, name='dec_loss') 35 | 36 | # Discriminator 37 | label_pq = mx.sym.Variable('label_pq') 38 | if supervised: 39 | label_n = mx.sym.Variable('label_n') 40 | z = mx.sym.Concat(label_n, z) 41 | 42 | dis = mx.sym.FullyConnected(data=z, num_hidden=n_discriminator, attr={'lr_mult': str(dis_mult)}, name='dis_fc1') 43 | if with_bn: 44 | dis = mx.sym.BatchNorm(data=dis, name='dis_bn1') 45 | dis = mx.sym.Activation(data=dis, act_type='relu', name='dis_relu1') 46 | dis = mx.sym.FullyConnected(data=dis, num_hidden=n_discriminator, attr={'lr_mult': str(dis_mult)}, name='dis_fc2') 47 | if with_bn: 48 | dis = mx.sym.BatchNorm(data=dis, name='dis_bn2') 49 | dis = mx.sym.Activation(data=dis, act_type='relu', name='dis_relu1') 50 | dis = mx.sym.FullyConnected(data=dis, num_hidden=1, attr={'lr_mult': str(dis_mult)}, name='dis_clf') 51 | dis = mx.sym.LogisticRegressionOutput(data=dis, name='dis_pred', label=label_pq) 52 | 53 | return enc, dec, dis 54 | 55 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from sklearn.decomposition import PCA 4 | 5 | def plot_latent_variable(X, Y): 6 | # print '%d samples in total' % X.shape[0] 7 | if X.shape[1] != 2: 8 | pca = PCA(n_components=2) 9 | X = pca.fit_transform(X) 10 | print pca.explained_variance_ratio_ 11 | plt.figure(figsize=(8, 8)) 12 | plt.axes().set_aspect('equal') 13 | color = plt.cm.rainbow(np.linspace(0, 1, 10)) 14 | for l, c in enumerate(color): 15 | inds = np.where(Y==l) 16 | # print '\t%d samples of label %d' % (len(inds[0]), l) 17 | plt.scatter(X[inds, 0], X[inds, 1], c=c, label=l, linewidth=0, s=8) 18 | # plt.xlim([-5.0, 5.0]) 19 | # plt.ylim([-5.0, 5.0]) 20 | plt.legend() 21 | plt.show() 22 | --------------------------------------------------------------------------------