├── .gitignore ├── README.md └── VRAE.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # Jupyter notebooks 8 | *.ipynb 9 | 10 | # PyCharm 11 | \.idea 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *,cover 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | 61 | # Sphinx documentation 62 | docs/_build/ 63 | 64 | # PyBuilder 65 | target/ 66 | 67 | # Created by .ignore support plugin (hsz.mobi) 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LSTM Variational Recurrent Auto-Encoder 2 | Implementation of the Variational Recurrent Auto-Encoder (http://arxiv.org/pdf/1412.6581.pdf) using single-layer 3 | LSTMs as both the encoder and decoder. 4 | 5 | This is based on Github user RyotaKatoh's chainer-Variational-Recurrent-Autoencoder 6 | (https://github.com/RyotaKatoh/chainer-Variational-Recurrent-Autoencoder) 7 | 8 | ## Dependencies 9 | * Anaconda 10 | * Chainer 11 | 12 | -------------------------------------------------------------------------------- /VRAE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from chainer import Variable, Chain 3 | from chainer import functions as F 4 | 5 | class LSTMVRAE(Chain): 6 | """ 7 | Class: LSTMVRAE 8 | =============== 9 | Implements Variational Recurrent Autoencoders, described here: http://arxiv.org/pdf/1412.6581.pdf 10 | This specific architecture uses a single-layer LSTM for both the encoder and the decoder. 11 | """ 12 | 13 | def __init__(self, n_input, n_hidden, n_latent, loss_func): 14 | """ 15 | :param n_input: number of input dimensions 16 | :param n_hidden: number of LSTM cells for both generator and decoder 17 | :param n_latent: number of dimensions for latent code (z) 18 | :param loss_func: loss function to compute reconstruction error (e.g. F.mean_squared_error) 19 | """ 20 | self.__dict__.update(locals()) 21 | super(LSTMVRAE, self).__init__( 22 | 23 | # Encoder (recognition): 24 | recog_x_h=F.Linear(n_input, n_hidden*4), 25 | recog_h_h=F.Linear(n_hidden, n_hidden*4), 26 | recog_mean=F.Linear(n_hidden, n_latent), 27 | recog_log_sigma=F.Linear(n_hidden, n_latent), 28 | 29 | # Decoder (generation) 30 | gen_z_h=F.Linear(n_latent, n_hidden*4), 31 | gen_x_h=F.Linear(n_input, n_hidden*4), 32 | gen_h_h=F.Linear(n_hidden, n_hidden*4), 33 | output=F.Linear(n_hidden, n_input) 34 | ) 35 | 36 | def make_initial_state(self): 37 | """Returns an initial state of the RNN - all zeros""" 38 | return { 39 | 'h_rec':Variable(np.zeros((1, self.n_hidden), dtype=np.float32)), 40 | 'c_rec':Variable(np.zeros((1, self.n_hidden), dtype=np.float32)), 41 | 'h_gen':Variable(np.zeros((1, self.n_hidden), dtype=np.float32)), 42 | 'c_gen':Variable(np.zeros((1, self.n_hidden), dtype=np.float32)) 43 | } 44 | 45 | def forward(self, x_data, state): 46 | """ 47 | Does encode/decode on x_data. 48 | :param x_data: input data (a single timestep) as a numpy.ndarray 49 | :param state: previous state of RNN 50 | :param nonlinear_q: nonlinearity used in q(z|x) (encoder) 51 | :param nonlinear_p: nonlinearity used in p(x|z) (decoder) 52 | :param output_f: #TODO# 53 | :return: output, recognition loss, KL Divergence, state 54 | """ 55 | #=====[ Step 1: Compute q(z|x) - encoding step, get z ]===== 56 | # Forward encoding 57 | for i in range(x_data.shape[0]): 58 | x = Variable(x_data[i].reshape((1, x_data.shape[1]))) 59 | h_in = self.recog_x_h(x) + self.recog_h_h(state['h_rec']) 60 | c_t, h_t = F.lstm(state['c_rec'], h_in) 61 | state.update({'c_rec':c_t, 'h_rec':h_t}) 62 | # Compute q_mean and q_log_sigma 63 | q_mean = self.recog_mean( state['h_rec'] ) 64 | q_log_sigma = 0.5 * self.recog_log_sigma( state['h_rec'] ) 65 | # Compute KL divergence based on q_mean and q_log_sigma 66 | KLD = -0.0005 * F.sum(1 + q_log_sigma - q_mean**2 - F.exp(q_log_sigma)) 67 | # Compute as q_mean + noise*exp(q_log_sigma) 68 | eps = Variable(np.random.normal(0, 1, q_log_sigma.data.shape ).astype(np.float32)) 69 | z = q_mean + F.exp(q_log_sigma) * eps 70 | 71 | #=====[ Step 2: Compute p(x|z) - decoding step ]===== 72 | # Initial step 73 | output = [] 74 | h_in = self.gen_z_h(z) 75 | c_t, h_t = F.lstm(state['c_gen'], h_in) 76 | state.update({'c_gen':c_t, 'h_gen':h_t}) 77 | rec_loss = Variable(np.zeros((), dtype=np.float32)) 78 | for i in range(x_data.shape[0]): 79 | # Get output and loss 80 | x_t = self.output(h_t) 81 | output.append(x_t.data) 82 | rec_loss += self.loss_func(x_t, Variable(x_data[i].reshape((1, x_data.shape[1])))) 83 | # Get next hidden state 84 | h_in = self.gen_x_h(x_t) + self.gen_h_h(state['h_gen']) 85 | c_t, h_t = F.lstm(state['c_gen'], h_in) 86 | state.update({'c_gen':c_t, 'h_gen':h_t}) 87 | 88 | #=====[ Step 3: Compute KL-Divergence based on all terms ]===== 89 | return output, rec_loss, KLD, state 90 | --------------------------------------------------------------------------------