├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── environment.yml ├── jsl ├── __init__.py ├── demos │ ├── __init__.py │ ├── bootstrap_filter.py │ ├── bootstrap_filter_maneuver.py │ ├── eekf_logistic_regression.py │ ├── ekf_continuous.py │ ├── ekf_mlp.py │ ├── ekf_mlp_anim.ipynb │ ├── ekf_mlp_anim.py │ ├── ekf_vs_eks_spiral.py │ ├── ekf_vs_ukf_spiral.py │ ├── fixed_lag_smoother.ipynb │ ├── hmm_casino.py │ ├── hmm_casino_em_train.py │ ├── hmm_casino_numpy.py │ ├── hmm_casino_sgd_train │ ├── hmm_casino_sgd_train.py │ ├── hmm_lillypad.py │ ├── kalman_sampling_demo.ipynb │ ├── kf_continuous_circle.py │ ├── kf_parallel.py │ ├── kf_spiral.py │ ├── kf_tracking.py │ ├── lds_sampling_demo.py │ ├── linreg_kf.py │ ├── logreg_biclusters.py │ ├── old-init.py │ ├── pendulum_1d.py │ ├── plot_utils.py │ ├── rbpf_maneuver.py │ ├── sis_vs_smc.py │ ├── superimport_test.py │ └── ukf_mlp.py ├── hmm │ ├── __init__.py │ ├── hmm_casino_test.py │ ├── hmm_lib.py │ ├── hmm_lib_test.py │ ├── hmm_logspace_lib.py │ ├── hmm_logspace_lib_test.py │ ├── hmm_numpy_lib.py │ ├── hmm_utils.py │ ├── old │ │ ├── hmm_discrete_em_lib.py │ │ ├── hmm_discrete_lib.py │ │ ├── hmm_discrete_lib_test.py │ │ ├── hmm_discrete_likelihood_test.py │ │ └── hmm_sgd_lib.py │ └── sparse_lib.py ├── lds │ ├── __init__.py │ ├── cont_kalman_filter.py │ ├── kalman_filter.py │ ├── kalman_filter_test.py │ ├── kalman_filter_with_unknown_noise.py │ ├── kalman_sampler.py │ └── mixture_kalman_filter.py ├── nlds │ ├── __init__.py │ ├── base.py │ ├── bootstrap_filter.py │ ├── continuous_extended_kalman_filter.py │ ├── diagonal_extended_kalman_filter.py │ ├── extended_kalman_filter.py │ ├── extended_kalman_smoother.py │ ├── sequential_monte_carlo.py │ └── unscented_kalman_filter.py └── setup.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | *.mp4 3 | *.pdf 4 | .DS_Store 5 | .vscode/ 6 | hmm_casino_params 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 112 | __pypackages__/ 113 | 114 | # Celery stuff 115 | celerybeat-schedule 116 | celerybeat.pid 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .env 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | 148 | # pytype static type analyzer 149 | .pytype/ 150 | 151 | # Cython debug symbols 152 | cython_debug/ 153 | 154 | # PyCharm 155 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 156 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 157 | # and can be added to the global gitignore or merged into this file. For a more nuclear 158 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 159 | #.idea/ 160 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: JSL 6 | message: >- 7 | If you use this software, please cite it using the 8 | metadata from this file. 9 | type: software 10 | authors: 11 | - given-names: Gerardo 12 | family-names: Duran-Martin 13 | - given-names: Kevin 14 | family-names: Murphy 15 | - given-names: Aleyna 16 | family-names: Kara 17 | repository: 'https://github.com/probml/JSL' 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Probabilistic machine learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JSL: JAX State-Space models (SSM) Library 2 | 3 |

4 | image 5 |

