├── .gitignore ├── LICENSE ├── README.md └── src ├── aae_class.py ├── data ├── t10k-images-idx3-ubyte.gz ├── t10k-labels-idx1-ubyte.gz ├── train-images-idx3-ubyte.gz └── train-labels-idx1-ubyte.gz ├── demo_aae.ipynb ├── demo_kdpp.ipynb └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sungjoon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ### Implementations of deep autoencoders (VAE, AAE, and others) 3 | 4 | Finished implementing the basic AAE. 5 | -------------------------------------------------------------------------------- /src/aae_class.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from util import kdpp,get_mnist,remove_warnings,gpu_sess,plot_imgs 6 | 7 | class gmm_sampler_class(object): 8 | def __init__(self,_name='gmm',_z_dim=16,_k=5,_var=0.05): 9 | self.name = _name 10 | self.z_dim = _z_dim 11 | self.k = _k 12 | # Determine k means within [-1~+1] 13 | self.mu,_ = kdpp(_X=2*np.random.rand(1000,self.z_dim)-1,_k=self.k) 14 | # Fix variance of each dim to be 0.1 15 | self.var = _var*np.ones(shape=(self.k,self.z_dim)) 16 | def sample(self,_n): 17 | samples = np.zeros(shape=(_n,self.z_dim)) 18 | for i in range(_n): 19 | k = np.random.randint(low=0,high=self.k) # mixture 20 | mu_k = self.mu[k,:] 21 | var_k = self.var[k,:] 22 | samples[i,:] = mu_k + np.sqrt(var_k)*np.random.randn(self.z_dim) 23 | return samples 24 | def plot(self,_n=1000,_title='Samples',_tfs=18): 25 | samples = self.sample(_n=_n) 26 | plt.figure(figsize=(6,6)) 27 | plt.plot(samples[:,0],samples[:,1],'k.') 28 | plt.xlim(-2.0,2.0); plt.ylim(-2.0,2.0) 29 | plt.gca().set_aspect('equal', adjustable='box') 30 | plt.title(_title,fontsize=_tfs) 31 | plt.show() 32 | 33 | tfd = tf.contrib.distributions 34 | tfrni = tf.random_normal_initializer 35 | tfci = tf.constant_initializer 36 | tfrui = tf.random_uniform_initializer 37 | tfscewl = tf.nn.sigmoid_cross_entropy_with_logits 38 | class AAE_class(object): 39 | def __init__(self,_name='aae',_x_dim=784,_z_dim=16, 40 | _h_dims_Q=[64,16],_h_dims_P=[64,16],_h_dims_D=[64,16], 41 | _actv_Q=tf.nn.relu,_actv_P=tf.nn.relu,_actv_D=tf.nn.relu, 42 | _l2_reg_coef=1e-4, 43 | _opmz=tf.train.RMSPropOptimizer,_lr=1e-3, 44 | _sess=None,_seed=0, 45 | _VERBOSE=True): 46 | self.name = _name 47 | self.x_dim = _x_dim 48 | self.z_dim = _z_dim 49 | self.h_dims_Q = _h_dims_Q 50 | self.h_dims_P = _h_dims_P 51 | self.h_dims_D = _h_dims_D 52 | self.actv_Q = _actv_Q 53 | self.actv_P = _actv_P 54 | self.actv_D = _actv_D 55 | self.l2_reg_coef = _l2_reg_coef 56 | self.opmz = _opmz 57 | self.lr = _lr 58 | self.sess = _sess 59 | self.seed = _seed 60 | self.VERBOSE = _VERBOSE 61 | # Define sampler 62 | self.sampler = gmm_sampler_class(_z_dim=self.z_dim,_k=5,_var=0.02) 63 | if self.VERBOSE: 64 | self.sampler.plot() 65 | # Build graph 66 | self._build_graph() 67 | # Check parameters 68 | self._check_params() 69 | 70 | # Build graph 71 | def _build_graph(self): 72 | with tf.variable_scope(self.name,reuse=False) as scope: 73 | # Placeholders 74 | self.x_real = tf.placeholder(shape=[None,self.x_dim],dtype=tf.float32,name='x') # [n x x_dim] 75 | self.z_sample = tf.placeholder(shape=[None,self.z_dim],dtype=tf.float32,name='z') # [n x z_dim] 76 | self.kp = tf.placeholder(shape=[],dtype=tf.float32,name='kp') # [1] 77 | 78 | # Encoder netowrk Q(z|x): x_real => z_real 79 | with tf.variable_scope('Q',reuse=False): 80 | self.net = self.x_real 81 | for h_idx,hid in enumerate(self.h_dims_Q): 82 | self.net = tf.layers.dense(self.net,hid,activation=self.actv_Q, 83 | kernel_initializer=tfrni(stddev=0.1),bias_initializer=tfci(0), 84 | name='hid_Q_%d'%(h_idx)) 85 | self.net = tf.layers.dropout(self.net, rate=self.kp) 86 | self.z_real = tf.layers.dense(self.net,self.z_dim,activation=None, 87 | kernel_initializer=tfrni(stddev=0.1),bias_initializer=tfci(0), 88 | name='z_real') # [n x z_dim] 89 | 90 | # Decoder network P(x|z): z_real => x_recon 91 | with tf.variable_scope('P',reuse=False): 92 | self.net = self.z_real 93 | for h_idx,hid in enumerate(self.h_dims_P): 94 | self.net = tf.layers.dense(self.net,hid,activation=self.actv_P, 95 | kernel_initializer=tfrni(stddev=0.1),bias_initializer=tfci(0), 96 | name='hid_P_%d'%(h_idx)) 97 | self.net = tf.layers.dropout(self.net, rate=self.kp) 98 | self.x_recon = tf.layers.dense(self.net,self.x_dim,activation=None, 99 | kernel_initializer=tfrni(stddev=0.1),bias_initializer=tfci(0), 100 | name='x_recon') # [n x x_dim] 101 | 102 | # Decoder network P(x|z): z_sample => x_sample 103 | with tf.variable_scope('P',reuse=True): 104 | self.net = self.z_sample 105 | for h_idx,hid in enumerate(self.h_dims_P): 106 | self.net = tf.layers.dense(self.net,hid,activation=self.actv_P, 107 | kernel_initializer=tfrni(stddev=0.1),bias_initializer=tfci(0), 108 | name='hid_P_%d'%(h_idx)) 109 | self.net = tf.layers.dropout(self.net, rate=self.kp) 110 | self.x_sample = tf.layers.dense(self.net,self.x_dim,activation=None, 111 | kernel_initializer=tfrni(stddev=0.1),bias_initializer=tfci(0), 112 | name='x_recon') # [n x x_dim] 113 | 114 | # Discriminator D(z): z_real => d_real 115 | with tf.variable_scope('D',reuse=False): 116 | self.net = self.z_real 117 | for h_idx,hid in enumerate(self.h_dims_D): 118 | self.net = tf.layers.dense(self.net,hid,activation=self.actv_D, 119 | kernel_initializer=tfrni(stddev=0.1),bias_initializer=tfci(0), 120 | name='hid_D_%d'%(h_idx)) 121 | self.net = tf.layers.dropout(self.net, rate=self.kp) 122 | self.d_real_logits = tf.layers.dense(self.net,1,activation=None, 123 | kernel_initializer=tfrni(stddev=0.1),bias_initializer=tfci(0), 124 | name='d_logits') # [n x 1] 125 | self.d_real = tf.sigmoid(self.d_real_logits,name='d') # [n x 1] 126 | 127 | # Discriminator D(z): 128 | with tf.variable_scope('D',reuse=True): 129 | self.net = self.z_sample 130 | for h_idx,hid in enumerate(self.h_dims_D): 131 | self.net = tf.layers.dense(self.net,hid,activation=self.actv_D, 132 | kernel_initializer=tfrni(stddev=0.1),bias_initializer=tfci(0), 133 | name='hid_D_%d'%(h_idx)) 134 | self.net = tf.layers.dropout(self.net, rate=self.kp) 135 | self.d_fake_logits = tf.layers.dense(self.net,1,activation=None, 136 | kernel_initializer=tfrni(stddev=0.1),bias_initializer=tfci(0), 137 | name='d_logits') # [n x 1] 138 | self.d_fake = tf.sigmoid(self.d_real_logits,name='d') # [n x 1] 139 | 140 | # Loss functions 141 | self.d_loss_reals = tfscewl(logits=self.d_real_logits,labels=tf.zeros_like(self.d_real_logits)) # [n x 1] 142 | self.d_loss_fakes = tfscewl(logits=self.d_fake_logits,labels=tf.ones_like(self.d_fake_logits)) # [n x 1] 143 | self.d_losses = self.d_loss_reals + self.d_loss_fakes # [n x 1] 144 | self.g_losses = tfscewl(logits=self.d_real_logits,labels=tf.ones_like(self.d_real_logits)) # [n x 1] 145 | self.ae_losses = 0.5*tf.norm(self.x_recon-self.x_real,ord=1,axis=1) # [n x 1] 146 | self.d_loss = tf.reduce_mean(self.d_losses) # [1] 147 | self.g_loss = tf.reduce_mean(self.g_losses) # [1] 148 | self.ae_loss = tf.reduce_mean(self.ae_losses) # [1] 149 | self.t_vars = tf.trainable_variables() 150 | self.c_vars = [var for var in self.t_vars if '%s/'%(self.name) in var.name] 151 | self.l2_reg = self.l2_reg_coef*tf.reduce_sum(tf.stack([tf.nn.l2_loss(v) for v in self.c_vars])) # [1] 152 | 153 | # Optimizer 154 | self.ae_vars = [var for var in self.t_vars if '%s/Q'%(self.name) or '%s/P'%(self.name) in var.name] 155 | self.d_vars = [var for var in self.t_vars if '%s/D'%(self.name) in var.name] 156 | self.g_vars = [var for var in self.t_vars if '%s/Q'%(self.name) in var.name] 157 | self.optm_ae = self.opmz(self.lr).minimize(self.ae_loss+self.l2_reg,var_list=self.ae_vars) 158 | self.optm_d = self.opmz(self.lr/2.).minimize(self.d_loss+self.l2_reg,var_list=self.d_vars) 159 | self.optm_g = self.opmz(self.lr).minimize(self.g_loss+self.l2_reg,var_list=self.g_vars) 160 | 161 | # Check parameters 162 | def _check_params(self): 163 | _g_vars = tf.global_variables() 164 | self.g_vars = [var for var in _g_vars if '%s/'%(self.name) in var.name] 165 | if self.VERBOSE: 166 | print ("==== Global Variables ====") 167 | for i in range(len(self.g_vars)): 168 | w_name = self.g_vars[i].name 169 | w_shape = self.g_vars[i].get_shape().as_list() 170 | if self.VERBOSE: 171 | print (" [%02d/%d] Name:[%s] Shape:%s" % (i,len(self.g_vars),w_name,w_shape)) 172 | 173 | # Train 174 | def train(self,_X,_Y=None,_max_iter=1e4,_batch_size=256, 175 | _PRINT_EVERY=1e3,_PLOT_EVERY=1e3): 176 | tf.set_random_seed(self.seed); np.random.seed(self.seed) # fix seeds 177 | self.sess.run(tf.global_variables_initializer()) # initialize variables 178 | n_x = _X.shape[0] # number of training data 179 | for _iter in range((int)(_max_iter)): 180 | rand_idx = np.random.permutation(n_x)[:_batch_size] 181 | x_batch = _X[rand_idx,:] 182 | z_sample = self.sampler.sample(_batch_size) 183 | feeds = {self.x_real:x_batch,self.kp:0.8,self.z_sample:z_sample} 184 | _,ae_loss_val = self.sess.run([self.optm_ae,self.ae_loss],feed_dict=feeds) 185 | _,d_loss_val = self.sess.run([self.optm_d,self.d_loss],feed_dict=feeds) 186 | for _ in range(2): _,g_loss_val = self.sess.run([self.optm_g,self.g_loss],feed_dict=feeds) 187 | 188 | # Print-out 189 | if (((_iter+1)%_PRINT_EVERY)==0) & (_PRINT_EVERY>0): 190 | total_loss = ae_loss_val+d_loss_val+g_loss_val 191 | print ("[%04d/%d]Loss AE:%.3f D:%.3f G:%.3f total loss:%.3f"% 192 | (_iter+1,_max_iter,ae_loss_val,d_loss_val,g_loss_val,total_loss)) 193 | 194 | # Plot samples 195 | if ( (_iter==0) | (((_iter+1)%_PLOT_EVERY)==0) ) & (_PLOT_EVERY>0): 196 | # Sample images using z~GMM 197 | z_samples4img = self.sampler.sample(10) 198 | feeds = {self.z_sample:z_samples4img, self.kp:1.0} 199 | sampled_images = self.sess.run(self.x_sample,feed_dict=feeds) 200 | plot_imgs(_imgs=sampled_images,_imgSz=(28,28), 201 | _nR=1,_nC=10,_figsize=(15,2),_title='Sampled Images',_tfs=18) 202 | 203 | # Plot z space 204 | rand_idx = np.random.permutation(n_x)[:min(n_x,2000)] # upto 2,000 inputs 205 | x_batch = _X[rand_idx,:] 206 | z_real = self.sess.run(self.z_real,feed_dict={self.x_real:x_batch,self.kp:1.0}) 207 | z_samples = self.sampler.sample(1000) 208 | plt.figure(figsize=(6,6)) 209 | h_sample,=plt.plot(z_samples[:,0],z_samples[:,1],'kx') 210 | if _Y is None: 211 | h_real,=plt.plot(z_real[:,0],z_real[:,1],'b.') 212 | plt.legend([h_sample,h_real],['Prior','Encoded'],fontsize=15) 213 | else: 214 | hs,strs = [h_sample],['Prior'] 215 | ys = np.argmax(_Y[rand_idx,:],axis=1) 216 | ydim = np.shape(_Y)[1] 217 | cmap = plt.get_cmap('gist_rainbow') 218 | colors = [cmap(ii) for ii in np.linspace(0,1,ydim)] 219 | for i in range(ydim): 220 | yi_idx = np.argwhere(ys==i).squeeze() 221 | hi,=plt.plot(z_real[yi_idx,0],z_real[yi_idx,1],'.',color=colors[i]) 222 | hs.append(hi) 223 | strs.append('Encoded (%d)'%(i)) 224 | plt.legend(hs,strs,fontsize=12,bbox_to_anchor=(1.04,1), loc="upper left") 225 | plt.xlim(-2.0,2.0); plt.ylim(-2.0,2.0) 226 | plt.gca().set_aspect('equal', adjustable='box') 227 | plt.title('Z-space',fontsize=18) 228 | plt.show() 229 | 230 | print ("[train] Done.") 231 | 232 | -------------------------------------------------------------------------------- /src/data/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjchoi86/deep-autoencoders/d2564a7a80dfb65cc055851bfffe1a8e8d8c1d76/src/data/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /src/data/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjchoi86/deep-autoencoders/d2564a7a80dfb65cc055851bfffe1a8e8d8c1d76/src/data/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /src/data/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjchoi86/deep-autoencoders/d2564a7a80dfb65cc055851bfffe1a8e8d8c1d76/src/data/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /src/data/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjchoi86/deep-autoencoders/d2564a7a80dfb65cc055851bfffe1a8e8d8c1d76/src/data/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import warnings,gym 2 | import tensorflow as tf 3 | import numpy as np 4 | from scipy.spatial.distance import pdist, squareform, cdist 5 | import matplotlib.pyplot as plt 6 | import matplotlib.gridspec as gridspec 7 | 8 | 9 | def kernel_se(_X1,_X2,_hyp={'gain':1,'len':1,'noise':1e-8}): 10 | hyp_gain = float(_hyp['gain'])**2 11 | hyp_len = 1/float(_hyp['len']) 12 | pairwise_dists = cdist(_X2,_X2,'euclidean') 13 | K = hyp_gain*np.exp(-pairwise_dists ** 2 / (hyp_len**2)) 14 | return K 15 | 16 | def kdpp(_X,_k): 17 | # Select _n samples out of _X using K-DPP 18 | n,d = _X.shape[0],_X.shape[1] 19 | mid_dist = np.median(cdist(_X,_X,'euclidean')) 20 | out,idx = np.zeros(shape=(_k,d)),[] 21 | for i in range(_k): 22 | if i == 0: 23 | rand_idx = np.random.randint(n) 24 | idx.append(rand_idx) # append index 25 | out[i,:] = _X[rand_idx,:] # append inputs 26 | else: 27 | det_vals = np.zeros(n) 28 | for j in range(n): 29 | if j in idx: 30 | det_vals[j] = -np.inf 31 | else: 32 | idx_temp = idx.copy() 33 | idx_temp.append(j) 34 | X_curr = _X[idx_temp,:] 35 | K = kernel_se(X_curr,X_curr,{'gain':1,'len':mid_dist,'noise':1e-4}) 36 | det_vals[j] = np.linalg.det(K) 37 | max_idx = np.argmax(det_vals) 38 | idx.append(max_idx) 39 | out[i,:] = _X[max_idx,:] # append inputs 40 | return out,idx 41 | 42 | def remove_warnings(): 43 | gym.logger.set_level(40) 44 | warnings.filterwarnings("ignore") 45 | tf.logging.set_verbosity(tf.logging.ERROR) 46 | 47 | def numpy_setting(): 48 | np.set_printoptions(precision=3) 49 | 50 | def get_mnist(): 51 | from tensorflow.examples.tutorials.mnist import input_data 52 | mnist = input_data.read_data_sets('data', one_hot=True) 53 | return mnist 54 | 55 | def gpu_sess(): 56 | config = tf.ConfigProto(); 57 | config.gpu_options.allow_growth=True 58 | sess = tf.Session(config=config) 59 | return sess 60 | 61 | def plot_imgs(_imgs,_imgSz=(28,28),_nR=1,_nC=10,_figsize=(15,2), 62 | _title=None,_titles=None,_tfs=15, 63 | _wspace=0.05,_hspace=0.05): 64 | nr,nc = _nR,_nC 65 | fig = plt.figure(figsize=_figsize) 66 | if _title is not None: 67 | fig.suptitle(_title, size=15) 68 | gs = gridspec.GridSpec(nr,nc) 69 | gs.update(wspace=_wspace, hspace=_hspace) 70 | for i, img in enumerate(_imgs): 71 | ax = plt.subplot(gs[i]) 72 | plt.axis('off') 73 | ax.set_xticklabels([]) 74 | ax.set_yticklabels([]) 75 | ax.set_aspect('equal') 76 | if len(img.shape) == 1: 77 | img = np.reshape(img,newshape=_imgSz) 78 | plt.imshow(img,cmap='Greys_r',interpolation='none') 79 | plt.clim(0.0, 1.0) 80 | if _titles is not None: 81 | plt.title(_titles[i],size=_tfs) 82 | plt.show() --------------------------------------------------------------------------------