├── .gitignore ├── LICENSE ├── README.md ├── lgss_example.py └── variational_smc.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Blei Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Sequential Monte Carlo 2 | 3 | The repository contains code for the variational sequential Monte Carlo (VSMC) algorithm for approximate Bayesian inference: 4 | ``` 5 | Variational Sequential Monte Carlo. 6 | Christian A. Naesseth, Scott W. Linderman, Rajesh Ranganath, and David M. Blei 7 | Proceedings of the 21st International Conference on Artificial Intelligence and Statistics 2018, 8 | Lanzarote, Spain. 9 | ``` 10 | Furthermore, it contains a simulation example (a linear Gaussian state space model) on how to use the VSMC module. 11 | Note that this example learns both model parameters and proposal parameters so the final lower bound will not be a lower 12 | bound to the exact log-marginal likelihood for the parameters that generated the data. 13 | -------------------------------------------------------------------------------- /lgss_example.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | 4 | import autograd.numpy as np 5 | import autograd.numpy.random as npr 6 | from autograd import grad 7 | from autograd.misc.optimizers import adam 8 | from variational_smc import * 9 | 10 | def init_model_params(Dx, Dy, alpha, r, obs, rs = npr.RandomState(0)): 11 | mu0 = np.zeros(Dx) 12 | Sigma0 = np.eye(Dx) 13 | 14 | A = np.zeros((Dx,Dx)) 15 | for i in range(Dx): 16 | for j in range(Dx): 17 | A[i,j] = alpha**(abs(i-j)+1) 18 | 19 | Q = np.eye(Dx) 20 | C = np.zeros((Dy,Dx)) 21 | if obs == 'sparse': 22 | C[:Dy,:Dy] = np.eye(Dy) 23 | else: 24 | C = rs.normal(size=(Dy,Dx)) 25 | R = r * np.eye(Dy) 26 | 27 | return (mu0, Sigma0, A, Q, C, R) 28 | 29 | def init_prop_params(T, Dx, scale = 0.5, rs = npr.RandomState(0)): 30 | return [(scale * rs.randn(Dx), # Bias 31 | 1. + scale * rs.randn(Dx), # Linear times A/mu0 32 | scale * rs.randn(Dx)) # Log-var 33 | for t in range(T)] 34 | 35 | def generate_data(model_params, T = 5, rs = npr.RandomState(0)): 36 | mu0, Sigma0, A, Q, C, R = model_params 37 | Dx = mu0.shape[0] 38 | Dy = R.shape[0] 39 | 40 | x_true = np.zeros((T,Dx)) 41 | y_true = np.zeros((T,Dy)) 42 | 43 | for t in range(T): 44 | if t > 0: 45 | x_true[t,:] = rs.multivariate_normal(np.dot(A,x_true[t-1,:]),Q) 46 | else: 47 | x_true[0,:] = rs.multivariate_normal(mu0,Sigma0) 48 | y_true[t,:] = rs.multivariate_normal(np.dot(C,x_true[t,:]),R) 49 | 50 | return x_true, y_true 51 | 52 | def log_marginal_likelihood(model_params, T, y_true): 53 | mu0, Sigma0, A, Q, C, R = model_params 54 | Dx = mu0.shape[0] 55 | Dy = R.shape[1] 56 | 57 | log_likelihood = 0. 58 | xfilt = np.zeros(Dx) 59 | Pfilt = np.zeros((Dx,Dx)) 60 | xpred = mu0 61 | Ppred = Sigma0 62 | 63 | for t in range(T): 64 | if t > 0: 65 | # Predict 66 | xpred = np.dot(A,xfilt) 67 | Ppred = np.dot(A,np.dot(Pfilt,A.T)) + Q 68 | 69 | # Update 70 | yt = y_true[t,:] - np.dot(C,xpred) 71 | S = np.dot(C,np.dot(Ppred,C.T)) + R 72 | K = np.linalg.solve(S,np.dot(C,Ppred)).T 73 | xfilt = xpred + np.dot(K,yt) 74 | Pfilt = Ppred - np.dot(K,np.dot(C,Ppred)) 75 | 76 | sign, logdet = np.linalg.slogdet(S) 77 | log_likelihood += -0.5*(np.sum(yt*np.linalg.solve(S,yt)) + logdet + Dy*np.log(2.*np.pi)) 78 | 79 | return log_likelihood 80 | 81 | class lgss_smc: 82 | """ 83 | Class for defining functions used in variational SMC. 84 | """ 85 | def __init__(self, T, Dx, Dy, N): 86 | self.T = T 87 | self.Dx = Dx 88 | self.Dy = Dy 89 | self.N = N 90 | 91 | def log_normal(self, x, mu, Sigma): 92 | dim = Sigma.shape[0] 93 | sign, logdet = np.linalg.slogdet(Sigma) 94 | log_norm = -0.5*dim*np.log(2.*np.pi) - 0.5*logdet 95 | Prec = np.linalg.inv(Sigma) 96 | return log_norm - 0.5*np.sum((x-mu)*np.dot(Prec,(x-mu).T).T,axis=1) 97 | 98 | def log_prop(self, t, Xc, Xp, y, prop_params, model_params): 99 | mu0, Sigma0, A, Q, C, R = model_params 100 | mut, lint, log_s2t = prop_params[t] 101 | s2t = np.exp(log_s2t) 102 | 103 | if t > 0: 104 | mu = mut + np.dot(A, Xp.T).T*lint 105 | else: 106 | mu = mut + lint*mu0 107 | 108 | return self.log_normal(Xc, mu, np.diag(s2t)) 109 | 110 | def log_target(self, t, Xc, Xp, y, prop_params, model_params): 111 | mu0, Sigma0, A, Q, C, R = model_params 112 | if t > 0: 113 | logF = self.log_normal(Xc,np.dot(A,Xp.T).T, Q) 114 | else: 115 | logF = self.log_normal(Xc, mu0, Sigma0) 116 | logG = self.log_normal(np.dot(C,Xc.T).T, y[t], R) 117 | return logF + logG 118 | 119 | # These following 2 are the only ones needed by variational-smc.py 120 | def log_weights(self, t, Xc, Xp, y, prop_params, model_params): 121 | return self.log_target(t, Xc, Xp, y, prop_params, model_params) - \ 122 | self.log_prop(t, Xc, Xp, y, prop_params, model_params) 123 | 124 | def sim_prop(self, t, Xp, y, prop_params, model_params, rs = npr.RandomState(0)): 125 | mu0, Sigma0, A, Q, C, R = model_params 126 | mut, lint, log_s2t = prop_params[t] 127 | s2t = np.exp(log_s2t) 128 | 129 | if t > 0: 130 | mu = mut + np.dot(A, Xp.T).T*lint 131 | else: 132 | mu = mut + lint*mu0 133 | return mu + rs.randn(*Xp.shape)*np.sqrt(s2t) 134 | 135 | 136 | if __name__ == '__main__': 137 | # Model hyper-parameters 138 | T = 10 139 | Dx = 5 140 | Dy = 3 141 | alpha = 0.42 142 | r = .1 143 | obs = 'sparse' 144 | 145 | # Training parameters 146 | param_scale = 0.5 147 | num_epochs = 1000 148 | step_size = 0.001 149 | 150 | N = 4 151 | 152 | data_seed = npr.RandomState(0) 153 | model_params = init_model_params(Dx, Dy, alpha, r, obs, data_seed) 154 | 155 | print("Generating data...") 156 | x_true, y_true = generate_data(model_params, T, data_seed) 157 | 158 | lml = log_marginal_likelihood(model_params, T, y_true) 159 | print("True log-marginal likelihood: "+str(lml)) 160 | 161 | seed = npr.RandomState(0) 162 | 163 | # Initialize proposal parameters 164 | prop_params = init_prop_params(T, Dx, param_scale, seed) 165 | combined_init_params = (model_params, prop_params) 166 | 167 | lgss_smc_obj = lgss_smc(T, Dx, Dy, N) 168 | 169 | # Define training objective 170 | def objective(combined_params, iter): 171 | model_params, prop_params = combined_params 172 | return -vsmc_lower_bound(prop_params, model_params, y_true, lgss_smc_obj, seed) 173 | 174 | # Get gradients of objective using autograd. 175 | objective_grad = grad(objective) 176 | 177 | print(" Epoch | ELBO ") 178 | f_head = './lgss_vsmc_biased_T'+str(T)+'_N'+str(N)+'_step'+str(step_size) 179 | with open(f_head+'_ELBO.csv', 'w') as f_handle: 180 | f_handle.write("iter,ELBO\n") 181 | def print_perf(combined_params, iter, grad): 182 | if iter % 10 == 0: 183 | model_params, prop_params = combined_params 184 | bound = -objective(combined_params, iter) 185 | message = "{:15}|{:20}".format(iter, bound) 186 | 187 | with open(f_head+'_ELBO.csv', 'a') as f_handle: 188 | np.savetxt(f_handle, [[iter,bound]], fmt='%i,%f') 189 | 190 | print(message) 191 | 192 | # SGD with adaptive step-size "adam" 193 | optimized_params = adam(objective_grad, combined_init_params, step_size=step_size, 194 | num_iters=num_epochs, callback=print_perf) 195 | opt_model_params, opt_prop_params = optimized_params 196 | print(sim_q(opt_prop_params, opt_model_params, y_true, lgss_smc_obj, seed)) 197 | -------------------------------------------------------------------------------- /variational_smc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import autograd.numpy as np 5 | from autograd import grad 6 | from autograd.extend import notrace_primitive 7 | 8 | @notrace_primitive 9 | def resampling(w, rs): 10 | """ 11 | Stratified resampling with "nograd_primitive" to ensure autograd 12 | takes no derivatives through it. 13 | """ 14 | N = w.shape[0] 15 | bins = np.cumsum(w) 16 | ind = np.arange(N) 17 | u = (ind + rs.rand(N))/N 18 | 19 | return np.digitize(u, bins) 20 | 21 | def vsmc_lower_bound(prop_params, model_params, y, smc_obj, rs, verbose=False, adapt_resamp=False): 22 | """ 23 | Estimate the VSMC lower bound. Amenable to (biased) reparameterization 24 | gradients. 25 | 26 | .. math:: 27 | ELBO(\theta,\lambda) = 28 | \mathbb{E}_{\phi}\left[\nabla_\lambda \log \hat p(y_{1:T}) \right] 29 | 30 | Requires an SMC object with 2 member functions: 31 | -- sim_prop(t, x_{t-1}, y, prop_params, model_params, rs) 32 | -- log_weights(t, x_t, x_{t-1}, y, prop_params, model_params) 33 | """ 34 | # Extract constants 35 | T = y.shape[0] 36 | Dx = smc_obj.Dx 37 | N = smc_obj.N 38 | 39 | # Initialize SMC 40 | X = np.zeros((N,Dx)) 41 | Xp = np.zeros((N,Dx)) 42 | logW = np.zeros(N) 43 | W = np.exp(logW) 44 | W /= np.sum(W) 45 | logZ = 0. 46 | ESS = 1./np.sum(W**2)/N 47 | 48 | for t in range(T): 49 | # Resampling 50 | if adapt_resamp: 51 | if ESS < 0.5: 52 | ancestors = resampling(W, rs) 53 | Xp = X[ancestors] 54 | logZ = logZ + max_logW + np.log(np.sum(W)) - np.log(N) 55 | logW = np.zeros(N) 56 | else: 57 | Xp = X 58 | else: 59 | if t > 0: 60 | ancestors = resampling(W, rs) 61 | Xp = X[ancestors] 62 | else: 63 | Xp = X 64 | 65 | # Propagation 66 | X = smc_obj.sim_prop(t, Xp, y, prop_params, model_params, rs) 67 | 68 | # Weighting 69 | if adapt_resamp: 70 | logW = logW + smc_obj.log_weights(t, X, Xp, y, prop_params, model_params) 71 | else: 72 | logW = smc_obj.log_weights(t, X, Xp, y, prop_params, model_params) 73 | max_logW = np.max(logW) 74 | W = np.exp(logW-max_logW) 75 | if adapt_resamp: 76 | if t == T-1: 77 | logZ = logZ + max_logW + np.log(np.sum(W)) - np.log(N) 78 | else: 79 | logZ = logZ + max_logW + np.log(np.sum(W)) - np.log(N) 80 | W /= np.sum(W) 81 | ESS = 1./np.sum(W**2)/N 82 | if verbose: 83 | print('ESS: '+str(ESS)) 84 | return logZ 85 | 86 | def sim_q(prop_params, model_params, y, smc_obj, rs, verbose=False): 87 | """ 88 | Simulates a single sample from the VSMC approximation. 89 | 90 | Requires an SMC object with 2 member functions: 91 | -- sim_prop(t, x_{t-1}, y, prop_params, model_params, rs) 92 | -- log_weights(t, x_t, x_{t-1}, y, prop_params, model_params) 93 | """ 94 | # Extract constants 95 | T = y.shape[0] 96 | Dx = smc_obj.Dx 97 | N = smc_obj.N 98 | 99 | # Initialize SMC 100 | X = np.zeros((N,T,Dx)) 101 | logW = np.zeros(N) 102 | W = np.zeros((N,T)) 103 | ESS = np.zeros(T) 104 | 105 | for t in range(T): 106 | # Resampling 107 | if t > 0: 108 | ancestors = resampling(W[:,t-1], rs) 109 | X[:,:t,:] = X[ancestors,:t,:] 110 | 111 | # Propagation 112 | X[:,t,:] = smc_obj.sim_prop(t, X[:,t-1,:], y, prop_params, model_params, rs) 113 | 114 | # Weighting 115 | logW = smc_obj.log_weights(t, X[:,t,:], X[:,t-1,:], y, prop_params, model_params) 116 | max_logW = np.max(logW) 117 | W[:,t] = np.exp(logW-max_logW) 118 | W[:,t] /= np.sum(W[:,t]) 119 | ESS[t] = 1./np.sum(W[:,t]**2) 120 | 121 | # Sample from the empirical approximation 122 | bins = np.cumsum(W[:,-1]) 123 | u = rs.rand() 124 | B = np.digitize(u,bins) 125 | 126 | if verbose: 127 | print('Mean ESS', np.mean(ESS)/N) 128 | print('Min ESS', np.min(ESS)) 129 | 130 | return X[B,:,:] 131 | --------------------------------------------------------------------------------