6 | 7 | JSL is a JAX library for Bayesian inference in state space models. 8 | As of 2022-06-28, JSL is **deprecated**. You should use [ssm-jax](https://github.com/probml/ssm-jax). 9 | 10 | 11 | # Installation 12 | 13 | We assume you have already installed [JAX](https://github.com/google/jax#installation) and 14 | [Tensorflow](https://www.tensorflow.org/install), 15 | since the details on how to do this depend on whether you have a CPU, GPU, etc. 16 | (This step is not necessary in Colab.) 17 | 18 | Now install these packages: 19 | 20 | ``` 21 | !pip install --upgrade git+https://github.com/google/flax.git 22 | !pip install --upgrade tensorflow-probability 23 | !pip install git+https://github.com/blackjax-devs/blackjax.git 24 | !pip install git+https://github.com/deepmind/distrax.git 25 | ``` 26 | 27 | Then install JSL: 28 | ``` 29 | !pip install git+https://github.com/probml/jsl 30 | ``` 31 | Alternatively, you can clone the repo locally, into say `~/github/JSL`, and then install it as a package, as follows: 32 | ``` 33 | !git clone https://github.com/probml/JSL.git 34 | cd JSL 35 | !pip install -e . 36 | ``` 37 | 38 | # Running the demos 39 | 40 | You can see how to use the library by looking at some of the demos. 41 | You can run the demos from inside a notebook like this 42 | ``` 43 | %run JSL/jsl/demos/kf_tracking.py 44 | %run JSL/jsl/demos/hmm_casino_em_train.py 45 | ``` 46 | 47 | Or from inside an ipython shell like this 48 | ``` 49 | from jsl.demos import kf_tracking 50 | figdict = kf_tracking.main() 51 | ``` 52 | 53 | Most of the demos create figures. If you want to save them (in both png and pdf format), 54 | you need to specify the FIGDIR environment variable, like this: 55 | ``` 56 | import os 57 | os.environ["FIGDIR"]='/Users/kpmurphy/figures' 58 | 59 | from jsl.demos.plot_utils import savefig 60 | savefig(figdict) 61 | ``` 62 | 63 | # Authors 64 | 65 | Gerardo Durán-Martín ([@gerdm](https://github.com/gerdm)), Aleyna Kara([@karalleyna](https://github.com/karalleyna)), Kevin Murphy ([@murphyk](https://github.com/murphyk)), Giles Harper-Donnelly ([@gileshd](https://github.com/gileshd)), Peter Chang ([@petergchang](https://github.com/petergchang)). 66 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: jsl 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=4.5=1_gnu 7 | - blas=1.0=mkl 8 | - brotli=1.0.9=he6710b0_2 9 | - ca-certificates=2021.10.26=h06a4308_2 10 | - certifi=2021.10.8=py39h06a4308_0 11 | - cycler=0.11.0=pyhd3eb1b0_0 12 | - dbus=1.13.18=hb2f20db_0 13 | - expat=2.4.1=h2531618_2 14 | - fontconfig=2.13.1=h6c09931_0 15 | - fonttools=4.25.0=pyhd3eb1b0_0 16 | - freetype=2.11.0=h70c0345_0 17 | - giflib=5.2.1=h7b6447c_0 18 | - glib=2.69.1=h5202010_0 19 | - gst-plugins-base=1.14.0=h8213a91_2 20 | - gstreamer=1.14.0=h28cd5cc_2 21 | - icu=58.2=he6710b0_3 22 | - intel-openmp=2021.4.0=h06a4308_3561 23 | - jpeg=9d=h7f8727e_0 24 | - kiwisolver=1.3.1=py39h2531618_0 25 | - lcms2=2.12=h3be6417_0 26 | - ld_impl_linux-64=2.35.1=h7274673_9 27 | - libffi=3.3=he6710b0_2 28 | - libgcc-ng=9.3.0=h5101ec6_17 29 | - libgomp=9.3.0=h5101ec6_17 30 | - libpng=1.6.37=hbc83047_0 31 | - libstdcxx-ng=9.3.0=hd4cf53a_17 32 | - libtiff=4.2.0=h85742a9_0 33 | - libuuid=1.0.3=h7f8727e_2 34 | - libwebp=1.2.0=h89dd481_0 35 | - libwebp-base=1.2.0=h27cfd23_0 36 | - libxcb=1.14=h7b6447c_0 37 | - libxml2=2.9.12=h03d6c58_0 38 | - lz4-c=1.9.3=h295c915_1 39 | - matplotlib=3.5.0=py39h06a4308_0 40 | - matplotlib-base=3.5.0=py39h3ed280b_0 41 | - mkl=2021.4.0=h06a4308_640 42 | - mkl-service=2.4.0=py39h7f8727e_0 43 | - mkl_fft=1.3.1=py39hd3c417c_0 44 | - mkl_random=1.2.2=py39h51133e4_0 45 | - munkres=1.1.4=py_0 46 | - ncurses=6.3=h7f8727e_2 47 | - numpy-base=1.21.2=py39h79a1101_0 48 | - olefile=0.46=pyhd3eb1b0_0 49 | - openssl=1.1.1l=h7f8727e_0 50 | - packaging=21.3=pyhd3eb1b0_0 51 | - pcre=8.45=h295c915_0 52 | - pillow=8.4.0=py39h5aabda8_0 53 | - pip=21.2.4=py39h06a4308_0 54 | - pyparsing=3.0.4=pyhd3eb1b0_0 55 | - pyqt=5.9.2=py39h2531618_6 56 | - python=3.9.7=h12debd9_1 57 | - python-dateutil=2.8.2=pyhd3eb1b0_0 58 | - qt=5.9.7=h5867ecd_1 59 | - readline=8.1=h27cfd23_0 60 | - setuptools=58.0.4=py39h06a4308_0 61 | - sip=4.19.13=py39h2531618_0 62 | - six=1.16.0=pyhd3eb1b0_0 63 | - sqlite=3.36.0=hc218d9a_0 64 | - tk=8.6.11=h1ccaba5_0 65 | - tornado=6.1=py39h27cfd23_0 66 | - tzdata=2021e=hda174b7_0 67 | - wheel=0.37.0=pyhd3eb1b0_1 68 | - xz=5.2.5=h7b6447c_0 69 | - zlib=1.2.11=h7b6447c_3 70 | - zstd=1.4.9=haebb681_0 71 | - pip: 72 | - absl-py==1.0.0 73 | - charset-normalizer==2.0.9 74 | - fire==0.4.0 75 | - flatbuffers==2.0 76 | - idna==3.3 77 | - jax==0.2.26 78 | - jaxlib==0.1.75 79 | - libtpu-nightly==0.1.dev20211208 80 | - numpy==1.21.4 81 | - opt-einsum==3.3.0 82 | - requests==2.26.0 83 | - scipy==1.7.3 84 | - termcolor==1.1.0 85 | - typing-extensions==4.0.1 86 | - urllib3==1.26.7 87 | prefix: /home/gerardoduran/miniconda3/envs/jsl 88 | -------------------------------------------------------------------------------- /jsl/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /jsl/demos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/JSL/649c1fa9709c9195d80f06d7a367112d67dc60c9/jsl/demos/__init__.py -------------------------------------------------------------------------------- /jsl/demos/bootstrap_filter.py: -------------------------------------------------------------------------------- 1 | # Demo of the bootstrap filter under a 2 | # nonlinear discrete system 3 | 4 | import jax 5 | from jsl.nlds.base import NLDS 6 | from jsl.nlds.bootstrap_filter import filter 7 | import jax.numpy as jnp 8 | import matplotlib.pyplot as plt 9 | from jax import random 10 | 11 | 12 | def plot_samples(sample_state, sample_obs, ax=None): 13 | fig, ax = plt.subplots() 14 | ax.plot(*sample_state.T, label="state space") 15 | ax.scatter(*sample_obs.T, s=60, c="tab:green", marker="+") 16 | ax.scatter(*sample_state[0], c="black", zorder=3) 17 | ax.legend() 18 | ax.set_title("Noisy observations from hidden trajectory") 19 | plt.axis("equal") 20 | return fig 21 | 22 | 23 | def plot_inference(sample_obs, mean_hist): 24 | fig, ax = plt.subplots() 25 | ax.scatter(*sample_obs.T, marker="+", color="tab:green", s=60) 26 | ax.plot(*mean_hist.T, c="tab:orange", label="filtered") 27 | ax.scatter(*mean_hist[0], c="black", zorder=3) 28 | plt.legend() 29 | plt.axis("equal") 30 | return fig 31 | 32 | def main(): 33 | def fz(x, dt): return x + dt * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])]) 34 | def fx(x): return x 35 | 36 | dt = 0.4 37 | nsteps = 100 38 | # Initial state vector 39 | x0 = jnp.array([1.5, 0.0]) 40 | # State noise 41 | Qt = jnp.eye(2) * 0.001 42 | # Observed noise 43 | Rt = jnp.eye(2) * 0.05 44 | 45 | key = random.PRNGKey(314) 46 | model = NLDS(lambda x: fz(x, dt), fx, Qt, Rt) 47 | sample_state, sample_obs = model.sample(key, x0, nsteps) 48 | 49 | n_particles = 3_000 50 | fz_vec = jax.vmap(fz, in_axes=(0, None)) 51 | particle_filter = NLDS(lambda x: fz_vec(x, dt), fx, Qt, Rt) 52 | pf_mean = filter(particle_filter, key, x0, sample_obs, n_particles) 53 | 54 | dict_figures = {} 55 | fig_boostrap = plot_inference(sample_obs, pf_mean) 56 | dict_figures["nlds2d_bootstrap"] = fig_boostrap 57 | 58 | fig_data = plot_samples(sample_state, sample_obs) 59 | dict_figures["nlds2d_data"] = fig_data 60 | 61 | return dict_figures 62 | 63 | if __name__ == "__main__": 64 | from jsl.demos.plot_utils import savefig 65 | plt.rcParams["axes.spines.right"] = False 66 | plt.rcParams["axes.spines.top"] = False 67 | dict_figures = main() 68 | savefig(dict_figures) 69 | plt.show() 70 | -------------------------------------------------------------------------------- /jsl/demos/bootstrap_filter_maneuver.py: -------------------------------------------------------------------------------- 1 | # Bootstrap filtering for jump markov linear systems 2 | # Compare to result in the Rao-Blackwelli particle filter 3 | # (see https://github.com/probml/pyprobml/blob/master/scripts/rbpf_maneuver_demo.py) 4 | 5 | # !pip install matplotlib==3.4.2 6 | 7 | 8 | import jax 9 | import numpy as np 10 | import jax.numpy as jnp 11 | import matplotlib.pyplot as plt 12 | import jsl.lds.mixture_kalman_filter as kflib 13 | from jax import random 14 | from functools import partial 15 | from jax.scipy.special import logit 16 | from jax.scipy.stats import multivariate_normal 17 | from jsl.demos.plot_utils import kdeg, style3d 18 | 19 | 20 | def plot_3d_belief_state(mu_hist, dim, ax, skip=3, npoints=2000, azimuth=-30, elevation=30, h=0.5): 21 | nsteps = len(mu_hist) 22 | xmin, xmax = mu_hist[..., dim].min(), mu_hist[..., dim].max() 23 | xrange = jnp.linspace(xmin, xmax, npoints).reshape(-1, 1) 24 | res = np.apply_along_axis(lambda X: kdeg(xrange, X[..., None], h), 1, mu_hist) 25 | densities = res[..., dim] 26 | for t in range(0, nsteps, skip): 27 | tloc = t * np.ones(npoints) 28 | px = densities[t] 29 | ax.plot(tloc, xrange, px, c="tab:blue", linewidth=1) 30 | ax.set_zlim(0, 1) 31 | style3d(ax, 1.8, 1.2, 0.7, 0.8) 32 | ax.view_init(elevation, azimuth) 33 | ax.set_xlabel(r"$t$", fontsize=13) 34 | ax.set_ylabel(r"$x_{"f"d={dim}"",t}$", fontsize=13) 35 | ax.set_zlabel(r"$p(x_{d, t} \vert y_{1:t})$", fontsize=13) 36 | 37 | 38 | def bootstrap(state, y, A, B, C, Q, transition_matrix, nparticles): 39 | latent_t, state_t, key = state 40 | key_latent, key_state, key_reindex, key_next = random.split(key, 4) 41 | 42 | # Discrete states 43 | latent_t = random.categorical(key_latent, jnp.log(transition_matrix[latent_t]), shape=(nparticles,)) 44 | # Continous states 45 | state_mean = jnp.einsum("nm,sm->sn", A, state_t) + B[latent_t] 46 | state_t = random.multivariate_normal(key_state, mean=state_mean, cov=Q) 47 | 48 | # Compute weights 49 | weights_t = multivariate_normal.pdf(y, mean=state_t, cov=C) 50 | indices_t = random.categorical(key_reindex, jnp.log(weights_t), shape=(nparticles,)) 51 | 52 | # Reindex and compute weights 53 | state_t = state_t[indices_t, ...] 54 | latent_t = latent_t[indices_t, ...] 55 | # weights_t = jnp.ones(nparticles) / nparticles 56 | 57 | mu_t = state_t.mean(axis=0) 58 | 59 | return (latent_t, state_t, key_next), (mu_t, latent_t, state_t) 60 | 61 | 62 | def main(): 63 | TT = 0.1 64 | A = jnp.array([[1, TT, 0, 0], 65 | [0, 1, 0, 0], 66 | [0, 0, 1, TT], 67 | [0, 0, 0, 1]]) 68 | 69 | 70 | B1 = jnp.array([0, 0, 0, 0]) 71 | B2 = jnp.array([-1.225, -0.35, 1.225, 0.35]) 72 | B3 = jnp.array([1.225, 0.35, -1.225, -0.35]) 73 | B = jnp.stack([B1, B2, B3], axis=0) 74 | 75 | Q = 0.2 * jnp.eye(4) 76 | R = 10 * jnp.diag(jnp.array([2, 1, 2, 1])) 77 | C = jnp.eye(4) 78 | 79 | transition_matrix = jnp.array([ 80 | [0.9, 0.05, 0.05], 81 | [0.05, 0.9, 0.05], 82 | [0.05, 0.05, 0.9] 83 | ]) 84 | 85 | transition_matrix = jnp.array([ 86 | [0.8, 0.1, 0.1], 87 | [0.1, 0.8, 0.1], 88 | [0.1, 0.1, 0.8] 89 | ]) 90 | 91 | # We sample from a Rao-Blackwell Kalman Filter 92 | params = kflib.RBPFParamsDiscrete(A, B, C, Q, R, transition_matrix) 93 | 94 | nparticles = 1000 95 | nsteps = 100 96 | key = random.PRNGKey(1) 97 | keys = random.split(key, nsteps) 98 | 99 | x0 = (1, random.multivariate_normal(key, jnp.zeros(4), jnp.eye(4))) 100 | draw_state_fixed = partial(kflib.draw_state, params=params) 101 | 102 | # Create target dataset 103 | _, (latent_hist, state_hist, obs_hist) = jax.lax.scan(draw_state_fixed, x0, keys) 104 | 105 | 106 | # ** Filtering process ** 107 | nparticles = 5000 108 | key_base = random.PRNGKey(31) 109 | key_mean_init, key_sample, key_state, key_next = random.split(key_base, 4) 110 | 111 | p_init = jnp.array([0.0, 1.0, 0.0]) 112 | mu_0 = jnp.zeros((nparticles, 4)) 113 | s0 = random.categorical(key_state, logit(p_init), shape=(nparticles,)) 114 | init_state = (s0, mu_0, key_base) 115 | 116 | def bootstrap_step(state, y): return bootstrap(state, y, A, B, C, Q, transition_matrix, nparticles) 117 | _, (mu_t_hist, latent_hist, state_hist_particles) = jax.lax.scan(bootstrap_step, init_state, obs_hist) 118 | 119 | estimated_track = mu_t_hist[:, [0, 2]] 120 | particles_track = state_hist_particles[..., [0, 2]] 121 | 122 | bf_mse = ((mu_t_hist - state_hist)[:, [0, 2]] ** 2).mean(axis=0).sum() 123 | color_dict = {0: "tab:green", 1: "tab:red", 2: "tab:blue"} 124 | latent_hist_est = latent_hist.mean(axis=1).round() 125 | color_states_est = [color_dict[state] for state in np.array(latent_hist_est)] 126 | 127 | dict_figures = {} 128 | fig, ax = plt.subplots() 129 | ax.scatter(*estimated_track.T, edgecolors=color_states_est, c="none", s=10) 130 | ax.set_title(f"Bootstrap Filter MSE: {bf_mse:.2f}") 131 | dict_figures["bootstrap-filter-trace"] = fig 132 | 133 | for dim in range(2): 134 | fig = plt.figure() 135 | ax = plt.axes(projection="3d") 136 | plot_3d_belief_state(particles_track, dim, ax, h=1.1, npoints=1000) 137 | ax.autoscale(enable=False, axis='both') 138 | # pml.savefig(f"bootstrap-filter-belief-states-dim{dim}.pdf", pad_inches=0, bbox_inches="tight") 139 | dict_figures[f"bootstrap-filter-belief-states-dim{dim}"] = fig 140 | 141 | return dict_figures 142 | 143 | if __name__ == "__main__": 144 | from jsl.demos.plot_utils import savefig 145 | plt.rcParams["axes.spines.right"] = False 146 | plt.rcParams["axes.spines.top"] = False 147 | dict_figures = main() 148 | savefig(dict_figures, pad_inches=0, bbox_inches="tight") 149 | plt.show() -------------------------------------------------------------------------------- /jsl/demos/eekf_logistic_regression.py: -------------------------------------------------------------------------------- 1 | # Online learning of a 2d binary logistic regression model p(y=1|x,w) = sigmoid(w'x), 2 | # using the Exponential-family Extended Kalman Filter (EEKF) algorithm 3 | # described in "Online natural gradient as a Kalman filter", Y. Ollivier, 2018. 4 | # https://projecteuclid.org/euclid.ejs/1537257630. 5 | 6 | # The latent state corresponds to the current estimate of the regression weights w. 7 | # The observation model has the form 8 | # p(y(t) | w(t), x(t)) propto Gauss(y(t) | h_t(w(t)), R(t)) 9 | # where h_t(w) = sigmoid(w' * x(t)) = p(t) and R(t) = p(t) * (1-p(t)) 10 | 11 | import jax.numpy as jnp 12 | import matplotlib.pyplot as plt 13 | from jax import random 14 | 15 | from jsl.nlds.base import NLDS 16 | from jsl.nlds.extended_kalman_filter import filter 17 | 18 | # Import data and baseline solution 19 | from jsl.demos import logreg_biclusters as demo 20 | 21 | figures, data = demo.main() 22 | X = data["X"] 23 | y = data["y"] 24 | Phi = data["Phi"] 25 | Xspace = data["Xspace"] 26 | Phispace = data["Phispace"] 27 | w_laplace = data["w_laplace"] 28 | 29 | 30 | # jax.config.update("jax_platform_name", "cpu") 31 | # jax.config.update("jax_enable_x64", True) 32 | 33 | def sigmoid(x): return jnp.exp(x) / (1 + jnp.exp(x)) 34 | 35 | 36 | def log_sigmoid(z): return z - jnp.log1p(jnp.exp(z)) 37 | 38 | 39 | def fz(x): return x 40 | 41 | 42 | def fx(w, x): return sigmoid(w[None, :] @ x) 43 | 44 | 45 | def Rt(w, x): return (sigmoid(w @ x) * (1 - sigmoid(w @ x)))[None, None] 46 | 47 | 48 | def main(): 49 | N, M = Phi.shape 50 | n_datapoints, ndims = Phi.shape 51 | 52 | # Predictive domain 53 | xmin, ymin = X.min(axis=0) - 0.1 54 | xmax, ymax = X.max(axis=0) + 0.1 55 | step = 0.1 56 | Xspace = jnp.mgrid[xmin:xmax:step, ymin:ymax:step] 57 | _, nx, ny = Xspace.shape 58 | Phispace = jnp.concatenate([jnp.ones((1, nx, ny)), Xspace]) 59 | 60 | ### EEKF Approximation 61 | mu_t = jnp.zeros(M) 62 | Pt = jnp.eye(M) * 0.0 63 | P0 = jnp.eye(M) * 2.0 64 | 65 | model = NLDS(fz, fx, Pt, Rt) 66 | (w_eekf, P_eekf), eekf_hist = filter(model, mu_t, y, Phi, P0, return_params=["mean", "cov"]) 67 | w_eekf_hist = eekf_hist["mean"] 68 | P_eekf_hist = eekf_hist["cov"] 69 | 70 | ### *** Ploting surface predictive distribution *** 71 | colors = ["black" if el else "white" for el in y] 72 | dict_figures = {} 73 | key = random.PRNGKey(31415) 74 | nsamples = 5000 75 | 76 | # EEKF surface predictive distribution 77 | eekf_samples = random.multivariate_normal(key, w_eekf, P_eekf, (nsamples,)) 78 | Z_eekf = sigmoid(jnp.einsum("mij,sm->sij", Phispace, eekf_samples)) 79 | Z_eekf = Z_eekf.mean(axis=0) 80 | 81 | fig_eekf, ax = plt.subplots() 82 | title = "EEKF Predictive Distribution" 83 | demo.plot_posterior_predictive(ax, X, Xspace, Z_eekf, title, colors) 84 | dict_figures["logistic_regression_surface_eekf"] = fig_eekf 85 | 86 | ### Plot EEKF and Laplace training history 87 | P_eekf_hist_diag = jnp.diagonal(P_eekf_hist, axis1=1, axis2=2) 88 | # P_laplace_diag = jnp.sqrt(jnp.diagonal(SN)) 89 | lcolors = ["black", "tab:blue", "tab:red"] 90 | elements = w_eekf_hist.T, P_eekf_hist_diag.T, w_laplace, lcolors 91 | timesteps = jnp.arange(n_datapoints) + 1 92 | 93 | for k, (wk, Pk, wk_laplace, c) in enumerate(zip(*elements)): 94 | fig_weight_k, ax = plt.subplots() 95 | ax.errorbar(timesteps, wk, jnp.sqrt(Pk), c=c, label=f"$w_{k}$ online (EEKF)") 96 | ax.axhline(y=wk_laplace, c=c, linestyle="dotted", label=f"$w_{k}$ batch (Laplace)", linewidth=3) 97 | 98 | ax.set_xlim(1, n_datapoints) 99 | ax.legend(framealpha=0.7, loc="upper right") 100 | ax.set_xlabel("number samples") 101 | ax.set_ylabel("weights") 102 | plt.tight_layout() 103 | dict_figures[f"logistic_regression_hist_ekf_w{k}"] = fig_weight_k 104 | 105 | print("EEKF weights") 106 | print(w_eekf, end="\n" * 2) 107 | 108 | return dict_figures 109 | 110 | 111 | if __name__ == "__main__": 112 | from jsl.demos.plot_utils import savefig 113 | 114 | figs = main() 115 | savefig(figs) 116 | plt.show() 117 | -------------------------------------------------------------------------------- /jsl/demos/ekf_continuous.py: -------------------------------------------------------------------------------- 1 | # Example of an Extended Kalman Filter using 2 | # a figure-8 nonlinear dynamical system. 3 | # For futher reference and examples see: 4 | # * Section on EKFs in PML vol2 book 5 | # * https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python/blob/master/11-Extended-Kalman-Filters.ipynb 6 | # * Nonlinear Dynamics and Chaos - Steven Strogatz 7 | 8 | from jsl.demos import plot_utils 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import jax.numpy as jnp 12 | from jax import random 13 | 14 | from jsl.nlds.base import NLDS 15 | from jsl.nlds.continuous_extended_kalman_filter import estimate 16 | 17 | 18 | def fz(x): 19 | x, y = x 20 | return jnp.asarray([y, x - x ** 3]) 21 | 22 | 23 | def fx(x): 24 | x, y = x 25 | return jnp.asarray([x, y]) 26 | 27 | 28 | def main(): 29 | dt = 0.01 30 | T = 7 31 | nsamples = 70 32 | x0 = jnp.array([0.5, -0.75]) 33 | 34 | # State noise 35 | Qt = jnp.eye(2) * 0.001 36 | # Observed noise 37 | Rt = jnp.eye(2) * 0.01 38 | 39 | key = random.PRNGKey(314) 40 | ekf = NLDS(fz, fx, Qt, Rt) 41 | sample_state, sample_obs, jump = ekf.sample(key, x0, T, nsamples) 42 | mu_hist, V_hist = estimate(ekf, sample_state, sample_obs, jump, dt) 43 | 44 | vmin, vmax, step = -1.5, 1.5 + 0.5, 0.5 45 | X = np.mgrid[-1:1.5:step, vmin:vmax:step][::-1] 46 | X_dot = jnp.apply_along_axis(fz, 0, X) 47 | 48 | dict_figures = {} 49 | 50 | fig, ax = plt.subplots() 51 | ax.plot(*sample_state.T, label="state space") 52 | ax.scatter(*sample_obs.T, marker="+", c="tab:green", s=60, label="observations") 53 | field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa") 54 | ax.legend() 55 | plt.axis("equal") 56 | ax.set_title("State Space") 57 | dict_figures["ekf-state-space"] = fig 58 | 59 | fig, ax = plt.subplots() 60 | ax.plot(*sample_state.T, c="tab:orange", label="EKF estimation") 61 | ax.scatter(*sample_obs.T, marker="+", s=60, c="tab:green", label="observations") 62 | ax.scatter(*mu_hist[0], c="black", zorder=3) 63 | for mut, Vt in zip(mu_hist[::4], V_hist[::4]): 64 | plot_utils.plot_ellipse(Vt, mut, ax, plot_center=False, alpha=0.9, zorder=3) 65 | plt.legend() 66 | field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa") 67 | ax.legend() 68 | plt.axis("equal") 69 | ax.set_title("Approximate Space") 70 | dict_figures["ekf-estimated-space"] = fig 71 | 72 | return dict_figures 73 | 74 | 75 | if __name__ == "__main__": 76 | from jsl.demos.plot_utils import savefig 77 | 78 | plt.rcParams["axes.spines.right"] = False 79 | plt.rcParams["axes.spines.top"] = False 80 | dict_figures = main() 81 | savefig(dict_figures) 82 | plt.show() 83 | -------------------------------------------------------------------------------- /jsl/demos/ekf_mlp.py: -------------------------------------------------------------------------------- 1 | # Demo showcasing the training of an MLP with a single hidden layer using 2 | # Extended Kalman Filtering (EKF). 3 | # In this demo, we consider the latent state to be the weights of an MLP. 4 | # The observed state at time t is the output of the MLP as influenced by the weights 5 | # at time t-1 and the covariate x[t]. 6 | # The transition function between latent states is the identity function. 7 | # For more information, see 8 | # * Neural Network Training Using Unscented and Extended Kalman Filter 9 | # https://juniperpublishers.com/raej/RAEJ.MS.ID.555568.php 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | import flax.linen as nn 14 | import matplotlib.pyplot as plt 15 | import jsl.nlds.extended_kalman_filter as ekf_lib 16 | from jax.flatten_util import ravel_pytree 17 | from functools import partial 18 | from typing import Sequence 19 | from jsl.nlds.base import NLDS 20 | 21 | class MLP(nn.Module): 22 | features: Sequence[int] 23 | 24 | @nn.compact 25 | def __call__(self, x): 26 | for feat in self.features[:-1]: 27 | x = nn.relu(nn.Dense(feat)(x)) 28 | x = nn.Dense(self.features[-1])(x) 29 | return x 30 | 31 | 32 | def apply(flat_params, x, model, unflatten_fn): 33 | """ 34 | Multilayer Perceptron (MLP) with a single hidden unit and 35 | tanh activation function. The input-unit and the 36 | output-unit are both assumed to be unidimensional 37 | 38 | Parameters 39 | ---------- 40 | W: array(2 * n_hidden + n_hidden + 1) 41 | Unravelled weights of the MLP 42 | x: array(1,) 43 | Singleton element to evaluate the MLP 44 | n_hidden: int 45 | Number of hidden units 46 | 47 | Returns 48 | ------- 49 | * array(1,) 50 | Evaluation of MLP at the specified point 51 | """ 52 | params = unflatten_fn(flat_params) 53 | return model.apply(params, x) 54 | 55 | 56 | def sample_observations(key, f, n_obs, xmin, xmax, x_noise=0.1, y_noise=3.0): 57 | key_x, key_y, key_shuffle = jax.random.split(key, 3) 58 | x_noise = jax.random.normal(key_x, (n_obs,)) * x_noise 59 | y_noise = jax.random.normal(key_y, (n_obs,)) * y_noise 60 | x = jnp.linspace(xmin, xmax, n_obs) + x_noise 61 | y = f(x) + y_noise 62 | X = jnp.c_[x, y] 63 | 64 | shuffled_ixs = jax.random.permutation(key_shuffle, jnp.arange(n_obs)) 65 | X, y = jnp.array(X[shuffled_ixs, :].T) 66 | return X, y 67 | 68 | 69 | def plot_mlp_prediction(key, xobs, yobs, xtest, fw, w, Sw, ax, n_samples=100): 70 | W_samples = jax.random.multivariate_normal(key, w, Sw, (n_samples,)) 71 | sample_yhat = fw(W_samples, xtest[:, None]) 72 | for sample in sample_yhat: # sample curves 73 | ax.plot(xtest, sample, c="tab:gray", alpha=0.07) 74 | ax.plot(xtest, sample_yhat.mean(axis=0)) # mean of posterior predictive 75 | ax.scatter(xobs, yobs, s=14, c="none", edgecolor="black", label="observations", alpha=0.5) 76 | ax.set_xlim(xobs.min(), xobs.max()) 77 | 78 | 79 | def plot_intermediate_steps(key, ax, fwd_func, intermediate_steps, xtest, mu_hist, Sigma_hist, x, y): 80 | """ 81 | Plot the intermediate steps of the training process, all of them in the same plot 82 | but in different subplots. 83 | """ 84 | for step, axi in zip(intermediate_steps, ax.flatten()): 85 | W_step, SW_step = mu_hist[step], Sigma_hist[step] 86 | x_step, y_step = x[:step], y[:step] 87 | plot_mlp_prediction(key, x_step, y_step, xtest, fwd_func, W_step, SW_step, axi) 88 | axi.set_title(f"step={step}") 89 | plt.tight_layout() 90 | 91 | 92 | def plot_intermediate_steps_single(key, method, fwd_func, intermediate_steps, xtest, mu_hist, Sigma_hist, x, y): 93 | """ 94 | Plot the intermediate steps of the training process, each one in a different plot. 95 | """ 96 | figures = {} 97 | for step in intermediate_steps: 98 | W_step, SW_step = mu_hist[step], Sigma_hist[step] 99 | x_step, y_step = x[:step], y[:step] 100 | fig_step, axi = plt.subplots() 101 | plot_mlp_prediction(key, x_step, y_step, xtest, fwd_func, W_step, SW_step, axi) 102 | axi.set_title(f"step={step}") 103 | plt.tight_layout() 104 | figname = f"{method}-mlp-step-{step}" 105 | figures[figname] = fig_step 106 | return figures 107 | 108 | 109 | def f(x): 110 | return x - 10 * jnp.cos(x) * jnp.sin(x) + x ** 3 111 | 112 | 113 | def fz(W): 114 | return W 115 | 116 | 117 | def main(): 118 | key = jax.random.PRNGKey(314) 119 | key_sample_obs, key_weights, key_init = jax.random.split(key, 3) 120 | 121 | all_figures = {} 122 | 123 | # *** MLP configuration *** 124 | n_hidden = 6 125 | n_out = 1 126 | n_in = 1 127 | model = MLP([n_hidden, n_out]) 128 | 129 | batch_size = 20 130 | batch = jnp.ones((batch_size, n_in)) 131 | 132 | variables = model.init(key_init, batch) 133 | W0, unflatten_fn = ravel_pytree(variables) 134 | 135 | 136 | fwd_mlp = partial(apply, model=model, unflatten_fn=unflatten_fn) 137 | # vectorised for multiple observations 138 | fwd_mlp_obs = jax.vmap(fwd_mlp, in_axes=[None, 0]) 139 | # vectorised for multiple observations and weights 140 | fwd_mlp_obs_weights = jax.vmap(fwd_mlp_obs, in_axes=[0, None]) 141 | 142 | # *** Generating training and test data *** 143 | n_obs = 200 144 | xmin, xmax = -3, 3 145 | sigma_y = 3.0 146 | x, y = sample_observations(key_sample_obs, f, n_obs, xmin, xmax, x_noise=0, y_noise=sigma_y) 147 | xtest = jnp.linspace(x.min(), x.max(), n_obs) 148 | 149 | # *** MLP Training with EKF *** 150 | n_params = W0.size 151 | W0 = jax.random.normal(key_weights, (n_params,)) * 1 # initial random guess 152 | Q = jnp.eye(n_params) * 1e-4 # parameters do not change 153 | R = jnp.eye(1) * sigma_y ** 2 # observation noise is fixed 154 | Vinit = jnp.eye(n_params) * 100 # vague prior 155 | 156 | ekf = NLDS(fz, fwd_mlp, Q, R) 157 | (W_ekf, SW_ekf), hist_ekf = ekf_lib.filter(ekf, W0, y[:, None], x[:, None], Vinit, return_params=["mean", "cov"]) 158 | ekf_mu_hist, ekf_Sigma_hist = hist_ekf["mean"], hist_ekf["cov"] 159 | 160 | # Plot final performance 161 | fig, ax = plt.subplots() 162 | plot_mlp_prediction(key, x, y, xtest, fwd_mlp_obs_weights, W_ekf, SW_ekf, ax) 163 | ax.set_title("EKF + MLP") 164 | all_figures["ekf-mlp"] = fig 165 | 166 | # Plot intermediate performance 167 | intermediate_steps = [10, 20, 30, 40, 50, 60] 168 | fig, ax = plt.subplots(2, 2) 169 | plot_intermediate_steps(key, ax, fwd_mlp_obs_weights, intermediate_steps, xtest, ekf_mu_hist, ekf_Sigma_hist, x, y) 170 | plt.suptitle("EKF + MLP training") 171 | all_figures["ekf-mlp-intermediate"] = fig 172 | figures_intermediates = plot_intermediate_steps_single(key, "ekf", fwd_mlp_obs_weights, 173 | intermediate_steps, xtest, ekf_mu_hist, ekf_Sigma_hist, x, y) 174 | all_figures = {**all_figures, **figures_intermediates} 175 | return all_figures 176 | 177 | 178 | if __name__ == "__main__": 179 | from jsl.demos.plot_utils import savefig 180 | 181 | plt.rcParams["axes.spines.right"] = False 182 | plt.rcParams["axes.spines.top"] = False 183 | figures = main() 184 | savefig(figures) 185 | plt.show() 186 | -------------------------------------------------------------------------------- /jsl/demos/ekf_mlp_anim.py: -------------------------------------------------------------------------------- 1 | # Example showcasing the learning process of the EKF algorithm. 2 | # This demo is based on the ekf_mlp_anim_demo.py demo. 3 | # The animation script produces this video. 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | from jax.flatten_util import ravel_pytree 8 | from jax.random import PRNGKey, split, normal, multivariate_normal 9 | 10 | import matplotlib.pyplot as plt 11 | import matplotlib.animation as animation 12 | from functools import partial 13 | 14 | from jsl.demos.ekf_mlp import MLP, apply, sample_observations 15 | from jsl.nlds.base import NLDS 16 | from jsl.nlds.extended_kalman_filter import filter 17 | 18 | 19 | def main(fx, fz, filepath): 20 | key = PRNGKey(314) 21 | key_sample_obs, key_weights, key_init = split(key, 3) 22 | 23 | # *** MLP configuration *** 24 | n_hidden = 6 25 | n_out = 1 26 | n_in = 1 27 | model = MLP([n_hidden, n_out]) 28 | 29 | batch_size = 20 30 | batch = jnp.ones((batch_size, n_in)) 31 | 32 | variables = model.init(key_init, batch) 33 | W0, unflatten_fn = ravel_pytree(variables) 34 | 35 | fwd_mlp = partial(apply, model=model, unflatten_fn=unflatten_fn) 36 | # vectorised for multiple observations 37 | fwd_mlp_obs = jax.vmap(fwd_mlp, in_axes=[None, 0]) 38 | # vectorised for multiple observations and weights 39 | fwd_mlp_obs_weights = jax.vmap(fwd_mlp_obs, in_axes=[0, None]) 40 | 41 | # *** Generating training and test data *** 42 | n_obs = 200 43 | xmin, xmax = -3, 3 44 | sigma_y = 3.0 45 | x, y = sample_observations(key_sample_obs, fx, n_obs, xmin, xmax, x_noise=0, y_noise=sigma_y) 46 | xtest = jnp.linspace(x.min(), x.max(), n_obs) 47 | 48 | # *** MLP Training with EKF *** 49 | n_params = W0.size 50 | W0 = normal(key_weights, (n_params,)) * 1 # initial random guess 51 | Q = jnp.eye(n_params) * 1e-4 # parameters do not change 52 | R = jnp.eye(1) * sigma_y ** 2 # observation noise is fixed 53 | Vinit = jnp.eye(n_params) * 100 # vague prior 54 | 55 | ekf = NLDS(fz, fwd_mlp, Q, R) 56 | _, ekf_hist = filter(ekf, W0, y[:, None], x[:, None], Vinit, return_params=["mean", "cov"]) 57 | ekf_mu_hist, ekf_Sigma_hist = ekf_hist["mean"], ekf_hist["cov"] 58 | 59 | xtest = jnp.linspace(x.min(), x.max(), 200) 60 | fig, ax = plt.subplots() 61 | 62 | def func(i): 63 | plt.cla() 64 | W, SW = ekf_mu_hist[i], ekf_Sigma_hist[i] 65 | W_samples = multivariate_normal(key, W, SW, (100,)) 66 | sample_yhat = fwd_mlp_obs_weights(W_samples, xtest[:, None]) 67 | for sample in sample_yhat: 68 | ax.plot(xtest, sample, c="tab:gray", alpha=0.07) 69 | ax.plot(xtest, sample_yhat.mean(axis=0)) 70 | ax.scatter(x[:i], y[:i], s=14, c="none", edgecolor="black", label="observations") 71 | ax.scatter(x[i], y[i], s=30, c="tab:red") 72 | ax.set_title(f"EKF+MLP ({i + 1:03}/{n_obs})") 73 | ax.set_xlim(x.min(), x.max()) 74 | ax.set_ylim(y.min(), y.max()) 75 | 76 | return ax 77 | 78 | ani = animation.FuncAnimation(fig, func, frames=n_obs) 79 | ani.save(filepath, dpi=200, bitrate=-1, fps=10) 80 | 81 | 82 | if __name__ == "__main__": 83 | import os 84 | 85 | plt.rcParams["axes.spines.right"] = False 86 | plt.rcParams["axes.spines.top"] = False 87 | 88 | path = os.environ.get("FIGDIR") 89 | path = "." if path is None else path 90 | filepath = os.path.join(path, "samples_hist_ekf.mp4") 91 | 92 | def f(x): return x - 10 * jnp.cos(x) * jnp.sin(x) + x ** 3 93 | def fz(W): return W 94 | main(f, fz, filepath) 95 | 96 | print(f"Saved animation to {filepath}") 97 | -------------------------------------------------------------------------------- /jsl/demos/ekf_vs_eks_spiral.py: -------------------------------------------------------------------------------- 1 | # Compare extended Kalman filter with unscented kalman filter 2 | # on a nonlinear 2d tracking problem 3 | import jax.numpy as jnp 4 | import matplotlib.pyplot as plt 5 | import jsl.nlds.extended_kalman_smoother as eks 6 | from jax import random 7 | from jsl.demos import plot_utils 8 | from jsl.nlds.base import NLDS 9 | 10 | 11 | def plot_data(sample_state, sample_obs): 12 | fig, ax = plt.subplots() 13 | ax.plot(*sample_state.T, label="state space") 14 | ax.scatter(*sample_obs.T, s=60, c="tab:green", marker="+") 15 | ax.scatter(*sample_state[0], c="black", zorder=3) 16 | ax.legend() 17 | ax.set_title("Noisy observations from hidden trajectory") 18 | plt.axis("equal") 19 | return fig, ax 20 | 21 | 22 | def plot_inference(sample_obs, mean_hist, Sigma_hist, label): 23 | fig, ax = plt.subplots() 24 | ax.scatter(*sample_obs.T, marker="+", color="tab:green") 25 | ax.plot(*mean_hist.T, c="tab:orange", label=label) 26 | ax.scatter(*mean_hist[0], c="black", zorder=3) 27 | plt.legend() 28 | collection = [(mut, Vt) for mut, Vt in zip(mean_hist[::10], Sigma_hist[::10]) 29 | if Vt[0, 0] > 0 and Vt[1, 1] > 0 and abs(Vt[1, 0] - Vt[0, 1]) < 7e-4] 30 | for mut, Vt in collection: 31 | plot_utils.plot_ellipse(Vt, mut, ax, plot_center=False, alpha=0.9, zorder=3) 32 | plt.scatter(*mut, c="black", zorder=3, s=5) 33 | plt.axis("equal") 34 | return fig, ax 35 | 36 | 37 | def main(): 38 | def fz(x, dt): return x + dt * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])]) 39 | 40 | def fx(x, *args): return x 41 | 42 | dt = 0.1 43 | nsteps = 300 44 | # Initial state vector 45 | x0 = jnp.array([1.5, 0.0]) 46 | x0 = jnp.array([1.053, 0.145]) 47 | state_size, *_ = x0.shape 48 | # State noise 49 | Qt = jnp.eye(state_size) * 0.001 50 | # Observed noise 51 | Rt = jnp.eye(2) * 0.05 52 | 53 | key = random.PRNGKey(31415) 54 | model = NLDS(lambda x: fz(x, dt), fx, Qt, Rt) 55 | sample_state, sample_obs = model.sample(key, x0, nsteps) 56 | 57 | # _, ekf_hist = ekf_lib.filter(ekf_model, x0, sample_obs, return_params=["mean", "cov"]) 58 | hist = eks.smooth(model, x0, sample_obs, return_params=["mean", "cov"], 59 | return_filter_history=True) 60 | eks_hist = hist["smooth"] 61 | ekf_hist = hist["filter"] 62 | 63 | eks_mean_hist = eks_hist["mean"] 64 | eks_Sigma_hist = eks_hist["cov"] 65 | 66 | ekf_mean_hist = ekf_hist["mean"] 67 | ekf_Sigma_hist = ekf_hist["cov"] 68 | 69 | dict_figures = {} 70 | # nlds2d_data 71 | fig_data, ax = plot_data(sample_state, sample_obs) 72 | dict_figures["nlds2d_data"] = fig_data 73 | 74 | # nlds2d_ekf 75 | fig_ekf, ax = plot_inference(sample_obs, ekf_mean_hist, ekf_Sigma_hist, 76 | label="filtered") 77 | ax.set_title("EKF") 78 | dict_figures["nlds2d_ekf"] = fig_ekf 79 | 80 | # nlds2d_eks 81 | fig_eks, ax = plot_inference(sample_obs, eks_mean_hist, eks_Sigma_hist, 82 | label="smoothed") 83 | ax.set_title("EKS") 84 | dict_figures["nlds2d_eks"] = fig_eks 85 | 86 | return dict_figures 87 | 88 | 89 | if __name__ == "__main__": 90 | from jsl.demos.plot_utils import savefig 91 | 92 | dict_figures = main() 93 | savefig(dict_figures) 94 | plt.show() 95 | -------------------------------------------------------------------------------- /jsl/demos/ekf_vs_ukf_spiral.py: -------------------------------------------------------------------------------- 1 | # Compare extended Kalman filter with unscented kalman filter on a nonlinear 2d tracking problem 2 | from jax import random 3 | 4 | import matplotlib.pyplot as plt 5 | import jax.numpy as jnp 6 | 7 | import jsl.nlds.extended_kalman_filter as ekf_lib 8 | import jsl.nlds.unscented_kalman_filter as ukf_lib 9 | from jsl.demos import plot_utils 10 | from jsl.nlds.base import NLDS 11 | 12 | 13 | def check_symmetric(a, rtol=1.1): 14 | return jnp.allclose(a, a.T, rtol=rtol) 15 | 16 | 17 | def plot_data(sample_state, sample_obs): 18 | fig, ax = plt.subplots() 19 | ax.plot(*sample_state.T, label="state space") 20 | ax.scatter(*sample_obs.T, s=60, c="tab:green", marker="+") 21 | ax.scatter(*sample_state[0], c="black", zorder=3) 22 | ax.legend() 23 | ax.set_title("Noisy observations from hidden trajectory") 24 | plt.axis("equal") 25 | return fig, ax 26 | 27 | 28 | def plot_inference(sample_obs, mean_hist, Sigma_hist): 29 | fig, ax = plt.subplots() 30 | ax.scatter(*sample_obs.T, marker="+", color="tab:green") 31 | ax.plot(*mean_hist.T, c="tab:orange", label="filtered") 32 | ax.scatter(*mean_hist[0], c="black", zorder=3) 33 | plt.legend() 34 | collection = [(mut, Vt) for mut, Vt in zip(mean_hist[::4], Sigma_hist[::4]) 35 | if Vt[0, 0] > 0 and Vt[1, 1] > 0 and abs(Vt[1, 0] - Vt[0, 1]) < 7e-4] 36 | for mut, Vt in collection: 37 | plot_utils.plot_ellipse(Vt, mut, ax, plot_center=False, alpha=0.9, zorder=3) 38 | plt.axis("equal") 39 | return fig, ax 40 | 41 | 42 | def main(): 43 | def fz(x, dt): return x + dt * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])]) 44 | 45 | def fx(x, *args): return x 46 | 47 | dt = 0.4 48 | nsteps = 100 49 | # Initial state vector 50 | x0 = jnp.array([1.5, 0.0]) 51 | state_size, *_ = x0.shape 52 | # State noise 53 | Qt = jnp.eye(state_size) * 0.001 54 | # Observed noise 55 | Rt = jnp.eye(2) * 0.05 56 | alpha, beta, kappa = 1, 0, 2 57 | 58 | key = random.PRNGKey(31415) 59 | ekf_model = NLDS(lambda x: fz(x, dt), fx, Qt, Rt) 60 | sample_state, sample_obs = ekf_model.sample(key, x0, nsteps) 61 | 62 | ukf_model = NLDS(lambda x: fz(x, dt), fx, Qt, Rt, 63 | alpha, beta, kappa, state_size) 64 | 65 | _, ekf_hist = ekf_lib.filter(ekf_model, x0, sample_obs, return_params=["mean", "cov"]) 66 | ukf_mean_hist, ukf_Sigma_hist = ukf_lib.filter(ukf_model, x0, sample_obs) 67 | 68 | ekf_mean_hist = ekf_hist["mean"] 69 | ekf_Sigma_hist = ekf_hist["cov"] 70 | 71 | dict_figures = {} 72 | # nlds2d_data 73 | fig_data, ax = plot_data(sample_state, sample_obs) 74 | dict_figures["nlds2d_data"] = fig_data 75 | 76 | # nlds2d_ekf 77 | fig_ekf, ax = plot_inference(sample_obs, ekf_mean_hist, ekf_Sigma_hist) 78 | ax.set_title("EKF") 79 | dict_figures["nlds2d_ekf"] = fig_ekf 80 | 81 | # nlds2d_ukf 82 | fig_ukf, ax = plot_inference(sample_obs, ukf_mean_hist, ukf_Sigma_hist) 83 | ax.set_title("UKF") 84 | dict_figures["nlds2d_ukf"] = fig_ukf 85 | 86 | return dict_figures 87 | 88 | 89 | if __name__ == "__main__": 90 | from jsl.demos.plot_utils import savefig 91 | 92 | dict_figures = main() 93 | savefig(dict_figures) 94 | plt.show() 95 | -------------------------------------------------------------------------------- /jsl/demos/hmm_casino.py: -------------------------------------------------------------------------------- 1 | # Occasionally dishonest casino example [Durbin98, p54]. This script 2 | # exemplifies a Hidden Markov Model (HMM) in which the throw of a die 3 | # may result in the die being biased (towards 6) or unbiased. If the dice turns out to 4 | # be biased, the probability of remaining biased is high, and similarly for the unbiased state. 5 | # Assuming we observe the die being thrown n times the goal is to recover the periods in which 6 | # the die was biased. 7 | # Original matlab code: https://github.com/probml/pmtk3/blob/master/demos/casinoDemo.m 8 | 9 | 10 | # from jsl.hmm.hmm_numpy_lib import (HMMNumpy, hmm_sample_numpy, hmm_plot_graphviz, 11 | # hmm_forwards_backwards_numpy, hmm_viterbi_numpy) 12 | 13 | from jsl.hmm.hmm_utils import hmm_plot_graphviz 14 | 15 | import numpy as np 16 | import jax.numpy as jnp 17 | import matplotlib.pyplot as plt 18 | from jsl.hmm.hmm_lib import HMMJax, hmm_forwards_backwards_jax, hmm_sample_jax, hmm_viterbi_jax 19 | from jax.random import PRNGKey 20 | 21 | 22 | def find_dishonest_intervals(z_hist): 23 | """ 24 | Find the span of timesteps that the 25 | simulated systems turns to be in state 1 26 | Parameters 27 | ---------- 28 | z_hist: array(n_samples) 29 | Result of running the system with two 30 | latent states 31 | Returns 32 | ------- 33 | list of tuples with span of values 34 | """ 35 | spans = [] 36 | x_init = 0 37 | for t, _ in enumerate(z_hist[:-1]): 38 | if z_hist[t + 1] == 0 and z_hist[t] == 1: 39 | x_end = t 40 | spans.append((x_init, x_end)) 41 | elif z_hist[t + 1] == 1 and z_hist[t] == 0: 42 | x_init = t + 1 43 | return spans 44 | 45 | 46 | def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False): 47 | """ 48 | Plot the estimated smoothing/filtering/map of a sequence of hidden states. 49 | "Vertical gray bars denote times when the hidden 50 | state corresponded to state 1. Blue lines represent the 51 | posterior probability of being in that state given different subsets 52 | of observed data." See Markov and Hidden Markov models section for more info 53 | Parameters 54 | ---------- 55 | inference_values: array(n_samples, state_size) 56 | Result of runnig smoothing method 57 | z_hist: array(n_samples) 58 | Latent simulation 59 | ax: matplotlib.axes 60 | state: int 61 | Decide which state to highlight 62 | map_estimate: bool 63 | Whether to plot steps (simple plot if False) 64 | """ 65 | n_samples = len(inference_values) 66 | xspan = np.arange(1, n_samples + 1) 67 | spans = find_dishonest_intervals(z_hist) 68 | if map_estimate: 69 | ax.step(xspan, inference_values, where="post") 70 | else: 71 | ax.plot(xspan, inference_values[:, state]) 72 | 73 | for span in spans: 74 | ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none") 75 | ax.set_xlim(1, n_samples) 76 | # ax.set_ylim(0, 1) 77 | ax.set_ylim(-0.1, 1.1) 78 | ax.set_xlabel("Observation number") 79 | 80 | 81 | def main(): 82 | # state transition matrix 83 | A = jnp.array([ 84 | [0.95, 0.05], 85 | [0.10, 0.90] 86 | ]) 87 | 88 | # observation matrix 89 | B = jnp.array([ 90 | [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die 91 | [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die 92 | ]) 93 | 94 | n_samples = 300 95 | init_state_dist = jnp.array([1, 1]) / 2 96 | 97 | # hmm = HMM(A, B, init_state_dist) 98 | params = HMMJax(A, B, init_state_dist) 99 | 100 | seed = 0 101 | z_hist, x_hist = hmm_sample_jax(params, n_samples, PRNGKey(seed)) 102 | # z_hist, x_hist = hmm_sample_numpy(params, n_samples, 314) 103 | 104 | z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60] 105 | x_hist_str = "".join((np.array(x_hist) + 1).astype(str))[:60] 106 | 107 | print("Printing sample observed/latent...") 108 | print(f"x: {x_hist_str}") 109 | print(f"z: {z_hist_str}") 110 | 111 | # Do inference 112 | # alpha, _, gamma, loglik = hmm_forwards_backwards_numpy(params, x_hist, len(x_hist)) 113 | alpha, beta, gamma, loglik = hmm_forwards_backwards_jax(params, x_hist, len(x_hist)) 114 | print(f"Loglikelihood: {loglik}") 115 | 116 | # z_map = hmm_viterbi_numpy(params, x_hist) 117 | z_map = hmm_viterbi_jax(params, x_hist) 118 | 119 | dict_figures = {} 120 | 121 | # Plot results 122 | fig, ax = plt.subplots() 123 | plot_inference(alpha, z_hist, ax) 124 | ax.set_ylabel("p(loaded)") 125 | ax.set_title("Filtered") 126 | dict_figures["hmm_casino_filter"] = fig 127 | 128 | fig, ax = plt.subplots() 129 | plot_inference(gamma, z_hist, ax) 130 | ax.set_ylabel("p(loaded)") 131 | ax.set_title("Smoothed") 132 | dict_figures["hmm_casino_smooth"] = fig 133 | 134 | fig, ax = plt.subplots() 135 | plot_inference(z_map, z_hist, ax, map_estimate=True) 136 | ax.set_ylabel("MAP state") 137 | ax.set_title("Viterbi") 138 | 139 | dict_figures["hmm_casino_map"] = fig 140 | states, observations = ["Fair Dice", "Loaded Dice"], [str(i + 1) for i in range(B.shape[1])] 141 | 142 | #AA = hmm.trans_dist.probs 143 | #assert np.allclose(A, AA) 144 | 145 | dotfile = hmm_plot_graphviz(A, B, states, observations) 146 | dotfile_dict = {"hmm_casino_graphviz": dotfile} 147 | 148 | return dict_figures, dotfile_dict 149 | 150 | 151 | if __name__ == "__main__": 152 | from jsl.demos.plot_utils import savefig, savedotfile 153 | figs, dotfile = main() 154 | 155 | savefig(figs) 156 | savedotfile(dotfile) 157 | plt.show() -------------------------------------------------------------------------------- /jsl/demos/hmm_casino_em_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This demo shows the parameter estimations of HMMs via Baulm-Welch algorithm on the occasionally dishonest casino example. 3 | Author : Aleyna Kara(@karalleyna) 4 | """ 5 | 6 | 7 | import time 8 | import numpy as np 9 | import jax.numpy as jnp 10 | import matplotlib.pyplot as plt 11 | from jax.random import split, PRNGKey 12 | from jsl.hmm.hmm_numpy_lib import HMMNumpy, hmm_em_numpy 13 | from jsl.hmm.hmm_lib import HMMJax, hmm_sample_jax 14 | from jsl.hmm.hmm_lib import init_random_params_jax, hmm_em_jax 15 | from jsl.hmm import hmm_utils 16 | 17 | 18 | def main(): 19 | A = jnp.array([ 20 | [0.95, 0.05], 21 | [0.10, 0.90] 22 | ]) 23 | 24 | # observation matrix 25 | B = jnp.array([ 26 | [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die 27 | [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die 28 | ]) 29 | 30 | pi = jnp.array([1, 1]) / 2 31 | 32 | seed = 100 33 | rng_key = PRNGKey(seed) 34 | rng_key, rng_sample, rng_batch, rng_init = split(rng_key, 4) 35 | 36 | casino = HMMJax(A, B, pi) 37 | 38 | n_obs_seq, batch_size, max_len = 5, 5, 3000 39 | 40 | observations, lens = hmm_utils.hmm_sample_n(casino, 41 | hmm_sample_jax, 42 | n_obs_seq, max_len, 43 | rng_sample) 44 | observations, lens = hmm_utils.pad_sequences(observations, lens) 45 | 46 | n_hidden, n_obs = B.shape 47 | params_jax = init_random_params_jax([n_hidden, n_obs], rng_key=rng_init) 48 | params_numpy = HMMNumpy(np.array(params_jax.trans_mat), 49 | np.array(params_jax.obs_mat), 50 | np.array(params_jax.init_dist)) 51 | 52 | num_epochs = 20 53 | 54 | start = time.time() 55 | params_numpy, neg_ll_numpy = hmm_em_numpy(np.array(observations), 56 | np.array(lens), 57 | num_epochs=num_epochs, 58 | init_params=params_numpy) 59 | print(f"Time taken by numpy version of EM : {time.time() - start}s") 60 | 61 | start = time.time() 62 | params_jax, neg_ll_jax = hmm_em_jax(observations, 63 | lens, 64 | num_epochs=num_epochs, 65 | init_params=params_jax) 66 | print(f"Time taken by JAX version of EM : {time.time() - start}s") 67 | 68 | assert jnp.allclose(np.array(neg_ll_jax), np.array(neg_ll_numpy), 4) 69 | 70 | print(f" Negative loglikelihoods : {neg_ll_jax}") 71 | 72 | dict_figures = {} 73 | fig, ax = plt.subplots() 74 | ax.plot(neg_ll_numpy, label="EM numpy") 75 | ax.set_label("Number of Iterations") 76 | dict_figures["em-numpy"] = fig 77 | 78 | fig, ax = plt.subplots() 79 | ax.plot(neg_ll_jax, label="EM JAX") 80 | ax.set_xlabel("Number of Iterations") 81 | dict_figures["em-jax"] = fig 82 | 83 | states, observations = ["Fair Dice", "Loaded Dice"], [str(i + 1) for i in range(B.shape[1])] 84 | # dotfile_np = hmm_plot_graphviz(params_numpy, "hmm_casino_train_np", states, observations) 85 | dotfile_np = hmm_utils.hmm_plot_graphviz(params_numpy.trans_mat, params_numpy.obs_mat, 86 | states, observations) 87 | 88 | # dotfile_jax = hmm_plot_graphviz(params_jax, "hmm_casino_train_jax", states, observations) 89 | dotfile_jax = hmm_utils.hmm_plot_graphviz(params_jax.trans_mat, params_jax.obs_mat, 90 | states, observations) 91 | 92 | dotfile_dict = {"graph-numpy": dotfile_np, "graph-jax": dotfile_jax} 93 | 94 | return dict_figures, dotfile_dict 95 | 96 | 97 | if __name__ == "__main__": 98 | from jsl.demos.plot_utils import savefig, savedotfile 99 | 100 | figs, dotfiles = main() 101 | savefig(figs) 102 | savedotfile(dotfiles) 103 | plt.show() 104 | -------------------------------------------------------------------------------- /jsl/demos/hmm_casino_numpy.py: -------------------------------------------------------------------------------- 1 | # Occasionally dishonest casino example [Durbin98, p54]. This script 2 | # exemplifies a Hidden Markov Model (HMM) in which the throw of a die 3 | # may result in the die being biased (towards 6) or unbiased. If the dice turns out to 4 | # be biased, the probability of remaining biased is high, and similarly for the unbiased state. 5 | # Assuming we observe the die being thrown n times the goal is to recover the periods in which 6 | # the die was biased. 7 | # Original matlab code: https://github.com/probml/pmtk3/blob/master/demos/casinoDemo.m 8 | 9 | from jsl.hmm.hmm_numpy_lib import (HMMNumpy, hmm_sample_numpy, 10 | hmm_forwards_backwards_numpy, hmm_viterbi_numpy) 11 | 12 | from jsl.hmm.hmm_utils import hmm_plot_graphviz 13 | 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | def find_dishonest_intervals(z_hist): 19 | """ 20 | Find the span of timesteps that the 21 | simulated systems turns to be in state 1 22 | Parameters 23 | ---------- 24 | z_hist: array(n_samples) 25 | Result of running the system with two 26 | latent states 27 | Returns 28 | ------- 29 | list of tuples with span of values 30 | """ 31 | spans = [] 32 | x_init = 0 33 | for t, _ in enumerate(z_hist[:-1]): 34 | if z_hist[t + 1] == 0 and z_hist[t] == 1: 35 | x_end = t 36 | spans.append((x_init, x_end)) 37 | elif z_hist[t + 1] == 1 and z_hist[t] == 0: 38 | x_init = t + 1 39 | return spans 40 | 41 | 42 | def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False): 43 | """ 44 | Plot the estimated smoothing/filtering/map of a sequence of hidden states. 45 | "Vertical gray bars denote times when the hidden 46 | state corresponded to state 1. Blue lines represent the 47 | posterior probability of being in that state given different subsets 48 | of observed data." See Markov and Hidden Markov models section for more info 49 | Parameters 50 | ---------- 51 | inference_values: array(n_samples, state_size) 52 | Result of runnig smoothing method 53 | z_hist: array(n_samples) 54 | Latent simulation 55 | ax: matplotlib.axes 56 | state: int 57 | Decide which state to highlight 58 | map_estimate: bool 59 | Whether to plot steps (simple plot if False) 60 | """ 61 | n_samples = len(inference_values) 62 | xspan = np.arange(1, n_samples + 1) 63 | spans = find_dishonest_intervals(z_hist) 64 | if map_estimate: 65 | ax.step(xspan, inference_values, where="post") 66 | else: 67 | ax.plot(xspan, inference_values[:, state]) 68 | 69 | for span in spans: 70 | ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none") 71 | ax.set_xlim(1, n_samples) 72 | # ax.set_ylim(0, 1) 73 | ax.set_ylim(-0.1, 1.1) 74 | ax.set_xlabel("Observation number") 75 | 76 | 77 | def main(): 78 | # state transition matrix 79 | A = np.array([ 80 | [0.95, 0.05], 81 | [0.10, 0.90] 82 | ]) 83 | 84 | # observation matrix 85 | B = np.array([ 86 | [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die 87 | [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die 88 | ]) 89 | 90 | n_samples = 300 91 | init_state_dist = np.array([1, 1]) / 2 92 | params = HMMNumpy(A, B, init_state_dist) 93 | z_hist, x_hist = hmm_sample_numpy(params, n_samples, 314) 94 | 95 | z_hist_str = "".join((z_hist + 1).astype(str))[:60] 96 | x_hist_str = "".join((x_hist + 1).astype(str))[:60] 97 | 98 | print("Printing sample observed/latent...") 99 | print(f"x: {x_hist_str}") 100 | print(f"z: {z_hist_str}") 101 | 102 | # Do inference 103 | alpha, _, gamma, loglik = hmm_forwards_backwards_numpy(params, x_hist, len(x_hist)) 104 | print(f"Loglikelihood: {loglik}") 105 | 106 | z_map = hmm_viterbi_numpy(params, x_hist) 107 | 108 | dict_figures = {} 109 | 110 | # Plot results 111 | fig, ax = plt.subplots() 112 | plot_inference(alpha, z_hist, ax) 113 | ax.set_ylabel("p(loaded)") 114 | ax.set_title("Filtered") 115 | dict_figures["hmm_casino_filter"] = fig 116 | 117 | fig, ax = plt.subplots() 118 | plot_inference(gamma, z_hist, ax) 119 | ax.set_ylabel("p(loaded)") 120 | ax.set_title("Smoothed") 121 | dict_figures["hmm_casino_smooth"] = fig 122 | 123 | fig, ax = plt.subplots() 124 | plot_inference(z_map, z_hist, ax, map_estimate=True) 125 | ax.set_ylabel("MAP state") 126 | ax.set_title("Viterbi") 127 | dict_figures["hmm_casino_map"] = fig 128 | 129 | file_name = "hmm_casino_params" 130 | states, observations = ["Fair Dice", "Loaded Dice"], [str(i + 1) for i in range(B.shape[1])] 131 | dotfile = hmm_plot_graphviz(A, B, states, observations) 132 | # dotfile = hmm_plot_graphviz(params, file_name, states, observations) 133 | dotfile_dict = {"hmm_casino_graphviz": dotfile} 134 | 135 | return dict_figures, dotfile_dict 136 | 137 | 138 | if __name__ == "__main__": 139 | from jsl.demos.plot_utils import savefig, savedotfile 140 | 141 | figs, dotfile = main() 142 | savefig(figs) 143 | savedotfile(dotfile) 144 | plt.show() 145 | -------------------------------------------------------------------------------- /jsl/demos/hmm_casino_sgd_train: -------------------------------------------------------------------------------- 1 | // HMM 2 | digraph { 3 | s0 [label=<
State 1
Obs 10.99
Obs 20.01
>] 4 | s1 [label=<
State 2
Obs 10.01
Obs 20.99
>] 5 | s0 -> s0 [label=0.99] 6 | s0 -> s1 [label=0.01] 7 | s1 -> s0 [label=0.01] 8 | s1 -> s1 [label=0.99] 9 | rankdir=LR 10 | } 11 | -------------------------------------------------------------------------------- /jsl/demos/hmm_casino_sgd_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This demo does MAP estimation of an HMM using gradient-descent algorithm applied to the log marginal likelihood. 3 | It includes 4 | 5 | 1. Mini Batch Gradient Descent 6 | 2. Full Batch Gradient Descent 7 | 3. Stochastic Gradient Descent 8 | 9 | Author: Aleyna Kara(@karalleyna) 10 | """ 11 | 12 | 13 | import matplotlib.pyplot as plt 14 | import jax.numpy as jnp 15 | from jax.example_libraries import optimizers 16 | from jax.random import split, PRNGKey 17 | from jsl.hmm.hmm_lib import fit 18 | from jsl.hmm.hmm_utils import pad_sequences, hmm_sample_n 19 | from jsl.hmm.hmm_lib import HMMJax, hmm_sample_jax 20 | from jsl.hmm.hmm_utils import hmm_plot_graphviz 21 | 22 | 23 | def main(): 24 | # state transition matrix 25 | A = jnp.array([ 26 | [0.95, 0.05], 27 | [0.10, 0.90]]) 28 | 29 | # observation matrix 30 | B = jnp.array([ 31 | [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die 32 | [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die 33 | ]) 34 | 35 | pi = jnp.array([1, 1]) / 2 36 | 37 | casino = HMMJax(A, B, pi) 38 | num_hidden, num_obs = 2, 6 39 | 40 | seed = 0 41 | rng_key = PRNGKey(seed) 42 | rng_key, rng_sample = split(rng_key) 43 | 44 | n_obs_seq, max_len = 4, 5000 45 | num_epochs = 400 46 | 47 | observations, lens = pad_sequences(*hmm_sample_n(casino, hmm_sample_jax, n_obs_seq, max_len, rng_sample)) 48 | optimizer = optimizers.momentum(step_size=1e-3, mass=0.95) 49 | 50 | # Mini Batch Gradient Descent 51 | batch_size = 2 52 | params_mbgd, losses_mbgd = fit(observations, 53 | lens, 54 | num_hidden, 55 | num_obs, 56 | batch_size, 57 | optimizer, 58 | rng_key=None, 59 | num_epochs=num_epochs) 60 | 61 | # Full Batch Gradient Descent 62 | batch_size = n_obs_seq 63 | params_fbgd, losses_fbgd = fit(observations, 64 | lens, 65 | num_hidden, 66 | num_obs, 67 | batch_size, 68 | optimizer, 69 | rng_key=None, 70 | num_epochs=num_epochs) 71 | 72 | # Stochastic Gradient Descent 73 | batch_size = 1 74 | params_sgd, losses_sgd = fit(observations, 75 | lens, 76 | num_hidden, 77 | num_obs, 78 | batch_size, 79 | optimizer, 80 | rng_key=None, 81 | num_epochs=num_epochs) 82 | 83 | losses = [losses_sgd, losses_mbgd, losses_fbgd] 84 | titles = ["Stochastic Gradient Descent", "Mini Batch Gradient Descent", "Full Batch Gradient Descent"] 85 | 86 | dict_figures = {} 87 | for loss, title in zip(losses, titles): 88 | filename = title.replace(" ", "_").lower() 89 | fig, ax = plt.subplots() 90 | ax.plot(loss) 91 | ax.set_title(f"{title}") 92 | dict_figures[filename] = fig 93 | dotfile = hmm_plot_graphviz(params_sgd.trans_mat, params_sgd.trans_mat) 94 | dotfile_dict = {"hmm-casino-dot": dotfile} 95 | 96 | return dict_figures, dotfile_dict 97 | 98 | 99 | if __name__ == "__main__": 100 | from jsl.demos.plot_utils import savefig, savedotfile 101 | 102 | figs, dotfile = main() 103 | savefig(figs) 104 | savedotfile(dotfile) 105 | plt.show() 106 | -------------------------------------------------------------------------------- /jsl/demos/hmm_lillypad.py: -------------------------------------------------------------------------------- 1 | # Example of an HMM with Gaussian emission in 2D 2 | # For a matlab version, see https://github.com/probml/pmtk3/blob/master/demos/hmmLillypadDemo.m 3 | 4 | # Author: Gerardo Durán-Martín (@gerdm), Aleyna Kara(@karalleyna) 5 | 6 | 7 | import logging 8 | import distrax 9 | import numpy as np 10 | import jax.numpy as jnp 11 | import matplotlib.pyplot as plt 12 | import tensorflow_probability as tfp 13 | from distrax import HMM 14 | from jax import vmap 15 | from jax.random import PRNGKey 16 | 17 | logging.getLogger('absl').setLevel(logging.CRITICAL) 18 | 19 | 20 | def plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax, step=1e-2): 21 | """ 22 | Plot the trajectory of a 2-dimensional HMM 23 | Parameters 24 | ---------- 25 | hmm : HMM 26 | Hidden Markov Model 27 | samples_obs: numpy.ndarray(n_samples, 2) 28 | Observations 29 | samples_state: numpy.ndarray(n_samples, ) 30 | Latent state of the system 31 | colors: list(int) 32 | List of colors for each latent state 33 | step: float 34 | Step size 35 | Returns 36 | ------- 37 | * matplotlib.axes 38 | * colour of each latent state 39 | """ 40 | obs_dist = hmm.obs_dist 41 | color_sample = [colors[i] for i in samples_state] 42 | 43 | xs = jnp.arange(xmin, xmax, step) 44 | ys = jnp.arange(ymin, ymax, step) 45 | 46 | v_prob = vmap(lambda x, y: obs_dist.prob(jnp.array([x, y])), in_axes=(None, 0)) 47 | z = vmap(v_prob, in_axes=(0, None))(xs, ys) 48 | 49 | grid = np.mgrid[xmin:xmax:step, ymin:ymax:step] 50 | 51 | for k, color in enumerate(colors): 52 | ax.contour(*grid, z[:, :, k], levels=[1], colors=color, linewidths=3) 53 | ax.text(*(obs_dist.mean()[k] + 0.13), f"$k$={k + 1}", fontsize=13, horizontalalignment="right") 54 | 55 | ax.plot(*samples_obs.T, c="black", alpha=0.3, zorder=1) 56 | ax.scatter(*samples_obs.T, c=color_sample, s=30, zorder=2, alpha=0.8) 57 | 58 | return ax, color_sample 59 | 60 | 61 | def main(): 62 | initial_probs = jnp.array([0.3, 0.2, 0.5]) 63 | 64 | # transition matrix 65 | A = jnp.array([ 66 | [0.3, 0.4, 0.3], 67 | [0.1, 0.6, 0.3], 68 | [0.2, 0.3, 0.5] 69 | ]) 70 | 71 | S1 = jnp.array([ 72 | [1.1, 0], 73 | [0, 0.3] 74 | ]) 75 | 76 | S2 = jnp.array([ 77 | [0.3, -0.5], 78 | [-0.5, 1.3] 79 | ]) 80 | 81 | S3 = jnp.array([ 82 | [0.8, 0.4], 83 | [0.4, 0.5] 84 | ]) 85 | 86 | cov_collection = jnp.array([S1, S2, S3]) / 60 87 | mu_collection = jnp.array([ 88 | [0.3, 0.3], 89 | [0.8, 0.5], 90 | [0.3, 0.8] 91 | ]) 92 | 93 | hmm = HMM(trans_dist=distrax.Categorical(probs=A), 94 | init_dist=distrax.Categorical(probs=initial_probs), 95 | obs_dist=distrax.as_distribution( 96 | tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=mu_collection, 97 | covariance_matrix=cov_collection))) 98 | n_samples, seed = 50, 10 99 | samples_state, samples_obs = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples) 100 | 101 | xmin, xmax = 0, 1 102 | ymin, ymax = 0, 1.2 103 | colors = ["tab:green", "tab:blue", "tab:red"] 104 | 105 | dict_figures = {} 106 | fig, ax = plt.subplots() 107 | _, color_sample = plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax) 108 | dict_figures["hmm-lillypad-2d"] = fig 109 | 110 | fig, ax = plt.subplots() 111 | ax.step(range(n_samples), samples_state, where="post", c="black", linewidth=1, alpha=0.3) 112 | ax.scatter(range(n_samples), samples_state, c=color_sample, zorder=3) 113 | dict_figures["hmm-lillypad-step"] = fig 114 | 115 | return dict_figures 116 | 117 | 118 | if __name__ == "__main__": 119 | from jsl.demos.plot_utils import savefig 120 | 121 | plt.rcParams["axes.spines.right"] = False 122 | plt.rcParams["axes.spines.top"] = False 123 | figs = main() 124 | savefig(figs) 125 | plt.show() 126 | -------------------------------------------------------------------------------- /jsl/demos/kf_continuous_circle.py: -------------------------------------------------------------------------------- 1 | # Example of a Kalman Filter procedure 2 | # on a continuous system with imaginary eigenvalues 3 | # and discrete samples 4 | # For futher reference and examples see: 5 | # * Section on Kalman Filters in PML vol2 book 6 | # * Nonlinear Dynamics and Chaos - Steven Strogatz 7 | 8 | import numpy as np 9 | import jax.numpy as jnp 10 | import matplotlib.pyplot as plt 11 | from jax import random 12 | from jsl.demos.plot_utils import plot_ellipse 13 | 14 | from jsl.lds.kalman_filter import LDS 15 | from jsl.lds.cont_kalman_filter import sample, filter 16 | 17 | 18 | def main(): 19 | A = jnp.array([[0, 1], [-1, 0]]) 20 | C = jnp.eye(2) 21 | 22 | dt = 0.01 23 | T = 5.5 24 | nsamples = 70 25 | x0 = jnp.array([0.5, -0.75]) 26 | 27 | # State noise 28 | Qt = jnp.eye(2) * 0.001 29 | # Observed noise 30 | Rt = jnp.eye(2) * 0.01 31 | 32 | Sigma0 = jnp.eye(2) 33 | 34 | key = random.PRNGKey(314) 35 | lds = LDS(A, C, Qt, Rt, x0, Sigma0) 36 | sample_state, sample_obs, jump = sample(key, lds, x0, T, nsamples) 37 | mu_hist, V_hist, *_ = filter(lds, sample_obs, jump, dt) 38 | 39 | step = 0.1 40 | vmin, vmax = -1.5, 1.5 + step 41 | X = np.mgrid[-1:1.5:step, vmin:vmax:step][::-1] 42 | X_dot = jnp.einsum("ij,jxy->ixy", A, X) 43 | 44 | dict_figures = {} 45 | 46 | fig_state, ax = plt.subplots() 47 | ax.plot(*sample_state.T, label="state space") 48 | ax.scatter(*sample_obs.T, marker="+", c="tab:green", s=60, label="observations") 49 | ax.scatter(*sample_state[0], c="black", zorder=3) 50 | field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa") 51 | ax.legend() 52 | plt.axis("equal") 53 | ax.set_title("State Space") 54 | dict_figures["kf-circle-state"] = fig_state 55 | 56 | fig_filtered, ax = plt.subplots() 57 | ax.plot(*mu_hist.T, c="tab:orange", label="Filtered") 58 | ax.scatter(*sample_obs.T, marker="+", s=60, c="tab:green", label="observations") 59 | ax.scatter(*mu_hist[0], c="black", zorder=3) 60 | for mut, Vt in zip(mu_hist[::4], V_hist[::4]): 61 | plot_ellipse(Vt, mut, ax, plot_center=False, alpha=0.9, zorder=3) 62 | plt.legend() 63 | field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa") 64 | ax.legend() 65 | plt.axis("equal") 66 | ax.set_title("Approximate Space") 67 | dict_figures["kf-circle-filtered"] = fig_filtered 68 | 69 | return dict_figures 70 | 71 | 72 | if __name__ == "__main__": 73 | from jsl.demos.plot_utils import savefig 74 | 75 | dict_figures = main() 76 | savefig(dict_figures) 77 | plt.show() 78 | -------------------------------------------------------------------------------- /jsl/demos/kf_parallel.py: -------------------------------------------------------------------------------- 1 | # Parallel Kalman Filter demo: this script simulates 2 | # 4 missiles as described in the section "state-space models". 3 | # Each of the missiles is then filtered and smoothed in parallel 4 | 5 | import jax.numpy as jnp 6 | from jsl.lds.kalman_filter import LDS, filter, smooth 7 | from jsl.demos.plot_utils import plot_ellipse 8 | import matplotlib.pyplot as plt 9 | from jax import random 10 | 11 | 12 | def sample_filter_smooth(key, lds_model, nsteps, n_samples, noisy_init): 13 | """ 14 | Sample from a linear dynamical system, apply the kalman filter 15 | (forward pass), and performs smoothing. 16 | 17 | Parameters 18 | ---------- 19 | lds: LinearDynamicalSystem 20 | Instance of a linear dynamical system with known parameters 21 | 22 | Returns 23 | ------- 24 | Dictionary with the following key, values 25 | * (z_hist) array(n_samples, timesteps, state_size): 26 | Simulation of Latent states 27 | * (x_hist) array(n_samples, timesteps, observation_size): 28 | Simulation of observed states 29 | * (mu_hist) array(n_samples, timesteps, state_size): 30 | Filtered means mut 31 | * (Sigma_hist) array(n_samples, timesteps, state_size, state_size) 32 | Filtered covariances Sigmat 33 | * (mu_cond_hist) array(n_samples, timesteps, state_size) 34 | Filtered conditional means mut|t-1 35 | * (Sigma_cond_hist) array(n_samples, timesteps, state_size, state_size) 36 | Filtered conditional covariances Sigmat|t-1 37 | * (mu_hist_smooth) array(n_samples, timesteps, state_size): 38 | Smoothed means mut 39 | * (Sigma_hist_smooth) array(n_samples, timesteps, state_size, state_size) 40 | Smoothed covariances Sigmat 41 | """ 42 | z_hist, x_hist = lds_model.sample(key, nsteps, n_samples, noisy_init) 43 | mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist = filter(lds_model, x_hist) 44 | mu_hist_smooth, Sigma_hist_smooth = smooth(lds_model, mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist) 45 | 46 | return { 47 | "z_hist": z_hist, 48 | "x_hist": x_hist, 49 | "mu_hist": mu_hist, 50 | "Sigma_hist": Sigma_hist, 51 | "mu_cond_hist": mu_cond_hist, 52 | "Sigma_cond_hist": Sigma_cond_hist, 53 | "mu_hist_smooth": mu_hist_smooth, 54 | "Sigma_hist_smooth": Sigma_hist_smooth 55 | } 56 | 57 | 58 | def plot_collection(obs, ax, means=None, covs=None, **kwargs): 59 | n_samples, n_steps, _ = obs.shape 60 | for nsim in range(n_samples): 61 | X = obs[nsim] 62 | if means is not None: 63 | mean = means[nsim] 64 | ax.scatter(*mean[0, :2], marker="o", s=20, c="black", zorder=2) 65 | ax.plot(*mean[:, :2].T, marker="o", markersize=2, **kwargs, zorder=1) 66 | if covs is not None: 67 | cov = covs[nsim] 68 | for t in range(1, n_steps, 3): 69 | plot_ellipse(cov[t][:2, :2], mean[t, :2], ax, 70 | plot_center=False, alpha=0.7) 71 | ax.scatter(*X.T, marker="+", s=60) 72 | 73 | 74 | def main(): 75 | Δ = 1.0 76 | A = jnp.array([ 77 | [1, 0, Δ, 0], 78 | [0, 1, 0, Δ], 79 | [0, 0, 1, 0], 80 | [0, 0, 0, 1] 81 | ]) 82 | 83 | C = jnp.array([ 84 | [1, 0, 0, 0], 85 | [0, 1, 0, 0] 86 | ]).astype(float) 87 | 88 | state_size, _ = A.shape 89 | observation_size, _ = C.shape 90 | 91 | Q = jnp.eye(state_size) * 0.01 92 | R = jnp.eye(observation_size) * 1.2 93 | # Prior parameter distribution 94 | mu0 = jnp.array([8, 10, 1, 0]).astype(float) 95 | Sigma0 = jnp.eye(state_size) * 0.1 96 | 97 | 98 | key = random.PRNGKey(3141) 99 | lds_instance = LDS(A, C, Q, R, mu0, Sigma0) 100 | 101 | nsteps, n_samples = 15, 4 102 | result = sample_filter_smooth(key, lds_instance, nsteps, n_samples, True) 103 | 104 | dict_figures = {} 105 | 106 | fig_latent, ax = plt.subplots() 107 | plot_collection(result["x_hist"], ax, result["z_hist"], linestyle="--") 108 | ax.set_title("State space") 109 | dict_figures["missiles_latent"] = fig_latent 110 | 111 | fig_filtered, ax = plt.subplots() 112 | plot_collection(result["x_hist"], ax, result["mu_hist"], result["Sigma_hist"]) 113 | ax.set_title("Filtered") 114 | dict_figures["missiles_filtered"] = fig_filtered 115 | 116 | fig_smoothed, ax = plt.subplots() 117 | plot_collection(result["x_hist"], ax, result["mu_hist_smooth"], result["Sigma_hist_smooth"]) 118 | ax.set_title("Smoothed") 119 | dict_figures["missiles_smoothed"] = fig_smoothed 120 | 121 | return dict_figures 122 | 123 | 124 | if __name__ == "__main__": 125 | from jsl.demos.plot_utils import savefig 126 | figures = main() 127 | savefig(figures) 128 | plt.show() 129 | -------------------------------------------------------------------------------- /jsl/demos/kf_spiral.py: -------------------------------------------------------------------------------- 1 | # This demo exemplifies the use of the Kalman Filter 2 | # algorithm when the linear dynamical system induced by the 3 | # matrix A has imaginary eigenvalues 4 | 5 | import jax.numpy as jnp 6 | import matplotlib.pyplot as plt 7 | from jsl.demos.plot_utils import plot_ellipse 8 | from jax import random 9 | from jsl.lds.kalman_filter import LDS, smooth, filter 10 | 11 | def plot_uncertainty_ellipses(means, covs, ax): 12 | timesteps = len(means) 13 | for t in range(timesteps): 14 | plot_ellipse(covs[t], means[t], ax, plot_center=False, alpha=0.7) 15 | 16 | def main(): 17 | dx = 1.1 18 | timesteps = 20 19 | key = random.PRNGKey(27182) 20 | 21 | mean_0 = jnp.array([1, 1, 1, 0]).astype(float) 22 | Sigma_0 = jnp.eye(4) 23 | A = jnp.array([ 24 | [0.1, 1.1, dx, 0], 25 | [-1, 1, 0, dx], 26 | [0, 0, 0.1, 0], 27 | [0, 0, 0, 0.1] 28 | ]) 29 | C = jnp.array([ 30 | [1, 0, 0, 0], 31 | [0, 1, 0, 0] 32 | ]) 33 | Q = jnp.eye(4) * 0.001 34 | R = jnp.eye(2) * 4 35 | 36 | lds_instance = LDS(A, C, Q, R, mean_0, Sigma_0) 37 | state_hist, obs_hist = lds_instance.sample(key, timesteps) 38 | 39 | res = filter(lds_instance, obs_hist) 40 | mean_hist, Sigma_hist, mean_cond_hist, Sigma_cond_hist = res 41 | mean_hist_smooth, Sigma_hist_smooth = smooth(lds_instance, mean_hist, 42 | Sigma_hist, mean_cond_hist, Sigma_cond_hist) 43 | 44 | dict_figures = {} 45 | 46 | fig_spiral_state, ax = plt.subplots() 47 | ax.plot(*state_hist[:, :2].T, linestyle="--") 48 | ax.scatter(*obs_hist.T, marker="+", s=60) 49 | ax.set_title("State space") 50 | dict_figures["spiral-state"] = fig_spiral_state 51 | 52 | 53 | fig_spiral_filtered, ax = plt.subplots() 54 | ax.plot(*mean_hist[:, :2].T) 55 | ax.scatter(*obs_hist.T, marker="+", s=60) 56 | plot_uncertainty_ellipses(mean_hist[:, :2], Sigma_hist[:, :2, :2], ax) 57 | ax.set_title("Filtered") 58 | dict_figures["spiral-filtered"] = fig_spiral_filtered 59 | 60 | fig_spiral_smoothed, ax = plt.subplots() 61 | ax.plot(*mean_hist_smooth[:, :2].T) 62 | ax.scatter(*obs_hist.T, marker="+", s=60) 63 | plot_uncertainty_ellipses(mean_hist_smooth[:, :2], Sigma_hist_smooth[:, :2, :2], ax) 64 | ax.set_title("Smoothed") 65 | dict_figures["spiral-smoothed"] = fig_spiral_smoothed 66 | 67 | return dict_figures 68 | 69 | 70 | if __name__ == "__main__": 71 | from jsl.demos.plot_utils import savefig 72 | figures = main() 73 | savefig(figures) 74 | plt.show() 75 | -------------------------------------------------------------------------------- /jsl/demos/kf_tracking.py: -------------------------------------------------------------------------------- 1 | # This script produces an illustration of Kalman filtering and smoothing 2 | 3 | import jax.numpy as jnp 4 | import matplotlib.pyplot as plt 5 | from jax import random 6 | from jsl.demos.plot_utils import plot_ellipse 7 | from jsl.lds.kalman_filter import LDS, smooth, filter 8 | 9 | 10 | def plot_tracking_values(observed, filtered, cov_hist, signal_label, ax): 11 | """ 12 | observed: array(nsteps, 2) 13 | Array of observed values 14 | filtered: array(nsteps, state_size) 15 | Array of latent (hidden) values. We consider only the first 16 | two dimensions of the latent values 17 | cov_hist: array(nsteps, state_size, state_size) 18 | History of the retrieved (filtered) covariance matrices 19 | ax: matplotlib AxesSubplot 20 | """ 21 | timesteps, _ = observed.shape 22 | ax.plot(observed[:, 0], observed[:, 1], marker="o", linewidth=0, 23 | markerfacecolor="none", markeredgewidth=2, markersize=8, label="observed", c="tab:green") 24 | ax.plot(*filtered[:, :2].T, label=signal_label, c="tab:red", marker="x", linewidth=2) 25 | for t in range(0, timesteps, 1): 26 | covn = cov_hist[t][:2, :2] 27 | plot_ellipse(covn, filtered[t, :2], ax, n_std=2.0, plot_center=False) 28 | ax.axis("equal") 29 | ax.legend() 30 | 31 | 32 | def sample_filter_smooth(key, lds_model, timesteps): 33 | """ 34 | Sample from a linear dynamical system, apply the kalman filter 35 | (forward pass), and performs smoothing. 36 | 37 | Parameters 38 | ---------- 39 | lds: LinearDynamicalSystem 40 | Instance of a linear dynamical system with known parameters 41 | 42 | Returns 43 | ------- 44 | Dictionary with the following key, values 45 | * (z_hist) array(timesteps, state_size): 46 | Simulation of Latent states 47 | * (x_hist) array(timesteps, observation_size): 48 | Simulation of observed states 49 | * (mu_hist) array(timesteps, state_size): 50 | Filtered means mut 51 | * (Sigma_hist) array(timesteps, state_size, state_size) 52 | Filtered covariances Sigmat 53 | * (mu_cond_hist) array(timesteps, state_size) 54 | Filtered conditional means mut|t-1 55 | * (Sigma_cond_hist) array(timesteps, state_size, state_size) 56 | Filtered conditional covariances Sigmat|t-1 57 | * (mu_hist_smooth) array(timesteps, state_size): 58 | Smoothed means mut 59 | * (Sigma_hist_smooth) array(timesteps, state_size, state_size) 60 | Smoothed covariances Sigmat 61 | """ 62 | z_hist, x_hist = lds_model.sample(key, timesteps) 63 | mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist = filter(lds_model, x_hist) 64 | mu_hist_smooth, Sigma_hist_smooth = smooth(lds_model, mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist) 65 | 66 | return { 67 | "z_hist": z_hist, 68 | "x_hist": x_hist, 69 | "mu_hist": mu_hist, 70 | "Sigma_hist": Sigma_hist, 71 | "mu_cond_hist": mu_cond_hist, 72 | "Sigma_cond_hist": Sigma_cond_hist, 73 | "mu_hist_smooth": mu_hist_smooth, 74 | "Sigma_hist_smooth": Sigma_hist_smooth 75 | } 76 | 77 | 78 | def main(): 79 | key = random.PRNGKey(314) 80 | timesteps = 15 81 | delta = 1.0 82 | A = jnp.array([ 83 | [1, 0, delta, 0], 84 | [0, 1, 0, delta], 85 | [0, 0, 1, 0], 86 | [0, 0, 0, 1] 87 | ]) 88 | 89 | C = jnp.array([ 90 | [1, 0, 0, 0], 91 | [0, 1, 0, 0] 92 | ]) 93 | 94 | state_size, _ = A.shape 95 | observation_size, _ = C.shape 96 | 97 | Q = jnp.eye(state_size) * 0.001 98 | R = jnp.eye(observation_size) * 1.0 99 | # Prior parameter distribution 100 | mu0 = jnp.array([8, 10, 1, 0]).astype(float) 101 | Sigma0 = jnp.eye(state_size) * 1.0 102 | 103 | lds_instance = LDS(A, C, Q, R, mu0, Sigma0) 104 | result = sample_filter_smooth(key, lds_instance, timesteps) 105 | 106 | l2_filter = jnp.linalg.norm(result["z_hist"][:, :2] - result["mu_hist"][:, :2], 2) 107 | l2_smooth = jnp.linalg.norm(result["z_hist"][:, :2] - result["mu_hist_smooth"][:, :2], 2) 108 | 109 | print(f"L2-filter: {l2_filter:0.4f}") 110 | print(f"L2-smooth: {l2_smooth:0.4f}") 111 | 112 | dict_figures = {} 113 | 114 | fig_truth, axs = plt.subplots() 115 | axs.plot(result["x_hist"][:, 0], result["x_hist"][:, 1], 116 | marker="o", linewidth=0, markerfacecolor="none", 117 | markeredgewidth=2, markersize=8, 118 | label="observed", c="tab:green") 119 | 120 | axs.plot(result["z_hist"][:, 0], result["z_hist"][:, 1], 121 | linewidth=2, label="truth", 122 | marker="s", markersize=8) 123 | axs.legend() 124 | axs.axis("equal") 125 | dict_figures["kalman-tracking-truth"] = fig_truth 126 | 127 | fig_filtered, axs = plt.subplots() 128 | plot_tracking_values(result["x_hist"], result["mu_hist"], result["Sigma_hist"], "filtered", axs) 129 | dict_figures["kalman-tracking-filtered"] = fig_filtered 130 | 131 | fig_smoothed, axs = plt.subplots() 132 | plot_tracking_values(result["x_hist"], result["mu_hist_smooth"], result["Sigma_hist_smooth"], "smoothed", axs) 133 | dict_figures["kalman-tracking-smoothed"] = fig_smoothed 134 | 135 | return dict_figures 136 | 137 | 138 | if __name__ == "__main__": 139 | from jsl.demos.plot_utils import savefig 140 | 141 | figures = main() 142 | savefig(figures) 143 | plt.show() 144 | -------------------------------------------------------------------------------- /jsl/demos/lds_sampling_demo.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import jax.numpy as jnp 3 | from jax.random import PRNGKey 4 | import tensorflow_probability as tfp 5 | import tensorflow as tf 6 | tfd = tfp.distributions 7 | import matplotlib.pyplot as plt 8 | 9 | from jsl.lds.kalman_filter import LDS, kalman_filter 10 | from jsl.lds.kalman_sampler import smooth_sampler 11 | 12 | 13 | # Define the 1-d LDS model using the LDS model in the packate tensorflow_probability 14 | ndims = 1 15 | step_std = 1.0 16 | noise_std = 5.0 17 | model = tfd.LinearGaussianStateSpaceModel( 18 | num_timesteps=100, 19 | transition_matrix=tf.linalg.LinearOperatorDiag(jnp.array([1.01])), 20 | transition_noise=tfd.MultivariateNormalDiag( 21 | scale_diag=step_std * tf.ones([ndims])), 22 | observation_matrix=tf.linalg.LinearOperatorIdentity(ndims), 23 | observation_noise=tfd.MultivariateNormalDiag( 24 | scale_diag=noise_std * tf.ones([ndims])), 25 | initial_state_prior=tfd.MultivariateNormalDiag(loc=jnp.array([5.0]), 26 | scale_diag=tf.ones([ndims]))) 27 | 28 | # Sample from the prior of the LDS 29 | y = model.sample() 30 | # Posterior sampling of the state variable using the built in method of the LDS model (for the sake of comparison) 31 | smps = model.posterior_sample(y, sample_shape=50) 32 | s_tf = jnp.array(smps[:,:,0]) 33 | 34 | # Define the same model as an LDS object defined in the kalman_filter file 35 | A = jnp.eye(1) * 1.01 36 | C = jnp.eye(1) 37 | Q = jnp.eye(1) 38 | R = jnp.eye(1) * 25.0 39 | mu0 = jnp.array([5.0]) 40 | Sigma0 = jnp.eye(1) 41 | model_lds = LDS(A, C, Q, R, mu0, Sigma0) 42 | # Run the Kalman filter algorithm first 43 | mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist = kalman_filter(model_lds, jnp.array(y)) 44 | # Sample backwards using the smoothing posterior 45 | smooth_sample = smooth_sampler(model_lds, PRNGKey(0), mu_hist, Sigma_hist, n_samples=50) 46 | 47 | 48 | # Plot the observation and posterior samples of state variables 49 | plt.plot(y, color='red', label='Observation') # Observation 50 | plt.plot(s_tf.T, alpha=0.12, color='blue') # Samples using TF built in function 51 | plt.plot(smooth_sample[:,:,0].T, alpha=0.12, color='green') # Kalman smoother backwards sampler. 52 | plt.legend() 53 | plt.show() 54 | -------------------------------------------------------------------------------- /jsl/demos/linreg_kf.py: -------------------------------------------------------------------------------- 1 | # Online Bayesian linear regression in 1d using Kalman Filter 2 | # Based on: https://github.com/probml/pmtk3/blob/master/demos/linregOnlineDemoKalman.m 3 | 4 | # The latent state corresponds to the current estimate of the regression weights w. 5 | # The observation model has the form 6 | # p(y(t) | w(t), x(t)) = Gauss( C(t) * w(t), R(t)) 7 | # where C(t) = X(t,:) is the observation matrix for step t. 8 | # The dynamics model has the form 9 | # p(w(t) | w(t-1)) = Gauss(A * w(t-1), Q) 10 | # where Q>0 allows for parameter drift. 11 | # We show that the result is equivalent to batch (offline) Bayesian inference. 12 | 13 | import jax.numpy as jnp 14 | import matplotlib.pyplot as plt 15 | from numpy.linalg import inv 16 | from jsl.lds.kalman_filter import LDS, kalman_filter 17 | 18 | 19 | def kf_linreg(X, y, R, mu0, Sigma0, F, Q): 20 | """ 21 | Online estimation of a linear regression 22 | using Kalman Filters 23 | Parameters 24 | ---------- 25 | X: array(n_obs, dimension) 26 | Matrix of features 27 | y: array(n_obs,) 28 | Array of observations 29 | Q: float 30 | Known variance 31 | mu0: array(dimension) 32 | Prior mean 33 | Sigma0: array(dimesion, dimension) 34 | Prior covariance matrix 35 | Returns 36 | ------- 37 | * array(n_obs, dimension) 38 | Online estimation of parameters 39 | * array(n_obs, dimension, dimension) 40 | Online estimation of uncertainty 41 | """ 42 | C = lambda t: X[t][None, ...] 43 | lds = LDS(F, C, Q, R, mu0, Sigma0) 44 | 45 | mu_hist, Sigma_hist, _, _ = kalman_filter(lds, y) 46 | return mu_hist, Sigma_hist 47 | 48 | 49 | def posterior_lreg(X, y, R, mu0, Sigma0): 50 | """ 51 | Compute mean and covariance matrix of a 52 | Bayesian Linear regression 53 | Parameters 54 | ---------- 55 | X: array(n_obs, dimension) 56 | Matrix of features 57 | y: array(n_obs,) 58 | Array of observations 59 | R: float 60 | Known variance 61 | mu0: array(dimension) 62 | Prior mean 63 | Sigma0: array(dimesion, dimension) 64 | Prior covariance matrix 65 | Returns 66 | ------- 67 | * array(dimension) 68 | Posterior mean 69 | * array(n_obs, dimension, dimension) 70 | Posterior covariance matrix 71 | """ 72 | Sn_bayes_inv = inv(Sigma0) + X.T @ X / R.item() 73 | b = inv(Sigma0) @ mu0 + X.T @ y / R.item() 74 | mn_bayes = jnp.linalg.solve(Sn_bayes_inv, b) 75 | 76 | return mn_bayes, Sn_bayes_inv 77 | 78 | def main(): 79 | n_obs = 21 80 | timesteps = jnp.arange(n_obs) 81 | x = jnp.linspace(0, 20, n_obs) 82 | X = jnp.c_[jnp.ones(n_obs), x] 83 | F = jnp.eye(2) 84 | mu0 = jnp.zeros(2) 85 | Sigma0 = jnp.eye(2) * 10. 86 | 87 | Q, R = 0, 1 88 | Q, R = jnp.asarray([[Q]]), jnp.asarray([[R]]) 89 | # Data from original matlab example 90 | y = jnp.array([2.4865, -0.3033, -4.0531, -4.3359, -6.1742, -5.604, -3.5069, -2.3257, -4.6377, 91 | -0.2327, -1.9858, 1.0284, -2.264, -0.4508, 1.1672, 6.6524, 4.1452, 5.2677, 6.3403, 9.6264, 14.7842]) 92 | 93 | # Online estimation 94 | mu_hist, Sigma_hist = kf_linreg(X, y, R, mu0, Sigma0, F, Q) 95 | kf_var = Sigma_hist[-1, [0, 1], [0, 1]] 96 | w0_hist, w1_hist = mu_hist.T 97 | w0_err, w1_err = jnp.sqrt(Sigma_hist[:, [0, 1], [0, 1]].T) 98 | 99 | # Offline estimation 100 | (w0_post, w1_post), inv_Sigma_post = posterior_lreg(X, y, R, mu0, Sigma0) 101 | Sigma_post = inv(inv_Sigma_post) 102 | w0_std, w1_std = jnp.sqrt(Sigma_post[[0, 1], [0, 1]]) 103 | 104 | dict_figures = {} 105 | 106 | fig, ax = plt.subplots() 107 | ax.errorbar(timesteps, w0_hist, w0_err, fmt="-o", label="$w_0$", color="black", fillstyle="none") 108 | ax.errorbar(timesteps, w1_hist, w1_err, fmt="-o", label="$w_1$", color="tab:red") 109 | 110 | ax.axhline(y=w0_post, c="black", label="$w_0$ batch") 111 | ax.axhline(y=w1_post, c="tab:red", linestyle="--", label="$w_1$ batch") 112 | 113 | ax.fill_between(timesteps, w0_post - w0_std, w0_post + w0_std, color="black", alpha=0.4) 114 | ax.fill_between(timesteps, w1_post - w1_std, w1_post + w1_std, color="tab:red", alpha=0.4) 115 | 116 | plt.legend() 117 | ax.set_xlabel("time") 118 | ax.set_ylabel("weights") 119 | ax.set_ylim(-8, 4) 120 | ax.set_xlim(-0.5, n_obs) 121 | dict_figures["linreg_online_kalman"] = fig 122 | return dict_figures 123 | 124 | 125 | if __name__ == "__main__": 126 | from jsl.demos.plot_utils import savefig 127 | plt.rcParams["axes.spines.right"] = False 128 | plt.rcParams["axes.spines.top"] = False 129 | dict_figures = main() 130 | savefig(dict_figures) 131 | plt.show() -------------------------------------------------------------------------------- /jsl/demos/logreg_biclusters.py: -------------------------------------------------------------------------------- 1 | # Bayesian logistic regression in 2d for 2 class problem 2 | # We compare MCMC to Laplace 3 | 4 | # Dependencies: 5 | # * !pip install git+https://github.com/blackjax-devs/blackjax.git 6 | 7 | import jax 8 | import numpy as np 9 | import jax.numpy as jnp 10 | import matplotlib.pyplot as plt 11 | from blackjax import rmh 12 | from jax import random 13 | from functools import partial 14 | from jax.scipy.optimize import minimize 15 | from jax.scipy.stats import norm 16 | 17 | 18 | def sigmoid(x): return jnp.exp(x) / (1 + jnp.exp(x)) 19 | def log_sigmoid(z): return z - jnp.log1p(jnp.exp(z)) 20 | 21 | 22 | def plot_posterior_predictive(ax, X, Xspace, Zspace, title, colors, cmap="viridis"): 23 | ax.contourf(*Xspace, Zspace, cmap=cmap, levels=20) 24 | ax.scatter(*X.T, c=colors, edgecolors="gray", s=80) 25 | ax.set_title(title) 26 | ax.axis("off") 27 | plt.tight_layout() 28 | 29 | 30 | def inference_loop(rng_key, kernel, initial_state, num_samples): 31 | def one_step(state, rng_key): 32 | state, _ = kernel(rng_key, state) 33 | return state, state 34 | 35 | keys = jax.random.split(rng_key, num_samples) 36 | _, states = jax.lax.scan(one_step, initial_state, keys) 37 | 38 | return states 39 | 40 | 41 | def E_base(w, Phi, y, alpha): 42 | # Energy is the log joint 43 | an = Phi @ w 44 | log_an = log_sigmoid(an) 45 | log_likelihood_term = y * log_an + (1 - y) * jnp.log(1 - sigmoid(an)) 46 | log_prior_term = -(alpha * w @ w / 2) 47 | return log_prior_term + log_likelihood_term.sum() 48 | 49 | 50 | def mcmc_logistic_posterior_sample(key, Phi, y, alpha=1.0, init_noise=1.0, 51 | n_samples=5_000, burnin=500, sigma_mcmc=0.8): 52 | """ 53 | Sample from the posterior distribution of the weights 54 | of a 2d binary logistic regression model p(y=1|x,w) = sigmoid(w'x), 55 | using the random walk Metropolis-Hastings algorithm. 56 | """ 57 | _, ndims = Phi.shape 58 | key, key_init = random.split(key) 59 | w0 = random.multivariate_normal(key, jnp.zeros(ndims), jnp.eye(ndims) * init_noise) 60 | energy = partial(E_base, Phi=Phi, y=y, alpha=alpha) 61 | mcmc_kernel = rmh(energy, sigma=jnp.ones(ndims) * sigma_mcmc) 62 | initial_state = mcmc_kernel.init(w0) 63 | 64 | states = inference_loop(key_init, mcmc_kernel.step, initial_state, n_samples) 65 | chains = states.position[burnin:, :] 66 | return chains 67 | 68 | 69 | def laplace_posterior(key, Phi, y, alpha=1.0, init_noise=1.0): 70 | N, M = Phi.shape 71 | w0 = random.multivariate_normal(key, jnp.zeros(M), jnp.eye(M) * init_noise) 72 | E = lambda w: -E_base(w, Phi, y, alpha) / len(y) 73 | res = minimize(E, w0, method="BFGS") 74 | w_laplace = res.x 75 | SN = jax.hessian(E)(w_laplace) 76 | return w_laplace, SN 77 | 78 | 79 | def make_dataset(seed=135): 80 | np.random.seed(seed) 81 | N = 30 82 | mu1 = np.hstack((np.ones((N, 1)), 5 * np.ones((N, 1)))) 83 | mu2 = np.hstack((-5 * np.ones((N, 1)), np.ones((N, 1)))) 84 | class1_std = 1 85 | class2_std = 1.1 86 | X_1 = np.add(class1_std * np.random.randn(N, 2), mu1) 87 | X_2 = np.add(2 * class2_std * np.random.randn(N, 2), mu2) 88 | X = np.vstack((X_1, X_2)) 89 | y = np.vstack((np.ones((N, 1)), np.zeros((N, 1)))) 90 | return X, y.ravel() 91 | 92 | 93 | def main(): 94 | key = random.PRNGKey(314) 95 | ## Data generating process 96 | X, y = make_dataset() 97 | n_datapoints = len(y) 98 | 99 | Phi = jnp.c_[jnp.ones(n_datapoints)[:, None], X] 100 | N, M = Phi.shape 101 | 102 | colors = ["black" if el else "white" for el in y] 103 | 104 | # Predictive domain 105 | xmin, ymin = X.min(axis=0) - 0.1 106 | xmax, ymax = X.max(axis=0) + 0.1 107 | step = 0.1 108 | Xspace = jnp.mgrid[xmin:xmax:step, ymin:ymax:step] 109 | _, nx, ny = Xspace.shape 110 | Phispace = jnp.concatenate([jnp.ones((1, nx, ny)), Xspace]) 111 | 112 | 113 | ## Laplace 114 | alpha = 2.0 115 | w_laplace, SN = laplace_posterior(key, Phi, y, alpha=alpha) 116 | 117 | ### MCMC Approximation 118 | chains = mcmc_logistic_posterior_sample(key, Phi, y, alpha=alpha) 119 | Z_mcmc = sigmoid(jnp.einsum("mij,sm->sij", Phispace, chains)) 120 | Z_mcmc = Z_mcmc.mean(axis=0) 121 | 122 | ### *** Ploting surface predictive distribution *** 123 | colors = ["black" if el else "white" for el in y] 124 | dict_figures = {} 125 | key = random.PRNGKey(31415) 126 | nsamples = 5000 127 | 128 | # Laplace surface predictive distribution 129 | laplace_samples = random.multivariate_normal(key, w_laplace, SN, (nsamples,)) 130 | Z_laplace = sigmoid(jnp.einsum("mij,sm->sij", Phispace, laplace_samples)) 131 | Z_laplace = Z_laplace.mean(axis=0) 132 | 133 | fig_laplace, ax = plt.subplots() 134 | title = "Laplace Predictive distribution" 135 | plot_posterior_predictive(ax, X, Xspace, Z_laplace, title, colors) 136 | dict_figures["logistic_regression_surface_laplace"] = fig_laplace 137 | 138 | # MCMC surface predictive distribution 139 | fig_mcmc, ax = plt.subplots() 140 | title = "MCMC Predictive distribution" 141 | plot_posterior_predictive(ax, X, Xspace, Z_mcmc, title, colors) 142 | dict_figures["logistic_regression_surface_mcmc"] = fig_mcmc 143 | 144 | 145 | # *** Plotting posterior marginals of weights *** 146 | for i in range(M): 147 | fig_weights_marginals, ax = plt.subplots() 148 | mean_laplace, std_laplace = w_laplace[i], jnp.sqrt(SN[i, i]) 149 | mean_mcmc, std_mcmc = chains[:, i].mean(), chains[:, i].std() 150 | 151 | x = jnp.linspace(mean_laplace - 4 * std_laplace, mean_laplace + 4 * std_laplace, 500) 152 | ax.plot(x, norm.pdf(x, mean_laplace, std_laplace), label="posterior (Laplace)", linestyle="dotted") 153 | ax.plot(x, norm.pdf(x, mean_mcmc, std_mcmc), label="posterior (MCMC)", linestyle="dashed") 154 | ax.legend() 155 | ax.set_title(f"Posterior marginals of weights ({i})") 156 | dict_figures[f"logistic_regression_weights_marginals_w{i}"] = fig_weights_marginals 157 | 158 | 159 | print("MCMC weights") 160 | w_mcmc = chains.mean(axis=0) 161 | print(w_mcmc, end="\n"*2) 162 | 163 | print("Laplace weights") 164 | print(w_laplace, end="\n"*2) 165 | 166 | dict_data = { 167 | "X": X, 168 | "y": y, 169 | "Xspace": Xspace, 170 | "Phi": Phi, 171 | "Phispace": Phispace, 172 | "w_laplace": w_laplace, 173 | "cov_laplace": SN 174 | } 175 | 176 | return dict_figures, dict_data 177 | 178 | 179 | if __name__ == "__main__": 180 | from jsl.demos.plot_utils import savefig 181 | figs, data = main() 182 | savefig(figs) 183 | plt.show() 184 | -------------------------------------------------------------------------------- /jsl/demos/old-init.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "sis_vs_smc", 3 | "ekf_vs_ukf", 4 | "ukf_mlp", 5 | "ekf_mlp", 6 | "eekf_logistic_regression", 7 | "ekf_mlp_anim", 8 | "linreg_kf", 9 | "kf_parallel", 10 | "kf_continuous_circle", 11 | "kf_spiral", 12 | "kf_tracking", 13 | "bootstrap_filter", 14 | "pendulum_1d", 15 | "ekf_continuous" 16 | ] 17 | -------------------------------------------------------------------------------- /jsl/demos/pendulum_1d.py: -------------------------------------------------------------------------------- 1 | # Example of a 1D pendulum problem applied to the Extended Kalman Filter, 2 | # the Unscented Kalman Filter, and the Particle Filter (boostrap filter) 3 | # Additionally, we test the particle filter when the observations have a 40% 4 | # probability of being perturbed by a uniform(-2, 2) distribution 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import matplotlib.pyplot as plt 9 | from jax import random 10 | from jax.ops import index_update 11 | 12 | from jsl.nlds.base import NLDS 13 | import jsl.nlds.extended_kalman_filter as ekf_lib 14 | import jsl.nlds.unscented_kalman_filter as ukf_lib 15 | import jsl.nlds.bootstrap_filter as b_lib 16 | 17 | 18 | def plot_filter_true(ax, time, estimate, obs, ground_truth, label, colors="tab:blue"): 19 | ax.plot(time, estimate, c="black", label=label) 20 | ax.scatter(time, obs, s=10, c="none", edgecolors=colors) 21 | ax.plot(time, ground_truth, c="gray", linewidth=8, alpha=0.5, 22 | label="true angle") 23 | ax.legend() 24 | 25 | 26 | def fz(x, g=0.1, dt=0.1): 27 | x1, x2 = x[0], x[1] 28 | x1_new = x1 + x2 * dt 29 | x2_new = x2 - g * jnp.sin(x1) * dt 30 | return jnp.asarray([x1_new, x2_new]) 31 | 32 | 33 | def fx(x, *args): 34 | return jnp.sin(jnp.asarray([x[0]])) 35 | 36 | 37 | def main(): 38 | # *** Define initial configuration *** 39 | g = 10 40 | dt = 0.015 41 | qc = 0.06 42 | Q = jnp.array([ 43 | [qc * dt ** 3 / 3, qc * dt ** 2 / 2], 44 | [qc * dt ** 2 / 2, qc * dt] 45 | ]) 46 | 47 | fx_vmap = jax.vmap(fx) 48 | fz_vec = jax.vmap(lambda x: fz(x, g=g, dt=dt)) 49 | 50 | nsteps = 200 51 | Rt = jnp.eye(1) * 0.02 52 | x0 = jnp.array([1.5, 0.0]).astype(float) 53 | time = jnp.arange(0, nsteps * dt, dt) 54 | 55 | key = random.PRNGKey(3141) 56 | key_samples, key_pf, key_noisy = random.split(key, 3) 57 | model = NLDS(lambda x: fz(x, g=g, dt=dt), fx, Q, Rt) 58 | sample_state, sample_obs = model.sample(key, x0, nsteps) 59 | 60 | # *** Pertubed data *** 61 | key_noisy, key_values = random.split(key_noisy) 62 | sample_obs_noise = sample_obs.copy() 63 | samples_map = random.bernoulli(key_noisy, 0.5, (nsteps,)) 64 | replacement_values = random.uniform(key_values, (samples_map.sum(),), minval=-2, maxval=2) 65 | sample_obs_noise = index_update(sample_obs_noise.ravel(), samples_map, replacement_values) 66 | colors = ["tab:red" if samp else "tab:blue" for samp in samples_map] 67 | 68 | # *** Perform filtering **** 69 | alpha, beta, kappa = 1, 0, 2 70 | state_size = 2 71 | Vinit = jnp.eye(state_size) 72 | ukf = NLDS(lambda x: fz(x, g=g, dt=dt), fx, Q, Rt, alpha, beta, kappa, state_size) 73 | particle_filter = NLDS(fz_vec, fx_vmap, Q, Rt) 74 | 75 | print("Filtering data...") 76 | _, ekf_hist = ekf_lib.filter(model, x0, sample_obs, return_params=["mean", "cov"]) 77 | ekf_mean_hist, ekf_Sigma_hist = ekf_hist["mean"], ekf_hist["cov"] 78 | ukf_mean_hist, ukf_Sigma_hist = ukf_lib.filter(ukf, x0, sample_obs) 79 | pf_mean_hist = b_lib.filter(particle_filter, key_pf, x0, sample_obs, nsamples=4_000, Vinit=Vinit) 80 | 81 | print("Filtering outlier data...") 82 | _, ekf_perturbed_hist = ekf_lib.filter(model, x0, sample_obs_noise, return_params=["mean", "cov"]) 83 | ekf_perturbed_mean_hist, ekf_Sigma_hist = ekf_perturbed_hist["mean"], ekf_perturbed_hist["cov"] 84 | ukf_perturbed_mean_hist, ukf_Sigma_hist = ukf_lib.filter(ukf, x0, sample_obs_noise) 85 | pf_perturbed_mean_hist = b_lib.filter(particle_filter, key_pf, x0, sample_obs_noise, nsamples=2_000) 86 | 87 | ekf_estimate = fx_vmap(ekf_mean_hist) 88 | ukf_estimate = fx_vmap(ukf_mean_hist) 89 | pf_estimate = fx_vmap(pf_mean_hist) 90 | 91 | ekf_perturbed_estimate = fx_vmap(ekf_perturbed_mean_hist) 92 | ukf_perturbed_estimate = fx_vmap(ukf_perturbed_mean_hist) 93 | pf_perturbed_estimate = fx_vmap(pf_perturbed_mean_hist) 94 | ground_truth = fx_vmap(sample_state) 95 | 96 | dict_figures = {} 97 | # *** Plot results *** 98 | fig, ax = plt.subplots() 99 | plot_filter_true(ax, time, ekf_estimate, sample_obs, ground_truth, "Extended KF") 100 | dict_figures["pendulum_ekf_1d_demo"] = fig 101 | 102 | fig, ax = plt.subplots() 103 | plot_filter_true(ax, time, ukf_estimate, sample_obs, ground_truth, "Unscented KF") 104 | dict_figures["pendulum_ukf_1d_demo"] = fig 105 | 106 | fig, ax = plt.subplots() 107 | plot_filter_true(ax, time, pf_estimate, sample_obs, ground_truth, "Bootstrap PF") 108 | dict_figures["pendulum_pf_1d_demo"] = fig 109 | 110 | fig, ax = plt.subplots() 111 | plot_filter_true(ax, time, pf_perturbed_estimate, sample_obs_noise, 112 | ground_truth, "Bootstrap PF (noisy)", colors=colors) 113 | dict_figures["pendulum_pf_noisy_1d_demo"] = fig 114 | 115 | fig, ax = plt.subplots() 116 | plot_filter_true(ax, time, ekf_perturbed_estimate, sample_obs_noise, 117 | ground_truth, "Extended KF (noisy)", colors=colors) 118 | dict_figures["pendulum_ekf_noisy_1d_demo"] = fig 119 | 120 | fig, ax = plt.subplots() 121 | plot_filter_true(ax, time, ukf_perturbed_estimate, sample_obs_noise, 122 | ground_truth, "Unscented KF (noisy)", colors=colors) 123 | dict_figures["pendulum_ukf_noisy_1d_demo"] = fig 124 | 125 | return dict_figures 126 | 127 | 128 | if __name__ == "__main__": 129 | from jsl.demos.plot_utils import savefig 130 | 131 | plt.rcParams["axes.spines.right"] = False 132 | plt.rcParams["axes.spines.top"] = False 133 | figures = main() 134 | savefig(figures) 135 | plt.show() 136 | -------------------------------------------------------------------------------- /jsl/demos/plot_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from numpy import linalg 5 | from matplotlib.patches import Ellipse, transforms 6 | from mpl_toolkits.mplot3d import Axes3D 7 | 8 | 9 | # https://matplotlib.org/devdocs/gallery/statistics/confidence_ellipse.html 10 | def plot_ellipse(Sigma, mu, ax, n_std=3.0, facecolor='none', edgecolor='k', plot_center='true', **kwargs): 11 | cov = Sigma 12 | pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1]) 13 | 14 | ell_radius_x = np.sqrt(1 + pearson) 15 | ell_radius_y = np.sqrt(1 - pearson) 16 | ellipse = Ellipse((0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2, 17 | facecolor=facecolor, edgecolor=edgecolor, **kwargs) 18 | 19 | scale_x = np.sqrt(cov[0, 0]) * n_std 20 | mean_x = mu[0] 21 | 22 | scale_y = np.sqrt(cov[1, 1]) * n_std 23 | mean_y = mu[1] 24 | 25 | transf = (transforms.Affine2D() 26 | .rotate_deg(45) 27 | .scale(scale_x, scale_y) 28 | .translate(mean_x, mean_y)) 29 | 30 | ellipse.set_transform(transf + ax.transData) 31 | 32 | if plot_center: 33 | ax.plot(mean_x, mean_y, '.') 34 | return ax.add_patch(ellipse) 35 | 36 | 37 | def savedotfile(dotfiles): 38 | if "FIGDIR" in os.environ: 39 | figdir = os.environ["FIGDIR"] 40 | for name, dot in dotfiles.items(): 41 | fname_full = os.path.join(figdir, name) 42 | dot.render(fname_full) 43 | print(f"saving dot file to {fname_full}") 44 | 45 | 46 | def savefig(figures, *args, **kwargs): 47 | if "FIGDIR" in os.environ: 48 | figdir = os.environ["FIGDIR"] 49 | for name, figure in figures.items(): 50 | fname_full = os.path.join(figdir, name) 51 | print(f"saving image to {fname_full}") 52 | figure.savefig(f"{fname_full}.pdf", *args, **kwargs) 53 | figure.savefig(f"{fname_full}.png", *args, **kwargs) 54 | 55 | 56 | def scale_3d(ax, x_scale, y_scale, z_scale, factor): 57 | scale=np.diag([x_scale, y_scale, z_scale, 1.0]) 58 | scale=scale*(1.0/scale.max()) 59 | scale[3,3]=factor 60 | def short_proj(): 61 | return np.dot(Axes3D.get_proj(ax), scale) 62 | return short_proj 63 | 64 | def style3d(ax, x_scale, y_scale, z_scale, factor=0.62): 65 | plt.gca().patch.set_facecolor('white') 66 | ax.w_xaxis.set_pane_color((0, 0, 0, 0)) 67 | ax.w_yaxis.set_pane_color((0, 0, 0, 0)) 68 | ax.w_zaxis.set_pane_color((0, 0, 0, 0)) 69 | ax.get_proj = scale_3d(ax, x_scale, y_scale, z_scale, factor) 70 | 71 | 72 | def kdeg(x, X, h): 73 | """ 74 | KDE under a gaussian kernel 75 | 76 | Parameters 77 | ---------- 78 | x: array(eval, D) 79 | X: array(obs, D) 80 | h: float 81 | 82 | Returns 83 | ------- 84 | array(eval): 85 | KDE around the observed values 86 | """ 87 | N, D = X.shape 88 | nden, _ = x.shape 89 | 90 | Xhat = X.reshape(D, 1, N) 91 | xhat = x.reshape(D, nden, 1) 92 | u = xhat - Xhat 93 | u = linalg.norm(u, ord=2, axis=0) ** 2 / (2 * h ** 2) 94 | px = np.exp(-u).sum(axis=1) / (N * h * np.sqrt(2 * np.pi)) 95 | return px 96 | -------------------------------------------------------------------------------- /jsl/demos/rbpf_maneuver.py: -------------------------------------------------------------------------------- 1 | # Rao-Blackwellised particle filtering for jump markov linear systems 2 | # Based on: https://github.com/probml/pmtk3/blob/master/demos/rbpfManeuverDemo.m 3 | 4 | # !pip install matplotlib==3.4.2 5 | 6 | import jax 7 | import numpy as np 8 | import jax.numpy as jnp 9 | import seaborn as sns 10 | import matplotlib.pyplot as plt 11 | import jsl.lds.mixture_kalman_filter as kflib 12 | from jax import random 13 | from functools import partial 14 | from sklearn.preprocessing import OneHotEncoder 15 | from jax.scipy.special import logit 16 | from jsl.demos.plot_utils import kdeg, style3d 17 | 18 | 19 | 20 | def plot_3d_belief_state(mu_hist, dim, ax, skip=3, npoints=2000, azimuth=-30, elevation=30, h=0.5): 21 | nsteps = len(mu_hist) 22 | xmin, xmax = mu_hist[..., dim].min(), mu_hist[..., dim].max() 23 | xrange = jnp.linspace(xmin, xmax, npoints).reshape(-1, 1) 24 | res = np.apply_along_axis(lambda X: kdeg(xrange, X[..., None], h), 1, mu_hist) 25 | densities = res[..., dim] 26 | for t in range(0, nsteps, skip): 27 | tloc = t * np.ones(npoints) 28 | px = densities[t] 29 | ax.plot(tloc, xrange, px, c="tab:blue", linewidth=1) 30 | ax.set_zlim(0, 1) 31 | style3d(ax, 1.8, 1.2, 0.7, 0.8) 32 | ax.view_init(elevation, azimuth) 33 | ax.set_xlabel(r"$t$", fontsize=13) 34 | ax.set_ylabel(r"$x_{"f"d={dim}"",t}$", fontsize=13) 35 | ax.set_zlabel(r"$p(x_{d, t} \vert y_{1:t})$", fontsize=13) 36 | 37 | 38 | def main(): 39 | TT = 0.1 40 | A = jnp.array([[1, TT, 0, 0], 41 | [0, 1, 0, 0], 42 | [0, 0, 1, TT], 43 | [0, 0, 0, 1]]) 44 | 45 | 46 | B1 = jnp.array([0, 0, 0, 0]) 47 | B2 = jnp.array([-1.225, -0.35, 1.225, 0.35]) 48 | B3 = jnp.array([1.225, 0.35, -1.225, -0.35]) 49 | B = jnp.stack([B1, B2, B3], axis=0) 50 | 51 | Q = 0.2 * jnp.eye(4) 52 | R = 10 * jnp.diag(jnp.array([2, 1, 2, 1])) 53 | C = jnp.eye(4) 54 | 55 | 56 | transition_matrix = jnp.array([ 57 | [0.8, 0.1, 0.1], 58 | [0.1, 0.8, 0.1], 59 | [0.1, 0.1, 0.8] 60 | ]) 61 | 62 | params = kflib.RBPFParamsDiscrete(A, B, C, Q, R, transition_matrix) 63 | 64 | nparticles = 1000 65 | nsteps = 100 66 | key = random.PRNGKey(1) 67 | keys = random.split(key, nsteps) 68 | 69 | x0 = (1, random.multivariate_normal(key, jnp.zeros(4), jnp.eye(4))) 70 | draw_state_fixed = partial(kflib.draw_state, params=params) 71 | 72 | # Create target dataset 73 | _, (latent_hist, state_hist, obs_hist) = jax.lax.scan(draw_state_fixed, x0, keys) 74 | 75 | # Perform filtering 76 | key_base = random.PRNGKey(31) 77 | key_mean_init, key_sample, key_state, key_next = random.split(key_base, 4) 78 | p_init = jnp.array([0.0, 1.0, 0.0]) 79 | 80 | # Initial filter configuration 81 | mu_0 = 0.01 * random.normal(key_mean_init, (nparticles, 4)) 82 | Sigma_0 = jnp.zeros((nparticles, 4,4)) 83 | s0 = random.categorical(key_state, logit(p_init), shape=(nparticles,)) 84 | weights_0 = jnp.ones(nparticles) / nparticles 85 | init_config = (key_next, mu_0, Sigma_0, weights_0, s0) 86 | 87 | rbpf_optimal_part = partial(kflib.rbpf_optimal, params=params, nparticles=nparticles) 88 | _, (mu_hist, Sigma_hist, weights_hist, s_hist, Ptk) = jax.lax.scan(rbpf_optimal_part, init_config, obs_hist) 89 | mu_hist_post_mean = jnp.einsum("ts,tsm->tm", weights_hist, mu_hist) 90 | 91 | 92 | dict_figures = {} 93 | # Plot target dataset 94 | color_dict = {0: "tab:green", 1: "tab:red", 2: "tab:blue"} 95 | fig, ax = plt.subplots() 96 | color_states_org = [color_dict[state] for state in np.array(latent_hist)] 97 | ax.scatter(*state_hist[:, [0, 2]].T, c="none", edgecolors=color_states_org, s=10) 98 | ax.scatter(*obs_hist[:, [0, 2]].T, s=5, c="black", alpha=0.6) 99 | ax.set_title("Data") 100 | dict_figures["rbpf-maneuver-data"] = fig 101 | 102 | # Plot filtered dataset 103 | fig, ax = plt.subplots() 104 | rbpf_mse = ((mu_hist_post_mean - state_hist)[:, [0, 2]] ** 2).mean(axis=0).sum() 105 | latent_hist_est = Ptk.mean(axis=1).argmax(axis=1) 106 | color_states_est = [color_dict[state] for state in np.array(latent_hist_est)] 107 | ax.scatter(*mu_hist_post_mean[:, [0, 2]].T, c="none", edgecolors=color_states_est, s=10) 108 | ax.set_title(f"RBPF MSE: {rbpf_mse:.2f}") 109 | dict_figures["rbpf-maneuver-trace"] = fig 110 | 111 | # Plot belief state of discrete system 112 | p_terms = Ptk.mean(axis=1) 113 | rbpf_error_rate = (latent_hist != p_terms.argmax(axis=1)).mean() 114 | fig, ax = plt.subplots(figsize=(2.5, 5)) 115 | sns.heatmap(p_terms, cmap="viridis", cbar=False) 116 | plt.title(f"RBPF, error rate: {rbpf_error_rate:0.3}") 117 | dict_figures["rbpf-maneuver-discrete-belief"] = fig 118 | 119 | # Plot ground truth and MAP estimate 120 | ohe = OneHotEncoder(sparse=False) 121 | latent_hmap = ohe.fit_transform(latent_hist[:, None]) 122 | latent_hmap_est = ohe.fit_transform(p_terms.argmax(axis=1)[:, None]) 123 | 124 | fig, ax = plt.subplots(figsize=(2.5, 5)) 125 | sns.heatmap(latent_hmap, cmap="viridis", cbar=False, ax=ax) 126 | ax.set_title("Data") 127 | dict_figures["rbpf-maneuver-discrete-ground-truth.pdf"] = fig 128 | 129 | fig, ax = plt.subplots(figsize=(2.5, 5)) 130 | sns.heatmap(latent_hmap_est, cmap="viridis", cbar=False, ax=ax) 131 | ax.set_title(f"MAP (error rate: {rbpf_error_rate:0.4f})") 132 | dict_figures["rbpf-maneuver-discrete-map"] = fig 133 | 134 | # Plot belief for state space 135 | dims = [0, 2] 136 | for dim in dims: 137 | fig = plt.figure() 138 | ax = plt.axes(projection="3d") 139 | plot_3d_belief_state(mu_hist, dim, ax, h=1.1) 140 | # pml.savefig(f"rbpf-maneuver-belief-states-dim{dim}.pdf", pad_inches=0, bbox_inches="tight") 141 | dict_figures[f"rbpf-maneuver-belief-states-dim{dim}.pdf"] = fig 142 | 143 | return dict_figures 144 | 145 | 146 | if __name__ == "__main__": 147 | from jsl.demos.plot_utils import savefig 148 | plt.rcParams["axes.spines.right"] = False 149 | plt.rcParams["axes.spines.top"] = False 150 | dict_figures = main() 151 | savefig(dict_figures, pad_inches=0, bbox_inches="tight") 152 | plt.show() 153 | -------------------------------------------------------------------------------- /jsl/demos/sis_vs_smc.py: -------------------------------------------------------------------------------- 1 | # This demo compares sequential importance sampling (SIS) to 2 | # sequential Monte Carlo (SMC) in the case of a non-markovian 3 | # Gaussian sequence model. 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import matplotlib.pyplot as plt 8 | 9 | from jsl.nlds.sequential_monte_carlo import NonMarkovianSequenceModel 10 | 11 | def find_path(ix_path, final_state): 12 | curr_state = final_state 13 | path = [curr_state] 14 | for i in range(1, 7): 15 | curr_state, _ = ix_path[:, -i, curr_state] 16 | path.append(curr_state) 17 | path = path[::-1] 18 | return path 19 | 20 | 21 | def plot_sis_weights(hist, n_steps, spacing=1.5, max_size=0.3): 22 | """ 23 | Plot the evolution of weights in the sequential importance sampling (SIS) algorithm. 24 | 25 | Parameters 26 | ---------- 27 | weights: array(n_particles, n_steps) 28 | Weights at each time step. 29 | n_steps: int 30 | Number of steps to plot. 31 | spacing: float 32 | Spacing between particles. 33 | max_size: float 34 | Maximum size of the particles. 35 | """ 36 | fig, ax = plt.subplots(figsize=(8, 6)) 37 | ax.set_aspect(1) 38 | weights_subset = hist["weights"][:n_steps] 39 | for col, weights_row in enumerate(weights_subset): 40 | norm_cst = weights_row.sum() 41 | radii = weights_row / norm_cst * max_size 42 | for row, rad in enumerate(radii): 43 | if col != n_steps - 1: 44 | plt.arrow(spacing * (col + 0.25), row, 0.6, 0, width=0.05, 45 | edgecolor="white", facecolor="tab:gray") 46 | circle = plt.Circle((spacing * col, row), rad, color="tab:red") 47 | ax.add_artist(circle) 48 | 49 | plt.xlim(-1, n_steps * spacing) 50 | plt.xlabel("Iteration (t)") 51 | plt.ylabel("Particle index (i)") 52 | 53 | xticks_pos = jnp.arange(0, n_steps * spacing - 1, 2) 54 | xticks_lab = jnp.arange(1, n_steps + 1) 55 | plt.xticks(xticks_pos, xticks_lab) 56 | 57 | return fig, ax 58 | 59 | 60 | def plot_smc_weights(hist, n_steps, spacing=1.5, max_size=0.3): 61 | """ 62 | Plot the evolution of weights in the sequential Monte Carlo (SMC) algorithm. 63 | 64 | Parameters 65 | ---------- 66 | weights: array(n_particles, n_steps) 67 | Weights at each time step. 68 | n_steps: int 69 | Number of steps to plot. 70 | spacing: float 71 | Spacing between particles. 72 | max_size: float 73 | Maximum size of the particles. 74 | 75 | Returns 76 | ------- 77 | fig: matplotlib.figure.Figure 78 | Figure containing the plot. 79 | """ 80 | fig, ax = plt.subplots(figsize=(8, 6)) 81 | ax.set_aspect(1) 82 | 83 | weights_subset = hist["weights"][:n_steps] 84 | # sampled indices represent the "position" of weights at the next time step 85 | ix_subset = hist["indices"][:n_steps][1:] 86 | 87 | for it, (weights_row, p_target) in enumerate(zip(weights_subset, ix_subset)): 88 | norm_cst = weights_row.sum() 89 | radii = weights_row / norm_cst * max_size 90 | 91 | for particle_ix, (rad, target_ix) in enumerate(zip(radii, p_target)): 92 | if it != n_steps - 2: 93 | diff = particle_ix - target_ix 94 | plt.arrow(spacing * (it + 0.15), target_ix, 1.3, diff, width=0.05, 95 | edgecolor="white", facecolor="tab:gray", length_includes_head=True) 96 | circle = plt.Circle((spacing * it, particle_ix), rad, color="tab:blue") 97 | ax.add_artist(circle) 98 | 99 | plt.xlim(-1, n_steps * spacing - 2) 100 | plt.xlabel("Iteration (t)") 101 | plt.ylabel("Particle index (i)") 102 | 103 | xticks_pos = jnp.arange(0, n_steps * spacing - 2, 2) 104 | xticks_lab = jnp.arange(1, n_steps) 105 | plt.xticks(xticks_pos, xticks_lab) 106 | 107 | # ylims = ax.axes.get_ylim() # to-do: grab this value for SCM-particle descendents' plot 108 | 109 | return fig, ax 110 | 111 | 112 | def plot_smc_weights_unique(hist, n_steps, spacing=1.5, max_size=0.3): 113 | """ 114 | We plot the evolution of particles that have been consistently resampled and form our final 115 | approximation of the target distribution using sequential Monte Carlo (SMC). 116 | 117 | Parameters 118 | ---------- 119 | weights: array(n_particles, n_steps) 120 | Weights at each time step. 121 | n_steps: int 122 | Number of steps to plot. 123 | spacing: float 124 | Spacing between particles. 125 | max_size: float 126 | Maximum size of the particles. 127 | 128 | 129 | Returns 130 | ------- 131 | fig: matplotlib.figure.Figure 132 | Figure containing the plot. 133 | """ 134 | weights_subset = hist["weights"][:n_steps] 135 | # sampled indices represent the "position" of weights at the next time step 136 | ix_subset = hist["indices"][:n_steps][1:] 137 | ix_path = ix_subset[:n_steps - 2] 138 | ix_map = jnp.repeat(jnp.arange(5)[None, :], 6, axis=0) 139 | ix_path = jnp.stack([ix_path, ix_map], axis=0) 140 | 141 | fig, ax = plt.subplots(figsize=(8, 6)) 142 | ax.set_aspect(1) 143 | 144 | for final_state in range(5): 145 | path = find_path(ix_path, final_state) 146 | path_beg, path_end = path[:-1], path[1:] 147 | for it, (beg, end) in enumerate(zip(path_beg, path_end)): 148 | diff = end - beg 149 | plt.arrow(spacing * (it + 0.15), beg, 1.3, diff, width=0.05, 150 | edgecolor="white", facecolor="tab:gray", alpha=1.0, length_includes_head=True) 151 | 152 | for it, weights_row in enumerate(weights_subset[:-1]): 153 | norm_cst = weights_row.sum() 154 | radii = weights_row / norm_cst * max_size 155 | 156 | for particle_ix, rad in enumerate(radii): 157 | circle = plt.Circle((spacing * it, particle_ix), rad, color="tab:blue") 158 | ax.add_artist(circle) 159 | 160 | plt.xlim(-1, n_steps * spacing - 2) 161 | plt.xlabel("Iteration (t)") 162 | plt.ylabel("Particle index (i)") 163 | 164 | xticks_pos = jnp.arange(0, n_steps * spacing - 2, 2) 165 | xticks_lab = jnp.arange(1, n_steps) 166 | 167 | plt.xticks(xticks_pos, xticks_lab) 168 | 169 | return fig, ax 170 | 171 | 172 | def main(): 173 | params = { 174 | "phi": 0.9, 175 | "q": 1.0, 176 | "beta": 0.5, 177 | "r": 1.0, 178 | } 179 | 180 | key = jax.random.PRNGKey(314) 181 | key_sample, key_sis, key_scm = jax.random.split(key, 3) 182 | seq_model = NonMarkovianSequenceModel(**params) 183 | hist_target = seq_model.sample(key_sample, 100) 184 | observations = hist_target["y"] 185 | 186 | res_sis = seq_model.sequential_importance_sample(key_sis, observations, n_particles=5) 187 | res_smc = seq_model.sequential_monte_carlo(key_scm, observations, n_particles=5) 188 | 189 | # Plot SMC particle evolution 190 | n_steps = 6 + 2 191 | spacing = 2 192 | 193 | dict_figures = {} 194 | 195 | fig, ax = plot_sis_weights(res_sis, n_steps=7, spacing=spacing) 196 | plt.tight_layout() 197 | dict_figures["sis_weights"] = fig 198 | 199 | fig, ax = plot_smc_weights(res_smc, n_steps=n_steps, spacing=spacing) 200 | ylims = ax.axes.get_ylim() 201 | plt.tight_layout() 202 | dict_figures["smc_weights"] = fig 203 | 204 | fig, ax = plot_smc_weights_unique(res_smc, n_steps=n_steps, spacing=spacing) 205 | ax.set_ylim(*ylims) 206 | plt.tight_layout() 207 | dict_figures["smc_weights_unique"] = fig 208 | 209 | return dict_figures 210 | 211 | 212 | if __name__ == "__main__": 213 | from jsl.demos.plot_utils import savefig 214 | figures = main() 215 | savefig(figures) 216 | plt.show() -------------------------------------------------------------------------------- /jsl/demos/superimport_test.py: -------------------------------------------------------------------------------- 1 | import superimport #https://github.com/probml/superimport 2 | 3 | import arviz as az 4 | 5 | from itertools import chain 6 | import jax 7 | import jax.numpy as jnp 8 | import matplotlib.pyplot as plt 9 | import blackjax.rmh as rmh 10 | from jax import random 11 | from functools import partial 12 | from jax.scipy.optimize import minimize 13 | from sklearn.datasets import make_biclusters 14 | from ..nlds.extended_kalman_filter import ExtendedKalmanFilter 15 | from jax.scipy.stats import norm 16 | 17 | print('hello world') -------------------------------------------------------------------------------- /jsl/demos/ukf_mlp.py: -------------------------------------------------------------------------------- 1 | # Demo showcasing the training of an MLP with a single hidden layer using 2 | # Unscented Kalman Filtering (UKF). 3 | # In this demo, we consider the latent state to be the weights of an MLP. 4 | # The observed state at time t is the output of the MLP as influenced by the weights 5 | # at time t-1 and the covariate x[t]. 6 | # The transition function between latent states is the identity function. 7 | # For more information, see 8 | # * UKF-based training algorithm for feed-forward neural networks with 9 | # application to XOR classification problem 10 | # https://ieeexplore.ieee.org/document/6234549 11 | 12 | import jax.numpy as jnp 13 | from jax import vmap 14 | from jax.random import PRNGKey, split, normal 15 | from jax.flatten_util import ravel_pytree 16 | 17 | import matplotlib.pyplot as plt 18 | from functools import partial 19 | 20 | import jsl.nlds.unscented_kalman_filter as ukf_lib 21 | from jsl.nlds.base import NLDS 22 | from jsl.demos.ekf_mlp import MLP, sample_observations, apply 23 | from jsl.demos.ekf_mlp import plot_mlp_prediction, plot_intermediate_steps, plot_intermediate_steps_single 24 | 25 | 26 | def f(x): 27 | return x - 10 * jnp.cos(x) * jnp.sin(x) + x ** 3 28 | 29 | 30 | def fz(W): 31 | return W 32 | 33 | 34 | def main(): 35 | key = PRNGKey(314) 36 | key_sample_obs, key_weights, key_init = split(key, 3) 37 | 38 | all_figures = {} 39 | 40 | # *** MLP configuration *** 41 | n_hidden = 6 42 | n_out = 1 43 | n_in = 1 44 | model = MLP([n_hidden, n_out]) 45 | 46 | batch_size = 20 47 | batch = jnp.ones((batch_size, n_in)) 48 | 49 | variables = model.init(key_init, batch) 50 | W0, unflatten_fn = ravel_pytree(variables) 51 | 52 | fwd_mlp = partial(apply, model=model, unflatten_fn=unflatten_fn) 53 | # vectorised for multiple observations 54 | fwd_mlp_obs = vmap(fwd_mlp, in_axes=[None, 0]) 55 | # vectorised for multiple weights 56 | fwd_mlp_weights = vmap(fwd_mlp, in_axes=[1, None]) 57 | # vectorised for multiple observations and weights 58 | fwd_mlp_obs_weights = vmap(fwd_mlp_obs, in_axes=[0, None]) 59 | 60 | # *** Generating training and test data *** 61 | n_obs = 200 62 | xmin, xmax = -3, 3 63 | sigma_y = 3.0 64 | x, y = sample_observations(key_sample_obs, f, n_obs, xmin, xmax, x_noise=0, y_noise=sigma_y) 65 | xtest = jnp.linspace(x.min(), x.max(), n_obs) 66 | 67 | # *** MLP Training with UKF *** 68 | n_params = W0.size 69 | W0 = normal(key_weights, (n_params,)) * 1 # initial random guess 70 | Q = jnp.eye(n_params) * 1e-4 # parameters do not change 71 | R = jnp.eye(1) * sigma_y ** 2 # observation noise is fixed 72 | 73 | Vinit = jnp.eye(n_params) * 5 # vague prior 74 | alpha, beta, kappa = 0.01, 2.0, 3.0 - n_params 75 | ukf = NLDS(fz, lambda w, x: fwd_mlp_weights(w, x).T, Q, R, alpha, beta, kappa, n_params) 76 | ukf_mu_hist, ukf_Sigma_hist = ukf_lib.filter(ukf, W0, y, x[:, None], Vinit) 77 | step = -1 78 | W_ukf, SW_ukf = ukf_mu_hist[step], ukf_Sigma_hist[step] 79 | 80 | fig, ax = plt.subplots() 81 | plot_mlp_prediction(key, x, y, xtest, fwd_mlp_obs_weights, W_ukf, SW_ukf, ax) 82 | ax.set_title("UKF + MLP") 83 | all_figures["ukf-mlp"] = fig 84 | 85 | fig, ax = plt.subplots(2, 2) 86 | intermediate_steps = [10, 20, 30, 40, 50, 60] 87 | plot_intermediate_steps(key, ax, fwd_mlp_obs_weights, intermediate_steps, xtest, ukf_mu_hist, ukf_Sigma_hist, x, 88 | y) 89 | plt.suptitle("UKF + MLP training") 90 | all_figures["ukf-mlp-intermediate"] = fig 91 | figures_intermediate = plot_intermediate_steps_single(key, "ukf", fwd_mlp_obs_weights, 92 | intermediate_steps, xtest, ukf_mu_hist, ukf_Sigma_hist, x, 93 | y) 94 | all_figures = {**all_figures, **figures_intermediate} 95 | 96 | return all_figures 97 | 98 | 99 | if __name__ == "__main__": 100 | from jsl.demos.plot_utils import savefig 101 | 102 | plt.rcParams["axes.spines.right"] = False 103 | plt.rcParams["axes.spines.top"] = False 104 | figures = main() 105 | savefig(figures) 106 | plt.show() 107 | -------------------------------------------------------------------------------- /jsl/hmm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/JSL/649c1fa9709c9195d80f06d7a367112d67dc60c9/jsl/hmm/__init__.py -------------------------------------------------------------------------------- /jsl/hmm/hmm_casino_test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Simple sanity check for all four Hidden Markov Models' implementations. 3 | ''' 4 | import jax.numpy as jnp 5 | from jax.random import PRNGKey 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | import distrax 11 | from distrax import HMM 12 | 13 | from jsl.hmm.hmm_numpy_lib import HMMNumpy, hmm_forwards_backwards_numpy, hmm_viterbi_numpy 14 | 15 | from jsl.hmm.hmm_lib import HMMJax, hmm_viterbi_jax, hmm_forwards_backwards_jax 16 | 17 | import jsl.hmm.hmm_logspace_lib as hmm_logspace_lib 18 | 19 | 20 | def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False): 21 | """ 22 | Plot the estimated smoothing/filtering/map of a sequence of hidden states. 23 | "Vertical gray bars denote times when the hidden 24 | state corresponded to state 1. Blue lines represent the 25 | posterior probability of being in that state given different subsets 26 | of observed data." See Markov and Hidden Markov models section for more info 27 | Parameters 28 | ---------- 29 | inference_values: array(n_samples, state_size) 30 | Result of runnig smoothing method 31 | z_hist: array(n_samples) 32 | Latent simulation 33 | ax: matplotlib.axes 34 | state: int 35 | Decide which state to highlight 36 | map_estimate: bool 37 | Whether to plot steps (simple plot if False) 38 | """ 39 | n_samples = len(inference_values) 40 | xspan = np.arange(1, n_samples + 1) 41 | spans = find_dishonest_intervals(z_hist) 42 | if map_estimate: 43 | ax.step(xspan, inference_values, where="post") 44 | else: 45 | ax.plot(xspan, inference_values[:, state]) 46 | 47 | for span in spans: 48 | ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none") 49 | ax.set_xlim(1, n_samples) 50 | # ax.set_ylim(0, 1) 51 | ax.set_ylim(-0.1, 1.1) 52 | ax.set_xlabel("Observation number") 53 | 54 | 55 | def find_dishonest_intervals(z_hist): 56 | """ 57 | Find the span of timesteps that the 58 | simulated systems turns to be in state 1 59 | Parameters 60 | ---------- 61 | z_hist: array(n_samples) 62 | Result of running the system with two 63 | latent states 64 | Returns 65 | ------- 66 | list of tuples with span of values 67 | """ 68 | spans = [] 69 | x_init = 0 70 | for t, _ in enumerate(z_hist[:-1]): 71 | if z_hist[t + 1] == 0 and z_hist[t] == 1: 72 | x_end = t 73 | spans.append((x_init, x_end)) 74 | elif z_hist[t + 1] == 1 and z_hist[t] == 0: 75 | x_init = t + 1 76 | return spans 77 | 78 | 79 | # state transition matrix 80 | A = jnp.array([ 81 | [0.95, 0.05], 82 | [0.10, 0.90] 83 | ]) 84 | 85 | # observation matrix 86 | B = jnp.array([ 87 | [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die 88 | [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die 89 | ]) 90 | 91 | n_samples = 300 92 | init_state_dist = jnp.array([1, 1]) / 2 93 | hmm_numpy = HMMNumpy(np.array(A), np.array(B), np.array(init_state_dist)) 94 | hmm_jax = HMMJax(A, B, init_state_dist) 95 | hmm = HMM(trans_dist=distrax.Categorical(probs=A), 96 | init_dist=distrax.Categorical(probs=init_state_dist), 97 | obs_dist=distrax.Categorical(probs=B)) 98 | hmm_log = hmm_logspace_lib.HMM(trans_dist=distrax.Categorical(probs=A), 99 | init_dist=distrax.Categorical(probs=init_state_dist), 100 | obs_dist=distrax.Categorical(probs=B)) 101 | 102 | seed = 314 103 | z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples) 104 | 105 | z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60] 106 | x_hist_str = "".join((np.array(x_hist) + 1).astype(str))[:60] 107 | 108 | print("Printing sample observed/latent...") 109 | print(f"x: {x_hist_str}") 110 | print(f"z: {z_hist_str}") 111 | 112 | # Do inference 113 | alpha_numpy, _, gamma_numpy, loglik_numpy = hmm_forwards_backwards_numpy(hmm_numpy, 114 | np.array(x_hist), 115 | len(x_hist)) 116 | alpha_jax, _, gamma_jax, loglik_jax = hmm_forwards_backwards_jax(hmm_jax, 117 | x_hist, 118 | len(x_hist)) 119 | 120 | alpha_log, _, gamma_log, loglik_log = hmm_logspace_lib.hmm_forwards_backwards_log(hmm_log, 121 | x_hist, 122 | len(x_hist)) 123 | alpha, beta, gamma, loglik = hmm.forward_backward(x_hist) 124 | 125 | assert np.allclose(alpha_numpy, alpha) 126 | assert np.allclose(alpha_jax, alpha) 127 | assert np.allclose(jnp.exp(alpha_log), alpha) 128 | 129 | assert np.allclose(gamma_numpy, gamma) 130 | assert np.allclose(gamma_jax, gamma) 131 | assert np.allclose(jnp.exp(gamma_log), gamma) 132 | 133 | print(f"Loglikelihood(Distrax): {loglik}") 134 | print(f"Loglikelihood(Numpy): {loglik_numpy}") 135 | print(f"Loglikelihood(Jax): {loglik_jax}") 136 | print(f"Loglikelihood(Jax): {loglik_log}") 137 | 138 | z_map_numpy = hmm_viterbi_numpy(hmm_numpy, np.array(x_hist)) 139 | z_map_jax = hmm_viterbi_jax(hmm_jax, x_hist) 140 | z_map_log = hmm_logspace_lib.hmm_viterbi_log(hmm_log, x_hist) 141 | z_map = hmm.viterbi(x_hist) 142 | 143 | assert np.allclose(z_map_numpy, z_map) 144 | assert np.allclose(z_map_jax, z_map) 145 | assert np.allclose(z_map_log, z_map) 146 | 147 | # Plot results 148 | fig, ax = plt.subplots() 149 | plot_inference(gamma_numpy, z_hist, ax) 150 | ax.set_ylabel("p(loaded)") 151 | ax.set_title("Smoothed") 152 | plt.savefig("hmm_casino_smooth_numpy.png") 153 | plt.show() 154 | 155 | fig, ax = plt.subplots() 156 | plot_inference(z_map_numpy, z_hist, ax, map_estimate=True) 157 | ax.set_ylabel("MAP state") 158 | ax.set_title("Viterbi") 159 | plt.savefig("hmm_casino_map_numpy.png") 160 | plt.show() 161 | 162 | # Plot results 163 | fig, ax = plt.subplots() 164 | plot_inference(gamma, z_hist, ax) 165 | ax.set_ylabel("p(loaded)") 166 | ax.set_title("Smoothed") 167 | # plt.savefig("hmm_casino_smooth_distrax.png") 168 | plt.show() 169 | 170 | fig, ax = plt.subplots() 171 | plot_inference(z_map, z_hist, ax, map_estimate=True) 172 | ax.set_ylabel("MAP state") 173 | ax.set_title("Viterbi") 174 | # plt.savefig("hmm_casino_map_distrax.png") 175 | plt.show() 176 | 177 | # Plot results 178 | fig, ax = plt.subplots() 179 | plot_inference(gamma_jax, z_hist, ax) 180 | ax.set_ylabel("p(loaded)") 181 | ax.set_title("Smoothed") 182 | # plt.savefig("hmm_casino_smooth_jax.png") 183 | plt.show() 184 | 185 | fig, ax = plt.subplots() 186 | plot_inference(z_map_jax, z_hist, ax, map_estimate=True) 187 | ax.set_ylabel("MAP state") 188 | ax.set_title("Viterbi") 189 | # plt.savefig("hmm_casino_map_jax.png") 190 | plt.show() 191 | 192 | # Plot results 193 | fig, ax = plt.subplots() 194 | plot_inference(jnp.exp(gamma_log), z_hist, ax) 195 | ax.set_ylabel("p(loaded)") 196 | ax.set_title("Smoothed") 197 | # plt.savefig("hmm_casino_smooth_log.png") 198 | plt.show() 199 | 200 | fig, ax = plt.subplots() 201 | plot_inference(z_map_log, z_hist, ax, map_estimate=True) 202 | ax.set_ylabel("MAP state") 203 | ax.set_title("Viterbi") 204 | # plt.savefig("hmm_casino_map_log.png") 205 | plt.show() 206 | -------------------------------------------------------------------------------- /jsl/hmm/hmm_lib_test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This demo compares the Jax, Numpy and Distrax version of forwards-backwards algorithm in terms of the speed. 3 | Also, checks whether or not they give the same result. 4 | Author : Aleyna Kara (@karalleyna) 5 | ''' 6 | 7 | import jax.numpy as jnp 8 | from jax import vmap, nn 9 | from jax.random import split, PRNGKey, uniform, normal 10 | 11 | import distrax 12 | from distrax import HMM 13 | 14 | import chex 15 | 16 | import numpy as np 17 | import time 18 | 19 | from jsl.hmm.hmm_numpy_lib import HMMNumpy, hmm_forwards_backwards_numpy, hmm_loglikelihood_numpy 20 | from jsl.hmm.hmm_lib import HMMJax, hmm_viterbi_jax 21 | from jsl.hmm.hmm_lib import hmm_sample_jax, hmm_forwards_backwards_jax, hmm_loglikelihood_jax 22 | from jsl.hmm.hmm_lib import normalize, fixed_lag_smoother 23 | import jsl.hmm.hmm_utils as hmm_utils 24 | 25 | 26 | from tensorflow_probability.substrates import jax as tfp 27 | 28 | tfd = tfp.distributions 29 | 30 | ####### 31 | # Test log likelihood 32 | 33 | def loglikelihood_numpy(params_numpy, batches, lens): 34 | return np.array([hmm_loglikelihood_numpy(params_numpy, batch, l) for batch, l in zip(batches, lens)]) 35 | 36 | def loglikelihood_jax(params_jax, batches, lens): 37 | return vmap(hmm_loglikelihood_jax, in_axes=(None, 0, 0))(params_jax, batches, lens)[:,:, 0] 38 | 39 | 40 | def test_all_hmm_models(): 41 | # state transition matrix 42 | A = jnp.array([ 43 | [0.95, 0.05], 44 | [0.10, 0.90] 45 | ]) 46 | 47 | # observation matrix 48 | B = jnp.array([ 49 | [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die 50 | [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die 51 | ]) 52 | 53 | pi = jnp.array([1, 1]) / 2 54 | 55 | params_numpy= HMMNumpy(np.array(A), np.array(B), np.array(pi)) 56 | params_jax = HMMJax(A, B, pi) 57 | 58 | seed = 0 59 | rng_key = PRNGKey(seed) 60 | rng_key, rng_sample = split(rng_key) 61 | 62 | n_obs_seq, batch_size, max_len = 15, 5, 10 63 | 64 | observations, lens = hmm_utils.hmm_sample_n(params_jax, 65 | hmm_sample_jax, 66 | n_obs_seq, max_len, 67 | rng_sample) 68 | 69 | observations, lens = hmm_utils.pad_sequences(observations, lens) 70 | 71 | rng_key, rng_batch = split(rng_key) 72 | batches, lens = hmm_utils.hmm_sample_minibatches(observations, 73 | lens, 74 | batch_size, 75 | rng_batch) 76 | 77 | ll_numpy = loglikelihood_numpy(params_numpy, np.array(batches), np.array(lens)) 78 | ll_jax = loglikelihood_jax(params_jax, batches, lens) 79 | assert np.allclose(ll_numpy, ll_jax, atol=4) 80 | 81 | 82 | def test_inference(): 83 | seed = 0 84 | rng_key = PRNGKey(seed) 85 | rng_key, key_A, key_B = split(rng_key, 3) 86 | 87 | # state transition matrix 88 | n_hidden, n_obs = 100, 10 89 | A = uniform(key_A, (n_hidden, n_hidden)) 90 | A = A / jnp.sum(A, axis=1) 91 | 92 | # observation matrix 93 | B = uniform(key_B, (n_hidden, n_obs)) 94 | B = B / jnp.sum(B, axis=1).reshape((-1, 1)) 95 | 96 | n_samples = 1000 97 | init_state_dist = jnp.ones(n_hidden) / n_hidden 98 | 99 | seed = 0 100 | rng_key = PRNGKey(seed) 101 | 102 | params_numpy = HMMNumpy(A, B, init_state_dist) 103 | params_jax = HMMJax(A, B, init_state_dist) 104 | hmm_distrax = HMM(trans_dist=distrax.Categorical(probs=A), 105 | obs_dist=distrax.Categorical(probs=B), 106 | init_dist=distrax.Categorical(probs=init_state_dist)) 107 | 108 | z_hist, x_hist = hmm_sample_jax(params_jax, n_samples, rng_key) 109 | 110 | start = time.time() 111 | alphas_np, _, gammas_np, loglikelihood_np = hmm_forwards_backwards_numpy(params_numpy, x_hist, len(x_hist)) 112 | print(f'Time taken by numpy version of forwards backwards : {time.time()-start}s') 113 | 114 | start = time.time() 115 | alphas_jax, _, gammas_jax, loglikelihood_jax = hmm_forwards_backwards_jax(params_jax, jnp.array(x_hist), len(x_hist)) 116 | print(f'Time taken by JAX version of forwards backwards: {time.time()-start}s') 117 | 118 | start = time.time() 119 | alphas, _, gammas, loglikelihood = hmm_distrax.forward_backward(obs_seq=jnp.array(x_hist), 120 | length=len(x_hist)) 121 | 122 | print(f'Time taken by HMM distrax : {time.time()-start}s') 123 | 124 | assert np.allclose(alphas_np, alphas_jax) 125 | assert np.allclose(loglikelihood_np, loglikelihood_jax) 126 | assert np.allclose(gammas_np, gammas_jax) 127 | 128 | assert np.allclose(alphas, alphas_jax, atol=8) 129 | assert np.allclose(loglikelihood, loglikelihood_jax, atol=8) 130 | assert np.allclose(gammas, gammas_jax, atol=8) 131 | 132 | 133 | def _make_models(init_probs, trans_probs, obs_probs, length): 134 | """Build distrax HMM and equivalent TFP HMM.""" 135 | 136 | dx_model = HMMJax( 137 | trans_probs, 138 | obs_probs, 139 | init_probs 140 | ) 141 | 142 | tfp_model = tfd.HiddenMarkovModel( 143 | initial_distribution=tfd.Categorical(probs=init_probs), 144 | transition_distribution=tfd.Categorical(probs=trans_probs), 145 | observation_distribution=tfd.Categorical(probs=obs_probs), 146 | num_steps=length, 147 | ) 148 | 149 | return dx_model, tfp_model 150 | 151 | 152 | def test_sample(length, num_states): 153 | params_fn = obs_dist_name_and_params_fn 154 | 155 | init_probs = nn.softmax(normal(PRNGKey(0), (num_states,)), axis=-1) 156 | trans_mat = nn.softmax(normal(PRNGKey(1), (num_states, num_states)), axis=-1) 157 | 158 | model, tfp_model = _make_models(init_probs, 159 | trans_mat, 160 | params_fn(num_states), 161 | length) 162 | 163 | states, obs = hmm_sample_jax(model, length, PRNGKey(0)) 164 | tfp_obs = tfp_model.sample(seed=PRNGKey(0)) 165 | 166 | chex.assert_shape(states, (length,)) 167 | chex.assert_equal_shape([obs, tfp_obs]) 168 | 169 | 170 | def test_forward_backward(length, num_states): 171 | params_fn = obs_dist_name_and_params_fn 172 | 173 | init_probs = nn.softmax(normal(PRNGKey(0), (num_states,)), axis=-1) 174 | trans_mat = nn.softmax(normal(PRNGKey(1), (num_states, num_states)), axis=-1) 175 | 176 | model, tfp_model = _make_models(init_probs, 177 | trans_mat, 178 | params_fn(num_states), 179 | length) 180 | 181 | _, observations = hmm_sample_jax(model, length, PRNGKey(42)) 182 | 183 | alphas, betas, marginals, log_prob = hmm_forwards_backwards_jax(model, 184 | observations) 185 | 186 | tfp_marginal_logits = tfp_model.posterior_marginals(observations).logits 187 | tfp_marginals = nn.softmax(tfp_marginal_logits) 188 | 189 | chex.assert_shape(alphas, (length, num_states)) 190 | chex.assert_shape(betas, (length, num_states)) 191 | chex.assert_shape(marginals, (length, num_states)) 192 | chex.assert_shape(log_prob, (1,)) 193 | np.testing.assert_array_almost_equal(marginals, tfp_marginals, decimal=4) 194 | 195 | 196 | def test_viterbi(length, num_states): 197 | params_fn = obs_dist_name_and_params_fn 198 | 199 | init_probs = nn.softmax(normal(PRNGKey(0), (num_states,)), axis=-1) 200 | trans_mat = nn.softmax(normal(PRNGKey(1), (num_states, num_states)), axis=-1) 201 | 202 | model, tfp_model = _make_models(init_probs, 203 | trans_mat, 204 | params_fn(num_states), 205 | length) 206 | 207 | _, observations = hmm_sample_jax(model, length, PRNGKey(42)) 208 | most_likely_states = hmm_viterbi_jax(model, observations) 209 | tfp_mode = tfp_model.posterior_mode(observations) 210 | chex.assert_shape(most_likely_states, (length,)) 211 | assert np.allclose(most_likely_states, tfp_mode) 212 | 213 | ''' 214 | ######## 215 | #Test Fixed Lag Smoother 216 | 217 | # helper function 218 | def get_fls_result(params, data, win_len, act=None): 219 | assert data.size > 2, "Complete observation set must be of size at least 2" 220 | prior, obs_mat = params.init_dist, params.obs_mat 221 | n_states = obs_mat.shape[0] 222 | alpha, _ = normalize(prior * obs_mat[:, data[0]]) 223 | bmatrix = jnp.eye(n_states)[None, :] 224 | for obs in data[1:]: 225 | alpha, bmatrix, gamma = fixed_lag_smoother(params, win_len, alpha, bmatrix, obs) 226 | return alpha, gamma 227 | 228 | *_, gammas_fls = get_fls_result(params_jax, jnp.array(x_hist), jnp.array(x_hist).size) 229 | 230 | assert np.allclose(gammas_fls, gammas_jax) 231 | ''' 232 | obs_dist_name_and_params_fn = lambda n: nn.softmax(normal(PRNGKey(0), (n, 7)), axis=-1) 233 | 234 | ### Tests 235 | test_all_hmm_models() 236 | test_inference() 237 | 238 | for length, num_states in zip([1, 3], (2, 23)): 239 | test_viterbi(length, num_states) 240 | test_forward_backward(length, num_states) 241 | test_sample(length, num_states) 242 | 243 | 244 | -------------------------------------------------------------------------------- /jsl/hmm/hmm_logspace_lib_test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This demo compares the log space version of Hidden Markov Model for discrete observations and general hidden markov model 3 | in terms of the speed. It also checks whether or not the inference algorithms give the same result. 4 | Author : Aleyna Kara (@karalleyna) 5 | ''' 6 | 7 | import time 8 | import distrax 9 | import jax.numpy as jnp 10 | from jax.random import PRNGKey, split, uniform 11 | import numpy as np 12 | from hmm_logspace_lib import HMM, hmm_forwards_backwards_log, hmm_viterbi_log, hmm_sample_log 13 | 14 | seed = 0 15 | rng_key = PRNGKey(seed) 16 | rng_key, key_A, key_B = split(rng_key, 3) 17 | 18 | # state transition matrix 19 | n_hidden, n_obs = 100, 10 20 | A = uniform(key_A, (n_hidden, n_hidden)) 21 | A = A / jnp.sum(A, axis=1) 22 | 23 | # observation matrix 24 | B = uniform(key_B, (n_hidden, n_obs)) 25 | B = B / jnp.sum(B, axis=1).reshape((-1, 1)) 26 | 27 | n_samples = 1000 28 | init_state_dist = jnp.ones(n_hidden) / n_hidden 29 | 30 | seed = 0 31 | rng_key = PRNGKey(seed) 32 | 33 | hmm = HMM(trans_dist=distrax.Categorical(probs=A), 34 | obs_dist=distrax.Categorical(probs=B), 35 | init_dist=distrax.Categorical(probs=init_state_dist)) 36 | 37 | hmm_distrax = distrax.HMM(trans_dist=distrax.Categorical(probs=A), 38 | obs_dist=distrax.Categorical(probs=B), 39 | init_dist=distrax.Categorical(probs=init_state_dist)) 40 | 41 | z_hist, x_hist = hmm_sample_log(hmm, n_samples, rng_key) 42 | 43 | start = time.time() 44 | alphas, _, gammas, loglikelihood = hmm_distrax.forward_backward(x_hist, len(x_hist)) 45 | print(f'Time taken by Forwards Backwards function of HMM general: {time.time() - start}s') 46 | print(f'Loglikelihood found by HMM general: {loglikelihood}') 47 | 48 | start = time.time() 49 | alphas_log, _, gammas_log, loglikelihood_log = hmm_forwards_backwards_log(hmm, x_hist, len(x_hist)) 50 | print(f'Time taken by Forwards Backwards function of HMM Log Space Version: {time.time() - start}s') 51 | print(f'Loglikelihood found by HMM General Log Space Version: {loglikelihood_log}') 52 | 53 | assert np.allclose(jnp.log(alphas), alphas_log, 8) 54 | assert np.allclose(loglikelihood, loglikelihood_log) 55 | assert np.allclose(jnp.log(gammas), gammas_log, 8) 56 | 57 | # Test for the hmm_viterbi_log. This test is based on https://github.com/deepmind/distrax/blob/master/distrax/_src/utils/hmm_test.py 58 | loc = jnp.array([0.0, 1.0, 2.0, 3.0]) 59 | scale = jnp.array(0.25) 60 | initial = jnp.array([0.25, 0.25, 0.25, 0.25]) 61 | trans = jnp.array([[0.9, 0.1, 0.0, 0.0], 62 | [0.1, 0.8, 0.1, 0.0], 63 | [0.0, 0.1, 0.8, 0.1], 64 | [0.0, 0.0, 0.1, 0.9]]) 65 | 66 | observations = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 3.0, 2.9, 2.8, 2.7, 2.6]) 67 | 68 | model = HMM( 69 | init_dist=distrax.Categorical(probs=initial), 70 | trans_dist=distrax.Categorical(probs=trans), 71 | obs_dist=distrax.Normal(loc, scale)) 72 | 73 | inferred_states = hmm_viterbi_log(model, observations) 74 | expected_states = [0, 0, 0, 0, 1, 2, 3, 3, 3, 3] 75 | 76 | assert np.allclose(inferred_states, expected_states) 77 | 78 | length = 7 79 | inferred_states = hmm_viterbi_log(model, observations, length) 80 | expected_states = [0, 0, 0, 0, 1, 2, 3, -1, -1, -1] 81 | assert np.allclose(inferred_states, expected_states) 82 | -------------------------------------------------------------------------------- /jsl/hmm/hmm_utils.py: -------------------------------------------------------------------------------- 1 | # Common functions that can be used for any hidden markov model type. 2 | # Author: Aleyna Kara(@karalleyna) 3 | 4 | import jax.numpy as jnp 5 | import matplotlib.pyplot as plt 6 | from jax import vmap, jit 7 | from jax.random import split, randint, PRNGKey, permutation 8 | from functools import partial 9 | # !pip install graphviz 10 | from graphviz import Digraph 11 | 12 | 13 | @partial(jit, static_argnums=(2,)) 14 | def hmm_sample_minibatches(observations, valid_lens, batch_size, rng_key): 15 | ''' 16 | Creates minibatches consists of the random permutations of the 17 | given observation sequences 18 | 19 | Parameters 20 | ---------- 21 | observations : array(N, seq_len) 22 | All observation sequences 23 | 24 | valid_lens : array(N, seq_len) 25 | Consists of the valid length of each observation sequence 26 | 27 | batch_size : int 28 | The number of observation sequences that will be included in 29 | each minibatch 30 | 31 | rng_key : array 32 | Random key of shape (2,) and dtype uint32 33 | 34 | Returns 35 | ------- 36 | * array(num_batches, batch_size, max_len) 37 | Minibatches 38 | ''' 39 | num_train = len(observations) 40 | perm = permutation(rng_key, num_train) 41 | 42 | def create_mini_batch(batch_idx): 43 | return observations[batch_idx], valid_lens[batch_idx] 44 | 45 | num_batches = num_train // batch_size 46 | batch_indices = perm.reshape((num_batches, -1)) 47 | minibatches = vmap(create_mini_batch)(batch_indices) 48 | return minibatches 49 | 50 | 51 | @partial(jit, static_argnums=(1, 2, 3)) 52 | def hmm_sample_n(params, sample_fn, n, max_len, rng_key): 53 | ''' 54 | Generates n observation sequences from the given Hidden Markov Model 55 | 56 | Parameters 57 | ---------- 58 | params : HMMNumpy or HMMJax 59 | Hidden Markov Model 60 | 61 | sample_fn : 62 | The sample function of the given hidden markov model 63 | 64 | n : int 65 | The total number of observation sequences 66 | 67 | max_len : int 68 | The upper bound of the length of each observation sequence. Note that the valid length of the observation 69 | sequence is less than or equal to the upper bound. 70 | 71 | rng_key : array 72 | Random key of shape (2,) and dtype uint32 73 | 74 | Returns 75 | ------- 76 | * array(n, max_len) 77 | Observation sequences 78 | ''' 79 | 80 | def sample_(params, n_samples, key): 81 | return sample_fn(params, n_samples, key)[1] 82 | 83 | rng_key, rng_lens = split(rng_key) 84 | lens = randint(rng_lens, (n,), minval=1, maxval=max_len + 1) 85 | keys = split(rng_key, n) 86 | observations = vmap(sample_, in_axes=(None, None, 0))(params, max_len, keys) 87 | return observations, lens 88 | 89 | 90 | @jit 91 | def pad_sequences(observations, valid_lens, pad_val=0): 92 | ''' 93 | Generates n observation sequences from the given Hidden Markov Model 94 | 95 | Parameters 96 | ---------- 97 | 98 | observations : array(N, seq_len) 99 | All observation sequences 100 | 101 | valid_lens : array(N, seq_len) 102 | Consists of the valid length of each observation sequence 103 | 104 | pad_val : int 105 | Value that the invalid observable events of the observation sequence will be replaced 106 | 107 | Returns 108 | ------- 109 | * array(n, max_len) 110 | Ragged dataset 111 | ''' 112 | 113 | def pad(seq, len): 114 | idx = jnp.arange(1, seq.shape[0] + 1) 115 | return jnp.where(idx <= len, seq, pad_val) 116 | 117 | ragged_dataset = vmap(pad, in_axes=(0, 0))(observations, valid_lens), valid_lens 118 | return ragged_dataset 119 | 120 | 121 | def hmm_plot_graphviz(trans_mat, obs_mat, states=[], observations=[]): 122 | """ 123 | Visualizes HMM transition matrix and observation matrix using graphhiz. 124 | 125 | Parameters 126 | ---------- 127 | trans_mat, obs_mat, init_dist: arrays 128 | 129 | states: List(num_hidden) 130 | Names of hidden states 131 | 132 | observations: List(num_obs) 133 | Names of observable events 134 | 135 | Returns 136 | ------- 137 | dot object, that can be displayed in colab 138 | """ 139 | 140 | n_states, n_obs = obs_mat.shape 141 | 142 | dot = Digraph(comment='HMM') 143 | if not states: 144 | states = [f'State {i + 1}' for i in range(n_states)] 145 | if not observations: 146 | observations = [f'Obs {i + 1}' for i in range(n_obs)] 147 | 148 | # Creates hidden state nodes 149 | for i, name in enumerate(states): 150 | table = [f'{observations[j]}{"%.2f" % prob}' for j, prob in 151 | enumerate(obs_mat[i])] 152 | label = f'''<{''.join(table)}
{name}
>''' 153 | dot.node(f's{i}', label=label) 154 | 155 | # Writes transition probabilities 156 | for i in range(n_states): 157 | for j in range(n_states): 158 | dot.edge(f's{i}', f's{j}', label=str('%.2f' % trans_mat[i, j])) 159 | dot.attr(rankdir='LR') 160 | # dot.render(file_name, view=True) 161 | return dot 162 | -------------------------------------------------------------------------------- /jsl/hmm/old/hmm_discrete_lib_test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This demo compares the Jax, Numpy and Distrax version of forwards-backwards algorithm in terms of the speed. 3 | Also, checks whether or not they give the same result. 4 | Author : Aleyna Kara (@karalleyna) 5 | ''' 6 | 7 | import superimport 8 | 9 | import time 10 | 11 | import jax.numpy as jnp 12 | from jax.random import PRNGKey, split, uniform 13 | import numpy as np 14 | 15 | from hmm_discrete_lib import HMMJax, HMMNumpy 16 | from hmm_discrete_lib import hmm_sample_jax, hmm_forwards_backwards_jax, hmm_forwards_backwards_numpy 17 | 18 | import distrax 19 | from distrax import HMM 20 | 21 | seed = 0 22 | rng_key = PRNGKey(seed) 23 | rng_key, key_A, key_B = split(rng_key, 3) 24 | 25 | # state transition matrix 26 | n_hidden, n_obs = 100, 10 27 | A = uniform(key_A, (n_hidden, n_hidden)) 28 | A = A / jnp.sum(A, axis=1) 29 | 30 | # observation matrix 31 | B = uniform(key_B, (n_hidden, n_obs)) 32 | B = B / jnp.sum(B, axis=1).reshape((-1, 1)) 33 | 34 | n_samples = 1000 35 | init_state_dist = jnp.ones(n_hidden) / n_hidden 36 | 37 | seed = 0 38 | rng_key = PRNGKey(seed) 39 | 40 | params_numpy = HMMNumpy(A, B, init_state_dist) 41 | params_jax = HMMJax(A, B, init_state_dist) 42 | hmm_distrax = HMM(trans_dist=distrax.Categorical(probs=A), 43 | obs_dist=distrax.Categorical(probs=B), 44 | init_dist=distrax.Categorical(probs=init_state_dist)) 45 | 46 | z_hist, x_hist = hmm_sample_jax(params_jax, n_samples, rng_key) 47 | 48 | start = time.time() 49 | alphas_np, _, gammas_np, loglikelihood_np = hmm_forwards_backwards_numpy(params_numpy, x_hist, len(x_hist)) 50 | print(f'Time taken by numpy version of forwards backwards : {time.time()-start}s') 51 | 52 | start = time.time() 53 | alphas_jax, _, gammas_jax, loglikelihood_jax = hmm_forwards_backwards_jax(params_jax, jnp.array(x_hist), len(x_hist)) 54 | print(f'Time taken by JAX version of forwards backwards: {time.time()-start}s') 55 | 56 | start = time.time() 57 | alphas, _, gammas, loglikelihood = hmm_distrax.forward_backward(obs_seq=jnp.array(x_hist), 58 | length=len(x_hist)) 59 | 60 | print(f'Time taken by HMM distrax : {time.time()-start}s') 61 | 62 | assert np.allclose(alphas_np, alphas_jax) 63 | assert np.allclose(loglikelihood_np, loglikelihood_jax) 64 | assert np.allclose(gammas_np, gammas_jax) 65 | 66 | assert np.allclose(alphas, alphas_jax, 8) 67 | assert np.allclose(loglikelihood, loglikelihood_jax) 68 | assert np.allclose(gammas, gammas_jax, 8) -------------------------------------------------------------------------------- /jsl/hmm/old/hmm_discrete_likelihood_test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This demo shows how to create variable size dataset and then it creates mini-batches from this dataset so that it 3 | calculates the likelihood of each observation sequence in every batch using vmap. 4 | Author : Aleyna Kara (@karalleyna) 5 | ''' 6 | 7 | 8 | from jax import vmap, jit 9 | from jax.random import split, randint, PRNGKey 10 | import jax.numpy as jnp 11 | 12 | import hmm_utils 13 | from hmm_discrete_lib import hmm_sample_jax, hmm_loglikelihood_numpy, hmm_loglikelihood_jax 14 | from hmm_discrete_lib import HMMNumpy, HMMJax 15 | import numpy as np 16 | 17 | def loglikelihood_numpy(params_numpy, batches, lens): 18 | return np.vstack([hmm_loglikelihood_numpy(params_numpy, batch, l) for batch, l in zip(batches, lens)]) 19 | 20 | def loglikelihood_jax(params_jax, batches, lens): 21 | return vmap(hmm_loglikelihood_jax, in_axes=(None, 0, 0))(params_jax, batches, lens) 22 | 23 | # state transition matrix 24 | A = jnp.array([ 25 | [0.95, 0.05], 26 | [0.10, 0.90] 27 | ]) 28 | 29 | # observation matrix 30 | B = jnp.array([ 31 | [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die 32 | [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die 33 | ]) 34 | 35 | pi = jnp.array([1, 1]) / 2 36 | 37 | params_numpy= HMMNumpy(np.array(A), np.array(B), np.array(pi)) 38 | params_jax = HMMJax(A, B, pi) 39 | 40 | seed = 0 41 | rng_key = PRNGKey(seed) 42 | rng_key, rng_sample = split(rng_key) 43 | 44 | n_obs_seq, batch_size, max_len = 15, 5, 10 45 | 46 | observations, lens = hmm_utils.hmm_sample_n(params_jax, 47 | hmm_sample_jax, 48 | n_obs_seq, max_len, 49 | rng_sample) 50 | 51 | observations, lens = hmm_utils.pad_sequences(observations, lens) 52 | 53 | rng_key, rng_batch = split(rng_key) 54 | batches, lens = hmm_utils.hmm_sample_minibatches(observations, 55 | lens, 56 | batch_size, 57 | rng_batch) 58 | 59 | ll_numpy = loglikelihood_numpy(params_numpy, np.array(batches), np.array(lens)) 60 | ll_jax = loglikelihood_jax(params_jax, batches, lens) 61 | 62 | assert np.allclose(ll_numpy, ll_jax) 63 | print(f'Loglikelihood {ll_numpy}') -------------------------------------------------------------------------------- /jsl/hmm/old/hmm_sgd_lib.py: -------------------------------------------------------------------------------- 1 | # Trains hidden markov models with discrete observations using Gradient Descent in a stateless way. 2 | # Author : Aleyna Kara(@karalleyna) 3 | 4 | 5 | import jax 6 | import itertools 7 | from jax import jit 8 | from jax.nn import softmax 9 | from jax.random import PRNGKey, split, normal 10 | from jsl.hmm.hmm_utils import hmm_sample_minibatches 11 | from jsl.hmm.hmm_discrete_lib import HMMJax, hmm_loglikelihood_jax 12 | 13 | opt_init, opt_update, get_params = None, None, None 14 | 15 | def init_random_params(sizes, rng_key): 16 | """ 17 | Initializes the components of HMM from normal distibution 18 | 19 | Parameters 20 | ---------- 21 | sizes: List 22 | Consists of number of hidden states and observable events, respectively 23 | 24 | rng_key : array 25 | Random key of shape (2,) and dtype uint32 26 | 27 | Returns 28 | ------- 29 | * array(num_hidden, num_hidden) 30 | Transition probability matrix 31 | 32 | * array(num_hidden, num_obs) 33 | Emission probability matrix 34 | 35 | * array(1, num_hidden) 36 | Initial distribution probabilities 37 | """ 38 | num_hidden, num_obs = sizes 39 | rng_key, rng_a, rng_b, rng_pi = split(rng_key, 4) 40 | return HMMJax(normal(rng_a, (num_hidden, num_hidden)), 41 | normal(rng_b, (num_hidden, num_obs)), 42 | normal(rng_pi, (num_hidden,))) 43 | 44 | 45 | @jit 46 | def loss_fn(params, batch, lens): 47 | """ 48 | Objective function of hidden markov models for discrete observations. It returns the mean of the negative 49 | loglikelihood of the sequence of observations 50 | 51 | Parameters 52 | ---------- 53 | params : HMMJax 54 | Hidden Markov Model 55 | 56 | batch: array(N, max_len) 57 | Minibatch consisting of observation sequences 58 | 59 | lens : array(N, seq_len) 60 | Consists of the valid length of each observation sequence in the minibatch 61 | 62 | Returns 63 | ------- 64 | * float 65 | The mean negative loglikelihood of the minibatch 66 | """ 67 | params_soft = HMMJax(softmax(params.trans_mat, axis=1), 68 | softmax(params.obs_mat, axis=1), 69 | softmax(params.init_dist)) 70 | return -hmm_loglikelihood_jax(params_soft, batch, lens).mean() 71 | 72 | 73 | @jit 74 | def update(i, opt_state, batch, lens): 75 | """ 76 | Objective function of hidden markov models for discrete observations. It returns the mean of the negative 77 | loglikelihood of the sequence of observations 78 | 79 | Parameters 80 | ---------- 81 | i : int 82 | Specifies the current iteration 83 | 84 | opt_state : OptimizerState 85 | 86 | batch: array(N, max_len) 87 | Minibatch consisting of observation sequences 88 | 89 | lens : array(N, seq_len) 90 | Consists of the valid length of each observation sequence in the minibatch 91 | 92 | Returns 93 | ------- 94 | * OptimizerState 95 | 96 | * float 97 | The mean negative loglikelihood of the minibatch, i.e. loss value for the current iteration. 98 | """ 99 | params = get_params(opt_state) 100 | loss, grads = jax.value_and_grad(loss_fn)(params, batch, lens) 101 | return opt_update(i, grads, opt_state), loss 102 | 103 | 104 | def fit(observations, lens, num_hidden, num_obs, batch_size, optimizer,rng_key=None, num_epochs=1): 105 | """ 106 | Trains the HMM model with the given number of hidden states and observations via any optimizer. 107 | 108 | Parameters 109 | ---------- 110 | observations: array(N, seq_len) 111 | All observation sequences 112 | 113 | lens : array(N, seq_len) 114 | Consists of the valid length of each observation sequence 115 | 116 | num_hidden : int 117 | The number of hidden state 118 | 119 | num_obs : int 120 | The number of observable events 121 | 122 | batch_size : int 123 | The number of observation sequences that will be included in each minibatch 124 | 125 | optimizer : jax.experimental.optimizers.Optimizer 126 | Optimizer that is used during training 127 | 128 | num_epochs : int 129 | The total number of iterations 130 | 131 | Returns 132 | ------- 133 | * HMMJax 134 | Hidden Markov Model 135 | 136 | * array 137 | Consists of training losses 138 | """ 139 | global opt_init, opt_update, get_params 140 | 141 | if rng_key is None: 142 | rng_key = PRNGKey(0) 143 | 144 | rng_init, rng_iter = split(rng_key) 145 | params = init_random_params([num_hidden, num_obs], rng_init) 146 | opt_init, opt_update, get_params = optimizer 147 | opt_state = opt_init(params) 148 | itercount = itertools.count() 149 | 150 | def epoch_step(opt_state, key): 151 | 152 | def train_step(opt_state, params): 153 | batch, length = params 154 | opt_state, loss = update(next(itercount), opt_state, batch, length) 155 | return opt_state, loss 156 | 157 | batches, valid_lens = hmm_sample_minibatches(observations, lens, batch_size, key) 158 | params = (batches, valid_lens) 159 | opt_state, losses = jax.lax.scan(train_step, opt_state, params) 160 | return opt_state, losses.mean() 161 | 162 | epochs = split(rng_iter, num_epochs) 163 | opt_state, losses = jax.lax.scan(epoch_step, opt_state, epochs) 164 | 165 | losses = losses.flatten() 166 | 167 | params = get_params(opt_state) 168 | params = HMMJax(softmax(params.trans_mat, axis=1), 169 | softmax(params.obs_mat, axis=1), 170 | softmax(params.init_dist)) 171 | return params, losses 172 | -------------------------------------------------------------------------------- /jsl/hmm/sparse_lib.py: -------------------------------------------------------------------------------- 1 | """ 2 | jax.experimental.sparse-compatible Hidden Markov Model (HMM) 3 | """ 4 | import jax 5 | import chex 6 | from functools import partial 7 | from typing import Callable, Tuple, Dict 8 | 9 | 10 | def alpha_step(alpha_prev, y, local_evidence_multiple, transition_matrix): 11 | local_evidence = local_evidence_multiple(y) 12 | alpha_next = local_evidence * (transition_matrix.T @ alpha_prev) 13 | normalisation_cst = alpha_next.sum() 14 | alpha_next = alpha_next / normalisation_cst 15 | 16 | carry = { 17 | "alpha": alpha_next, 18 | "cst": normalisation_cst 19 | } 20 | 21 | return alpha_next, carry 22 | 23 | 24 | def beta_step(beta_next, y, local_evidence_multiple, transition_matrix): 25 | norm_cst = beta_next.sum() 26 | local_evidence = local_evidence_multiple(y) 27 | beta_prev = transition_matrix @ (local_evidence * beta_next) 28 | beta_prev = beta_prev / norm_cst 29 | 30 | carry = { 31 | "beta": beta_prev, 32 | "cst": norm_cst 33 | } 34 | 35 | return beta_prev, carry 36 | 37 | 38 | def alpha_forward(obs: chex.Array, 39 | local_evidence: Callable[[chex.Array], chex.Array], 40 | transition_matrix: chex.Array, 41 | alpha_init: chex.Array) -> Tuple[chex.Array, chex.Array]: 42 | """ 43 | Compute the alpha-forward (forward-filter) pass for a given observation sequence. 44 | """ 45 | alpha_step_part = partial(alpha_step, 46 | local_evidence_multiple=local_evidence, 47 | transition_matrix=transition_matrix) 48 | alpha_last, alpha_hist = jax.lax.scan(alpha_step_part, alpha_init, obs) 49 | return alpha_last, alpha_hist 50 | 51 | 52 | def beta_backward(obs: chex.Array, 53 | local_evidence: Callable[[chex.Array], chex.Array], 54 | transition_matrix: chex.Array, 55 | alpha_last: chex.Array) -> Tuple[chex.Array, chex.Array]: 56 | """ 57 | Compute the forward-filter and backward-smoother pass for a 58 | given observation sequence. 59 | """ 60 | beta_step_part = partial(beta_step, 61 | local_evidence_multiple=local_evidence, 62 | transition_matrix=transition_matrix) 63 | beta_first, beta_hist = jax.lax.scan(beta_step_part, alpha_last, obs, reverse=True) 64 | return beta_first, beta_hist 65 | 66 | 67 | def forward_backward(obs: chex.Array, 68 | local_evidence: Callable[[chex.Array], chex.Array], 69 | transition_matrix: chex.Array, 70 | alpha_init: chex.Array) -> Dict[str, chex.Array]: 71 | """ 72 | Compute the forward-filter and backward-smoother pass for a 73 | given observation sequence. 74 | """ 75 | alpha_last, alpha_hist = alpha_forward(obs, local_evidence, transition_matrix, alpha_init) 76 | _, beta_hist = beta_backward(obs, local_evidence, transition_matrix, alpha_last) 77 | 78 | filter_hist = alpha_hist["alpha"] 79 | smooth_hist = filter_hist * beta_hist["beta"] 80 | smooth_hist = smooth_hist / smooth_hist.sum(axis=1, keepdims=True) 81 | 82 | res = { 83 | "filter": filter_hist, 84 | "smooth": smooth_hist 85 | } 86 | 87 | return res 88 | -------------------------------------------------------------------------------- /jsl/lds/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /jsl/lds/cont_kalman_filter.py: -------------------------------------------------------------------------------- 1 | # Implementation of the Kalman Filter for 2 | # continuous time series 3 | # Author: Gerardo Durán-Martín (@gerdm), Aleyna Kara(@karalleyna) 4 | 5 | import jax.numpy as jnp 6 | from jax import random, lax 7 | from jax.scipy.linalg import solve 8 | 9 | import chex 10 | from math import ceil 11 | 12 | from jsl.lds.kalman_filter import LDS 13 | 14 | 15 | def _rk2(x0, M, nsteps, dt): 16 | """ 17 | class-independent second-order Runge-Kutta method for linear systems 18 | 19 | Parameters 20 | ---------- 21 | x0: array(state_size, ) 22 | Initial state of the system 23 | M: array(state_size, K) 24 | Evolution matrix 25 | nsteps: int 26 | Total number of steps to integrate 27 | dt: float 28 | integration step size 29 | 30 | Returns 31 | ------- 32 | array(nsteps, state_size) 33 | Integration history 34 | """ 35 | def f(x): return M @ x 36 | input_dim, *_ = x0.shape 37 | 38 | def step(xt, t): 39 | k1 = f(xt) 40 | k2 = f(xt + dt * k1) 41 | xt = xt + dt * (k1 + k2) / 2 42 | return xt, xt 43 | 44 | steps = jnp.arange(nsteps) 45 | _, simulation = lax.scan(step, x0, steps) 46 | 47 | simulation = jnp.vstack([x0, simulation]) 48 | return simulation 49 | 50 | def sample(key: chex.PRNGKey, 51 | params: LDS, 52 | x0: chex.Array, 53 | T: float, 54 | nsamples: int, 55 | dt: float=0.01, 56 | noisy: bool=False): 57 | """ 58 | Run the Kalman Filter algorithm. First, we integrate 59 | up to time T, then we obtain nsamples equally-spaced points. Finally, 60 | we transform the latent space to obtain the observations 61 | 62 | Parameters 63 | ---------- 64 | params: LDS 65 | Linear Dynamical System object 66 | key: jax.random.PRNGKey 67 | x0: array(state_size) 68 | Initial state of simulation 69 | T: float 70 | Final time of integration 71 | nsamples: int 72 | Number of observations to take from the total integration 73 | dt: float 74 | integration step size 75 | noisy: bool 76 | Whether to (naively) add noise to the state space 77 | 78 | Returns 79 | ------- 80 | * array(nsamples, state_size) 81 | State-space values 82 | * array(nsamples, obs_size) 83 | Observed-space values 84 | * int 85 | Number of observations skipped between one 86 | datapoint and the next 87 | """ 88 | nsteps = ceil(T / dt) 89 | jump_size = ceil(nsteps / nsamples) 90 | correction = nsamples - ceil(nsteps / jump_size) 91 | nsteps += correction * jump_size 92 | 93 | key_state, key_obs = random.split(key) 94 | obs_size, state_size = params.C.shape 95 | state_noise = random.multivariate_normal(key_state, jnp.zeros(state_size), params.Q, (nsteps,)) 96 | obs_noise = random.multivariate_normal(key_obs, jnp.zeros(obs_size), params.R, (nsteps,)) 97 | simulation = _rk2(x0, params.A, nsteps, dt) 98 | 99 | if noisy: 100 | simulation = simulation + state_noise 101 | 102 | sample_state = simulation[::jump_size] 103 | sample_obs = jnp.einsum("ij,si->si", params.C, sample_state) + obs_noise[:len(sample_state)] 104 | 105 | return sample_state, sample_obs, jump_size 106 | 107 | def filter(params: LDS, 108 | x_hist: chex.Array, 109 | jump_size: chex.Array, 110 | dt: chex.Array): 111 | """ 112 | Compute the online version of the Kalman-Filter, i.e, 113 | the one-step-ahead prediction for the hidden state or the 114 | time update step 115 | 116 | Parameters 117 | ---------- 118 | x_hist: array(timesteps, observation_size) 119 | 120 | Returns 121 | ------- 122 | * array(timesteps, state_size): 123 | Filtered means mut 124 | * array(timesteps, state_size, state_size) 125 | Filtered covariances Sigmat 126 | * array(timesteps, state_size) 127 | Filtered conditional means mut|t-1 128 | * array(timesteps, state_size, state_size) 129 | Filtered conditional covariances Sigmat|t-1 130 | """ 131 | obs_size, state_size = params.C.shape 132 | 133 | I = jnp.eye(state_size) 134 | timesteps, *_ = x_hist.shape 135 | mu_hist = jnp.zeros((timesteps, state_size)) 136 | Sigma_hist = jnp.zeros((timesteps, state_size, state_size)) 137 | Sigma_cond_hist = jnp.zeros((timesteps, state_size, state_size)) 138 | mu_cond_hist = jnp.zeros((timesteps, state_size)) 139 | 140 | # Initial configuration 141 | A, Q, C, R = params.A, params.Q, params.C, params.R 142 | mu, Sigma = params.mu, params.Sigma 143 | 144 | temp = C @ Sigma @ C.T + R 145 | K1 = solve(temp, C @ Sigma, sym_pos=True).T 146 | mu1 = mu + K1 @ (x_hist[0] - C @ mu) 147 | Sigma1 = (I - K1 @ C) @ Sigma 148 | 149 | 150 | def rk_integration_step(state, carry): 151 | # Runge-kutta integration step 152 | mu, Sigma = state 153 | k1 = A @ mu 154 | k2 = A @ (mu + dt * k1) 155 | mu = mu + dt * (k1 + k2) / 2 156 | 157 | k1 = A @ Sigma @ A.T + Q 158 | k2 = A @ (Sigma + dt * k1) @ A.T + Q 159 | Sigma = Sigma + dt * (k1 + k2) / 2 160 | 161 | return (mu, Sigma), None 162 | 163 | def step(state, x): 164 | mun, Sigman = state 165 | initial_state = (mun, Sigman) 166 | (mun, Sigman), _ = lax.scan(rk_integration_step, initial_state, jnp.arange(jump_size)) 167 | 168 | Sigman_cond = jnp.ones_like(Sigman) * Sigman 169 | St = C @ Sigman_cond @ C.T + R 170 | Kn = solve(St, C @ Sigman_cond, sym_pos=True).T 171 | 172 | mu_update = jnp.ones_like(mun) * mun 173 | x_update = C @ mun 174 | mun = mu_update + Kn @ (x - x_update) 175 | Sigman = (I - Kn @ C) @ Sigman_cond 176 | 177 | return (mun, Sigman), (mun, Sigman, mu_update, Sigman_cond) 178 | 179 | initial_state = (mu1, Sigma1) 180 | _, (mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist) = lax.scan(step, initial_state, x_hist[1:]) 181 | 182 | mu_hist = jnp.vstack([mu1[None, ...], mu_hist]) 183 | Sigma_hist = jnp.vstack([Sigma1[None, ...], Sigma_hist]) 184 | mu_cond_hist = jnp.vstack([params.mu[None, ...], mu_cond_hist]) 185 | Sigma_cond_hist = jnp.vstack([params.Sigma[None, ...], Sigma_cond_hist]) 186 | 187 | return mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist 188 | -------------------------------------------------------------------------------- /jsl/lds/kalman_filter_test.py: -------------------------------------------------------------------------------- 1 | from jax import random 2 | from jax import numpy as jnp 3 | import numpy as np 4 | import tensorflow_probability.substrates.jax.distributions as tfd 5 | 6 | from jsl.lds.kalman_filter import LDS, kalman_filter, kalman_smoother 7 | 8 | 9 | def lds_jsl_to_tfp(num_timesteps, lds): 10 | """Convert a JSL `LDS` object into a tfp `LinearGaussianStateSpaceModel`. 11 | 12 | Args: 13 | num_timesteps: int, number of timesteps. 14 | lds: LDS object. 15 | """ 16 | dynamics_noise_dist = tfd.MultivariateNormalFullCovariance(covariance_matrix=lds.Q) 17 | emission_noise_dist = tfd.MultivariateNormalFullCovariance(covariance_matrix=lds.R) 18 | initial_dist = tfd.MultivariateNormalFullCovariance(lds.mu, lds.Sigma) 19 | 20 | tfp_lgssm = tfd.LinearGaussianStateSpaceModel( 21 | num_timesteps, 22 | lds.A, dynamics_noise_dist, 23 | lds.C, emission_noise_dist, 24 | initial_dist, 25 | ) 26 | 27 | return tfp_lgssm 28 | 29 | 30 | def test_kalman_filter(): 31 | key = random.PRNGKey(314) 32 | num_timesteps = 15 33 | delta = 1.0 34 | 35 | ### LDS Parameters ### 36 | state_size = 2 37 | observation_size = 2 38 | A = jnp.eye(state_size) 39 | C = jnp.eye(state_size) 40 | 41 | transition_noise_scale = 1.0 42 | observation_noise_scale = 1.0 43 | Q = jnp.eye(state_size) * transition_noise_scale 44 | R = jnp.eye(observation_size) * observation_noise_scale 45 | 46 | ### Prior distribution params ### 47 | mu0 = jnp.array([8, 10]).astype(float) 48 | Sigma0 = jnp.eye(state_size) * 1.0 49 | 50 | ### Sample data ### 51 | lds_instance = LDS(A, C, Q, R, mu0, Sigma0) 52 | z_hist, x_hist = lds_instance.sample(key, num_timesteps) 53 | 54 | filter_output = kalman_filter(lds_instance, x_hist) 55 | JSL_filtered_means, JSL_filtered_covs, *_ = filter_output 56 | JSL_smoothed_means, JSL_smoothed_covs = kalman_smoother(lds_instance, *filter_output) 57 | 58 | tfp_lgssm = lds_jsl_to_tfp(num_timesteps, lds_instance) 59 | _, tfp_filtered_means, tfp_filtered_covs, *_ = tfp_lgssm.forward_filter(x_hist) 60 | tfp_smoothed_means, tfp_smoothed_covs = tfp_lgssm.posterior_marginals(x_hist) 61 | 62 | assert np.allclose(JSL_filtered_means, tfp_filtered_means, rtol=1e-2) 63 | assert np.allclose(JSL_filtered_covs, tfp_filtered_covs, rtol=1e-2) 64 | assert np.allclose(JSL_smoothed_means, tfp_smoothed_means, rtol=1e-2) 65 | assert np.allclose(JSL_smoothed_covs, tfp_smoothed_covs, rtol=1e-2) 66 | -------------------------------------------------------------------------------- /jsl/lds/kalman_filter_with_unknown_noise.py: -------------------------------------------------------------------------------- 1 | # Jax implementation of a Linear Dynamical System where the observation noise is not known. 2 | # Author: Gerardo Durán-Martín (@gerdm), Aleyna Kara(@karalleyna) 3 | 4 | import chex 5 | 6 | import jax.numpy as jnp 7 | from jax import lax, vmap, tree_map 8 | 9 | from dataclasses import dataclass 10 | from functools import partial 11 | from typing import Union, Callable 12 | 13 | from tensorflow_probability.substrates import jax as tfp 14 | 15 | tfd = tfp.distributions 16 | 17 | 18 | @dataclass 19 | class LDS: 20 | """ 21 | Kalman filtering for a linear Gaussian state space model with scalar observations, 22 | where all the parameters are known except for the observation noise variance. 23 | The model has the following form: 24 | p(state(0)) = Gauss(mu0, Sigma0) 25 | p(state(t) | state(t-1)) = Gauss(A * state(t-1), Q) 26 | p(obs(t) | state(t)) = Gauss(C * state(t), r) 27 | The value of r is jointly inferred together with the latent states to produce the posterior 28 | p(state(t) , r | obs(1:t)) = Gauss(mu(t), Sigma(t) * r) * Ga(1/r | nu(t)/2, nu(t)*tau(t)/2) 29 | where 1/r is the observation precision. For details on this algorithm, see sec 4.5 of 30 | "Bayesian forecasting and dynamic models", West and Harrison, 1997. 31 | https://www2.stat.duke.edu/~mw/West&HarrisonBook/ 32 | 33 | Parameters 34 | ---------- 35 | A: array(state_size, state_size) 36 | Transition matrix 37 | C: array(observation_size, state_size) 38 | Observation matrix 39 | Q: array(state_size, state_size) 40 | Transition covariance matrix 41 | mu: array(state_size) 42 | Mean of initial configuration 43 | Sigma: array(state_size, state_size) or 0 44 | Covariance of initial configuration. If value is set 45 | to zero, the initial state will be completely determined 46 | by mu0 47 | """ 48 | A: chex.Array 49 | C: Union[chex.Array, Callable] 50 | Q: chex.Array 51 | R: chex.Array 52 | mu: chex.Array 53 | Sigma: chex.Array 54 | v: chex.Array 55 | tau: chex.Array 56 | 57 | 58 | def kalman_filter(params: LDS, x_hist: chex.Array, 59 | return_history: bool = True): 60 | """ 61 | Compute the online version of the Kalman-Filter, i.e, 62 | the one-step-ahead prediction for the hidden state or the 63 | time update step 64 | 65 | Parameters 66 | ---------- 67 | params: LDS 68 | Linear Dynamical System object 69 | x_hist: array(timesteps, observation_size) 70 | return_history: bool 71 | 72 | Returns 73 | ------- 74 | * array(timesteps, state_size): 75 | Filtered means mut 76 | * array(timesteps, state_size, state_size) 77 | Filtered covariances Sigmat 78 | * array(timesteps, state_size) 79 | Filtered conditional means mut|t-1 80 | * array(timesteps, state_size, state_size) 81 | Filtered conditional covariances Sigmat|t-1 82 | """ 83 | A, Q, R = params.A, params.Q, params.R 84 | state_size, _ = A.shape 85 | 86 | def kalman_step(state, obs): 87 | mu, Sigma, v, tau = state 88 | covariates, response = obs 89 | 90 | mu_cond = jnp.matmul(A, mu, precision=lax.Precision.HIGHEST) 91 | Sigmat_cond = jnp.matmul(jnp.matmul(A, Sigma, precision=lax.Precision.HIGHEST), A, 92 | precision=lax.Precision.HIGHEST) + Q 93 | 94 | e_k = response - covariates.T @ mu_cond 95 | s_k = covariates.T @ Sigmat_cond @ covariates + 1 96 | Kt = (Sigmat_cond @ covariates) / s_k 97 | 98 | mu = mu + e_k * Kt 99 | Sigma = Sigmat_cond - jnp.outer(Kt, Kt) * s_k 100 | 101 | v_update = v + 1 102 | tau = (v * tau + (e_k * e_k) / s_k) / v_update 103 | 104 | return (mu, Sigma, v_update, tau), (mu, Sigma) 105 | 106 | mu0, Sigma0 = params.mu, params.Sigma 107 | initial_state = (mu0, Sigma0, 0) 108 | (mu, Sigma, _, _), history = lax.scan(kalman_step, initial_state, x_hist) 109 | if return_history: 110 | return history 111 | return mu, Sigma 112 | 113 | 114 | def filter(params: LDS, x_hist: chex.Array, 115 | return_history: bool = True): 116 | """ 117 | Compute the online version of the Kalman-Filter, i.e, 118 | the one-step-ahead prediction for the hidden state or the 119 | time update step. 120 | Note that x_hist can optionally be of dimensionality two, 121 | This corresponds to different samples of the same underlying 122 | Linear Dynamical System 123 | Parameters 124 | ---------- 125 | params: LDS 126 | Linear Dynamical System object 127 | x_hist: array(n_samples?, timesteps, observation_size) 128 | Returns 129 | ------- 130 | * array(n_samples?, timesteps, state_size): 131 | Filtered means mut 132 | * array(n_samples?, timesteps, state_size, state_size) 133 | Filtered covariances Sigmat 134 | * array(n_samples?, timesteps, state_size) 135 | Filtered conditional means mut|t-1 136 | * array(n_samples?, timesteps, state_size, state_size) 137 | Filtered conditional covariances Sigmat|t-1 138 | """ 139 | has_one_sim = False 140 | if x_hist.ndim == 2: 141 | x_hist = x_hist[None, ...] 142 | has_one_sim = True 143 | kalman_map = vmap(partial(kalman_filter, return_history=return_history), (None, 0)) 144 | outputs = kalman_map(params, x_hist) 145 | if has_one_sim and return_history: 146 | return tree_map(lambda x: x[0, ...], outputs) 147 | return outputs 148 | -------------------------------------------------------------------------------- /jsl/lds/kalman_sampler.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax.random import multivariate_normal, PRNGKey 3 | from jax.scipy.linalg import solve, cholesky 4 | from jax import lax 5 | 6 | from .kalman_filter import LDS 7 | 8 | 9 | 10 | def smooth_sampler(params: LDS, 11 | key: PRNGKey, 12 | mu_hist: jnp.array, 13 | Sigma_hist: jnp.array, 14 | n_samples: jnp.array = 1): 15 | """ 16 | Backwards sample from the smoothing distribution 17 | Parameters 18 | ---------- 19 | params: LDS 20 | Linear Dynamical System object 21 | key: jax.random.PRNGKey 22 | Seed of state noises 23 | mu_hist: array(timesteps, state_size): 24 | Filtered means mut 25 | Sigma_hist: array(timesteps, state_size, state_size) 26 | Filtered covariances Sigmat 27 | n_samples: int 28 | Number of posterior samples (optional) 29 | Returns 30 | ------- 31 | * array(n_samples, timesteps, state_size): 32 | Posterior samples 33 | """ 34 | state_size, _ = params.get_trans_mat_of(0).shape 35 | I = jnp.eye(state_size) 36 | timesteps = len(mu_hist) 37 | # Generate all state noise terms 38 | zeros_state = jnp.zeros(state_size) 39 | system_noise = multivariate_normal(key, zeros_state, I, (timesteps, n_samples)) 40 | state_T = mu_hist[-1] + system_noise[-1] @ Sigma_hist[-1, ...].T 41 | 42 | def smooth_sample_step(state, inps): 43 | system_noise_t, mutt, Sigmatt , t = inps 44 | A = params.get_trans_mat_of(t) 45 | et = state - mutt @ A.T 46 | St = A @ Sigmatt @ A.T + params.get_system_noise_of(t) 47 | Kt = solve(St, A @ Sigmatt, sym_pos=True).T 48 | mu_t = mutt + et @ Kt.T 49 | Sigma_t = (I - Kt @ A) @ Sigmatt 50 | Sigma_root = cholesky(Sigma_t) 51 | state_new = mu_t + system_noise_t @ Sigma_root.T 52 | return state_new, state_new 53 | 54 | inps = (system_noise[-2::-1, ...], mu_hist[-2::-1, ...], Sigma_hist[-2::-1, ...], jnp.arange(1, timesteps)[::-1]) 55 | _, state_sample_smooth = lax.scan(smooth_sample_step, state_T, inps) 56 | 57 | state_sample_smooth = jnp.concatenate([state_sample_smooth[::-1, ...], state_T[None, ...]], axis=0) 58 | state_sample_smooth = jnp.swapaxes(state_sample_smooth, 0, 1) 59 | 60 | if n_samples == 1: 61 | state_sample_smooth = state_sample_smooth[:, 0, :] 62 | return state_sample_smooth -------------------------------------------------------------------------------- /jsl/lds/mixture_kalman_filter.py: -------------------------------------------------------------------------------- 1 | # Mixture Kalman Filter library. Also known as the 2 | # Rao-Blackwell Particle Filter. 3 | 4 | # Author: Gerardo Durán-Martín (@gerdm) 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | from jax import random 9 | from jax.scipy.special import logit 10 | from dataclasses import dataclass 11 | 12 | 13 | @dataclass 14 | class RBPFParamsDiscrete: 15 | """ 16 | Rao-Blackwell Particle Filtering (RBPF) parameters for 17 | a system with discrete latent-space. 18 | We assume that the system evolves as 19 | z_next = A * z_old + B(u_old) + noise1_next 20 | x_next = C * z_next + noise2_next 21 | u_next ~ transition_matrix(u_old) 22 | 23 | where 24 | noise1_next ~ N(0, Q) 25 | noise2_next ~ N(0, R) 26 | """ 27 | A: jnp.array 28 | B: jnp.array 29 | C: jnp.array 30 | Q: jnp.array 31 | R: jnp.array 32 | transition_matrix: jnp.array 33 | 34 | 35 | def draw_state(val, key, params): 36 | """ 37 | Simulate one step of a system that evolves as 38 | A z_{t-1} + Bk + eps, 39 | where eps ~ N(0, Q). 40 | 41 | Parameters 42 | ---------- 43 | val: tuple (int, jnp.array) 44 | (latent value of system, state value of system). 45 | params: PRBPFParamsDiscrete 46 | key: PRNGKey 47 | """ 48 | latent_old, state_old = val 49 | probabilities = params.transition_matrix[latent_old, :] 50 | logits = logit(probabilities) 51 | latent_new = random.categorical(key, logits) 52 | 53 | key_latent, key_obs = random.split(key) 54 | state_new = params.A @ state_old + params.B[latent_new, :] 55 | state_new = random.multivariate_normal(key_latent, state_new, params.Q) 56 | obs_new = random.multivariate_normal(key_obs, params.C @ state_new, params.R) 57 | 58 | return (latent_new, state_new), (latent_new, state_new, obs_new) 59 | 60 | 61 | def kf_update(mu_t, Sigma_t, k, xt, params): 62 | I = jnp.eye(len(mu_t)) 63 | mu_t_cond = params.A @ mu_t + params.B[k] 64 | Sigma_t_cond = params.A @ Sigma_t @ params.A.T + params.Q 65 | xt_cond = params.C @ mu_t_cond 66 | St = params.C @ Sigma_t_cond @ params.C.T + params.R 67 | 68 | Kt = Sigma_t_cond @ params.C.T @ jnp.linalg.inv(St) 69 | 70 | # Estimation update 71 | mu_t = mu_t_cond + Kt @ (xt - xt_cond) 72 | Sigma_t = (I - Kt @ params.C) @ Sigma_t_cond 73 | 74 | # Normalisation constant 75 | mean_norm = params.C @ mu_t_cond 76 | cov_norm = params.C @ Sigma_t_cond @ params.C.T + params.R 77 | Ltk = jax.scipy.stats.multivariate_normal.pdf(xt, mean_norm, cov_norm) 78 | 79 | return mu_t, Sigma_t, Ltk 80 | 81 | 82 | def rbpf_step(key, weight_t, st, mu_t, Sigma_t, xt, params): 83 | log_p_next = logit(params.transition_matrix[st]) 84 | k = random.categorical(key, log_p_next) 85 | mu_t, Sigma_t, Ltk = kf_update(mu_t, Sigma_t, k, xt, params) 86 | weight_t = weight_t * Ltk 87 | 88 | return mu_t, Sigma_t, weight_t, Ltk 89 | 90 | 91 | kf_update_vmap = jax.vmap(kf_update, in_axes=(None, None, 0, None, None), out_axes=0) 92 | 93 | 94 | def rbpf_step_optimal(key, weight_t, st, mu_t, Sigma_t, xt, params): 95 | k = jnp.arange(len(params.transition_matrix)) 96 | mu_tk, Sigma_tk, Ltk = kf_update_vmap(mu_t, Sigma_t, k, xt, params) 97 | 98 | proposal = Ltk * params.transition_matrix[st] 99 | 100 | weight_tk = weight_t * proposal.sum() 101 | proposal = proposal / proposal.sum() 102 | 103 | return mu_tk, Sigma_tk, weight_tk, proposal 104 | 105 | 106 | # vectorised RBPF step 107 | rbpf_step_vec = jax.vmap(rbpf_step, in_axes=(0, 0, 0, 0, 0, None, None)) 108 | # vectorisedRBPF Step optimal 109 | rbpf_step_optimal_vec = jax.vmap(rbpf_step_optimal, in_axes=(0, 0, 0, 0, 0, None, None)) 110 | 111 | 112 | def rbpf(current_config, xt, params, nparticles=100): 113 | """ 114 | Rao-Blackwell Particle Filter using prior as proposal 115 | """ 116 | key, mu_t, Sigma_t, weights_t, st = current_config 117 | 118 | key_sample, key_state, key_next, key_reindex = random.split(key, 4) 119 | keys = random.split(key_sample, nparticles) 120 | 121 | st = random.categorical(key_state, logit(params.transition_matrix[st, :])) 122 | mu_t, Sigma_t, weights_t, Ltk = rbpf_step_vec(keys, weights_t, st, mu_t, Sigma_t, xt, params) 123 | weights_t = weights_t / weights_t.sum() 124 | 125 | indices = jnp.arange(nparticles) 126 | pi = random.choice(key_reindex, indices, shape=(nparticles,), p=weights_t, replace=True) 127 | st = st[pi] 128 | mu_t = mu_t[pi, ...] 129 | Sigma_t = Sigma_t[pi, ...] 130 | weights_t = jnp.ones(nparticles) / nparticles 131 | 132 | return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, Ltk) 133 | 134 | 135 | def rbpf_optimal(current_config, xt, params, nparticles=100): 136 | """ 137 | Rao-Blackwell Particle Filter using optimal proposal 138 | """ 139 | key, mu_t, Sigma_t, weights_t, st = current_config 140 | 141 | key_sample, key_state, key_next, key_reindex = random.split(key, 4) 142 | keys = random.split(key_sample, nparticles) 143 | 144 | st = random.categorical(key_state, logit(params.transition_matrix[st, :])) 145 | mu_t, Sigma_t, weights_t, proposal = rbpf_step_optimal_vec(keys, weights_t, st, mu_t, Sigma_t, xt, params) 146 | 147 | indices = jnp.arange(nparticles) 148 | pi = random.choice(key_reindex, indices, shape=(nparticles,), p=weights_t, replace=True) 149 | 150 | # Obtain optimal proposal distribution 151 | proposal_samp = proposal[pi, :] 152 | st = random.categorical(key, logit(proposal_samp)) 153 | 154 | mu_t = mu_t[pi, st, ...] 155 | Sigma_t = Sigma_t[pi, st, ...] 156 | 157 | weights_t = jnp.ones(nparticles) / nparticles 158 | 159 | return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, proposal_samp) 160 | -------------------------------------------------------------------------------- /jsl/nlds/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/JSL/649c1fa9709c9195d80f06d7a367112d67dc60c9/jsl/nlds/__init__.py -------------------------------------------------------------------------------- /jsl/nlds/base.py: -------------------------------------------------------------------------------- 1 | # Library of nonlinear dynamical systems 2 | # Usage: Every discrete xKF class inherits from NLDS. 3 | # There are two ways to use this library in the discrete case: 4 | # 1) Explicitly initialize a discrete NLDS object with the desired parameters, 5 | # then pass it onto the xKF class of your choice. 6 | # 2) Initialize the xKF object with the desired NLDS parameters using 7 | # the .from_base constructor. 8 | # Way 1 is preferable whenever you want to use the same NLDS for multiple 9 | # filtering processes. Way 2 is preferred whenever you want to use a single NLDS 10 | # for a single filtering process 11 | 12 | # Author: Gerardo Durán-Martín (@gerdm) 13 | 14 | import jax 15 | from jax.random import split, multivariate_normal 16 | 17 | import chex 18 | 19 | from dataclasses import dataclass 20 | from typing import Callable 21 | 22 | 23 | @dataclass 24 | class NLDS: 25 | """ 26 | Base class for the nonlinear dynamical systems' module 27 | 28 | Parameters 29 | ---------- 30 | fz: function 31 | Nonlinear state transition function 32 | fx: function 33 | Nonlinear observation function 34 | Q: array(state_size, state_size) or function 35 | Nonlinear state transition noise covariance function 36 | R: array(obs_size, obs_size) or function 37 | Nonlinear observation noise covariance function 38 | """ 39 | fz: Callable 40 | fx: Callable 41 | Q: chex.Array 42 | R: chex.Array 43 | alpha: float = 0. 44 | beta: float = 0. 45 | kappa: float = 0. 46 | d: int = 0 47 | 48 | def Qz(self, z, *args): 49 | if callable(self.Q): 50 | return self.Q(z, *args) 51 | else: 52 | return self.Q 53 | 54 | def Rx(self, x, *args): 55 | if callable(self.R): 56 | return self.R(x, *args) 57 | else: 58 | return self.R 59 | 60 | def __sample_step(self, input_vals, obs): 61 | key, state_t = input_vals 62 | key_system, key_obs, key = split(key, 3) 63 | 64 | state_t = multivariate_normal(key_system, self.fz(state_t), self.Qz(state_t)) 65 | obs_t = multivariate_normal(key_obs, self.fx(state_t, *obs), self.Rx(state_t, *obs)) 66 | 67 | return (key, state_t), (state_t, obs_t) 68 | 69 | def sample(self, key, x0, nsteps, obs=None): 70 | """ 71 | Sample discrete elements of a nonlinear system 72 | Parameters 73 | ---------- 74 | key: jax.random.PRNGKey 75 | x0: array(state_size) 76 | Initial state of simulation 77 | nsteps: int 78 | Total number of steps to sample from the system 79 | obs: None, tuple of arrays 80 | Observed values to pass to fx and R 81 | Returns 82 | ------- 83 | * array(nsamples, state_size) 84 | State-space values 85 | * array(nsamples, obs_size) 86 | Observed-space values 87 | """ 88 | obs = () if obs is None else obs 89 | state_t = x0.copy() 90 | obs_t = self.fx(state_t) 91 | 92 | self.state_size, *_ = state_t.shape 93 | self.obs_t, *_ = obs_t.shape 94 | 95 | init_state = (key, state_t) 96 | _, hist = jax.lax.scan(self.__sample_step, init_state, obs, length=nsteps) 97 | 98 | return hist 99 | -------------------------------------------------------------------------------- /jsl/nlds/bootstrap_filter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of the Bootrstrap Filter for discrete time systems 3 | **This implementation considers the case of multivariate normals** 4 | 5 | 6 | """ 7 | import jax.numpy as jnp 8 | from jax import random, lax 9 | 10 | import chex 11 | 12 | from jax.scipy import stats 13 | from jsl.nlds.base import NLDS 14 | 15 | 16 | # TODO: Extend to general case 17 | def filter(params: NLDS, 18 | key: chex.PRNGKey, 19 | init_state: chex.Array, 20 | sample_obs: chex.Array, 21 | nsamples: int = 2000, 22 | Vinit: chex.Array = None): 23 | """ 24 | init_state: array(state_size,) 25 | Initial state estimate 26 | sample_obs: array(nsamples, obs_size) 27 | Samples of the observations 28 | """ 29 | m, *_ = init_state.shape 30 | 31 | fx, fz = params.fx, params.fz 32 | Q, R = params.Qz, params.Rx 33 | 34 | key, key_init = random.split(key, 2) 35 | V = Q(init_state) if Vinit is None else Vinit 36 | zt_rvs = random.multivariate_normal(key_init, init_state, V, shape=(nsamples,)) 37 | 38 | init_state = (zt_rvs, key) 39 | 40 | def __filter_step(state, obs_t): 41 | indices = jnp.arange(nsamples) 42 | zt_rvs, key_t = state 43 | 44 | key_t, key_reindex, key_next = random.split(key_t, 3) 45 | # 1. Draw new points from the dynamic model 46 | zt_rvs = random.multivariate_normal(key_t, fz(zt_rvs), Q(zt_rvs)) 47 | 48 | # 2. Calculate unnormalised weights 49 | xt_rvs = fx(zt_rvs) 50 | weights_t = stats.multivariate_normal.pdf(obs_t, xt_rvs, R(zt_rvs, obs_t)) 51 | 52 | # 3. Resampling 53 | pi = random.choice(key_reindex, indices, 54 | p=weights_t, shape=(nsamples,)) 55 | zt_rvs = zt_rvs[pi, ...] 56 | weights_t = jnp.ones(nsamples) / nsamples 57 | 58 | # 4. Compute latent-state estimate, 59 | # Set next covariance state matrix 60 | mu_t = jnp.einsum("im,i->m", zt_rvs, weights_t) 61 | 62 | return (zt_rvs, key_next), mu_t 63 | 64 | _, mu_hist = lax.scan(__filter_step, init_state, sample_obs) 65 | 66 | return mu_hist 67 | -------------------------------------------------------------------------------- /jsl/nlds/continuous_extended_kalman_filter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extended Kalman Filter for a nonlinear continuous time 3 | dynamical system with observations in discrete time. 4 | """ 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | from jax import lax, jacrev 9 | 10 | import chex 11 | 12 | from math import ceil 13 | 14 | from jsl.nlds.base import NLDS 15 | 16 | 17 | def _rk2(x0, f, nsteps, dt): 18 | """ 19 | class-independent second-order Runge-Kutta method 20 | 21 | Parameters 22 | ---------- 23 | x0: array(state_size, ) 24 | Initial state of the system 25 | f: function 26 | Function to integrate. Must return jax.numpy 27 | array of size state_size 28 | nsteps: int 29 | Total number of steps to integrate 30 | dt: float 31 | integration step size 32 | 33 | Returns 34 | ------- 35 | array(nsteps, state_size) 36 | Integration history 37 | """ 38 | input_dim, *_ = x0.shape 39 | 40 | def step(xt, t): 41 | k1 = f(xt) 42 | k2 = f(xt + dt * k1) 43 | xt = xt + dt * (k1 + k2) / 2 44 | return xt, xt 45 | 46 | steps = jnp.arange(nsteps) 47 | _, simulation = lax.scan(step, x0, steps) 48 | 49 | simulation = jnp.vstack([x0, simulation]) 50 | return simulation 51 | 52 | 53 | def sample(key: chex.PRNGKey, 54 | params: NLDS, 55 | x0: chex.Array, 56 | T: float, 57 | nsamples: int, 58 | dt: float = 0.01, 59 | noisy: bool = False): 60 | """ 61 | Run the Extended Kalman Filter algorithm. First, we integrate 62 | up to time T, then we obtain nsamples equally-spaced points. Finally, 63 | we transform the latent space to obtain the observations 64 | 65 | Parameters 66 | ---------- 67 | key: jax.random.PRNGKey 68 | Initial seed 69 | x0: array(state_size) 70 | Initial state of simulation 71 | T: float 72 | Final time of integration 73 | nsamples: int 74 | Number of observations to take from the total integration 75 | dt: float 76 | integration step size 77 | noisy: bool 78 | Whether to (naively) add noise to the state space 79 | 80 | Returns 81 | ------- 82 | * array(nsamples, state_size) 83 | State-space values 84 | * array(nsamples, obs_size) 85 | Observed-space values 86 | * int 87 | Number of observations skipped between one 88 | datapoint and the next 89 | """ 90 | 91 | fz, fx = params.fz, params.fx 92 | Q, R = params.Qz, params.Rx 93 | 94 | state_size, _ = Q.shape 95 | obs_size, _ = R.shape 96 | 97 | nsteps = ceil(T / dt) 98 | jump_size = ceil(nsteps / nsamples) 99 | correction = nsamples - ceil(nsteps / jump_size) 100 | nsteps += correction * jump_size 101 | 102 | key_state, key_obs = jax.random.split(key) 103 | state_noise = jax.random.multivariate_normal(key_state, jnp.zeros(state_size), Q, (nsteps,)) 104 | obs_noise = jax.random.multivariate_normal(key_obs, jnp.zeros(obs_size), R, (nsteps,)) 105 | simulation = _rk2(x0, fz, nsteps, dt) 106 | 107 | if noisy: 108 | simulation = simulation + jnp.sqrt(dt) * state_noise 109 | 110 | sample_state = simulation[::jump_size] 111 | sample_obs = jnp.apply_along_axis(fx, 1, sample_state) + obs_noise[:len(sample_state)] 112 | 113 | return sample_state, sample_obs, jump_size 114 | 115 | 116 | def _Vt_dot(V, G, Q): 117 | return G @ V @ G.T + Q 118 | 119 | 120 | def estimate(params: NLDS, 121 | sample_state: chex.Array, 122 | sample_obs: chex.Array, 123 | jump_size: int, 124 | dt: float, 125 | return_history: bool = True): 126 | """ 127 | Run the Extended Kalman Filter algorithm over a set of observed samples. 128 | 129 | Parameters 130 | ---------- 131 | sample_state: array(nsamples, state_size) 132 | sample_obs: array(nsamples, obs_size) 133 | jump_size: int 134 | dt: float 135 | return_history: bool 136 | 137 | Returns 138 | ------- 139 | * array(nsamples, state_size) 140 | History of filtered mean terms 141 | * array(nsamples, state_size, state_size) 142 | History of filtered covariance terms 143 | """ 144 | 145 | fz, fx = params.fz, params.fx 146 | Q, R = params.Qz, params.Rx 147 | 148 | Dfz = jacrev(fz) 149 | Dfx = jacrev(fx) 150 | 151 | state_size, _ = Q.shape 152 | obs_size, _ = R.shape 153 | 154 | I = jnp.eye(state_size) 155 | Vt = R.copy() 156 | mu_t = sample_state[0] 157 | 158 | def jump_step(state, t): 159 | mu_t, Vt = state 160 | k1 = fz(mu_t) 161 | k2 = fz(mu_t + dt * k1) 162 | mu_t = mu_t + dt * (k1 + k2) / 2 163 | 164 | Gt = Dfz(mu_t) 165 | k1 = _Vt_dot(Vt, Gt, Q) 166 | k2 = _Vt_dot(Vt + dt * k1, Gt, Q) 167 | Vt = Vt + dt * (k1 + k2) / 2 168 | return (mu_t, Vt), None 169 | 170 | def step(state, obs): 171 | jumps = jnp.arange(jump_size) 172 | (mu, V), _ = lax.scan(jump_step, state, jumps) 173 | 174 | mu_t_cond = mu 175 | Vt_cond = V 176 | Ht = Dfx(mu_t_cond) 177 | 178 | Kt = Vt_cond @ Ht.T @ jnp.linalg.inv(Ht @ Vt_cond @ Ht.T + R) 179 | mu = mu_t_cond + Kt @ (obs - fx(mu_t_cond)) 180 | V = (I - Kt @ Ht) @ Vt_cond 181 | return (mu, V), (mu, V) 182 | 183 | initial_state = (mu_t.copy(), Vt.copy()) 184 | (mu, V), (mu_hist, V_hist) = lax.scan(step, initial_state, sample_obs[1:]) 185 | 186 | if return_history: 187 | mu_hist = jnp.vstack([mu_t, mu_hist]) 188 | V_hist = jnp.vstack([Vt, V_hist]) 189 | return mu_hist, V_hist 190 | 191 | return mu, V 192 | -------------------------------------------------------------------------------- /jsl/nlds/diagonal_extended_kalman_filter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of the Diagonal Extended Kalman Filter for a nonlinear 3 | dynamical system with discrete observations. Also known as the 4 | Node-decoupled Extended Kalman Filter (NDEKF) 5 | """ 6 | 7 | import jax.numpy as jnp 8 | from jax import jacrev, lax 9 | 10 | import chex 11 | from typing import Tuple 12 | 13 | from .base import NLDS 14 | 15 | 16 | def filter(params: NLDS, 17 | init_state: chex.Array, 18 | sample_obs: chex.Array, 19 | observations: Tuple = None, 20 | Vinit: chex.Array = None, 21 | return_history: bool = True): 22 | """ 23 | Run the Extended Kalman Filter algorithm over a set of observed samples. 24 | Parameters 25 | ---------- 26 | init_state: array(state_size) 27 | sample_obs: array(nsamples, obs_size) 28 | Returns 29 | ------- 30 | * array(nsamples, state_size) 31 | History of filtered mean terms 32 | * array(nsamples, state_size, state_size) 33 | History of filtered covariance terms 34 | """ 35 | state_size, *_ = init_state.shape 36 | 37 | fz, fx = params.fz, params.fx 38 | Q, R = params.Qz, params.Rx 39 | Dfx = jacrev(fx) 40 | 41 | Vt = Q(init_state) if Vinit is None else Vinit 42 | 43 | t = 0 44 | state = (init_state, Vt, t) 45 | observations = (observations,) if type(observations) is not tuple else observations 46 | xs = (sample_obs, observations) 47 | 48 | def filter_step(state: Tuple[chex.Array, chex.Array], 49 | xs: Tuple[chex.Array, int]): 50 | """ 51 | Run the Extended Kalman filter algorithm for a single step 52 | Paramters 53 | --------- 54 | state: tuple 55 | Mean, covariance at time t-1 56 | xs: tuple 57 | Target value and observations at time t 58 | """ 59 | mu_t, Vt, t = state 60 | xt, obs = xs 61 | 62 | mu_t_cond = fz(mu_t) 63 | Ht = Dfx(mu_t_cond, *obs) 64 | 65 | Rt = R(mu_t_cond, *obs) 66 | xt_hat = fx(mu_t_cond, *obs) 67 | xi = xt - xt_hat 68 | A = jnp.linalg.inv(Rt + jnp.einsum("id,jd,d->ij", Ht, Ht, Vt)) 69 | mu_t = mu_t_cond + jnp.einsum("s,is,ij,j->s", Vt, Ht, A, xi) 70 | Vt = Vt - jnp.einsum("s,is,ij,is,s->s", Vt, Ht, A, Ht, Vt) + Q(mu_t, t) 71 | 72 | return (mu_t, Vt, t + 1), (mu_t, None) 73 | 74 | (mu_t, Vt, _), mu_t_hist = lax.scan(filter_step, state, xs) 75 | 76 | if return_history: 77 | return (mu_t, Vt), mu_t_hist 78 | 79 | return (mu_t, Vt), None 80 | -------------------------------------------------------------------------------- /jsl/nlds/extended_kalman_filter.py: -------------------------------------------------------------------------------- 1 | import chex 2 | from jax import jacrev, lax 3 | import jax.numpy as jnp 4 | from typing import Dict, List, Tuple, Callable 5 | from functools import partial 6 | from .base import NLDS 7 | 8 | def filter_step(state: Tuple[chex.Array, chex.Array, int], 9 | xs: Tuple[chex.Array, chex.Array], 10 | params: NLDS, 11 | Dfx: Callable, 12 | Dfz: Callable, 13 | eps: float, 14 | return_params: Dict 15 | ) -> Tuple[Tuple[chex.Array, chex.Array, int], Dict]: 16 | """ 17 | Run a single step of the extended Kalman filter (EKF) algorithm. 18 | 19 | Parameters 20 | --------- 21 | state: tuple 22 | Mean, covariance at time t-1 23 | xs: tuple 24 | Target value and covariates at time t 25 | params: NLDS 26 | Nonlinear dynamical system parameters 27 | Dfx: Callable 28 | Jacobian of the observation function 29 | Dfz: Callable 30 | Jacobian of the state transition function 31 | eps: float 32 | Small number to prevent singular matrix 33 | return_params: list 34 | Fix elements to carry 35 | 36 | Returns 37 | ------- 38 | * tuple 39 | 1. Mean, covariance, and time at time t 40 | 2. History of filtered mean terms (if requested) 41 | """ 42 | mu_t, Vt, t = state 43 | obs, inputs = xs 44 | 45 | state_size, *_ = mu_t.shape 46 | I = jnp.eye(state_size) 47 | Gt = Dfz(mu_t) 48 | mu_t_cond = params.fz(mu_t) 49 | Vt_cond = Gt @ Vt @ Gt.T + params.Qz(mu_t, t) 50 | Ht = Dfx(mu_t_cond, *inputs) 51 | 52 | Rt = params.Rx(mu_t_cond, *inputs) 53 | num_inputs, *_ = Rt.shape 54 | 55 | obs_hat = params.fx(mu_t_cond, *inputs) 56 | Mt = Ht @ Vt_cond @ Ht.T + Rt + eps * jnp.eye(num_inputs) 57 | Kt = Vt_cond @ Ht.T @ jnp.linalg.inv(Mt) 58 | mu_t = mu_t_cond + Kt @ (obs - obs_hat) 59 | Vt = (I - Kt @ Ht) @ Vt_cond @ (I - Kt @ Ht).T + Kt @ Rt @ Kt.T 60 | 61 | carry = {"mean": mu_t, "cov": Vt, "obs_hat": obs_hat} 62 | carry = {key: val for key, val in carry.items() if key in return_params} 63 | return (mu_t, Vt, t + 1), carry 64 | 65 | 66 | def filter(params: NLDS, 67 | init_state: chex.Array, 68 | observations: chex.Array, 69 | covariates: chex.Array = None, 70 | Vinit: chex.Array = None, 71 | return_params: List = None, 72 | eps: float = 0.001, 73 | return_history: bool = True): 74 | """ 75 | Run the Extended Kalman Filter algorithm over a set of observed samples. 76 | 77 | Parameters 78 | ---------- 79 | init_state: array(state_size) 80 | observations: array(nsamples, obs_size) 81 | covariates: array(nsamples, feature_size) or None 82 | optional covariates to pass to the observation function 83 | Vinit: array(state_size, state_size) or None 84 | Initial state covariance matrix 85 | return_params: list 86 | Parameters to carry from the filter step. Possible values are: 87 | "mean", "cov" 88 | return_history: bool 89 | Whether to return the history of mu and sigma obtained at each step 90 | 91 | Returns 92 | ------- 93 | * array(nsamples, state_size) 94 | History of filtered mean terms 95 | * array(nsamples, state_size, state_size) 96 | History of filtered covariance terms 97 | """ 98 | state_size, *_ = init_state.shape 99 | 100 | fz, fx = params.fz, params.fx 101 | Q, R = params.Qz, params.Rx 102 | 103 | Dfz = jacrev(fz) 104 | Dfx = jacrev(fx) 105 | 106 | Vt = Q(init_state) if Vinit is None else Vinit 107 | 108 | t = 0 109 | state = (init_state, Vt, t) 110 | covariates = (covariates,) if type(covariates) is not tuple else covariates 111 | xs = (observations, covariates) 112 | 113 | return_params = [] if return_params is None else return_params 114 | 115 | filter_step_pass = partial(filter_step, params=params, Dfx=Dfx, Dfz=Dfz, 116 | eps=eps, return_params=return_params) 117 | (mu_t, Vt, _), hist_elements = lax.scan(filter_step_pass, state, xs) 118 | 119 | if return_history: 120 | return (mu_t, Vt), hist_elements 121 | 122 | return (mu_t, Vt), None 123 | -------------------------------------------------------------------------------- /jsl/nlds/extended_kalman_smoother.py: -------------------------------------------------------------------------------- 1 | # Extended Rauch-Tung-Striebel smoother or Extended Kalman Smoother (EKS) 2 | import jax 3 | import chex 4 | import jax.numpy as jnp 5 | from .base import NLDS 6 | from functools import partial 7 | from typing import Dict, List, Tuple, Callable 8 | from jsl.nlds import extended_kalman_filter as ekf 9 | 10 | 11 | def smooth_step(state: Tuple[chex.Array, chex.Array, int], 12 | xs: Tuple[chex.Array, chex.Array], 13 | params: NLDS, 14 | Dfz: Callable, 15 | eps: float, 16 | return_params: Dict 17 | ) -> Tuple[Tuple[chex.Array, chex.Array, int], Dict]: 18 | mean_next, cov_next, t = state 19 | mean_kf, cov_kf = xs 20 | 21 | mean_next_hat = params.fz(mean_kf) 22 | cov_next_hat = Dfz(mean_kf) @ cov_kf @ Dfz(mean_kf).T + params.Qz(mean_kf, t) 23 | cov_next_hat_eps = cov_next_hat + eps * jnp.eye(mean_next_hat.shape[0]) 24 | kalman_gain = jnp.linalg.solve(cov_next_hat_eps, Dfz(mean_kf).T) @ cov_kf 25 | 26 | mean_prev = mean_kf + kalman_gain @ (mean_next - mean_next_hat) 27 | cov_prev = cov_kf + kalman_gain @ (cov_next - cov_next_hat) @ kalman_gain.T 28 | 29 | prev_state = (mean_prev, cov_prev, t-1) 30 | carry = {"mean": mean_prev, "cov": cov_prev} 31 | carry = {key: val for key, val in carry.items() if key in return_params} 32 | 33 | return prev_state, carry 34 | 35 | 36 | def smooth(params: NLDS, 37 | init_state: chex.Array, 38 | observations: chex.Array, 39 | covariates: chex.Array = None, 40 | Vinit: chex.Array = None, 41 | return_params: List = None, 42 | eps: float = 0.001, 43 | return_filter_history: bool = False, 44 | ) -> Dict[str, Dict[str, chex.Array]]: 45 | 46 | kf_params = ["mean", "cov"] 47 | Dfz = jax.jacrev(params.fz) 48 | _, hist_filter = ekf.filter(params, init_state, observations, covariates, Vinit, 49 | return_params=kf_params, eps=eps, return_history=True) 50 | kf_hist_mean, kf_hist_cov = hist_filter["mean"], hist_filter["cov"] 51 | kf_last_mean, kf_hist_mean = kf_hist_mean[-1], kf_hist_mean[:-1] 52 | kf_last_cov, kf_hist_cov = kf_hist_cov[-1], kf_hist_cov[:-1] 53 | 54 | smooth_step_partial = partial(smooth_step, params=params, Dfz=Dfz, 55 | eps=eps, return_params=return_params) 56 | 57 | init_state = (kf_last_mean, kf_last_cov, len(kf_hist_mean) - 1) 58 | xs = (kf_hist_mean, kf_hist_cov) 59 | _, hist_smooth = jax.lax.scan(smooth_step_partial, init_state, xs, reverse=True) 60 | 61 | hist = { 62 | "smooth": hist_smooth, 63 | "filter": hist_filter if return_filter_history else None 64 | } 65 | 66 | return hist 67 | -------------------------------------------------------------------------------- /jsl/nlds/sequential_monte_carlo.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.scipy.stats import norm 4 | 5 | class NonMarkovianSequenceModel: 6 | """ 7 | Non-Markovian Gaussian Sequence Model 8 | """ 9 | def __init__(self, phi, beta, q, r): 10 | """ 11 | Parameters 12 | ---------- 13 | phi: float 14 | Multiplicative effect in latent-space 15 | beta: float 16 | Decay relate in observed-space 17 | q: float 18 | Variance in latent-space 19 | r: float 20 | Variance in observed-space 21 | """ 22 | self.phi = phi 23 | self.beta = beta 24 | self.q = q 25 | self.r = r 26 | 27 | @staticmethod 28 | def _obtain_weights(log_weights): 29 | weights = jnp.exp(log_weights - jax.nn.logsumexp(log_weights)) 30 | return weights 31 | 32 | def sample_latent_step(self, key, x_prev): 33 | x_next = jax.random.normal(key) * jnp.sqrt(self.q) + self.phi * x_prev 34 | return x_next 35 | 36 | def sample_observed_step(self, key, mu, x_curr): 37 | mu_next = self.beta * mu + x_curr 38 | y_curr = jax.random.normal(key) * jnp.sqrt(self.r) + mu_next 39 | return y_curr, mu_next 40 | 41 | def sample_step(self, key, x_prev, mu_prev): 42 | key_latent, key_obs = jax.random.split(key) 43 | x_curr = self.sample_latent_step(key_latent, x_prev) 44 | y_curr, mu = self.sample_observed_step(key_obs, mu_prev, x_curr) 45 | 46 | carry_vals = {"x": x_curr, "y": y_curr} 47 | return (x_curr, mu), carry_vals 48 | 49 | def sample_single(self, key, nsteps): 50 | """ 51 | Sample a single path from the non-Markovian Gaussian state-space model. 52 | 53 | Parameters 54 | ---------- 55 | key: jax.random.PRNGKey 56 | Initial seed 57 | 58 | """ 59 | key_init, key_simul = jax.random.split(key) 60 | x_init = jax.random.normal(key_init) * jnp.sqrt(self.q) 61 | mu_init = 0 62 | 63 | keys = jax.random.split(key_simul, nsteps) 64 | carry_init = (x_init, mu_init) 65 | _, hist = jax.lax.scan(lambda carry, key: self.sample_step(key, *carry), carry_init, keys) 66 | return hist 67 | 68 | def sample(self, key, nsteps, nsims=1): 69 | """ 70 | Sample from a non-Markovian Gaussian state-space model. 71 | 72 | Parameters 73 | ---------- 74 | key: jax.random.PRNGKey 75 | Initial key to perform the simulation. 76 | nsteps: int 77 | Total number of steps to sample. 78 | nsims: int 79 | Number of paths to sample. 80 | """ 81 | key_simulations = jax.random.split(key, nsims) 82 | sample_vmap = jax.vmap(self.sample_single, (0, None)) 83 | 84 | simulations = sample_vmap(key_simulations, nsteps) 85 | 86 | # convert to one-dimensional array if only one simulation is 87 | # required 88 | if nsims == 1: 89 | for key, values in simulations.items(): 90 | simulations[key] = values.ravel() 91 | 92 | return simulations 93 | 94 | def _sis_step(self, key, log_weights_prev, mu_prev, xparticles_prev, yobs): 95 | """ 96 | Compute one step of the sequential-importance-sampling algorithm 97 | at time t. 98 | 99 | Parameters 100 | ---------- 101 | key: jax.random.PRNGKey 102 | key to sample particle. 103 | mu_prev: array(n_particles) 104 | Term carrying past cumulate values. 105 | xsamp_prev: array(n_particles) 106 | Samples / particles from the latent space at t-1 107 | yobs: float 108 | Observation at time t. 109 | """ 110 | 111 | key_particles = jax.random.split(key, len(xparticles_prev)) 112 | 113 | # 1. Sample from proposal 114 | xparticles = jax.vmap(self.sample_latent_step)(key_particles, xparticles_prev) 115 | # 2. Evaluate unnormalised weights 116 | # 2.1 Compute new mean 117 | mu = self.beta * mu_prev + xparticles 118 | # 2.2 Compute log-unnormalised weights 119 | 120 | log_weights = log_weights_prev + norm.logpdf(yobs, loc=mu, scale=jnp.sqrt(self.q)) 121 | dict_carry = { 122 | "log_weights": log_weights, 123 | "particles": xparticles, 124 | } 125 | 126 | return (log_weights, mu, xparticles), dict_carry 127 | 128 | def sequential_importance_sample(self, key, observations, n_particles=10): 129 | """ 130 | Apply sequential importance sampling (SIS) to a series of observations. Sampling 131 | considers the transition distribution as the proposal. 132 | 133 | Parameters 134 | ---------- 135 | key: jax.random.PRNGKey 136 | Initial key. 137 | observations: array(n_observations) 138 | one-array of observed values. 139 | n_particles: int (default: 10) 140 | Total number of particles to consider in the SIS filter. 141 | """ 142 | T = len(observations) 143 | key, key_init_particles = jax.random.split(key) 144 | keys = jax.random.split(key, T) 145 | 146 | init_log_weights = jnp.zeros(n_particles) 147 | init_mu = jnp.zeros(n_particles) # equiv. ∀n.wn=1.0 148 | init_xparticles = jax.random.normal(key_init_particles, shape=(n_particles,)) * jnp.sqrt(self.q) 149 | 150 | carry_init = (init_log_weights, init_mu, init_xparticles) 151 | xs_tuple = (keys, observations) 152 | _, dict_hist = jax.lax.scan(lambda carry, xs: self._sis_step(xs[0], *carry, xs[1]), carry_init, xs_tuple) 153 | dict_hist["weights"] = jnp.exp(dict_hist["log_weights"] - jax.nn.logsumexp(dict_hist["log_weights"], axis=1, keepdims=True)) 154 | 155 | return dict_hist 156 | 157 | def _smc_step(self, key, log_weights_prev, mu_prev, xparticles_prev, yobs): 158 | n_particles = len(xparticles_prev) 159 | key, key_particles = jax.random.split(key) 160 | key_particles = jax.random.split(key_particles, n_particles) 161 | 162 | # 1. Resample particles 163 | weights = self._obtain_weights(log_weights_prev) 164 | ix_sampled = jax.random.choice(key, n_particles, p=weights, shape=(n_particles,)) 165 | xparticles_prev_sampled = xparticles_prev[ix_sampled] 166 | mu_prev_sampled = mu_prev[ix_sampled] 167 | # 2. Propagate particles 168 | xparticles = jax.vmap(self.sample_latent_step)(key_particles, xparticles_prev_sampled) 169 | # 3. Concatenate 170 | mu = self.beta * mu_prev_sampled + xparticles 171 | 172 | # ToDo: return dictionary of log_weights and sampled indices 173 | log_weights = norm.logpdf(yobs, loc=mu, scale=jnp.sqrt(self.q)) 174 | dict_carry = { 175 | "log_weights": log_weights, 176 | "indices": ix_sampled, 177 | "particles": xparticles, 178 | } 179 | return (log_weights, mu, xparticles_prev_sampled), dict_carry 180 | 181 | def sequential_monte_carlo(self, key, observations, n_particles=10): 182 | """ 183 | Apply sequential Monte Carlo (SCM), a.k.a sequential importance resampling (SIR), 184 | a.k.a sequential importance sampling and resampling(SISR). 185 | """ 186 | T = len(observations) 187 | key, key_particle_init = jax.random.split(key) 188 | keys = jax.random.split(key, T) 189 | 190 | init_xparticles = jax.random.normal(key_particle_init, shape=(n_particles,)) * jnp.sqrt(self.q) 191 | init_log_weights = jnp.zeros(n_particles) # equiv. ∀n.wn=1.0 192 | init_mu = jnp.zeros(n_particles) 193 | 194 | carry_init = (init_log_weights, init_mu, init_xparticles) 195 | xs_tuple = (keys, observations) 196 | _, dict_hist = jax.lax.scan(lambda carry, xs: self._smc_step(xs[0], *carry, xs[1]), carry_init, xs_tuple) 197 | # transform log-unnormalised weights to weights 198 | dict_hist["weights"] = jnp.exp(dict_hist["log_weights"] - jax.nn.logsumexp(dict_hist["log_weights"], axis=1, keepdims=True)) 199 | 200 | return dict_hist 201 | -------------------------------------------------------------------------------- /jsl/nlds/unscented_kalman_filter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of the Unscented Kalman Filter for discrete time systems 3 | """ 4 | 5 | import jax.numpy as jnp 6 | from jax.lax import scan 7 | 8 | import chex 9 | from typing import List 10 | 11 | from .base import NLDS 12 | 13 | 14 | def sqrtm(M): 15 | """ 16 | Compute the matrix square-root of a hermitian 17 | matrix M. i,e, R such that RR = M 18 | 19 | Parameters 20 | ---------- 21 | M: array(m, m) 22 | Hermitian matrix 23 | 24 | Returns 25 | ------- 26 | array(m, m): square-root matrix 27 | """ 28 | evals, evecs = jnp.linalg.eigh(M) 29 | R = evecs @ jnp.sqrt(jnp.diag(evals)) @ jnp.linalg.inv(evecs) 30 | return R 31 | 32 | 33 | def filter(params: NLDS, 34 | init_state: chex.Array, 35 | sample_obs: chex.Array, 36 | observations: List = None, 37 | Vinit: chex.Array = None, 38 | return_history: bool = True): 39 | """ 40 | Run the Unscented Kalman Filter algorithm over a set of observed samples. 41 | Parameters 42 | ---------- 43 | sample_obs: array(nsamples, obs_size) 44 | return_history: bool 45 | Whether to return the history of mu and Sigma values. 46 | Returns 47 | ------- 48 | * array(nsamples, state_size) 49 | History of filtered mean terms 50 | * array(nsamples, state_size, state_size) 51 | History of filtered covariance terms 52 | """ 53 | alpha = params.alpha 54 | beta = params.beta 55 | kappa = params.kappa 56 | d = params.d 57 | 58 | fx, fz = params.fx, params.fz 59 | Q, R = params.Qz, params.Rx 60 | 61 | lmbda = alpha ** 2 * (d + kappa) - d 62 | gamma = jnp.sqrt(d + lmbda) 63 | 64 | wm_vec = jnp.array([1 / (2 * (d + lmbda)) if i > 0 65 | else lmbda / (d + lmbda) 66 | for i in range(2 * d + 1)]) 67 | wc_vec = jnp.array([1 / (2 * (d + lmbda)) if i > 0 68 | else lmbda / (d + lmbda) + (1 - alpha ** 2 + beta) 69 | for i in range(2 * d + 1)]) 70 | nsteps, *_ = sample_obs.shape 71 | initial_mu_t = init_state 72 | initial_Sigma_t = Q(init_state) if Vinit is None else Vinit 73 | 74 | if observations is None: 75 | observations = iter([()] * nsteps) 76 | else: 77 | observations = iter([(obs,) for obs in observations]) 78 | 79 | def filter_step(params, sample_observation): 80 | mu_t, Sigma_t = params 81 | observation = next(observations) 82 | 83 | # TO-DO: use jax.scipy.linalg.sqrtm when it gets added to lib 84 | comp1 = mu_t[:, None] + gamma * sqrtm(Sigma_t) 85 | comp2 = mu_t[:, None] - gamma * sqrtm(Sigma_t) 86 | # sigma_points = jnp.c_[mu_t, comp1, comp2] 87 | sigma_points = jnp.concatenate((mu_t[:, None], comp1, comp2), axis=1) 88 | 89 | z_bar = fz(sigma_points) 90 | mu_bar = z_bar @ wm_vec 91 | Sigma_bar = (z_bar - mu_bar[:, None]) 92 | Sigma_bar = jnp.einsum("i,ji,ki->jk", wc_vec, Sigma_bar, Sigma_bar) + Q(mu_t) 93 | 94 | Sigma_bar_half = sqrtm(Sigma_bar) 95 | comp1 = mu_bar[:, None] + gamma * Sigma_bar_half 96 | comp2 = mu_bar[:, None] - gamma * Sigma_bar_half 97 | # sigma_points = jnp.c_[mu_bar, comp1, comp2] 98 | sigma_points = jnp.concatenate((mu_bar[:, None], comp1, comp2), axis=1) 99 | 100 | x_bar = fx(sigma_points, *observation) 101 | x_hat = x_bar @ wm_vec 102 | St = x_bar - x_hat[:, None] 103 | St = jnp.einsum("i,ji,ki->jk", wc_vec, St, St) + R(mu_t, *observation) 104 | 105 | mu_hat_component = z_bar - mu_bar[:, None] 106 | x_hat_component = x_bar - x_hat[:, None] 107 | Sigma_bar_y = jnp.einsum("i,ji,ki->jk", wc_vec, mu_hat_component, x_hat_component) 108 | Kt = Sigma_bar_y @ jnp.linalg.inv(St) 109 | 110 | mu_t = mu_bar + Kt @ (sample_observation - x_hat) 111 | Sigma_t = Sigma_bar - Kt @ St @ Kt.T 112 | 113 | return (mu_t, Sigma_t), (mu_t, Sigma_t) 114 | 115 | (mu, Sigma), (mu_hist, Sigma_hist) = scan(filter_step, (initial_mu_t, initial_Sigma_t), sample_obs[1:]) 116 | 117 | mu_hist = jnp.vstack([initial_mu_t[None, ...], mu_hist]) 118 | Sigma_hist = jnp.vstack([initial_Sigma_t[None, ...], Sigma_hist]) 119 | 120 | if return_history: 121 | return mu_hist, Sigma_hist 122 | return mu, Sigma 123 | -------------------------------------------------------------------------------- /jsl/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="snb", version="0.0.1", install_requires=["gym"] ) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="jsl", 5 | packages=find_packages(), 6 | install_requires=[ 7 | "chex", 8 | "dataclasses", 9 | "jaxlib", 10 | "jax", 11 | "matplotlib", 12 | "tensorflow_probability" 13 | ] 14 | ) 15 | --------------------------------------------------------------------------------