├── .gitignore ├── .travis.yml ├── LICENSE-MIT ├── MANIFEST.in ├── README.md ├── aux ├── demo_ll.png └── demo_smooth.png ├── examples ├── EM.py ├── bernoulli_lds.py ├── diagonal_gibbs.py ├── diagonal_meanfield.py ├── gibbs.py ├── meanfield.py ├── missing_data.py ├── poisson_lds.py ├── simple_demo.py └── zeroinflated_bernoulli_lds.py ├── notes ├── info_form.pdf ├── info_form.tex └── macros.sty ├── pylds ├── __init__.py ├── cyutil.pxd ├── distributions.py ├── laplace.py ├── lds_info_messages.pyx ├── lds_messages.pyx ├── lds_messages_interface.py ├── lds_messages_python.py ├── models.py ├── states.py └── util.py ├── setup.py └── tests ├── test_dense.py ├── test_diagonal_plus_lowrank.py ├── test_infofilter.py ├── test_laplace.py ├── test_randomwalk.py └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | /dist 3 | /pylds/*.c 4 | /pylds/*.so 5 | MANIFEST 6 | *.pyc 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: python 3 | python: 4 | - "2.7" 5 | - "3.5" 6 | - "3.6" 7 | notifications: 8 | email: false 9 | before_install: 10 | - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 11 | wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; 12 | else 13 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 14 | fi 15 | - bash miniconda.sh -b -p $HOME/miniconda 16 | - export PATH="$HOME/miniconda/bin:$PATH" 17 | - conda update --yes conda 18 | - conda config --add channels conda-forge 19 | install: 20 | - conda install --yes python=$TRAVIS_PYTHON_VERSION pip numpy scipy matplotlib cython nose future 21 | - pip install -e . 22 | script: nosetests tests 23 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Matthew James Johnson, 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | 21 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include pylds *.pyx 2 | recursive-include pylds *.pxd 3 | recursive-include pylds *.cpp 4 | recursive-include pylds *.c 5 | recursive-include pylds *.h 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyLDS: Bayesian inference for linear dynamical systems [![Test status](https://travis-ci.org/mattjj/pylds.svg?branch=master)](https://travis-ci.org/mattjj/pylds) 2 | _Authors: Matt Johnson and Scott Linderman_ 3 | 4 | This package supports Bayesian learning and inference via Gibbs sampling, 5 | structured mean field, and expectation maximization (EM) for 6 | dynamical systems with linear Gaussian state dynamics and 7 | either linear Gaussian or count observations. For count data, 8 | we support either Pólya-gamma augmentation or Laplace approximation. 9 | All inference algorithms benefit from fast message passing code 10 | written in Cython with direct calls to the BLAS and LAPACK routines 11 | linked to the scipy build. 12 | 13 | # Installation 14 | 15 | To install from pypi, just run 16 | 17 | ``` 18 | pip install pylds 19 | ``` 20 | 21 | To install from a clone of the git repository, you need to install Cython. 22 | Here's one way to do it: 23 | 24 | ``` 25 | pip install cython 26 | git clone https://github.com/mattjj/pylds.git 27 | pip install -e pylds 28 | ``` 29 | 30 | To handle count data, you'll also need 31 | [pypolyagamma](https://github.com/slinderman/pypolyagamma), which can be 32 | installed with 33 | 34 | ``` 35 | pip install pypolyagamma>=1.1 36 | ``` 37 | 38 | # Example 39 | PyLDS exposes a variety of classes and functions for working with linear 40 | dynamical systems. For example, the following snippet will 41 | generate synthetic data from a random model: 42 | ```python 43 | import numpy.random as npr 44 | from pylds.models import DefaultLDS 45 | 46 | D_obs = 1 # Observed data dimension 47 | D_latent = 2 # Latent state dimension 48 | D_input = 0 # Exogenous input dimension 49 | T = 2000 # Number of time steps to simulate 50 | 51 | true_model = DefaultLDS(D_obs, D_latent, D_input) 52 | inputs = npr.randn(T, D_input) 53 | data, stateseq = true_model.generate(T, inputs=inputs) 54 | 55 | # Compute the log likelihood of the data with the true params 56 | true_ll = true_model.log_likelihood() 57 | ``` 58 | The `DefaultLDS` constructor initializes an LDS with a 59 | random rotational dynamics matrix. The outputs are `data`, a `T x D_obs` 60 | matrix of observations, and `stateseq`, a `T x D_latent` matrix 61 | of latent states. 62 | 63 | Now create another LDS and try to infer the latent states and 64 | learn the parameters given the observed data. 65 | 66 | ```python 67 | # Create a separate model and add the observed data 68 | test_model = DefaultLDS(D_obs, D_latent, D_input) 69 | test_model.add_data(data) 70 | 71 | # Run the Gibbs sampler 72 | N_samples = 100 73 | def update(model): 74 | model.resample_model() 75 | return model.log_likelihood() 76 | 77 | lls = [update(test_model) for _ in range(N_samples)] 78 | ``` 79 | 80 | We can plot the log likelihood over iterations to assess the 81 | convergence of the sampling algorithm: 82 | 83 | ```python 84 | # Plot the log likelihoods 85 | plt.figure() 86 | plt.plot([0, N_samples], true_ll * np.ones(2), '--k', label="true") 87 | plt.plot(np.arange(N_samples), lls, color=colors[0], label="test") 88 | plt.xlabel("iteration") 89 | plt.ylabel("training likelihood") 90 | plt.legend(loc="lower right") 91 | ``` 92 | ![Log Likelihood](aux/demo_ll.png) 93 | 94 | We can also smooth the observations with the test model. 95 | ```python 96 | # Smooth the data 97 | smoothed_data = test_model.smooth(data, inputs) 98 | 99 | plt.figure() 100 | plt.plot(data, color=colors[0], lw=2, label="observed") 101 | plt.plot(smoothed_data, color=colors[1], lw=1, label="smoothed") 102 | plt.xlabel("Time") 103 | plt.xlim(0, 500) 104 | plt.ylabel("Smoothed Data") 105 | plt.legend(loc="upper center", ncol=2) 106 | ``` 107 | 108 | ![Smoothed Data](aux/demo_smooth.png) 109 | 110 | Check out the [examples](/examples) directory for demos of other 111 | types of inference, as well as examples of how to work with count 112 | data and missing observations. 113 | 114 | For a lower-level interface to the fast message passing functions, 115 | see [lds_messages.pyx](pylds/lds_messages.pyx), 116 | [lds_info_messages.pyx](pylds/lds_info_messages.pyx), and 117 | [lds_messages_interface.py](pylds/lds_messages_interface.py). 118 | 119 | 120 | -------------------------------------------------------------------------------- /aux/demo_ll.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattjj/pylds/e946bfa5aa76e8f8284614561a0f40ffd5d868fb/aux/demo_ll.png -------------------------------------------------------------------------------- /aux/demo_smooth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattjj/pylds/e946bfa5aa76e8f8284614561a0f40ffd5d868fb/aux/demo_smooth.png -------------------------------------------------------------------------------- /examples/EM.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import numpy.random as npr 4 | import matplotlib.pyplot as plt 5 | 6 | from pybasicbayes.util.text import progprint_xrange 7 | 8 | from pylds.models import DefaultLDS 9 | 10 | npr.seed(0) 11 | 12 | # Set parameters 13 | D_obs = 1 14 | D_latent = 2 15 | D_input = 0 16 | T = 2000 17 | 18 | # Simulate from one LDS 19 | truemodel = DefaultLDS(D_obs, D_latent, D_input) 20 | inputs = np.random.randn(T, D_input) 21 | data, stateseq = truemodel.generate(T, inputs=inputs) 22 | 23 | # Fit with another LDS 24 | model = DefaultLDS(D_obs, D_latent, D_input) 25 | model.add_data(data, inputs=inputs) 26 | 27 | # Initialize with a few iterations of Gibbs 28 | for _ in progprint_xrange(10): 29 | model.resample_model() 30 | 31 | # Run EM 32 | def update(model): 33 | model.EM_step() 34 | return model.log_likelihood() 35 | 36 | lls = [update(model) for _ in progprint_xrange(50)] 37 | 38 | # Plot the log likelihoods 39 | plt.figure() 40 | plt.plot(lls) 41 | plt.xlabel('iteration') 42 | plt.ylabel('training likelihood') 43 | 44 | # Predict forward in time 45 | T_given = 1800 46 | T_predict = 200 47 | given_data= data[:T_given] 48 | given_inputs = inputs[:T_given] 49 | 50 | preds = \ 51 | model.sample_predictions( 52 | given_data, inputs=given_inputs, 53 | Tpred=T_predict, 54 | inputs_pred=inputs[T_given:T_given + T_predict]) 55 | 56 | # Plot the predictions 57 | plt.figure() 58 | plt.plot(np.arange(T), data, 'b-', label="true") 59 | plt.plot(T_given + np.arange(T_predict), preds, 'r--', label="prediction") 60 | ylim = plt.ylim() 61 | plt.plot([T_given, T_given], ylim, '-k') 62 | plt.xlabel('time index') 63 | plt.xlim(max(0, T_given - 200), T) 64 | plt.ylabel('prediction') 65 | plt.ylim(ylim) 66 | plt.legend() 67 | plt.show() 68 | 69 | -------------------------------------------------------------------------------- /examples/bernoulli_lds.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import numpy.random as npr 4 | import matplotlib.pyplot as plt 5 | 6 | from pybasicbayes.distributions import Regression 7 | from pybasicbayes.util.text import progprint_xrange 8 | from pypolyagamma.distributions import BernoulliRegression 9 | from pylds.models import CountLDS, DefaultBernoulliLDS 10 | 11 | npr.seed(1) 12 | 13 | # Parameters 14 | D_obs = 10 15 | D_latent = 2 16 | D_input = 0 17 | T = 2000 18 | 19 | # True LDS Parameters 20 | mu_init = np.array([0.,1.]) 21 | sigma_init = 0.01*np.eye(2) 22 | 23 | A = 0.99*np.array([[np.cos(np.pi/24), -np.sin(np.pi/24)], 24 | [np.sin(np.pi/24), np.cos(np.pi/24)]]) 25 | B = np.zeros((D_latent, D_input)) 26 | sigma_states = 0.01*np.eye(2) 27 | 28 | C = np.random.randn(D_obs, D_latent) 29 | D = np.zeros((D_obs, D_input)) 30 | b = -2.0 * np.ones((D_obs, 1)) 31 | 32 | # Simulate from a Bernoulli LDS 33 | truemodel = CountLDS( 34 | dynamics_distn=Regression(A=np.hstack((A,B)), sigma=sigma_states), 35 | emission_distn=BernoulliRegression(D_out=D_obs, D_in=D_latent + D_input, 36 | A=np.hstack((C,D)), b=b)) 37 | 38 | inputs = np.random.randn(T, D_input) 39 | data, stateseq = truemodel.generate(T, inputs=inputs) 40 | 41 | # Make a model 42 | model = CountLDS( 43 | dynamics_distn=Regression(nu_0=D_latent + 2, 44 | S_0=D_latent * np.eye(D_latent), 45 | M_0=np.zeros((D_latent, D_latent + D_input)), 46 | K_0=(D_latent + D_input) * np.eye(D_latent + D_input)), 47 | emission_distn=BernoulliRegression(D_out=D_obs, D_in=D_latent + D_input)) 48 | model.add_data(data, inputs=inputs, stateseq=np.zeros((T, D_latent))) 49 | 50 | # Run a Gibbs sampler with Polya-gamma augmentation 51 | N_samples = 50 52 | def gibbs_update(model): 53 | model.resample_model() 54 | smoothed_obs = model.states_list[0].smooth() 55 | ll = model.log_likelihood() 56 | return ll, model.states_list[0].gaussian_states, smoothed_obs 57 | 58 | lls_gibbs, x_smpls_gibbs, y_smooth_gibbs = \ 59 | zip(*[gibbs_update(model) for _ in progprint_xrange(N_samples)]) 60 | 61 | # Fit with a Bernoulli LDS using Laplace approximation for comparison 62 | model = DefaultBernoulliLDS(D_obs, D_latent, D_input=D_input, 63 | C=0.01 * np.random.randn(D_obs, D_latent), 64 | D=0.01 * np.random.randn(D_obs, D_input)) 65 | model.add_data(data, inputs=inputs, stateseq=np.zeros((T, D_latent))) 66 | 67 | N_iters = 50 68 | def em_update(model): 69 | model.EM_step(verbose=True) 70 | smoothed_obs = model.states_list[0].smooth() 71 | ll = model.log_likelihood() 72 | return ll, model.states_list[0].gaussian_states, smoothed_obs 73 | 74 | lls_em, x_smpls_em, y_smooth_em = \ 75 | zip(*[em_update(model) for _ in progprint_xrange(N_iters)]) 76 | 77 | # Plot the log likelihood over iterations 78 | plt.figure(figsize=(10,6)) 79 | plt.plot(lls_gibbs, label="gibbs") 80 | plt.plot(lls_em, label="em") 81 | plt.plot([0,N_samples], truemodel.log_likelihood() * np.ones(2), '-k', label="true") 82 | plt.xlabel('iteration') 83 | plt.ylabel('log likelihood') 84 | plt.legend(loc="lower right") 85 | 86 | # Plot the smoothed observations 87 | fig = plt.figure(figsize=(10,10)) 88 | N_subplots = min(D_obs, 6) 89 | 90 | ylims = (-0.1, 1.1) 91 | xlims = (0, min(T,1000)) 92 | 93 | n_to_plot = np.arange(min(N_subplots, D_obs)) 94 | for i,j in enumerate(n_to_plot): 95 | ax = fig.add_subplot(N_subplots,1,i+1) 96 | # Plot spike counts 97 | given_ts = np.where(data[:,j]==1)[0] 98 | ax.plot(given_ts, np.ones_like(given_ts), 'ko', markersize=5) 99 | 100 | ax.plot([0], [0], 'ko', lw=2, label="data") 101 | ax.plot(y_smooth_gibbs[-1][:, j], lw=2, label="gibbs probs") 102 | ax.plot(y_smooth_em[-1][:, j], lw=2, label="em probs") 103 | 104 | if i == 0: 105 | plt.legend(loc="upper center", ncol=4, bbox_to_anchor=(0.5, 2.)) 106 | if i == N_subplots - 1: 107 | plt.xlabel('time index') 108 | ax.set_xlim(xlims) 109 | ax.set_ylim(0, 1.1) 110 | ax.set_ylabel("$x_%d(t)$" % (j+1)) 111 | 112 | plt.show() 113 | -------------------------------------------------------------------------------- /examples/diagonal_gibbs.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import numpy.random as npr 4 | import matplotlib.pyplot as plt 5 | 6 | from pybasicbayes.distributions import Regression, DiagonalRegression 7 | from pybasicbayes.util.text import progprint_xrange 8 | 9 | from pylds.models import LDS, DefaultLDS 10 | 11 | npr.seed(0) 12 | 13 | 14 | # Parameters 15 | D_obs = 1 16 | D_latent = 2 17 | D_input = 0 18 | T = 2000 19 | 20 | # Simulate from an LDS with diagonal observation noise 21 | truemodel = DefaultLDS(D_obs, D_latent, D_input, sigma_obs=0.1 * np.eye(D_obs)) 22 | inputs = np.random.randn(T, D_input) 23 | data, stateseq = truemodel.generate(T, inputs=inputs) 24 | 25 | # Fit with an LDS with diagonal observation noise 26 | diag_model = LDS( 27 | dynamics_distn=Regression(nu_0=D_latent + 2, 28 | S_0=D_latent * np.eye(D_latent), 29 | M_0=np.zeros((D_latent, D_latent + D_input)), 30 | K_0=(D_latent + D_input) * np.eye(D_latent + D_input)), 31 | emission_distn=DiagonalRegression(D_obs, D_latent+D_input)) 32 | diag_model.add_data(data, inputs=inputs) 33 | 34 | # Also fit a model with a full covariance matrix 35 | full_model = DefaultLDS(D_obs, D_latent, D_input) 36 | full_model.add_data(data, inputs=inputs) 37 | 38 | # Fit with Gibbs sampling 39 | def update(model): 40 | model.resample_model() 41 | return model.log_likelihood() 42 | 43 | N_steps = 100 44 | diag_lls = [update(diag_model) for _ in progprint_xrange(N_steps)] 45 | full_lls = [update(full_model) for _ in progprint_xrange(N_steps)] 46 | 47 | plt.figure() 48 | plt.plot([0, N_steps], truemodel.log_likelihood() * np.ones(2), '--k', label="true") 49 | plt.plot(diag_lls, label="diag cov.") 50 | plt.plot(full_lls, label="full cov.") 51 | plt.xlabel('iteration') 52 | plt.ylabel('log likelihood') 53 | plt.legend() 54 | 55 | # Predict forward in time 56 | T_given = 1800 57 | T_predict = 200 58 | given_data= data[:T_given] 59 | given_inputs = inputs[:T_given] 60 | 61 | preds = \ 62 | diag_model.sample_predictions( 63 | given_data, inputs=given_inputs, 64 | Tpred=T_predict, 65 | inputs_pred=inputs[T_given:T_given + T_predict]) 66 | 67 | # Plot the predictions 68 | plt.figure() 69 | plt.plot(np.arange(T), data, 'b-', label="true") 70 | plt.plot(T_given + np.arange(T_predict), preds, 'r--', label="prediction") 71 | ylim = plt.ylim() 72 | plt.plot([T_given, T_given], ylim, '-k') 73 | plt.xlabel('time index') 74 | plt.xlim(max(0, T_given - 200), T) 75 | plt.ylabel('prediction') 76 | plt.ylim(ylim) 77 | plt.legend() 78 | 79 | # Smooth the data (TODO: Clean this up) 80 | ys = diag_model.smooth(data, inputs) 81 | 82 | plt.figure() 83 | plt.plot(data, 'b-', label="true") 84 | plt.plot(ys, 'r-', lw=2, label="smoothed") 85 | plt.xlabel("Time") 86 | plt.xlim(max(0, T_given-200), T) 87 | plt.ylabel("Smoothed Data") 88 | plt.legend() 89 | 90 | plt.show() 91 | -------------------------------------------------------------------------------- /examples/diagonal_meanfield.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import numpy.random as npr 4 | import matplotlib.pyplot as plt 5 | 6 | from pybasicbayes.distributions import Regression, DiagonalRegression 7 | from pybasicbayes.util.text import progprint_xrange 8 | 9 | from pylds.models import LDS, DefaultLDS 10 | 11 | npr.seed(0) 12 | 13 | 14 | # Parameters 15 | D_obs = 1 16 | D_latent = 2 17 | D_input = 0 18 | T = 2000 19 | 20 | # Simulate from an LDS 21 | truemodel = DefaultLDS(D_obs, D_latent, D_input) 22 | inputs = np.random.randn(T, D_input) 23 | data, stateseq = truemodel.generate(T, inputs=inputs) 24 | 25 | # Fit with an LDS with diagonal observation noise 26 | model = LDS( 27 | dynamics_distn=Regression(nu_0=D_latent + 2, 28 | S_0=D_latent * np.eye(D_latent), 29 | M_0=np.zeros((D_latent, D_latent + D_input)), 30 | K_0=(D_latent + D_input) * np.eye(D_latent + D_input)), 31 | emission_distn=DiagonalRegression(D_obs, D_latent+D_input)) 32 | model.add_data(data, inputs=inputs) 33 | 34 | # Fit with mean field 35 | def update(model): 36 | return model.meanfield_coordinate_descent_step() 37 | 38 | for _ in progprint_xrange(100): 39 | model.resample_model() 40 | 41 | N_steps = 100 42 | vlbs = [update(model) for _ in progprint_xrange(N_steps)] 43 | 44 | plt.figure(figsize=(3,4)) 45 | plt.plot([0, N_steps], truemodel.log_likelihood() * np.ones(2), '--k') 46 | plt.plot(vlbs) 47 | plt.xlabel('iteration') 48 | plt.ylabel('variational lower bound') 49 | 50 | # Predict forward in time 51 | T_given = 1800 52 | T_predict = 200 53 | given_data= data[:T_given] 54 | given_inputs = inputs[:T_given] 55 | 56 | preds = \ 57 | model.sample_predictions( 58 | given_data, inputs=given_inputs, 59 | Tpred=T_predict, 60 | inputs_pred=inputs[T_given:T_given + T_predict]) 61 | 62 | # Plot the predictions 63 | plt.figure() 64 | plt.plot(np.arange(T), data, 'b-', label="true") 65 | plt.plot(T_given + np.arange(T_predict), preds, 'r--', label="prediction") 66 | ylim = plt.ylim() 67 | plt.plot([T_given, T_given], ylim, '-k') 68 | plt.xlabel('time index') 69 | plt.xlim(max(0, T_given - 200), T) 70 | plt.ylabel('prediction') 71 | plt.ylim(ylim) 72 | plt.legend() 73 | 74 | # Smooth the data (TODO: Clean this up) 75 | E_CD,_,_,_ = model.emission_distn.mf_expectations 76 | E_C, E_D = E_CD[:,:D_latent], E_CD[:,D_latent:] 77 | ys = model.states_list[0].smoothed_mus.dot(E_C.T) + inputs.dot(E_D.T) 78 | 79 | plt.figure() 80 | plt.plot(data, 'b-', label="true") 81 | plt.plot(ys, 'r-', lw=2, label="smoothed") 82 | plt.xlabel("Time") 83 | plt.xlim(max(0, T_given-200), T) 84 | plt.ylabel("Smoothed Data") 85 | plt.legend() 86 | 87 | plt.show() 88 | -------------------------------------------------------------------------------- /examples/gibbs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.random as npr 3 | import matplotlib.pyplot as plt 4 | 5 | from pybasicbayes.util.text import progprint_xrange 6 | 7 | from pylds.models import DefaultLDS 8 | 9 | npr.seed(0) 10 | 11 | # Set parameters 12 | D_obs = 1 13 | D_latent = 2 14 | D_input = 1 15 | T = 2000 16 | 17 | # Simulate from one LDS 18 | truemodel = DefaultLDS(D_obs, D_latent, D_input) 19 | inputs = np.random.randn(T, D_input) 20 | data, stateseq = truemodel.generate(T, inputs=inputs) 21 | 22 | # Fit with another LDS 23 | input_model = DefaultLDS(D_obs, D_latent, D_input) 24 | input_model.add_data(data, inputs=inputs) 25 | 26 | # Fit a separate model without the inputs 27 | noinput_model = DefaultLDS(D_obs, D_latent, D_input=0) 28 | noinput_model.add_data(data) 29 | 30 | # Run the Gibbs sampler 31 | def update(model): 32 | model.resample_model() 33 | return model.log_likelihood() 34 | 35 | input_lls = [update(input_model) for _ in progprint_xrange(100)] 36 | noinput_lls = [update(noinput_model) for _ in progprint_xrange(100)] 37 | 38 | # Plot the log likelihoods 39 | plt.figure() 40 | plt.plot(input_lls, label="with inputs") 41 | plt.plot(noinput_lls, label="wo inputs") 42 | plt.xlabel('iteration') 43 | plt.ylabel('training likelihood') 44 | plt.legend() 45 | 46 | # Predict forward in time 47 | T_given = 1800 48 | T_predict = 200 49 | given_data= data[:T_given] 50 | given_inputs = inputs[:T_given] 51 | 52 | preds = \ 53 | input_model.sample_predictions( 54 | given_data, inputs=given_inputs, 55 | Tpred=T_predict, 56 | inputs_pred=inputs[T_given:T_given + T_predict]) 57 | 58 | # Plot the predictions 59 | plt.figure() 60 | plt.plot(np.arange(T), data, 'b-', label="true") 61 | plt.plot(T_given + np.arange(T_predict), preds, 'r--', label="prediction") 62 | ylim = plt.ylim() 63 | plt.plot([T_given, T_given], ylim, '-k') 64 | plt.xlabel('time index') 65 | plt.xlim(max(0, T_given - 200), T) 66 | plt.ylabel('prediction') 67 | plt.ylim(ylim) 68 | plt.legend() 69 | 70 | # Smooth the data 71 | input_ys = input_model.smooth(data, inputs) 72 | noinput_ys = noinput_model.smooth(data) 73 | 74 | plt.figure() 75 | plt.plot(data, 'b-', label="true") 76 | plt.plot(input_ys, 'r-', lw=2, label="with input") 77 | plt.xlabel("Time") 78 | plt.xlim(max(0, T_given-200), T) 79 | plt.ylabel("Smoothed Data") 80 | plt.legend() 81 | 82 | plt.show() 83 | -------------------------------------------------------------------------------- /examples/meanfield.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import numpy.random as npr 4 | import matplotlib.pyplot as plt 5 | 6 | from pybasicbayes.util.text import progprint_xrange 7 | 8 | from pylds.models import DefaultLDS 9 | 10 | npr.seed(0) 11 | 12 | # Set parameters 13 | D_obs = 1 14 | D_latent = 2 15 | D_input = 1 16 | T = 2000 17 | 18 | # Simulate from one LDS 19 | truemodel = DefaultLDS(D_obs, D_latent, D_input) 20 | inputs = np.random.randn(T, D_input) 21 | data, stateseq = truemodel.generate(T, inputs=inputs) 22 | 23 | # Fit with another LDS 24 | model = DefaultLDS(D_obs, D_latent, D_input) 25 | model.add_data(data, inputs=inputs) 26 | 27 | # Initialize with a few iterations of Gibbs 28 | for _ in progprint_xrange(10): 29 | model.resample_model() 30 | 31 | # Run EM 32 | def update(model): 33 | vlb = model.meanfield_coordinate_descent_step() 34 | return vlb 35 | 36 | vlbs = [update(model) for _ in progprint_xrange(50)] 37 | 38 | # Sample from the mean field posterior 39 | model.resample_from_mf() 40 | 41 | # Plot the log likelihoods 42 | plt.figure() 43 | plt.plot(vlbs) 44 | plt.xlabel('iteration') 45 | plt.ylabel('variational lower bound') 46 | 47 | # Predict forward in time 48 | T_given = 1800 49 | T_predict = 200 50 | given_data= data[:T_given] 51 | given_inputs = inputs[:T_given] 52 | 53 | preds = \ 54 | model.sample_predictions( 55 | given_data, inputs=given_inputs, 56 | Tpred=T_predict, 57 | inputs_pred=inputs[T_given:T_given + T_predict]) 58 | 59 | # Plot the predictions 60 | plt.figure() 61 | plt.plot(np.arange(T), data, 'b-', label="true") 62 | plt.plot(T_given + np.arange(T_predict), preds, 'r--', label="prediction") 63 | ylim = plt.ylim() 64 | plt.plot([T_given, T_given], ylim, '-k') 65 | plt.xlabel('time index') 66 | plt.xlim(max(0, T_given - 200), T) 67 | plt.ylabel('prediction') 68 | plt.ylim(ylim) 69 | plt.legend() 70 | 71 | 72 | # Smooth the data 73 | ys = model.smooth(data, inputs) 74 | 75 | plt.figure() 76 | plt.plot(data, 'b-', label="true") 77 | plt.plot(ys, 'r-', lw=2, label="smoothed") 78 | plt.xlabel("Time") 79 | plt.xlim(max(0, T_given-200), T) 80 | plt.ylabel("Smoothed Data") 81 | plt.legend() 82 | 83 | plt.show() 84 | -------------------------------------------------------------------------------- /examples/missing_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import numpy.random as npr 4 | import matplotlib.pyplot as plt 5 | 6 | from pybasicbayes.distributions import Regression, DiagonalRegression 7 | from pybasicbayes.util.text import progprint_xrange 8 | 9 | from pylds.models import DefaultLDS, MissingDataLDS 10 | 11 | npr.seed(0) 12 | 13 | # Model parameters 14 | D_obs = 4 15 | D_latent = 4 16 | T = 1000 17 | 18 | # Simulate from an LDS 19 | truemodel = DefaultLDS(D_obs, D_latent) 20 | data, stateseq = truemodel.generate(T) 21 | 22 | # Mask off a chunk of data 23 | mask = np.ones_like(data, dtype=bool) 24 | chunksz = 100 25 | for i,offset in enumerate(range(0,T,chunksz)): 26 | j = i % (D_obs + 1) 27 | if j < D_obs: 28 | mask[offset:min(offset+chunksz, T), j] = False 29 | if j == D_obs: 30 | mask[offset:min(offset+chunksz, T), :] = False 31 | 32 | # Fit with another LDS 33 | model = MissingDataLDS( 34 | dynamics_distn=Regression( 35 | nu_0=D_latent+3, 36 | S_0=D_latent*np.eye(D_latent), 37 | M_0=np.zeros((D_latent, D_latent)), 38 | K_0=D_latent*np.eye(D_latent)), 39 | emission_distn=DiagonalRegression(D_obs, D_latent, alpha_0=2.0, beta_0=1.0)) 40 | model.add_data(data=data, mask=mask) 41 | 42 | 43 | # Fit the model 44 | N_samples = 500 45 | sigma_obs_smpls = [] 46 | def gibbs_update(model): 47 | model.resample_model() 48 | sigma_obs_smpls.append(model.sigma_obs_flat) 49 | return model.log_likelihood() 50 | 51 | def em_update(model): 52 | model.EM_step() 53 | sigma_obs_smpls.append(model.sigma_obs_flat) 54 | return model.log_likelihood() 55 | 56 | def meanfield_update(model): 57 | model.meanfield_coordinate_descent_step() 58 | sigma_obs_smpls.append(model.emission_distn.mf_beta / model.emission_distn.mf_alpha) 59 | model.resample_from_mf() 60 | return model.log_likelihood() 61 | 62 | def svi_update(model, stepsize, minibatchsize): 63 | # Sample a minibatch 64 | start = np.random.randint(0,T-minibatchsize+1) 65 | minibatch = data[start:start+minibatchsize] 66 | minibatch_mask = mask[start:start+minibatchsize] 67 | prob = minibatchsize/float(T) 68 | model.meanfield_sgdstep(minibatch, prob, stepsize, masks=minibatch_mask) 69 | 70 | sigma_obs_smpls.append(model.emission_distn.mf_beta / model.emission_distn.mf_alpha) 71 | model.resample_from_mf() 72 | return model.log_likelihood(data, mask=mask) 73 | 74 | 75 | # Gibbs 76 | lls = [gibbs_update(model) for _ in progprint_xrange(N_samples)] 77 | 78 | ## EM -- initialized with a few Gibbs iterations 79 | # [model.resample_model() for _ in progprint_xrange(100)] 80 | # lls = [em_update(model) for _ in progprint_xrange(N_samples)] 81 | 82 | ## Mean field 83 | # lls = [meanfield_update(model) for _ in progprint_xrange(N_samples)] 84 | 85 | ## SVI 86 | # delay = 10.0 87 | # forgetting_rate = 0.5 88 | # stepsizes = (np.arange(N_samples) + delay)**(-forgetting_rate) 89 | # minibatchsize = 500 90 | # # [model.resample_model() for _ in progprint_xrange(100)] 91 | # lls = [svi_update(model, stepsizes[itr], minibatchsize) for itr in progprint_xrange(N_samples)] 92 | 93 | # Plot the log likelihood 94 | plt.figure() 95 | plt.plot(lls,'-b') 96 | dummymodel = MissingDataLDS( 97 | dynamics_distn=truemodel.dynamics_distn, 98 | emission_distn=truemodel.emission_distn) 99 | plt.plot([0,N_samples], dummymodel.log_likelihood(data, mask=mask) * np.ones(2), '-k') 100 | plt.xlabel('iteration') 101 | plt.ylabel('log likelihood') 102 | 103 | # Plot the inferred observation noise 104 | plt.figure() 105 | plt.plot(sigma_obs_smpls) 106 | plt.xlabel("iteration") 107 | plt.ylabel("sigma_obs") 108 | 109 | # Smooth over missing data 110 | smoothed_obs = model.states_list[0].smooth() 111 | sample_predictive_obs = model.states_list[0].gaussian_states.dot(model.C.T) 112 | 113 | plt.figure() 114 | given_data = data.copy() 115 | given_data[~mask] = np.nan 116 | masked_data = data.copy() 117 | masked_data[mask] = np.nan 118 | ylims = (-1.1*abs(data).max(), 1.1*abs(data).max()) 119 | xlims = (0, min(T,1000)) 120 | 121 | N_subplots = min(D_obs,4) 122 | for i in range(N_subplots): 123 | plt.subplot(N_subplots,1,i+1,aspect="auto") 124 | 125 | plt.plot(given_data[:,i], 'k', label="observed") 126 | plt.plot(masked_data[:,i], ':k', label="masked") 127 | plt.plot(smoothed_obs[:,i], 'b', lw=2, label="smoothed") 128 | 129 | plt.imshow(1-mask[:,i][None,:],cmap="Greys",alpha=0.25,extent=(0,T) + ylims, aspect="auto") 130 | 131 | if i == 0: 132 | plt.legend(loc="upper center", ncol=3, bbox_to_anchor=(0.5, 1.5)) 133 | 134 | if i == N_subplots - 1: 135 | plt.xlabel('time index') 136 | 137 | plt.ylabel("$x_%d(t)$" % (i+1)) 138 | plt.ylim(ylims) 139 | plt.xlim(xlims) 140 | 141 | plt.show() 142 | 143 | -------------------------------------------------------------------------------- /examples/poisson_lds.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import numpy.random as npr 4 | import matplotlib.pyplot as plt 5 | 6 | from scipy.stats import poisson 7 | 8 | from pybasicbayes.util.text import progprint_xrange 9 | from pylds.models import DefaultPoissonLDS 10 | 11 | npr.seed(0) 12 | 13 | # Parameters 14 | D_obs = 10 15 | D_latent = 2 16 | T = 5000 17 | 18 | # True LDS Parameters 19 | mu_init = np.array([0.,1.]) 20 | sigma_init = 0.01*np.eye(2) 21 | 22 | A = 0.99*np.array([[np.cos(np.pi/24), -np.sin(np.pi/24)], 23 | [np.sin(np.pi/24), np.cos(np.pi/24)]]) 24 | sigma_states = 0.01*np.eye(2) 25 | C = np.random.randn(D_obs, D_latent) 26 | 27 | # Simulate from a Poisson LDS 28 | truemodel = DefaultPoissonLDS(D_obs, D_latent, A=A, sigma_states=sigma_states, C=C) 29 | data, stateseq = truemodel.generate(T) 30 | 31 | # Fit with a Poisson LDS 32 | model = DefaultPoissonLDS(D_obs, D_latent) 33 | model.add_data(data) 34 | model.states_list[0].gaussian_states *= 0 35 | 36 | N_iters = 50 37 | def em_update(model): 38 | model.EM_step(verbose=True) 39 | ll = model.log_likelihood() 40 | return ll 41 | 42 | lls = [em_update(model) for _ in progprint_xrange(N_iters)] 43 | 44 | # Compute baseline likelihood under Poisson MLE model 45 | rates = data.mean(0) 46 | baseline = 0 47 | for n in range(D_obs): 48 | baseline += poisson(rates[n]).logpmf(data[:,n]).sum() 49 | 50 | # Plot the log likelihood over iterations 51 | plt.plot(np.array(lls) / T / D_obs, '-', lw=2, label="model") 52 | plt.plot([0, N_iters-1], baseline * np.ones(2) / T / D_obs, ':k', lw=2, label="baseline") 53 | plt.xlabel('iteration') 54 | plt.ylabel('log likelihood per datapoint') 55 | plt.legend(loc="lower right") 56 | plt.tight_layout() 57 | 58 | # Plot the smoothed observations 59 | fig = plt.figure(figsize=(6, 6)) 60 | smoothed_obs = model.states_list[0].smooth() 61 | true_smoothed_obs = truemodel.states_list[0].smooth() 62 | 63 | ylims = (-0.1, 1.1) 64 | xlims = (0, min(T,1000)) 65 | 66 | n_subplots = min(D_obs, 6) 67 | n_to_plot = np.arange(n_subplots) 68 | for i,j in enumerate(n_to_plot): 69 | ax = fig.add_subplot(n_subplots,1,i+1) 70 | 71 | # Plot the inferred rate 72 | ax.plot([0], [0], 'ko', lw=2, label="observed data") 73 | ax.plot(true_smoothed_obs[:,j], 'k', lw=3, label="true mean") 74 | ax.plot(smoothed_obs[:,j], '--r', lw=2, label="inf mean") 75 | 76 | # Plot spike counts 77 | yl = ax.get_ylim() 78 | given_ts = np.where(data[:, j] == 1)[0] 79 | ax.plot(given_ts, (yl[1] * 1.05) * np.ones_like(given_ts), 'ko', markersize=5) 80 | 81 | if i == 0: 82 | plt.legend(loc="upper center", ncol=4, bbox_to_anchor=(0.5, 1.8)) 83 | if i == n_subplots - 1: 84 | plt.xlabel('time index') 85 | 86 | ax.set_xlim(xlims) 87 | ax.set_ylim(yl[0], yl[1] * 1.1) 88 | ax.set_ylabel("$x_%d(t)$" % (j+1)) 89 | 90 | plt.tight_layout() 91 | plt.show() 92 | -------------------------------------------------------------------------------- /examples/simple_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.random as npr 3 | import matplotlib.pyplot as plt 4 | 5 | # Fancy plotting 6 | try: 7 | import seaborn as sns 8 | sns.set_style("white") 9 | sns.set_context("talk") 10 | 11 | color_names = ["windows blue", 12 | "red", 13 | "amber", 14 | "faded green", 15 | "dusty purple", 16 | "crimson", 17 | "greyish"] 18 | colors = sns.xkcd_palette(color_names) 19 | except: 20 | colors = ['b' ,'r', 'y', 'g'] 21 | 22 | from pybasicbayes.util.text import progprint_xrange 23 | from pylds.models import DefaultLDS 24 | 25 | npr.seed(3) 26 | 27 | # Set parameters 28 | D_obs = 1 29 | D_latent = 2 30 | D_input = 0 31 | T = 2000 32 | 33 | # Simulate from one LDS 34 | true_model = DefaultLDS(D_obs, D_latent, D_input, sigma_obs=np.eye(D_obs)) 35 | inputs = npr.randn(T, D_input) 36 | data, stateseq = true_model.generate(T, inputs=inputs) 37 | 38 | # Fit with another LDS 39 | test_model = DefaultLDS(D_obs, D_latent, D_input) 40 | test_model.add_data(data, inputs=inputs) 41 | 42 | # Run the Gibbs sampler 43 | N_samples = 100 44 | def update(model): 45 | model.resample_model() 46 | return model.log_likelihood() 47 | 48 | lls = [update(test_model) for _ in progprint_xrange(N_samples)] 49 | 50 | # Plot the log likelihoods 51 | plt.figure(figsize=(5,3)) 52 | plt.plot([0, N_samples], true_model.log_likelihood() * np.ones(2), '--k', label="true") 53 | plt.plot(np.arange(N_samples), lls, color=colors[0], label="test") 54 | plt.xlabel('iteration') 55 | plt.ylabel('training likelihood') 56 | plt.legend(loc="lower right") 57 | plt.tight_layout() 58 | plt.savefig("aux/demo_ll.png") 59 | 60 | # Smooth the data 61 | smoothed_data = test_model.smooth(data, inputs) 62 | 63 | plt.figure(figsize=(5,3)) 64 | plt.plot(data, color=colors[0], lw=2, label="observed") 65 | plt.plot(smoothed_data, color=colors[1], lw=1, label="smoothed") 66 | plt.xlabel("Time") 67 | plt.xlim(0, min(T, 500)) 68 | plt.ylabel("Smoothed Data") 69 | plt.ylim(1.2 * np.array(plt.ylim())) 70 | plt.legend(loc="upper center", ncol=2) 71 | plt.tight_layout() 72 | plt.savefig("aux/demo_smooth.png") 73 | plt.show() 74 | -------------------------------------------------------------------------------- /examples/zeroinflated_bernoulli_lds.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import numpy.random as npr 4 | import matplotlib.pyplot as plt 5 | # Fancy plotting 6 | try: 7 | import seaborn as sns 8 | sns.set_style("white") 9 | sns.set_context("talk") 10 | 11 | color_names = ["windows blue", 12 | "red", 13 | "amber", 14 | "faded green", 15 | "dusty purple", 16 | "crimson", 17 | "greyish"] 18 | colors = sns.xkcd_palette(color_names) 19 | except: 20 | colors = ['b' ,'r', 'y', 'g'] 21 | 22 | 23 | from pybasicbayes.distributions import Regression 24 | from pybasicbayes.util.text import progprint_xrange 25 | from pypolyagamma.distributions import BernoulliRegression 26 | from pylds.models import ZeroInflatedCountLDS, LDS 27 | 28 | npr.seed(0) 29 | 30 | # Parameters 31 | rho = 0.5 # Sparsity (1-probability of deterministic zero) 32 | D_obs = 10 33 | D_latent = 2 34 | D_input = 0 35 | T = 2000 36 | 37 | ### True LDS Parameters 38 | mu_init = np.array([0.,1.]) 39 | sigma_init = 0.01*np.eye(2) 40 | 41 | A = 0.99*np.array([[np.cos(np.pi/24), -np.sin(np.pi/24)], 42 | [np.sin(np.pi/24), np.cos(np.pi/24)]]) 43 | B = np.ones((D_latent, D_input)) 44 | sigma_states = 0.01*np.eye(2) 45 | 46 | C = np.random.randn(D_obs, D_latent) 47 | D = np.zeros((D_obs, D_input)) 48 | b = -2.0 * np.ones((D_obs, 1)) 49 | 50 | ### Simulate from a Bernoulli LDS 51 | truemodel = ZeroInflatedCountLDS( 52 | rho=rho, 53 | dynamics_distn=Regression(A=np.hstack((A,B)), sigma=sigma_states), 54 | emission_distn=BernoulliRegression(D_out=D_obs, D_in=D_latent + D_input, 55 | A=np.hstack((C,D)), b=b)) 56 | 57 | inputs = np.random.randn(T, D_input) 58 | data, stateseq = truemodel.generate(T, inputs=inputs) 59 | dense_data = data.toarray() 60 | true_rate = rho * truemodel.emission_distn.mean(np.hstack((stateseq, inputs))) 61 | 62 | ### First fit a zero inflated model 63 | zi_model = ZeroInflatedCountLDS( 64 | rho=rho, 65 | dynamics_distn=Regression(nu_0=D_latent + 2, 66 | S_0=D_latent * np.eye(D_latent), 67 | M_0=np.zeros((D_latent, D_latent + D_input)), 68 | K_0=(D_latent + D_input) * np.eye(D_latent + D_input)), 69 | emission_distn=BernoulliRegression(D_out=D_obs, D_in=D_latent + D_input)) 70 | zi_model.add_data(data, inputs=inputs) 71 | 72 | # Run a Gibbs sampler 73 | N_samples = 500 74 | def gibbs_update(model): 75 | model.resample_model() 76 | return model.log_likelihood(), \ 77 | model.states_list[0].gaussian_states, \ 78 | model.states_list[0].smooth() 79 | 80 | zi_lls, zi_x_smpls, zi_smoothed_obss = \ 81 | zip(*[gibbs_update(zi_model) for _ in progprint_xrange(N_samples)]) 82 | 83 | ### Now fit a standard model 84 | std_model = LDS( 85 | dynamics_distn=Regression(nu_0=D_latent + 2, 86 | S_0=D_latent * np.eye(D_latent), 87 | M_0=np.zeros((D_latent, D_latent + D_input)), 88 | K_0=(D_latent + D_input) * np.eye(D_latent + D_input)), 89 | emission_distn=BernoulliRegression(D_out=D_obs, D_in=D_latent + D_input)) 90 | std_model.add_data(dense_data, inputs=inputs) 91 | 92 | 93 | # Run a Gibbs sampler 94 | std_lls, std_x_smpls, std_smoothed_obss = \ 95 | zip(*[gibbs_update(std_model) for _ in progprint_xrange(N_samples)]) 96 | 97 | # Plot the log likelihood over iterations 98 | # plt.figure(figsize=(10,6)) 99 | # plt.plot(lls,'-b') 100 | # plt.plot([0,N_samples], truemodel.log_likelihood() * np.ones(2), '-k') 101 | # plt.xlabel('iteration') 102 | # plt.ylabel('log likelihood') 103 | 104 | # Plot the smoothed observations 105 | fig = plt.figure(figsize=(10,10)) 106 | N_subplots = min(D_obs, 6) 107 | 108 | ylims = (-0.1, 1.1) 109 | xlims = (0, min(T,1000)) 110 | 111 | n_to_plot = np.arange(min(N_subplots, D_obs)) 112 | for i,j in enumerate(n_to_plot): 113 | ax = fig.add_subplot(N_subplots,1,i+1) 114 | # Plot spike counts 115 | given_ts = np.where(dense_data[:,j]==1)[0] 116 | ax.plot(given_ts, np.ones_like(given_ts), 'ko', markersize=5) 117 | 118 | # Plot the inferred rate 119 | ax.plot([-10], [0], 'ko', lw=2, label="obs.") 120 | ax.plot(zi_smoothed_obss[-1][:, j], '-', color=colors[0], label="zero inflated") 121 | ax.plot(std_smoothed_obss[-1][:, j], '-', color=colors[1], label="standard") 122 | ax.plot(true_rate[:, j], '--k', lw=2, label="true rate") 123 | 124 | if i == 0: 125 | plt.legend(loc="upper center", ncol=4, bbox_to_anchor=(0.5, 1.5)) 126 | 127 | if i == N_subplots - 1: 128 | plt.xlabel('time index') 129 | else: 130 | ax.set_xticklabels([]) 131 | 132 | ax.set_xlim(xlims) 133 | ax.set_ylim(0, 1.1) 134 | ax.set_ylabel("$x_%d(t)$" % (j+1)) 135 | 136 | plt.savefig("aux/zeroinflation.png") 137 | plt.show() 138 | 139 | -------------------------------------------------------------------------------- /notes/info_form.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattjj/pylds/e946bfa5aa76e8f8284614561a0f40ffd5d868fb/notes/info_form.pdf -------------------------------------------------------------------------------- /notes/info_form.tex: -------------------------------------------------------------------------------- 1 | \documentclass{article} 2 | 3 | \usepackage[numbers, compress]{natbib} 4 | 5 | \usepackage[utf8]{inputenc} % allow utf-8 input 6 | \usepackage[T1]{fontenc} % use 8-bit T1 fonts 7 | \usepackage{hyperref} % hyperlinks 8 | \usepackage{url} % simple URL typesetting 9 | \usepackage{booktabs} % professional-quality tables 10 | \usepackage{amsfonts} % blackboard math symbols 11 | \usepackage{nicefrac} % compact symbols for 1/2, etc. 12 | \usepackage{microtype} % microtypography 13 | 14 | \usepackage{amsthm,amsmath,amssymb} 15 | \usepackage{macros} 16 | \usepackage{subcaption} 17 | \usepackage[textfont=small, labelfont=small]{caption} 18 | \usepackage{graphicx} 19 | \DeclareGraphicsExtensions{.pdf,.png,.jpg,.eps} 20 | 21 | \usepackage{algorithm} 22 | \usepackage{algorithmic} 23 | 24 | \usepackage[color=yellow]{todonotes} 25 | \usepackage{booktabs} 26 | \usepackage[inline]{enumitem} 27 | \usepackage{verbatim} 28 | 29 | \usepackage[margin=1in]{geometry} 30 | 31 | \usepackage{setspace} 32 | 33 | 34 | \title{Information form filtering and smoothing for Gaussian linear dynamical systems} 35 | 36 | \author{Scott W. Linderman 37 | \and 38 | Matthew J. Johnson 39 | } 40 | 41 | 42 | \begin{document} 43 | 44 | \singlespacing 45 | \maketitle 46 | \onehalfspacing 47 | 48 | The \emph{information form} of the Gaussian density of~$x \in \reals^n$ is defined as, 49 | \begin{align} 50 | p(x \given J, h) 51 | &= \exp \left \{ -\frac{1}{2} x^\trans J x + h^\trans x - \log Z \right \}, 52 | \end{align} 53 | where 54 | \begin{align} 55 | \log Z = \frac{1}{2} h^\trans J^{-1} h - \frac{1}{2}\log |J| +\frac{n}{2} \log 2\pi. 56 | \end{align} 57 | The standard formulation is recovered by the transformations, $\Sigma = J^{-1}$, and 58 | $\mu = J^{-1} h$. The advantage of working in the information form is that 59 | it corresponds to the natural parameterization of the Gaussian distribution, and 60 | mean field variational inference is considerably easier with this form. 61 | 62 | In order to perform Kalman filtering and smoothing, we must be able to perform 63 | two operations: \emph{conditioning} and \emph{marginalization}. 64 | 65 | \paragraph{Conditioning} 66 | If, 67 | \begin{align} 68 | p(x) &= \distNormal(x \given J, h) \\ 69 | p(y \given x) &\propto \distNormal(x \given J_{\sf obs}, h_{\sf obs}) 70 | \end{align} 71 | then, 72 | \begin{align} 73 | p(x \given y) &= \distNormal(x \given J + J_{\sf obs}, h + h_{\sf obs}). 74 | \end{align} 75 | 76 | \paragraph{Marginalization} 77 | If, 78 | \begin{align} 79 | \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} 80 | &\sim 81 | \distNormal \left( 82 | \begin{bmatrix} 83 | J_{11} & J_{12} \\ 84 | J_{12}^\trans & J_{22} 85 | \end{bmatrix}, 86 | \begin{bmatrix} h_1 \\ h_2 \end{bmatrix} 87 | \right), 88 | \end{align} 89 | then, 90 | \begin{align} 91 | x_2 &\sim \distNormal( 92 | J_{22} - J_{21}J_{11}^{-1} J_{21}^\trans, \; 93 | h_2 - J_{21} J_{11}^{-1} h_1) 94 | \end{align} 95 | 96 | \begin{proof} 97 | We use the following integral identity, 98 | \begin{align} 99 | \int \exp \left \{ - \frac{1}{2} x^\trans J x + h^\trans x\right\} \mathrm{d} x 100 | &= \exp \left\{ \frac{1}{2} h^\trans J^{-1} h - \frac{1}{2} \log |J| + \frac{n}{2} \log 2 \pi \right\}. 101 | \end{align} 102 | For~$x = [x_1, x_2]^\trans$, we marginalize out~$x_1$ via, 103 | \begin{align} 104 | \int \exp \left \{ - \frac{1}{2} 105 | \begin{bmatrix} x_1^\trans & x_2^\trans \end{bmatrix} 106 | \begin{bmatrix} J_{11} & J_{12} \\ J_{21} & J_{22} \end{bmatrix} 107 | \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} 108 | + \begin{bmatrix} h_1^\trans & h_2^\trans \end{bmatrix} 109 | \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} 110 | - \log Z 111 | \right \} 112 | \mathrm{d} x_1. 113 | \end{align} 114 | Rearranging terms, we have, 115 | \begin{align} 116 | &\exp \left \{ -\frac{1}{2} x_2^\trans J_{22} x_2 + h_2^\trans x_2 - \log Z \right \} 117 | \int \exp \left \{- \frac{1}{2} x_1^\trans J_{11} x_1 + (h_1 - J_{12}x_2)^\trans x_1 \right \} \mathrm{d}x_1 \\ 118 | &= \exp \left \{ -\frac{1}{2} x_2^\trans J_{22} x_2 + h_2^\trans x_2 - \log Z \right \} 119 | \exp \left \{ \frac{1}{2} (h_1 - J_{12}x_2)^\trans J_{11}^{-1} (h_1 - J_{12} x_2) - \frac{1}{2} \log |J_{11}| + \frac{n}{2} \log 2 \pi \right \} \\ 120 | &= \exp \left \{ -\frac{1}{2} x_2^\trans (J_{22} - J_{21}J_{11}^{-1} J_{12}) x_2 121 | + (h_2 - J_{21} J_{11}^{-1} h_1)^\trans x_2 - \log Z \right \} \nonumber \\ 122 | &\qquad \qquad \times 123 | \exp \left \{ \frac{1}{2} h_1^\trans J_{11}^{-1} h_1 - \frac{1}{2} \log |J_{11}| + \frac{n}{2} \log 2 \pi \right \}. 124 | \end{align} 125 | We recognize this as a Gaussian potential on~$x_2$ of the form, 126 | \begin{align} 127 | p(x_2) &= \exp \left \{-\frac{1}{2} x_2^\trans \widetilde{J}_2 x_2 + \widetilde{h}_2^\trans x_2 - \log \widetilde{Z}_2 \right \} \\ 128 | \widetilde{J}_2 &= J_{22} - J_{21}J_{11}^{-1} J_{12} \\ 129 | \widetilde{h}_2 &= h_2 - J_{21} J_{11}^{-1} h_1 \\ 130 | \log \widetilde{Z}_2 &= \log Z -\frac{1}{2} h_1^\trans J_{11}^{-1} h_1 + \frac{1}{2} \log |J_{11}| - \frac{n}{2} \log 2 \pi. 131 | \end{align} 132 | \end{proof} 133 | 134 | \section*{Filtering, Sampling, and Smoothing} 135 | By interleaving these two steps we can filter, sample, and smooth the latent states 136 | in a linear dynamical system. Take the model, 137 | \begin{align} 138 | x_1 &\sim \distNormal(\mu_1, Q_1) \\ 139 | x_{t+1} &\sim \distNormal(A_t x_t + B_t u_t, Q_t) \\ 140 | y_t &\sim \distNormal(C_t x_t + D_t u_t, R_t). 141 | \end{align} 142 | In information form, the initial distribution is, 143 | \begin{align} 144 | % initi 145 | x_1 &\sim \distNormal(J=Q_1^{-1}, h = Q_1^{-1} \mu_1). 146 | \end{align} 147 | The dynamics are given by, 148 | \begin{align} 149 | % dynamics 150 | p(x_{t+1} \given x_t) & \propto 151 | \distNormal \left( 152 | \begin{bmatrix} x_t \\ x_{t+1} \end{bmatrix} 153 | \, \bigg| \, 154 | \begin{bmatrix} 155 | J_{11} & J_{12} \\ 156 | J_{12}^\trans & J_{22} 157 | \end{bmatrix}, 158 | \begin{bmatrix} h_{1} \\ h_2 \end{bmatrix} 159 | \right), 160 | \end{align} 161 | with, 162 | \begin{align} 163 | J_{11} &= A_t^\trans Q_t^{-1} A_t, \quad 164 | J_{12} = -A_t^\trans Q_t^{-1} \quad 165 | J_{22} = Q_t^{-1} \quad 166 | h_1 = -u_t^\trans B_t^\trans Q_t^{-1} A_t \quad 167 | h_2 = u_t^\trans B_t Q_t^{-1}. 168 | \end{align} 169 | Finally, the observations are given by, 170 | \begin{align} 171 | p(y_t \given x_t) 172 | &\propto \distNormal(x_t \given J_{\sf obs}, h_{\sf obs}) 173 | \end{align} 174 | with 175 | \begin{align} 176 | J_{\sf obs} = C_t^\trans R_t^{-1} C_t \quad 177 | h_{\sf obs} = (y_t - D_t u_t)^\trans R_t^{-1} C_t 178 | \end{align} 179 | 180 | \subsection*{Filtering} 181 | We seek the conditional distribution,~$p(x_t \given y_{1:t})$, which 182 | will be Gaussian. We begin with the initial distribution, 183 | \begin{align} 184 | p(x_1) &= \distNormal(x_1 \given J_{1 | 0}, h_{1 | 0}). 185 | \end{align} 186 | Assume, inductively, that~$x_t \given y_{1:t-1} \sim \distNormal(J_{t|t-1}, h_{t|t-1})$. Conditioning on the~$t$-th observation yields, 187 | Conditioned on the first observation, 188 | \begin{align} 189 | p(x_t \given y_{1:t}) &= \distNormal(x_t \given J_{t|t}, h_{t|t}), \\ 190 | J_{t|t} &= J_{t|t-1} + J_{\sf obs} \\ 191 | h_{t|t} &= h_{t|t-1} + h_{\sf obs}. 192 | \end{align} 193 | Then, we predict the next latent state by writing the joint distribution of~$x_t$ and~$x_{t+1}$ and marginalizing out~$x_t$. 194 | \begin{align} 195 | p(x_{t+1} \given y_{1:t}) &= p(x_t \given y_{1:t}) \, p(x_{t+1} \given x_t) \\ 196 | &= \distNormal(x_t \given J_{t+1|t}, h_{t+1|t}) \\ 197 | J_{t+1|t} &= J_{22} - J_{21} (J_{t|t} + J_{11})^{-1}J_{21}^\trans \\ 198 | h_{t+1|t} &= h_2 - J_{21} (J_{t|t} + J_{11})^{-1} (h_{t|t} + h_1) 199 | \end{align}. 200 | This completes one iteration and provides the input to the next. To start the recursion, we initialize, 201 | \begin{align} 202 | J_{1|0} = \Sigma_{\sf init}^{-1}, \quad 203 | h_{1|0} = \Sigma_{\sf init}^{-1} \, \mu_{\sf init}. 204 | \end{align} 205 | 206 | \subsection*{Marginal Likelihood} 207 | This filtering algorithm corresponds to message passing 208 | in a chain-structured Gaussian graphical model. To compute 209 | the marginal likelihood,~$p(y_{1:T})$, we observe that it 210 | is the normalization constant of the graphical model, 211 | \begin{align} 212 | p(x_{1:T}) &= 213 | \frac{1}{Z} \prod_{t=1}^T \psi(x_t) \prod_{t=1}^{T-1} \psi(x_t, x_{t+1}), & & \\ 214 | \psi(x_1) &= p(x_1) \, p(y_1 \given x_1) & & \\ 215 | \psi(x_t) &= p(y_t \given x_t) & \text{for } t&=2 \ldots T \\ 216 | \psi(x_t, x_{t+1}) &= p(x_{t+1} \given x_t) & \text{for } t&=1 \ldots T-1. 217 | \end{align} 218 | To compute the normalizaion constant~$Z$, we recursively eliminate nodes 219 | via message passing. 220 | 221 | \paragraph{Base case.} Let us work backward from the final step in which 222 | we are left with a graph with a single, unnormalized Gaussian potential. 223 | \begin{align} 224 | p(x_T) &= \frac{1}{Z} \psi(x_T) 225 | = \frac{1}{Z} \exp \left \{-\frac{1}{2} x_T^\trans J_{T} x_T + h_{T}^\trans x_T - \log Z_{T} \right \} 226 | \end{align} 227 | To compute the normalizing constant,~$Z$, we integrate over~$x_T$ and use 228 | the normalizing constant for Gaussian distributions: 229 | \begin{align} 230 | Z &= \int \exp \left \{-\frac{1}{2} x_T^\trans J_{T} x_T + h_{T}^\trans x_T - \log Z_{T} \right \} \mathrm{d} x_T \\ 231 | &= \exp \left \{ -\log Z_{T} \right \} 232 | \int \exp \left \{-\frac{1}{2} x_T^\trans J_{T} x_T + h_{T}^\trans x_T \right \} \mathrm{d} x_T \\ 233 | &= \exp \left \{ -\log Z_{T} +\frac{1}{2} h_{T}^\trans J_{T}^{-1} h_{T} - \frac{1}{2} \log |J_{T}| +\frac{n}{2} \log 2 \pi \right \} 234 | \end{align} 235 | 236 | \paragraph{Condition Step.} In the second to last step, we have two potentials, 237 | one from the dynamics induced by the preceding step and one from the observation: 238 | \begin{align} 239 | p(x_T) &= \frac{1}{Z} \psi_{\sf dyn}(x_T) \, \psi_{\sf obs}(x_T) \\ 240 | &= \exp \left \{-\frac{1}{2} x_T^\trans J_{T|T-1} x_T + h_{T|T-1}^\trans x_T - \log Z_{T|T-1} \right \} 241 | \exp \left \{-\frac{1}{2} x_T^\trans J_{\sf obs} x_T + h_{\sf obs}^\trans x_T - \log Z_{\sf obs} \right \} 242 | \end{align} 243 | We reduce this to the base case by simply summing the natural parameters and 244 | log normalizers, 245 | \begin{align} 246 | J_{T} &= J_{T|T-1} + J_{\sf obs} \\ 247 | h_{T} &= h_{T|T-1} + h_{\sf obs} \\ 248 | \log Z_{T} &= \log Z_{T|T-1} + \log Z_{\sf obs}. 249 | \end{align} 250 | 251 | \paragraph{Predict Step.} Now consider a model with two latent states. 252 | \begin{align} 253 | p(x_{T-1}, x_T) &= \frac{1}{Z} \psi(x_{T-1}) \, \psi_{\sf dyn}(x_{T-1}, x_T) \, \psi_{\sf obs}(x_T) 254 | \end{align} 255 | We will eliminate the variable~$x_{T-1}$ by marginalizing, or integrating it out. 256 | \begin{align} 257 | p(x_T) &= \frac{1}{Z} \psi_{\sf obs}(x_T) \int \psi(x_{T-1}) \, \psi_{\sf dyn}(x_{T-1}, x_T) \mathrm{d}x_{T-1} 258 | \end{align} 259 | The integrand is of the form, 260 | \begin{align} 261 | &\psi(x_{T-1}) \, \psi_{\sf dyn}(x_{T-1}, x_T) 262 | \\ 263 | &= \exp \left \{ -\frac{1}{2} x_{T-1}^\trans J_{T-1} x_{T-1} + h_{T-1}^\trans x_{T-1} - \log Z_{T-1} \right \} \\ 264 | &\qquad \times 265 | \exp \left \{ - \frac{1}{2} 266 | \begin{bmatrix} x_{T-1}^\trans & x_T^\trans \end{bmatrix} 267 | \begin{bmatrix} J_{11} & J_{12} \\ J_{21} & J_{22} \end{bmatrix} 268 | \begin{bmatrix} x_{T-1} \\ x_T \end{bmatrix} 269 | + \begin{bmatrix} h_1^\trans & h_2^\trans \end{bmatrix} 270 | \begin{bmatrix} x_{T-1} \\ x_T \end{bmatrix} 271 | - \log Z_{\sf dyn} 272 | \right \} 273 | \\ 274 | &= 275 | \exp \left \{ - \frac{1}{2} 276 | \begin{bmatrix} x_{T-1}^\trans & x_T^\trans \end{bmatrix} 277 | \begin{bmatrix} J_{11} + J_{T-1} & J_{12} \\ J_{21} & J_{22} \end{bmatrix} 278 | \begin{bmatrix} x_{T-1} \\ x_T \end{bmatrix} 279 | + \begin{bmatrix} (h_1 + h_{T-1})^\trans & h_2^\trans \end{bmatrix} 280 | \begin{bmatrix} x_{T-1} \\ x_T \end{bmatrix} 281 | - (\log Z_{T-1} + \log Z_{\sf dyn}) 282 | \right \} 283 | \end{align} 284 | Appealing to the marginalization propositions above, this implies, 285 | \begin{align} 286 | \int \psi(x_{T-1}) \, \psi_{\sf dyn}(x_{T-1}, x_T) \, \mathrm{d} x_{T-1} 287 | &= \exp \left \{-\frac{1}{2} x_T^\trans J_{T|T-1} x_T + h_{T|T-1}^\trans x_T 288 | - \log Z_{T|T-1} \right \}, 289 | \end{align} 290 | where 291 | \begin{multline} 292 | \log Z_{T|T-1} = \\ 293 | \log Z_{T-1} + \log Z_{\sf dyn} 294 | -\frac{1}{2} (h_1 + h_{T-1}^\trans) (J_{11} + J_{T-1})^{-1} (h_1 + h_{T-1}^\trans) 295 | + \frac{1}{2} \log |J_{11} + J_{T-1}| - \frac{n}{2} \log 2 \pi. 296 | \end{multline} 297 | Thus, the log normalizer passed into the condition step is an accumulation of 298 | log normalizers from previous time steps ($\log Z_{T-1}$), the log normalizer 299 | of the dynamics potential ($\log Z_{\sf dyn}$), and a term that comes from 300 | marginalizing out the local variable~$x_{T-1}$. Once we have computed 301 | ~$\log Z_{T|T-1}$, it is passed into the condition step and then into the final 302 | integration to compute the marginal likelihood. 303 | 304 | This process of predicting and conditioning is recursively applied, marginalizing 305 | out variables one at a time, starting with~$x_1$ and ending with~$x_T$. At the 306 | end of this procedure, we are left with the marginal likelihood of the observations. 307 | 308 | \subsection*{Standard Form Marginal likelihood} 309 | The marginal likelihood of the observed data is given by, 310 | \begin{align} 311 | \log p(y_{1:T} \given u_{1:T}) 312 | &= \sum_{t=1}^T \log p(y_t \given y_{1:t-1}, u_{1:t}) \\ 313 | &= \sum_{t=1}^T \log \distNormal(y_t \given C \mu_{t | t-1} + D u_t, S_t), 314 | \end{align} 315 | where 316 | \begin{align} 317 | S_t &= C \Sigma_{t | t-1} C^\trans + R_t \\ 318 | \mu_{t | t-1} &= J_{t | t-1}^{-1} h_{t|t-1} \\ 319 | \Sigma_{t | t-1} &= J_{t|t-1}^{-1} 320 | \end{align} 321 | Expanding the above, we have, 322 | \begin{align} 323 | \log p(y_{1:T} \given u_{1:T}) 324 | &= \sum_{t=1}^T \left[ -\frac{N}{2} \log 2 \pi - \frac{1}{2} \log \big| S_t \big| 325 | - \frac{1}{2} (y_t - C \mu_{t | t-1} - D u_t)^\trans S_t^{-1} 326 | (y_t - C \mu_{t | t-1} - D u_t) \right] 327 | \end{align} 328 | 329 | \subsection*{Backward Sampling} 330 | Having computed~$J_{t|t}$ and~$h_{t|t}$, we the proceed backward in time to draw a joint sample of the latent states. 331 | Given~$J_{t|t}$,~$h_{t|t}$, and~$x_{t+1}$, we have, 332 | \begin{align} 333 | p(x_{t} \given y_{1:t}, x_{t+1}) &\propto p(x_t \given y_{1:T}) \, p(x_{t+1} \given x_t) \\ 334 | &\propto \distNormal(x_t \given J_{t|t}, h_{t|t}) \; 335 | \distNormal(x_t \given J_{11}, \, h_1 - x_{t+1}^\trans J_{21} ) \\ 336 | &\propto \distNormal(x_t \given J_{t|t} + J_{11}, \; h_{t|t} + h_1 - x_{t+1}^\trans J_{21} ) 337 | \end{align} 338 | We sample~$x_t$ from this conditional, then use it to sample~$x_{t-1}$, and repeat until we reach~$x_1$. 339 | 340 | \subsection*{Rauch-Tung-Striebel Smoothing} 341 | Next we seek the conditional distribution given all the data,~$p(x_t \given y_{1:T})$. 342 | This will again be Gaussian, and we will call its parameters~$J_{t|T}$ and~$h_{t|T}$. 343 | Assume, inductively, that we have computed~$J_{t+1|T}$ and~$h_{t+1|T}$. We show how to 344 | compute the parameters for time~$t$. 345 | 346 | From the Markov properties of the model and the conditional distribution derived above, we have, 347 | \begin{align} 348 | p(x_t \given x_{t+1}, y_{1:T}) 349 | &= \distNormal(x_t \given J_{t|t} + J_{11}, \; h_{t|t} + h_1 - J_{12} x_{t+1}). 350 | \end{align} 351 | Expanding, taking care to note that~$x_{t+1}$ appears in the normalizing constant, yields, 352 | \begin{multline} 353 | p(x_t \given x_{t+1}, y_{1:T}) 354 | = \exp \bigg \{-\frac{1}{2} x_t^\trans (J_{t|t} + J_{11}) x_t + (h_{t|t} + h_1)^\trans x_t - x_{t+1}^\trans J_{12} x_t 355 | \\ 356 | -\frac{1}{2} x_{t+1}^\trans J_{12}^\trans (J_{t|t} + J_{11})^{-1} J_{12} x_{t+1} 357 | +(h_{t|t} + h_1)^\trans (J_{t|t} + J_{11})^{-1} J_{12} x_{t+1} \\ 358 | - \frac{1}{2} (h_{t|t} + h_1)^\trans (J_{t|t} + J_{11})^{-1} (h_{t|t} + h_1) 359 | \bigg \} 360 | \end{multline} 361 | 362 | Now consider the joint distribution of~$x_t$ and~$x_{t+1}$ given all the data, 363 | \begin{align} 364 | p(x_t, x_{t+1} \given y_{1:T}) 365 | &= p(x_t \given x_{t+1}, y_{1:T}) p(x_{t+1} \given y_{1:T}) \\ 366 | &\propto \distNormal \left( 367 | \begin{bmatrix} x_t \\ x_{t+1} \end{bmatrix} 368 | \, \bigg| \, 369 | \begin{bmatrix} 370 | \widetilde{J}_{11} & \widetilde{J}_{12} \\ 371 | \widetilde{J}_{12}^\trans & \widetilde{J}_{22} 372 | \end{bmatrix}, 373 | \begin{bmatrix} 374 | \widetilde{h}_1 \\ 375 | \widetilde{h}_2 376 | \end{bmatrix} 377 | \right), 378 | \end{align} 379 | with, 380 | \begin{align} 381 | \widetilde{J}_{11} &= J_{t|t} + J_{11} \\ 382 | \widetilde{J}_{12} &= J_{12} \\ 383 | \widetilde{J}_{22} &= J_{t+1|T} + J_{12}^\trans (J_{t|t} + J_{11})^{-1} J_{12} \\ 384 | \widetilde{h}_1 &= h_{t|t} + h_1 \\ 385 | \widetilde{h}_2 &= h_{t+1|T} + (h_{t|t} + h_1)^\trans (J_{t|t} + J_{11})^{-1} J_{12}. 386 | \end{align} 387 | Recall that, 388 | \begin{align} 389 | J_{t+1|t} &= J_{22} - J_{21} (J_{t|t} + J_{11})^{-1}J_{21}^\trans \\ 390 | h_{t+1|t} &= h_2 - J_{21} (J_{t|t} + J_{11})^{-1} (h_{t|t} + h_1). 391 | \end{align} 392 | Thus, 393 | \begin{align} 394 | \widetilde{J}_{22} &= J_{t+1|T} - J_{t+1|t} + J_{22} \\ 395 | \widetilde{h}_{2} &= h_{t+1|T} - h_{t+1|t} + h_2. 396 | \end{align} 397 | 398 | 399 | Finally, marginalize, 400 | \begin{align} 401 | p(x_t \given y_{1:T}) 402 | &= \distNormal(x_t \given 403 | \widetilde{J}_{11} - \widetilde{J}_{12} \widetilde{J}_{22}^{-1} \widetilde{J}_{12}^\trans, \; 404 | \widetilde{h}_{1} - \widetilde{J}_{12} \widetilde{J}_{22}^{-1} \widetilde{h}_2) \\ 405 | &= \distNormal(x_t \given J_{t|T}, h_{t|T}). 406 | \end{align} 407 | Substituting the simplified forms above yields, 408 | \begin{align} 409 | J_{t|T} &= J_{t|t} + J_{11} - J_{12} (J_{t+1|T} - J_{t+1|t} + J_{22})^{-1} J_{12}^\trans \\ 410 | h_{t|T} &= h_{t|t} + h_1 - J_{12} (J_{t+1|T} - J_{t+1|t} + J_{22})^{-1} (h_{t+1|T} - h_{t+1|t} +h_2). 411 | \end{align} 412 | 413 | \appendix 414 | 415 | \subsection*{Working out marginal likelihood in a simple example} 416 | \begin{align} 417 | p(x) &= \frac{1}{Z} \distNormal(x \given 0, 1) \, \distNormal(1 \given x, \sigma^2) \\ 418 | &= \frac{1}{Z} 419 | \exp \left \{ -\frac{1}{2} x^2 -\frac{1}{2} \log 2 \pi \right \} 420 | \exp \left \{ -\frac{1}{2} \frac{(x-1)^2}{\sigma^2} -\frac{1}{2} \log 2 \pi -\frac{1}{2} \log \sigma^2 \right \} \\ 421 | &= \frac{1}{Z} \exp \left \{ -\frac{1}{2} (x^2 + \frac{x^2}{\sigma^2}) -\frac{1}{2} \frac{-2x}{\sigma^2} -\frac{1}{2}\frac{1}{\sigma^2} - \log 2 \pi -\frac{1}{2} \log \sigma^2 \right \} \\ 422 | &= \frac{1}{Z} \exp \left \{ -\frac{1}{2} (1 + \frac{1}{\sigma^2}) x^2 + \frac{x}{\sigma^2} -\frac{1}{2\sigma^2} - \log 2 \pi -\frac{1}{2} \log \sigma^2 \right \} 423 | \end{align} 424 | This implies that the normalization constant is, 425 | \begin{align} 426 | Z &= \exp \left \{ - \log 2 \pi -\frac{1}{2\sigma^2} + \frac{1}{2} (\frac{1}{\sigma^2}(1+\frac{1}{\sigma^2})^{-1} \frac{1}{\sigma^2}) - \frac{1}{2} \log |1+\frac{1}{\sigma^2}| + \frac{1}{2} \log 2 \pi -\frac{1}{2} \log \sigma^2\right \} \\ 427 | &= \exp \left \{ -\frac{1}{2} \log 2 \pi - \frac{1}{2\sigma^2} + \frac{1}{2\sigma^2} \frac{1}{1+\sigma^2} - \frac{1}{2} \log(1+\sigma^2) + \frac{1}{2}\log \sigma^2 -\frac{1}{2} \log \sigma^2\right \}\\ 428 | &= \exp \left \{ -\frac{1}{2} \log 2 \pi - \frac{1}{2} \frac{1}{1+\sigma^2} - \frac{1}{2} \log(1+\sigma^2) \right \} 429 | \end{align} 430 | From the standard marginal likelihood, we know that this is like, 431 | \begin{align} 432 | \log p(y=1) &= \log \distNormal(1 \given 0, 1 + \sigma^2) \\ 433 | &= -\frac{1}{2} \log 2 \pi -\frac{1}{2} \log (1+\sigma^2) - \frac{1}{2} \frac{1^2}{1+\sigma^2} \\ 434 | % &= -\frac{1}{2} \log 2 \pi -\frac{1}{2} \log 2 - \frac{1}{4} 435 | \end{align} 436 | Thankfully, all checks out! 437 | 438 | \subsection*{Working out marginal likelihood in an input example} 439 | \begin{align} 440 | p(x) &= \frac{1}{Z} \distNormal(x \given 0, 1) \, \distNormal(1 \given x + d, 1) \\ 441 | &= \frac{1}{Z} 442 | \exp \left \{ -\frac{1}{2} x^2 -\frac{1}{2} \log 2 \pi \right \} 443 | \exp \left \{ -\frac{1}{2} (x+d-1)^2 -\frac{1}{2} \log 2 \pi \right \} \\ 444 | &= \frac{1}{Z} \exp \left \{ -\frac{1}{2} (x^2 + x^2) -\frac{1}{2} 2x(d-1) -\frac{1}{2}(d-1)^2 - \log 2 \pi \right \} \\ 445 | &= \frac{1}{Z} \exp \left \{ -\frac{1}{2} 2x^2 + x(1-d) -\frac{1}{2}(d-1)^2 - \log 2 \pi \right \} 446 | \end{align} 447 | This implies that the normalization constant is, 448 | \begin{align} 449 | Z &= \exp \left \{ - \log 2 \pi -\frac{1}{2}(d-1)^2 + \frac{1}{2} \frac{(1-d)^2}{2} - \frac{1}{2} \log |2| + \frac{1}{2} \log 2 \pi \right \} \\ 450 | &= \exp \left \{ - \frac{1}{2}\log 2 \pi -\frac{1}{4}(d-1)^2 - \frac{1}{2} \log |2| \right \} 451 | \end{align} 452 | From the standard marginal likelihood, we know that this is like, 453 | \begin{align} 454 | \log p(y=1) &= \log \distNormal(1 \given d, 2) \\ 455 | &= -\frac{1}{2} \log 2 \pi -\frac{1}{2} \log 2 - \frac{1}{2} \frac{(1-d)^2}{2} \\ 456 | % &= -\frac{1}{2} \log 2 \pi -\frac{1}{2} \log 2 - \frac{1}{4} 457 | \end{align} 458 | Thankfully, all checks out! 459 | 460 | 461 | \end{document} 462 | -------------------------------------------------------------------------------- /notes/macros.sty: -------------------------------------------------------------------------------- 1 | \newcommand{\bA}{\boldsymbol{A}} 2 | \newcommand{\bB}{\boldsymbol{B}} 3 | \newcommand{\bC}{\boldsymbol{C}} 4 | \newcommand{\bD}{\boldsymbol{D}} 5 | \newcommand{\bE}{\boldsymbol{E}} 6 | \newcommand{\bF}{\boldsymbol{F}} 7 | \newcommand{\bG}{\boldsymbol{G}} 8 | \newcommand{\bH}{\boldsymbol{H}} 9 | \newcommand{\bI}{\boldsymbol{I}} 10 | \newcommand{\bJ}{\boldsymbol{J}} 11 | \newcommand{\bK}{\boldsymbol{K}} 12 | \newcommand{\bL}{\boldsymbol{L}} 13 | \newcommand{\bM}{\boldsymbol{M}} 14 | \newcommand{\bN}{\boldsymbol{N}} 15 | \newcommand{\bO}{\boldsymbol{O}} 16 | \newcommand{\bP}{\boldsymbol{P}} 17 | \newcommand{\bQ}{\boldsymbol{Q}} 18 | \newcommand{\bR}{\boldsymbol{R}} 19 | \newcommand{\bS}{\boldsymbol{S}} 20 | \newcommand{\bT}{\boldsymbol{T}} 21 | \newcommand{\bU}{\boldsymbol{U}} 22 | \newcommand{\bV}{\boldsymbol{V}} 23 | \newcommand{\bW}{\boldsymbol{W}} 24 | \newcommand{\bX}{\boldsymbol{X}} 25 | \newcommand{\bY}{\boldsymbol{Y}} 26 | \newcommand{\bZ}{\boldsymbol{Z}} 27 | \newcommand{\ba}{\boldsymbol{a}} 28 | \newcommand{\bb}{\boldsymbol{b}} 29 | \newcommand{\bc}{\boldsymbol{c}} 30 | \newcommand{\bd}{\boldsymbol{d}} 31 | \newcommand{\be}{\boldsymbol{e}} 32 | \newcommand{\bbf}{\boldsymbol{f}} 33 | \newcommand{\bg}{\boldsymbol{g}} 34 | \newcommand{\bh}{\boldsymbol{h}} 35 | \newcommand{\bi}{\boldsymbol{i}} 36 | \newcommand{\bj}{\boldsymbol{j}} 37 | \newcommand{\bk}{\boldsymbol{k}} 38 | \newcommand{\bl}{\boldsymbol{l}} 39 | \newcommand{\bell}{\boldsymbol{\ell}} 40 | \newcommand{\bbm}{\boldsymbol{m}} 41 | \newcommand{\bn}{\boldsymbol{n}} 42 | \newcommand{\bo}{\boldsymbol{o}} 43 | \newcommand{\bp}{\boldsymbol{p}} 44 | \newcommand{\bq}{\boldsymbol{q}} 45 | \newcommand{\br}{\boldsymbol{r}} 46 | \newcommand{\bs}{\boldsymbol{s}} 47 | \newcommand{\bt}{\boldsymbol{t}} 48 | \newcommand{\bu}{\boldsymbol{u}} 49 | \newcommand{\bv}{\boldsymbol{v}} 50 | \newcommand{\bw}{\boldsymbol{w}} 51 | \newcommand{\bx}{\boldsymbol{x}} 52 | \newcommand{\by}{\boldsymbol{y}} 53 | \newcommand{\bz}{\boldsymbol{z}} 54 | 55 | \newcommand{\balpha}{\boldsymbol{\alpha}} 56 | \newcommand{\bbeta}{\boldsymbol{\beta}} 57 | \newcommand{\boldeta}{\boldsymbol{\eta}} 58 | \newcommand{\bkappa}{\boldsymbol{\kappa}} 59 | \newcommand{\bgamma}{\boldsymbol{\gamma}} 60 | \newcommand{\blambda}{\boldsymbol{\lambda}} 61 | \newcommand{\bmu}{\boldsymbol{\mu}} 62 | \newcommand{\bnu}{\boldsymbol{\nu}} 63 | \newcommand{\bphi}{\boldsymbol{\phi}} 64 | \newcommand{\bpi}{\boldsymbol{\pi}} 65 | \newcommand{\bpsi}{\boldsymbol{\psi}} 66 | \newcommand{\bsigma}{\boldsymbol{\sigma}} 67 | \newcommand{\btheta}{\boldsymbol{\theta}} 68 | \newcommand{\bvartheta}{\boldsymbol{\vartheta}} 69 | \newcommand{\bxi}{\boldsymbol{\xi}} 70 | \newcommand{\bomega}{\boldsymbol{\omega}} 71 | \newcommand{\brho}{\boldsymbol{\rho}} 72 | 73 | \newcommand{\bGamma}{\boldsymbol{\Gamma}} 74 | \newcommand{\bLambda}{\boldsymbol{\Lambda}} 75 | \newcommand{\bOmega}{\boldsymbol{\Omega}} 76 | \newcommand{\bPhi}{\boldsymbol{\Phi}} 77 | \newcommand{\bPi}{\boldsymbol{\Pi}} 78 | \newcommand{\bPsi}{\boldsymbol{\Psi}} 79 | \newcommand{\bSigma}{\boldsymbol{\Sigma}} 80 | \newcommand{\bTheta}{\boldsymbol{\Theta}} 81 | \newcommand{\bUpsilon}{\boldsymbol{\Upsilon}} 82 | \newcommand{\bXi}{\boldsymbol{\Xi}} 83 | \newcommand{\bepsilon}{\boldsymbol{\epsilon}} 84 | 85 | \newcommand{\bzero}{\boldsymbol{0}} 86 | \newcommand{\bone}{\boldsymbol{1}} 87 | 88 | 89 | \newcommand{\mcA}{\mathcal{A}} 90 | \newcommand{\mcB}{\mathcal{B}} 91 | \newcommand{\mcC}{\mathcal{C}} 92 | \newcommand{\mcD}{\mathcal{D}} 93 | \newcommand{\mcE}{\mathcal{E}} 94 | \newcommand{\mcF}{\mathcal{F}} 95 | \newcommand{\mcG}{\mathcal{G}} 96 | \newcommand{\mcH}{\mathcal{H}} 97 | \newcommand{\mcI}{\mathcal{I}} 98 | \newcommand{\mcJ}{\mathcal{J}} 99 | \newcommand{\mcK}{\mathcal{K}} 100 | \newcommand{\mcL}{\mathcal{L}} 101 | \newcommand{\mcM}{\mathcal{M}} 102 | \newcommand{\mcN}{\mathcal{N}} 103 | \newcommand{\mcO}{\mathcal{O}} 104 | \newcommand{\mcP}{\mathcal{P}} 105 | \newcommand{\mcQ}{\mathcal{Q}} 106 | \newcommand{\mcR}{\mathcal{R}} 107 | \newcommand{\mcS}{\mathcal{S}} 108 | \newcommand{\mcT}{\mathcal{T}} 109 | \newcommand{\mcU}{\mathcal{U}} 110 | \newcommand{\mcV}{\mathcal{V}} 111 | \newcommand{\mcW}{\mathcal{W}} 112 | \newcommand{\mcX}{\mathcal{X}} 113 | \newcommand{\mcY}{\mathcal{Y}} 114 | \newcommand{\mcZ}{\mathcal{Z}} 115 | 116 | \newcommand{\trans}{\mathsf{T}} 117 | \newcommand{\naturals}{\mathbb{N}} 118 | \newcommand{\reals}{\mathbb{R}} 119 | \newcommand{\given}{\,|\,} 120 | \def\argmax{\operatornamewithlimits{arg\,max}} 121 | \def\argmin{\operatornamewithlimits{arg\,min}} 122 | 123 | \newcommand{\distNormal}{\mathcal{N}} 124 | \newcommand{\distGaussian}{\mathcal{N}} 125 | \newcommand{\distPoisson}{\mathrm{Poisson}} 126 | \newcommand{\distGamma}{\mathrm{Gamma}} 127 | \newcommand{\distInvGamma}{\mathrm{InvGamma}} 128 | \newcommand{\distExponential}{\mathrm{Exp}} 129 | \newcommand{\distBernoulli}{\mathrm{Bern}} 130 | \newcommand{\distBinomial}{\mathrm{Bin}} 131 | \newcommand{\distDirichlet}{\mathrm{Dir}} 132 | \newcommand{\distBeta}{\mathrm{Beta}} 133 | \newcommand{\distMultinomial}{\mathrm{Mult}} 134 | \newcommand{\distCategorical}{\mathrm{Cat}} 135 | \newcommand{\distPolyaGamma}{\mathrm{PG}} 136 | \newcommand{\distNegBinomial}{\mathrm{NB}} 137 | \newcommand{\distNormalInvWishart}{\mathrm{NIW}} 138 | \newcommand{\polyagamma}{~P\'{o}lya-gamma~} 139 | 140 | \newcommand{\prt}[1]{\frac{\partial}{\partial #1}} 141 | \newcommand{\deriv}[1]{\frac{\mathrm{d}}{\mathrm{d} #1}} 142 | \newcommand{\tr}{\text{tr}} 143 | 144 | 145 | \newcommand{\TODO}[1]{\todo[inline]{#1}} 146 | 147 | \newcommand{\diag}{\mathrm{diag}} 148 | \newcommand{\bbI}{\mathbb{I}} 149 | \newcommand{\bbE}{\mathbb{E}} 150 | \newcommand{\Var}{\mathrm{Var}} 151 | \newcommand{\PP}{\mathcal{PP}} 152 | \newcommand{\DP}{\mathcal{DP}} 153 | \newcommand{\dtv}{{D_{\mathrm{TV}}}} 154 | \newcommand{\from}{\leftarrow} 155 | \newcommand{\KL}{\mathrm{KL}} 156 | \newcommand{\barbar}{{\,||\,}} 157 | \newcommand{\pa}{\mathsf{pa}} 158 | \newcommand{\ch}{\mathsf{ch}} 159 | \newcommand{\neigh}{\mathsf{ne}} 160 | \newcommand{\lambdamax}{\lambda_{\mathsf{max}}} 161 | 162 | -------------------------------------------------------------------------------- /pylds/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattjj/pylds/e946bfa5aa76e8f8284614561a0f40ffd5d868fb/pylds/__init__.py -------------------------------------------------------------------------------- /pylds/cyutil.pxd: -------------------------------------------------------------------------------- 1 | # distutils: extra_compile_args = -O3 -w 2 | # cython: boundscheck = False, nonecheck = False, wraparound = False, cdivision = True 3 | 4 | from cython cimport floating 5 | from libc.math cimport sqrt 6 | 7 | from scipy.linalg.cython_blas cimport srot, srotg, drot, drotg 8 | 9 | # TODO for higher-rank updates, Householder reflections may be preferrable 10 | cdef inline void chol_update(int n, floating *R, floating *z) nogil: 11 | cdef int k 12 | cdef int inc = 1 13 | cdef floating a, b, c, s 14 | if floating is double: 15 | for k in range(n): 16 | a, b = R[k*n+k], z[k] 17 | drotg(&a,&b,&c,&s) 18 | drot(&n,&R[k*n],&inc,&z[0],&inc,&c,&s) 19 | else: 20 | for k in range(n): 21 | a, b = R[k*n+k], z[k] 22 | srotg(&a,&b,&c,&s) 23 | srot(&n,&R[k*n],&inc,&z[0],&inc,&c,&s) 24 | 25 | cdef inline void chol_downdate(int n, floating *R, floating *z) nogil: 26 | cdef int k, j 27 | cdef floating rbar 28 | for k in range(n): 29 | rbar = sqrt((R[k*n+k] - z[k])*(R[k*n+k] + z[k])) 30 | for j in range(k+1,n): 31 | R[k*n+j] = (R[k*n+k]*R[k*n+j] - z[k]*z[j]) / rbar 32 | z[j] = (rbar*z[j] - z[k]*R[k*n+j]) / R[k*n+k] 33 | R[k*n+k] = rbar 34 | 35 | cdef inline void copy_transpose(int m, int n, floating *x, floating *y) nogil: 36 | # NOTE: x is (m,n) and stored in Fortran order 37 | cdef int i, j 38 | for i in range(m): 39 | for j in range(n): 40 | y[n*i+j] = x[m*j+i] 41 | 42 | cdef inline void copy_upper_lower(int n, floating *x) nogil: 43 | cdef int i, j 44 | for i in range(n): 45 | for j in range(i): 46 | x[n*i+j] = x[n*j+i] 47 | -------------------------------------------------------------------------------- /pylds/distributions.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | from autograd import value_and_grad 3 | from autograd.scipy.special import gammaln 4 | 5 | from scipy.optimize import minimize 6 | 7 | from pybasicbayes.distributions import Regression 8 | from pybasicbayes.util.text import progprint_xrange 9 | 10 | 11 | class PoissonRegression(Regression): 12 | """ 13 | Poisson regression with Gaussian distributed inputs and exp link: 14 | 15 | y ~ Poisson(exp(Ax)) 16 | 17 | where x ~ N(mu, sigma) 18 | 19 | Currently, we only support maximum likelihood estimation of the 20 | parameters A given the distribution over inputs, x, and 21 | the observed outputs, y. 22 | 23 | We compute the expected log likelihood in closed form (since 24 | we can do this with the exp link function), and we use Autograd 25 | to compute its gradients. 26 | """ 27 | 28 | def __init__(self, D_out, D_in, A=None, verbose=False): 29 | self._D_out, self._D_in = D_out, D_in 30 | self.verbose = verbose 31 | 32 | if A is not None: 33 | assert A.shape == (D_out, D_in) 34 | self.A = A.copy() 35 | else: 36 | self.A = 0.01 * np.random.randn(D_out, D_in) 37 | 38 | self.sigma = None 39 | 40 | @property 41 | def D_in(self): 42 | return self._D_in 43 | 44 | @property 45 | def D_out(self): 46 | return self._D_out 47 | 48 | def log_likelihood(self,xy): 49 | assert isinstance(xy, tuple) 50 | x, y = xy 51 | loglmbda = x.dot(self.A.T) 52 | lmbda = np.exp(loglmbda) 53 | return -gammaln(y+1) - lmbda + y * loglmbda 54 | 55 | def expected_log_likelihood(self, mus, sigmas, y): 56 | """ 57 | Compute the expected log likelihood for a mean and 58 | covariance of x and an observed value of y. 59 | """ 60 | 61 | # Flatten the covariance 62 | T = mus.shape[0] 63 | D = self.D_in 64 | sigs_vec = sigmas.reshape((T, D ** 2)) 65 | 66 | # Compute the log likelihood of each column 67 | ll = np.zeros((T, self.D_out)) 68 | for n in range(self.D_out): 69 | 70 | an = self.A[n] 71 | 72 | E_loglmbda = np.dot(mus, an) 73 | ll[:,n] += y[:,n] * E_loglmbda 74 | 75 | # Vectorized log likelihood calculation 76 | aa_vec = np.outer(an, an).reshape((D ** 2,)) 77 | ll[:,n] = -np.exp(E_loglmbda + 0.5 * np.dot(sigs_vec, aa_vec)) 78 | 79 | return ll 80 | 81 | def predict(self, x): 82 | return np.exp(x.dot(self.A.T)) 83 | 84 | def rvs(self,x=None,size=1,return_xy=True): 85 | x = np.random.normal(size=(size, self.D_in)) if x is None else x 86 | y = np.random.poisson(self.predict(x)) 87 | return np.hstack((x, y)) if return_xy else y 88 | 89 | def max_likelihood(self, data, weights=None,stats=None): 90 | """ 91 | Maximize the likelihood for a given value of x 92 | :param data: 93 | :param weights: 94 | :param stats: 95 | :return: 96 | """ 97 | raise NotImplementedError 98 | 99 | def max_expected_likelihood(self, stats, verbose=False): 100 | # These aren't really "sufficient" statistics, since we 101 | # need the mean and covariance for each time bin. 102 | EyxuT = np.sum([s[0] for s in stats], axis=0) 103 | mus = np.vstack([s[1] for s in stats]) 104 | sigmas = np.vstack([s[2] for s in stats]) 105 | inputs = np.vstack([s[3] for s in stats]) 106 | masks = np.vstack(s[4] for s in stats) 107 | T = mus.shape[0] 108 | 109 | D_latent = mus.shape[1] 110 | sigmas_vec = sigmas.reshape((T, D_latent**2)) 111 | 112 | # Optimize each row of A independently 113 | ns = progprint_xrange(self.D_out) if verbose else range(self.D_out) 114 | for n in ns: 115 | 116 | # Flatten the covariance to enable vectorized calculations 117 | def ll_vec(an): 118 | 119 | ll = 0 120 | ll += np.dot(an, EyxuT[n]) 121 | 122 | # Vectorized log likelihood calculation 123 | loglmbda = np.dot(mus, an) 124 | aa_vec = np.outer(an[:D_latent], an[:D_latent]).reshape((D_latent ** 2,)) 125 | trms = np.exp(loglmbda + 0.5 * np.dot(sigmas_vec, aa_vec)) 126 | ll -= np.sum(trms[masks[:, n]]) 127 | 128 | if not np.isfinite(ll): 129 | return -np.inf 130 | 131 | return ll / T 132 | obj = lambda x: -ll_vec(x) 133 | 134 | itr = [0] 135 | def cbk(x): 136 | itr[0] += 1 137 | print("M_step iteration ", itr[0]) 138 | 139 | res = minimize(value_and_grad(obj), self.A[n], 140 | jac=True, 141 | callback=cbk if verbose else None) 142 | assert res.success 143 | self.A[n] = res.x 144 | 145 | 146 | class BernoulliRegression(Regression): 147 | """ 148 | Bernoulli regression with Gaussian distributed inputs and logistic link: 149 | 150 | y ~ Bernoulli(logistic(Ax)) 151 | 152 | where x ~ N(mu, sigma) 153 | 154 | Currently, we only support maximum likelihood estimation of the 155 | parameter A given the distribution over inputs, x, and 156 | the observed outputs, y. 157 | 158 | We approximate the expected log likelihood with Monte Carlo. 159 | """ 160 | 161 | def __init__(self, D_out, D_in, A=None, verbose=False): 162 | self._D_out, self._D_in = D_out, D_in 163 | self.verbose = verbose 164 | 165 | if A is not None: 166 | assert A.shape == (D_out, D_in) 167 | self.A = A.copy() 168 | else: 169 | self.A = 0.01 * np.random.randn(D_out, D_in) 170 | 171 | self.sigma = None 172 | 173 | @property 174 | def D_in(self): 175 | return self._D_in 176 | 177 | @property 178 | def D_out(self): 179 | return self._D_out 180 | 181 | def log_likelihood(self,xy): 182 | assert isinstance(xy, tuple) 183 | x, y = xy 184 | psi = x.dot(self.A.T) 185 | 186 | # First term is linear 187 | ll = y * psi 188 | 189 | # Compute second term with log-sum-exp trick (see above) 190 | logm = np.maximum(0, psi) 191 | ll -= np.sum(logm) 192 | ll -= np.sum(np.log(np.exp(-logm) + np.exp(psi - logm))) 193 | 194 | return ll 195 | 196 | def predict(self, x): 197 | return 1 / (1 + np.exp(-x.dot(self.A.T))) 198 | 199 | def rvs(self, x=None, size=1, return_xy=True): 200 | x = np.random.normal(size=(size, self.D_in)) if x is None else x 201 | y = np.random.rand(x.shape[0], self.D_out) < self.predict(x) 202 | return np.hstack((x, y)) if return_xy else y 203 | 204 | def max_likelihood(self, data, weights=None, stats=None): 205 | """ 206 | Maximize the likelihood for given data 207 | :param data: 208 | :param weights: 209 | :param stats: 210 | :return: 211 | """ 212 | if isinstance(data, list): 213 | x = np.vstack([d[0] for d in data]) 214 | y = np.vstack([d[1] for d in data]) 215 | elif isinstance(data, tuple): 216 | assert len(data) == 2 217 | elif isinstance(data, np.ndarray): 218 | x, y = data[:,:self.D_in], data[:, self.D_in:] 219 | else: 220 | raise Exception("Invalid data type") 221 | 222 | from sklearn.linear_model import LogisticRegression 223 | for n in progprint_xrange(self.D_out): 224 | lr = LogisticRegression(fit_intercept=False) 225 | lr.fit(x, y[:,n]) 226 | self.A[n] = lr.coef_ 227 | 228 | 229 | def max_expected_likelihood(self, stats, verbose=False, n_smpls=1): 230 | 231 | # These aren't really "sufficient" statistics, since we 232 | # need the mean and covariance for each time bin. 233 | EyxuT = np.sum([s[0] for s in stats], axis=0) 234 | mus = np.vstack([s[1] for s in stats]) 235 | sigmas = np.vstack([s[2] for s in stats]) 236 | inputs = np.vstack([s[3] for s in stats]) 237 | T = mus.shape[0] 238 | 239 | D_latent = mus.shape[1] 240 | 241 | # Draw Monte Carlo samples of x 242 | sigmas_chol = np.linalg.cholesky(sigmas) 243 | x_smpls = mus[:, :, None] + np.matmul(sigmas_chol, np.random.randn(T, D_latent, n_smpls)) 244 | 245 | # Optimize each row of A independently 246 | ns = progprint_xrange(self.D_out) if verbose else range(self.D_out) 247 | for n in ns: 248 | 249 | def ll_vec(an): 250 | ll = 0 251 | 252 | # todo include mask 253 | # First term is linear in psi 254 | ll += np.dot(an, EyxuT[n]) 255 | 256 | # Second term depends only on x and cannot be computed in closed form 257 | # Instead, Monte Carlo sample x 258 | psi_smpls = np.einsum('tdm, d -> tm', x_smpls, an[:D_latent]) 259 | psi_smpls = psi_smpls + np.dot(inputs, an[D_latent:])[:, None] 260 | logm = np.maximum(0, psi_smpls) 261 | trm2_smpls = logm + np.log(np.exp(-logm) + np.exp(psi_smpls - logm)) 262 | ll -= np.sum(trm2_smpls) / n_smpls 263 | 264 | if not np.isfinite(ll): 265 | return -np.inf 266 | 267 | return ll / T 268 | 269 | obj = lambda x: -ll_vec(x) 270 | 271 | itr = [0] 272 | def cbk(x): 273 | itr[0] += 1 274 | print("M_step iteration ", itr[0]) 275 | 276 | res = minimize(value_and_grad(obj), self.A[n], 277 | jac=True, 278 | # callback=cbk if verbose else None) 279 | callback=None) 280 | assert res.success 281 | self.A[n] = res.x 282 | 283 | -------------------------------------------------------------------------------- /pylds/laplace.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | from autograd.scipy.special import gammaln 3 | from autograd import grad, hessian 4 | 5 | from pylds.states import _LDSStates 6 | from pylds.util import symm_block_tridiag_matmul 7 | from pylds.lds_messages_interface import info_E_step 8 | 9 | 10 | class _LaplaceApproxLDSStatesBase(_LDSStates): 11 | """ 12 | Support variational inference via Laplace approximation. 13 | The key is a definition of the log likelihood, 14 | 15 | log p(y_t | x_t, \theta) 16 | 17 | Combining this with a Gaussian LDS prior on the states, 18 | we can compute the gradient and Hessian of the log likelihood. 19 | """ 20 | def local_log_likelihood(self, xt, yt, ut): 21 | """ 22 | Return log p(yt | xt). Implement this in base classes. 23 | """ 24 | raise NotImplementedError 25 | 26 | def log_conditional_likelihood(self, x): 27 | """ 28 | likelihood \sum_t log p(y_t | x_t) 29 | Optionally override this in base classes 30 | """ 31 | T, D = self.T, self.D_latent 32 | assert x.shape == (T, D) 33 | 34 | ll = 0 35 | for t in range(self.T): 36 | ll += self.local_log_likelihood(x[t], self.data[t], self.inputs[t]) 37 | return ll 38 | 39 | def grad_local_log_likelihood(self, x): 40 | """ 41 | return d/dxt log p(yt | xt) evaluated at xt 42 | Optionally override this in base classes 43 | """ 44 | T, D = self.T, self.D_latent 45 | assert x.shape == (T, D) 46 | gfun = grad(self.local_log_likelihood) 47 | 48 | g = np.zeros((T, D)) 49 | for t in range(T): 50 | g[t] += gfun(x[t], self.data[t], self.inputs[t]) 51 | return g 52 | 53 | def hessian_local_log_likelihood(self, x): 54 | """ 55 | return d^2/dxt^2 log p(y | x) for each time bin 56 | Optionally override this in base classes 57 | """ 58 | T, D = self.T, self.D_latent 59 | assert x.shape == (T, D) 60 | 61 | hfun = hessian(self.local_log_likelihood) 62 | H_diag = np.zeros((T, D, D)) 63 | for t in range(T): 64 | H_diag[t] = hfun(x[t], self.data[t], self.inputs[t]) 65 | return H_diag 66 | 67 | @property 68 | def sparse_J_prior(self): 69 | T, D = self.T, self.D_latent 70 | J_init, _, _ = self.info_init_params 71 | J_11, J_21, J_22, _, _, _ = self.info_dynamics_params 72 | 73 | # Collect the Gaussian LDS prior terms 74 | J_diag = np.zeros((T, D, D)) 75 | J_diag[0] += J_init 76 | J_diag[:-1] += J_11 77 | J_diag[1:] += J_22 78 | 79 | J_upper_diag = np.repeat(J_21.T[None, :, :], T - 1, axis=0) 80 | return J_diag, J_upper_diag 81 | 82 | def log_joint(self, x): 83 | """ 84 | Compute the log joint probability p(x, y) 85 | """ 86 | T, D = self.T, self.D_latent 87 | assert x.shape == (T, D) 88 | 89 | # prior log p(x) -- quadratic terms 90 | J_diag, J_upper_diag = self.sparse_J_prior 91 | lp = -0.5 * np.sum(x * symm_block_tridiag_matmul(J_diag, J_upper_diag, x)) 92 | 93 | # prior log p(x) -- linear terms 94 | _, h_init, log_Z_init = self.info_init_params 95 | _, _, _, h1, h2, log_Z_dyn = self.info_dynamics_params 96 | lp += x[0].dot(h_init) 97 | lp += np.sum(x[:-1] * h1) 98 | lp += np.sum(x[1:] * h2) 99 | 100 | # prior log p(x) -- normalization constants 101 | lp += log_Z_init 102 | lp += np.sum(log_Z_dyn) 103 | 104 | # likelihood log p(y | x) 105 | lp += self.log_conditional_likelihood(x) 106 | 107 | return lp 108 | 109 | def sparse_hessian_log_joint(self, x): 110 | """ 111 | The Hessian includes the quadratic terms of the Gaussian LDS prior 112 | as well as the Hessian of the local log likelihood. 113 | """ 114 | T, D = self.T, self.D_latent 115 | assert x.shape == (T, D) 116 | 117 | # Collect the Gaussian LDS prior terms 118 | J_diag, J_upper_diag = self.sparse_J_prior 119 | H_diag, H_upper_diag = -J_diag, -J_upper_diag 120 | 121 | # Collect the likelihood terms 122 | H_diag += self.hessian_local_log_likelihood(x) 123 | 124 | # Subtract a little bit to ensure negative definiteness 125 | H_diag -= 1e-8 * np.eye(D) 126 | 127 | return H_diag, H_upper_diag 128 | 129 | def hessian_vector_product_log_joint(self, x, v): 130 | H_diag, H_upper_diag = self.sparse_hessian_log_joint(x) 131 | return symm_block_tridiag_matmul(H_diag, H_upper_diag, v) 132 | 133 | def gradient_log_joint(self, x): 134 | """ 135 | The gradient of the log joint probability. 136 | 137 | For the Gaussian terms, this is 138 | 139 | d/dx [-1/2 x^T J x + h^T x] = -Jx + h. 140 | 141 | For the likelihood terms, we have for each time t 142 | 143 | d/dx log p(yt | xt) 144 | """ 145 | T, D = self.T, self.D_latent 146 | assert x.shape == (T, D) 147 | 148 | # Collect the Gaussian LDS prior terms 149 | _, h_init, _ = self.info_init_params 150 | _, _, _, h1, h2, _ = self.info_dynamics_params 151 | H_diag, H_upper_diag = self.sparse_J_prior 152 | 153 | # Compute the gradient from the prior 154 | g = -1 * symm_block_tridiag_matmul(H_diag, H_upper_diag, x) 155 | g[0] += h_init 156 | g[:-1] += h1 157 | g[1:] += h2 158 | 159 | # Compute gradient from the likelihood terms 160 | g += self.grad_local_log_likelihood(x) 161 | 162 | return g 163 | 164 | def laplace_approximation(self, method="newton", verbose=False, tol=1e-7, **kwargs): 165 | if method.lower() == "newton": 166 | return self._laplace_approximation_newton(verbose=verbose, tol=tol, **kwargs) 167 | elif method.lower() == "bfgs": 168 | return self._laplace_approximation_bfgs(verbose=verbose, tol=tol, **kwargs) 169 | else: 170 | raise Exception("Invalid method: {}".format(method)) 171 | 172 | def _laplace_approximation_bfgs(self, tol=1e-7, verbose=False): 173 | from scipy.optimize import minimize 174 | 175 | # Gradient ascent on the log joint probability to get mu 176 | T, D = self.T, self.D_latent 177 | scale = self.T * self.D_emission 178 | obj = lambda xflat: -self.log_joint(xflat.reshape((T, D))) / scale 179 | jac = lambda xflat: -self.gradient_log_joint(xflat.reshape((T, D))).ravel() / scale 180 | hvp = lambda xflat, v: -self.hessian_vector_product_log_joint( 181 | xflat.reshape((T, D)), v.reshape((T, D))).ravel() / scale 182 | 183 | x0 = self.gaussian_states.reshape((T * D,)) 184 | 185 | # Make callback 186 | itr = [0] 187 | 188 | def cbk(x): 189 | print("Iteration: ", itr[0], 190 | "\tObjective: ", obj(x).round(2), 191 | "\tAvg Grad: ", jac(x).mean().round(2)) 192 | itr[0] += 1 193 | 194 | # Second order method 195 | if verbose: 196 | print("Fitting Laplace approximation") 197 | 198 | res = minimize(obj, x0, 199 | tol=tol, 200 | method="Newton-CG", 201 | jac=jac, 202 | hessp=hvp, 203 | callback=cbk if verbose else None) 204 | assert res.success 205 | mu = res.x 206 | assert np.all(np.isfinite(mu)) 207 | if verbose: print("Done") 208 | 209 | # Unflatten and compute the expected sufficient statistics 210 | return mu.reshape((T, D)) 211 | 212 | def _laplace_approximation_newton(self, tol=1e-6, stepsz=0.9, verbose=False): 213 | """ 214 | Solve a block tridiagonal system with message passing. 215 | """ 216 | from pylds.util import solve_symm_block_tridiag, scipy_solve_symm_block_tridiag 217 | scale = self.T * self.D_emission 218 | 219 | def newton_step(x, stepsz): 220 | assert 0 <= stepsz <= 1 221 | g = self.gradient_log_joint(x) 222 | H_diag, H_upper_diag = self.sparse_hessian_log_joint(x) 223 | Hinv_g = -scipy_solve_symm_block_tridiag(-H_diag / scale, 224 | -H_upper_diag / scale, 225 | g / scale) 226 | return x - stepsz * Hinv_g 227 | 228 | if verbose: 229 | print("Fitting Laplace approximation") 230 | 231 | itr = [0] 232 | def cbk(x): 233 | print("Iteration: ", itr[0], 234 | "\tObjective: ", (self.log_joint(x) / scale).round(4), 235 | "\tAvg Grad: ", (self.gradient_log_joint(x).mean() / scale).round(4)) 236 | itr[0] += 1 237 | 238 | # Solve for optimal x with Newton's method 239 | x = self.gaussian_states 240 | dx = np.inf 241 | while dx >= tol: 242 | xnew = newton_step(x, stepsz) 243 | dx = np.mean(abs(xnew - x)) 244 | x = xnew 245 | 246 | if verbose: 247 | cbk(x) 248 | 249 | assert np.all(np.isfinite(x)) 250 | if verbose: 251 | print("Done") 252 | 253 | return x 254 | 255 | def log_likelihood(self): 256 | if self._normalizer is None: 257 | self.E_step() 258 | return self._normalizer 259 | 260 | def E_step(self, verbose=False): 261 | self.gaussian_states = self.laplace_approximation(verbose=verbose) 262 | 263 | # Compute normalizer and covariances with E step 264 | T, D = self.T, self.D_latent 265 | H_diag, H_upper_diag = self.sparse_hessian_log_joint(self.gaussian_states) 266 | J_init = J_11 = J_22 = np.zeros((D, D)) 267 | h_init = h_1 = h_2 = np.zeros((D,)) 268 | 269 | # Negate the Hessian since precision is -H 270 | J_21 = np.swapaxes(-H_upper_diag, -1, -2) 271 | J_node = -H_diag 272 | h_node = np.zeros((T, D)) 273 | 274 | logZ, _, self.smoothed_sigmas, E_xtp1_xtT = \ 275 | info_E_step(J_init, h_init, 0, 276 | J_11, J_21, J_22, h_1, h_2, np.zeros((T - 1)), 277 | J_node, h_node, np.zeros(T)) 278 | 279 | # Laplace approximation -- normalizer is the joint times 280 | # the normalizer from the Gaussian approx. 281 | self._normalizer = self.log_joint(self.gaussian_states) + logZ 282 | 283 | self._set_expected_stats(self.gaussian_states, self.smoothed_sigmas, E_xtp1_xtT) 284 | 285 | def _set_expected_stats(self, mu, sigmas, E_xtp1_xtT): 286 | # Get the emission stats 287 | p, n, d, T, inputs, y = \ 288 | self.D_emission, self.D_latent, self.D_input, self.T, \ 289 | self.inputs, self.data 290 | 291 | E_x_xT = sigmas + mu[:, :, None] * mu[:, None, :] 292 | E_x_uT = mu[:, :, None] * self.inputs[:, None, :] 293 | E_u_uT = self.inputs[:, :, None] * self.inputs[:, None, :] 294 | 295 | E_xu_xuT = np.concatenate(( 296 | np.concatenate((E_x_xT, E_x_uT), axis=2), 297 | np.concatenate((np.transpose(E_x_uT, (0, 2, 1)), E_u_uT), axis=2)), 298 | axis=1) 299 | E_xut_xutT = E_xu_xuT[:-1].sum(0) 300 | 301 | E_xtp1_xtp1T = E_x_xT[1:].sum(0) 302 | E_xtp1_xtT = E_xtp1_xtT.sum(0) 303 | 304 | E_xtp1_utT = (mu[1:, :, None] * inputs[:-1, None, :]).sum(0) 305 | E_xtp1_xutT = np.hstack((E_xtp1_xtT, E_xtp1_utT)) 306 | 307 | self.E_dynamics_stats = np.array( 308 | [E_xtp1_xtp1T, E_xtp1_xutT, E_xut_xutT, self.T - 1]) 309 | 310 | # Compute the expectations for the observations 311 | E_yxT = np.sum(y[:, :, None] * mu[:, None, :], axis=0) 312 | E_yuT = y.T.dot(inputs) 313 | E_yxuT = np.hstack((E_yxT, E_yuT)) 314 | self.E_emission_stats = np.array([E_yxuT, mu, sigmas, inputs, np.ones_like(y, dtype=bool)]) 315 | 316 | def smooth(self): 317 | return self.emission_distn.predict(np.hstack((self.gaussian_states, self.inputs))) 318 | 319 | 320 | class LaplaceApproxPoissonLDSStates(_LaplaceApproxLDSStatesBase): 321 | """ 322 | Poisson observations 323 | """ 324 | def local_log_likelihood(self, xt, yt, ut): 325 | # Observation likelihoods 326 | C, D = self.C, self.D 327 | 328 | loglmbda = np.dot(C, xt) + np.dot(D, ut) 329 | lmbda = np.exp(loglmbda) 330 | 331 | ll = np.sum(yt * loglmbda) 332 | ll -= np.sum(lmbda) 333 | ll -= np.sum(gammaln(yt + 1)) 334 | return ll 335 | 336 | # Override likelihood, gradient, and hessian with vectorized forms 337 | def log_conditional_likelihood(self, x): 338 | # Observation likelihoods 339 | C, D = self.C, self.D 340 | 341 | loglmbda = np.dot(x, C.T) + np.dot(self.inputs, D.T) 342 | lmbda = np.exp(loglmbda) 343 | 344 | ll = np.sum(self.data * loglmbda) 345 | ll -= np.sum(lmbda) 346 | ll -= np.sum(gammaln(self.data + 1)) 347 | return ll 348 | 349 | def grad_local_log_likelihood(self, x): 350 | """ 351 | d/dx y^T Cx + y^T d - exp(Cx+d) 352 | = y^T C - exp(Cx+d)^T C 353 | = (y - lmbda)^T C 354 | """ 355 | # Observation likelihoods 356 | lmbda = np.exp(np.dot(x, self.C.T) + np.dot(self.inputs, self.D.T)) 357 | return (self.data - lmbda).dot(self.C) 358 | 359 | def hessian_local_log_likelihood(self, x): 360 | """ 361 | d/dx (y - lmbda)^T C = d/dx -exp(Cx + d)^T C 362 | = -C^T exp(Cx + d)^T C 363 | """ 364 | # Observation likelihoods 365 | lmbda = np.exp(np.dot(x, self.C.T) + np.dot(self.inputs, self.D.T)) 366 | return np.einsum('tn, ni, nj ->tij', -lmbda, self.C, self.C) 367 | 368 | # Test hooks 369 | def test_joint_probability(self, x): 370 | # A differentiable function to compute the joint probability for a given 371 | # latent state sequence 372 | import autograd.numpy as anp 373 | T = self.T 374 | ll = 0 375 | 376 | # Initial likelihood 377 | mu_init, sigma_init = self.mu_init, self.sigma_init 378 | ll += -0.5 * anp.dot(x[0] - mu_init, anp.linalg.solve(sigma_init, x[0] - mu_init)) 379 | 380 | # Transition likelihoods 381 | A, B, Q = self.A, self.B, self.sigma_states 382 | xpred = anp.dot(x[:T-1], A.T) + anp.dot(self.inputs[:T-1], B.T) 383 | dx = x[1:] - xpred 384 | ll += -0.5 * (dx.T * anp.linalg.solve(Q, dx.T)).sum() 385 | 386 | # Observation likelihoods 387 | y = self.data 388 | C, D = self.C, self.D 389 | loglmbda = (anp.dot(x, C.T) + anp.dot(self.inputs, D.T)) 390 | lmbda = anp.exp(loglmbda) 391 | 392 | ll += anp.sum(y * loglmbda) 393 | ll -= anp.sum(lmbda) 394 | 395 | if anp.isnan(ll): 396 | ll = -anp.inf 397 | 398 | return ll 399 | 400 | def test_gradient_log_joint(self, x): 401 | return grad(self.test_joint_probability)(x) 402 | 403 | def test_hessian_log_joint(self, x): 404 | return hessian(self.test_joint_probability)(x) 405 | 406 | 407 | class LaplaceApproxBernoulliLDSStates(_LaplaceApproxLDSStatesBase): 408 | """ 409 | Bernoulli observations with Laplace approximation 410 | 411 | Let \psi_t = C x_t + D u_t 412 | 413 | p(y_t = 1 | x_t) = \sigma(\psi_t) 414 | 415 | log p(y_t | x_t) = y_t \log \sigma(\psi_t) + 416 | (1-y_t) \log \sigma(-\psi_t) + 417 | 418 | = y_t \psi_t - log (1 + exp(\psi_t)) 419 | 420 | use the log-sum-exp trick to compute this: 421 | = y_t \psi_t - log {exp(0) + exp(\psi_t)} 422 | = y_t \psi_t - log {m [exp(0 - log m) + exp(\psi_t - log m)]} 423 | = y_t \psi_t - log m - log {exp(-log m) + exp(\psi_t - log m)} 424 | 425 | set log m = max(0, psi) 426 | 427 | """ 428 | def local_log_likelihood(self, xt, yt, ut): 429 | # Observation likelihoods 430 | C, D = self.C, self.D 431 | 432 | psi = C.dot(xt) + D.dot(ut) 433 | 434 | ll = np.sum(yt * psi) 435 | 436 | # Compute second term with log-sum-exp trick (see above) 437 | logm = np.maximum(0, psi) 438 | ll -= np.sum(logm) 439 | ll -= np.sum(np.log(np.exp(-logm) + np.exp(psi - logm))) 440 | return ll 441 | 442 | # Override likelihood, gradient, and hessian with vectorized forms 443 | def log_conditional_likelihood(self, x): 444 | # Observation likelihoods 445 | C, D, u, y = self.C, self.D, self.inputs, self.data 446 | psi = x.dot(C.T) + u.dot(D.T) 447 | 448 | # First term is linear in psi 449 | ll = np.sum(y * psi) 450 | 451 | # Compute second term with log-sum-exp trick (see above) 452 | logm = np.maximum(0, psi) 453 | ll -= np.sum(logm) 454 | ll -= np.sum(np.log(np.exp(-logm) + np.exp(psi - logm))) 455 | 456 | return ll 457 | 458 | def grad_local_log_likelihood(self, x): 459 | """ 460 | d/d \psi y \psi - log (1 + exp(\psi)) 461 | = y - exp(\psi) / (1 + exp(\psi)) 462 | = y - sigma(psi) 463 | = y - p 464 | 465 | d \psi / dx = C 466 | 467 | d / dx = (y - sigma(psi)) * C 468 | """ 469 | C, D, u, y = self.C, self.D, self.inputs, self.data 470 | psi = x.dot(C.T) + u.dot(D.T) 471 | p = 1. / (1 + np.exp(-psi)) 472 | return (y - p).dot(C) 473 | 474 | def hessian_local_log_likelihood(self, x): 475 | """ 476 | d/dx (y - p) * C 477 | = -dpsi/dx (dp/d\psi) C 478 | = -C p (1-p) C 479 | """ 480 | C, D, u, y = self.C, self.D, self.inputs, self.data 481 | psi = x.dot(C.T) + u.dot(D.T) 482 | p = 1. / (1 + np.exp(-psi)) 483 | dp_dpsi = p * (1 - p) 484 | return np.einsum('tn, ni, nj ->tij', -dp_dpsi, self.C, self.C) 485 | 486 | # Test hooks 487 | def test_joint_probability(self, x): 488 | # A differentiable function to compute the joint probability for a given 489 | # latent state sequence 490 | import autograd.numpy as anp 491 | T = self.T 492 | ll = 0 493 | 494 | # Initial likelihood 495 | mu_init, sigma_init = self.mu_init, self.sigma_init 496 | ll += -0.5 * anp.dot(x[0] - mu_init, anp.linalg.solve(sigma_init, x[0] - mu_init)) 497 | 498 | # Transition likelihoods 499 | A, B, Q = self.A, self.B, self.sigma_states 500 | xpred = anp.dot(x[:T-1], A.T) + anp.dot(self.inputs[:T-1], B.T) 501 | dx = x[1:] - xpred 502 | ll += -0.5 * (dx.T * anp.linalg.solve(Q, dx.T)).sum() 503 | 504 | # Observation likelihoods 505 | y = self.data 506 | C, D = self.C, self.D 507 | psi = (anp.dot(x, C.T) + anp.dot(self.inputs, D.T)) 508 | 509 | ll += anp.sum(y * psi) 510 | ll -= anp.sum(np.log(1 + np.exp(psi))) 511 | 512 | if anp.isnan(ll): 513 | ll = -anp.inf 514 | 515 | return ll 516 | 517 | def test_gradient_log_joint(self, x): 518 | return grad(self.test_joint_probability)(x) 519 | 520 | def test_hessian_log_joint(self, x): 521 | return hessian(self.test_joint_probability)(x) 522 | -------------------------------------------------------------------------------- /pylds/lds_info_messages.pyx: -------------------------------------------------------------------------------- 1 | # distutils: extra_compile_args = -O2 -w 2 | # distutils: include_dirs = pylds/ 3 | # cython: boundscheck = False, nonecheck = False, wraparound = False, cdivision = True 4 | 5 | import numpy as np 6 | from numpy.lib.stride_tricks import as_strided 7 | 8 | cimport numpy as np 9 | cimport cython 10 | from libc.math cimport log, sqrt 11 | from numpy.math cimport INFINITY, PI 12 | 13 | from scipy.linalg.cython_blas cimport dsymm, dcopy, dgemm, dgemv, daxpy, dsyrk, \ 14 | dtrmv, dger, dnrm2, ddot 15 | from scipy.linalg.cython_lapack cimport dpotrf, dpotrs, dpotri, dtrtrs 16 | from cyutil cimport copy_transpose, copy_upper_lower 17 | 18 | 19 | # TODO instead of specializing last step in info filtering and rts, we could 20 | # instead just pad the input J's and h's by zeroes 21 | 22 | 23 | ################################# 24 | # information-form operations # 25 | ################################# 26 | 27 | def kalman_info_filter( 28 | double[:,:] J_init, double[:] h_init, double log_Z_init, 29 | double[:,:,:] J_pair_11, double[:,:,:] J_pair_21, double[:,:,:] J_pair_22, 30 | double[:,:] h_pair_1, double[:,:] h_pair_2, double[:] log_Z_pair, 31 | double[:,:,:] J_node, double[:,:] h_node, double[:] log_Z_node): 32 | 33 | # allocate temporaries and internals 34 | cdef int T = J_node.shape[0], n = J_node.shape[1] 35 | cdef int t 36 | 37 | cdef double[:,:] J_predict = np.copy(J_init) 38 | cdef double[:] h_predict = np.copy(h_init) 39 | 40 | cdef double[::1] temp_n = np.empty((n,), order='F') 41 | cdef double[::1,:] temp_nn = np.empty((n,n),order='F') 42 | cdef double[::1,:] temp_nn2 = np.empty((n,n),order='F') 43 | 44 | # allocate output 45 | cdef double[:,:,::1] filtered_Js = np.empty((T,n,n)) 46 | cdef double[:,::1] filtered_hs = np.empty((T,n)) 47 | cdef double lognorm = 0 48 | 49 | # Initialize 50 | lognorm += log_Z_init 51 | 52 | # run filter forwards 53 | for t in range(T-1): 54 | lognorm += info_condition_on( 55 | J_predict, h_predict, J_node[t], h_node[t], log_Z_node[t], 56 | filtered_Js[t], filtered_hs[t]) 57 | lognorm += info_predict( 58 | filtered_Js[t], filtered_hs[t], 59 | J_pair_11[t], J_pair_21[t], J_pair_22[t], 60 | h_pair_1[t], h_pair_2[t], log_Z_pair[t], 61 | J_predict, h_predict, 62 | temp_n, temp_nn, temp_nn2) 63 | lognorm += info_condition_on( 64 | J_predict, h_predict, J_node[T-1], h_node[T-1], log_Z_node[T-1], 65 | filtered_Js[T-1], filtered_hs[T-1]) 66 | lognorm += info_lognorm( 67 | filtered_Js[T-1], filtered_hs[T-1], temp_n, temp_nn) 68 | 69 | return lognorm, np.asarray(filtered_Js), np.asarray(filtered_hs) 70 | 71 | def info_E_step( 72 | double[:,::1] J_init, double[::1] h_init, double log_Z_init, 73 | double[:,:,:] J_pair_11, double[:,:,:] J_pair_21, double[:,:,:] J_pair_22, 74 | double[:,:] h_pair_1, double[:,:] h_pair_2, double[:] log_Z_pair, 75 | double[:,:,:] J_node, double[:,:] h_node, double[:] log_Z_node): 76 | 77 | # allocate temporaries and internals 78 | cdef int T = J_node.shape[0], n = J_node.shape[1] 79 | cdef int t 80 | 81 | cdef double[:,:,::1] filtered_Js = np.empty((T,n,n)) 82 | cdef double[:,::1] filtered_hs = np.empty((T,n)) 83 | cdef double[:,:,::1] predict_Js = np.empty((T,n,n)) 84 | cdef double[:,::1] predict_hs = np.empty((T,n)) 85 | 86 | cdef double[::1] temp_n = np.empty((n,), order='F') 87 | cdef double[::1,:] temp_nn = np.empty((n,n),order='F') 88 | cdef double[::1,:] temp_nn2 = np.empty((n,n),order='F') 89 | 90 | # allocate output 91 | cdef double[:,::1] smoothed_mus = np.empty((T,n)) 92 | cdef double[:,:,::1] smoothed_sigmas = np.empty((T,n,n)) 93 | cdef double[:,:,::1] ExnxT = np.empty((T-1,n,n)) # 'n' for next 94 | cdef double lognorm = 0. 95 | 96 | # initialize 97 | lognorm += log_Z_init 98 | 99 | # run filter forwards 100 | predict_Js[0,:,:] = J_init 101 | predict_hs[0,:] = h_init 102 | for t in range(T-1): 103 | lognorm += info_condition_on( 104 | predict_Js[t], predict_hs[t], J_node[t], h_node[t], log_Z_node[t], 105 | filtered_Js[t], filtered_hs[t]) 106 | lognorm += info_predict( 107 | filtered_Js[t], filtered_hs[t], 108 | J_pair_11[t], J_pair_21[t], J_pair_22[t], 109 | h_pair_1[t], h_pair_2[t], log_Z_pair[t], 110 | predict_Js[t+1], predict_hs[t+1], 111 | temp_n, temp_nn, temp_nn2) 112 | lognorm += info_condition_on( 113 | predict_Js[T-1], predict_hs[T-1], J_node[T-1], h_node[T-1], log_Z_node[T-1], 114 | filtered_Js[T-1], filtered_hs[T-1]) 115 | lognorm += info_lognorm( 116 | filtered_Js[T-1], filtered_hs[T-1], temp_n, temp_nn) 117 | 118 | # run info-form rts update backwards 119 | # overwriting the filtered params with smoothed ones 120 | info_to_distn( 121 | filtered_Js[T-1], filtered_hs[T-1], 122 | smoothed_mus[T-1], smoothed_sigmas[T-1]) 123 | for t in range(T-2,-1,-1): 124 | info_rts_backward_step( 125 | J_pair_11[t], J_pair_21[t], J_pair_22[t], 126 | h_pair_1[t], h_pair_2[t], 127 | predict_Js[t+1], filtered_Js[t], filtered_Js[t+1], # filtered_Js[t] is mutated 128 | predict_hs[t+1], filtered_hs[t], filtered_hs[t+1], # filtered_hs[t] is mutated 129 | smoothed_mus[t], smoothed_mus[t+1], smoothed_sigmas[t], ExnxT[t], 130 | temp_n, temp_nn, temp_nn2) 131 | 132 | return lognorm, np.asarray(smoothed_mus), \ 133 | np.asarray(smoothed_sigmas), np.swapaxes(ExnxT, 1, 2) 134 | 135 | 136 | def info_sample( 137 | double[:,::1] J_init, double[::1] h_init, double log_Z_init, 138 | double[:,:,:] J_pair_11, double[:,:,:] J_pair_21, double[:,:,:] J_pair_22, 139 | double[:,:] h_pair_1, double[:,:] h_pair_2, double[:] log_Z_pair, 140 | double[:,:,:] J_node, double[:,:] h_node, double[:] log_Z_node): 141 | 142 | cdef int T = J_node.shape[0], n = J_node.shape[1] 143 | cdef int t 144 | 145 | cdef double[:,:,::1] filtered_Js = np.empty((T,n,n)) 146 | cdef double[:,::1] filtered_hs = np.empty((T,n)) 147 | cdef double[:,:,::1] predict_Js = np.empty((T,n,n)) 148 | cdef double[:,::1] predict_hs = np.empty((T,n)) 149 | 150 | cdef double[::1] temp_n = np.empty((n,), order='F') 151 | cdef double[::1,:] temp_nn = np.empty((n,n),order='F') 152 | cdef double[::1,:] temp_nn2 = np.empty((n,n),order='F') 153 | 154 | # allocate output 155 | cdef double[:,::1] randseq = np.random.randn(T,n) 156 | cdef double lognorm = 0. 157 | 158 | # dgemv requires these things 159 | cdef int inc = 1 160 | cdef double neg1 = -1., one = 1., zero = 0. 161 | 162 | # initialize 163 | lognorm += log_Z_init 164 | 165 | # run filter forwards 166 | predict_Js[0,:,:] = J_init 167 | predict_hs[0,:] = h_init 168 | for t in range(T-1): 169 | lognorm += info_condition_on( 170 | predict_Js[t], predict_hs[t], J_node[t], h_node[t], log_Z_node[t], 171 | filtered_Js[t], filtered_hs[t]) 172 | lognorm += info_predict( 173 | filtered_Js[t], filtered_hs[t], 174 | J_pair_11[t], J_pair_21[t], J_pair_22[t], 175 | h_pair_1[t], h_pair_2[t], log_Z_pair[t], 176 | predict_Js[t+1], predict_hs[t+1], 177 | temp_n, temp_nn, temp_nn2) 178 | lognorm += info_condition_on( 179 | predict_Js[T-1], predict_hs[T-1], J_node[T-1], h_node[T-1], log_Z_node[T-1], 180 | filtered_Js[T-1], filtered_hs[T-1]) 181 | lognorm += info_lognorm( 182 | filtered_Js[T-1], filtered_hs[T-1], temp_n, temp_nn) 183 | 184 | # sample backward 185 | info_sample_gaussian(filtered_Js[T-1], filtered_hs[T-1], randseq[T-1]) 186 | for t in range(T-2,-1,-1): 187 | # temp_n = h_1 - J_12^T x_{t+1} 188 | # J_pair_21 is C-major, so it is actually J12 to blas! 189 | dcopy(&n, &h_pair_1[t,0], &inc, &temp_n[0], &inc) 190 | dgemv('N', &n, &n, &neg1, &J_pair_21[t,0,0], &n, &randseq[t+1,0], 191 | &inc, &one, &temp_n[0], &inc) 192 | 193 | info_condition_on( 194 | filtered_Js[t], filtered_hs[t], J_pair_11[t], temp_n, 0, 195 | filtered_Js[t], filtered_hs[t]) 196 | info_sample_gaussian(filtered_Js[t], filtered_hs[t], randseq[t]) 197 | 198 | return lognorm, np.asarray(randseq) 199 | 200 | 201 | ########################### 202 | # information-form util # 203 | ########################### 204 | 205 | cdef inline double info_condition_on( 206 | double[:,:] J1, double[:] h1, 207 | double[:,:] J2, double[:] h2, 208 | double log_Z, 209 | double[:,:] Jout, double[:] hout, 210 | ) nogil: 211 | cdef int n = J1.shape[0] 212 | cdef int i 213 | 214 | for i in range(n): 215 | hout[i] = h1[i] + h2[i] 216 | 217 | for i in range(n): 218 | for j in range(n): 219 | Jout[i,j] = J1[i,j] + J2[i,j] 220 | 221 | return log_Z 222 | 223 | 224 | cdef inline double info_predict( 225 | double[:,:] J, double[:] h, 226 | double[:,:] J11, double[:,:] J21, double[:,:] J22, 227 | double[:] h1, double[:] h2, double log_Z, 228 | double[:,:] Jpredict, double[:] hpredict, 229 | double[:] temp_n, double[:,:] temp_nn, double[:,:] temp_nn2, 230 | ) nogil: 231 | 232 | # NOTE: J21 is in C-major order, so BLAS and LAPACK function calls mark it as 233 | # transposed 234 | 235 | cdef int n = J.shape[0] 236 | cdef int nn = n*n 237 | cdef int inc = 1, info = 0 238 | cdef double one = 1., zero = 0., neg1 = -1., lognorm = 0. 239 | 240 | # Copy J to temp_nn and add J_11 241 | dcopy(&nn, &J[0,0], &inc, &temp_nn[0,0], &inc) 242 | daxpy(&nn, &one, &J11[0,0], &inc, &temp_nn[0,0], &inc) 243 | 244 | # Copy h to temp_n and add h_1 245 | dcopy(&n, &h[0], &inc, &temp_n[0], &inc) 246 | daxpy(&n, &one, &h1[0], &inc, &temp_n[0], &inc) 247 | 248 | # Initialize J_predict to J_22, h_predict to h_2, and temp_nn2 with J_21 249 | dcopy(&nn, &J22[0,0], &inc, &Jpredict[0,0], &inc) 250 | dcopy(&n, &h2[0], &inc, &hpredict[0], &inc) 251 | dcopy(&nn, &J21[0,0], &inc, &temp_nn2[0,0], &inc) 252 | 253 | # Inputs: temp_nn = J_{t|t} + J_11 254 | # temp_n = h_{t|t} + h_1 255 | # L = cholesky(J_{t|t} + J_11) 256 | # v = solve_triangular(L, h_{t|t} + h_1) 257 | # lognorm = 1./2 * np.dot(v,v) - np.sum(np.log(np.diag(L))) 258 | lognorm += info_lognorm_destructive(temp_nn, temp_n) 259 | # mutates temp_n and temp_nn 260 | # Now temp_nn = chol(J+J11), temp_n = chol(J+J11)^{-1} (h+h1) 261 | # Solve again so that temp_n = (J+J11)^{-1} (h+h1) 262 | dtrtrs('L', 'T', 'N', &n, &inc, &temp_nn[0,0], &n, &temp_n[0], &n, &info) 263 | # Finally, subtract J21 (J+J11)^{-1} (h+h1) 264 | # NOTE: transpose because J21 is in C-major order 265 | dgemv('T', &n, &n, &neg1, &J21[0,0], &n, &temp_n[0], &inc, &one, &hpredict[0], &inc) 266 | 267 | # Solve again to get temp_nn2 = (J+J11)^{-1}J_12 268 | dtrtrs('L', 'N', 'N', &n, &n, &temp_nn[0,0], &n, &temp_nn2[0,0], &n, &info) 269 | # Finally, subtract to get Jp = J22 - J21 (Jf+J11)^{-1} J21 270 | # TODO this call aliases pointers, should really call dsyrk and copy lower to upper 271 | dgemm('T', 'N', &n, &n, &n, &neg1, &temp_nn2[0,0], &n, &temp_nn2[0,0], &n, &one, &Jpredict[0,0], &n) 272 | # dsyrk('L', 'T', &n, &n, &neg1, &temp_nn2[0,0], &n, &one, &Jpredict[0,0], &n) 273 | 274 | return lognorm + log_Z 275 | 276 | 277 | cdef inline double info_lognorm_destructive(double[:,:] J, double[:] h) nogil: 278 | # NOTE: mutates input to chol(J) and solve_triangular(chol(J),h), resp. 279 | 280 | cdef int n = J.shape[0] 281 | cdef int nn = n*n 282 | cdef int inc = 1, info = 0 283 | cdef double lognorm = 0. 284 | 285 | dpotrf('L', &n, &J[0,0], &n, &info) 286 | dtrtrs('L', 'N', 'N', &n, &inc, &J[0,0], &n, &h[0], &n, &info) 287 | 288 | lognorm += (1./2) * dnrm2(&n, &h[0], &inc)**2 289 | for i in range(n): 290 | lognorm -= log(J[i,i]) 291 | lognorm += n/2. * log(2*PI) 292 | 293 | return lognorm 294 | 295 | 296 | cdef inline double info_lognorm( 297 | double[:,:] J, double[:] h, 298 | double[:] temp_n, double[:,:] temp_nn, 299 | ) nogil: 300 | cdef int n = J.shape[0] 301 | cdef int nn = n*n, inc = 1 302 | 303 | dcopy(&nn, &J[0,0], &inc, &temp_nn[0,0], &inc) 304 | dcopy(&n, &h[0], &inc, &temp_n[0], &inc) 305 | 306 | return info_lognorm_destructive(temp_nn, temp_n) 307 | 308 | 309 | cdef inline void info_rts_backward_step( 310 | double[:,:] J11, double[:,:] J21, double[:,:] J22, 311 | double[:] h1, double[:] h2, 312 | double[:,:] Jpred_tp1, double[:,:] Jfilt_t, double[:,:] Jsmooth_tp1, # Jfilt_t is mutated! 313 | double[:] hpred_tp1, double[:] hfilt_t, double[:] hsmooth_tp1, # hfilt_t is mutated! 314 | double[:] mu_t, double[:] mu_tp1, double[:,:] sigma_t, double[:,:] ExnxT, 315 | double[:] temp_n, double[:,:] temp_nn, double[:,:] temp_nn2, 316 | ) nogil: 317 | 318 | # NOTE: this function mutates Jfilt_t and hfilt_t to be Jsmooth_t and 319 | # hsmooth_t, respectively 320 | # NOTE: J21 is in C-major order, so BLAS and LAPACK function calls mark it as 321 | # transposed 322 | 323 | cdef int n = J11.shape[0] 324 | cdef int nn = n*n 325 | cdef int inc = 1, info = 0 326 | cdef double one = 1., zero = 0., neg1 = -1. 327 | 328 | # temp_nn = Jsmooth_tp1 - J_pred_tp1 + J22 329 | dcopy(&nn, &Jsmooth_tp1[0,0], &inc, &temp_nn[0,0], &inc) 330 | daxpy(&nn, &neg1, &Jpred_tp1[0,0], &inc, &temp_nn[0,0], &inc) 331 | daxpy(&nn, &one, &J22[0,0], &inc, &temp_nn[0,0], &inc) 332 | 333 | # temp_nn2 = J_12.T (recall C order) 334 | copy_transpose(n, n, &J21[0,0], &temp_nn2[0,0]) 335 | 336 | # temp_nn2 = temp_nn^{-1} temp_nn2 337 | # = (Jsmooth_tp1 - J_pred_tp1 + J22)^{-1/2} J_12.T 338 | dpotrf('L', &n, &temp_nn[0,0], &n, &info) 339 | dtrtrs('L', 'N', 'N', &n, &n, &temp_nn[0,0], &n, &temp_nn2[0,0], &n, &info) 340 | 341 | # Jfilt_t = J_filt_t + J11 342 | daxpy(&nn, &one, &J11[0,0], &inc, &Jfilt_t[0,0], &inc) 343 | 344 | # J_filt_t = J_filt_t + J11 - J_12 (J_smooth - J_pred_tp1 + J_22)^{-1} J_12.T 345 | dgemm('T', 'N', &n, &n, &n, &neg1, &temp_nn2[0,0], &n, &temp_nn2[0,0], &n, &one, &Jfilt_t[0,0], &n) 346 | 347 | 348 | # hfilt_t = h_filt_t + h1 349 | daxpy(&n, &one, &h1[0], &inc, &hfilt_t[0], &inc) 350 | 351 | # temp_n = h_smooth_tp1 - h_pred_tp1 + h2 352 | dcopy(&n, &hsmooth_tp1[0], &inc, &temp_n[0], &inc) 353 | daxpy(&n, &neg1, &hpred_tp1[0], &inc, &temp_n[0], &inc) 354 | daxpy(&n, &one, &h2[0], &inc, &temp_n[0], &inc) 355 | 356 | # temp_n = (Jsmooth_tp1 - J_pred_tp1 + J22)^{-1} (h_smooth_tp1 - h_pred_tp1 + h2) 357 | dpotrs('L', &n, &inc, &temp_nn[0,0], &n, &temp_n[0], &n, &info) 358 | 359 | # h_filt_t = h_filt_t + h1 - J_12.T (Jsmooth_tp1 - J_pred_tp1 + J22)^{-1} (h_smooth_tp1 - h_pred_tp1 + h2) 360 | dgemv('N', &n, &n, &neg1, &J21[0,0], &n, &temp_n[0], &inc, &one, &hfilt_t[0], &inc) 361 | 362 | # Convert to distribution form 363 | info_to_distn(Jfilt_t, hfilt_t, mu_t, sigma_t) 364 | 365 | # Compute expected sufficient statistics 366 | dgemm('T', 'N', &n, &n, &n, &neg1, &J21[0,0], &n, &sigma_t[0,0], &n, &zero, &ExnxT[0,0], &n) 367 | dpotrs('L', &n, &n, &temp_nn[0,0], &n, &ExnxT[0,0], &n, &info) 368 | dger(&n, &n, &one, &mu_tp1[0], &inc, &mu_t[0], &inc, &ExnxT[0,0], &n) 369 | 370 | 371 | cdef inline void info_to_distn( 372 | double[:,:] J, double[:] h, double[:] mu, double[:,:] Sigma, 373 | ) nogil: 374 | cdef int n = J.shape[0] 375 | cdef int nn = n*n 376 | cdef int inc = 1, info = 0 377 | cdef double zero = 0., one = 1. 378 | 379 | dcopy(&nn, &J[0,0], &inc, &Sigma[0,0], &inc) 380 | dpotrf('L', &n, &Sigma[0,0], &n, &info) 381 | dpotri('L', &n, &Sigma[0,0], &n, &info) 382 | copy_upper_lower(n, &Sigma[0,0]) # NOTE: 'L' in Fortran order, but upper for C order 383 | dgemv('N', &n, &n, &one, &Sigma[0,0], &n, &h[0], &inc, &zero, &mu[0], &inc) 384 | 385 | 386 | cdef inline void info_sample_gaussian( 387 | double[:,:] J, double[:] h, 388 | double[:] randvec, 389 | ) nogil: 390 | cdef int n = h.shape[0] 391 | cdef int inc = 1, info = 0 392 | cdef double one = 1. 393 | 394 | dpotrf('L', &n, &J[0,0], &n, &info) 395 | dtrtrs('L', 'T', 'N', &n, &inc, &J[0,0], &n, &randvec[0], &n, &info) 396 | dpotrs('L', &n, &inc, &J[0,0], &n, &h[0], &n, &info) 397 | daxpy(&n, &one, &h[0], &inc, &randvec[0], &inc) 398 | 399 | 400 | ################### 401 | # test bindings # 402 | ################### 403 | 404 | def info_predict_test(J,h,J11,J21,J22,h1,h2,logZ, Jpredict,hpredict): 405 | temp_n = np.random.randn(*h.shape) 406 | temp_nn = np.random.randn(*J.shape) 407 | temp_nn2 = np.random.randn(*J.shape) 408 | 409 | return info_predict(J,h,J11,J21,J22,h1,h2,logZ,Jpredict,hpredict,temp_n,temp_nn,temp_nn2) 410 | -------------------------------------------------------------------------------- /pylds/lds_messages.pyx: -------------------------------------------------------------------------------- 1 | # distutils: extra_compile_args = -O2 -w 2 | # distutils: include_dirs = pylds/ 3 | # cython: boundscheck = False, nonecheck = False, wraparound = False, cdivision = True 4 | 5 | import numpy as np 6 | from numpy.lib.stride_tricks import as_strided 7 | 8 | cimport numpy as np 9 | cimport cython 10 | from libc.math cimport log, sqrt 11 | from numpy.math cimport INFINITY, PI 12 | 13 | from scipy.linalg.cython_blas cimport dsymm, dcopy, dgemm, dgemv, daxpy, dsyrk, \ 14 | dtrmv, dger, dnrm2, ddot 15 | from scipy.linalg.cython_lapack cimport dpotrf, dpotrs, dpotri, dtrtrs 16 | from cyutil cimport copy_transpose, copy_upper_lower 17 | 18 | # NOTE: because the matrix operations are done in Fortran order but the code 19 | # expects C ordered arrays as input, the BLAS and LAPACK function calls mark 20 | # input matrices as transposed. temporaries, which don't get sliced, are left in 21 | # Fortran order. for symmetric matrices, F/C order doesn't matter. 22 | # NOTE: I tried the dsymm / dsyrk version and it was slower, even for larger p! 23 | # NOTE: using typed memoryview syntax instead of raw pointers is slightly slower 24 | # due to function call struct passing overhead, but much prettier 25 | 26 | # TODO try an Eigen version! faster for small matrices (numerically and in 27 | # function call overhead) 28 | # TODO cholesky update/downdate versions (square root filter) 29 | 30 | 31 | ################################## 32 | # distribution-form operations # 33 | ################################## 34 | 35 | def kalman_filter( 36 | double[:] mu_init, double[:,:] sigma_init, 37 | double[:,:,:] A, double[:,:,:] B, double[:,:,:] sigma_states, 38 | double[:,:,:] C, double[:,:,:] D, double[:,:,:] sigma_obs, 39 | double[:,:] inputs, double[:,::1] data): 40 | 41 | # allocate temporaries and internals 42 | cdef int T = C.shape[0], p = C.shape[1], n = C.shape[2] 43 | cdef int t 44 | 45 | cdef double[::1] mu_predict = np.copy(mu_init) 46 | cdef double[:,:] sigma_predict = np.copy(sigma_init) 47 | 48 | cdef double[::1,:] temp_pp = np.empty((p,p),order='F') 49 | cdef double[::1,:] temp_pn = np.empty((p,n),order='F') 50 | cdef double[::1] temp_p = np.empty((p,), order='F') 51 | cdef double[::1,:] temp_nn = np.empty((n,n),order='F') 52 | 53 | # allocate output 54 | cdef double[:,::1] filtered_mus = np.empty((T,n)) 55 | cdef double[:,:,::1] filtered_sigmas = np.empty((T,n,n)) 56 | cdef double ll = 0. 57 | 58 | # run filter forwards 59 | for t in range(T): 60 | ll += condition_on( 61 | mu_predict, sigma_predict, 62 | C[t], D[t], sigma_obs[t], 63 | inputs[t], data[t], 64 | filtered_mus[t], filtered_sigmas[t], 65 | temp_p, temp_pn, temp_pp) 66 | predict( 67 | filtered_mus[t], filtered_sigmas[t], inputs[t], 68 | A[t], B[t], sigma_states[t], 69 | mu_predict, sigma_predict, 70 | temp_nn) 71 | 72 | return ll, np.asarray(filtered_mus), np.asarray(filtered_sigmas) 73 | 74 | 75 | def rts_smoother( 76 | double[::1] mu_init, double[:,::1] sigma_init, 77 | double[:,:,:] A, double[:,:,:] B, double[:,:,:] sigma_states, 78 | double[:,:,:] C, double[:,:,:] D, double[:,:,:] sigma_obs, 79 | double[:,:] inputs, double[:,::1] data): 80 | 81 | # allocate temporaries and internals 82 | cdef int T = C.shape[0], p = C.shape[1], n = C.shape[2] 83 | cdef int t 84 | 85 | cdef double[:,::1] mu_predicts = np.empty((T+1,n)) 86 | cdef double[:,:,:] sigma_predicts = np.empty((T+1,n,n)) 87 | 88 | cdef double[::1,:] temp_pp = np.empty((p,p),order='F') 89 | cdef double[::1,:] temp_pn = np.empty((p,n),order='F') 90 | cdef double[::1,:] temp_nn = np.empty((n,n),order='F') 91 | cdef double[::1,:] temp_nn2 = np.empty((n,n),order='F') 92 | cdef double[::1] temp_p = np.empty((p,), order='F') 93 | 94 | # allocate output 95 | cdef double[:,::1] smoothed_mus = np.empty((T,n)) 96 | cdef double[:,:,::1] smoothed_sigmas = np.empty((T,n,n)) 97 | cdef double ll = 0. 98 | 99 | # run filter forwards, saving predictions 100 | mu_predicts[0] = mu_init 101 | sigma_predicts[0] = sigma_init 102 | for t in range(T): 103 | ll += condition_on( 104 | mu_predicts[t], sigma_predicts[t], 105 | C[t], D[t], sigma_obs[t], 106 | inputs[t], data[t], 107 | smoothed_mus[t], smoothed_sigmas[t], 108 | temp_p, temp_pn, temp_pp) 109 | predict( 110 | smoothed_mus[t], smoothed_sigmas[t], inputs[t], 111 | A[t], B[t], sigma_states[t], 112 | mu_predicts[t+1], sigma_predicts[t+1], 113 | temp_nn) 114 | 115 | # run rts update backwards, using predictions 116 | for t in range(T-2,-1,-1): 117 | rts_backward_step( 118 | A[t], sigma_states[t], 119 | smoothed_mus[t], smoothed_sigmas[t], 120 | mu_predicts[t+1], sigma_predicts[t+1], 121 | smoothed_mus[t+1], smoothed_sigmas[t+1], 122 | temp_nn, temp_nn2) 123 | 124 | return ll, np.asarray(smoothed_mus), np.asarray(smoothed_sigmas) 125 | 126 | 127 | def filter_and_sample( 128 | double[:] mu_init, double[:,:] sigma_init, 129 | double[:,:,:] A, double[:,:,:] B, double[:,:,:] sigma_states, 130 | double[:,:,:] C, double[:,:,:] D, double[:,:,:] sigma_obs, 131 | double[:,:] inputs, double[:,::1] data): 132 | 133 | # allocate temporaries and internals 134 | cdef int T = C.shape[0], p = C.shape[1], n = C.shape[2] 135 | cdef int t 136 | 137 | cdef double[::1] mu_predict = np.copy(mu_init) 138 | cdef double[:,:] sigma_predict = np.copy(sigma_init) 139 | 140 | cdef double[::1,:] temp_pp = np.empty((p,p),order='F') 141 | cdef double[::1,:] temp_pn = np.empty((p,n),order='F') 142 | cdef double[::1] temp_p = np.empty((p,), order='F') 143 | cdef double[::1,:] temp_nn = np.empty((n,n),order='F') 144 | cdef double[::1] temp_n = np.empty((n,), order='F') 145 | 146 | cdef double[:,::1] filtered_mus = np.empty((T,n)) 147 | cdef double[:,:,::1] filtered_sigmas = np.empty((T,n,n)) 148 | 149 | # allocate output and generate randomness 150 | cdef double[:,::1] randseq = np.random.randn(T,n) 151 | cdef double ll = 0. 152 | 153 | # run filter forwards 154 | for t in range(T): 155 | ll += condition_on( 156 | mu_predict, sigma_predict, 157 | C[t], D[t], sigma_obs[t], 158 | inputs[t], data[t], 159 | filtered_mus[t], filtered_sigmas[t], 160 | temp_p, temp_pn, temp_pp) 161 | predict( 162 | filtered_mus[t], filtered_sigmas[t], inputs[t], 163 | A[t], B[t], sigma_states[t], 164 | mu_predict, sigma_predict, 165 | temp_nn) 166 | 167 | # sample backwards 168 | sample_gaussian(filtered_mus[T-1], filtered_sigmas[T-1], randseq[T-1]) 169 | for t in range(T-2,-1,-1): 170 | condition_on( 171 | filtered_mus[t], filtered_sigmas[t], 172 | A[t], B[t], sigma_states[t], 173 | inputs[t], randseq[t+1], 174 | filtered_mus[t], filtered_sigmas[t], 175 | temp_n, temp_nn, sigma_predict) 176 | sample_gaussian(filtered_mus[t], filtered_sigmas[t], randseq[t]) 177 | 178 | return ll, np.asarray(randseq) 179 | 180 | 181 | def E_step( 182 | double[:] mu_init, double[:,:] sigma_init, 183 | double[:,:,:] A, double[:,:,:] B, double[:,:,:] sigma_states, 184 | double[:,:,:] C, double[:,:,:] D, double[:,:,:] sigma_obs, 185 | double[:,:] inputs, double[:,::1] data): 186 | 187 | # NOTE: this is almost the same as the RTS smoother except 188 | # 1. we collect statistics along the way, and 189 | # 2. we use the RTS gain matrix to do it 190 | 191 | # allocate temporaries and internals 192 | cdef int T = C.shape[0], p = C.shape[1], n = C.shape[2] 193 | cdef int t 194 | 195 | cdef double[:,:] mu_predicts = np.empty((T+1,n)) 196 | cdef double[:,:,:] sigma_predicts = np.empty((T+1,n,n)) 197 | 198 | cdef double[::1,:] temp_pp = np.empty((p,p),order='F') 199 | cdef double[::1,:] temp_pn = np.empty((p,n),order='F') 200 | cdef double[::1,:] temp_nn = np.empty((n,n),order='F') 201 | cdef double[::1,:] temp_nn2 = np.empty((n,n),order='F') 202 | cdef double[::1] temp_p = np.empty((p,), order='F') 203 | 204 | # allocate output 205 | cdef double[:,::1] smoothed_mus = np.empty((T,n)) 206 | cdef double[:,:,::1] smoothed_sigmas = np.empty((T,n,n)) 207 | cdef double[:,:,::1] ExnxT = np.empty((T-1,n,n)) # 'n' for next 208 | cdef double ll = 0. 209 | 210 | # run filter forwards, saving predictions 211 | mu_predicts[0] = mu_init 212 | sigma_predicts[0] = sigma_init 213 | for t in range(T): 214 | ll += condition_on( 215 | mu_predicts[t], sigma_predicts[t], 216 | C[t], D[t], sigma_obs[t], 217 | inputs[t], data[t], 218 | smoothed_mus[t], smoothed_sigmas[t], 219 | temp_p, temp_pn, temp_pp) 220 | predict( 221 | smoothed_mus[t], smoothed_sigmas[t], inputs[t], 222 | A[t], B[t], sigma_states[t], 223 | mu_predicts[t+1], sigma_predicts[t+1], 224 | temp_nn) 225 | 226 | # run rts update backwards, using predictions and setting E[x_t x_{t+1}^T] 227 | for t in range(T-2,-1,-1): 228 | rts_backward_step( 229 | A[t], sigma_states[t], 230 | smoothed_mus[t], smoothed_sigmas[t], 231 | mu_predicts[t+1], sigma_predicts[t+1], 232 | smoothed_mus[t+1], smoothed_sigmas[t+1], 233 | temp_nn, temp_nn2) 234 | set_dynamics_stats( 235 | smoothed_mus[t], smoothed_mus[t+1], smoothed_sigmas[t+1], 236 | temp_nn, ExnxT[t]) 237 | 238 | return ll, np.asarray(smoothed_mus), np.asarray(smoothed_sigmas), np.asarray(ExnxT) 239 | 240 | 241 | ### diagonal emission distributions (D is diagonal) 242 | 243 | def kalman_filter_diagonal( 244 | double[:] mu_init, double[:,:] sigma_init, 245 | double[:,:,:] A, double[:,:,:] B, double[:,:,:] sigma_states, 246 | double[:,:,:] C, double[:,:,:] D, double[:,:] sigma_obs, 247 | double[:,:] inputs, double[:,::1] data): 248 | 249 | # allocate temporaries and internals 250 | cdef int T = C.shape[0], p = C.shape[1], n = C.shape[2] 251 | cdef int t 252 | 253 | cdef double[::1] mu_predict = np.copy(mu_init) 254 | cdef double[:,:] sigma_predict = np.copy(sigma_init) 255 | 256 | cdef double[::1,:] temp_pn = np.empty((p,n), order='F') 257 | cdef double[::1,:] temp_pn2 = np.empty((p,n), order='F') 258 | cdef double[::1,:] temp_pn3 = np.empty((p,n), order='F') 259 | cdef double[::1] temp_p = np.empty((p,), order='F') 260 | cdef double[::1,:] temp_nn = np.empty((n,n), order='F') 261 | cdef double[::1,:] temp_pk = np.empty((p,n+1),order='F') 262 | cdef double[::1,:] temp_nk = np.empty((n,n+1),order='F') 263 | 264 | # allocate output 265 | cdef double[:,::1] filtered_mus = np.empty((T,n)) 266 | cdef double[:,:,::1] filtered_sigmas = np.empty((T,n,n)) 267 | cdef double ll = 0. 268 | 269 | # run filter forwards 270 | for t in range(T): 271 | ll += condition_on_diagonal( 272 | mu_predict, sigma_predict, 273 | C[t], D[t], sigma_obs[t], 274 | inputs[t], data[t], 275 | filtered_mus[t], filtered_sigmas[t], 276 | temp_p, temp_nn, temp_pn, temp_pn2, temp_pn3, temp_pk, temp_nk) 277 | predict( 278 | filtered_mus[t], filtered_sigmas[t], inputs[t], 279 | A[t], B[t], sigma_states[t], 280 | mu_predict, sigma_predict, 281 | temp_nn) 282 | 283 | return ll, np.asarray(filtered_mus), np.asarray(filtered_sigmas) 284 | 285 | 286 | def filter_and_sample_diagonal( 287 | double[:] mu_init, double[:,:] sigma_init, 288 | double[:,:,:] A, double[:,:,:] B, double[:,:,:] sigma_states, 289 | double[:,:,:] C, double[:,:,:] D, double[:,:] sigma_obs, 290 | double[:,:] inputs, double[:,::1] data): 291 | 292 | # allocate temporaries and internals 293 | cdef int T = C.shape[0], p = C.shape[1], n = C.shape[2] 294 | cdef int t 295 | 296 | cdef double[::1] mu_predict = np.copy(mu_init) 297 | cdef double[:,:] sigma_predict = np.copy(sigma_init) 298 | 299 | cdef double[::1,:] temp_pn = np.empty((p,n), order='F') 300 | cdef double[::1,:] temp_pn2 = np.empty((p,n), order='F') 301 | cdef double[::1,:] temp_pn3 = np.empty((p,n), order='F') 302 | cdef double[::1] temp_p = np.empty((p,), order='F') 303 | cdef double[::1,:] temp_nn = np.empty((n,n), order='F') 304 | cdef double[::1] temp_n = np.empty((n,), order='F') 305 | cdef double[::1,:] temp_pk = np.empty((p,n+1),order='F') 306 | cdef double[::1,:] temp_nk = np.empty((n,n+1),order='F') 307 | 308 | cdef double[:,::1] filtered_mus = np.empty((T,n)) 309 | cdef double[:,:,::1] filtered_sigmas = np.empty((T,n,n)) 310 | 311 | # allocate output and generate randomness 312 | cdef double[:,::1] randseq = np.random.randn(T,n) 313 | cdef double ll = 0. 314 | 315 | # run filter forwards 316 | for t in range(T): 317 | ll += condition_on_diagonal( 318 | mu_predict, sigma_predict, 319 | C[t], D[t], sigma_obs[t], 320 | inputs[t], data[t], 321 | filtered_mus[t], filtered_sigmas[t], 322 | temp_p, temp_nn, temp_pn, temp_pn2, temp_pn3, temp_pk, temp_nk) 323 | predict( 324 | filtered_mus[t], filtered_sigmas[t], inputs[t], 325 | A[t], B[t], sigma_states[t], 326 | mu_predict, sigma_predict, 327 | temp_nn) 328 | 329 | # sample backwards 330 | sample_gaussian(filtered_mus[T-1], filtered_sigmas[T-1], randseq[T-1]) 331 | for t in range(T-2,-1,-1): 332 | condition_on( 333 | filtered_mus[t], filtered_sigmas[t], 334 | A[t], B[t], sigma_states[t], 335 | inputs[t], randseq[t+1], 336 | filtered_mus[t], filtered_sigmas[t], 337 | temp_n, temp_nn, sigma_predict) 338 | sample_gaussian(filtered_mus[t], filtered_sigmas[t], randseq[t]) 339 | 340 | return ll, np.asarray(randseq) 341 | 342 | 343 | ### random walk (A = I, B is diagonal, C = I, D is diagonal) 344 | 345 | def filter_and_sample_randomwalk( 346 | double[::1] mu_init, double[::1] sigmasq_init, double[:,:] sigmasq_states, 347 | double[:,:] sigmasq_obs, double[:,::1] data): 348 | # TODO: the randomwalk code needs to be updated to handle inputs 349 | 350 | # allocate temporaries and internals 351 | cdef int T = data.shape[0], n = data.shape[1] 352 | cdef int t 353 | 354 | cdef double[::1] mu_predict = np.copy(mu_init) 355 | cdef double[::1] sigmasq_predict = np.copy(sigmasq_init) 356 | 357 | cdef double[:,::1] filtered_mus = np.empty((T,n)) 358 | cdef double[:,::1] filtered_sigmasqs = np.empty((T,n)) 359 | 360 | # allocate output and generate randomness 361 | cdef double[:,::1] randseq = np.random.randn(T,n) 362 | cdef double ll = 0. 363 | 364 | # run filter forwards 365 | for t in range(T): 366 | ll += condition_on_randomwalk( 367 | n, &mu_predict[0], &sigmasq_predict[0], &sigmasq_obs[t,0], &data[t,0], 368 | &filtered_mus[t,0], &filtered_sigmasqs[t,0]) 369 | predict_randomwalk( 370 | n, &filtered_mus[t,0], &filtered_sigmasqs[t,0], &sigmasq_states[t,0], 371 | &mu_predict[0], &sigmasq_predict[0]) 372 | 373 | # sample backwards 374 | sample_diagonal_gaussian(n, &filtered_mus[T-1,0], &filtered_sigmasqs[T-1,0], &randseq[T-1,0]) 375 | for t in range(T-2,-1,-1): 376 | condition_on_randomwalk( 377 | n, &filtered_mus[t,0], &filtered_sigmasqs[t,0], &sigmasq_states[t,0], &randseq[t+1,0], 378 | &filtered_mus[t,0], &filtered_sigmasqs[t,0]) 379 | sample_diagonal_gaussian(n, &filtered_mus[t,0], &filtered_sigmasqs[t,0], &randseq[t,0]) 380 | 381 | return ll, np.asarray(randseq) 382 | 383 | 384 | ############################ 385 | # distribution-form util # 386 | ############################ 387 | 388 | cdef inline double condition_on( 389 | # prior predictions 390 | double[:] mu_x, double[:,:] sigma_x, 391 | # Observation model 392 | double[:,:] C, double[:,:] D, double[:,:] sigma_obs, 393 | # Data 394 | double[:] u, double[:] y, 395 | # outputs 396 | double[:] mu_cond, double[:,:] sigma_cond, 397 | # temps 398 | double[:] temp_p, double[:,:] temp_pn, double[:,:] temp_pp, 399 | ) nogil: 400 | cdef int p = C.shape[0], n = C.shape[1], d = D.shape[1] 401 | cdef int nn = n*n, pp = p*p 402 | cdef int inc = 1, info = 0 403 | cdef double one = 1., zero = 0., neg1 = -1., ll = 0. 404 | 405 | if y[0] != y[0]: # nan check 406 | dcopy(&n, &mu_x[0], &inc, &mu_cond[0], &inc) 407 | dcopy(&nn, &sigma_x[0,0], &inc, &sigma_cond[0,0], &inc) 408 | return 0. 409 | else: 410 | # NOTE: the C and D arguments are treated as transposed because C and D are 411 | # assumed to be in C order (row-major order) 412 | 413 | # Compute temp_pn = C * sigma_x 414 | # and temp_pp = chol(sigma_obs + C * sigma_x * C.T) = chol(S) 415 | dgemm('T', 'N', &p, &n, &n, &one, &C[0,0], &n, &sigma_x[0,0], &n, &zero, &temp_pn[0,0], &p) 416 | # Now temp_pp = sigma_obs 417 | dcopy(&pp, &sigma_obs[0,0], &inc, &temp_pp[0,0], &inc) 418 | # temp_pp += temp_pn * C = sigma_obs + C * sigma_x * C.T 419 | # call this S, as in (18.38) of Murphy 420 | dgemm('N', 'N', &p, &p, &n, &one, &temp_pn[0,0], &p, &C[0,0], &n, &one, &temp_pp[0,0], &p) 421 | # temp_pp = cholesky(S, lower=True) = L 422 | dpotrf('L', &p, &temp_pp[0,0], &p, &info) 423 | 424 | # Compute the residual -- this is where the inputs come in 425 | dcopy(&p, &y[0], &inc, &temp_p[0], &inc) 426 | # temp_p -= - C * mu_x = y - C * mu_x 427 | dgemv('T', &n, &p, &neg1, &C[0,0], &n, &mu_x[0], &inc, &one, &temp_p[0], &inc) 428 | # temp_p -= - D * u = y - C * mu_x - D * u 429 | if d > 0: 430 | dgemv('T', &d, &p, &neg1, &D[0,0], &d, &u[0], &inc, &one, &temp_p[0], &inc) 431 | # Solve temp_p = temp_pp^{-1} temp_p 432 | # = L^{-1} (y - C * mu_x - D * u) 433 | dtrtrs('L', 'N', 'N', &p, &inc, &temp_pp[0,0], &p, &temp_p[0], &p, &info) 434 | # log likelihood = -1/2 * ||L^{-1} (y - C * mu_x - D * u)||*2 435 | ll = (-1./2) * dnrm2(&p, &temp_p[0], &inc)**2 436 | 437 | # Second solve with cholesky 438 | # temp_p = L.T^{-1} temp_p 439 | # = S^{-1} (y - C * mu_x - D * u) 440 | dtrtrs('L', 'T', 'N', &p, &inc, &temp_pp[0,0], &p, &temp_p[0], &p, &info) 441 | 442 | # Compute the conditional mean 443 | # mu_cond = mu_x + temp_pn * temp_p 444 | # = mu_x + sigma_x * C.T * S^{-1} (y - C * mu_x - D * u) 445 | # Compare this to (18.31) of Murphy 446 | if (&mu_x[0] != &mu_cond[0]): 447 | dcopy(&n, &mu_x[0], &inc, &mu_cond[0], &inc) 448 | dgemv('T', &p, &n, &one, &temp_pn[0,0], &p, &temp_p[0], &inc, &one, &mu_cond[0], &inc) 449 | 450 | # Compute the conditional covariance 451 | # sigma_cond = sigma_x - C * sigma_x.T * L.T^{-1} * L^{-1} C.T * sigma_x 452 | # = sigma_x - sigma_x * C.T * S^{-1} * C * sigma_x 453 | # Compare this to (18.32) of Murphy 454 | # 455 | # First, temp_pn = temp_pp^{-1} temp_pn 456 | # = L^{-1} C.T * sigma_x 457 | # Then we square this and subtract from sigma_x 458 | dtrtrs('L', 'N', 'N', &p, &n, &temp_pp[0,0], &p, &temp_pn[0,0], &p, &info) 459 | if (&sigma_x[0,0] != &sigma_cond[0,0]): 460 | dcopy(&nn, &sigma_x[0,0], &inc, &sigma_cond[0,0], &inc) 461 | # TODO this call aliases pointers, should really call dsyrk and copy lower to upper 462 | dgemm('T', 'N', &n, &n, &p, &neg1, &temp_pn[0,0], &p, &temp_pn[0,0], &p, &one, &sigma_cond[0,0], &n) 463 | 464 | # Compute log determinant of the covariance by summing log diagonal of cholesky 465 | ll -= p/2. * log(2.*PI) 466 | for i in range(p): 467 | ll -= log(temp_pp[i,i]) 468 | 469 | return ll 470 | 471 | 472 | cdef inline void predict( 473 | # inputs 474 | double[:] mu, double[:,:] sigma, double[:] u, 475 | double[:,:] A, double[:,:] B, double[:,:] sigma_states, 476 | # outputs 477 | double[:] mu_predict, double[:,:] sigma_predict, 478 | # temps 479 | double[:,:] temp_nn, 480 | ) nogil: 481 | cdef int n = mu.shape[0], d = B.shape[1] 482 | cdef int nn = n*n 483 | cdef int inc = 1 484 | cdef double one = 1., zero = 0. 485 | 486 | # NOTE: the A and B arguments are treated as transposed because A and B are assumed to be 487 | # in C order (row-major order) 488 | 489 | # mu_predict = A * mu 490 | dgemv('T', &n, &n, &one, &A[0,0], &n, &mu[0], &inc, &zero, &mu_predict[0], &inc) 491 | # mu_predict += B * u 492 | if d > 0: 493 | dgemv('T', &d, &n, &one, &B[0,0], &d, &u[0], &inc, &one, &mu_predict[0], &inc) 494 | 495 | # temp_nn = A * sigma 496 | dgemm('T', 'N', &n, &n, &n, &one, &A[0,0], &n, &sigma[0,0], &n, &zero, &temp_nn[0,0], &n) 497 | dcopy(&nn, &sigma_states[0,0], &inc, &sigma_predict[0,0], &inc) 498 | # sigma_pred = sigma_states + A * sigma * A.T 499 | dgemm('N', 'N', &n, &n, &n, &one, &temp_nn[0,0], &n, &A[0,0], &n, &one, &sigma_predict[0,0], &n) 500 | 501 | 502 | cdef inline void sample_gaussian( 503 | # inputs (which get mutated) 504 | double[:] mu, double[:,:] sigma, 505 | # input/output 506 | double[:] randvec, 507 | ) nogil: 508 | cdef int n = mu.shape[0] 509 | cdef int inc = 1, info = 0 510 | cdef double one = 1. 511 | 512 | dpotrf('L', &n, &sigma[0,0], &n, &info) 513 | dtrmv('L', 'N', 'N', &n, &sigma[0,0], &n, &randvec[0], &inc) 514 | daxpy(&n, &one, &mu[0], &inc, &randvec[0], &inc) 515 | 516 | 517 | cdef inline void rts_backward_step( 518 | double[:,:] A, double[:,:] sigma_states, 519 | double[:] filtered_mu, double[:,:] filtered_sigma, # inputs/outputs 520 | double[:] next_predict_mu, double[:,:] next_predict_sigma, # mutated inputs! 521 | double[:] next_smoothed_mu, double[:,:] next_smoothed_sigma, 522 | double[:,:] temp_nn, double[:,:] temp_nn2, # temps 523 | ) nogil: 524 | # filtered_mu and filtered_sigma are m_{t|t} and P_{t|t}, respectively 525 | # Recall, m_{t|t} = A * m_{t|t-1} + B * u_{t-1}, 526 | # and, P_{t|t} = A * P_{t|t-1} * A.T + Q_t 527 | # next_predict_mu and next_predict_sigma are m_{t+1|t} and P_{t+1|t} 528 | # next_smoothed_mu and next_smoothed_sigma are m_{t+1|T} and P_{t+1|T} 529 | 530 | # NOTE: on exit, temp_nn holds the RTS gain, called G_k.T in the notation of 531 | # Thm 8.2 of Sarkka 2013 "Bayesian Filtering and Smoothing" 532 | 533 | cdef int n = A.shape[0] 534 | cdef int nn = n*n 535 | cdef int inc = 1, info = 0 536 | cdef double one = 1., zero = 0., neg1 = -1. 537 | 538 | # NOTE: the A argument is treated as transposed because A is assumd to be in C order 539 | # temp_nn = A * P_{t|t} 540 | dgemm('T', 'N', &n, &n, &n, &one, &A[0,0], &n, &filtered_sigma[0,0], &n, &zero, &temp_nn[0,0], &n) 541 | # TODO: could just call dposv directly instead of dpotrf+dpotrs 542 | # temp_nn2 = P_{t+1|t} 543 | dcopy(&nn, &next_predict_sigma[0,0], &inc, &temp_nn2[0,0], &inc) 544 | # temp_nn2 = chol(P_{t|t-1}, lower=True) 545 | dpotrf('L', &n, &temp_nn2[0,0], &n, &info) 546 | # temp_nn = temp_nn2^{-1} temp_nn 547 | # = (P_{t+1|t})^{-1} A * P_k 548 | # = G_t^T 549 | dpotrs('L', &n, &n, &temp_nn2[0,0], &n, &temp_nn[0,0], &n, &info) 550 | 551 | # next_predict_mu = m_{t+1|t} - m_{t+1|T} (negated version of notes) 552 | daxpy(&n, &neg1, &next_smoothed_mu[0], &inc, &next_predict_mu[0], &inc) 553 | # filtered_mu = filtered_mu - temp_nn * next_predict_mu 554 | # = m_{t|t} - G_t^T * (m_{t|t-1} - m_{t+1|T}) 555 | # = m_{t|t} + G_t^T * (m_{t+1|T} - m_{t|t-1}) 556 | # = m_{t|T} 557 | dgemv('T', &n, &n, &neg1, &temp_nn[0,0], &n, &next_predict_mu[0], &inc, &one, &filtered_mu[0], &inc) 558 | 559 | # next_predict_sigma = next_predict_sigma - next_smoothed_sigma 560 | # = P_{t+1|t} - P_{t+1|T} 561 | daxpy(&nn, &neg1, &next_smoothed_sigma[0,0], &inc, &next_predict_sigma[0,0], &inc) 562 | # temp_nn2 = -next_predict_sigma * temp_nn 563 | # = (P_{t+1|T} - P_{t+1|t}) * G_t^T 564 | dgemm('N', 'N', &n, &n, &n, &neg1, &next_predict_sigma[0,0], &n, &temp_nn[0,0], &n, &zero, &temp_nn2[0,0], &n) 565 | # filtered_sigma = filtered_sigma + temp_nn * temp_nn2 566 | # = P_{t|t} + G_t (P_{t+1|T} - P_{t+1|t}) G_t^T 567 | # = P_{t|T} 568 | dgemm('T', 'N', &n, &n, &n, &one, &temp_nn[0,0], &n, &temp_nn2[0,0], &n, &one, &filtered_sigma[0,0], &n) 569 | 570 | 571 | cdef inline void set_dynamics_stats( 572 | double[::1] mk, double[::1] mkn, double[:,::1] Pkns, 573 | double[::1,:] GkT, 574 | double[:,::1] ExnxT, 575 | ) nogil: 576 | # mk = m_{t|T} 577 | # mkn = m_{t+1|T} 578 | # Pkns = P_{t+1|T} 579 | # GkT is the transpose of the RTS gain G_t = P_t A_t^T P_{t+1|t}^{-1} 580 | # ExnxT = E[x_{t+1} x_t^T] 581 | cdef int n = mk.shape[0], inc = 1 582 | cdef double one = 1., zero = 0. 583 | # E_xnxT = GkT.T * Pkns 584 | # = G_t * P_{t+1|T} 585 | # = P_t A_t 586 | # Compare to Sarkka notes, this is the cross covariance, Cov(xn, x) 587 | dgemm('T', 'N', &n, &n, &n, &one, &GkT[0,0], &n, &Pkns[0,0], &n, &zero, &ExnxT[0,0], &n) 588 | # Recall, 589 | # Cov(xn, x) = E[(x_{t+1} - m_{t+1}) (x_t - m_t)^T] 590 | # = E[x_{t+1} x_t^T] - E[m_{t+1} x_t^T] - E[x_{t+1} m_t^T] + m_{t+1} m_t^T 591 | # = E[x_{t+1} x_t^T] - m_{t+1} m_t^T 592 | # Add outer product of means to get E[x_{t+1] x_t^T] 593 | dger(&n, &n, &one, &mk[0], &inc, &mkn[0], &inc, &ExnxT[0,0], &n) 594 | 595 | 596 | ### diagonal emission distributions 597 | 598 | cdef inline double condition_on_diagonal( 599 | double[:] mu_x, double[:,:] sigma_x, 600 | double[:,:] C, double[:,:] D, double[:] sigma_obs, 601 | double[:] u, double[:] y, 602 | double[:] mu_cond, double[:,:] sigma_cond, 603 | double[::1] temp_p, double[::1,:] temp_nn, 604 | double[::1,:] temp_pn, double[::1,:] temp_pn2, double[::1,:] temp_pn3, 605 | double[::1,:] temp_pk, double[::1,:] temp_nk, 606 | ) nogil: 607 | 608 | # see Boyd and Vandenberghe, Convex Optimization, Appendix C.4.3 (p. 679) 609 | # and also https://en.wikipedia.org/wiki/Woodbury_matrix_identity 610 | # and https://en.wikipedia.org/wiki/Matrix_determinant_lemma 611 | 612 | # an extra temp (temp_pn3) and an extra copy_transpose are needed because C 613 | # is not stored in Fortran order as solve_diagonal_plus_lowrank requires 614 | 615 | cdef int p = C.shape[0], n = C.shape[1], d = D.shape[1] 616 | cdef int nn = n*n, pn = p*n 617 | cdef int inc = 1, info = 0, i 618 | cdef double one = 1., zero = 0., neg1 = -1., ll = 0. 619 | 620 | if y[0] != y[0]: # nan check 621 | dcopy(&n, &mu_x[0], &inc, &mu_cond[0], &inc) 622 | dcopy(&nn, &sigma_x[0,0], &inc, &sigma_cond[0,0], &inc) 623 | return 0. 624 | else: 625 | # NOTE: the C arguments are treated as transposed because C is 626 | # assumed to be in C order 627 | 628 | # Compute residual y - C * mu_x - D * u 629 | dcopy(&p, &y[0], &inc, &temp_p[0], &inc) 630 | dgemv('T', &n, &p, &neg1, &C[0,0], &n, &mu_x[0], &inc, &one, &temp_p[0], &inc) 631 | if d > 0: 632 | dgemv('T', &d, &p, &neg1, &D[0,0], &d, &u[0], &inc, &one, &temp_p[0], &inc) 633 | 634 | # Compute conditional mean and variance using low rank plus diagonal code 635 | dgemm('T', 'N', &p, &n, &n, &one, &C[0,0], &n, &sigma_x[0,0], &n, &zero, &temp_pn[0,0], &p) 636 | copy_transpose(n, p, &C[0,0], &temp_pn3[0,0]) 637 | dcopy(&p, &temp_p[0], &inc, &temp_pk[0,0], &inc) 638 | dcopy(&pn, &temp_pn[0,0], &inc, &temp_pk[0,1], &inc) 639 | 640 | ll = -1./2 * solve_diagonal_plus_lowrank( 641 | sigma_obs, temp_pn3, sigma_x, temp_pk, False, 642 | temp_nn, temp_pn2, temp_nk) 643 | 644 | if (&mu_x[0] != &mu_cond[0]): 645 | dcopy(&n, &mu_x[0], &inc, &mu_cond[0], &inc) 646 | dgemv('T', &p, &n, &one, &temp_pn[0,0], &p, &temp_pk[0,0], &inc, &one, &mu_cond[0], &inc) 647 | 648 | if (&sigma_x[0,0] != &sigma_cond[0,0]): 649 | dcopy(&nn, &sigma_x[0,0], &inc, &sigma_cond[0,0], &inc) 650 | dgemm('T', 'N', &n, &n, &p, &neg1, &temp_pn[0,0], &p, &temp_pk[0,1], &p, &one, &sigma_cond[0,0], &n) 651 | 652 | ll -= 1./2 * ddot(&p, &temp_p[0], &inc, &temp_pk[0,0], &inc) 653 | ll -= p/2. * log(2*PI) 654 | return ll 655 | 656 | 657 | cdef inline double solve_diagonal_plus_lowrank( 658 | double[:] a, double[:,:] B, double[:,:] C, double[:,:] b, bint C_is_identity, 659 | double[:,:] temp_nn, double[:,:] temp_pn, double[:,:] temp_nk, 660 | ) nogil: 661 | cdef int p = B.shape[0], n = B.shape[1], k = b.shape[1] 662 | cdef int nn = n*n, inc = 1, info = 0, i, j 663 | cdef double one = 1., zero = 0., neg1 = -1., logdet = 0. 664 | 665 | # NOTE: on exit, temp_nn is guaranteed to hold chol(C^{-1} + B' A^{-1} B) 666 | # NOTE: assumes Fortran order for everything 667 | 668 | for j in range(k): 669 | for i in range(p): 670 | b[i,j] /= a[i] 671 | 672 | for j in range(n): 673 | for i in range(p): 674 | temp_pn[i,j] = B[i,j] / a[i] 675 | 676 | dcopy(&nn, &C[0,0], &inc, &temp_nn[0,0], &inc) 677 | if not C_is_identity: 678 | dpotrf('L', &n, &temp_nn[0,0], &n, &info) 679 | for i in range(n): 680 | logdet += 2.*log(temp_nn[i,i]) 681 | dpotri('L', &n, &temp_nn[0,0], &n, &info) 682 | dgemm('T', 'N', &n, &n, &p, &one, &B[0,0], &p, &temp_pn[0,0], &p, &one, &temp_nn[0,0], &n) 683 | dpotrf('L', &n, &temp_nn[0,0], &n, &info) 684 | 685 | dgemm('T', 'N', &n, &k, &p, &one, &B[0,0], &p, &b[0,0], &p, &zero, &temp_nk[0,0], &n) 686 | dpotrs('L', &n, &k, &temp_nn[0,0], &n, &temp_nk[0,0], &n, &info) 687 | 688 | dgemm('N', 'N', &p, &k, &n, &neg1, &temp_pn[0,0], &p, &temp_nk[0,0], &n, &one, &b[0,0], &p) 689 | 690 | for i in range(n): 691 | logdet += 2.*log(temp_nn[i,i]) 692 | for i in range(p): 693 | logdet += log(a[i]) 694 | 695 | return logdet 696 | 697 | 698 | ### identity dynamics and emission distributions (A = I, C = I) 699 | 700 | # NOTE: we have to use raw pointers here because numpy (and hence cython's typed 701 | # memoryview checks) doesn't count arrays with a zero stride as possibly being 702 | # C-contiguous 703 | 704 | cdef inline double condition_on_randomwalk( 705 | int n, 706 | double *mu_x, double *sigmasq_x, 707 | double *sigmasq_obs, double *y, 708 | double *mu_cond, double *sigmasq_cond, 709 | ) nogil: 710 | 711 | # TODO: the randomwalk code needs to be updated to handle inputs 712 | 713 | cdef double ll = -n/2. * log(2.*PI), sigmasq_yi 714 | cdef int i 715 | for i in range(n): 716 | sigmasq_yi = sigmasq_x[i] + sigmasq_obs[i] 717 | ll -= 1./2 * log(sigmasq_yi) 718 | ll -= 1./2 * (y[i] - mu_x[i])**2 / sigmasq_yi 719 | mu_cond[i] = mu_x[i] + sigmasq_x[i] / sigmasq_yi * (y[i] - mu_x[i]) 720 | sigmasq_cond[i] = sigmasq_x[i] - sigmasq_x[i]**2 / sigmasq_yi 721 | 722 | return ll 723 | 724 | 725 | cdef inline void predict_randomwalk( 726 | int n, 727 | double *mu, double *sigmasq, double *sigmasq_states, 728 | double *mu_predict, double *sigmasq_predict, 729 | ) nogil: 730 | 731 | cdef int i 732 | for i in range(n): 733 | mu_predict[i] = mu[i] 734 | sigmasq_predict[i] = sigmasq[i] + sigmasq_states[i] 735 | 736 | 737 | cdef inline void sample_diagonal_gaussian( 738 | int n, 739 | double *mu, double *sigmasq, double *randvec, 740 | ) nogil: 741 | 742 | cdef int i 743 | for i in range(n): 744 | randvec[i] = mu[i] + sqrt(sigmasq[i]) * randvec[i] 745 | 746 | 747 | ################### 748 | # test bindings # 749 | ################### 750 | 751 | def test_condition_on_diagonal( 752 | double[:] mu_x, double[:,:] sigma_x, 753 | double[:,:] C, double[:,:] D, double[:] sigma_obs, 754 | double[:] u, double[:] y, 755 | double[:] mu_cond, double[:,:] sigma_cond): 756 | p = y.shape[0] 757 | n = mu_x.shape[0] 758 | k = n+1 759 | temp_p = np.asfortranarray(np.random.randn(p)) 760 | temp_nn = np.asfortranarray(np.random.randn(n,n)) 761 | temp_pn = np.asfortranarray(np.random.randn(p,n)) 762 | temp_pn2 = np.asfortranarray(np.random.randn(p,n)) 763 | temp_pn3 = np.asfortranarray(np.random.randn(p,n)) 764 | temp_pk = np.asfortranarray(np.random.randn(p,k)) 765 | temp_nk = np.asfortranarray(np.random.randn(n,k)) 766 | return condition_on_diagonal( 767 | mu_x, sigma_x, C, D, sigma_obs, u, y, mu_cond, sigma_cond, 768 | temp_p, temp_nn, temp_pn, temp_pn2, temp_pn3, temp_pk, temp_nk) 769 | 770 | 771 | def test_solve_diagonal_plus_lowrank( 772 | double[:] a, double[::1,:] B, double[:,:] C, bint C_is_identity, 773 | double[::1,:] b): 774 | p = B.shape[0] 775 | n = B.shape[1] 776 | k = b.shape[1] 777 | temp_nn = np.asfortranarray(np.random.randn(n,n)) 778 | temp_pn = np.asfortranarray(np.random.randn(p,n)) 779 | temp_nk = np.asfortranarray(np.random.randn(n,k)) 780 | return solve_diagonal_plus_lowrank(a,B,C,b,C_is_identity,temp_nn,temp_pn,temp_nk) 781 | -------------------------------------------------------------------------------- /pylds/lds_messages_interface.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | from numpy.lib.stride_tricks import as_strided 4 | from functools import wraps, partial 5 | 6 | ################################ 7 | # distribution-form wrappers # 8 | ################################ 9 | 10 | from pylds.lds_messages import \ 11 | kalman_filter as _kalman_filter, \ 12 | rts_smoother as _rts_smoother, \ 13 | filter_and_sample as _filter_and_sample, \ 14 | kalman_filter_diagonal as _kalman_filter_diagonal, \ 15 | filter_and_sample_diagonal as _filter_and_sample_diagonal, \ 16 | filter_and_sample_randomwalk as _filter_and_sample_randomwalk, \ 17 | E_step as _E_step 18 | 19 | 20 | def _ensure_ndim(X,T,ndim): 21 | X = np.require(X,dtype=np.float64, requirements='C') 22 | assert ndim-1 <= X.ndim <= ndim 23 | if X.ndim == ndim: 24 | assert X.shape[0] == T 25 | return X 26 | else: 27 | return as_strided(X, shape=(T,)+X.shape, strides=(0,)+X.strides) 28 | 29 | 30 | def _argcheck(mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, data): 31 | T = data.shape[0] 32 | A, B, sigma_states, C, D, sigma_obs = \ 33 | map(partial(_ensure_ndim, T=T, ndim=3), 34 | [A, B, sigma_states, C, D, sigma_obs]) 35 | # Check that the inputs are C ordered and at least 1d 36 | inputs = np.require(inputs, dtype=np.float64, requirements='C') 37 | 38 | data = np.require(data, dtype=np.float64, requirements='C') 39 | return mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, data 40 | 41 | 42 | def _argcheck_diag_sigma_obs(mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, data): 43 | T = data.shape[0] 44 | A, B, sigma_states, C, D, = \ 45 | map(partial(_ensure_ndim, T=T, ndim=3), 46 | [A, B, sigma_states, C, D]) 47 | sigma_obs = _ensure_ndim(sigma_obs, T=T, ndim=2) 48 | inputs = np.require(inputs, dtype=np.float64, requirements='C') 49 | data = np.require(data, dtype=np.float64, requirements='C') 50 | return mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, data 51 | 52 | 53 | def _argcheck_randomwalk(mu_init, sigma_init, sigmasq_states, sigmasq_obs, data): 54 | T = data.shape[0] 55 | sigmasq_states, sigmasq_obs = \ 56 | map(partial(_ensure_ndim, T=T, ndim=2), 57 | [sigmasq_states, sigmasq_obs]) 58 | data = np.require(data, dtype=np.float64, requirements='C') 59 | return mu_init, sigma_init, sigmasq_states, sigmasq_obs, data 60 | 61 | 62 | def _wrap(func, check): 63 | @wraps(func) 64 | def wrapped(*args, **kwargs): 65 | return func(*check(*args,**kwargs)) 66 | return wrapped 67 | 68 | 69 | kalman_filter = _wrap(_kalman_filter,_argcheck) 70 | rts_smoother = _wrap(_rts_smoother,_argcheck) 71 | filter_and_sample = _wrap(_filter_and_sample,_argcheck) 72 | E_step = _wrap(_E_step,_argcheck) 73 | kalman_filter_diagonal = _wrap(_kalman_filter_diagonal,_argcheck_diag_sigma_obs) 74 | filter_and_sample_diagonal = _wrap(_filter_and_sample_diagonal,_argcheck_diag_sigma_obs) 75 | filter_and_sample_randomwalk = _wrap(_filter_and_sample_randomwalk,_argcheck_randomwalk) 76 | 77 | 78 | ############################### 79 | # information-form wrappers # 80 | ############################### 81 | 82 | from pylds.lds_info_messages import \ 83 | kalman_info_filter as _kalman_info_filter, \ 84 | info_E_step as _info_E_step, \ 85 | info_sample as _info_sample 86 | 87 | 88 | def _info_argcheck(J_init, h_init, log_Z_init, 89 | J_pair_11, J_pair_21, J_pair_22, h_pair_1, h_pair_2, log_Z_pair, 90 | J_node, h_node, log_Z_node): 91 | T = h_node.shape[0] 92 | assert np.isscalar(log_Z_init) 93 | J_node = _ensure_ndim(J_node, T=T, ndim=3) 94 | J_pair_11, J_pair_21, J_pair_22 = \ 95 | map(partial(_ensure_ndim, T=T-1, ndim=3), 96 | [J_pair_11, J_pair_21, J_pair_22]) 97 | h_pair_1, h_pair_2 = \ 98 | map(partial(_ensure_ndim, T=T-1, ndim=2), 99 | [h_pair_1, h_pair_2]) 100 | log_Z_pair = _ensure_ndim(log_Z_pair, T=T-1, ndim=1) 101 | log_Z_node = _ensure_ndim(log_Z_node, T=T, ndim=1) 102 | 103 | h_node = np.require(h_node, dtype=np.float64, requirements='C') 104 | return J_init, h_init, log_Z_init, \ 105 | J_pair_11, J_pair_21, J_pair_22, h_pair_1, h_pair_2, log_Z_pair,\ 106 | J_node, h_node, log_Z_node 107 | 108 | 109 | kalman_info_filter = _wrap(_kalman_info_filter, _info_argcheck) 110 | info_E_step = _wrap(_info_E_step, _info_argcheck) 111 | info_sample = _wrap(_info_sample, _info_argcheck) 112 | -------------------------------------------------------------------------------- /pylds/lds_messages_python.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | 4 | 5 | solve_psd = np.linalg.solve 6 | 7 | 8 | def kf(init_mu, init_sigma, 9 | As, Bs, sigma_statess, 10 | Cs, Ds, sigam_obss, 11 | inputs, emissions): 12 | T, D_latent = emissions.shape[0], As[0].shape[0] 13 | 14 | filtered_mus = np.empty((T,D_latent)) 15 | filtered_sigmas = np.empty((T,D_latent,D_latent)) 16 | 17 | # messages forwards 18 | prediction_mu, prediction_sigma = init_mu, init_sigma 19 | for t, (A,B,sigma_states,C,D,sigma_obs) in \ 20 | enumerate(zip(As, Bs, sigma_statess, Cs, Ds, sigam_obss)): 21 | # condition 22 | filtered_mus[t], filtered_sigmas[t] = \ 23 | condition_on(prediction_mu,prediction_sigma,C,D,sigma_obs,inputs[t], emissions[t]) 24 | 25 | # predict 26 | prediction_mu, prediction_sigma = \ 27 | A.dot(filtered_mus[t]) + B.dot(inputs[t]), A.dot(filtered_sigmas[t]).dot(A.T) + sigma_states 28 | 29 | return filtered_mus, filtered_sigmas 30 | 31 | 32 | def kf_resample_lds(init_mu, init_sigma, 33 | As, Bs, sigma_statess, 34 | Cs, Ds, sigma_obss, 35 | inputs, emissions): 36 | T, D_latent = emissions.shape[0], As[0].shape[0] 37 | x = np.empty((T,D_latent)) 38 | 39 | filtered_mus, filtered_sigmas = \ 40 | kf(init_mu, init_sigma, As, Bs, sigma_statess, Cs, Ds, sigma_obss, inputs, emissions) 41 | 42 | # sample backwards 43 | # TODO pull rng out of the loop 44 | x[-1] = np.random.multivariate_normal(filtered_mus[-1],filtered_sigmas[-1]) 45 | for t in xrange(T-2,-1,-1): 46 | x[t] = np.random.multivariate_normal( 47 | *condition_on(filtered_mus[t], filtered_sigmas[t], As[t], Bs[t], sigma_statess[t], inputs[t], x[t + 1])) 48 | 49 | return x 50 | 51 | 52 | def condition_on(mu_x, sigma_x, A, B, sigma_obs, u, y): 53 | # mu = mu_x + sigma_xy sigma_yy^{-1} (y - A mu_x - B u) 54 | # sigma = sigma_x - sigma_xy sigma_yy^{-1} sigma_xy' 55 | sigma_xy = sigma_x.dot(A.T) 56 | sigma_yy = A.dot(sigma_x).dot(A.T) + sigma_obs 57 | mu = mu_x + sigma_xy.dot(solve_psd(sigma_yy, y - A.dot(mu_x) - B.dot(u))) 58 | sigma = sigma_x - sigma_xy.dot(solve_psd(sigma_yy,sigma_xy.T)) 59 | return mu, symmetrize(sigma) 60 | 61 | 62 | def symmetrize(A): 63 | ret = A+A.T 64 | ret /= 2. 65 | return ret 66 | 67 | -------------------------------------------------------------------------------- /pylds/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import copy 3 | import numpy as np 4 | 5 | from pybasicbayes.abstractions import Model, ModelGibbsSampling, \ 6 | ModelEM, ModelMeanField, ModelMeanFieldSVI 7 | 8 | from pybasicbayes.distributions import DiagonalRegression, Gaussian, Regression 9 | from pylds.distributions import PoissonRegression, BernoulliRegression 10 | from pylds.states import LDSStates, LDSStatesCountData, LDSStatesMissingData,\ 11 | LDSStatesZeroInflatedCountData 12 | from pylds.laplace import LaplaceApproxPoissonLDSStates, LaplaceApproxBernoulliLDSStates 13 | from pylds.util import random_rotation 14 | 15 | 16 | class _LDSBase(Model): 17 | 18 | _states_class = LDSStates 19 | 20 | def __init__(self,dynamics_distn,emission_distn): 21 | self.dynamics_distn = dynamics_distn 22 | self.emission_distn = emission_distn 23 | self.states_list = [] 24 | 25 | def add_data(self,data, inputs=None, **kwargs): 26 | assert isinstance(data,np.ndarray) 27 | self.states_list.append(self._states_class(model=self, data=data, inputs=inputs, **kwargs)) 28 | return self 29 | 30 | def log_likelihood(self, data=None, inputs=None, **kwargs): 31 | if data is not None: 32 | assert isinstance(data,(list,np.ndarray)) 33 | if isinstance(data,np.ndarray): 34 | self.add_data(data=data, inputs=inputs, **kwargs) 35 | return self.states_list.pop().log_likelihood() 36 | else: 37 | return sum(self.log_likelihood(d, i) for (d, i) in zip(data, inputs)) 38 | else: 39 | return sum(s.log_likelihood() for s in self.states_list) 40 | 41 | def generate(self, T, keep=True, inputs=None, **kwargs): 42 | s = self._states_class(model=self, T=T, inputs=inputs, 43 | initialize_from_prior=True, **kwargs) 44 | data = self._generate_obs(s, inputs) 45 | if keep: 46 | self.states_list.append(s) 47 | return data, s.gaussian_states 48 | 49 | def _generate_obs(self,s, inputs): 50 | if s.data is None: 51 | inputs = np.zeros((s.T, 0)) if inputs is None else inputs 52 | s.data = self.emission_distn.rvs( 53 | x=np.hstack((s.gaussian_states, inputs)), return_xy=False) 54 | else: 55 | # filling in missing data 56 | raise NotImplementedError 57 | return s.data 58 | 59 | def smooth(self, data, inputs=None, **kwargs): 60 | self.add_data(data, inputs=inputs, **kwargs) 61 | s = self.states_list.pop() 62 | return s.smooth() 63 | 64 | def predict(self, data, Tpred): 65 | # return means and covariances 66 | raise NotImplementedError 67 | 68 | def sample_predictions(self, data, Tpred, inputs_pred=None, inputs=None, states_noise=True, obs_noise=True, **kwargs): 69 | self.add_data(data, inputs=inputs, **kwargs) 70 | s = self.states_list.pop() 71 | return s.sample_predictions(Tpred, inputs=inputs_pred, states_noise=states_noise, obs_noise=obs_noise) 72 | 73 | # convenience properties 74 | 75 | @property 76 | def D_latent(self): 77 | 'latent dimension' 78 | return self.dynamics_distn.D_out 79 | 80 | @property 81 | def D_obs(self): 82 | 'emission dimension' 83 | return self.emission_distn.D_out 84 | 85 | @property 86 | def D_input(self): 87 | 'input dimension' 88 | return self.dynamics_distn.D_in - self.dynamics_distn.D_out 89 | 90 | @property 91 | def mu_init(self): 92 | return np.zeros(self.D_latent) if not hasattr(self, '_mu_init') \ 93 | else self._mu_init 94 | 95 | @mu_init.setter 96 | def mu_init(self,mu_init): 97 | self._mu_init = mu_init 98 | 99 | @property 100 | def sigma_init(self): 101 | if hasattr(self,'_sigma_init'): 102 | return self._sigma_init 103 | 104 | try: 105 | from scipy.linalg import solve_discrete_lyapunov as dtlyap 106 | return dtlyap(self.A, self.sigma_states) 107 | except ImportError: 108 | return np.linalg.solve( 109 | np.eye(self.D_latent ** 2) - np.kron(self.A, self.A), self.sigma_states.ravel())\ 110 | .reshape(self.D_latent, self.D_latent) 111 | 112 | @sigma_init.setter 113 | def sigma_init(self,sigma_init): 114 | self._sigma_init = sigma_init 115 | 116 | @property 117 | def A(self): 118 | return self.dynamics_distn.A[:, :self.D_latent].copy("C") 119 | 120 | @A.setter 121 | def A(self,A): 122 | self.dynamics_distn.A[:, :self.D_latent] = A 123 | 124 | @property 125 | def B(self): 126 | return self.dynamics_distn.A[:, self.D_latent:].copy("C") 127 | 128 | @B.setter 129 | def B(self, B): 130 | self.dynamics_distn.A[:, self.D_latent:] = B 131 | 132 | @property 133 | def sigma_states(self): 134 | return self.dynamics_distn.sigma 135 | 136 | @sigma_states.setter 137 | def sigma_states(self,sigma_states): 138 | self.dynamics_distn.sigma = sigma_states 139 | 140 | @property 141 | def C(self): 142 | return self.emission_distn.A[:, :self.D_latent].copy("C") 143 | 144 | @C.setter 145 | def C(self,C): 146 | self.emission_distn.A[:, :self.D_latent] = C 147 | 148 | @property 149 | def D(self): 150 | return self.emission_distn.A[:, self.D_latent:].copy("C") 151 | 152 | @D.setter 153 | def D(self, D): 154 | self.emission_distn.A[:, self.D_latent:] = D 155 | 156 | @property 157 | def sigma_obs(self): 158 | return self.emission_distn.sigma 159 | 160 | @sigma_obs.setter 161 | def sigma_obs(self,sigma_obs): 162 | self.emission_distn.sigma = sigma_obs 163 | 164 | @property 165 | def diagonal_noise(self): 166 | return isinstance(self.emission_distn, DiagonalRegression) 167 | 168 | @property 169 | def sigma_obs_flat(self): 170 | return self.emission_distn.sigmasq_flat 171 | 172 | @sigma_obs_flat.setter 173 | def sigma_obs_flat(self, value): 174 | self.emission_distn.sigmasq_flat = value 175 | 176 | @property 177 | def is_stable(self): 178 | return np.max(np.abs(np.linalg.eigvals(self.dynamics_distn.A))) < 1. 179 | 180 | 181 | class _LDSGibbsSampling(_LDSBase, ModelGibbsSampling): 182 | def copy_sample(self): 183 | model = copy.deepcopy(self) 184 | for states in model.states_list: 185 | states.data = None 186 | return model 187 | 188 | def resample_model(self): 189 | self.resample_parameters() 190 | self.resample_states() 191 | 192 | def resample_states(self): 193 | for s in self.states_list: 194 | s.resample() 195 | 196 | def resample_parameters(self): 197 | self.resample_dynamics_distn() 198 | self.resample_emission_distn() 199 | 200 | def resample_dynamics_distn(self): 201 | self.dynamics_distn.resample( 202 | [np.hstack((s.gaussian_states[:-1],s.inputs[:-1],s.gaussian_states[1:])) 203 | for s in self.states_list]) 204 | 205 | def resample_emission_distn(self): 206 | xys = [(np.hstack((s.gaussian_states, s.inputs)), s.data) for s in self.states_list] 207 | self.emission_distn.resample(data=xys) 208 | 209 | class _LDSMeanField(_LDSBase, ModelMeanField): 210 | def meanfield_coordinate_descent_step(self): 211 | for s in self.states_list: 212 | if not hasattr(s, 'E_emission_stats'): 213 | s.meanfieldupdate() 214 | 215 | self.meanfield_update_parameters() 216 | self.meanfield_update_states() 217 | 218 | return self.vlb() 219 | 220 | def meanfield_update_states(self): 221 | for s in self.states_list: 222 | s.meanfieldupdate() 223 | 224 | def meanfield_update_parameters(self): 225 | self.meanfield_update_dynamics_distn() 226 | self.meanfield_update_emission_distn() 227 | 228 | def meanfield_update_dynamics_distn(self): 229 | self.dynamics_distn.meanfieldupdate( 230 | stats=(sum(s.E_dynamics_stats for s in self.states_list))) 231 | 232 | def meanfield_update_emission_distn(self): 233 | self.emission_distn.meanfieldupdate( 234 | stats=(sum(s.E_emission_stats for s in self.states_list))) 235 | 236 | def resample_from_mf(self): 237 | self.dynamics_distn.resample_from_mf() 238 | self.emission_distn.resample_from_mf() 239 | 240 | def vlb(self): 241 | vlb = 0. 242 | vlb += sum(s.get_vlb() for s in self.states_list) 243 | vlb += self.emission_distn.get_vlb() 244 | vlb += self.dynamics_distn.get_vlb() 245 | return vlb 246 | 247 | 248 | class _LDSMeanFieldSVI(_LDSBase, ModelMeanFieldSVI): 249 | def meanfield_sgdstep(self, minibatch, prob, stepsize, masks=None, **kwargs): 250 | states_list = self._get_mb_states_list(minibatch, masks, **kwargs) 251 | for s in states_list: 252 | s.meanfieldupdate() 253 | self._meanfield_sgdstep_parameters(states_list, prob, stepsize) 254 | 255 | def _meanfield_sgdstep_parameters(self, states_list, prob, stepsize): 256 | self._meanfield_sgdstep_dynamics_distn(states_list, prob, stepsize) 257 | self._meanfield_sgdstep_emission_distn(states_list, prob, stepsize) 258 | 259 | def _meanfield_sgdstep_dynamics_distn(self, states_list, prob, stepsize): 260 | self.dynamics_distn.meanfield_sgdstep( 261 | data=None, weights=None, 262 | stats=(sum(s.E_dynamics_stats for s in states_list)), 263 | prob=prob, stepsize=stepsize) 264 | 265 | def _meanfield_sgdstep_emission_distn(self, states_list, prob, stepsize): 266 | self.emission_distn.meanfield_sgdstep( 267 | data=None, weights=None, 268 | stats=(sum(s.E_emission_stats for s in states_list)), 269 | prob=prob, stepsize=stepsize) 270 | 271 | def _get_mb_states_list(self, minibatch, masks, **kwargs): 272 | minibatch = minibatch if isinstance(minibatch,list) else [minibatch] 273 | masks = [None] * len(minibatch) if masks is None else \ 274 | (masks if isinstance(masks, list) else [masks]) 275 | 276 | def get_states(data, mask): 277 | self.add_data(data, mask=mask, **kwargs) 278 | return self.states_list.pop() 279 | 280 | return [get_states(data, mask) for data, mask in zip(minibatch, masks)] 281 | 282 | 283 | class _NonstationaryLDSGibbsSampling(_LDSGibbsSampling): 284 | def resample_model(self): 285 | self.resample_init_dynamics_distn() 286 | super(_NonstationaryLDSGibbsSampling, self).resample_model() 287 | 288 | def resample_init_dynamics_distn(self): 289 | self.init_dynamics_distn.resample( 290 | [s.gaussian_states[0] for s in self.states_list]) 291 | 292 | 293 | class _LDSEM(_LDSBase, ModelEM): 294 | def EM_step(self): 295 | self.E_step() 296 | self.M_step() 297 | 298 | def E_step(self): 299 | for s in self.states_list: 300 | s.E_step() 301 | 302 | def M_step(self): 303 | self.M_step_dynamics_distn() 304 | self.M_step_emission_distn() 305 | 306 | def M_step_dynamics_distn(self): 307 | self.dynamics_distn.max_likelihood( 308 | data=None, 309 | stats=(sum(s.E_dynamics_stats for s in self.states_list))) 310 | 311 | def M_step_emission_distn(self): 312 | self.emission_distn.max_likelihood( 313 | data=None, 314 | stats=(sum(s.E_emission_stats for s in self.states_list))) 315 | 316 | 317 | class _NonstationaryLDSEM(_LDSEM): 318 | def M_Step(self): 319 | self.M_step_init_dynamics_distn() 320 | super(_NonstationaryLDSEM, self).M_step() 321 | 322 | def M_step_init_dynamics_distn(self): 323 | self.init_dynamics_distn.max_likelihood( 324 | stats=(sum(s.E_x1_x1 for s in self.states_list))) 325 | 326 | 327 | ################### 328 | # model classes # 329 | ################### 330 | 331 | class LDS(_LDSGibbsSampling, _LDSMeanField, _LDSEM, _LDSMeanFieldSVI, _LDSBase): 332 | pass 333 | 334 | 335 | class NonstationaryLDS( 336 | _NonstationaryLDSGibbsSampling, 337 | _NonstationaryLDSEM, 338 | _LDSBase): 339 | def __init__(self, init_dynamics_distn, *args, **kwargs): 340 | self.init_dynamics_distn = init_dynamics_distn 341 | super(NonstationaryLDS, self).__init__(*args, **kwargs) 342 | 343 | def resample_init_dynamics_distn(self): 344 | self.init_dynamics_distn.resample( 345 | [s.gaussian_states[0] for s in self.states_list]) 346 | 347 | # convenience properties 348 | 349 | @property 350 | def mu_init(self): 351 | return self.init_dynamics_distn.mu 352 | 353 | @mu_init.setter 354 | def mu_init(self, mu_init): 355 | self.init_dynamics_distn.mu = mu_init 356 | 357 | @property 358 | def sigma_init(self): 359 | return self.init_dynamics_distn.sigma 360 | 361 | @sigma_init.setter 362 | def sigma_init(self, sigma_init): 363 | self.init_dynamics_distn.sigma = sigma_init 364 | 365 | 366 | class MissingDataLDS(_LDSGibbsSampling, _LDSBase): 367 | _states_class = LDSStatesMissingData 368 | 369 | def copy_sample(self): 370 | model = copy.deepcopy(self) 371 | for states in model.states_list: 372 | states.data = None 373 | states.mask = None 374 | return model 375 | 376 | def resample_emission_distn(self): 377 | xys = [(np.hstack((s.gaussian_states, s.inputs)), s.data) for s in self.states_list] 378 | mask = [s.mask for s in self.states_list] 379 | self.emission_distn.resample(data=xys, mask=mask) 380 | 381 | 382 | class CountLDS(_LDSGibbsSampling, _LDSBase): 383 | _states_class = LDSStatesCountData 384 | 385 | def copy_sample(self): 386 | model = copy.deepcopy(self) 387 | for states in model.states_list: 388 | states.data = None 389 | states.mask = None 390 | states.omega = None 391 | return model 392 | 393 | def resample_emission_distn(self): 394 | xys = [(np.hstack((s.gaussian_states, s.inputs)), s.data) for s in self.states_list] 395 | mask = [s.mask for s in self.states_list] 396 | omega = [s.omega for s in self.states_list] 397 | self.emission_distn.resample(data=xys, mask=mask, omega=omega) 398 | 399 | 400 | class ZeroInflatedCountLDS(_LDSGibbsSampling, _LDSBase): 401 | _states_class = LDSStatesZeroInflatedCountData 402 | 403 | def __init__(self, rho, *args, **kwargs): 404 | """ 405 | :param rho: Probability of count drawn from model 406 | With pr 1-rho, the emission is deterministically zero 407 | """ 408 | super(ZeroInflatedCountLDS, self).__init__(*args, **kwargs) 409 | self.rho = rho 410 | 411 | def add_data(self,data, inputs=None, mask=None, **kwargs): 412 | self.states_list.append(self._states_class(model=self, data=data, inputs=inputs, mask=mask, **kwargs)) 413 | return self 414 | 415 | def _generate_obs(self,s, inputs): 416 | if s.data is None: 417 | # TODO: Do this sparsely 418 | inputs = np.zeros((s.T, 0)) if inputs is None else inputs 419 | data = self.emission_distn.rvs( 420 | x=np.hstack((s.gaussian_states, inputs)), return_xy=False) 421 | 422 | # Zero out data 423 | zeros = np.random.rand(s.T, self.D_obs) > self.rho 424 | data[zeros] = 0 425 | 426 | from scipy.sparse import csr_matrix 427 | s.data = csr_matrix(data) 428 | 429 | else: 430 | # filling in missing data 431 | raise NotImplementedError 432 | return s.data 433 | 434 | def resample_emission_distn(self): 435 | """ 436 | Now for the expensive part... the data is stored in a sparse row 437 | format, which is good for updating the latent states (since we 438 | primarily rely on dot products with the data, which can be 439 | efficiently performed for CSR matrices). 440 | 441 | However, in order to update the n-th row of the emission matrix, 442 | we need to know which counts are observed in the n-th column of data. 443 | This involves converting the data to a sparse column format, which 444 | can require (time) intensive re-indexing. 445 | """ 446 | masked_datas = [s.masked_data.tocsc() for s in self.states_list] 447 | xs = [np.hstack((s.gaussian_states, s.inputs))for s in self.states_list] 448 | 449 | for n in range(self.D_obs): 450 | # Get the nonzero values of the nth column 451 | rowns = [md.indices[md.indptr[n]:md.indptr[n+1]] for md in masked_datas] 452 | xns = [x[r] for x,r in zip(xs, rowns)] 453 | yns = [s.masked_data.getcol(n).data for s in self.states_list] 454 | maskns = [np.ones_like(y, dtype=bool) for y in yns] 455 | omegans = [s.omega.getcol(n).data for s in self.states_list] 456 | self.emission_distn._resample_row_of_emission_matrix(n, xns, yns, maskns, omegans) 457 | 458 | 459 | ### Models that support Laplace approximation 460 | class _LaplaceApproxLDSBase(NonstationaryLDS, _NonstationaryLDSEM): 461 | def log_conditional_likelihood(self): 462 | return sum(s.log_conditional_likelihood(s.gaussian_states) 463 | for s in self.states_list) 464 | 465 | def EM_step(self, verbose=False): 466 | self.E_step(verbose=verbose) 467 | self.M_step(verbose=verbose) 468 | 469 | def E_step(self, verbose=False): 470 | for s in self.states_list: 471 | s.E_step(verbose=verbose) 472 | 473 | def M_step(self, verbose=False): 474 | self.M_step_dynamics_distn() 475 | self.M_step_emission_distn(verbose=verbose) 476 | 477 | def M_step_emission_distn(self, verbose=False): 478 | # self.emission_distn.max_likelihood( 479 | # data=[(np.hstack((s.gaussian_states, s.inputs)), s.data) 480 | # for s in self.states_list]) 481 | 482 | self.emission_distn.max_expected_likelihood( 483 | stats=[s.E_emission_stats for s in self.states_list], 484 | verbose=verbose) 485 | 486 | def expected_log_likelihood(self): 487 | return sum([s.expected_log_likelihood() for s in self.states_list]) 488 | 489 | 490 | class LaplaceApproxPoissonLDS(_LaplaceApproxLDSBase): 491 | _states_class = LaplaceApproxPoissonLDSStates 492 | 493 | 494 | class LaplaceApproxBernoulliLDS(_LaplaceApproxLDSBase): 495 | _states_class = LaplaceApproxBernoulliLDSStates 496 | 497 | 498 | ############################## 499 | # convenience constructors # 500 | ############################## 501 | 502 | # TODO make data-dependent default constructors 503 | def DefaultLDS(D_obs, D_latent, D_input=0, 504 | mu_init=None, sigma_init=None, 505 | A=None, B=None, sigma_states=None, 506 | C=None, D=None, sigma_obs=None): 507 | model = LDS( 508 | dynamics_distn=Regression( 509 | nu_0=D_latent + 1, 510 | S_0=D_latent * np.eye(D_latent), 511 | M_0=np.zeros((D_latent, D_latent + D_input)), 512 | K_0=D_latent * np.eye(D_latent + D_input)), 513 | emission_distn=Regression( 514 | nu_0=D_obs + 1, 515 | S_0=D_obs * np.eye(D_obs), 516 | M_0=np.zeros((D_obs, D_latent + D_input)), 517 | K_0=D_obs * np.eye(D_latent + D_input))) 518 | 519 | set_default = \ 520 | lambda prm, val, default: \ 521 | model.__setattr__(prm, val if val is not None else default) 522 | 523 | set_default("mu_init", mu_init, np.zeros(D_latent)) 524 | set_default("sigma_init", sigma_init, np.eye(D_latent)) 525 | 526 | set_default("A", A, 0.99 * random_rotation(D_latent)) 527 | set_default("B", B, 0.1 * np.random.randn(D_latent, D_input)) 528 | set_default("sigma_states", sigma_states, 0.1 * np.eye(D_latent)) 529 | 530 | set_default("C", C, np.random.randn(D_obs, D_latent)) 531 | set_default("D", D, 0.1 * np.random.randn(D_obs, D_input)) 532 | set_default("sigma_obs", sigma_obs, 0.1 * np.eye(D_obs)) 533 | 534 | return model 535 | 536 | 537 | def DefaultPoissonLDS(D_obs, D_latent, D_input=0, 538 | mu_init=None, sigma_init=None, 539 | A=None, B=None, sigma_states=None, 540 | C=None, D=None 541 | ): 542 | model = LaplaceApproxPoissonLDS( 543 | init_dynamics_distn= 544 | Gaussian(mu_0=np.zeros(D_latent), sigma_0=np.eye(D_latent), 545 | kappa_0=1.0, nu_0=D_latent + 1), 546 | dynamics_distn=Regression( 547 | nu_0=D_latent + 1, 548 | S_0=D_latent * np.eye(D_latent), 549 | M_0=np.zeros((D_latent, D_latent + D_input)), 550 | K_0=D_latent * np.eye(D_latent + D_input)), 551 | emission_distn= 552 | PoissonRegression(D_obs, D_latent + D_input, verbose=False)) 553 | 554 | set_default = \ 555 | lambda prm, val, default: \ 556 | model.__setattr__(prm, val if val is not None else default) 557 | 558 | set_default("mu_init", mu_init, np.zeros(D_latent)) 559 | set_default("sigma_init", sigma_init, np.eye(D_latent)) 560 | 561 | set_default("A", A, 0.99 * random_rotation(D_latent)) 562 | set_default("B", B, 0.1 * np.random.randn(D_latent, D_input)) 563 | set_default("sigma_states", sigma_states, 0.1 * np.eye(D_latent)) 564 | 565 | set_default("C", C, np.random.randn(D_obs, D_latent)) 566 | set_default("D", D, 0.1 * np.random.randn(D_obs, D_input)) 567 | 568 | return model 569 | 570 | 571 | def DefaultBernoulliLDS(D_obs, D_latent, D_input=0, 572 | mu_init=None, sigma_init=None, 573 | A=None, B=None, sigma_states=None, 574 | C=None, D=None 575 | ): 576 | model = LaplaceApproxBernoulliLDS( 577 | init_dynamics_distn= 578 | Gaussian(mu_0=np.zeros(D_latent), sigma_0=np.eye(D_latent), 579 | kappa_0=1.0, nu_0=D_latent + 1), 580 | dynamics_distn=Regression( 581 | nu_0=D_latent + 1, 582 | S_0=D_latent * np.eye(D_latent), 583 | M_0=np.zeros((D_latent, D_latent + D_input)), 584 | K_0=D_latent * np.eye(D_latent + D_input)), 585 | emission_distn= 586 | BernoulliRegression(D_obs, D_latent + D_input, verbose=False)) 587 | 588 | set_default = \ 589 | lambda prm, val, default: \ 590 | model.__setattr__(prm, val if val is not None else default) 591 | 592 | set_default("mu_init", mu_init, np.zeros(D_latent)) 593 | set_default("sigma_init", sigma_init, np.eye(D_latent)) 594 | 595 | set_default("A", A, 0.99 * random_rotation(D_latent)) 596 | set_default("B", B, 0.1 * np.random.randn(D_latent, D_input)) 597 | set_default("sigma_states", sigma_states, 0.1 * np.eye(D_latent)) 598 | 599 | set_default("C", C, np.random.randn(D_obs, D_latent)) 600 | set_default("D", D, 0.1 * np.random.randn(D_obs, D_input)) 601 | 602 | return model 603 | -------------------------------------------------------------------------------- /pylds/util.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | 3 | from pylds.lds_messages_interface import info_E_step, kalman_info_filter, info_sample 4 | 5 | def random_rotation(n, theta=None): 6 | if theta is None: 7 | # Sample a random, slow rotation 8 | theta = 0.5 * np.pi * np.random.rand() 9 | 10 | if n == 1: 11 | return np.random.rand() * np.eye(1) 12 | 13 | rot = np.array([[np.cos(theta), -np.sin(theta)], 14 | [np.sin(theta), np.cos(theta)]]) 15 | out = np.zeros((n, n)) 16 | out[:2, :2] = rot 17 | q = np.linalg.qr(np.random.randn(n, n))[0] 18 | return q.dot(out).dot(q.T) 19 | 20 | 21 | def symm_block_tridiag_matmul(H_diag, H_upper_diag, v): 22 | """ 23 | Compute matrix-vector product with a symmetric block 24 | tridiagonal matrix H and vector v. 25 | 26 | :param H_diag: block diagonal terms of H 27 | :param H_upper_diag: upper block diagonal terms of H 28 | :param v: vector to multiple 29 | :return: H * v 30 | """ 31 | T, D, _ = H_diag.shape 32 | assert H_diag.ndim == 3 and H_diag.shape[2] == D 33 | assert H_upper_diag.shape == (T-1, D, D) 34 | assert v.shape == (T, D) 35 | 36 | out = np.matmul(H_diag, v[:, :, None])[:, :, 0] 37 | out[:-1] += np.matmul(H_upper_diag, v[1:][:, :, None])[:, :, 0] 38 | out[1:] += np.matmul(np.swapaxes(H_upper_diag, -2, -1), v[:-1][:, :, None])[:, :, 0] 39 | return out 40 | 41 | 42 | def solve_symm_block_tridiag(H_diag, H_upper_diag, v): 43 | """ 44 | use the info smoother to solve a symmetric block tridiagonal system 45 | """ 46 | T, D, _ = H_diag.shape 47 | assert H_diag.ndim == 3 and H_diag.shape[2] == D 48 | assert H_upper_diag.shape == (T - 1, D, D) 49 | assert v.shape == (T, D) 50 | 51 | J_init = J_11 = J_22 = np.zeros((D, D)) 52 | h_init = h_1 = h_2 = np.zeros((D,)) 53 | 54 | J_21 = np.swapaxes(H_upper_diag, -1, -2) 55 | J_node = H_diag 56 | h_node = v 57 | 58 | _, y, _, _ = info_E_step(J_init, h_init, 0, 59 | J_11, J_21, J_22, h_1, h_2, np.zeros((T-1)), 60 | J_node, h_node, np.zeros(T)) 61 | return y 62 | 63 | 64 | def convert_block_tridiag_to_banded(H_diag, H_upper_diag, lower=True): 65 | """ 66 | convert blocks to banded matrix representation required for scipy. 67 | we are using the "lower form." 68 | see https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.solveh_banded.html 69 | """ 70 | T, D, _ = H_diag.shape 71 | assert H_diag.ndim == 3 and H_diag.shape[2] == D 72 | assert H_upper_diag.shape == (T - 1, D, D) 73 | H_lower_diag = np.swapaxes(H_upper_diag, -2, -1) 74 | 75 | ab = np.zeros((2 * D, T * D)) 76 | 77 | # Fill in blocks along the diagonal 78 | for d in range(D): 79 | # Get indices of (-d)-th diagonal of H_diag 80 | i = np.arange(d, D) 81 | j = np.arange(0, D - d) 82 | h = np.column_stack((H_diag[:, i, j], np.zeros((T, d)))) 83 | ab[d] = h.ravel() 84 | 85 | # Fill in lower left corner of blocks below the diagonal 86 | for d in range(0, D): 87 | # Get indices of (-d)-th diagonal of H_diag 88 | i = np.arange(d, D) 89 | j = np.arange(0, D - d) 90 | h = np.column_stack((H_lower_diag[:, i, j], np.zeros((T - 1, d)))) 91 | ab[D + d, :D * (T - 1)] = h.ravel() 92 | 93 | # Fill in upper corner of blocks below the diagonal 94 | for d in range(1, D): 95 | # Get indices of (+d)-th diagonal of H_lower_diag 96 | i = np.arange(0, D - d) 97 | j = np.arange(d, D) 98 | h = np.column_stack((np.zeros((T - 1, d)), H_lower_diag[:, i, j])) 99 | ab[D - d, :D * (T - 1)] += h.ravel() 100 | 101 | return ab if lower else transpose_lower_banded_matrix(ab) 102 | 103 | 104 | def transpose_lower_banded_matrix(Lab): 105 | # This is painful 106 | Uab = np.flipud(Lab) 107 | u = Uab.shape[0] - 1 108 | for i in range(1,u+1): 109 | Uab[-(i+1), i:] = Uab[-(i+1), :-i] 110 | Uab[-(i + 1), :i] = 0 111 | return Uab 112 | 113 | 114 | def scipy_solve_symm_block_tridiag(H_diag, H_upper_diag, v, ab=None): 115 | """ 116 | use scipy.linalg.solve_banded to solve a symmetric block tridiagonal system 117 | 118 | see https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.solveh_banded.html 119 | """ 120 | from scipy.linalg import solveh_banded 121 | ab = convert_block_tridiag_to_banded(H_diag, H_upper_diag) \ 122 | if ab is None else ab 123 | x = solveh_banded(ab, v.ravel(), lower=True) 124 | return x.reshape(v.shape) 125 | 126 | 127 | def scipy_sample_block_tridiag(H_diag, H_upper_diag, size=1, ab=None, z=None): 128 | from scipy.linalg import cholesky_banded, solve_banded 129 | 130 | ab = convert_block_tridiag_to_banded(H_diag, H_upper_diag, lower=False) \ 131 | if ab is None else ab 132 | 133 | Uab = cholesky_banded(ab, lower=False) 134 | z = np.random.randn(ab.shape[1], size) if z is None else z 135 | 136 | # If lower = False, we have (U^T U)^{-1} = U^{-1} U^{-T} = AA^T = Sigma 137 | # where A = U^{-1}. Samples are Az = U^{-1}z = x, or equivalently Ux = z. 138 | return solve_banded((0, Uab.shape[0]-1), Uab, z) 139 | 140 | 141 | def sample_block_tridiag(H_diag, H_upper_diag): 142 | """ 143 | helper function for sampling block tridiag gaussians. 144 | this is only for speed comparison with the solve approach. 145 | """ 146 | T, D, _ = H_diag.shape 147 | assert H_diag.ndim == 3 and H_diag.shape[2] == D 148 | assert H_upper_diag.shape == (T - 1, D, D) 149 | 150 | J_init = J_11 = J_22 = np.zeros((D, D)) 151 | h_init = h_1 = h_2 = np.zeros((D,)) 152 | 153 | J_21 = np.swapaxes(H_upper_diag, -1, -2) 154 | J_node = H_diag 155 | h_node = np.zeros((T,D)) 156 | 157 | y = info_sample(J_init, h_init, 0, 158 | J_11, J_21, J_22, h_1, h_2, np.zeros((T-1)), 159 | J_node, h_node, np.zeros(T)) 160 | return y 161 | 162 | 163 | def logdet_symm_block_tridiag(H_diag, H_upper_diag): 164 | """ 165 | compute the log determinant of a positive definite, 166 | symmetric block tridiag matrix. Use the Kalman 167 | info filter to do so. Specifically, the KF computes 168 | the normalizer: 169 | 170 | log Z = 1/2 h^T J^{-1} h -1/2 log |J| +n/2 log 2 \pi 171 | 172 | We set h=0 to get -1/2 log |J| + n/2 log 2 \pi and from 173 | this we solve for log |J|. 174 | """ 175 | T, D, _ = H_diag.shape 176 | assert H_diag.ndim == 3 and H_diag.shape[2] == D 177 | assert H_upper_diag.shape == (T - 1, D, D) 178 | 179 | J_init = J_11 = J_22 = np.zeros((D, D)) 180 | h_init = h_1 = h_2 = np.zeros((D,)) 181 | log_Z_init = 0 182 | 183 | J_21 = np.swapaxes(H_upper_diag, -1, -2) 184 | log_Z_pair = 0 185 | 186 | J_node = H_diag 187 | h_node = np.zeros((T, D)) 188 | log_Z_node = 0 189 | 190 | logZ, _, _ = kalman_info_filter(J_init, h_init, log_Z_init, 191 | J_11, J_21, J_22, h_1, h_2, log_Z_pair, 192 | J_node, h_node, log_Z_node) 193 | 194 | # logZ = -1/2 log |J| + n/2 log 2 \pi 195 | logdetJ = -2 * (logZ - (T*D) / 2 * np.log(2 * np.pi)) 196 | return logdetJ 197 | 198 | 199 | def compute_symm_block_tridiag_covariances(H_diag, H_upper_diag): 200 | """ 201 | use the info smoother to solve a symmetric block tridiagonal system 202 | """ 203 | T, D, _ = H_diag.shape 204 | assert H_diag.ndim == 3 and H_diag.shape[2] == D 205 | assert H_upper_diag.shape == (T - 1, D, D) 206 | 207 | J_init = J_11 = J_22 = np.zeros((D, D)) 208 | h_init = h_1 = h_2 = np.zeros((D,)) 209 | 210 | J_21 = np.swapaxes(H_upper_diag, -1, -2) 211 | J_node = H_diag 212 | h_node = np.zeros((T, D)) 213 | 214 | _, _, sigmas, E_xt_xtp1 = \ 215 | info_E_step(J_init, h_init, 0, 216 | J_11, J_21, J_22, h_1, h_2, np.zeros((T-1)), 217 | J_node, h_node, np.zeros(T)) 218 | return sigmas, E_xt_xtp1 219 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension 2 | from setuptools.command.build_ext import build_ext as _build_ext 3 | from setuptools.command.sdist import sdist as _sdist 4 | from distutils.errors import CompileError 5 | from warnings import warn 6 | import os 7 | import sys 8 | from glob import glob 9 | 10 | # generate list of .c extension modules based on the .pyx files, which are 11 | # included in the pypi package (but aren't included in the git repo) 12 | extension_pathspec = os.path.join('pylds','*.pyx') # only searches pylds/*.pyx! 13 | paths = [os.path.splitext(fp)[0] for fp in glob(extension_pathspec)] 14 | names = ['.'.join(os.path.split(p)) for p in paths] 15 | ext_modules = [ 16 | Extension(name, sources=[path + '.c'], include_dirs=[os.path.join('deps')]) 17 | for name, path in zip(names,paths)] 18 | 19 | # alternatively, use cython to generate extension modules if we can import it 20 | # (which is required if we're installing from the git repo) 21 | try: 22 | from Cython.Distutils import build_ext as _build_ext 23 | except ImportError: 24 | use_cython = False 25 | else: 26 | use_cython = True 27 | 28 | if use_cython: 29 | from Cython.Build import cythonize 30 | try: 31 | ext_modules = cythonize('**/*.pyx') # recursive globbing! 32 | except: 33 | warn('Failed to generate extension module code from Cython files') 34 | sys.exit(1) 35 | 36 | # if we run the dist command, regenerate the sources from cython 37 | class sdist(_sdist): 38 | def run(self): 39 | from Cython.Build import cythonize 40 | cythonize(os.path.join('pylds','*.pyx')) 41 | _sdist.run(self) 42 | 43 | # the final extension module build step should have numpy headers available 44 | class build_ext(_build_ext): 45 | # see http://stackoverflow.com/q/19919905 for explanation 46 | def finalize_options(self): 47 | _build_ext.finalize_options(self) 48 | __builtins__.__NUMPY_SETUP__ = False 49 | import numpy as np 50 | self.include_dirs.append(np.get_include()) 51 | 52 | setup( 53 | name='pylds', 54 | version='0.0.5', 55 | description="Learning and inference for linear dynamical systems" 56 | "with fast Cython and BLAS/LAPACK implementations", 57 | author='Matthew James Johnson and Scott W Linderman', 58 | author_email='mattjj@csail.mit.edu', 59 | license="MIT", 60 | url='https://github.com/mattjj/pylds', 61 | packages=['pylds'], 62 | install_requires=[ 63 | 'numpy>=1.9.3', 'scipy>=0.16', 'matplotlib', 64 | 'pybasicbayes', 'autograd'], 65 | setup_requires=['future'], 66 | ext_modules=ext_modules, 67 | classifiers=[ 68 | 'Intended Audience :: Science/Research', 69 | 'Programming Language :: Python', 70 | ], 71 | keywords=[ 72 | 'lds', 'linear dynamical system', 'kalman filter', 'kalman', 73 | 'kalman smoother', 'rts smoother'], 74 | platforms="ALL", 75 | cmdclass={'build_ext': build_ext, 'sdist': sdist}) 76 | -------------------------------------------------------------------------------- /tests/test_dense.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | from scipy.stats import multivariate_normal 4 | 5 | from pylds.models import DefaultLDS 6 | 7 | 8 | ########## 9 | # util # 10 | ########## 11 | 12 | def cumsum(v,strict=False): 13 | if not strict: 14 | return np.cumsum(v,axis=0) 15 | else: 16 | out = np.zeros_like(v) 17 | out[1:] = np.cumsum(v[:-1],axis=0) 18 | return out 19 | 20 | 21 | def bmat(blocks): 22 | rowsizes = [row[0].shape[0] for row in blocks] 23 | colsizes = [col[0].shape[1] for col in zip(*blocks)] 24 | rowstarts = cumsum(rowsizes,strict=True) 25 | colstarts = cumsum(colsizes,strict=True) 26 | 27 | nrows, ncols = sum(rowsizes), sum(colsizes) 28 | out = np.zeros((nrows,ncols)) 29 | 30 | for i, (rstart, rsz) in enumerate(zip(rowstarts, rowsizes)): 31 | for j, (cstart, csz) in enumerate(zip(colstarts, colsizes)): 32 | out[rstart:rstart+rsz,cstart:cstart+csz] = blocks[i][j] 33 | 34 | return out 35 | 36 | 37 | def random_rotation(n,theta): 38 | if n == 1: 39 | return np.random.rand() * np.eye(1) 40 | 41 | rot = np.array([[np.cos(theta), -np.sin(theta)], 42 | [np.sin(theta), np.cos(theta)]]) 43 | out = np.zeros((n,n)) 44 | out[:2,:2] = rot 45 | q = np.linalg.qr(np.random.randn(n,n))[0] 46 | return q.dot(out).dot(q.T) 47 | 48 | 49 | def lds_to_dense_infoparams(model,data,inputs): 50 | T, n = data.shape[0], model.D_latent 51 | 52 | mu_init, sigma_init = model.mu_init, model.sigma_init 53 | A, B, sigma_states = model.A, model.B, model.sigma_states 54 | C, D, sigma_obs = model.C, model.D, model.sigma_obs 55 | ss_inv = np.linalg.inv(sigma_states) 56 | 57 | h = np.zeros((T,n)) 58 | h[0] += np.linalg.solve(sigma_init, mu_init) 59 | # Dynamics 60 | h[1:] += inputs[:-1].dot(B.T).dot(ss_inv) 61 | h[:-1] += -inputs[:-1].dot(B.T).dot(np.linalg.solve(sigma_states, A)) 62 | # Emissions 63 | h += C.T.dot(np.linalg.solve(sigma_obs, data.T)).T 64 | h += -inputs.dot(D.T).dot(np.linalg.solve(sigma_obs, C)) 65 | 66 | J = np.kron(np.eye(T),C.T.dot(np.linalg.solve(sigma_obs,C))) 67 | J[:n,:n] += np.linalg.inv(sigma_init) 68 | pairblock = bmat([[A.T.dot(ss_inv).dot(A), -A.T.dot(ss_inv)], 69 | [-ss_inv.dot(A), ss_inv]]) 70 | for t in range(0,n*(T-1),n): 71 | J[t:t+2*n,t:t+2*n] += pairblock 72 | 73 | return J.reshape(T*n,T*n), h.reshape(T*n) 74 | 75 | 76 | ########### 77 | # tests # 78 | ########### 79 | 80 | def same_means(model, Jh): 81 | J,h = Jh 82 | n, T = model.D_latent, model.states_list[0].T 83 | 84 | dense_mu = np.linalg.solve(J,h).reshape((T,n)) 85 | 86 | model.E_step() 87 | model_mu = model.states_list[0].smoothed_mus 88 | 89 | assert np.allclose(dense_mu,model_mu) 90 | 91 | 92 | def same_marginal_covs(model, Jh): 93 | J, h = Jh 94 | n, T = model.D_latent, model.states_list[0].T 95 | 96 | all_dense_sigmas = np.linalg.inv(J) 97 | dense_sigmas = np.array([all_dense_sigmas[k*n:(k+1)*n,k*n:(k+1)*n] 98 | for k in range(T)]) 99 | 100 | model.E_step() 101 | model_sigmas = model.states_list[0].smoothed_sigmas 102 | 103 | assert np.allclose(dense_sigmas,model_sigmas) 104 | 105 | 106 | def same_pairwise_secondmoments(model, Jh): 107 | J, h = Jh 108 | n, T = model.D_latent, model.states_list[0].T 109 | 110 | all_dense_sigmas = np.linalg.inv(J) 111 | dense_mu = np.linalg.solve(J,h) 112 | blockslices = [slice(k*n,(k+1)*n) for k in range(T)] 113 | dense_Extp1_xtT = \ 114 | sum(all_dense_sigmas[tp1,t] + np.outer(dense_mu[tp1],dense_mu[t]) 115 | for tp1,t in zip(blockslices[1:],blockslices[:-1])) 116 | 117 | model.E_step() 118 | model_Extp1_xtT = model.states_list[0].E_dynamics_stats[1][:n, :n] 119 | 120 | assert np.allclose(dense_Extp1_xtT,model_Extp1_xtT) 121 | 122 | 123 | def same_loglike(model,_): 124 | # NOTE: ignore the posterior (J,h) passed in so we can use the more 125 | # convenient prior info parameters 126 | states = model.states_list[0] 127 | data, inputs = states.data, states.inputs 128 | T = data.shape[0] 129 | 130 | C, model.C = model.C, np.zeros_like(model.C) 131 | D, model.D = model.D, np.zeros_like(model.D) 132 | J,h = lds_to_dense_infoparams(model,data,inputs) 133 | model.C, model.D = C, D 134 | 135 | bigC = np.kron(np.eye(T),C) 136 | bigD = np.kron(np.eye(T),D) 137 | mu_x = np.linalg.solve(J,h) 138 | sigma_x = np.linalg.inv(J) 139 | mu_y = bigC.dot(mu_x) + bigD.dot(inputs.ravel()) 140 | sigma_y = bigC.dot(sigma_x).dot(bigC.T) + np.kron(np.eye(T),model.sigma_obs) 141 | dense_loglike = multivariate_normal.logpdf(data.ravel(),mu_y,sigma_y) 142 | 143 | model_loglike = model.log_likelihood() 144 | if not np.isclose(dense_loglike, model_loglike): 145 | print("model - dense: ", model_loglike - dense_loglike) 146 | assert np.isclose(dense_loglike, model_loglike) 147 | 148 | 149 | def random_model(n,p,d,T): 150 | data = np.random.randn(T,p) 151 | inputs = np.random.randn(T,d) 152 | model = DefaultLDS(p,n,d) 153 | model.A = 0.99*random_rotation(n,0.01) 154 | model.B = 0.1*np.random.randn(n,d) 155 | model.C = np.random.randn(p,n) 156 | model.D = 0.1*np.random.randn(p,d) 157 | 158 | J,h = lds_to_dense_infoparams(model,data,inputs) 159 | model.add_data(data, inputs=inputs) 160 | 161 | return model, (J,h) 162 | 163 | 164 | def check_random_model(check): 165 | n, p, d = np.random.randint(2,5), np.random.randint(2,5), np.random.randint(0,3) 166 | T = np.random.randint(10,20) 167 | check(*random_model(n,p,d,T)) 168 | 169 | 170 | def test_means(): 171 | for _ in range(5): 172 | yield check_random_model, same_means 173 | 174 | 175 | def test_marginals_covs(): 176 | for _ in range(5): 177 | yield check_random_model, same_marginal_covs 178 | 179 | 180 | def test_pairwise_secondmoments(): 181 | for _ in range(5): 182 | yield check_random_model, same_pairwise_secondmoments 183 | 184 | 185 | def test_loglike(): 186 | for _ in range(5): 187 | yield check_random_model, same_loglike 188 | -------------------------------------------------------------------------------- /tests/test_diagonal_plus_lowrank.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | from numpy.random import rand, randn, randint 4 | from scipy.stats import multivariate_normal 5 | 6 | from pylds.lds_messages import test_solve_diagonal_plus_lowrank, test_condition_on_diagonal 7 | from pylds.lds_messages_interface import filter_and_sample, filter_and_sample_diagonal 8 | from test_infofilter import generate_data, spectral_radius 9 | 10 | 11 | ########## 12 | # util # 13 | ########## 14 | 15 | def rand_psd(n,k=None): 16 | k = k if k else n 17 | out = randn(n,k) 18 | return np.atleast_2d(out.dot(out.T)) 19 | 20 | 21 | def generate_diag_model(n,p,d): 22 | A = randn(n,n) 23 | A /= 2.*spectral_radius(A) # ensures stability 24 | assert spectral_radius(A) < 1. 25 | 26 | B = randn(n,d) 27 | 28 | sigma_states = randn(n,n) 29 | sigma_states = sigma_states.dot(sigma_states.T) 30 | 31 | C = randn(p,n) 32 | D = randn(p,d) 33 | 34 | sigma_obs = np.diag(rand(p)**2) 35 | 36 | mu_init = randn(n) 37 | sigma_init = rand_psd(n) 38 | 39 | return A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init 40 | 41 | 42 | ########### 43 | # tests # 44 | ########### 45 | 46 | # test solve_diagonal_plus_lowrank 47 | 48 | def solve_diagonal_plus_lowrank(a,B,C,b): 49 | out = b.copy(order='F') 50 | B = np.asfortranarray(B) 51 | C_is_identity = np.allclose(C,np.eye(C.shape[0])) 52 | logdet = test_solve_diagonal_plus_lowrank(a,B,C,C_is_identity,out) 53 | return logdet, out 54 | 55 | 56 | def check_diagonal_plus_lowrank(a,B,C,b): 57 | solve1 = np.linalg.solve(np.diag(a)+B.dot(C).dot(B.T), b) 58 | logdet1 = np.linalg.slogdet(np.diag(a)+B.dot(C).dot(B.T))[1] 59 | logdet2, solve2 = solve_diagonal_plus_lowrank(a,B,C,b) 60 | 61 | assert np.isclose(logdet1, logdet2) 62 | assert np.allclose(solve1, solve2) 63 | 64 | 65 | def test_cython_diagonal_plus_lowrank(): 66 | for _ in range(5): 67 | n, p, k = randint(1,10), randint(1,10), randint(1,10) 68 | a = rand(p) 69 | B = randn(p,n) 70 | b = randn(p,k) 71 | C = np.eye(n) if rand() < 0.5 else rand_psd(n) 72 | 73 | yield check_diagonal_plus_lowrank, a, B, C, b 74 | 75 | 76 | # test condition_on_diagonal 77 | 78 | def cython_condition_on_diagonal(mu_x, sigma_x, C, D, sigma_obs, u, y): 79 | mu_cond = np.random.randn(*mu_x.shape) 80 | sigma_cond = np.random.randn(*sigma_x.shape) 81 | ll = test_condition_on_diagonal(mu_x, sigma_x, C, D, sigma_obs, u, y, mu_cond, sigma_cond) 82 | return ll, mu_cond, sigma_cond 83 | 84 | 85 | def check_condition_on_diagonal(mu_x, sigma_x, C, D, sigma_obs, u, y): 86 | def condition_on(mu_x, sigma_x, C, D, sigma_obs, u, y): 87 | sigma_xy = sigma_x.dot(C.T) 88 | sigma_yy = C.dot(sigma_x).dot(C.T) + np.diag(sigma_obs) 89 | mu_y = C.dot(mu_x) + D.dot(u) 90 | mu = mu_x + sigma_xy.dot(np.linalg.solve(sigma_yy, y - mu_y)) 91 | sigma = sigma_x - sigma_xy.dot(np.linalg.solve(sigma_yy,sigma_xy.T)) 92 | 93 | ll = multivariate_normal.logpdf(y,mu_y,sigma_yy) 94 | 95 | return ll, mu, sigma 96 | 97 | py_ll, py_mu, py_sigma = condition_on(mu_x, sigma_x, C, D, sigma_obs, u, y) 98 | cy_ll, cy_mu, cy_sigma = cython_condition_on_diagonal(mu_x, sigma_x, C, D, sigma_obs, u, y) 99 | 100 | assert np.allclose(py_sigma, cy_sigma) 101 | assert np.allclose(py_mu, cy_mu) 102 | assert np.isclose(py_ll, cy_ll) 103 | 104 | 105 | def test_cython_condition_on_diagonal(): 106 | for _ in range(1): 107 | n, p, d = randint(1,10), randint(1,10), 1 108 | mu_x = randn(n) 109 | sigma_x = rand_psd(n) 110 | C = randn(p,n) 111 | D = randn(p,d) 112 | sigma_obs = rand(p) 113 | u = randn(d) 114 | y = randn(p) 115 | 116 | yield check_condition_on_diagonal, mu_x, sigma_x, C, D, sigma_obs, u, y 117 | 118 | 119 | # test filter_and_sample 120 | 121 | def check_filter_and_sample(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init, inputs, data): 122 | rngstate = np.random.get_state() 123 | ll1, sample1 = filter_and_sample( 124 | mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, inputs) 125 | np.random.set_state(rngstate) 126 | ll2, sample2 = filter_and_sample_diagonal( 127 | mu_init, sigma_init, A, B, sigma_states, C, D, np.diag(sigma_obs), inputs, data) 128 | 129 | assert np.isclose(ll1, ll2) 130 | assert np.allclose(sample1, sample2) 131 | 132 | 133 | # def test_filter_and_sample(): 134 | # for _ in range(5): 135 | # n, p, d, T = randint(1,5), randint(1,5), randint(0,5), randint(10,20) 136 | # model = generate_diag_model(n,p,d) 137 | # data, inputs = generate_data(*(model + (T,))) 138 | # yield (check_filter_and_sample,) + model + (inputs, data) 139 | 140 | 141 | -------------------------------------------------------------------------------- /tests/test_infofilter.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | from numpy.random import randn, randint 4 | 5 | from pylds.lds_messages_interface import kalman_filter, kalman_info_filter, \ 6 | E_step, info_E_step 7 | from pylds.lds_info_messages import info_predict_test 8 | from pylds.states import LDSStates 9 | 10 | 11 | ########## 12 | # util # 13 | ########## 14 | 15 | def blockarray(*args,**kwargs): 16 | return np.array(np.bmat(*args,**kwargs),copy=False) 17 | 18 | 19 | def cumsum(l, strict=False): 20 | if not strict: 21 | return np.cumsum(l) 22 | else: 23 | return np.cumsum(l) - l[0] 24 | 25 | 26 | def rand_psd(n,k=None): 27 | k = k if k else n 28 | out = randn(n,k) 29 | return np.atleast_2d(out.dot(out.T)) 30 | 31 | 32 | def blockdiag(mats): 33 | assert all(m.shape[0] == m.shape[1] for m in mats) 34 | ns = [m.shape[0] for m in mats] 35 | starts, stops, n = cumsum(ns,strict=True), cumsum(ns), sum(ns) 36 | out = np.zeros((n,n)) 37 | for start, stop, mat in zip(starts, stops, mats): 38 | out[start:stop,start:stop] = mat 39 | return out 40 | 41 | 42 | def info_to_mean(J,h): 43 | mu, Sigma = np.linalg.solve(J,h), np.linalg.inv(J) 44 | return mu, Sigma 45 | 46 | 47 | def mean_to_info(mu,Sigma): 48 | J = np.linalg.inv(Sigma) 49 | h = np.linalg.solve(Sigma,mu) 50 | return J, h 51 | 52 | 53 | def spectral_radius(A): 54 | return max(np.abs(np.linalg.eigvals(A))) 55 | 56 | 57 | def generate_model(n, p, d): 58 | A = randn(n,n) 59 | A /= 2.*spectral_radius(A) # ensures stability 60 | assert spectral_radius(A) < 1. 61 | B = randn(n,d) 62 | sigma_states = np.random.randn(n,n) 63 | sigma_states = np.dot(sigma_states, sigma_states.T) 64 | 65 | C = randn(p,n) 66 | D = randn(p,d) 67 | sigma_obs = np.random.randn(p,p) 68 | sigma_obs = np.dot(sigma_obs, sigma_obs.T) 69 | 70 | mu_init = randn(n) 71 | sigma_init = rand_psd(n) 72 | 73 | return A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init 74 | 75 | 76 | def generate_data(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init, T): 77 | p, n, d = C.shape[0], C.shape[1], B.shape[1] 78 | x = np.zeros((T+1,n)) 79 | u = np.random.randn(T,d) 80 | out = np.zeros((T,p)) 81 | 82 | Ldyn = np.linalg.cholesky(sigma_states) 83 | Lobs = np.linalg.cholesky(sigma_obs) 84 | 85 | staterandseq = randn(T,n) 86 | emissionrandseq = randn(T,p) 87 | 88 | x[0] = np.random.multivariate_normal(mu_init,sigma_init) 89 | for t in range(T): 90 | x[t+1] = A.dot(x[t]) + B.dot(u[t]) + Ldyn.dot(staterandseq[t]) 91 | out[t] = C.dot(x[t]) + D.dot(u[t]) + Lobs.dot(emissionrandseq[t]) 92 | 93 | return out, u 94 | 95 | 96 | def info_params(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init, data, inputs): 97 | T, n, p = data.shape[0], A.shape[0], C.shape[0] 98 | 99 | J_init = np.linalg.inv(sigma_init) 100 | h_init = np.linalg.solve(sigma_init, mu_init) 101 | log_Z_init = -1./2 * h_init.dot(np.linalg.solve(J_init, h_init)) 102 | log_Z_init += 1./2 * np.linalg.slogdet(J_init)[1] 103 | log_Z_init -= n/2. * np.log(2*np.pi) 104 | 105 | J_pair_11 = A.T.dot(np.linalg.solve(sigma_states,A)) 106 | J_pair_21 = -np.linalg.solve(sigma_states,A) 107 | J_pair_22 = np.linalg.inv(sigma_states) 108 | 109 | h_pair_1 = -inputs[:-1].dot(B.T).dot(np.linalg.solve(sigma_states,A)) 110 | h_pair_2 = inputs[:-1].dot(np.linalg.solve(sigma_states, B).T) 111 | 112 | log_Z_pair = -1. / 2 * np.linalg.slogdet(sigma_states)[1] 113 | log_Z_pair -= n / 2. * np.log(2 * np.pi) 114 | hJh_pair = B.T.dot(np.linalg.solve(sigma_states, B)) 115 | log_Z_pair -= 1. / 2 * np.einsum('ij,ti,tj->t', hJh_pair, inputs[:-1], inputs[:-1]) 116 | 117 | J_node = C.T.dot(np.linalg.solve(sigma_obs,C)) 118 | h_node = np.einsum('ik,ij,tj->tk', 119 | C, np.linalg.inv(sigma_obs), 120 | (data - inputs.dot(D.T))) 121 | 122 | log_Z_node = -p / 2. * np.log(2 * np.pi) * np.ones(T) 123 | log_Z_node -= 1. / 2 * np.linalg.slogdet(sigma_obs)[1] 124 | log_Z_node -= 1. / 2 * np.einsum('ij,ti,tj->t', 125 | np.linalg.inv(sigma_obs), 126 | data - inputs.dot(D.T), 127 | data - inputs.dot(D.T)) 128 | 129 | return J_init, h_init, log_Z_init, \ 130 | J_pair_11, J_pair_21, J_pair_22, h_pair_1, h_pair_2, log_Z_pair,\ 131 | J_node, h_node, log_Z_node 132 | 133 | 134 | def dense_infoparams(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init, data, inputs): 135 | p, n = C.shape 136 | T = data.shape[0] 137 | 138 | J_init, h_init, logZ_init, \ 139 | J_pair_11, J_pair_21, J_pair_22, h_pair_1, h_pair_2, log_Z_pair, \ 140 | J_node, h_node, log_Z_node = \ 141 | info_params(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init, data, inputs) 142 | 143 | h = h_node 144 | h[0] += h_init 145 | h[:-1] += h_pair_1[:-1] 146 | h[1:] += h_pair_2[:-1] 147 | h = h.ravel() 148 | 149 | J = np.kron(np.eye(T), J_node) 150 | pairblock = blockarray([[J_pair_11, J_pair_21.T], [J_pair_21, J_pair_22]]) 151 | for t in range(0,n*(T-1),n): 152 | J[t:t+2*n,t:t+2*n] += pairblock 153 | J[:n, :n] += J_init 154 | 155 | assert J.shape == (T*n, T*n) 156 | assert h.shape == (T*n,) 157 | 158 | return J, h 159 | 160 | 161 | ########################## 162 | # testing info_predict # 163 | ########################## 164 | 165 | def py_info_predict(J,h,J11,J21,J22,h1,h2,logZ): 166 | Jnew = J + J11 167 | Jpredict = J22 - J21.dot(np.linalg.solve(Jnew,J21.T)) 168 | hnew = h + h1 169 | hpredict = h2-J21.dot(np.linalg.solve(Jnew,hnew)) 170 | lognorm = -1./2*np.linalg.slogdet(Jnew)[1] + 1./2*hnew.dot(np.linalg.solve(Jnew,hnew)) \ 171 | + J.shape[0]/2.*np.log(2*np.pi) + logZ 172 | return Jpredict, hpredict, lognorm 173 | 174 | 175 | def py_info_predict2(J,h,J11,J21,J22,h1,h2,logZ): 176 | n = J.shape[0] 177 | bigJ = blockarray([[J11, J21.T], [J21, J22]]) + blockdiag([J, np.zeros_like(J)]) 178 | bigh = np.concatenate([h+h1,h2]) 179 | mu, Sigma = info_to_mean(bigJ, bigh) 180 | Jpredict = np.linalg.inv(Sigma[n:,n:]) 181 | hpredict = np.linalg.solve(Sigma[n:,n:],mu[n:]) 182 | return Jpredict, hpredict 183 | 184 | 185 | def cy_info_predict(J,h,J11,J21,J22,h1,h2,logZ): 186 | Jpredict = np.zeros_like(J) 187 | hpredict = np.zeros_like(h) 188 | 189 | lognorm = info_predict_test(J,h,J11,J21,J22,h1,h2,logZ,Jpredict,hpredict) 190 | 191 | return Jpredict, hpredict, lognorm 192 | 193 | 194 | def check_info_predict(J,h,J11,J21,J22,h1,h2,logZ): 195 | py_Jpredict, py_hpredict, py_lognorm = py_info_predict(J,h,J11,J21,J22,h1,h2,logZ) 196 | cy_Jpredict, cy_hpredict, cy_lognorm = cy_info_predict(J,h,J11,J21,J22,h1,h2,logZ) 197 | 198 | assert np.allclose(py_Jpredict, cy_Jpredict) 199 | assert np.allclose(py_hpredict, cy_hpredict) 200 | assert np.allclose(py_lognorm, cy_lognorm) 201 | 202 | py2_Jpredict, py2_hpredict = py_info_predict2(J,h,J11,J21,J22,h1,h2,logZ) 203 | assert np.allclose(py2_Jpredict, cy_Jpredict) 204 | assert np.allclose(py2_hpredict, cy_hpredict) 205 | 206 | 207 | def test_info_predict(): 208 | for _ in range(5): 209 | n = randint(1,20) 210 | J = rand_psd(n) 211 | h = randn(n) 212 | 213 | bigJ = rand_psd(2*n) 214 | J11, J21, J22 = map(np.copy,[bigJ[:n,:n], bigJ[n:,:n], bigJ[n:,n:]]) 215 | 216 | h1 = randn(n) 217 | h2 = randn(n) 218 | 219 | logZ = randn() 220 | 221 | yield check_info_predict, J, h, J11, J21, J22, h1, h2, logZ 222 | 223 | 224 | 225 | #################################### 226 | # test against distribution form # 227 | #################################### 228 | 229 | def check_filters(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init, data, inputs): 230 | ll, filtered_mus, filtered_sigmas = kalman_filter( 231 | mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, data) 232 | 233 | ll2, filtered_Js, filtered_hs = kalman_info_filter( 234 | *info_params(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init, data, inputs)) 235 | 236 | filtered_mus2 = [np.linalg.solve(J,h) for J, h in zip(filtered_Js, filtered_hs)] 237 | 238 | filtered_sigmas2 = [np.linalg.inv(J) for J in filtered_Js] 239 | 240 | assert all(np.allclose(mu1, mu2) 241 | for mu1, mu2 in zip(filtered_mus, filtered_mus2)) 242 | assert all(np.allclose(s1, s2) 243 | for s1, s2 in zip(filtered_sigmas, filtered_sigmas2)) 244 | assert np.isclose(ll, ll2) 245 | 246 | 247 | def test_info_filter(): 248 | for _ in range(1): 249 | n, p, d, T = randint(1,5), randint(1,5), 1, randint(10,20) 250 | model = generate_model(n,p,d) 251 | data, inputs = generate_data(*(model + (T,))) 252 | yield (check_filters,) + model + (data,inputs) 253 | 254 | 255 | def check_info_Estep(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init, inputs, data): 256 | ll, smoothed_mus, smoothed_sigmas, ExnxT = E_step( 257 | mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, data) 258 | ll2, smoothed_mus2, smoothed_sigmas2, ExnxT2 = info_E_step( 259 | *info_params(A, B, sigma_states, C, D, sigma_obs, mu_init, sigma_init, data, inputs)) 260 | 261 | assert np.isclose(ll,ll2) 262 | assert np.allclose(smoothed_mus, smoothed_mus2) 263 | assert np.allclose(smoothed_sigmas, smoothed_sigmas2) 264 | assert np.allclose(ExnxT, ExnxT2) 265 | 266 | 267 | def test_info_Estep(): 268 | for _ in range(5): 269 | n, p, d, T = randint(1, 5), randint(1, 5), 1, randint(10, 20) 270 | model = generate_model(n, p, d) 271 | data, inputs = generate_data(*(model + (T,))) 272 | yield (check_info_Estep,) + model + (inputs, data) 273 | -------------------------------------------------------------------------------- /tests/test_laplace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pylds.models import DefaultPoissonLDS, DefaultBernoulliLDS 3 | from pylds.laplace import LaplaceApproxPoissonLDSStates, LaplaceApproxBernoulliLDSStates 4 | 5 | from nose.tools import nottest 6 | 7 | 8 | def correct_gradient(states): 9 | x = np.random.randn(states.T, states.D_latent) 10 | g_sparse = states.gradient_log_joint(x) 11 | g_full = states.test_gradient_log_joint(x) 12 | assert np.allclose(g_sparse, g_full) 13 | 14 | 15 | def correct_hessian(states): 16 | x = np.random.randn(states.T, states.D_latent) 17 | H_diag, H_upper_diag = states.sparse_hessian_log_joint(x) 18 | H_full = states.test_hessian_log_joint(x) 19 | 20 | for t in range(states.T): 21 | assert np.allclose(H_full[t,:,t,:], H_diag[t]) 22 | if t < states.T - 1: 23 | assert np.allclose(H_full[t,:,t+1,:], H_upper_diag[t]) 24 | assert np.allclose(H_full[t+1,:,t,:], H_upper_diag[t].T) 25 | 26 | 27 | def correct_hessian_vector_product(states): 28 | T, D = states.T, states.D_latent 29 | x = np.random.randn(T, D) 30 | H_full = states.test_hessian_log_joint(x) 31 | v = np.random.randn(T, D) 32 | 33 | hvp1 = H_full.reshape((T*D, T*D)).dot(v.ravel()) 34 | hvp2 = states.hessian_vector_product_log_joint(x, v) 35 | assert np.allclose(hvp1.reshape((T, D)), hvp2) 36 | 37 | 38 | def correct_laplace_approximation_bfgs(states): 39 | xhat = states.laplace_approximation(method="bfgs", verbose=True) 40 | g = states.gradient_log_joint(xhat) 41 | assert np.allclose(g, 0, atol=1e-2) 42 | 43 | def correct_laplace_approximation_newton(states): 44 | xhat = states.laplace_approximation(method="newton", verbose=True) 45 | g = states.gradient_log_joint(xhat) 46 | assert np.allclose(g, 0, atol=1e-2) 47 | 48 | @nottest 49 | def test_laplace_approximation_newton_largescale(): 50 | T = 50000 51 | N = 100 52 | D = 10 53 | D_input = 1 54 | model = DefaultPoissonLDS(N, D, D_input=D_input) 55 | data = np.random.poisson(3.0, size=(T, N)) 56 | inputs = np.random.randn(T, D_input) 57 | states = LaplaceApproxPoissonLDSStates(model, data=data, inputs=inputs) 58 | states.gaussian_states *= 0 59 | 60 | xhat = states.laplace_approximation(method="newton", stepsz=.99, verbose=True) 61 | g = states.gradient_log_joint(xhat) 62 | assert np.allclose(g, 0, atol=1e-2) 63 | 64 | 65 | def check_random_poisson_states(check): 66 | T = np.random.randint(25, 200) 67 | N = np.random.randint(10, 20) 68 | D = np.random.randint(1, 10) 69 | D_input = np.random.randint(0, 2) 70 | 71 | model = DefaultPoissonLDS(N, D, D_input=D_input) 72 | data = np.random.poisson(3.0, size=(T, N)) 73 | inputs = np.random.randn(T, D_input) 74 | states = LaplaceApproxPoissonLDSStates(model, data=data, inputs=inputs) 75 | states.gaussian_states *= 0 76 | 77 | check(states) 78 | 79 | 80 | def check_random_bernoulli_states(check): 81 | T = np.random.randint(25, 200) 82 | N = np.random.randint(10, 20) 83 | D = np.random.randint(1, 10) 84 | D_input = np.random.randint(0, 2) 85 | 86 | model = DefaultBernoulliLDS(N, D, D_input=D_input) 87 | data = np.random.rand(T, N) 88 | inputs = np.random.randn(T, D_input) 89 | states = LaplaceApproxBernoulliLDSStates(model, data=data, inputs=inputs) 90 | states.gaussian_states *= 0 91 | 92 | check(states) 93 | 94 | 95 | ### Poisson tests 96 | def test_poisson_gradients(): 97 | for _ in range(5): 98 | yield check_random_poisson_states, correct_gradient 99 | 100 | 101 | def test_poisson_hessian(): 102 | for _ in range(5): 103 | yield check_random_poisson_states, correct_hessian 104 | 105 | 106 | def test_poisson_hessian_vector_product(): 107 | for _ in range(5): 108 | yield check_random_poisson_states, correct_hessian_vector_product 109 | 110 | 111 | def test_poisson_laplace_approximation_bfgs(): 112 | for _ in range(5): 113 | yield check_random_poisson_states, correct_laplace_approximation_bfgs 114 | 115 | 116 | def test_poisson_laplace_approximation_newton(): 117 | for _ in range(5): 118 | yield check_random_poisson_states, correct_laplace_approximation_newton 119 | 120 | 121 | ### Bernoulli tests 122 | def test_bernoulli_gradients(): 123 | for _ in range(5): 124 | yield check_random_bernoulli_states, correct_gradient 125 | 126 | 127 | def test_bernoulli_hessian(): 128 | for _ in range(5): 129 | yield check_random_bernoulli_states, correct_hessian 130 | 131 | 132 | def test_bernoulli_hessian_vector_product(): 133 | for _ in range(5): 134 | yield check_random_bernoulli_states, correct_hessian_vector_product 135 | 136 | 137 | def test_bernoulli_laplace_approximation_bfgs(): 138 | for _ in range(5): 139 | yield check_random_bernoulli_states, correct_laplace_approximation_bfgs 140 | 141 | 142 | def test_bernoulli_laplace_approximation_newton(): 143 | for _ in range(5): 144 | yield check_random_bernoulli_states, correct_laplace_approximation_newton 145 | -------------------------------------------------------------------------------- /tests/test_randomwalk.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | from numpy.random import randn, rand, randint 4 | 5 | from pylds.lds_messages_interface import filter_and_sample, filter_and_sample_randomwalk 6 | 7 | 8 | ########## 9 | # util # 10 | ########## 11 | 12 | def generate_model(n): 13 | sigmasq_states = rand(n) 14 | sigmasq_obs = rand(n) 15 | 16 | mu_init = randn(n) 17 | sigmasq_init = rand(n) 18 | 19 | return sigmasq_states, sigmasq_obs, mu_init, sigmasq_init 20 | 21 | 22 | def generate_data(sigmasq_states, sigmasq_obs, mu_init, sigmasq_init, T): 23 | n = sigmasq_states.shape[0] 24 | x = np.zeros((T+1,n)) 25 | out = np.zeros((T,n)) 26 | 27 | staterandseq = randn(T,n) 28 | emissionrandseq = randn(T,n) 29 | 30 | x[0] = mu_init + np.sqrt(sigmasq_init)*randn(n) 31 | for t in range(T): 32 | x[t+1] = x[t] + np.sqrt(sigmasq_states)*staterandseq[t] 33 | out[t] = x[t] + np.sqrt(sigmasq_obs)*emissionrandseq[t] 34 | 35 | return out 36 | 37 | 38 | def dense_sample_states(sigmasq_states, sigmasq_obs, mu_init, sigmasq_init, data): 39 | T, n = data.shape 40 | inputs = np.zeros((T,0)) 41 | 42 | # construct corresponding dense model 43 | A = np.eye(n) 44 | B = np.zeros((n,0)) 45 | sigma_states = np.diag(sigmasq_states) 46 | C = np.eye(n) 47 | D = np.zeros((n,0)) 48 | sigma_obs = np.diag(sigmasq_obs) 49 | sigma_init = np.diag(sigmasq_init) 50 | 51 | return filter_and_sample( 52 | mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, data) 53 | 54 | 55 | ##################### 56 | # testing samples # 57 | ##################### 58 | 59 | def check_sample(sigmasq_states, sigmasq_obs, mu_init, sigmasq_init, data): 60 | rngstate = np.random.get_state() 61 | dense_ll, dense_sample = dense_sample_states( 62 | sigmasq_states, sigmasq_obs, mu_init, sigmasq_init, data) 63 | np.random.set_state(rngstate) 64 | rw_ll, rw_sample = filter_and_sample_randomwalk( 65 | mu_init, sigmasq_init, sigmasq_states, sigmasq_obs, data) 66 | 67 | assert np.isclose(dense_ll, rw_ll) 68 | assert np.allclose(dense_sample, rw_sample) 69 | 70 | 71 | ################################## 72 | # test against dense functions # 73 | ################################## 74 | 75 | def test_filter_and_sample(): 76 | for _ in range(5): 77 | n, T = randint(1,10), randint(10,50) 78 | model = generate_model(n) 79 | data = generate_data(*(model + (T,))) 80 | yield (check_sample,) + model + (data,) 81 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pylds.util import symm_block_tridiag_matmul, solve_symm_block_tridiag, \ 4 | logdet_symm_block_tridiag, compute_symm_block_tridiag_covariances, \ 5 | convert_block_tridiag_to_banded, scipy_solve_symm_block_tridiag, \ 6 | transpose_lower_banded_matrix, scipy_sample_block_tridiag, sample_block_tridiag 7 | 8 | def random_symm_block_tridiags(n, d): 9 | """ 10 | Create a random matrix of size (n*d, n*d) with 11 | blocks of size d along (-1, 0, 1) diagonals 12 | """ 13 | assert n > 0 and d > 0 14 | 15 | H_diag = np.random.rand(n, d, d) 16 | H_diag = np.matmul(H_diag, np.swapaxes(H_diag, -1, -2)) 17 | 18 | H_upper_diag = np.random.rand(n-1, d, d) 19 | H_upper_diag = np.matmul(H_upper_diag, np.swapaxes(H_upper_diag, -1, -2)) 20 | return H_diag, H_upper_diag 21 | 22 | def symm_block_tridiags_to_dense(H_diag, H_upper_diag): 23 | n, d, _ = H_diag.shape 24 | H = np.zeros((n*d, n*d)) 25 | for i in range(n): 26 | H[i*d:(i+1)*d, i*d:(i+1)*d] = H_diag[i] 27 | 28 | if i < n-1: 29 | H[i*d:(i+1)*d, (i+1)*d:(i+2)*d] = H_upper_diag[i] 30 | H[(i+1)*d:(i+2)*d, i*d:(i+1)*d] = H_upper_diag[i].T 31 | return H 32 | 33 | def test_symm_block_tridiag_matmul(): 34 | n, d = 10, 3 35 | for _ in range(5): 36 | H_diag, H_upper_diag = random_symm_block_tridiags(n, d) 37 | H = symm_block_tridiags_to_dense(H_diag, H_upper_diag) 38 | v = np.random.randn(n, d) 39 | 40 | out1 = H.dot(v.ravel()).reshape((n, d)) 41 | out2 = symm_block_tridiag_matmul(H_diag, H_upper_diag, v) 42 | assert np.allclose(out1, out2) 43 | 44 | 45 | def test_convert_block_to_banded(): 46 | n, d = 10, 3 47 | for _ in range(5): 48 | H_diag, H_upper_diag = random_symm_block_tridiags(n, d) 49 | H = symm_block_tridiags_to_dense(H_diag, H_upper_diag) 50 | 51 | # Get the true ab matrix 52 | ab_true = np.zeros((2*d, n*d)) 53 | for j in range(2*d): 54 | ab_true[j, :n*d-j] = np.diag(H, -j) 55 | 56 | ab = convert_block_tridiag_to_banded(H_diag, H_upper_diag) 57 | 58 | for j in range(d): 59 | assert np.allclose(ab_true[j], ab[j]) 60 | 61 | 62 | def test_solve_symm_block_tridiag(): 63 | n, d = 10, 3 64 | for _ in range(5): 65 | H_diag, H_upper_diag = random_symm_block_tridiags(n, d) 66 | H = symm_block_tridiags_to_dense(H_diag, H_upper_diag) 67 | 68 | # Make sure H is positive definite 69 | min_ev = np.linalg.eigvalsh(H).min() 70 | if min_ev < 0: 71 | for i in range(n): 72 | H_diag[i] += (-min_ev + 1e-8) * np.eye(d) 73 | H += (-min_ev + 1e-8) * np.eye(n * d) 74 | assert np.allclose(H, symm_block_tridiags_to_dense(H_diag, H_upper_diag)) 75 | assert np.all(np.linalg.eigvalsh(H) > 0) 76 | 77 | # Make random vector to solve against 78 | v = np.random.randn(n, d) 79 | 80 | out1 = np.linalg.solve(H, v.ravel()).reshape((n, d)) 81 | out2 = solve_symm_block_tridiag(H_diag, H_upper_diag, v) 82 | out3 = scipy_solve_symm_block_tridiag(H_diag, H_upper_diag, v) 83 | assert np.allclose(out1, out2) 84 | assert np.allclose(out1, out3) 85 | 86 | 87 | def test_logdet_symm_block_tridiag(): 88 | n, d = 10, 3 89 | for _ in range(5): 90 | H_diag, H_upper_diag = random_symm_block_tridiags(n, d) 91 | H = symm_block_tridiags_to_dense(H_diag, H_upper_diag) 92 | 93 | # Make sure H is positive definite 94 | min_ev = np.linalg.eigvalsh(H).min() 95 | if min_ev < 0: 96 | for i in range(n): 97 | H_diag[i] += (-min_ev + .1) * np.eye(d) 98 | H += (-min_ev + .1) * np.eye(n * d) 99 | assert np.allclose(H, symm_block_tridiags_to_dense(H_diag, H_upper_diag)) 100 | assert np.all(np.linalg.eigvalsh(H) > 0) 101 | 102 | out1 = np.linalg.slogdet(H)[1] 103 | out2 = logdet_symm_block_tridiag(H_diag, H_upper_diag) 104 | assert np.allclose(out1, out2) 105 | 106 | 107 | def test_symm_block_tridiag_covariances(): 108 | n, d = 10, 3 109 | for _ in range(5): 110 | H_diag, H_upper_diag = random_symm_block_tridiags(n, d) 111 | H = symm_block_tridiags_to_dense(H_diag, H_upper_diag) 112 | 113 | # Make sure H is positive definite 114 | min_ev = np.linalg.eigvalsh(H).min() 115 | if min_ev < 0: 116 | for i in range(n): 117 | H_diag[i] += (-min_ev + .1) * np.eye(d) 118 | H += (-min_ev + .1) * np.eye(n * d) 119 | assert np.allclose(H, symm_block_tridiags_to_dense(H_diag, H_upper_diag)) 120 | assert np.all(np.linalg.eigvalsh(H) > 0) 121 | 122 | Sigma = np.linalg.inv(H) 123 | sigmas, E_xtp1_xt = compute_symm_block_tridiag_covariances(H_diag, H_upper_diag) 124 | 125 | for i in range(n): 126 | assert np.allclose(Sigma[i*d:(i+1)*d, i*d:(i+1)*d], sigmas[i]) 127 | 128 | for i in range(n-1): 129 | assert np.allclose(Sigma[(i+1)*d:(i+2)*d, i*d:(i+1)*d], E_xtp1_xt[i]) 130 | 131 | 132 | def test_sample_block_tridiag(): 133 | n, d = 10, 3 134 | for _ in range(5): 135 | H_diag, H_upper_diag = random_symm_block_tridiags(n, d) 136 | H = symm_block_tridiags_to_dense(H_diag, H_upper_diag) 137 | 138 | # Make sure H is positive definite 139 | min_ev = np.linalg.eigvalsh(H).min() 140 | if min_ev < 0: 141 | for i in range(n): 142 | H_diag[i] += (-min_ev + .1) * np.eye(d) 143 | H += (-min_ev + .1) * np.eye(n * d) 144 | assert np.allclose(H, symm_block_tridiags_to_dense(H_diag, H_upper_diag)) 145 | assert np.all(np.linalg.eigvalsh(H) > 0) 146 | 147 | # Cholesky of H 148 | from scipy.linalg import cholesky_banded, solve_banded 149 | L1 = np.linalg.cholesky(H) 150 | Lab = convert_block_tridiag_to_banded(H_diag, H_upper_diag) 151 | L2 = cholesky_banded(Lab, lower=True) 152 | assert np.allclose(np.diag(L1), L2[0]) 153 | for i in range(1, 2*d): 154 | assert np.allclose(np.diag(L1, -i), L2[i, :-i]) 155 | 156 | U1 = L1.T 157 | U2 = transpose_lower_banded_matrix(L2) 158 | Uab = convert_block_tridiag_to_banded(H_diag, H_upper_diag, lower=False) 159 | U3 = cholesky_banded(Uab, lower=False) 160 | assert np.allclose(np.diag(U1), U2[-1]) 161 | assert np.allclose(np.diag(U1), U3[-1]) 162 | for i in range(1, 2 * d): 163 | assert np.allclose(np.diag(U1, i), U2[-(i+1), i:]) 164 | assert np.allclose(np.diag(U1, i), U3[-(i+1), i:]) 165 | 166 | z = np.random.randn(n*d) 167 | x1 = np.linalg.solve(U1, z) 168 | x2 = solve_banded((0, 2*d-1), U2, z) 169 | x3 = scipy_sample_block_tridiag(H_diag, H_upper_diag, z=z) 170 | assert np.allclose(x1, x2) 171 | assert np.allclose(x1, x3) 172 | print("success") 173 | 174 | 175 | 176 | def time_sample_block_tridiag(): 177 | from time import time 178 | n, d, m = 1000, 10, 5 179 | 180 | ds = [5, 10, 25, 50, 100] 181 | ts_scipy = np.zeros_like(ds) 182 | ts_pylds = np.zeros_like(ds) 183 | 184 | for d in ds: 185 | print("timing test: n={} d={}".format(n, d)) 186 | 187 | H_diag = 2 * np.eye(d)[None, :, :].repeat(n, axis=0) 188 | H_upper_diag = np.eye(d)[None, :, :].repeat(n-1, axis=0) 189 | 190 | Uab = convert_block_tridiag_to_banded(H_diag, H_upper_diag, lower=False) 191 | 192 | tic = time() 193 | for _ in range(m): 194 | scipy_sample_block_tridiag(H_diag, H_upper_diag) 195 | print("scipy: {:.4f} sec".format((time() - tic)/m)) 196 | 197 | tic = time() 198 | for _ in range(m): 199 | scipy_sample_block_tridiag(H_diag, H_upper_diag, ab=Uab) 200 | print("scipy (given ab): {:.4f} sec".format((time() - tic)/m)) 201 | 202 | tic = time() 203 | for _ in range(m): 204 | sample_block_tridiag(H_diag, H_upper_diag) 205 | print("message passing: {:.4f} sec".format((time() - tic)/m)) 206 | 207 | 208 | if __name__ == "__main__": 209 | test_symm_block_tridiag_matmul() 210 | test_convert_block_to_banded() 211 | test_solve_symm_block_tridiag() 212 | test_logdet_symm_block_tridiag() 213 | test_symm_block_tridiag_covariances() 214 | test_sample_block_tridiag() 215 | time_sample_block_tridiag() 216 | 217 | --------------------------------------------------------------------------------