├── .gitignore
├── CITATION.cff
├── LICENSE
├── README.md
├── environment.yml
├── jsl
├── __init__.py
├── demos
│ ├── __init__.py
│ ├── bootstrap_filter.py
│ ├── bootstrap_filter_maneuver.py
│ ├── eekf_logistic_regression.py
│ ├── ekf_continuous.py
│ ├── ekf_mlp.py
│ ├── ekf_mlp_anim.ipynb
│ ├── ekf_mlp_anim.py
│ ├── ekf_vs_eks_spiral.py
│ ├── ekf_vs_ukf_spiral.py
│ ├── fixed_lag_smoother.ipynb
│ ├── hmm_casino.py
│ ├── hmm_casino_em_train.py
│ ├── hmm_casino_numpy.py
│ ├── hmm_casino_sgd_train
│ ├── hmm_casino_sgd_train.py
│ ├── hmm_lillypad.py
│ ├── kalman_sampling_demo.ipynb
│ ├── kf_continuous_circle.py
│ ├── kf_parallel.py
│ ├── kf_spiral.py
│ ├── kf_tracking.py
│ ├── lds_sampling_demo.py
│ ├── linreg_kf.py
│ ├── logreg_biclusters.py
│ ├── old-init.py
│ ├── pendulum_1d.py
│ ├── plot_utils.py
│ ├── rbpf_maneuver.py
│ ├── sis_vs_smc.py
│ ├── superimport_test.py
│ └── ukf_mlp.py
├── hmm
│ ├── __init__.py
│ ├── hmm_casino_test.py
│ ├── hmm_lib.py
│ ├── hmm_lib_test.py
│ ├── hmm_logspace_lib.py
│ ├── hmm_logspace_lib_test.py
│ ├── hmm_numpy_lib.py
│ ├── hmm_utils.py
│ ├── old
│ │ ├── hmm_discrete_em_lib.py
│ │ ├── hmm_discrete_lib.py
│ │ ├── hmm_discrete_lib_test.py
│ │ ├── hmm_discrete_likelihood_test.py
│ │ └── hmm_sgd_lib.py
│ └── sparse_lib.py
├── lds
│ ├── __init__.py
│ ├── cont_kalman_filter.py
│ ├── kalman_filter.py
│ ├── kalman_filter_test.py
│ ├── kalman_filter_with_unknown_noise.py
│ ├── kalman_sampler.py
│ └── mixture_kalman_filter.py
├── nlds
│ ├── __init__.py
│ ├── base.py
│ ├── bootstrap_filter.py
│ ├── continuous_extended_kalman_filter.py
│ ├── diagonal_extended_kalman_filter.py
│ ├── extended_kalman_filter.py
│ ├── extended_kalman_smoother.py
│ ├── sequential_monte_carlo.py
│ └── unscented_kalman_filter.py
└── setup.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | output/
2 | *.mp4
3 | *.pdf
4 | .DS_Store
5 | .vscode/
6 | hmm_casino_params
7 |
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 | cover/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | .pybuilder/
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # IPython
89 | profile_default/
90 | ipython_config.py
91 |
92 | # pyenv
93 | # For a library or package, you might want to ignore these files since the code is
94 | # intended to run in multiple environments; otherwise, check them in:
95 | # .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # poetry
105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106 | # This is especially recommended for binary packages to ensure reproducibility, and is more
107 | # commonly ignored for libraries.
108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109 | #poetry.lock
110 |
111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
112 | __pypackages__/
113 |
114 | # Celery stuff
115 | celerybeat-schedule
116 | celerybeat.pid
117 |
118 | # SageMath parsed files
119 | *.sage.py
120 |
121 | # Environments
122 | .env
123 | .venv
124 | env/
125 | venv/
126 | ENV/
127 | env.bak/
128 | venv.bak/
129 |
130 | # Spyder project settings
131 | .spyderproject
132 | .spyproject
133 |
134 | # Rope project settings
135 | .ropeproject
136 |
137 | # mkdocs documentation
138 | /site
139 |
140 | # mypy
141 | .mypy_cache/
142 | .dmypy.json
143 | dmypy.json
144 |
145 | # Pyre type checker
146 | .pyre/
147 |
148 | # pytype static type analyzer
149 | .pytype/
150 |
151 | # Cython debug symbols
152 | cython_debug/
153 |
154 | # PyCharm
155 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
156 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
157 | # and can be added to the global gitignore or merged into this file. For a more nuclear
158 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
159 | #.idea/
160 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | # This CITATION.cff file was generated with cffinit.
2 | # Visit https://bit.ly/cffinit to generate yours today!
3 |
4 | cff-version: 1.2.0
5 | title: JSL
6 | message: >-
7 | If you use this software, please cite it using the
8 | metadata from this file.
9 | type: software
10 | authors:
11 | - given-names: Gerardo
12 | family-names: Duran-Martin
13 | - given-names: Kevin
14 | family-names: Murphy
15 | - given-names: Aleyna
16 | family-names: Kara
17 | repository: 'https://github.com/probml/JSL'
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Probabilistic machine learning
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # JSL: JAX State-Space models (SSM) Library
2 |
3 |
4 |
5 |
6 |
7 | JSL is a JAX library for Bayesian inference in state space models.
8 | As of 2022-06-28, JSL is **deprecated**. You should use [ssm-jax](https://github.com/probml/ssm-jax).
9 |
10 |
11 | # Installation
12 |
13 | We assume you have already installed [JAX](https://github.com/google/jax#installation) and
14 | [Tensorflow](https://www.tensorflow.org/install),
15 | since the details on how to do this depend on whether you have a CPU, GPU, etc.
16 | (This step is not necessary in Colab.)
17 |
18 | Now install these packages:
19 |
20 | ```
21 | !pip install --upgrade git+https://github.com/google/flax.git
22 | !pip install --upgrade tensorflow-probability
23 | !pip install git+https://github.com/blackjax-devs/blackjax.git
24 | !pip install git+https://github.com/deepmind/distrax.git
25 | ```
26 |
27 | Then install JSL:
28 | ```
29 | !pip install git+https://github.com/probml/jsl
30 | ```
31 | Alternatively, you can clone the repo locally, into say `~/github/JSL`, and then install it as a package, as follows:
32 | ```
33 | !git clone https://github.com/probml/JSL.git
34 | cd JSL
35 | !pip install -e .
36 | ```
37 |
38 | # Running the demos
39 |
40 | You can see how to use the library by looking at some of the demos.
41 | You can run the demos from inside a notebook like this
42 | ```
43 | %run JSL/jsl/demos/kf_tracking.py
44 | %run JSL/jsl/demos/hmm_casino_em_train.py
45 | ```
46 |
47 | Or from inside an ipython shell like this
48 | ```
49 | from jsl.demos import kf_tracking
50 | figdict = kf_tracking.main()
51 | ```
52 |
53 | Most of the demos create figures. If you want to save them (in both png and pdf format),
54 | you need to specify the FIGDIR environment variable, like this:
55 | ```
56 | import os
57 | os.environ["FIGDIR"]='/Users/kpmurphy/figures'
58 |
59 | from jsl.demos.plot_utils import savefig
60 | savefig(figdict)
61 | ```
62 |
63 | # Authors
64 |
65 | Gerardo Durán-Martín ([@gerdm](https://github.com/gerdm)), Aleyna Kara([@karalleyna](https://github.com/karalleyna)), Kevin Murphy ([@murphyk](https://github.com/murphyk)), Giles Harper-Donnelly ([@gileshd](https://github.com/gileshd)), Peter Chang ([@petergchang](https://github.com/petergchang)).
66 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: jsl
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=4.5=1_gnu
7 | - blas=1.0=mkl
8 | - brotli=1.0.9=he6710b0_2
9 | - ca-certificates=2021.10.26=h06a4308_2
10 | - certifi=2021.10.8=py39h06a4308_0
11 | - cycler=0.11.0=pyhd3eb1b0_0
12 | - dbus=1.13.18=hb2f20db_0
13 | - expat=2.4.1=h2531618_2
14 | - fontconfig=2.13.1=h6c09931_0
15 | - fonttools=4.25.0=pyhd3eb1b0_0
16 | - freetype=2.11.0=h70c0345_0
17 | - giflib=5.2.1=h7b6447c_0
18 | - glib=2.69.1=h5202010_0
19 | - gst-plugins-base=1.14.0=h8213a91_2
20 | - gstreamer=1.14.0=h28cd5cc_2
21 | - icu=58.2=he6710b0_3
22 | - intel-openmp=2021.4.0=h06a4308_3561
23 | - jpeg=9d=h7f8727e_0
24 | - kiwisolver=1.3.1=py39h2531618_0
25 | - lcms2=2.12=h3be6417_0
26 | - ld_impl_linux-64=2.35.1=h7274673_9
27 | - libffi=3.3=he6710b0_2
28 | - libgcc-ng=9.3.0=h5101ec6_17
29 | - libgomp=9.3.0=h5101ec6_17
30 | - libpng=1.6.37=hbc83047_0
31 | - libstdcxx-ng=9.3.0=hd4cf53a_17
32 | - libtiff=4.2.0=h85742a9_0
33 | - libuuid=1.0.3=h7f8727e_2
34 | - libwebp=1.2.0=h89dd481_0
35 | - libwebp-base=1.2.0=h27cfd23_0
36 | - libxcb=1.14=h7b6447c_0
37 | - libxml2=2.9.12=h03d6c58_0
38 | - lz4-c=1.9.3=h295c915_1
39 | - matplotlib=3.5.0=py39h06a4308_0
40 | - matplotlib-base=3.5.0=py39h3ed280b_0
41 | - mkl=2021.4.0=h06a4308_640
42 | - mkl-service=2.4.0=py39h7f8727e_0
43 | - mkl_fft=1.3.1=py39hd3c417c_0
44 | - mkl_random=1.2.2=py39h51133e4_0
45 | - munkres=1.1.4=py_0
46 | - ncurses=6.3=h7f8727e_2
47 | - numpy-base=1.21.2=py39h79a1101_0
48 | - olefile=0.46=pyhd3eb1b0_0
49 | - openssl=1.1.1l=h7f8727e_0
50 | - packaging=21.3=pyhd3eb1b0_0
51 | - pcre=8.45=h295c915_0
52 | - pillow=8.4.0=py39h5aabda8_0
53 | - pip=21.2.4=py39h06a4308_0
54 | - pyparsing=3.0.4=pyhd3eb1b0_0
55 | - pyqt=5.9.2=py39h2531618_6
56 | - python=3.9.7=h12debd9_1
57 | - python-dateutil=2.8.2=pyhd3eb1b0_0
58 | - qt=5.9.7=h5867ecd_1
59 | - readline=8.1=h27cfd23_0
60 | - setuptools=58.0.4=py39h06a4308_0
61 | - sip=4.19.13=py39h2531618_0
62 | - six=1.16.0=pyhd3eb1b0_0
63 | - sqlite=3.36.0=hc218d9a_0
64 | - tk=8.6.11=h1ccaba5_0
65 | - tornado=6.1=py39h27cfd23_0
66 | - tzdata=2021e=hda174b7_0
67 | - wheel=0.37.0=pyhd3eb1b0_1
68 | - xz=5.2.5=h7b6447c_0
69 | - zlib=1.2.11=h7b6447c_3
70 | - zstd=1.4.9=haebb681_0
71 | - pip:
72 | - absl-py==1.0.0
73 | - charset-normalizer==2.0.9
74 | - fire==0.4.0
75 | - flatbuffers==2.0
76 | - idna==3.3
77 | - jax==0.2.26
78 | - jaxlib==0.1.75
79 | - libtpu-nightly==0.1.dev20211208
80 | - numpy==1.21.4
81 | - opt-einsum==3.3.0
82 | - requests==2.26.0
83 | - scipy==1.7.3
84 | - termcolor==1.1.0
85 | - typing-extensions==4.0.1
86 | - urllib3==1.26.7
87 | prefix: /home/gerardoduran/miniconda3/envs/jsl
88 |
--------------------------------------------------------------------------------
/jsl/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/jsl/demos/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probml/JSL/649c1fa9709c9195d80f06d7a367112d67dc60c9/jsl/demos/__init__.py
--------------------------------------------------------------------------------
/jsl/demos/bootstrap_filter.py:
--------------------------------------------------------------------------------
1 | # Demo of the bootstrap filter under a
2 | # nonlinear discrete system
3 |
4 | import jax
5 | from jsl.nlds.base import NLDS
6 | from jsl.nlds.bootstrap_filter import filter
7 | import jax.numpy as jnp
8 | import matplotlib.pyplot as plt
9 | from jax import random
10 |
11 |
12 | def plot_samples(sample_state, sample_obs, ax=None):
13 | fig, ax = plt.subplots()
14 | ax.plot(*sample_state.T, label="state space")
15 | ax.scatter(*sample_obs.T, s=60, c="tab:green", marker="+")
16 | ax.scatter(*sample_state[0], c="black", zorder=3)
17 | ax.legend()
18 | ax.set_title("Noisy observations from hidden trajectory")
19 | plt.axis("equal")
20 | return fig
21 |
22 |
23 | def plot_inference(sample_obs, mean_hist):
24 | fig, ax = plt.subplots()
25 | ax.scatter(*sample_obs.T, marker="+", color="tab:green", s=60)
26 | ax.plot(*mean_hist.T, c="tab:orange", label="filtered")
27 | ax.scatter(*mean_hist[0], c="black", zorder=3)
28 | plt.legend()
29 | plt.axis("equal")
30 | return fig
31 |
32 | def main():
33 | def fz(x, dt): return x + dt * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])])
34 | def fx(x): return x
35 |
36 | dt = 0.4
37 | nsteps = 100
38 | # Initial state vector
39 | x0 = jnp.array([1.5, 0.0])
40 | # State noise
41 | Qt = jnp.eye(2) * 0.001
42 | # Observed noise
43 | Rt = jnp.eye(2) * 0.05
44 |
45 | key = random.PRNGKey(314)
46 | model = NLDS(lambda x: fz(x, dt), fx, Qt, Rt)
47 | sample_state, sample_obs = model.sample(key, x0, nsteps)
48 |
49 | n_particles = 3_000
50 | fz_vec = jax.vmap(fz, in_axes=(0, None))
51 | particle_filter = NLDS(lambda x: fz_vec(x, dt), fx, Qt, Rt)
52 | pf_mean = filter(particle_filter, key, x0, sample_obs, n_particles)
53 |
54 | dict_figures = {}
55 | fig_boostrap = plot_inference(sample_obs, pf_mean)
56 | dict_figures["nlds2d_bootstrap"] = fig_boostrap
57 |
58 | fig_data = plot_samples(sample_state, sample_obs)
59 | dict_figures["nlds2d_data"] = fig_data
60 |
61 | return dict_figures
62 |
63 | if __name__ == "__main__":
64 | from jsl.demos.plot_utils import savefig
65 | plt.rcParams["axes.spines.right"] = False
66 | plt.rcParams["axes.spines.top"] = False
67 | dict_figures = main()
68 | savefig(dict_figures)
69 | plt.show()
70 |
--------------------------------------------------------------------------------
/jsl/demos/bootstrap_filter_maneuver.py:
--------------------------------------------------------------------------------
1 | # Bootstrap filtering for jump markov linear systems
2 | # Compare to result in the Rao-Blackwelli particle filter
3 | # (see https://github.com/probml/pyprobml/blob/master/scripts/rbpf_maneuver_demo.py)
4 |
5 | # !pip install matplotlib==3.4.2
6 |
7 |
8 | import jax
9 | import numpy as np
10 | import jax.numpy as jnp
11 | import matplotlib.pyplot as plt
12 | import jsl.lds.mixture_kalman_filter as kflib
13 | from jax import random
14 | from functools import partial
15 | from jax.scipy.special import logit
16 | from jax.scipy.stats import multivariate_normal
17 | from jsl.demos.plot_utils import kdeg, style3d
18 |
19 |
20 | def plot_3d_belief_state(mu_hist, dim, ax, skip=3, npoints=2000, azimuth=-30, elevation=30, h=0.5):
21 | nsteps = len(mu_hist)
22 | xmin, xmax = mu_hist[..., dim].min(), mu_hist[..., dim].max()
23 | xrange = jnp.linspace(xmin, xmax, npoints).reshape(-1, 1)
24 | res = np.apply_along_axis(lambda X: kdeg(xrange, X[..., None], h), 1, mu_hist)
25 | densities = res[..., dim]
26 | for t in range(0, nsteps, skip):
27 | tloc = t * np.ones(npoints)
28 | px = densities[t]
29 | ax.plot(tloc, xrange, px, c="tab:blue", linewidth=1)
30 | ax.set_zlim(0, 1)
31 | style3d(ax, 1.8, 1.2, 0.7, 0.8)
32 | ax.view_init(elevation, azimuth)
33 | ax.set_xlabel(r"$t$", fontsize=13)
34 | ax.set_ylabel(r"$x_{"f"d={dim}"",t}$", fontsize=13)
35 | ax.set_zlabel(r"$p(x_{d, t} \vert y_{1:t})$", fontsize=13)
36 |
37 |
38 | def bootstrap(state, y, A, B, C, Q, transition_matrix, nparticles):
39 | latent_t, state_t, key = state
40 | key_latent, key_state, key_reindex, key_next = random.split(key, 4)
41 |
42 | # Discrete states
43 | latent_t = random.categorical(key_latent, jnp.log(transition_matrix[latent_t]), shape=(nparticles,))
44 | # Continous states
45 | state_mean = jnp.einsum("nm,sm->sn", A, state_t) + B[latent_t]
46 | state_t = random.multivariate_normal(key_state, mean=state_mean, cov=Q)
47 |
48 | # Compute weights
49 | weights_t = multivariate_normal.pdf(y, mean=state_t, cov=C)
50 | indices_t = random.categorical(key_reindex, jnp.log(weights_t), shape=(nparticles,))
51 |
52 | # Reindex and compute weights
53 | state_t = state_t[indices_t, ...]
54 | latent_t = latent_t[indices_t, ...]
55 | # weights_t = jnp.ones(nparticles) / nparticles
56 |
57 | mu_t = state_t.mean(axis=0)
58 |
59 | return (latent_t, state_t, key_next), (mu_t, latent_t, state_t)
60 |
61 |
62 | def main():
63 | TT = 0.1
64 | A = jnp.array([[1, TT, 0, 0],
65 | [0, 1, 0, 0],
66 | [0, 0, 1, TT],
67 | [0, 0, 0, 1]])
68 |
69 |
70 | B1 = jnp.array([0, 0, 0, 0])
71 | B2 = jnp.array([-1.225, -0.35, 1.225, 0.35])
72 | B3 = jnp.array([1.225, 0.35, -1.225, -0.35])
73 | B = jnp.stack([B1, B2, B3], axis=0)
74 |
75 | Q = 0.2 * jnp.eye(4)
76 | R = 10 * jnp.diag(jnp.array([2, 1, 2, 1]))
77 | C = jnp.eye(4)
78 |
79 | transition_matrix = jnp.array([
80 | [0.9, 0.05, 0.05],
81 | [0.05, 0.9, 0.05],
82 | [0.05, 0.05, 0.9]
83 | ])
84 |
85 | transition_matrix = jnp.array([
86 | [0.8, 0.1, 0.1],
87 | [0.1, 0.8, 0.1],
88 | [0.1, 0.1, 0.8]
89 | ])
90 |
91 | # We sample from a Rao-Blackwell Kalman Filter
92 | params = kflib.RBPFParamsDiscrete(A, B, C, Q, R, transition_matrix)
93 |
94 | nparticles = 1000
95 | nsteps = 100
96 | key = random.PRNGKey(1)
97 | keys = random.split(key, nsteps)
98 |
99 | x0 = (1, random.multivariate_normal(key, jnp.zeros(4), jnp.eye(4)))
100 | draw_state_fixed = partial(kflib.draw_state, params=params)
101 |
102 | # Create target dataset
103 | _, (latent_hist, state_hist, obs_hist) = jax.lax.scan(draw_state_fixed, x0, keys)
104 |
105 |
106 | # ** Filtering process **
107 | nparticles = 5000
108 | key_base = random.PRNGKey(31)
109 | key_mean_init, key_sample, key_state, key_next = random.split(key_base, 4)
110 |
111 | p_init = jnp.array([0.0, 1.0, 0.0])
112 | mu_0 = jnp.zeros((nparticles, 4))
113 | s0 = random.categorical(key_state, logit(p_init), shape=(nparticles,))
114 | init_state = (s0, mu_0, key_base)
115 |
116 | def bootstrap_step(state, y): return bootstrap(state, y, A, B, C, Q, transition_matrix, nparticles)
117 | _, (mu_t_hist, latent_hist, state_hist_particles) = jax.lax.scan(bootstrap_step, init_state, obs_hist)
118 |
119 | estimated_track = mu_t_hist[:, [0, 2]]
120 | particles_track = state_hist_particles[..., [0, 2]]
121 |
122 | bf_mse = ((mu_t_hist - state_hist)[:, [0, 2]] ** 2).mean(axis=0).sum()
123 | color_dict = {0: "tab:green", 1: "tab:red", 2: "tab:blue"}
124 | latent_hist_est = latent_hist.mean(axis=1).round()
125 | color_states_est = [color_dict[state] for state in np.array(latent_hist_est)]
126 |
127 | dict_figures = {}
128 | fig, ax = plt.subplots()
129 | ax.scatter(*estimated_track.T, edgecolors=color_states_est, c="none", s=10)
130 | ax.set_title(f"Bootstrap Filter MSE: {bf_mse:.2f}")
131 | dict_figures["bootstrap-filter-trace"] = fig
132 |
133 | for dim in range(2):
134 | fig = plt.figure()
135 | ax = plt.axes(projection="3d")
136 | plot_3d_belief_state(particles_track, dim, ax, h=1.1, npoints=1000)
137 | ax.autoscale(enable=False, axis='both')
138 | # pml.savefig(f"bootstrap-filter-belief-states-dim{dim}.pdf", pad_inches=0, bbox_inches="tight")
139 | dict_figures[f"bootstrap-filter-belief-states-dim{dim}"] = fig
140 |
141 | return dict_figures
142 |
143 | if __name__ == "__main__":
144 | from jsl.demos.plot_utils import savefig
145 | plt.rcParams["axes.spines.right"] = False
146 | plt.rcParams["axes.spines.top"] = False
147 | dict_figures = main()
148 | savefig(dict_figures, pad_inches=0, bbox_inches="tight")
149 | plt.show()
--------------------------------------------------------------------------------
/jsl/demos/eekf_logistic_regression.py:
--------------------------------------------------------------------------------
1 | # Online learning of a 2d binary logistic regression model p(y=1|x,w) = sigmoid(w'x),
2 | # using the Exponential-family Extended Kalman Filter (EEKF) algorithm
3 | # described in "Online natural gradient as a Kalman filter", Y. Ollivier, 2018.
4 | # https://projecteuclid.org/euclid.ejs/1537257630.
5 |
6 | # The latent state corresponds to the current estimate of the regression weights w.
7 | # The observation model has the form
8 | # p(y(t) | w(t), x(t)) propto Gauss(y(t) | h_t(w(t)), R(t))
9 | # where h_t(w) = sigmoid(w' * x(t)) = p(t) and R(t) = p(t) * (1-p(t))
10 |
11 | import jax.numpy as jnp
12 | import matplotlib.pyplot as plt
13 | from jax import random
14 |
15 | from jsl.nlds.base import NLDS
16 | from jsl.nlds.extended_kalman_filter import filter
17 |
18 | # Import data and baseline solution
19 | from jsl.demos import logreg_biclusters as demo
20 |
21 | figures, data = demo.main()
22 | X = data["X"]
23 | y = data["y"]
24 | Phi = data["Phi"]
25 | Xspace = data["Xspace"]
26 | Phispace = data["Phispace"]
27 | w_laplace = data["w_laplace"]
28 |
29 |
30 | # jax.config.update("jax_platform_name", "cpu")
31 | # jax.config.update("jax_enable_x64", True)
32 |
33 | def sigmoid(x): return jnp.exp(x) / (1 + jnp.exp(x))
34 |
35 |
36 | def log_sigmoid(z): return z - jnp.log1p(jnp.exp(z))
37 |
38 |
39 | def fz(x): return x
40 |
41 |
42 | def fx(w, x): return sigmoid(w[None, :] @ x)
43 |
44 |
45 | def Rt(w, x): return (sigmoid(w @ x) * (1 - sigmoid(w @ x)))[None, None]
46 |
47 |
48 | def main():
49 | N, M = Phi.shape
50 | n_datapoints, ndims = Phi.shape
51 |
52 | # Predictive domain
53 | xmin, ymin = X.min(axis=0) - 0.1
54 | xmax, ymax = X.max(axis=0) + 0.1
55 | step = 0.1
56 | Xspace = jnp.mgrid[xmin:xmax:step, ymin:ymax:step]
57 | _, nx, ny = Xspace.shape
58 | Phispace = jnp.concatenate([jnp.ones((1, nx, ny)), Xspace])
59 |
60 | ### EEKF Approximation
61 | mu_t = jnp.zeros(M)
62 | Pt = jnp.eye(M) * 0.0
63 | P0 = jnp.eye(M) * 2.0
64 |
65 | model = NLDS(fz, fx, Pt, Rt)
66 | (w_eekf, P_eekf), eekf_hist = filter(model, mu_t, y, Phi, P0, return_params=["mean", "cov"])
67 | w_eekf_hist = eekf_hist["mean"]
68 | P_eekf_hist = eekf_hist["cov"]
69 |
70 | ### *** Ploting surface predictive distribution ***
71 | colors = ["black" if el else "white" for el in y]
72 | dict_figures = {}
73 | key = random.PRNGKey(31415)
74 | nsamples = 5000
75 |
76 | # EEKF surface predictive distribution
77 | eekf_samples = random.multivariate_normal(key, w_eekf, P_eekf, (nsamples,))
78 | Z_eekf = sigmoid(jnp.einsum("mij,sm->sij", Phispace, eekf_samples))
79 | Z_eekf = Z_eekf.mean(axis=0)
80 |
81 | fig_eekf, ax = plt.subplots()
82 | title = "EEKF Predictive Distribution"
83 | demo.plot_posterior_predictive(ax, X, Xspace, Z_eekf, title, colors)
84 | dict_figures["logistic_regression_surface_eekf"] = fig_eekf
85 |
86 | ### Plot EEKF and Laplace training history
87 | P_eekf_hist_diag = jnp.diagonal(P_eekf_hist, axis1=1, axis2=2)
88 | # P_laplace_diag = jnp.sqrt(jnp.diagonal(SN))
89 | lcolors = ["black", "tab:blue", "tab:red"]
90 | elements = w_eekf_hist.T, P_eekf_hist_diag.T, w_laplace, lcolors
91 | timesteps = jnp.arange(n_datapoints) + 1
92 |
93 | for k, (wk, Pk, wk_laplace, c) in enumerate(zip(*elements)):
94 | fig_weight_k, ax = plt.subplots()
95 | ax.errorbar(timesteps, wk, jnp.sqrt(Pk), c=c, label=f"$w_{k}$ online (EEKF)")
96 | ax.axhline(y=wk_laplace, c=c, linestyle="dotted", label=f"$w_{k}$ batch (Laplace)", linewidth=3)
97 |
98 | ax.set_xlim(1, n_datapoints)
99 | ax.legend(framealpha=0.7, loc="upper right")
100 | ax.set_xlabel("number samples")
101 | ax.set_ylabel("weights")
102 | plt.tight_layout()
103 | dict_figures[f"logistic_regression_hist_ekf_w{k}"] = fig_weight_k
104 |
105 | print("EEKF weights")
106 | print(w_eekf, end="\n" * 2)
107 |
108 | return dict_figures
109 |
110 |
111 | if __name__ == "__main__":
112 | from jsl.demos.plot_utils import savefig
113 |
114 | figs = main()
115 | savefig(figs)
116 | plt.show()
117 |
--------------------------------------------------------------------------------
/jsl/demos/ekf_continuous.py:
--------------------------------------------------------------------------------
1 | # Example of an Extended Kalman Filter using
2 | # a figure-8 nonlinear dynamical system.
3 | # For futher reference and examples see:
4 | # * Section on EKFs in PML vol2 book
5 | # * https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python/blob/master/11-Extended-Kalman-Filters.ipynb
6 | # * Nonlinear Dynamics and Chaos - Steven Strogatz
7 |
8 | from jsl.demos import plot_utils
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 | import jax.numpy as jnp
12 | from jax import random
13 |
14 | from jsl.nlds.base import NLDS
15 | from jsl.nlds.continuous_extended_kalman_filter import estimate
16 |
17 |
18 | def fz(x):
19 | x, y = x
20 | return jnp.asarray([y, x - x ** 3])
21 |
22 |
23 | def fx(x):
24 | x, y = x
25 | return jnp.asarray([x, y])
26 |
27 |
28 | def main():
29 | dt = 0.01
30 | T = 7
31 | nsamples = 70
32 | x0 = jnp.array([0.5, -0.75])
33 |
34 | # State noise
35 | Qt = jnp.eye(2) * 0.001
36 | # Observed noise
37 | Rt = jnp.eye(2) * 0.01
38 |
39 | key = random.PRNGKey(314)
40 | ekf = NLDS(fz, fx, Qt, Rt)
41 | sample_state, sample_obs, jump = ekf.sample(key, x0, T, nsamples)
42 | mu_hist, V_hist = estimate(ekf, sample_state, sample_obs, jump, dt)
43 |
44 | vmin, vmax, step = -1.5, 1.5 + 0.5, 0.5
45 | X = np.mgrid[-1:1.5:step, vmin:vmax:step][::-1]
46 | X_dot = jnp.apply_along_axis(fz, 0, X)
47 |
48 | dict_figures = {}
49 |
50 | fig, ax = plt.subplots()
51 | ax.plot(*sample_state.T, label="state space")
52 | ax.scatter(*sample_obs.T, marker="+", c="tab:green", s=60, label="observations")
53 | field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa")
54 | ax.legend()
55 | plt.axis("equal")
56 | ax.set_title("State Space")
57 | dict_figures["ekf-state-space"] = fig
58 |
59 | fig, ax = plt.subplots()
60 | ax.plot(*sample_state.T, c="tab:orange", label="EKF estimation")
61 | ax.scatter(*sample_obs.T, marker="+", s=60, c="tab:green", label="observations")
62 | ax.scatter(*mu_hist[0], c="black", zorder=3)
63 | for mut, Vt in zip(mu_hist[::4], V_hist[::4]):
64 | plot_utils.plot_ellipse(Vt, mut, ax, plot_center=False, alpha=0.9, zorder=3)
65 | plt.legend()
66 | field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa")
67 | ax.legend()
68 | plt.axis("equal")
69 | ax.set_title("Approximate Space")
70 | dict_figures["ekf-estimated-space"] = fig
71 |
72 | return dict_figures
73 |
74 |
75 | if __name__ == "__main__":
76 | from jsl.demos.plot_utils import savefig
77 |
78 | plt.rcParams["axes.spines.right"] = False
79 | plt.rcParams["axes.spines.top"] = False
80 | dict_figures = main()
81 | savefig(dict_figures)
82 | plt.show()
83 |
--------------------------------------------------------------------------------
/jsl/demos/ekf_mlp.py:
--------------------------------------------------------------------------------
1 | # Demo showcasing the training of an MLP with a single hidden layer using
2 | # Extended Kalman Filtering (EKF).
3 | # In this demo, we consider the latent state to be the weights of an MLP.
4 | # The observed state at time t is the output of the MLP as influenced by the weights
5 | # at time t-1 and the covariate x[t].
6 | # The transition function between latent states is the identity function.
7 | # For more information, see
8 | # * Neural Network Training Using Unscented and Extended Kalman Filter
9 | # https://juniperpublishers.com/raej/RAEJ.MS.ID.555568.php
10 |
11 | import jax
12 | import jax.numpy as jnp
13 | import flax.linen as nn
14 | import matplotlib.pyplot as plt
15 | import jsl.nlds.extended_kalman_filter as ekf_lib
16 | from jax.flatten_util import ravel_pytree
17 | from functools import partial
18 | from typing import Sequence
19 | from jsl.nlds.base import NLDS
20 |
21 | class MLP(nn.Module):
22 | features: Sequence[int]
23 |
24 | @nn.compact
25 | def __call__(self, x):
26 | for feat in self.features[:-1]:
27 | x = nn.relu(nn.Dense(feat)(x))
28 | x = nn.Dense(self.features[-1])(x)
29 | return x
30 |
31 |
32 | def apply(flat_params, x, model, unflatten_fn):
33 | """
34 | Multilayer Perceptron (MLP) with a single hidden unit and
35 | tanh activation function. The input-unit and the
36 | output-unit are both assumed to be unidimensional
37 |
38 | Parameters
39 | ----------
40 | W: array(2 * n_hidden + n_hidden + 1)
41 | Unravelled weights of the MLP
42 | x: array(1,)
43 | Singleton element to evaluate the MLP
44 | n_hidden: int
45 | Number of hidden units
46 |
47 | Returns
48 | -------
49 | * array(1,)
50 | Evaluation of MLP at the specified point
51 | """
52 | params = unflatten_fn(flat_params)
53 | return model.apply(params, x)
54 |
55 |
56 | def sample_observations(key, f, n_obs, xmin, xmax, x_noise=0.1, y_noise=3.0):
57 | key_x, key_y, key_shuffle = jax.random.split(key, 3)
58 | x_noise = jax.random.normal(key_x, (n_obs,)) * x_noise
59 | y_noise = jax.random.normal(key_y, (n_obs,)) * y_noise
60 | x = jnp.linspace(xmin, xmax, n_obs) + x_noise
61 | y = f(x) + y_noise
62 | X = jnp.c_[x, y]
63 |
64 | shuffled_ixs = jax.random.permutation(key_shuffle, jnp.arange(n_obs))
65 | X, y = jnp.array(X[shuffled_ixs, :].T)
66 | return X, y
67 |
68 |
69 | def plot_mlp_prediction(key, xobs, yobs, xtest, fw, w, Sw, ax, n_samples=100):
70 | W_samples = jax.random.multivariate_normal(key, w, Sw, (n_samples,))
71 | sample_yhat = fw(W_samples, xtest[:, None])
72 | for sample in sample_yhat: # sample curves
73 | ax.plot(xtest, sample, c="tab:gray", alpha=0.07)
74 | ax.plot(xtest, sample_yhat.mean(axis=0)) # mean of posterior predictive
75 | ax.scatter(xobs, yobs, s=14, c="none", edgecolor="black", label="observations", alpha=0.5)
76 | ax.set_xlim(xobs.min(), xobs.max())
77 |
78 |
79 | def plot_intermediate_steps(key, ax, fwd_func, intermediate_steps, xtest, mu_hist, Sigma_hist, x, y):
80 | """
81 | Plot the intermediate steps of the training process, all of them in the same plot
82 | but in different subplots.
83 | """
84 | for step, axi in zip(intermediate_steps, ax.flatten()):
85 | W_step, SW_step = mu_hist[step], Sigma_hist[step]
86 | x_step, y_step = x[:step], y[:step]
87 | plot_mlp_prediction(key, x_step, y_step, xtest, fwd_func, W_step, SW_step, axi)
88 | axi.set_title(f"step={step}")
89 | plt.tight_layout()
90 |
91 |
92 | def plot_intermediate_steps_single(key, method, fwd_func, intermediate_steps, xtest, mu_hist, Sigma_hist, x, y):
93 | """
94 | Plot the intermediate steps of the training process, each one in a different plot.
95 | """
96 | figures = {}
97 | for step in intermediate_steps:
98 | W_step, SW_step = mu_hist[step], Sigma_hist[step]
99 | x_step, y_step = x[:step], y[:step]
100 | fig_step, axi = plt.subplots()
101 | plot_mlp_prediction(key, x_step, y_step, xtest, fwd_func, W_step, SW_step, axi)
102 | axi.set_title(f"step={step}")
103 | plt.tight_layout()
104 | figname = f"{method}-mlp-step-{step}"
105 | figures[figname] = fig_step
106 | return figures
107 |
108 |
109 | def f(x):
110 | return x - 10 * jnp.cos(x) * jnp.sin(x) + x ** 3
111 |
112 |
113 | def fz(W):
114 | return W
115 |
116 |
117 | def main():
118 | key = jax.random.PRNGKey(314)
119 | key_sample_obs, key_weights, key_init = jax.random.split(key, 3)
120 |
121 | all_figures = {}
122 |
123 | # *** MLP configuration ***
124 | n_hidden = 6
125 | n_out = 1
126 | n_in = 1
127 | model = MLP([n_hidden, n_out])
128 |
129 | batch_size = 20
130 | batch = jnp.ones((batch_size, n_in))
131 |
132 | variables = model.init(key_init, batch)
133 | W0, unflatten_fn = ravel_pytree(variables)
134 |
135 |
136 | fwd_mlp = partial(apply, model=model, unflatten_fn=unflatten_fn)
137 | # vectorised for multiple observations
138 | fwd_mlp_obs = jax.vmap(fwd_mlp, in_axes=[None, 0])
139 | # vectorised for multiple observations and weights
140 | fwd_mlp_obs_weights = jax.vmap(fwd_mlp_obs, in_axes=[0, None])
141 |
142 | # *** Generating training and test data ***
143 | n_obs = 200
144 | xmin, xmax = -3, 3
145 | sigma_y = 3.0
146 | x, y = sample_observations(key_sample_obs, f, n_obs, xmin, xmax, x_noise=0, y_noise=sigma_y)
147 | xtest = jnp.linspace(x.min(), x.max(), n_obs)
148 |
149 | # *** MLP Training with EKF ***
150 | n_params = W0.size
151 | W0 = jax.random.normal(key_weights, (n_params,)) * 1 # initial random guess
152 | Q = jnp.eye(n_params) * 1e-4 # parameters do not change
153 | R = jnp.eye(1) * sigma_y ** 2 # observation noise is fixed
154 | Vinit = jnp.eye(n_params) * 100 # vague prior
155 |
156 | ekf = NLDS(fz, fwd_mlp, Q, R)
157 | (W_ekf, SW_ekf), hist_ekf = ekf_lib.filter(ekf, W0, y[:, None], x[:, None], Vinit, return_params=["mean", "cov"])
158 | ekf_mu_hist, ekf_Sigma_hist = hist_ekf["mean"], hist_ekf["cov"]
159 |
160 | # Plot final performance
161 | fig, ax = plt.subplots()
162 | plot_mlp_prediction(key, x, y, xtest, fwd_mlp_obs_weights, W_ekf, SW_ekf, ax)
163 | ax.set_title("EKF + MLP")
164 | all_figures["ekf-mlp"] = fig
165 |
166 | # Plot intermediate performance
167 | intermediate_steps = [10, 20, 30, 40, 50, 60]
168 | fig, ax = plt.subplots(2, 2)
169 | plot_intermediate_steps(key, ax, fwd_mlp_obs_weights, intermediate_steps, xtest, ekf_mu_hist, ekf_Sigma_hist, x, y)
170 | plt.suptitle("EKF + MLP training")
171 | all_figures["ekf-mlp-intermediate"] = fig
172 | figures_intermediates = plot_intermediate_steps_single(key, "ekf", fwd_mlp_obs_weights,
173 | intermediate_steps, xtest, ekf_mu_hist, ekf_Sigma_hist, x, y)
174 | all_figures = {**all_figures, **figures_intermediates}
175 | return all_figures
176 |
177 |
178 | if __name__ == "__main__":
179 | from jsl.demos.plot_utils import savefig
180 |
181 | plt.rcParams["axes.spines.right"] = False
182 | plt.rcParams["axes.spines.top"] = False
183 | figures = main()
184 | savefig(figures)
185 | plt.show()
186 |
--------------------------------------------------------------------------------
/jsl/demos/ekf_mlp_anim.py:
--------------------------------------------------------------------------------
1 | # Example showcasing the learning process of the EKF algorithm.
2 | # This demo is based on the ekf_mlp_anim_demo.py demo.
3 | # The animation script produces this video.
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | from jax.flatten_util import ravel_pytree
8 | from jax.random import PRNGKey, split, normal, multivariate_normal
9 |
10 | import matplotlib.pyplot as plt
11 | import matplotlib.animation as animation
12 | from functools import partial
13 |
14 | from jsl.demos.ekf_mlp import MLP, apply, sample_observations
15 | from jsl.nlds.base import NLDS
16 | from jsl.nlds.extended_kalman_filter import filter
17 |
18 |
19 | def main(fx, fz, filepath):
20 | key = PRNGKey(314)
21 | key_sample_obs, key_weights, key_init = split(key, 3)
22 |
23 | # *** MLP configuration ***
24 | n_hidden = 6
25 | n_out = 1
26 | n_in = 1
27 | model = MLP([n_hidden, n_out])
28 |
29 | batch_size = 20
30 | batch = jnp.ones((batch_size, n_in))
31 |
32 | variables = model.init(key_init, batch)
33 | W0, unflatten_fn = ravel_pytree(variables)
34 |
35 | fwd_mlp = partial(apply, model=model, unflatten_fn=unflatten_fn)
36 | # vectorised for multiple observations
37 | fwd_mlp_obs = jax.vmap(fwd_mlp, in_axes=[None, 0])
38 | # vectorised for multiple observations and weights
39 | fwd_mlp_obs_weights = jax.vmap(fwd_mlp_obs, in_axes=[0, None])
40 |
41 | # *** Generating training and test data ***
42 | n_obs = 200
43 | xmin, xmax = -3, 3
44 | sigma_y = 3.0
45 | x, y = sample_observations(key_sample_obs, fx, n_obs, xmin, xmax, x_noise=0, y_noise=sigma_y)
46 | xtest = jnp.linspace(x.min(), x.max(), n_obs)
47 |
48 | # *** MLP Training with EKF ***
49 | n_params = W0.size
50 | W0 = normal(key_weights, (n_params,)) * 1 # initial random guess
51 | Q = jnp.eye(n_params) * 1e-4 # parameters do not change
52 | R = jnp.eye(1) * sigma_y ** 2 # observation noise is fixed
53 | Vinit = jnp.eye(n_params) * 100 # vague prior
54 |
55 | ekf = NLDS(fz, fwd_mlp, Q, R)
56 | _, ekf_hist = filter(ekf, W0, y[:, None], x[:, None], Vinit, return_params=["mean", "cov"])
57 | ekf_mu_hist, ekf_Sigma_hist = ekf_hist["mean"], ekf_hist["cov"]
58 |
59 | xtest = jnp.linspace(x.min(), x.max(), 200)
60 | fig, ax = plt.subplots()
61 |
62 | def func(i):
63 | plt.cla()
64 | W, SW = ekf_mu_hist[i], ekf_Sigma_hist[i]
65 | W_samples = multivariate_normal(key, W, SW, (100,))
66 | sample_yhat = fwd_mlp_obs_weights(W_samples, xtest[:, None])
67 | for sample in sample_yhat:
68 | ax.plot(xtest, sample, c="tab:gray", alpha=0.07)
69 | ax.plot(xtest, sample_yhat.mean(axis=0))
70 | ax.scatter(x[:i], y[:i], s=14, c="none", edgecolor="black", label="observations")
71 | ax.scatter(x[i], y[i], s=30, c="tab:red")
72 | ax.set_title(f"EKF+MLP ({i + 1:03}/{n_obs})")
73 | ax.set_xlim(x.min(), x.max())
74 | ax.set_ylim(y.min(), y.max())
75 |
76 | return ax
77 |
78 | ani = animation.FuncAnimation(fig, func, frames=n_obs)
79 | ani.save(filepath, dpi=200, bitrate=-1, fps=10)
80 |
81 |
82 | if __name__ == "__main__":
83 | import os
84 |
85 | plt.rcParams["axes.spines.right"] = False
86 | plt.rcParams["axes.spines.top"] = False
87 |
88 | path = os.environ.get("FIGDIR")
89 | path = "." if path is None else path
90 | filepath = os.path.join(path, "samples_hist_ekf.mp4")
91 |
92 | def f(x): return x - 10 * jnp.cos(x) * jnp.sin(x) + x ** 3
93 | def fz(W): return W
94 | main(f, fz, filepath)
95 |
96 | print(f"Saved animation to {filepath}")
97 |
--------------------------------------------------------------------------------
/jsl/demos/ekf_vs_eks_spiral.py:
--------------------------------------------------------------------------------
1 | # Compare extended Kalman filter with unscented kalman filter
2 | # on a nonlinear 2d tracking problem
3 | import jax.numpy as jnp
4 | import matplotlib.pyplot as plt
5 | import jsl.nlds.extended_kalman_smoother as eks
6 | from jax import random
7 | from jsl.demos import plot_utils
8 | from jsl.nlds.base import NLDS
9 |
10 |
11 | def plot_data(sample_state, sample_obs):
12 | fig, ax = plt.subplots()
13 | ax.plot(*sample_state.T, label="state space")
14 | ax.scatter(*sample_obs.T, s=60, c="tab:green", marker="+")
15 | ax.scatter(*sample_state[0], c="black", zorder=3)
16 | ax.legend()
17 | ax.set_title("Noisy observations from hidden trajectory")
18 | plt.axis("equal")
19 | return fig, ax
20 |
21 |
22 | def plot_inference(sample_obs, mean_hist, Sigma_hist, label):
23 | fig, ax = plt.subplots()
24 | ax.scatter(*sample_obs.T, marker="+", color="tab:green")
25 | ax.plot(*mean_hist.T, c="tab:orange", label=label)
26 | ax.scatter(*mean_hist[0], c="black", zorder=3)
27 | plt.legend()
28 | collection = [(mut, Vt) for mut, Vt in zip(mean_hist[::10], Sigma_hist[::10])
29 | if Vt[0, 0] > 0 and Vt[1, 1] > 0 and abs(Vt[1, 0] - Vt[0, 1]) < 7e-4]
30 | for mut, Vt in collection:
31 | plot_utils.plot_ellipse(Vt, mut, ax, plot_center=False, alpha=0.9, zorder=3)
32 | plt.scatter(*mut, c="black", zorder=3, s=5)
33 | plt.axis("equal")
34 | return fig, ax
35 |
36 |
37 | def main():
38 | def fz(x, dt): return x + dt * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])])
39 |
40 | def fx(x, *args): return x
41 |
42 | dt = 0.1
43 | nsteps = 300
44 | # Initial state vector
45 | x0 = jnp.array([1.5, 0.0])
46 | x0 = jnp.array([1.053, 0.145])
47 | state_size, *_ = x0.shape
48 | # State noise
49 | Qt = jnp.eye(state_size) * 0.001
50 | # Observed noise
51 | Rt = jnp.eye(2) * 0.05
52 |
53 | key = random.PRNGKey(31415)
54 | model = NLDS(lambda x: fz(x, dt), fx, Qt, Rt)
55 | sample_state, sample_obs = model.sample(key, x0, nsteps)
56 |
57 | # _, ekf_hist = ekf_lib.filter(ekf_model, x0, sample_obs, return_params=["mean", "cov"])
58 | hist = eks.smooth(model, x0, sample_obs, return_params=["mean", "cov"],
59 | return_filter_history=True)
60 | eks_hist = hist["smooth"]
61 | ekf_hist = hist["filter"]
62 |
63 | eks_mean_hist = eks_hist["mean"]
64 | eks_Sigma_hist = eks_hist["cov"]
65 |
66 | ekf_mean_hist = ekf_hist["mean"]
67 | ekf_Sigma_hist = ekf_hist["cov"]
68 |
69 | dict_figures = {}
70 | # nlds2d_data
71 | fig_data, ax = plot_data(sample_state, sample_obs)
72 | dict_figures["nlds2d_data"] = fig_data
73 |
74 | # nlds2d_ekf
75 | fig_ekf, ax = plot_inference(sample_obs, ekf_mean_hist, ekf_Sigma_hist,
76 | label="filtered")
77 | ax.set_title("EKF")
78 | dict_figures["nlds2d_ekf"] = fig_ekf
79 |
80 | # nlds2d_eks
81 | fig_eks, ax = plot_inference(sample_obs, eks_mean_hist, eks_Sigma_hist,
82 | label="smoothed")
83 | ax.set_title("EKS")
84 | dict_figures["nlds2d_eks"] = fig_eks
85 |
86 | return dict_figures
87 |
88 |
89 | if __name__ == "__main__":
90 | from jsl.demos.plot_utils import savefig
91 |
92 | dict_figures = main()
93 | savefig(dict_figures)
94 | plt.show()
95 |
--------------------------------------------------------------------------------
/jsl/demos/ekf_vs_ukf_spiral.py:
--------------------------------------------------------------------------------
1 | # Compare extended Kalman filter with unscented kalman filter on a nonlinear 2d tracking problem
2 | from jax import random
3 |
4 | import matplotlib.pyplot as plt
5 | import jax.numpy as jnp
6 |
7 | import jsl.nlds.extended_kalman_filter as ekf_lib
8 | import jsl.nlds.unscented_kalman_filter as ukf_lib
9 | from jsl.demos import plot_utils
10 | from jsl.nlds.base import NLDS
11 |
12 |
13 | def check_symmetric(a, rtol=1.1):
14 | return jnp.allclose(a, a.T, rtol=rtol)
15 |
16 |
17 | def plot_data(sample_state, sample_obs):
18 | fig, ax = plt.subplots()
19 | ax.plot(*sample_state.T, label="state space")
20 | ax.scatter(*sample_obs.T, s=60, c="tab:green", marker="+")
21 | ax.scatter(*sample_state[0], c="black", zorder=3)
22 | ax.legend()
23 | ax.set_title("Noisy observations from hidden trajectory")
24 | plt.axis("equal")
25 | return fig, ax
26 |
27 |
28 | def plot_inference(sample_obs, mean_hist, Sigma_hist):
29 | fig, ax = plt.subplots()
30 | ax.scatter(*sample_obs.T, marker="+", color="tab:green")
31 | ax.plot(*mean_hist.T, c="tab:orange", label="filtered")
32 | ax.scatter(*mean_hist[0], c="black", zorder=3)
33 | plt.legend()
34 | collection = [(mut, Vt) for mut, Vt in zip(mean_hist[::4], Sigma_hist[::4])
35 | if Vt[0, 0] > 0 and Vt[1, 1] > 0 and abs(Vt[1, 0] - Vt[0, 1]) < 7e-4]
36 | for mut, Vt in collection:
37 | plot_utils.plot_ellipse(Vt, mut, ax, plot_center=False, alpha=0.9, zorder=3)
38 | plt.axis("equal")
39 | return fig, ax
40 |
41 |
42 | def main():
43 | def fz(x, dt): return x + dt * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])])
44 |
45 | def fx(x, *args): return x
46 |
47 | dt = 0.4
48 | nsteps = 100
49 | # Initial state vector
50 | x0 = jnp.array([1.5, 0.0])
51 | state_size, *_ = x0.shape
52 | # State noise
53 | Qt = jnp.eye(state_size) * 0.001
54 | # Observed noise
55 | Rt = jnp.eye(2) * 0.05
56 | alpha, beta, kappa = 1, 0, 2
57 |
58 | key = random.PRNGKey(31415)
59 | ekf_model = NLDS(lambda x: fz(x, dt), fx, Qt, Rt)
60 | sample_state, sample_obs = ekf_model.sample(key, x0, nsteps)
61 |
62 | ukf_model = NLDS(lambda x: fz(x, dt), fx, Qt, Rt,
63 | alpha, beta, kappa, state_size)
64 |
65 | _, ekf_hist = ekf_lib.filter(ekf_model, x0, sample_obs, return_params=["mean", "cov"])
66 | ukf_mean_hist, ukf_Sigma_hist = ukf_lib.filter(ukf_model, x0, sample_obs)
67 |
68 | ekf_mean_hist = ekf_hist["mean"]
69 | ekf_Sigma_hist = ekf_hist["cov"]
70 |
71 | dict_figures = {}
72 | # nlds2d_data
73 | fig_data, ax = plot_data(sample_state, sample_obs)
74 | dict_figures["nlds2d_data"] = fig_data
75 |
76 | # nlds2d_ekf
77 | fig_ekf, ax = plot_inference(sample_obs, ekf_mean_hist, ekf_Sigma_hist)
78 | ax.set_title("EKF")
79 | dict_figures["nlds2d_ekf"] = fig_ekf
80 |
81 | # nlds2d_ukf
82 | fig_ukf, ax = plot_inference(sample_obs, ukf_mean_hist, ukf_Sigma_hist)
83 | ax.set_title("UKF")
84 | dict_figures["nlds2d_ukf"] = fig_ukf
85 |
86 | return dict_figures
87 |
88 |
89 | if __name__ == "__main__":
90 | from jsl.demos.plot_utils import savefig
91 |
92 | dict_figures = main()
93 | savefig(dict_figures)
94 | plt.show()
95 |
--------------------------------------------------------------------------------
/jsl/demos/hmm_casino.py:
--------------------------------------------------------------------------------
1 | # Occasionally dishonest casino example [Durbin98, p54]. This script
2 | # exemplifies a Hidden Markov Model (HMM) in which the throw of a die
3 | # may result in the die being biased (towards 6) or unbiased. If the dice turns out to
4 | # be biased, the probability of remaining biased is high, and similarly for the unbiased state.
5 | # Assuming we observe the die being thrown n times the goal is to recover the periods in which
6 | # the die was biased.
7 | # Original matlab code: https://github.com/probml/pmtk3/blob/master/demos/casinoDemo.m
8 |
9 |
10 | # from jsl.hmm.hmm_numpy_lib import (HMMNumpy, hmm_sample_numpy, hmm_plot_graphviz,
11 | # hmm_forwards_backwards_numpy, hmm_viterbi_numpy)
12 |
13 | from jsl.hmm.hmm_utils import hmm_plot_graphviz
14 |
15 | import numpy as np
16 | import jax.numpy as jnp
17 | import matplotlib.pyplot as plt
18 | from jsl.hmm.hmm_lib import HMMJax, hmm_forwards_backwards_jax, hmm_sample_jax, hmm_viterbi_jax
19 | from jax.random import PRNGKey
20 |
21 |
22 | def find_dishonest_intervals(z_hist):
23 | """
24 | Find the span of timesteps that the
25 | simulated systems turns to be in state 1
26 | Parameters
27 | ----------
28 | z_hist: array(n_samples)
29 | Result of running the system with two
30 | latent states
31 | Returns
32 | -------
33 | list of tuples with span of values
34 | """
35 | spans = []
36 | x_init = 0
37 | for t, _ in enumerate(z_hist[:-1]):
38 | if z_hist[t + 1] == 0 and z_hist[t] == 1:
39 | x_end = t
40 | spans.append((x_init, x_end))
41 | elif z_hist[t + 1] == 1 and z_hist[t] == 0:
42 | x_init = t + 1
43 | return spans
44 |
45 |
46 | def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False):
47 | """
48 | Plot the estimated smoothing/filtering/map of a sequence of hidden states.
49 | "Vertical gray bars denote times when the hidden
50 | state corresponded to state 1. Blue lines represent the
51 | posterior probability of being in that state given different subsets
52 | of observed data." See Markov and Hidden Markov models section for more info
53 | Parameters
54 | ----------
55 | inference_values: array(n_samples, state_size)
56 | Result of runnig smoothing method
57 | z_hist: array(n_samples)
58 | Latent simulation
59 | ax: matplotlib.axes
60 | state: int
61 | Decide which state to highlight
62 | map_estimate: bool
63 | Whether to plot steps (simple plot if False)
64 | """
65 | n_samples = len(inference_values)
66 | xspan = np.arange(1, n_samples + 1)
67 | spans = find_dishonest_intervals(z_hist)
68 | if map_estimate:
69 | ax.step(xspan, inference_values, where="post")
70 | else:
71 | ax.plot(xspan, inference_values[:, state])
72 |
73 | for span in spans:
74 | ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none")
75 | ax.set_xlim(1, n_samples)
76 | # ax.set_ylim(0, 1)
77 | ax.set_ylim(-0.1, 1.1)
78 | ax.set_xlabel("Observation number")
79 |
80 |
81 | def main():
82 | # state transition matrix
83 | A = jnp.array([
84 | [0.95, 0.05],
85 | [0.10, 0.90]
86 | ])
87 |
88 | # observation matrix
89 | B = jnp.array([
90 | [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die
91 | [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die
92 | ])
93 |
94 | n_samples = 300
95 | init_state_dist = jnp.array([1, 1]) / 2
96 |
97 | # hmm = HMM(A, B, init_state_dist)
98 | params = HMMJax(A, B, init_state_dist)
99 |
100 | seed = 0
101 | z_hist, x_hist = hmm_sample_jax(params, n_samples, PRNGKey(seed))
102 | # z_hist, x_hist = hmm_sample_numpy(params, n_samples, 314)
103 |
104 | z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60]
105 | x_hist_str = "".join((np.array(x_hist) + 1).astype(str))[:60]
106 |
107 | print("Printing sample observed/latent...")
108 | print(f"x: {x_hist_str}")
109 | print(f"z: {z_hist_str}")
110 |
111 | # Do inference
112 | # alpha, _, gamma, loglik = hmm_forwards_backwards_numpy(params, x_hist, len(x_hist))
113 | alpha, beta, gamma, loglik = hmm_forwards_backwards_jax(params, x_hist, len(x_hist))
114 | print(f"Loglikelihood: {loglik}")
115 |
116 | # z_map = hmm_viterbi_numpy(params, x_hist)
117 | z_map = hmm_viterbi_jax(params, x_hist)
118 |
119 | dict_figures = {}
120 |
121 | # Plot results
122 | fig, ax = plt.subplots()
123 | plot_inference(alpha, z_hist, ax)
124 | ax.set_ylabel("p(loaded)")
125 | ax.set_title("Filtered")
126 | dict_figures["hmm_casino_filter"] = fig
127 |
128 | fig, ax = plt.subplots()
129 | plot_inference(gamma, z_hist, ax)
130 | ax.set_ylabel("p(loaded)")
131 | ax.set_title("Smoothed")
132 | dict_figures["hmm_casino_smooth"] = fig
133 |
134 | fig, ax = plt.subplots()
135 | plot_inference(z_map, z_hist, ax, map_estimate=True)
136 | ax.set_ylabel("MAP state")
137 | ax.set_title("Viterbi")
138 |
139 | dict_figures["hmm_casino_map"] = fig
140 | states, observations = ["Fair Dice", "Loaded Dice"], [str(i + 1) for i in range(B.shape[1])]
141 |
142 | #AA = hmm.trans_dist.probs
143 | #assert np.allclose(A, AA)
144 |
145 | dotfile = hmm_plot_graphviz(A, B, states, observations)
146 | dotfile_dict = {"hmm_casino_graphviz": dotfile}
147 |
148 | return dict_figures, dotfile_dict
149 |
150 |
151 | if __name__ == "__main__":
152 | from jsl.demos.plot_utils import savefig, savedotfile
153 | figs, dotfile = main()
154 |
155 | savefig(figs)
156 | savedotfile(dotfile)
157 | plt.show()
--------------------------------------------------------------------------------
/jsl/demos/hmm_casino_em_train.py:
--------------------------------------------------------------------------------
1 | """
2 | This demo shows the parameter estimations of HMMs via Baulm-Welch algorithm on the occasionally dishonest casino example.
3 | Author : Aleyna Kara(@karalleyna)
4 | """
5 |
6 |
7 | import time
8 | import numpy as np
9 | import jax.numpy as jnp
10 | import matplotlib.pyplot as plt
11 | from jax.random import split, PRNGKey
12 | from jsl.hmm.hmm_numpy_lib import HMMNumpy, hmm_em_numpy
13 | from jsl.hmm.hmm_lib import HMMJax, hmm_sample_jax
14 | from jsl.hmm.hmm_lib import init_random_params_jax, hmm_em_jax
15 | from jsl.hmm import hmm_utils
16 |
17 |
18 | def main():
19 | A = jnp.array([
20 | [0.95, 0.05],
21 | [0.10, 0.90]
22 | ])
23 |
24 | # observation matrix
25 | B = jnp.array([
26 | [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die
27 | [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die
28 | ])
29 |
30 | pi = jnp.array([1, 1]) / 2
31 |
32 | seed = 100
33 | rng_key = PRNGKey(seed)
34 | rng_key, rng_sample, rng_batch, rng_init = split(rng_key, 4)
35 |
36 | casino = HMMJax(A, B, pi)
37 |
38 | n_obs_seq, batch_size, max_len = 5, 5, 3000
39 |
40 | observations, lens = hmm_utils.hmm_sample_n(casino,
41 | hmm_sample_jax,
42 | n_obs_seq, max_len,
43 | rng_sample)
44 | observations, lens = hmm_utils.pad_sequences(observations, lens)
45 |
46 | n_hidden, n_obs = B.shape
47 | params_jax = init_random_params_jax([n_hidden, n_obs], rng_key=rng_init)
48 | params_numpy = HMMNumpy(np.array(params_jax.trans_mat),
49 | np.array(params_jax.obs_mat),
50 | np.array(params_jax.init_dist))
51 |
52 | num_epochs = 20
53 |
54 | start = time.time()
55 | params_numpy, neg_ll_numpy = hmm_em_numpy(np.array(observations),
56 | np.array(lens),
57 | num_epochs=num_epochs,
58 | init_params=params_numpy)
59 | print(f"Time taken by numpy version of EM : {time.time() - start}s")
60 |
61 | start = time.time()
62 | params_jax, neg_ll_jax = hmm_em_jax(observations,
63 | lens,
64 | num_epochs=num_epochs,
65 | init_params=params_jax)
66 | print(f"Time taken by JAX version of EM : {time.time() - start}s")
67 |
68 | assert jnp.allclose(np.array(neg_ll_jax), np.array(neg_ll_numpy), 4)
69 |
70 | print(f" Negative loglikelihoods : {neg_ll_jax}")
71 |
72 | dict_figures = {}
73 | fig, ax = plt.subplots()
74 | ax.plot(neg_ll_numpy, label="EM numpy")
75 | ax.set_label("Number of Iterations")
76 | dict_figures["em-numpy"] = fig
77 |
78 | fig, ax = plt.subplots()
79 | ax.plot(neg_ll_jax, label="EM JAX")
80 | ax.set_xlabel("Number of Iterations")
81 | dict_figures["em-jax"] = fig
82 |
83 | states, observations = ["Fair Dice", "Loaded Dice"], [str(i + 1) for i in range(B.shape[1])]
84 | # dotfile_np = hmm_plot_graphviz(params_numpy, "hmm_casino_train_np", states, observations)
85 | dotfile_np = hmm_utils.hmm_plot_graphviz(params_numpy.trans_mat, params_numpy.obs_mat,
86 | states, observations)
87 |
88 | # dotfile_jax = hmm_plot_graphviz(params_jax, "hmm_casino_train_jax", states, observations)
89 | dotfile_jax = hmm_utils.hmm_plot_graphviz(params_jax.trans_mat, params_jax.obs_mat,
90 | states, observations)
91 |
92 | dotfile_dict = {"graph-numpy": dotfile_np, "graph-jax": dotfile_jax}
93 |
94 | return dict_figures, dotfile_dict
95 |
96 |
97 | if __name__ == "__main__":
98 | from jsl.demos.plot_utils import savefig, savedotfile
99 |
100 | figs, dotfiles = main()
101 | savefig(figs)
102 | savedotfile(dotfiles)
103 | plt.show()
104 |
--------------------------------------------------------------------------------
/jsl/demos/hmm_casino_numpy.py:
--------------------------------------------------------------------------------
1 | # Occasionally dishonest casino example [Durbin98, p54]. This script
2 | # exemplifies a Hidden Markov Model (HMM) in which the throw of a die
3 | # may result in the die being biased (towards 6) or unbiased. If the dice turns out to
4 | # be biased, the probability of remaining biased is high, and similarly for the unbiased state.
5 | # Assuming we observe the die being thrown n times the goal is to recover the periods in which
6 | # the die was biased.
7 | # Original matlab code: https://github.com/probml/pmtk3/blob/master/demos/casinoDemo.m
8 |
9 | from jsl.hmm.hmm_numpy_lib import (HMMNumpy, hmm_sample_numpy,
10 | hmm_forwards_backwards_numpy, hmm_viterbi_numpy)
11 |
12 | from jsl.hmm.hmm_utils import hmm_plot_graphviz
13 |
14 | import numpy as np
15 | import matplotlib.pyplot as plt
16 |
17 |
18 | def find_dishonest_intervals(z_hist):
19 | """
20 | Find the span of timesteps that the
21 | simulated systems turns to be in state 1
22 | Parameters
23 | ----------
24 | z_hist: array(n_samples)
25 | Result of running the system with two
26 | latent states
27 | Returns
28 | -------
29 | list of tuples with span of values
30 | """
31 | spans = []
32 | x_init = 0
33 | for t, _ in enumerate(z_hist[:-1]):
34 | if z_hist[t + 1] == 0 and z_hist[t] == 1:
35 | x_end = t
36 | spans.append((x_init, x_end))
37 | elif z_hist[t + 1] == 1 and z_hist[t] == 0:
38 | x_init = t + 1
39 | return spans
40 |
41 |
42 | def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False):
43 | """
44 | Plot the estimated smoothing/filtering/map of a sequence of hidden states.
45 | "Vertical gray bars denote times when the hidden
46 | state corresponded to state 1. Blue lines represent the
47 | posterior probability of being in that state given different subsets
48 | of observed data." See Markov and Hidden Markov models section for more info
49 | Parameters
50 | ----------
51 | inference_values: array(n_samples, state_size)
52 | Result of runnig smoothing method
53 | z_hist: array(n_samples)
54 | Latent simulation
55 | ax: matplotlib.axes
56 | state: int
57 | Decide which state to highlight
58 | map_estimate: bool
59 | Whether to plot steps (simple plot if False)
60 | """
61 | n_samples = len(inference_values)
62 | xspan = np.arange(1, n_samples + 1)
63 | spans = find_dishonest_intervals(z_hist)
64 | if map_estimate:
65 | ax.step(xspan, inference_values, where="post")
66 | else:
67 | ax.plot(xspan, inference_values[:, state])
68 |
69 | for span in spans:
70 | ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none")
71 | ax.set_xlim(1, n_samples)
72 | # ax.set_ylim(0, 1)
73 | ax.set_ylim(-0.1, 1.1)
74 | ax.set_xlabel("Observation number")
75 |
76 |
77 | def main():
78 | # state transition matrix
79 | A = np.array([
80 | [0.95, 0.05],
81 | [0.10, 0.90]
82 | ])
83 |
84 | # observation matrix
85 | B = np.array([
86 | [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die
87 | [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die
88 | ])
89 |
90 | n_samples = 300
91 | init_state_dist = np.array([1, 1]) / 2
92 | params = HMMNumpy(A, B, init_state_dist)
93 | z_hist, x_hist = hmm_sample_numpy(params, n_samples, 314)
94 |
95 | z_hist_str = "".join((z_hist + 1).astype(str))[:60]
96 | x_hist_str = "".join((x_hist + 1).astype(str))[:60]
97 |
98 | print("Printing sample observed/latent...")
99 | print(f"x: {x_hist_str}")
100 | print(f"z: {z_hist_str}")
101 |
102 | # Do inference
103 | alpha, _, gamma, loglik = hmm_forwards_backwards_numpy(params, x_hist, len(x_hist))
104 | print(f"Loglikelihood: {loglik}")
105 |
106 | z_map = hmm_viterbi_numpy(params, x_hist)
107 |
108 | dict_figures = {}
109 |
110 | # Plot results
111 | fig, ax = plt.subplots()
112 | plot_inference(alpha, z_hist, ax)
113 | ax.set_ylabel("p(loaded)")
114 | ax.set_title("Filtered")
115 | dict_figures["hmm_casino_filter"] = fig
116 |
117 | fig, ax = plt.subplots()
118 | plot_inference(gamma, z_hist, ax)
119 | ax.set_ylabel("p(loaded)")
120 | ax.set_title("Smoothed")
121 | dict_figures["hmm_casino_smooth"] = fig
122 |
123 | fig, ax = plt.subplots()
124 | plot_inference(z_map, z_hist, ax, map_estimate=True)
125 | ax.set_ylabel("MAP state")
126 | ax.set_title("Viterbi")
127 | dict_figures["hmm_casino_map"] = fig
128 |
129 | file_name = "hmm_casino_params"
130 | states, observations = ["Fair Dice", "Loaded Dice"], [str(i + 1) for i in range(B.shape[1])]
131 | dotfile = hmm_plot_graphviz(A, B, states, observations)
132 | # dotfile = hmm_plot_graphviz(params, file_name, states, observations)
133 | dotfile_dict = {"hmm_casino_graphviz": dotfile}
134 |
135 | return dict_figures, dotfile_dict
136 |
137 |
138 | if __name__ == "__main__":
139 | from jsl.demos.plot_utils import savefig, savedotfile
140 |
141 | figs, dotfile = main()
142 | savefig(figs)
143 | savedotfile(dotfile)
144 | plt.show()
145 |
--------------------------------------------------------------------------------
/jsl/demos/hmm_casino_sgd_train:
--------------------------------------------------------------------------------
1 | // HMM
2 | digraph {
3 | s0 [label=<State 1 |
Obs 1 | 0.99 |
Obs 2 | 0.01 |
>]
4 | s1 [label=<State 2 |
Obs 1 | 0.01 |
Obs 2 | 0.99 |
>]
5 | s0 -> s0 [label=0.99]
6 | s0 -> s1 [label=0.01]
7 | s1 -> s0 [label=0.01]
8 | s1 -> s1 [label=0.99]
9 | rankdir=LR
10 | }
11 |
--------------------------------------------------------------------------------
/jsl/demos/hmm_casino_sgd_train.py:
--------------------------------------------------------------------------------
1 | """
2 | This demo does MAP estimation of an HMM using gradient-descent algorithm applied to the log marginal likelihood.
3 | It includes
4 |
5 | 1. Mini Batch Gradient Descent
6 | 2. Full Batch Gradient Descent
7 | 3. Stochastic Gradient Descent
8 |
9 | Author: Aleyna Kara(@karalleyna)
10 | """
11 |
12 |
13 | import matplotlib.pyplot as plt
14 | import jax.numpy as jnp
15 | from jax.example_libraries import optimizers
16 | from jax.random import split, PRNGKey
17 | from jsl.hmm.hmm_lib import fit
18 | from jsl.hmm.hmm_utils import pad_sequences, hmm_sample_n
19 | from jsl.hmm.hmm_lib import HMMJax, hmm_sample_jax
20 | from jsl.hmm.hmm_utils import hmm_plot_graphviz
21 |
22 |
23 | def main():
24 | # state transition matrix
25 | A = jnp.array([
26 | [0.95, 0.05],
27 | [0.10, 0.90]])
28 |
29 | # observation matrix
30 | B = jnp.array([
31 | [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die
32 | [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die
33 | ])
34 |
35 | pi = jnp.array([1, 1]) / 2
36 |
37 | casino = HMMJax(A, B, pi)
38 | num_hidden, num_obs = 2, 6
39 |
40 | seed = 0
41 | rng_key = PRNGKey(seed)
42 | rng_key, rng_sample = split(rng_key)
43 |
44 | n_obs_seq, max_len = 4, 5000
45 | num_epochs = 400
46 |
47 | observations, lens = pad_sequences(*hmm_sample_n(casino, hmm_sample_jax, n_obs_seq, max_len, rng_sample))
48 | optimizer = optimizers.momentum(step_size=1e-3, mass=0.95)
49 |
50 | # Mini Batch Gradient Descent
51 | batch_size = 2
52 | params_mbgd, losses_mbgd = fit(observations,
53 | lens,
54 | num_hidden,
55 | num_obs,
56 | batch_size,
57 | optimizer,
58 | rng_key=None,
59 | num_epochs=num_epochs)
60 |
61 | # Full Batch Gradient Descent
62 | batch_size = n_obs_seq
63 | params_fbgd, losses_fbgd = fit(observations,
64 | lens,
65 | num_hidden,
66 | num_obs,
67 | batch_size,
68 | optimizer,
69 | rng_key=None,
70 | num_epochs=num_epochs)
71 |
72 | # Stochastic Gradient Descent
73 | batch_size = 1
74 | params_sgd, losses_sgd = fit(observations,
75 | lens,
76 | num_hidden,
77 | num_obs,
78 | batch_size,
79 | optimizer,
80 | rng_key=None,
81 | num_epochs=num_epochs)
82 |
83 | losses = [losses_sgd, losses_mbgd, losses_fbgd]
84 | titles = ["Stochastic Gradient Descent", "Mini Batch Gradient Descent", "Full Batch Gradient Descent"]
85 |
86 | dict_figures = {}
87 | for loss, title in zip(losses, titles):
88 | filename = title.replace(" ", "_").lower()
89 | fig, ax = plt.subplots()
90 | ax.plot(loss)
91 | ax.set_title(f"{title}")
92 | dict_figures[filename] = fig
93 | dotfile = hmm_plot_graphviz(params_sgd.trans_mat, params_sgd.trans_mat)
94 | dotfile_dict = {"hmm-casino-dot": dotfile}
95 |
96 | return dict_figures, dotfile_dict
97 |
98 |
99 | if __name__ == "__main__":
100 | from jsl.demos.plot_utils import savefig, savedotfile
101 |
102 | figs, dotfile = main()
103 | savefig(figs)
104 | savedotfile(dotfile)
105 | plt.show()
106 |
--------------------------------------------------------------------------------
/jsl/demos/hmm_lillypad.py:
--------------------------------------------------------------------------------
1 | # Example of an HMM with Gaussian emission in 2D
2 | # For a matlab version, see https://github.com/probml/pmtk3/blob/master/demos/hmmLillypadDemo.m
3 |
4 | # Author: Gerardo Durán-Martín (@gerdm), Aleyna Kara(@karalleyna)
5 |
6 |
7 | import logging
8 | import distrax
9 | import numpy as np
10 | import jax.numpy as jnp
11 | import matplotlib.pyplot as plt
12 | import tensorflow_probability as tfp
13 | from distrax import HMM
14 | from jax import vmap
15 | from jax.random import PRNGKey
16 |
17 | logging.getLogger('absl').setLevel(logging.CRITICAL)
18 |
19 |
20 | def plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax, step=1e-2):
21 | """
22 | Plot the trajectory of a 2-dimensional HMM
23 | Parameters
24 | ----------
25 | hmm : HMM
26 | Hidden Markov Model
27 | samples_obs: numpy.ndarray(n_samples, 2)
28 | Observations
29 | samples_state: numpy.ndarray(n_samples, )
30 | Latent state of the system
31 | colors: list(int)
32 | List of colors for each latent state
33 | step: float
34 | Step size
35 | Returns
36 | -------
37 | * matplotlib.axes
38 | * colour of each latent state
39 | """
40 | obs_dist = hmm.obs_dist
41 | color_sample = [colors[i] for i in samples_state]
42 |
43 | xs = jnp.arange(xmin, xmax, step)
44 | ys = jnp.arange(ymin, ymax, step)
45 |
46 | v_prob = vmap(lambda x, y: obs_dist.prob(jnp.array([x, y])), in_axes=(None, 0))
47 | z = vmap(v_prob, in_axes=(0, None))(xs, ys)
48 |
49 | grid = np.mgrid[xmin:xmax:step, ymin:ymax:step]
50 |
51 | for k, color in enumerate(colors):
52 | ax.contour(*grid, z[:, :, k], levels=[1], colors=color, linewidths=3)
53 | ax.text(*(obs_dist.mean()[k] + 0.13), f"$k$={k + 1}", fontsize=13, horizontalalignment="right")
54 |
55 | ax.plot(*samples_obs.T, c="black", alpha=0.3, zorder=1)
56 | ax.scatter(*samples_obs.T, c=color_sample, s=30, zorder=2, alpha=0.8)
57 |
58 | return ax, color_sample
59 |
60 |
61 | def main():
62 | initial_probs = jnp.array([0.3, 0.2, 0.5])
63 |
64 | # transition matrix
65 | A = jnp.array([
66 | [0.3, 0.4, 0.3],
67 | [0.1, 0.6, 0.3],
68 | [0.2, 0.3, 0.5]
69 | ])
70 |
71 | S1 = jnp.array([
72 | [1.1, 0],
73 | [0, 0.3]
74 | ])
75 |
76 | S2 = jnp.array([
77 | [0.3, -0.5],
78 | [-0.5, 1.3]
79 | ])
80 |
81 | S3 = jnp.array([
82 | [0.8, 0.4],
83 | [0.4, 0.5]
84 | ])
85 |
86 | cov_collection = jnp.array([S1, S2, S3]) / 60
87 | mu_collection = jnp.array([
88 | [0.3, 0.3],
89 | [0.8, 0.5],
90 | [0.3, 0.8]
91 | ])
92 |
93 | hmm = HMM(trans_dist=distrax.Categorical(probs=A),
94 | init_dist=distrax.Categorical(probs=initial_probs),
95 | obs_dist=distrax.as_distribution(
96 | tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=mu_collection,
97 | covariance_matrix=cov_collection)))
98 | n_samples, seed = 50, 10
99 | samples_state, samples_obs = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
100 |
101 | xmin, xmax = 0, 1
102 | ymin, ymax = 0, 1.2
103 | colors = ["tab:green", "tab:blue", "tab:red"]
104 |
105 | dict_figures = {}
106 | fig, ax = plt.subplots()
107 | _, color_sample = plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax)
108 | dict_figures["hmm-lillypad-2d"] = fig
109 |
110 | fig, ax = plt.subplots()
111 | ax.step(range(n_samples), samples_state, where="post", c="black", linewidth=1, alpha=0.3)
112 | ax.scatter(range(n_samples), samples_state, c=color_sample, zorder=3)
113 | dict_figures["hmm-lillypad-step"] = fig
114 |
115 | return dict_figures
116 |
117 |
118 | if __name__ == "__main__":
119 | from jsl.demos.plot_utils import savefig
120 |
121 | plt.rcParams["axes.spines.right"] = False
122 | plt.rcParams["axes.spines.top"] = False
123 | figs = main()
124 | savefig(figs)
125 | plt.show()
126 |
--------------------------------------------------------------------------------
/jsl/demos/kf_continuous_circle.py:
--------------------------------------------------------------------------------
1 | # Example of a Kalman Filter procedure
2 | # on a continuous system with imaginary eigenvalues
3 | # and discrete samples
4 | # For futher reference and examples see:
5 | # * Section on Kalman Filters in PML vol2 book
6 | # * Nonlinear Dynamics and Chaos - Steven Strogatz
7 |
8 | import numpy as np
9 | import jax.numpy as jnp
10 | import matplotlib.pyplot as plt
11 | from jax import random
12 | from jsl.demos.plot_utils import plot_ellipse
13 |
14 | from jsl.lds.kalman_filter import LDS
15 | from jsl.lds.cont_kalman_filter import sample, filter
16 |
17 |
18 | def main():
19 | A = jnp.array([[0, 1], [-1, 0]])
20 | C = jnp.eye(2)
21 |
22 | dt = 0.01
23 | T = 5.5
24 | nsamples = 70
25 | x0 = jnp.array([0.5, -0.75])
26 |
27 | # State noise
28 | Qt = jnp.eye(2) * 0.001
29 | # Observed noise
30 | Rt = jnp.eye(2) * 0.01
31 |
32 | Sigma0 = jnp.eye(2)
33 |
34 | key = random.PRNGKey(314)
35 | lds = LDS(A, C, Qt, Rt, x0, Sigma0)
36 | sample_state, sample_obs, jump = sample(key, lds, x0, T, nsamples)
37 | mu_hist, V_hist, *_ = filter(lds, sample_obs, jump, dt)
38 |
39 | step = 0.1
40 | vmin, vmax = -1.5, 1.5 + step
41 | X = np.mgrid[-1:1.5:step, vmin:vmax:step][::-1]
42 | X_dot = jnp.einsum("ij,jxy->ixy", A, X)
43 |
44 | dict_figures = {}
45 |
46 | fig_state, ax = plt.subplots()
47 | ax.plot(*sample_state.T, label="state space")
48 | ax.scatter(*sample_obs.T, marker="+", c="tab:green", s=60, label="observations")
49 | ax.scatter(*sample_state[0], c="black", zorder=3)
50 | field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa")
51 | ax.legend()
52 | plt.axis("equal")
53 | ax.set_title("State Space")
54 | dict_figures["kf-circle-state"] = fig_state
55 |
56 | fig_filtered, ax = plt.subplots()
57 | ax.plot(*mu_hist.T, c="tab:orange", label="Filtered")
58 | ax.scatter(*sample_obs.T, marker="+", s=60, c="tab:green", label="observations")
59 | ax.scatter(*mu_hist[0], c="black", zorder=3)
60 | for mut, Vt in zip(mu_hist[::4], V_hist[::4]):
61 | plot_ellipse(Vt, mut, ax, plot_center=False, alpha=0.9, zorder=3)
62 | plt.legend()
63 | field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa")
64 | ax.legend()
65 | plt.axis("equal")
66 | ax.set_title("Approximate Space")
67 | dict_figures["kf-circle-filtered"] = fig_filtered
68 |
69 | return dict_figures
70 |
71 |
72 | if __name__ == "__main__":
73 | from jsl.demos.plot_utils import savefig
74 |
75 | dict_figures = main()
76 | savefig(dict_figures)
77 | plt.show()
78 |
--------------------------------------------------------------------------------
/jsl/demos/kf_parallel.py:
--------------------------------------------------------------------------------
1 | # Parallel Kalman Filter demo: this script simulates
2 | # 4 missiles as described in the section "state-space models".
3 | # Each of the missiles is then filtered and smoothed in parallel
4 |
5 | import jax.numpy as jnp
6 | from jsl.lds.kalman_filter import LDS, filter, smooth
7 | from jsl.demos.plot_utils import plot_ellipse
8 | import matplotlib.pyplot as plt
9 | from jax import random
10 |
11 |
12 | def sample_filter_smooth(key, lds_model, nsteps, n_samples, noisy_init):
13 | """
14 | Sample from a linear dynamical system, apply the kalman filter
15 | (forward pass), and performs smoothing.
16 |
17 | Parameters
18 | ----------
19 | lds: LinearDynamicalSystem
20 | Instance of a linear dynamical system with known parameters
21 |
22 | Returns
23 | -------
24 | Dictionary with the following key, values
25 | * (z_hist) array(n_samples, timesteps, state_size):
26 | Simulation of Latent states
27 | * (x_hist) array(n_samples, timesteps, observation_size):
28 | Simulation of observed states
29 | * (mu_hist) array(n_samples, timesteps, state_size):
30 | Filtered means mut
31 | * (Sigma_hist) array(n_samples, timesteps, state_size, state_size)
32 | Filtered covariances Sigmat
33 | * (mu_cond_hist) array(n_samples, timesteps, state_size)
34 | Filtered conditional means mut|t-1
35 | * (Sigma_cond_hist) array(n_samples, timesteps, state_size, state_size)
36 | Filtered conditional covariances Sigmat|t-1
37 | * (mu_hist_smooth) array(n_samples, timesteps, state_size):
38 | Smoothed means mut
39 | * (Sigma_hist_smooth) array(n_samples, timesteps, state_size, state_size)
40 | Smoothed covariances Sigmat
41 | """
42 | z_hist, x_hist = lds_model.sample(key, nsteps, n_samples, noisy_init)
43 | mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist = filter(lds_model, x_hist)
44 | mu_hist_smooth, Sigma_hist_smooth = smooth(lds_model, mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist)
45 |
46 | return {
47 | "z_hist": z_hist,
48 | "x_hist": x_hist,
49 | "mu_hist": mu_hist,
50 | "Sigma_hist": Sigma_hist,
51 | "mu_cond_hist": mu_cond_hist,
52 | "Sigma_cond_hist": Sigma_cond_hist,
53 | "mu_hist_smooth": mu_hist_smooth,
54 | "Sigma_hist_smooth": Sigma_hist_smooth
55 | }
56 |
57 |
58 | def plot_collection(obs, ax, means=None, covs=None, **kwargs):
59 | n_samples, n_steps, _ = obs.shape
60 | for nsim in range(n_samples):
61 | X = obs[nsim]
62 | if means is not None:
63 | mean = means[nsim]
64 | ax.scatter(*mean[0, :2], marker="o", s=20, c="black", zorder=2)
65 | ax.plot(*mean[:, :2].T, marker="o", markersize=2, **kwargs, zorder=1)
66 | if covs is not None:
67 | cov = covs[nsim]
68 | for t in range(1, n_steps, 3):
69 | plot_ellipse(cov[t][:2, :2], mean[t, :2], ax,
70 | plot_center=False, alpha=0.7)
71 | ax.scatter(*X.T, marker="+", s=60)
72 |
73 |
74 | def main():
75 | Δ = 1.0
76 | A = jnp.array([
77 | [1, 0, Δ, 0],
78 | [0, 1, 0, Δ],
79 | [0, 0, 1, 0],
80 | [0, 0, 0, 1]
81 | ])
82 |
83 | C = jnp.array([
84 | [1, 0, 0, 0],
85 | [0, 1, 0, 0]
86 | ]).astype(float)
87 |
88 | state_size, _ = A.shape
89 | observation_size, _ = C.shape
90 |
91 | Q = jnp.eye(state_size) * 0.01
92 | R = jnp.eye(observation_size) * 1.2
93 | # Prior parameter distribution
94 | mu0 = jnp.array([8, 10, 1, 0]).astype(float)
95 | Sigma0 = jnp.eye(state_size) * 0.1
96 |
97 |
98 | key = random.PRNGKey(3141)
99 | lds_instance = LDS(A, C, Q, R, mu0, Sigma0)
100 |
101 | nsteps, n_samples = 15, 4
102 | result = sample_filter_smooth(key, lds_instance, nsteps, n_samples, True)
103 |
104 | dict_figures = {}
105 |
106 | fig_latent, ax = plt.subplots()
107 | plot_collection(result["x_hist"], ax, result["z_hist"], linestyle="--")
108 | ax.set_title("State space")
109 | dict_figures["missiles_latent"] = fig_latent
110 |
111 | fig_filtered, ax = plt.subplots()
112 | plot_collection(result["x_hist"], ax, result["mu_hist"], result["Sigma_hist"])
113 | ax.set_title("Filtered")
114 | dict_figures["missiles_filtered"] = fig_filtered
115 |
116 | fig_smoothed, ax = plt.subplots()
117 | plot_collection(result["x_hist"], ax, result["mu_hist_smooth"], result["Sigma_hist_smooth"])
118 | ax.set_title("Smoothed")
119 | dict_figures["missiles_smoothed"] = fig_smoothed
120 |
121 | return dict_figures
122 |
123 |
124 | if __name__ == "__main__":
125 | from jsl.demos.plot_utils import savefig
126 | figures = main()
127 | savefig(figures)
128 | plt.show()
129 |
--------------------------------------------------------------------------------
/jsl/demos/kf_spiral.py:
--------------------------------------------------------------------------------
1 | # This demo exemplifies the use of the Kalman Filter
2 | # algorithm when the linear dynamical system induced by the
3 | # matrix A has imaginary eigenvalues
4 |
5 | import jax.numpy as jnp
6 | import matplotlib.pyplot as plt
7 | from jsl.demos.plot_utils import plot_ellipse
8 | from jax import random
9 | from jsl.lds.kalman_filter import LDS, smooth, filter
10 |
11 | def plot_uncertainty_ellipses(means, covs, ax):
12 | timesteps = len(means)
13 | for t in range(timesteps):
14 | plot_ellipse(covs[t], means[t], ax, plot_center=False, alpha=0.7)
15 |
16 | def main():
17 | dx = 1.1
18 | timesteps = 20
19 | key = random.PRNGKey(27182)
20 |
21 | mean_0 = jnp.array([1, 1, 1, 0]).astype(float)
22 | Sigma_0 = jnp.eye(4)
23 | A = jnp.array([
24 | [0.1, 1.1, dx, 0],
25 | [-1, 1, 0, dx],
26 | [0, 0, 0.1, 0],
27 | [0, 0, 0, 0.1]
28 | ])
29 | C = jnp.array([
30 | [1, 0, 0, 0],
31 | [0, 1, 0, 0]
32 | ])
33 | Q = jnp.eye(4) * 0.001
34 | R = jnp.eye(2) * 4
35 |
36 | lds_instance = LDS(A, C, Q, R, mean_0, Sigma_0)
37 | state_hist, obs_hist = lds_instance.sample(key, timesteps)
38 |
39 | res = filter(lds_instance, obs_hist)
40 | mean_hist, Sigma_hist, mean_cond_hist, Sigma_cond_hist = res
41 | mean_hist_smooth, Sigma_hist_smooth = smooth(lds_instance, mean_hist,
42 | Sigma_hist, mean_cond_hist, Sigma_cond_hist)
43 |
44 | dict_figures = {}
45 |
46 | fig_spiral_state, ax = plt.subplots()
47 | ax.plot(*state_hist[:, :2].T, linestyle="--")
48 | ax.scatter(*obs_hist.T, marker="+", s=60)
49 | ax.set_title("State space")
50 | dict_figures["spiral-state"] = fig_spiral_state
51 |
52 |
53 | fig_spiral_filtered, ax = plt.subplots()
54 | ax.plot(*mean_hist[:, :2].T)
55 | ax.scatter(*obs_hist.T, marker="+", s=60)
56 | plot_uncertainty_ellipses(mean_hist[:, :2], Sigma_hist[:, :2, :2], ax)
57 | ax.set_title("Filtered")
58 | dict_figures["spiral-filtered"] = fig_spiral_filtered
59 |
60 | fig_spiral_smoothed, ax = plt.subplots()
61 | ax.plot(*mean_hist_smooth[:, :2].T)
62 | ax.scatter(*obs_hist.T, marker="+", s=60)
63 | plot_uncertainty_ellipses(mean_hist_smooth[:, :2], Sigma_hist_smooth[:, :2, :2], ax)
64 | ax.set_title("Smoothed")
65 | dict_figures["spiral-smoothed"] = fig_spiral_smoothed
66 |
67 | return dict_figures
68 |
69 |
70 | if __name__ == "__main__":
71 | from jsl.demos.plot_utils import savefig
72 | figures = main()
73 | savefig(figures)
74 | plt.show()
75 |
--------------------------------------------------------------------------------
/jsl/demos/kf_tracking.py:
--------------------------------------------------------------------------------
1 | # This script produces an illustration of Kalman filtering and smoothing
2 |
3 | import jax.numpy as jnp
4 | import matplotlib.pyplot as plt
5 | from jax import random
6 | from jsl.demos.plot_utils import plot_ellipse
7 | from jsl.lds.kalman_filter import LDS, smooth, filter
8 |
9 |
10 | def plot_tracking_values(observed, filtered, cov_hist, signal_label, ax):
11 | """
12 | observed: array(nsteps, 2)
13 | Array of observed values
14 | filtered: array(nsteps, state_size)
15 | Array of latent (hidden) values. We consider only the first
16 | two dimensions of the latent values
17 | cov_hist: array(nsteps, state_size, state_size)
18 | History of the retrieved (filtered) covariance matrices
19 | ax: matplotlib AxesSubplot
20 | """
21 | timesteps, _ = observed.shape
22 | ax.plot(observed[:, 0], observed[:, 1], marker="o", linewidth=0,
23 | markerfacecolor="none", markeredgewidth=2, markersize=8, label="observed", c="tab:green")
24 | ax.plot(*filtered[:, :2].T, label=signal_label, c="tab:red", marker="x", linewidth=2)
25 | for t in range(0, timesteps, 1):
26 | covn = cov_hist[t][:2, :2]
27 | plot_ellipse(covn, filtered[t, :2], ax, n_std=2.0, plot_center=False)
28 | ax.axis("equal")
29 | ax.legend()
30 |
31 |
32 | def sample_filter_smooth(key, lds_model, timesteps):
33 | """
34 | Sample from a linear dynamical system, apply the kalman filter
35 | (forward pass), and performs smoothing.
36 |
37 | Parameters
38 | ----------
39 | lds: LinearDynamicalSystem
40 | Instance of a linear dynamical system with known parameters
41 |
42 | Returns
43 | -------
44 | Dictionary with the following key, values
45 | * (z_hist) array(timesteps, state_size):
46 | Simulation of Latent states
47 | * (x_hist) array(timesteps, observation_size):
48 | Simulation of observed states
49 | * (mu_hist) array(timesteps, state_size):
50 | Filtered means mut
51 | * (Sigma_hist) array(timesteps, state_size, state_size)
52 | Filtered covariances Sigmat
53 | * (mu_cond_hist) array(timesteps, state_size)
54 | Filtered conditional means mut|t-1
55 | * (Sigma_cond_hist) array(timesteps, state_size, state_size)
56 | Filtered conditional covariances Sigmat|t-1
57 | * (mu_hist_smooth) array(timesteps, state_size):
58 | Smoothed means mut
59 | * (Sigma_hist_smooth) array(timesteps, state_size, state_size)
60 | Smoothed covariances Sigmat
61 | """
62 | z_hist, x_hist = lds_model.sample(key, timesteps)
63 | mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist = filter(lds_model, x_hist)
64 | mu_hist_smooth, Sigma_hist_smooth = smooth(lds_model, mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist)
65 |
66 | return {
67 | "z_hist": z_hist,
68 | "x_hist": x_hist,
69 | "mu_hist": mu_hist,
70 | "Sigma_hist": Sigma_hist,
71 | "mu_cond_hist": mu_cond_hist,
72 | "Sigma_cond_hist": Sigma_cond_hist,
73 | "mu_hist_smooth": mu_hist_smooth,
74 | "Sigma_hist_smooth": Sigma_hist_smooth
75 | }
76 |
77 |
78 | def main():
79 | key = random.PRNGKey(314)
80 | timesteps = 15
81 | delta = 1.0
82 | A = jnp.array([
83 | [1, 0, delta, 0],
84 | [0, 1, 0, delta],
85 | [0, 0, 1, 0],
86 | [0, 0, 0, 1]
87 | ])
88 |
89 | C = jnp.array([
90 | [1, 0, 0, 0],
91 | [0, 1, 0, 0]
92 | ])
93 |
94 | state_size, _ = A.shape
95 | observation_size, _ = C.shape
96 |
97 | Q = jnp.eye(state_size) * 0.001
98 | R = jnp.eye(observation_size) * 1.0
99 | # Prior parameter distribution
100 | mu0 = jnp.array([8, 10, 1, 0]).astype(float)
101 | Sigma0 = jnp.eye(state_size) * 1.0
102 |
103 | lds_instance = LDS(A, C, Q, R, mu0, Sigma0)
104 | result = sample_filter_smooth(key, lds_instance, timesteps)
105 |
106 | l2_filter = jnp.linalg.norm(result["z_hist"][:, :2] - result["mu_hist"][:, :2], 2)
107 | l2_smooth = jnp.linalg.norm(result["z_hist"][:, :2] - result["mu_hist_smooth"][:, :2], 2)
108 |
109 | print(f"L2-filter: {l2_filter:0.4f}")
110 | print(f"L2-smooth: {l2_smooth:0.4f}")
111 |
112 | dict_figures = {}
113 |
114 | fig_truth, axs = plt.subplots()
115 | axs.plot(result["x_hist"][:, 0], result["x_hist"][:, 1],
116 | marker="o", linewidth=0, markerfacecolor="none",
117 | markeredgewidth=2, markersize=8,
118 | label="observed", c="tab:green")
119 |
120 | axs.plot(result["z_hist"][:, 0], result["z_hist"][:, 1],
121 | linewidth=2, label="truth",
122 | marker="s", markersize=8)
123 | axs.legend()
124 | axs.axis("equal")
125 | dict_figures["kalman-tracking-truth"] = fig_truth
126 |
127 | fig_filtered, axs = plt.subplots()
128 | plot_tracking_values(result["x_hist"], result["mu_hist"], result["Sigma_hist"], "filtered", axs)
129 | dict_figures["kalman-tracking-filtered"] = fig_filtered
130 |
131 | fig_smoothed, axs = plt.subplots()
132 | plot_tracking_values(result["x_hist"], result["mu_hist_smooth"], result["Sigma_hist_smooth"], "smoothed", axs)
133 | dict_figures["kalman-tracking-smoothed"] = fig_smoothed
134 |
135 | return dict_figures
136 |
137 |
138 | if __name__ == "__main__":
139 | from jsl.demos.plot_utils import savefig
140 |
141 | figures = main()
142 | savefig(figures)
143 | plt.show()
144 |
--------------------------------------------------------------------------------
/jsl/demos/lds_sampling_demo.py:
--------------------------------------------------------------------------------
1 | from cProfile import label
2 | import jax.numpy as jnp
3 | from jax.random import PRNGKey
4 | import tensorflow_probability as tfp
5 | import tensorflow as tf
6 | tfd = tfp.distributions
7 | import matplotlib.pyplot as plt
8 |
9 | from jsl.lds.kalman_filter import LDS, kalman_filter
10 | from jsl.lds.kalman_sampler import smooth_sampler
11 |
12 |
13 | # Define the 1-d LDS model using the LDS model in the packate tensorflow_probability
14 | ndims = 1
15 | step_std = 1.0
16 | noise_std = 5.0
17 | model = tfd.LinearGaussianStateSpaceModel(
18 | num_timesteps=100,
19 | transition_matrix=tf.linalg.LinearOperatorDiag(jnp.array([1.01])),
20 | transition_noise=tfd.MultivariateNormalDiag(
21 | scale_diag=step_std * tf.ones([ndims])),
22 | observation_matrix=tf.linalg.LinearOperatorIdentity(ndims),
23 | observation_noise=tfd.MultivariateNormalDiag(
24 | scale_diag=noise_std * tf.ones([ndims])),
25 | initial_state_prior=tfd.MultivariateNormalDiag(loc=jnp.array([5.0]),
26 | scale_diag=tf.ones([ndims])))
27 |
28 | # Sample from the prior of the LDS
29 | y = model.sample()
30 | # Posterior sampling of the state variable using the built in method of the LDS model (for the sake of comparison)
31 | smps = model.posterior_sample(y, sample_shape=50)
32 | s_tf = jnp.array(smps[:,:,0])
33 |
34 | # Define the same model as an LDS object defined in the kalman_filter file
35 | A = jnp.eye(1) * 1.01
36 | C = jnp.eye(1)
37 | Q = jnp.eye(1)
38 | R = jnp.eye(1) * 25.0
39 | mu0 = jnp.array([5.0])
40 | Sigma0 = jnp.eye(1)
41 | model_lds = LDS(A, C, Q, R, mu0, Sigma0)
42 | # Run the Kalman filter algorithm first
43 | mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist = kalman_filter(model_lds, jnp.array(y))
44 | # Sample backwards using the smoothing posterior
45 | smooth_sample = smooth_sampler(model_lds, PRNGKey(0), mu_hist, Sigma_hist, n_samples=50)
46 |
47 |
48 | # Plot the observation and posterior samples of state variables
49 | plt.plot(y, color='red', label='Observation') # Observation
50 | plt.plot(s_tf.T, alpha=0.12, color='blue') # Samples using TF built in function
51 | plt.plot(smooth_sample[:,:,0].T, alpha=0.12, color='green') # Kalman smoother backwards sampler.
52 | plt.legend()
53 | plt.show()
54 |
--------------------------------------------------------------------------------
/jsl/demos/linreg_kf.py:
--------------------------------------------------------------------------------
1 | # Online Bayesian linear regression in 1d using Kalman Filter
2 | # Based on: https://github.com/probml/pmtk3/blob/master/demos/linregOnlineDemoKalman.m
3 |
4 | # The latent state corresponds to the current estimate of the regression weights w.
5 | # The observation model has the form
6 | # p(y(t) | w(t), x(t)) = Gauss( C(t) * w(t), R(t))
7 | # where C(t) = X(t,:) is the observation matrix for step t.
8 | # The dynamics model has the form
9 | # p(w(t) | w(t-1)) = Gauss(A * w(t-1), Q)
10 | # where Q>0 allows for parameter drift.
11 | # We show that the result is equivalent to batch (offline) Bayesian inference.
12 |
13 | import jax.numpy as jnp
14 | import matplotlib.pyplot as plt
15 | from numpy.linalg import inv
16 | from jsl.lds.kalman_filter import LDS, kalman_filter
17 |
18 |
19 | def kf_linreg(X, y, R, mu0, Sigma0, F, Q):
20 | """
21 | Online estimation of a linear regression
22 | using Kalman Filters
23 | Parameters
24 | ----------
25 | X: array(n_obs, dimension)
26 | Matrix of features
27 | y: array(n_obs,)
28 | Array of observations
29 | Q: float
30 | Known variance
31 | mu0: array(dimension)
32 | Prior mean
33 | Sigma0: array(dimesion, dimension)
34 | Prior covariance matrix
35 | Returns
36 | -------
37 | * array(n_obs, dimension)
38 | Online estimation of parameters
39 | * array(n_obs, dimension, dimension)
40 | Online estimation of uncertainty
41 | """
42 | C = lambda t: X[t][None, ...]
43 | lds = LDS(F, C, Q, R, mu0, Sigma0)
44 |
45 | mu_hist, Sigma_hist, _, _ = kalman_filter(lds, y)
46 | return mu_hist, Sigma_hist
47 |
48 |
49 | def posterior_lreg(X, y, R, mu0, Sigma0):
50 | """
51 | Compute mean and covariance matrix of a
52 | Bayesian Linear regression
53 | Parameters
54 | ----------
55 | X: array(n_obs, dimension)
56 | Matrix of features
57 | y: array(n_obs,)
58 | Array of observations
59 | R: float
60 | Known variance
61 | mu0: array(dimension)
62 | Prior mean
63 | Sigma0: array(dimesion, dimension)
64 | Prior covariance matrix
65 | Returns
66 | -------
67 | * array(dimension)
68 | Posterior mean
69 | * array(n_obs, dimension, dimension)
70 | Posterior covariance matrix
71 | """
72 | Sn_bayes_inv = inv(Sigma0) + X.T @ X / R.item()
73 | b = inv(Sigma0) @ mu0 + X.T @ y / R.item()
74 | mn_bayes = jnp.linalg.solve(Sn_bayes_inv, b)
75 |
76 | return mn_bayes, Sn_bayes_inv
77 |
78 | def main():
79 | n_obs = 21
80 | timesteps = jnp.arange(n_obs)
81 | x = jnp.linspace(0, 20, n_obs)
82 | X = jnp.c_[jnp.ones(n_obs), x]
83 | F = jnp.eye(2)
84 | mu0 = jnp.zeros(2)
85 | Sigma0 = jnp.eye(2) * 10.
86 |
87 | Q, R = 0, 1
88 | Q, R = jnp.asarray([[Q]]), jnp.asarray([[R]])
89 | # Data from original matlab example
90 | y = jnp.array([2.4865, -0.3033, -4.0531, -4.3359, -6.1742, -5.604, -3.5069, -2.3257, -4.6377,
91 | -0.2327, -1.9858, 1.0284, -2.264, -0.4508, 1.1672, 6.6524, 4.1452, 5.2677, 6.3403, 9.6264, 14.7842])
92 |
93 | # Online estimation
94 | mu_hist, Sigma_hist = kf_linreg(X, y, R, mu0, Sigma0, F, Q)
95 | kf_var = Sigma_hist[-1, [0, 1], [0, 1]]
96 | w0_hist, w1_hist = mu_hist.T
97 | w0_err, w1_err = jnp.sqrt(Sigma_hist[:, [0, 1], [0, 1]].T)
98 |
99 | # Offline estimation
100 | (w0_post, w1_post), inv_Sigma_post = posterior_lreg(X, y, R, mu0, Sigma0)
101 | Sigma_post = inv(inv_Sigma_post)
102 | w0_std, w1_std = jnp.sqrt(Sigma_post[[0, 1], [0, 1]])
103 |
104 | dict_figures = {}
105 |
106 | fig, ax = plt.subplots()
107 | ax.errorbar(timesteps, w0_hist, w0_err, fmt="-o", label="$w_0$", color="black", fillstyle="none")
108 | ax.errorbar(timesteps, w1_hist, w1_err, fmt="-o", label="$w_1$", color="tab:red")
109 |
110 | ax.axhline(y=w0_post, c="black", label="$w_0$ batch")
111 | ax.axhline(y=w1_post, c="tab:red", linestyle="--", label="$w_1$ batch")
112 |
113 | ax.fill_between(timesteps, w0_post - w0_std, w0_post + w0_std, color="black", alpha=0.4)
114 | ax.fill_between(timesteps, w1_post - w1_std, w1_post + w1_std, color="tab:red", alpha=0.4)
115 |
116 | plt.legend()
117 | ax.set_xlabel("time")
118 | ax.set_ylabel("weights")
119 | ax.set_ylim(-8, 4)
120 | ax.set_xlim(-0.5, n_obs)
121 | dict_figures["linreg_online_kalman"] = fig
122 | return dict_figures
123 |
124 |
125 | if __name__ == "__main__":
126 | from jsl.demos.plot_utils import savefig
127 | plt.rcParams["axes.spines.right"] = False
128 | plt.rcParams["axes.spines.top"] = False
129 | dict_figures = main()
130 | savefig(dict_figures)
131 | plt.show()
--------------------------------------------------------------------------------
/jsl/demos/logreg_biclusters.py:
--------------------------------------------------------------------------------
1 | # Bayesian logistic regression in 2d for 2 class problem
2 | # We compare MCMC to Laplace
3 |
4 | # Dependencies:
5 | # * !pip install git+https://github.com/blackjax-devs/blackjax.git
6 |
7 | import jax
8 | import numpy as np
9 | import jax.numpy as jnp
10 | import matplotlib.pyplot as plt
11 | from blackjax import rmh
12 | from jax import random
13 | from functools import partial
14 | from jax.scipy.optimize import minimize
15 | from jax.scipy.stats import norm
16 |
17 |
18 | def sigmoid(x): return jnp.exp(x) / (1 + jnp.exp(x))
19 | def log_sigmoid(z): return z - jnp.log1p(jnp.exp(z))
20 |
21 |
22 | def plot_posterior_predictive(ax, X, Xspace, Zspace, title, colors, cmap="viridis"):
23 | ax.contourf(*Xspace, Zspace, cmap=cmap, levels=20)
24 | ax.scatter(*X.T, c=colors, edgecolors="gray", s=80)
25 | ax.set_title(title)
26 | ax.axis("off")
27 | plt.tight_layout()
28 |
29 |
30 | def inference_loop(rng_key, kernel, initial_state, num_samples):
31 | def one_step(state, rng_key):
32 | state, _ = kernel(rng_key, state)
33 | return state, state
34 |
35 | keys = jax.random.split(rng_key, num_samples)
36 | _, states = jax.lax.scan(one_step, initial_state, keys)
37 |
38 | return states
39 |
40 |
41 | def E_base(w, Phi, y, alpha):
42 | # Energy is the log joint
43 | an = Phi @ w
44 | log_an = log_sigmoid(an)
45 | log_likelihood_term = y * log_an + (1 - y) * jnp.log(1 - sigmoid(an))
46 | log_prior_term = -(alpha * w @ w / 2)
47 | return log_prior_term + log_likelihood_term.sum()
48 |
49 |
50 | def mcmc_logistic_posterior_sample(key, Phi, y, alpha=1.0, init_noise=1.0,
51 | n_samples=5_000, burnin=500, sigma_mcmc=0.8):
52 | """
53 | Sample from the posterior distribution of the weights
54 | of a 2d binary logistic regression model p(y=1|x,w) = sigmoid(w'x),
55 | using the random walk Metropolis-Hastings algorithm.
56 | """
57 | _, ndims = Phi.shape
58 | key, key_init = random.split(key)
59 | w0 = random.multivariate_normal(key, jnp.zeros(ndims), jnp.eye(ndims) * init_noise)
60 | energy = partial(E_base, Phi=Phi, y=y, alpha=alpha)
61 | mcmc_kernel = rmh(energy, sigma=jnp.ones(ndims) * sigma_mcmc)
62 | initial_state = mcmc_kernel.init(w0)
63 |
64 | states = inference_loop(key_init, mcmc_kernel.step, initial_state, n_samples)
65 | chains = states.position[burnin:, :]
66 | return chains
67 |
68 |
69 | def laplace_posterior(key, Phi, y, alpha=1.0, init_noise=1.0):
70 | N, M = Phi.shape
71 | w0 = random.multivariate_normal(key, jnp.zeros(M), jnp.eye(M) * init_noise)
72 | E = lambda w: -E_base(w, Phi, y, alpha) / len(y)
73 | res = minimize(E, w0, method="BFGS")
74 | w_laplace = res.x
75 | SN = jax.hessian(E)(w_laplace)
76 | return w_laplace, SN
77 |
78 |
79 | def make_dataset(seed=135):
80 | np.random.seed(seed)
81 | N = 30
82 | mu1 = np.hstack((np.ones((N, 1)), 5 * np.ones((N, 1))))
83 | mu2 = np.hstack((-5 * np.ones((N, 1)), np.ones((N, 1))))
84 | class1_std = 1
85 | class2_std = 1.1
86 | X_1 = np.add(class1_std * np.random.randn(N, 2), mu1)
87 | X_2 = np.add(2 * class2_std * np.random.randn(N, 2), mu2)
88 | X = np.vstack((X_1, X_2))
89 | y = np.vstack((np.ones((N, 1)), np.zeros((N, 1))))
90 | return X, y.ravel()
91 |
92 |
93 | def main():
94 | key = random.PRNGKey(314)
95 | ## Data generating process
96 | X, y = make_dataset()
97 | n_datapoints = len(y)
98 |
99 | Phi = jnp.c_[jnp.ones(n_datapoints)[:, None], X]
100 | N, M = Phi.shape
101 |
102 | colors = ["black" if el else "white" for el in y]
103 |
104 | # Predictive domain
105 | xmin, ymin = X.min(axis=0) - 0.1
106 | xmax, ymax = X.max(axis=0) + 0.1
107 | step = 0.1
108 | Xspace = jnp.mgrid[xmin:xmax:step, ymin:ymax:step]
109 | _, nx, ny = Xspace.shape
110 | Phispace = jnp.concatenate([jnp.ones((1, nx, ny)), Xspace])
111 |
112 |
113 | ## Laplace
114 | alpha = 2.0
115 | w_laplace, SN = laplace_posterior(key, Phi, y, alpha=alpha)
116 |
117 | ### MCMC Approximation
118 | chains = mcmc_logistic_posterior_sample(key, Phi, y, alpha=alpha)
119 | Z_mcmc = sigmoid(jnp.einsum("mij,sm->sij", Phispace, chains))
120 | Z_mcmc = Z_mcmc.mean(axis=0)
121 |
122 | ### *** Ploting surface predictive distribution ***
123 | colors = ["black" if el else "white" for el in y]
124 | dict_figures = {}
125 | key = random.PRNGKey(31415)
126 | nsamples = 5000
127 |
128 | # Laplace surface predictive distribution
129 | laplace_samples = random.multivariate_normal(key, w_laplace, SN, (nsamples,))
130 | Z_laplace = sigmoid(jnp.einsum("mij,sm->sij", Phispace, laplace_samples))
131 | Z_laplace = Z_laplace.mean(axis=0)
132 |
133 | fig_laplace, ax = plt.subplots()
134 | title = "Laplace Predictive distribution"
135 | plot_posterior_predictive(ax, X, Xspace, Z_laplace, title, colors)
136 | dict_figures["logistic_regression_surface_laplace"] = fig_laplace
137 |
138 | # MCMC surface predictive distribution
139 | fig_mcmc, ax = plt.subplots()
140 | title = "MCMC Predictive distribution"
141 | plot_posterior_predictive(ax, X, Xspace, Z_mcmc, title, colors)
142 | dict_figures["logistic_regression_surface_mcmc"] = fig_mcmc
143 |
144 |
145 | # *** Plotting posterior marginals of weights ***
146 | for i in range(M):
147 | fig_weights_marginals, ax = plt.subplots()
148 | mean_laplace, std_laplace = w_laplace[i], jnp.sqrt(SN[i, i])
149 | mean_mcmc, std_mcmc = chains[:, i].mean(), chains[:, i].std()
150 |
151 | x = jnp.linspace(mean_laplace - 4 * std_laplace, mean_laplace + 4 * std_laplace, 500)
152 | ax.plot(x, norm.pdf(x, mean_laplace, std_laplace), label="posterior (Laplace)", linestyle="dotted")
153 | ax.plot(x, norm.pdf(x, mean_mcmc, std_mcmc), label="posterior (MCMC)", linestyle="dashed")
154 | ax.legend()
155 | ax.set_title(f"Posterior marginals of weights ({i})")
156 | dict_figures[f"logistic_regression_weights_marginals_w{i}"] = fig_weights_marginals
157 |
158 |
159 | print("MCMC weights")
160 | w_mcmc = chains.mean(axis=0)
161 | print(w_mcmc, end="\n"*2)
162 |
163 | print("Laplace weights")
164 | print(w_laplace, end="\n"*2)
165 |
166 | dict_data = {
167 | "X": X,
168 | "y": y,
169 | "Xspace": Xspace,
170 | "Phi": Phi,
171 | "Phispace": Phispace,
172 | "w_laplace": w_laplace,
173 | "cov_laplace": SN
174 | }
175 |
176 | return dict_figures, dict_data
177 |
178 |
179 | if __name__ == "__main__":
180 | from jsl.demos.plot_utils import savefig
181 | figs, data = main()
182 | savefig(figs)
183 | plt.show()
184 |
--------------------------------------------------------------------------------
/jsl/demos/old-init.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | "sis_vs_smc",
3 | "ekf_vs_ukf",
4 | "ukf_mlp",
5 | "ekf_mlp",
6 | "eekf_logistic_regression",
7 | "ekf_mlp_anim",
8 | "linreg_kf",
9 | "kf_parallel",
10 | "kf_continuous_circle",
11 | "kf_spiral",
12 | "kf_tracking",
13 | "bootstrap_filter",
14 | "pendulum_1d",
15 | "ekf_continuous"
16 | ]
17 |
--------------------------------------------------------------------------------
/jsl/demos/pendulum_1d.py:
--------------------------------------------------------------------------------
1 | # Example of a 1D pendulum problem applied to the Extended Kalman Filter,
2 | # the Unscented Kalman Filter, and the Particle Filter (boostrap filter)
3 | # Additionally, we test the particle filter when the observations have a 40%
4 | # probability of being perturbed by a uniform(-2, 2) distribution
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | import matplotlib.pyplot as plt
9 | from jax import random
10 | from jax.ops import index_update
11 |
12 | from jsl.nlds.base import NLDS
13 | import jsl.nlds.extended_kalman_filter as ekf_lib
14 | import jsl.nlds.unscented_kalman_filter as ukf_lib
15 | import jsl.nlds.bootstrap_filter as b_lib
16 |
17 |
18 | def plot_filter_true(ax, time, estimate, obs, ground_truth, label, colors="tab:blue"):
19 | ax.plot(time, estimate, c="black", label=label)
20 | ax.scatter(time, obs, s=10, c="none", edgecolors=colors)
21 | ax.plot(time, ground_truth, c="gray", linewidth=8, alpha=0.5,
22 | label="true angle")
23 | ax.legend()
24 |
25 |
26 | def fz(x, g=0.1, dt=0.1):
27 | x1, x2 = x[0], x[1]
28 | x1_new = x1 + x2 * dt
29 | x2_new = x2 - g * jnp.sin(x1) * dt
30 | return jnp.asarray([x1_new, x2_new])
31 |
32 |
33 | def fx(x, *args):
34 | return jnp.sin(jnp.asarray([x[0]]))
35 |
36 |
37 | def main():
38 | # *** Define initial configuration ***
39 | g = 10
40 | dt = 0.015
41 | qc = 0.06
42 | Q = jnp.array([
43 | [qc * dt ** 3 / 3, qc * dt ** 2 / 2],
44 | [qc * dt ** 2 / 2, qc * dt]
45 | ])
46 |
47 | fx_vmap = jax.vmap(fx)
48 | fz_vec = jax.vmap(lambda x: fz(x, g=g, dt=dt))
49 |
50 | nsteps = 200
51 | Rt = jnp.eye(1) * 0.02
52 | x0 = jnp.array([1.5, 0.0]).astype(float)
53 | time = jnp.arange(0, nsteps * dt, dt)
54 |
55 | key = random.PRNGKey(3141)
56 | key_samples, key_pf, key_noisy = random.split(key, 3)
57 | model = NLDS(lambda x: fz(x, g=g, dt=dt), fx, Q, Rt)
58 | sample_state, sample_obs = model.sample(key, x0, nsteps)
59 |
60 | # *** Pertubed data ***
61 | key_noisy, key_values = random.split(key_noisy)
62 | sample_obs_noise = sample_obs.copy()
63 | samples_map = random.bernoulli(key_noisy, 0.5, (nsteps,))
64 | replacement_values = random.uniform(key_values, (samples_map.sum(),), minval=-2, maxval=2)
65 | sample_obs_noise = index_update(sample_obs_noise.ravel(), samples_map, replacement_values)
66 | colors = ["tab:red" if samp else "tab:blue" for samp in samples_map]
67 |
68 | # *** Perform filtering ****
69 | alpha, beta, kappa = 1, 0, 2
70 | state_size = 2
71 | Vinit = jnp.eye(state_size)
72 | ukf = NLDS(lambda x: fz(x, g=g, dt=dt), fx, Q, Rt, alpha, beta, kappa, state_size)
73 | particle_filter = NLDS(fz_vec, fx_vmap, Q, Rt)
74 |
75 | print("Filtering data...")
76 | _, ekf_hist = ekf_lib.filter(model, x0, sample_obs, return_params=["mean", "cov"])
77 | ekf_mean_hist, ekf_Sigma_hist = ekf_hist["mean"], ekf_hist["cov"]
78 | ukf_mean_hist, ukf_Sigma_hist = ukf_lib.filter(ukf, x0, sample_obs)
79 | pf_mean_hist = b_lib.filter(particle_filter, key_pf, x0, sample_obs, nsamples=4_000, Vinit=Vinit)
80 |
81 | print("Filtering outlier data...")
82 | _, ekf_perturbed_hist = ekf_lib.filter(model, x0, sample_obs_noise, return_params=["mean", "cov"])
83 | ekf_perturbed_mean_hist, ekf_Sigma_hist = ekf_perturbed_hist["mean"], ekf_perturbed_hist["cov"]
84 | ukf_perturbed_mean_hist, ukf_Sigma_hist = ukf_lib.filter(ukf, x0, sample_obs_noise)
85 | pf_perturbed_mean_hist = b_lib.filter(particle_filter, key_pf, x0, sample_obs_noise, nsamples=2_000)
86 |
87 | ekf_estimate = fx_vmap(ekf_mean_hist)
88 | ukf_estimate = fx_vmap(ukf_mean_hist)
89 | pf_estimate = fx_vmap(pf_mean_hist)
90 |
91 | ekf_perturbed_estimate = fx_vmap(ekf_perturbed_mean_hist)
92 | ukf_perturbed_estimate = fx_vmap(ukf_perturbed_mean_hist)
93 | pf_perturbed_estimate = fx_vmap(pf_perturbed_mean_hist)
94 | ground_truth = fx_vmap(sample_state)
95 |
96 | dict_figures = {}
97 | # *** Plot results ***
98 | fig, ax = plt.subplots()
99 | plot_filter_true(ax, time, ekf_estimate, sample_obs, ground_truth, "Extended KF")
100 | dict_figures["pendulum_ekf_1d_demo"] = fig
101 |
102 | fig, ax = plt.subplots()
103 | plot_filter_true(ax, time, ukf_estimate, sample_obs, ground_truth, "Unscented KF")
104 | dict_figures["pendulum_ukf_1d_demo"] = fig
105 |
106 | fig, ax = plt.subplots()
107 | plot_filter_true(ax, time, pf_estimate, sample_obs, ground_truth, "Bootstrap PF")
108 | dict_figures["pendulum_pf_1d_demo"] = fig
109 |
110 | fig, ax = plt.subplots()
111 | plot_filter_true(ax, time, pf_perturbed_estimate, sample_obs_noise,
112 | ground_truth, "Bootstrap PF (noisy)", colors=colors)
113 | dict_figures["pendulum_pf_noisy_1d_demo"] = fig
114 |
115 | fig, ax = plt.subplots()
116 | plot_filter_true(ax, time, ekf_perturbed_estimate, sample_obs_noise,
117 | ground_truth, "Extended KF (noisy)", colors=colors)
118 | dict_figures["pendulum_ekf_noisy_1d_demo"] = fig
119 |
120 | fig, ax = plt.subplots()
121 | plot_filter_true(ax, time, ukf_perturbed_estimate, sample_obs_noise,
122 | ground_truth, "Unscented KF (noisy)", colors=colors)
123 | dict_figures["pendulum_ukf_noisy_1d_demo"] = fig
124 |
125 | return dict_figures
126 |
127 |
128 | if __name__ == "__main__":
129 | from jsl.demos.plot_utils import savefig
130 |
131 | plt.rcParams["axes.spines.right"] = False
132 | plt.rcParams["axes.spines.top"] = False
133 | figures = main()
134 | savefig(figures)
135 | plt.show()
136 |
--------------------------------------------------------------------------------
/jsl/demos/plot_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from numpy import linalg
5 | from matplotlib.patches import Ellipse, transforms
6 | from mpl_toolkits.mplot3d import Axes3D
7 |
8 |
9 | # https://matplotlib.org/devdocs/gallery/statistics/confidence_ellipse.html
10 | def plot_ellipse(Sigma, mu, ax, n_std=3.0, facecolor='none', edgecolor='k', plot_center='true', **kwargs):
11 | cov = Sigma
12 | pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1])
13 |
14 | ell_radius_x = np.sqrt(1 + pearson)
15 | ell_radius_y = np.sqrt(1 - pearson)
16 | ellipse = Ellipse((0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2,
17 | facecolor=facecolor, edgecolor=edgecolor, **kwargs)
18 |
19 | scale_x = np.sqrt(cov[0, 0]) * n_std
20 | mean_x = mu[0]
21 |
22 | scale_y = np.sqrt(cov[1, 1]) * n_std
23 | mean_y = mu[1]
24 |
25 | transf = (transforms.Affine2D()
26 | .rotate_deg(45)
27 | .scale(scale_x, scale_y)
28 | .translate(mean_x, mean_y))
29 |
30 | ellipse.set_transform(transf + ax.transData)
31 |
32 | if plot_center:
33 | ax.plot(mean_x, mean_y, '.')
34 | return ax.add_patch(ellipse)
35 |
36 |
37 | def savedotfile(dotfiles):
38 | if "FIGDIR" in os.environ:
39 | figdir = os.environ["FIGDIR"]
40 | for name, dot in dotfiles.items():
41 | fname_full = os.path.join(figdir, name)
42 | dot.render(fname_full)
43 | print(f"saving dot file to {fname_full}")
44 |
45 |
46 | def savefig(figures, *args, **kwargs):
47 | if "FIGDIR" in os.environ:
48 | figdir = os.environ["FIGDIR"]
49 | for name, figure in figures.items():
50 | fname_full = os.path.join(figdir, name)
51 | print(f"saving image to {fname_full}")
52 | figure.savefig(f"{fname_full}.pdf", *args, **kwargs)
53 | figure.savefig(f"{fname_full}.png", *args, **kwargs)
54 |
55 |
56 | def scale_3d(ax, x_scale, y_scale, z_scale, factor):
57 | scale=np.diag([x_scale, y_scale, z_scale, 1.0])
58 | scale=scale*(1.0/scale.max())
59 | scale[3,3]=factor
60 | def short_proj():
61 | return np.dot(Axes3D.get_proj(ax), scale)
62 | return short_proj
63 |
64 | def style3d(ax, x_scale, y_scale, z_scale, factor=0.62):
65 | plt.gca().patch.set_facecolor('white')
66 | ax.w_xaxis.set_pane_color((0, 0, 0, 0))
67 | ax.w_yaxis.set_pane_color((0, 0, 0, 0))
68 | ax.w_zaxis.set_pane_color((0, 0, 0, 0))
69 | ax.get_proj = scale_3d(ax, x_scale, y_scale, z_scale, factor)
70 |
71 |
72 | def kdeg(x, X, h):
73 | """
74 | KDE under a gaussian kernel
75 |
76 | Parameters
77 | ----------
78 | x: array(eval, D)
79 | X: array(obs, D)
80 | h: float
81 |
82 | Returns
83 | -------
84 | array(eval):
85 | KDE around the observed values
86 | """
87 | N, D = X.shape
88 | nden, _ = x.shape
89 |
90 | Xhat = X.reshape(D, 1, N)
91 | xhat = x.reshape(D, nden, 1)
92 | u = xhat - Xhat
93 | u = linalg.norm(u, ord=2, axis=0) ** 2 / (2 * h ** 2)
94 | px = np.exp(-u).sum(axis=1) / (N * h * np.sqrt(2 * np.pi))
95 | return px
96 |
--------------------------------------------------------------------------------
/jsl/demos/rbpf_maneuver.py:
--------------------------------------------------------------------------------
1 | # Rao-Blackwellised particle filtering for jump markov linear systems
2 | # Based on: https://github.com/probml/pmtk3/blob/master/demos/rbpfManeuverDemo.m
3 |
4 | # !pip install matplotlib==3.4.2
5 |
6 | import jax
7 | import numpy as np
8 | import jax.numpy as jnp
9 | import seaborn as sns
10 | import matplotlib.pyplot as plt
11 | import jsl.lds.mixture_kalman_filter as kflib
12 | from jax import random
13 | from functools import partial
14 | from sklearn.preprocessing import OneHotEncoder
15 | from jax.scipy.special import logit
16 | from jsl.demos.plot_utils import kdeg, style3d
17 |
18 |
19 |
20 | def plot_3d_belief_state(mu_hist, dim, ax, skip=3, npoints=2000, azimuth=-30, elevation=30, h=0.5):
21 | nsteps = len(mu_hist)
22 | xmin, xmax = mu_hist[..., dim].min(), mu_hist[..., dim].max()
23 | xrange = jnp.linspace(xmin, xmax, npoints).reshape(-1, 1)
24 | res = np.apply_along_axis(lambda X: kdeg(xrange, X[..., None], h), 1, mu_hist)
25 | densities = res[..., dim]
26 | for t in range(0, nsteps, skip):
27 | tloc = t * np.ones(npoints)
28 | px = densities[t]
29 | ax.plot(tloc, xrange, px, c="tab:blue", linewidth=1)
30 | ax.set_zlim(0, 1)
31 | style3d(ax, 1.8, 1.2, 0.7, 0.8)
32 | ax.view_init(elevation, azimuth)
33 | ax.set_xlabel(r"$t$", fontsize=13)
34 | ax.set_ylabel(r"$x_{"f"d={dim}"",t}$", fontsize=13)
35 | ax.set_zlabel(r"$p(x_{d, t} \vert y_{1:t})$", fontsize=13)
36 |
37 |
38 | def main():
39 | TT = 0.1
40 | A = jnp.array([[1, TT, 0, 0],
41 | [0, 1, 0, 0],
42 | [0, 0, 1, TT],
43 | [0, 0, 0, 1]])
44 |
45 |
46 | B1 = jnp.array([0, 0, 0, 0])
47 | B2 = jnp.array([-1.225, -0.35, 1.225, 0.35])
48 | B3 = jnp.array([1.225, 0.35, -1.225, -0.35])
49 | B = jnp.stack([B1, B2, B3], axis=0)
50 |
51 | Q = 0.2 * jnp.eye(4)
52 | R = 10 * jnp.diag(jnp.array([2, 1, 2, 1]))
53 | C = jnp.eye(4)
54 |
55 |
56 | transition_matrix = jnp.array([
57 | [0.8, 0.1, 0.1],
58 | [0.1, 0.8, 0.1],
59 | [0.1, 0.1, 0.8]
60 | ])
61 |
62 | params = kflib.RBPFParamsDiscrete(A, B, C, Q, R, transition_matrix)
63 |
64 | nparticles = 1000
65 | nsteps = 100
66 | key = random.PRNGKey(1)
67 | keys = random.split(key, nsteps)
68 |
69 | x0 = (1, random.multivariate_normal(key, jnp.zeros(4), jnp.eye(4)))
70 | draw_state_fixed = partial(kflib.draw_state, params=params)
71 |
72 | # Create target dataset
73 | _, (latent_hist, state_hist, obs_hist) = jax.lax.scan(draw_state_fixed, x0, keys)
74 |
75 | # Perform filtering
76 | key_base = random.PRNGKey(31)
77 | key_mean_init, key_sample, key_state, key_next = random.split(key_base, 4)
78 | p_init = jnp.array([0.0, 1.0, 0.0])
79 |
80 | # Initial filter configuration
81 | mu_0 = 0.01 * random.normal(key_mean_init, (nparticles, 4))
82 | Sigma_0 = jnp.zeros((nparticles, 4,4))
83 | s0 = random.categorical(key_state, logit(p_init), shape=(nparticles,))
84 | weights_0 = jnp.ones(nparticles) / nparticles
85 | init_config = (key_next, mu_0, Sigma_0, weights_0, s0)
86 |
87 | rbpf_optimal_part = partial(kflib.rbpf_optimal, params=params, nparticles=nparticles)
88 | _, (mu_hist, Sigma_hist, weights_hist, s_hist, Ptk) = jax.lax.scan(rbpf_optimal_part, init_config, obs_hist)
89 | mu_hist_post_mean = jnp.einsum("ts,tsm->tm", weights_hist, mu_hist)
90 |
91 |
92 | dict_figures = {}
93 | # Plot target dataset
94 | color_dict = {0: "tab:green", 1: "tab:red", 2: "tab:blue"}
95 | fig, ax = plt.subplots()
96 | color_states_org = [color_dict[state] for state in np.array(latent_hist)]
97 | ax.scatter(*state_hist[:, [0, 2]].T, c="none", edgecolors=color_states_org, s=10)
98 | ax.scatter(*obs_hist[:, [0, 2]].T, s=5, c="black", alpha=0.6)
99 | ax.set_title("Data")
100 | dict_figures["rbpf-maneuver-data"] = fig
101 |
102 | # Plot filtered dataset
103 | fig, ax = plt.subplots()
104 | rbpf_mse = ((mu_hist_post_mean - state_hist)[:, [0, 2]] ** 2).mean(axis=0).sum()
105 | latent_hist_est = Ptk.mean(axis=1).argmax(axis=1)
106 | color_states_est = [color_dict[state] for state in np.array(latent_hist_est)]
107 | ax.scatter(*mu_hist_post_mean[:, [0, 2]].T, c="none", edgecolors=color_states_est, s=10)
108 | ax.set_title(f"RBPF MSE: {rbpf_mse:.2f}")
109 | dict_figures["rbpf-maneuver-trace"] = fig
110 |
111 | # Plot belief state of discrete system
112 | p_terms = Ptk.mean(axis=1)
113 | rbpf_error_rate = (latent_hist != p_terms.argmax(axis=1)).mean()
114 | fig, ax = plt.subplots(figsize=(2.5, 5))
115 | sns.heatmap(p_terms, cmap="viridis", cbar=False)
116 | plt.title(f"RBPF, error rate: {rbpf_error_rate:0.3}")
117 | dict_figures["rbpf-maneuver-discrete-belief"] = fig
118 |
119 | # Plot ground truth and MAP estimate
120 | ohe = OneHotEncoder(sparse=False)
121 | latent_hmap = ohe.fit_transform(latent_hist[:, None])
122 | latent_hmap_est = ohe.fit_transform(p_terms.argmax(axis=1)[:, None])
123 |
124 | fig, ax = plt.subplots(figsize=(2.5, 5))
125 | sns.heatmap(latent_hmap, cmap="viridis", cbar=False, ax=ax)
126 | ax.set_title("Data")
127 | dict_figures["rbpf-maneuver-discrete-ground-truth.pdf"] = fig
128 |
129 | fig, ax = plt.subplots(figsize=(2.5, 5))
130 | sns.heatmap(latent_hmap_est, cmap="viridis", cbar=False, ax=ax)
131 | ax.set_title(f"MAP (error rate: {rbpf_error_rate:0.4f})")
132 | dict_figures["rbpf-maneuver-discrete-map"] = fig
133 |
134 | # Plot belief for state space
135 | dims = [0, 2]
136 | for dim in dims:
137 | fig = plt.figure()
138 | ax = plt.axes(projection="3d")
139 | plot_3d_belief_state(mu_hist, dim, ax, h=1.1)
140 | # pml.savefig(f"rbpf-maneuver-belief-states-dim{dim}.pdf", pad_inches=0, bbox_inches="tight")
141 | dict_figures[f"rbpf-maneuver-belief-states-dim{dim}.pdf"] = fig
142 |
143 | return dict_figures
144 |
145 |
146 | if __name__ == "__main__":
147 | from jsl.demos.plot_utils import savefig
148 | plt.rcParams["axes.spines.right"] = False
149 | plt.rcParams["axes.spines.top"] = False
150 | dict_figures = main()
151 | savefig(dict_figures, pad_inches=0, bbox_inches="tight")
152 | plt.show()
153 |
--------------------------------------------------------------------------------
/jsl/demos/sis_vs_smc.py:
--------------------------------------------------------------------------------
1 | # This demo compares sequential importance sampling (SIS) to
2 | # sequential Monte Carlo (SMC) in the case of a non-markovian
3 | # Gaussian sequence model.
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import matplotlib.pyplot as plt
8 |
9 | from jsl.nlds.sequential_monte_carlo import NonMarkovianSequenceModel
10 |
11 | def find_path(ix_path, final_state):
12 | curr_state = final_state
13 | path = [curr_state]
14 | for i in range(1, 7):
15 | curr_state, _ = ix_path[:, -i, curr_state]
16 | path.append(curr_state)
17 | path = path[::-1]
18 | return path
19 |
20 |
21 | def plot_sis_weights(hist, n_steps, spacing=1.5, max_size=0.3):
22 | """
23 | Plot the evolution of weights in the sequential importance sampling (SIS) algorithm.
24 |
25 | Parameters
26 | ----------
27 | weights: array(n_particles, n_steps)
28 | Weights at each time step.
29 | n_steps: int
30 | Number of steps to plot.
31 | spacing: float
32 | Spacing between particles.
33 | max_size: float
34 | Maximum size of the particles.
35 | """
36 | fig, ax = plt.subplots(figsize=(8, 6))
37 | ax.set_aspect(1)
38 | weights_subset = hist["weights"][:n_steps]
39 | for col, weights_row in enumerate(weights_subset):
40 | norm_cst = weights_row.sum()
41 | radii = weights_row / norm_cst * max_size
42 | for row, rad in enumerate(radii):
43 | if col != n_steps - 1:
44 | plt.arrow(spacing * (col + 0.25), row, 0.6, 0, width=0.05,
45 | edgecolor="white", facecolor="tab:gray")
46 | circle = plt.Circle((spacing * col, row), rad, color="tab:red")
47 | ax.add_artist(circle)
48 |
49 | plt.xlim(-1, n_steps * spacing)
50 | plt.xlabel("Iteration (t)")
51 | plt.ylabel("Particle index (i)")
52 |
53 | xticks_pos = jnp.arange(0, n_steps * spacing - 1, 2)
54 | xticks_lab = jnp.arange(1, n_steps + 1)
55 | plt.xticks(xticks_pos, xticks_lab)
56 |
57 | return fig, ax
58 |
59 |
60 | def plot_smc_weights(hist, n_steps, spacing=1.5, max_size=0.3):
61 | """
62 | Plot the evolution of weights in the sequential Monte Carlo (SMC) algorithm.
63 |
64 | Parameters
65 | ----------
66 | weights: array(n_particles, n_steps)
67 | Weights at each time step.
68 | n_steps: int
69 | Number of steps to plot.
70 | spacing: float
71 | Spacing between particles.
72 | max_size: float
73 | Maximum size of the particles.
74 |
75 | Returns
76 | -------
77 | fig: matplotlib.figure.Figure
78 | Figure containing the plot.
79 | """
80 | fig, ax = plt.subplots(figsize=(8, 6))
81 | ax.set_aspect(1)
82 |
83 | weights_subset = hist["weights"][:n_steps]
84 | # sampled indices represent the "position" of weights at the next time step
85 | ix_subset = hist["indices"][:n_steps][1:]
86 |
87 | for it, (weights_row, p_target) in enumerate(zip(weights_subset, ix_subset)):
88 | norm_cst = weights_row.sum()
89 | radii = weights_row / norm_cst * max_size
90 |
91 | for particle_ix, (rad, target_ix) in enumerate(zip(radii, p_target)):
92 | if it != n_steps - 2:
93 | diff = particle_ix - target_ix
94 | plt.arrow(spacing * (it + 0.15), target_ix, 1.3, diff, width=0.05,
95 | edgecolor="white", facecolor="tab:gray", length_includes_head=True)
96 | circle = plt.Circle((spacing * it, particle_ix), rad, color="tab:blue")
97 | ax.add_artist(circle)
98 |
99 | plt.xlim(-1, n_steps * spacing - 2)
100 | plt.xlabel("Iteration (t)")
101 | plt.ylabel("Particle index (i)")
102 |
103 | xticks_pos = jnp.arange(0, n_steps * spacing - 2, 2)
104 | xticks_lab = jnp.arange(1, n_steps)
105 | plt.xticks(xticks_pos, xticks_lab)
106 |
107 | # ylims = ax.axes.get_ylim() # to-do: grab this value for SCM-particle descendents' plot
108 |
109 | return fig, ax
110 |
111 |
112 | def plot_smc_weights_unique(hist, n_steps, spacing=1.5, max_size=0.3):
113 | """
114 | We plot the evolution of particles that have been consistently resampled and form our final
115 | approximation of the target distribution using sequential Monte Carlo (SMC).
116 |
117 | Parameters
118 | ----------
119 | weights: array(n_particles, n_steps)
120 | Weights at each time step.
121 | n_steps: int
122 | Number of steps to plot.
123 | spacing: float
124 | Spacing between particles.
125 | max_size: float
126 | Maximum size of the particles.
127 |
128 |
129 | Returns
130 | -------
131 | fig: matplotlib.figure.Figure
132 | Figure containing the plot.
133 | """
134 | weights_subset = hist["weights"][:n_steps]
135 | # sampled indices represent the "position" of weights at the next time step
136 | ix_subset = hist["indices"][:n_steps][1:]
137 | ix_path = ix_subset[:n_steps - 2]
138 | ix_map = jnp.repeat(jnp.arange(5)[None, :], 6, axis=0)
139 | ix_path = jnp.stack([ix_path, ix_map], axis=0)
140 |
141 | fig, ax = plt.subplots(figsize=(8, 6))
142 | ax.set_aspect(1)
143 |
144 | for final_state in range(5):
145 | path = find_path(ix_path, final_state)
146 | path_beg, path_end = path[:-1], path[1:]
147 | for it, (beg, end) in enumerate(zip(path_beg, path_end)):
148 | diff = end - beg
149 | plt.arrow(spacing * (it + 0.15), beg, 1.3, diff, width=0.05,
150 | edgecolor="white", facecolor="tab:gray", alpha=1.0, length_includes_head=True)
151 |
152 | for it, weights_row in enumerate(weights_subset[:-1]):
153 | norm_cst = weights_row.sum()
154 | radii = weights_row / norm_cst * max_size
155 |
156 | for particle_ix, rad in enumerate(radii):
157 | circle = plt.Circle((spacing * it, particle_ix), rad, color="tab:blue")
158 | ax.add_artist(circle)
159 |
160 | plt.xlim(-1, n_steps * spacing - 2)
161 | plt.xlabel("Iteration (t)")
162 | plt.ylabel("Particle index (i)")
163 |
164 | xticks_pos = jnp.arange(0, n_steps * spacing - 2, 2)
165 | xticks_lab = jnp.arange(1, n_steps)
166 |
167 | plt.xticks(xticks_pos, xticks_lab)
168 |
169 | return fig, ax
170 |
171 |
172 | def main():
173 | params = {
174 | "phi": 0.9,
175 | "q": 1.0,
176 | "beta": 0.5,
177 | "r": 1.0,
178 | }
179 |
180 | key = jax.random.PRNGKey(314)
181 | key_sample, key_sis, key_scm = jax.random.split(key, 3)
182 | seq_model = NonMarkovianSequenceModel(**params)
183 | hist_target = seq_model.sample(key_sample, 100)
184 | observations = hist_target["y"]
185 |
186 | res_sis = seq_model.sequential_importance_sample(key_sis, observations, n_particles=5)
187 | res_smc = seq_model.sequential_monte_carlo(key_scm, observations, n_particles=5)
188 |
189 | # Plot SMC particle evolution
190 | n_steps = 6 + 2
191 | spacing = 2
192 |
193 | dict_figures = {}
194 |
195 | fig, ax = plot_sis_weights(res_sis, n_steps=7, spacing=spacing)
196 | plt.tight_layout()
197 | dict_figures["sis_weights"] = fig
198 |
199 | fig, ax = plot_smc_weights(res_smc, n_steps=n_steps, spacing=spacing)
200 | ylims = ax.axes.get_ylim()
201 | plt.tight_layout()
202 | dict_figures["smc_weights"] = fig
203 |
204 | fig, ax = plot_smc_weights_unique(res_smc, n_steps=n_steps, spacing=spacing)
205 | ax.set_ylim(*ylims)
206 | plt.tight_layout()
207 | dict_figures["smc_weights_unique"] = fig
208 |
209 | return dict_figures
210 |
211 |
212 | if __name__ == "__main__":
213 | from jsl.demos.plot_utils import savefig
214 | figures = main()
215 | savefig(figures)
216 | plt.show()
--------------------------------------------------------------------------------
/jsl/demos/superimport_test.py:
--------------------------------------------------------------------------------
1 | import superimport #https://github.com/probml/superimport
2 |
3 | import arviz as az
4 |
5 | from itertools import chain
6 | import jax
7 | import jax.numpy as jnp
8 | import matplotlib.pyplot as plt
9 | import blackjax.rmh as rmh
10 | from jax import random
11 | from functools import partial
12 | from jax.scipy.optimize import minimize
13 | from sklearn.datasets import make_biclusters
14 | from ..nlds.extended_kalman_filter import ExtendedKalmanFilter
15 | from jax.scipy.stats import norm
16 |
17 | print('hello world')
--------------------------------------------------------------------------------
/jsl/demos/ukf_mlp.py:
--------------------------------------------------------------------------------
1 | # Demo showcasing the training of an MLP with a single hidden layer using
2 | # Unscented Kalman Filtering (UKF).
3 | # In this demo, we consider the latent state to be the weights of an MLP.
4 | # The observed state at time t is the output of the MLP as influenced by the weights
5 | # at time t-1 and the covariate x[t].
6 | # The transition function between latent states is the identity function.
7 | # For more information, see
8 | # * UKF-based training algorithm for feed-forward neural networks with
9 | # application to XOR classification problem
10 | # https://ieeexplore.ieee.org/document/6234549
11 |
12 | import jax.numpy as jnp
13 | from jax import vmap
14 | from jax.random import PRNGKey, split, normal
15 | from jax.flatten_util import ravel_pytree
16 |
17 | import matplotlib.pyplot as plt
18 | from functools import partial
19 |
20 | import jsl.nlds.unscented_kalman_filter as ukf_lib
21 | from jsl.nlds.base import NLDS
22 | from jsl.demos.ekf_mlp import MLP, sample_observations, apply
23 | from jsl.demos.ekf_mlp import plot_mlp_prediction, plot_intermediate_steps, plot_intermediate_steps_single
24 |
25 |
26 | def f(x):
27 | return x - 10 * jnp.cos(x) * jnp.sin(x) + x ** 3
28 |
29 |
30 | def fz(W):
31 | return W
32 |
33 |
34 | def main():
35 | key = PRNGKey(314)
36 | key_sample_obs, key_weights, key_init = split(key, 3)
37 |
38 | all_figures = {}
39 |
40 | # *** MLP configuration ***
41 | n_hidden = 6
42 | n_out = 1
43 | n_in = 1
44 | model = MLP([n_hidden, n_out])
45 |
46 | batch_size = 20
47 | batch = jnp.ones((batch_size, n_in))
48 |
49 | variables = model.init(key_init, batch)
50 | W0, unflatten_fn = ravel_pytree(variables)
51 |
52 | fwd_mlp = partial(apply, model=model, unflatten_fn=unflatten_fn)
53 | # vectorised for multiple observations
54 | fwd_mlp_obs = vmap(fwd_mlp, in_axes=[None, 0])
55 | # vectorised for multiple weights
56 | fwd_mlp_weights = vmap(fwd_mlp, in_axes=[1, None])
57 | # vectorised for multiple observations and weights
58 | fwd_mlp_obs_weights = vmap(fwd_mlp_obs, in_axes=[0, None])
59 |
60 | # *** Generating training and test data ***
61 | n_obs = 200
62 | xmin, xmax = -3, 3
63 | sigma_y = 3.0
64 | x, y = sample_observations(key_sample_obs, f, n_obs, xmin, xmax, x_noise=0, y_noise=sigma_y)
65 | xtest = jnp.linspace(x.min(), x.max(), n_obs)
66 |
67 | # *** MLP Training with UKF ***
68 | n_params = W0.size
69 | W0 = normal(key_weights, (n_params,)) * 1 # initial random guess
70 | Q = jnp.eye(n_params) * 1e-4 # parameters do not change
71 | R = jnp.eye(1) * sigma_y ** 2 # observation noise is fixed
72 |
73 | Vinit = jnp.eye(n_params) * 5 # vague prior
74 | alpha, beta, kappa = 0.01, 2.0, 3.0 - n_params
75 | ukf = NLDS(fz, lambda w, x: fwd_mlp_weights(w, x).T, Q, R, alpha, beta, kappa, n_params)
76 | ukf_mu_hist, ukf_Sigma_hist = ukf_lib.filter(ukf, W0, y, x[:, None], Vinit)
77 | step = -1
78 | W_ukf, SW_ukf = ukf_mu_hist[step], ukf_Sigma_hist[step]
79 |
80 | fig, ax = plt.subplots()
81 | plot_mlp_prediction(key, x, y, xtest, fwd_mlp_obs_weights, W_ukf, SW_ukf, ax)
82 | ax.set_title("UKF + MLP")
83 | all_figures["ukf-mlp"] = fig
84 |
85 | fig, ax = plt.subplots(2, 2)
86 | intermediate_steps = [10, 20, 30, 40, 50, 60]
87 | plot_intermediate_steps(key, ax, fwd_mlp_obs_weights, intermediate_steps, xtest, ukf_mu_hist, ukf_Sigma_hist, x,
88 | y)
89 | plt.suptitle("UKF + MLP training")
90 | all_figures["ukf-mlp-intermediate"] = fig
91 | figures_intermediate = plot_intermediate_steps_single(key, "ukf", fwd_mlp_obs_weights,
92 | intermediate_steps, xtest, ukf_mu_hist, ukf_Sigma_hist, x,
93 | y)
94 | all_figures = {**all_figures, **figures_intermediate}
95 |
96 | return all_figures
97 |
98 |
99 | if __name__ == "__main__":
100 | from jsl.demos.plot_utils import savefig
101 |
102 | plt.rcParams["axes.spines.right"] = False
103 | plt.rcParams["axes.spines.top"] = False
104 | figures = main()
105 | savefig(figures)
106 | plt.show()
107 |
--------------------------------------------------------------------------------
/jsl/hmm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probml/JSL/649c1fa9709c9195d80f06d7a367112d67dc60c9/jsl/hmm/__init__.py
--------------------------------------------------------------------------------
/jsl/hmm/hmm_casino_test.py:
--------------------------------------------------------------------------------
1 | '''
2 | Simple sanity check for all four Hidden Markov Models' implementations.
3 | '''
4 | import jax.numpy as jnp
5 | from jax.random import PRNGKey
6 |
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 |
10 | import distrax
11 | from distrax import HMM
12 |
13 | from jsl.hmm.hmm_numpy_lib import HMMNumpy, hmm_forwards_backwards_numpy, hmm_viterbi_numpy
14 |
15 | from jsl.hmm.hmm_lib import HMMJax, hmm_viterbi_jax, hmm_forwards_backwards_jax
16 |
17 | import jsl.hmm.hmm_logspace_lib as hmm_logspace_lib
18 |
19 |
20 | def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False):
21 | """
22 | Plot the estimated smoothing/filtering/map of a sequence of hidden states.
23 | "Vertical gray bars denote times when the hidden
24 | state corresponded to state 1. Blue lines represent the
25 | posterior probability of being in that state given different subsets
26 | of observed data." See Markov and Hidden Markov models section for more info
27 | Parameters
28 | ----------
29 | inference_values: array(n_samples, state_size)
30 | Result of runnig smoothing method
31 | z_hist: array(n_samples)
32 | Latent simulation
33 | ax: matplotlib.axes
34 | state: int
35 | Decide which state to highlight
36 | map_estimate: bool
37 | Whether to plot steps (simple plot if False)
38 | """
39 | n_samples = len(inference_values)
40 | xspan = np.arange(1, n_samples + 1)
41 | spans = find_dishonest_intervals(z_hist)
42 | if map_estimate:
43 | ax.step(xspan, inference_values, where="post")
44 | else:
45 | ax.plot(xspan, inference_values[:, state])
46 |
47 | for span in spans:
48 | ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none")
49 | ax.set_xlim(1, n_samples)
50 | # ax.set_ylim(0, 1)
51 | ax.set_ylim(-0.1, 1.1)
52 | ax.set_xlabel("Observation number")
53 |
54 |
55 | def find_dishonest_intervals(z_hist):
56 | """
57 | Find the span of timesteps that the
58 | simulated systems turns to be in state 1
59 | Parameters
60 | ----------
61 | z_hist: array(n_samples)
62 | Result of running the system with two
63 | latent states
64 | Returns
65 | -------
66 | list of tuples with span of values
67 | """
68 | spans = []
69 | x_init = 0
70 | for t, _ in enumerate(z_hist[:-1]):
71 | if z_hist[t + 1] == 0 and z_hist[t] == 1:
72 | x_end = t
73 | spans.append((x_init, x_end))
74 | elif z_hist[t + 1] == 1 and z_hist[t] == 0:
75 | x_init = t + 1
76 | return spans
77 |
78 |
79 | # state transition matrix
80 | A = jnp.array([
81 | [0.95, 0.05],
82 | [0.10, 0.90]
83 | ])
84 |
85 | # observation matrix
86 | B = jnp.array([
87 | [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], # fair die
88 | [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10] # loaded die
89 | ])
90 |
91 | n_samples = 300
92 | init_state_dist = jnp.array([1, 1]) / 2
93 | hmm_numpy = HMMNumpy(np.array(A), np.array(B), np.array(init_state_dist))
94 | hmm_jax = HMMJax(A, B, init_state_dist)
95 | hmm = HMM(trans_dist=distrax.Categorical(probs=A),
96 | init_dist=distrax.Categorical(probs=init_state_dist),
97 | obs_dist=distrax.Categorical(probs=B))
98 | hmm_log = hmm_logspace_lib.HMM(trans_dist=distrax.Categorical(probs=A),
99 | init_dist=distrax.Categorical(probs=init_state_dist),
100 | obs_dist=distrax.Categorical(probs=B))
101 |
102 | seed = 314
103 | z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
104 |
105 | z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60]
106 | x_hist_str = "".join((np.array(x_hist) + 1).astype(str))[:60]
107 |
108 | print("Printing sample observed/latent...")
109 | print(f"x: {x_hist_str}")
110 | print(f"z: {z_hist_str}")
111 |
112 | # Do inference
113 | alpha_numpy, _, gamma_numpy, loglik_numpy = hmm_forwards_backwards_numpy(hmm_numpy,
114 | np.array(x_hist),
115 | len(x_hist))
116 | alpha_jax, _, gamma_jax, loglik_jax = hmm_forwards_backwards_jax(hmm_jax,
117 | x_hist,
118 | len(x_hist))
119 |
120 | alpha_log, _, gamma_log, loglik_log = hmm_logspace_lib.hmm_forwards_backwards_log(hmm_log,
121 | x_hist,
122 | len(x_hist))
123 | alpha, beta, gamma, loglik = hmm.forward_backward(x_hist)
124 |
125 | assert np.allclose(alpha_numpy, alpha)
126 | assert np.allclose(alpha_jax, alpha)
127 | assert np.allclose(jnp.exp(alpha_log), alpha)
128 |
129 | assert np.allclose(gamma_numpy, gamma)
130 | assert np.allclose(gamma_jax, gamma)
131 | assert np.allclose(jnp.exp(gamma_log), gamma)
132 |
133 | print(f"Loglikelihood(Distrax): {loglik}")
134 | print(f"Loglikelihood(Numpy): {loglik_numpy}")
135 | print(f"Loglikelihood(Jax): {loglik_jax}")
136 | print(f"Loglikelihood(Jax): {loglik_log}")
137 |
138 | z_map_numpy = hmm_viterbi_numpy(hmm_numpy, np.array(x_hist))
139 | z_map_jax = hmm_viterbi_jax(hmm_jax, x_hist)
140 | z_map_log = hmm_logspace_lib.hmm_viterbi_log(hmm_log, x_hist)
141 | z_map = hmm.viterbi(x_hist)
142 |
143 | assert np.allclose(z_map_numpy, z_map)
144 | assert np.allclose(z_map_jax, z_map)
145 | assert np.allclose(z_map_log, z_map)
146 |
147 | # Plot results
148 | fig, ax = plt.subplots()
149 | plot_inference(gamma_numpy, z_hist, ax)
150 | ax.set_ylabel("p(loaded)")
151 | ax.set_title("Smoothed")
152 | plt.savefig("hmm_casino_smooth_numpy.png")
153 | plt.show()
154 |
155 | fig, ax = plt.subplots()
156 | plot_inference(z_map_numpy, z_hist, ax, map_estimate=True)
157 | ax.set_ylabel("MAP state")
158 | ax.set_title("Viterbi")
159 | plt.savefig("hmm_casino_map_numpy.png")
160 | plt.show()
161 |
162 | # Plot results
163 | fig, ax = plt.subplots()
164 | plot_inference(gamma, z_hist, ax)
165 | ax.set_ylabel("p(loaded)")
166 | ax.set_title("Smoothed")
167 | # plt.savefig("hmm_casino_smooth_distrax.png")
168 | plt.show()
169 |
170 | fig, ax = plt.subplots()
171 | plot_inference(z_map, z_hist, ax, map_estimate=True)
172 | ax.set_ylabel("MAP state")
173 | ax.set_title("Viterbi")
174 | # plt.savefig("hmm_casino_map_distrax.png")
175 | plt.show()
176 |
177 | # Plot results
178 | fig, ax = plt.subplots()
179 | plot_inference(gamma_jax, z_hist, ax)
180 | ax.set_ylabel("p(loaded)")
181 | ax.set_title("Smoothed")
182 | # plt.savefig("hmm_casino_smooth_jax.png")
183 | plt.show()
184 |
185 | fig, ax = plt.subplots()
186 | plot_inference(z_map_jax, z_hist, ax, map_estimate=True)
187 | ax.set_ylabel("MAP state")
188 | ax.set_title("Viterbi")
189 | # plt.savefig("hmm_casino_map_jax.png")
190 | plt.show()
191 |
192 | # Plot results
193 | fig, ax = plt.subplots()
194 | plot_inference(jnp.exp(gamma_log), z_hist, ax)
195 | ax.set_ylabel("p(loaded)")
196 | ax.set_title("Smoothed")
197 | # plt.savefig("hmm_casino_smooth_log.png")
198 | plt.show()
199 |
200 | fig, ax = plt.subplots()
201 | plot_inference(z_map_log, z_hist, ax, map_estimate=True)
202 | ax.set_ylabel("MAP state")
203 | ax.set_title("Viterbi")
204 | # plt.savefig("hmm_casino_map_log.png")
205 | plt.show()
206 |
--------------------------------------------------------------------------------
/jsl/hmm/hmm_lib_test.py:
--------------------------------------------------------------------------------
1 | '''
2 | This demo compares the Jax, Numpy and Distrax version of forwards-backwards algorithm in terms of the speed.
3 | Also, checks whether or not they give the same result.
4 | Author : Aleyna Kara (@karalleyna)
5 | '''
6 |
7 | import jax.numpy as jnp
8 | from jax import vmap, nn
9 | from jax.random import split, PRNGKey, uniform, normal
10 |
11 | import distrax
12 | from distrax import HMM
13 |
14 | import chex
15 |
16 | import numpy as np
17 | import time
18 |
19 | from jsl.hmm.hmm_numpy_lib import HMMNumpy, hmm_forwards_backwards_numpy, hmm_loglikelihood_numpy
20 | from jsl.hmm.hmm_lib import HMMJax, hmm_viterbi_jax
21 | from jsl.hmm.hmm_lib import hmm_sample_jax, hmm_forwards_backwards_jax, hmm_loglikelihood_jax
22 | from jsl.hmm.hmm_lib import normalize, fixed_lag_smoother
23 | import jsl.hmm.hmm_utils as hmm_utils
24 |
25 |
26 | from tensorflow_probability.substrates import jax as tfp
27 |
28 | tfd = tfp.distributions
29 |
30 | #######
31 | # Test log likelihood
32 |
33 | def loglikelihood_numpy(params_numpy, batches, lens):
34 | return np.array([hmm_loglikelihood_numpy(params_numpy, batch, l) for batch, l in zip(batches, lens)])
35 |
36 | def loglikelihood_jax(params_jax, batches, lens):
37 | return vmap(hmm_loglikelihood_jax, in_axes=(None, 0, 0))(params_jax, batches, lens)[:,:, 0]
38 |
39 |
40 | def test_all_hmm_models():
41 | # state transition matrix
42 | A = jnp.array([
43 | [0.95, 0.05],
44 | [0.10, 0.90]
45 | ])
46 |
47 | # observation matrix
48 | B = jnp.array([
49 | [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
50 | [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
51 | ])
52 |
53 | pi = jnp.array([1, 1]) / 2
54 |
55 | params_numpy= HMMNumpy(np.array(A), np.array(B), np.array(pi))
56 | params_jax = HMMJax(A, B, pi)
57 |
58 | seed = 0
59 | rng_key = PRNGKey(seed)
60 | rng_key, rng_sample = split(rng_key)
61 |
62 | n_obs_seq, batch_size, max_len = 15, 5, 10
63 |
64 | observations, lens = hmm_utils.hmm_sample_n(params_jax,
65 | hmm_sample_jax,
66 | n_obs_seq, max_len,
67 | rng_sample)
68 |
69 | observations, lens = hmm_utils.pad_sequences(observations, lens)
70 |
71 | rng_key, rng_batch = split(rng_key)
72 | batches, lens = hmm_utils.hmm_sample_minibatches(observations,
73 | lens,
74 | batch_size,
75 | rng_batch)
76 |
77 | ll_numpy = loglikelihood_numpy(params_numpy, np.array(batches), np.array(lens))
78 | ll_jax = loglikelihood_jax(params_jax, batches, lens)
79 | assert np.allclose(ll_numpy, ll_jax, atol=4)
80 |
81 |
82 | def test_inference():
83 | seed = 0
84 | rng_key = PRNGKey(seed)
85 | rng_key, key_A, key_B = split(rng_key, 3)
86 |
87 | # state transition matrix
88 | n_hidden, n_obs = 100, 10
89 | A = uniform(key_A, (n_hidden, n_hidden))
90 | A = A / jnp.sum(A, axis=1)
91 |
92 | # observation matrix
93 | B = uniform(key_B, (n_hidden, n_obs))
94 | B = B / jnp.sum(B, axis=1).reshape((-1, 1))
95 |
96 | n_samples = 1000
97 | init_state_dist = jnp.ones(n_hidden) / n_hidden
98 |
99 | seed = 0
100 | rng_key = PRNGKey(seed)
101 |
102 | params_numpy = HMMNumpy(A, B, init_state_dist)
103 | params_jax = HMMJax(A, B, init_state_dist)
104 | hmm_distrax = HMM(trans_dist=distrax.Categorical(probs=A),
105 | obs_dist=distrax.Categorical(probs=B),
106 | init_dist=distrax.Categorical(probs=init_state_dist))
107 |
108 | z_hist, x_hist = hmm_sample_jax(params_jax, n_samples, rng_key)
109 |
110 | start = time.time()
111 | alphas_np, _, gammas_np, loglikelihood_np = hmm_forwards_backwards_numpy(params_numpy, x_hist, len(x_hist))
112 | print(f'Time taken by numpy version of forwards backwards : {time.time()-start}s')
113 |
114 | start = time.time()
115 | alphas_jax, _, gammas_jax, loglikelihood_jax = hmm_forwards_backwards_jax(params_jax, jnp.array(x_hist), len(x_hist))
116 | print(f'Time taken by JAX version of forwards backwards: {time.time()-start}s')
117 |
118 | start = time.time()
119 | alphas, _, gammas, loglikelihood = hmm_distrax.forward_backward(obs_seq=jnp.array(x_hist),
120 | length=len(x_hist))
121 |
122 | print(f'Time taken by HMM distrax : {time.time()-start}s')
123 |
124 | assert np.allclose(alphas_np, alphas_jax)
125 | assert np.allclose(loglikelihood_np, loglikelihood_jax)
126 | assert np.allclose(gammas_np, gammas_jax)
127 |
128 | assert np.allclose(alphas, alphas_jax, atol=8)
129 | assert np.allclose(loglikelihood, loglikelihood_jax, atol=8)
130 | assert np.allclose(gammas, gammas_jax, atol=8)
131 |
132 |
133 | def _make_models(init_probs, trans_probs, obs_probs, length):
134 | """Build distrax HMM and equivalent TFP HMM."""
135 |
136 | dx_model = HMMJax(
137 | trans_probs,
138 | obs_probs,
139 | init_probs
140 | )
141 |
142 | tfp_model = tfd.HiddenMarkovModel(
143 | initial_distribution=tfd.Categorical(probs=init_probs),
144 | transition_distribution=tfd.Categorical(probs=trans_probs),
145 | observation_distribution=tfd.Categorical(probs=obs_probs),
146 | num_steps=length,
147 | )
148 |
149 | return dx_model, tfp_model
150 |
151 |
152 | def test_sample(length, num_states):
153 | params_fn = obs_dist_name_and_params_fn
154 |
155 | init_probs = nn.softmax(normal(PRNGKey(0), (num_states,)), axis=-1)
156 | trans_mat = nn.softmax(normal(PRNGKey(1), (num_states, num_states)), axis=-1)
157 |
158 | model, tfp_model = _make_models(init_probs,
159 | trans_mat,
160 | params_fn(num_states),
161 | length)
162 |
163 | states, obs = hmm_sample_jax(model, length, PRNGKey(0))
164 | tfp_obs = tfp_model.sample(seed=PRNGKey(0))
165 |
166 | chex.assert_shape(states, (length,))
167 | chex.assert_equal_shape([obs, tfp_obs])
168 |
169 |
170 | def test_forward_backward(length, num_states):
171 | params_fn = obs_dist_name_and_params_fn
172 |
173 | init_probs = nn.softmax(normal(PRNGKey(0), (num_states,)), axis=-1)
174 | trans_mat = nn.softmax(normal(PRNGKey(1), (num_states, num_states)), axis=-1)
175 |
176 | model, tfp_model = _make_models(init_probs,
177 | trans_mat,
178 | params_fn(num_states),
179 | length)
180 |
181 | _, observations = hmm_sample_jax(model, length, PRNGKey(42))
182 |
183 | alphas, betas, marginals, log_prob = hmm_forwards_backwards_jax(model,
184 | observations)
185 |
186 | tfp_marginal_logits = tfp_model.posterior_marginals(observations).logits
187 | tfp_marginals = nn.softmax(tfp_marginal_logits)
188 |
189 | chex.assert_shape(alphas, (length, num_states))
190 | chex.assert_shape(betas, (length, num_states))
191 | chex.assert_shape(marginals, (length, num_states))
192 | chex.assert_shape(log_prob, (1,))
193 | np.testing.assert_array_almost_equal(marginals, tfp_marginals, decimal=4)
194 |
195 |
196 | def test_viterbi(length, num_states):
197 | params_fn = obs_dist_name_and_params_fn
198 |
199 | init_probs = nn.softmax(normal(PRNGKey(0), (num_states,)), axis=-1)
200 | trans_mat = nn.softmax(normal(PRNGKey(1), (num_states, num_states)), axis=-1)
201 |
202 | model, tfp_model = _make_models(init_probs,
203 | trans_mat,
204 | params_fn(num_states),
205 | length)
206 |
207 | _, observations = hmm_sample_jax(model, length, PRNGKey(42))
208 | most_likely_states = hmm_viterbi_jax(model, observations)
209 | tfp_mode = tfp_model.posterior_mode(observations)
210 | chex.assert_shape(most_likely_states, (length,))
211 | assert np.allclose(most_likely_states, tfp_mode)
212 |
213 | '''
214 | ########
215 | #Test Fixed Lag Smoother
216 |
217 | # helper function
218 | def get_fls_result(params, data, win_len, act=None):
219 | assert data.size > 2, "Complete observation set must be of size at least 2"
220 | prior, obs_mat = params.init_dist, params.obs_mat
221 | n_states = obs_mat.shape[0]
222 | alpha, _ = normalize(prior * obs_mat[:, data[0]])
223 | bmatrix = jnp.eye(n_states)[None, :]
224 | for obs in data[1:]:
225 | alpha, bmatrix, gamma = fixed_lag_smoother(params, win_len, alpha, bmatrix, obs)
226 | return alpha, gamma
227 |
228 | *_, gammas_fls = get_fls_result(params_jax, jnp.array(x_hist), jnp.array(x_hist).size)
229 |
230 | assert np.allclose(gammas_fls, gammas_jax)
231 | '''
232 | obs_dist_name_and_params_fn = lambda n: nn.softmax(normal(PRNGKey(0), (n, 7)), axis=-1)
233 |
234 | ### Tests
235 | test_all_hmm_models()
236 | test_inference()
237 |
238 | for length, num_states in zip([1, 3], (2, 23)):
239 | test_viterbi(length, num_states)
240 | test_forward_backward(length, num_states)
241 | test_sample(length, num_states)
242 |
243 |
244 |
--------------------------------------------------------------------------------
/jsl/hmm/hmm_logspace_lib_test.py:
--------------------------------------------------------------------------------
1 | '''
2 | This demo compares the log space version of Hidden Markov Model for discrete observations and general hidden markov model
3 | in terms of the speed. It also checks whether or not the inference algorithms give the same result.
4 | Author : Aleyna Kara (@karalleyna)
5 | '''
6 |
7 | import time
8 | import distrax
9 | import jax.numpy as jnp
10 | from jax.random import PRNGKey, split, uniform
11 | import numpy as np
12 | from hmm_logspace_lib import HMM, hmm_forwards_backwards_log, hmm_viterbi_log, hmm_sample_log
13 |
14 | seed = 0
15 | rng_key = PRNGKey(seed)
16 | rng_key, key_A, key_B = split(rng_key, 3)
17 |
18 | # state transition matrix
19 | n_hidden, n_obs = 100, 10
20 | A = uniform(key_A, (n_hidden, n_hidden))
21 | A = A / jnp.sum(A, axis=1)
22 |
23 | # observation matrix
24 | B = uniform(key_B, (n_hidden, n_obs))
25 | B = B / jnp.sum(B, axis=1).reshape((-1, 1))
26 |
27 | n_samples = 1000
28 | init_state_dist = jnp.ones(n_hidden) / n_hidden
29 |
30 | seed = 0
31 | rng_key = PRNGKey(seed)
32 |
33 | hmm = HMM(trans_dist=distrax.Categorical(probs=A),
34 | obs_dist=distrax.Categorical(probs=B),
35 | init_dist=distrax.Categorical(probs=init_state_dist))
36 |
37 | hmm_distrax = distrax.HMM(trans_dist=distrax.Categorical(probs=A),
38 | obs_dist=distrax.Categorical(probs=B),
39 | init_dist=distrax.Categorical(probs=init_state_dist))
40 |
41 | z_hist, x_hist = hmm_sample_log(hmm, n_samples, rng_key)
42 |
43 | start = time.time()
44 | alphas, _, gammas, loglikelihood = hmm_distrax.forward_backward(x_hist, len(x_hist))
45 | print(f'Time taken by Forwards Backwards function of HMM general: {time.time() - start}s')
46 | print(f'Loglikelihood found by HMM general: {loglikelihood}')
47 |
48 | start = time.time()
49 | alphas_log, _, gammas_log, loglikelihood_log = hmm_forwards_backwards_log(hmm, x_hist, len(x_hist))
50 | print(f'Time taken by Forwards Backwards function of HMM Log Space Version: {time.time() - start}s')
51 | print(f'Loglikelihood found by HMM General Log Space Version: {loglikelihood_log}')
52 |
53 | assert np.allclose(jnp.log(alphas), alphas_log, 8)
54 | assert np.allclose(loglikelihood, loglikelihood_log)
55 | assert np.allclose(jnp.log(gammas), gammas_log, 8)
56 |
57 | # Test for the hmm_viterbi_log. This test is based on https://github.com/deepmind/distrax/blob/master/distrax/_src/utils/hmm_test.py
58 | loc = jnp.array([0.0, 1.0, 2.0, 3.0])
59 | scale = jnp.array(0.25)
60 | initial = jnp.array([0.25, 0.25, 0.25, 0.25])
61 | trans = jnp.array([[0.9, 0.1, 0.0, 0.0],
62 | [0.1, 0.8, 0.1, 0.0],
63 | [0.0, 0.1, 0.8, 0.1],
64 | [0.0, 0.0, 0.1, 0.9]])
65 |
66 | observations = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 3.0, 2.9, 2.8, 2.7, 2.6])
67 |
68 | model = HMM(
69 | init_dist=distrax.Categorical(probs=initial),
70 | trans_dist=distrax.Categorical(probs=trans),
71 | obs_dist=distrax.Normal(loc, scale))
72 |
73 | inferred_states = hmm_viterbi_log(model, observations)
74 | expected_states = [0, 0, 0, 0, 1, 2, 3, 3, 3, 3]
75 |
76 | assert np.allclose(inferred_states, expected_states)
77 |
78 | length = 7
79 | inferred_states = hmm_viterbi_log(model, observations, length)
80 | expected_states = [0, 0, 0, 0, 1, 2, 3, -1, -1, -1]
81 | assert np.allclose(inferred_states, expected_states)
82 |
--------------------------------------------------------------------------------
/jsl/hmm/hmm_utils.py:
--------------------------------------------------------------------------------
1 | # Common functions that can be used for any hidden markov model type.
2 | # Author: Aleyna Kara(@karalleyna)
3 |
4 | import jax.numpy as jnp
5 | import matplotlib.pyplot as plt
6 | from jax import vmap, jit
7 | from jax.random import split, randint, PRNGKey, permutation
8 | from functools import partial
9 | # !pip install graphviz
10 | from graphviz import Digraph
11 |
12 |
13 | @partial(jit, static_argnums=(2,))
14 | def hmm_sample_minibatches(observations, valid_lens, batch_size, rng_key):
15 | '''
16 | Creates minibatches consists of the random permutations of the
17 | given observation sequences
18 |
19 | Parameters
20 | ----------
21 | observations : array(N, seq_len)
22 | All observation sequences
23 |
24 | valid_lens : array(N, seq_len)
25 | Consists of the valid length of each observation sequence
26 |
27 | batch_size : int
28 | The number of observation sequences that will be included in
29 | each minibatch
30 |
31 | rng_key : array
32 | Random key of shape (2,) and dtype uint32
33 |
34 | Returns
35 | -------
36 | * array(num_batches, batch_size, max_len)
37 | Minibatches
38 | '''
39 | num_train = len(observations)
40 | perm = permutation(rng_key, num_train)
41 |
42 | def create_mini_batch(batch_idx):
43 | return observations[batch_idx], valid_lens[batch_idx]
44 |
45 | num_batches = num_train // batch_size
46 | batch_indices = perm.reshape((num_batches, -1))
47 | minibatches = vmap(create_mini_batch)(batch_indices)
48 | return minibatches
49 |
50 |
51 | @partial(jit, static_argnums=(1, 2, 3))
52 | def hmm_sample_n(params, sample_fn, n, max_len, rng_key):
53 | '''
54 | Generates n observation sequences from the given Hidden Markov Model
55 |
56 | Parameters
57 | ----------
58 | params : HMMNumpy or HMMJax
59 | Hidden Markov Model
60 |
61 | sample_fn :
62 | The sample function of the given hidden markov model
63 |
64 | n : int
65 | The total number of observation sequences
66 |
67 | max_len : int
68 | The upper bound of the length of each observation sequence. Note that the valid length of the observation
69 | sequence is less than or equal to the upper bound.
70 |
71 | rng_key : array
72 | Random key of shape (2,) and dtype uint32
73 |
74 | Returns
75 | -------
76 | * array(n, max_len)
77 | Observation sequences
78 | '''
79 |
80 | def sample_(params, n_samples, key):
81 | return sample_fn(params, n_samples, key)[1]
82 |
83 | rng_key, rng_lens = split(rng_key)
84 | lens = randint(rng_lens, (n,), minval=1, maxval=max_len + 1)
85 | keys = split(rng_key, n)
86 | observations = vmap(sample_, in_axes=(None, None, 0))(params, max_len, keys)
87 | return observations, lens
88 |
89 |
90 | @jit
91 | def pad_sequences(observations, valid_lens, pad_val=0):
92 | '''
93 | Generates n observation sequences from the given Hidden Markov Model
94 |
95 | Parameters
96 | ----------
97 |
98 | observations : array(N, seq_len)
99 | All observation sequences
100 |
101 | valid_lens : array(N, seq_len)
102 | Consists of the valid length of each observation sequence
103 |
104 | pad_val : int
105 | Value that the invalid observable events of the observation sequence will be replaced
106 |
107 | Returns
108 | -------
109 | * array(n, max_len)
110 | Ragged dataset
111 | '''
112 |
113 | def pad(seq, len):
114 | idx = jnp.arange(1, seq.shape[0] + 1)
115 | return jnp.where(idx <= len, seq, pad_val)
116 |
117 | ragged_dataset = vmap(pad, in_axes=(0, 0))(observations, valid_lens), valid_lens
118 | return ragged_dataset
119 |
120 |
121 | def hmm_plot_graphviz(trans_mat, obs_mat, states=[], observations=[]):
122 | """
123 | Visualizes HMM transition matrix and observation matrix using graphhiz.
124 |
125 | Parameters
126 | ----------
127 | trans_mat, obs_mat, init_dist: arrays
128 |
129 | states: List(num_hidden)
130 | Names of hidden states
131 |
132 | observations: List(num_obs)
133 | Names of observable events
134 |
135 | Returns
136 | -------
137 | dot object, that can be displayed in colab
138 | """
139 |
140 | n_states, n_obs = obs_mat.shape
141 |
142 | dot = Digraph(comment='HMM')
143 | if not states:
144 | states = [f'State {i + 1}' for i in range(n_states)]
145 | if not observations:
146 | observations = [f'Obs {i + 1}' for i in range(n_obs)]
147 |
148 | # Creates hidden state nodes
149 | for i, name in enumerate(states):
150 | table = [f'{observations[j]} | {"%.2f" % prob} |
' for j, prob in
151 | enumerate(obs_mat[i])]
152 | label = f'''<>'''
153 | dot.node(f's{i}', label=label)
154 |
155 | # Writes transition probabilities
156 | for i in range(n_states):
157 | for j in range(n_states):
158 | dot.edge(f's{i}', f's{j}', label=str('%.2f' % trans_mat[i, j]))
159 | dot.attr(rankdir='LR')
160 | # dot.render(file_name, view=True)
161 | return dot
162 |
--------------------------------------------------------------------------------
/jsl/hmm/old/hmm_discrete_lib_test.py:
--------------------------------------------------------------------------------
1 | '''
2 | This demo compares the Jax, Numpy and Distrax version of forwards-backwards algorithm in terms of the speed.
3 | Also, checks whether or not they give the same result.
4 | Author : Aleyna Kara (@karalleyna)
5 | '''
6 |
7 | import superimport
8 |
9 | import time
10 |
11 | import jax.numpy as jnp
12 | from jax.random import PRNGKey, split, uniform
13 | import numpy as np
14 |
15 | from hmm_discrete_lib import HMMJax, HMMNumpy
16 | from hmm_discrete_lib import hmm_sample_jax, hmm_forwards_backwards_jax, hmm_forwards_backwards_numpy
17 |
18 | import distrax
19 | from distrax import HMM
20 |
21 | seed = 0
22 | rng_key = PRNGKey(seed)
23 | rng_key, key_A, key_B = split(rng_key, 3)
24 |
25 | # state transition matrix
26 | n_hidden, n_obs = 100, 10
27 | A = uniform(key_A, (n_hidden, n_hidden))
28 | A = A / jnp.sum(A, axis=1)
29 |
30 | # observation matrix
31 | B = uniform(key_B, (n_hidden, n_obs))
32 | B = B / jnp.sum(B, axis=1).reshape((-1, 1))
33 |
34 | n_samples = 1000
35 | init_state_dist = jnp.ones(n_hidden) / n_hidden
36 |
37 | seed = 0
38 | rng_key = PRNGKey(seed)
39 |
40 | params_numpy = HMMNumpy(A, B, init_state_dist)
41 | params_jax = HMMJax(A, B, init_state_dist)
42 | hmm_distrax = HMM(trans_dist=distrax.Categorical(probs=A),
43 | obs_dist=distrax.Categorical(probs=B),
44 | init_dist=distrax.Categorical(probs=init_state_dist))
45 |
46 | z_hist, x_hist = hmm_sample_jax(params_jax, n_samples, rng_key)
47 |
48 | start = time.time()
49 | alphas_np, _, gammas_np, loglikelihood_np = hmm_forwards_backwards_numpy(params_numpy, x_hist, len(x_hist))
50 | print(f'Time taken by numpy version of forwards backwards : {time.time()-start}s')
51 |
52 | start = time.time()
53 | alphas_jax, _, gammas_jax, loglikelihood_jax = hmm_forwards_backwards_jax(params_jax, jnp.array(x_hist), len(x_hist))
54 | print(f'Time taken by JAX version of forwards backwards: {time.time()-start}s')
55 |
56 | start = time.time()
57 | alphas, _, gammas, loglikelihood = hmm_distrax.forward_backward(obs_seq=jnp.array(x_hist),
58 | length=len(x_hist))
59 |
60 | print(f'Time taken by HMM distrax : {time.time()-start}s')
61 |
62 | assert np.allclose(alphas_np, alphas_jax)
63 | assert np.allclose(loglikelihood_np, loglikelihood_jax)
64 | assert np.allclose(gammas_np, gammas_jax)
65 |
66 | assert np.allclose(alphas, alphas_jax, 8)
67 | assert np.allclose(loglikelihood, loglikelihood_jax)
68 | assert np.allclose(gammas, gammas_jax, 8)
--------------------------------------------------------------------------------
/jsl/hmm/old/hmm_discrete_likelihood_test.py:
--------------------------------------------------------------------------------
1 | '''
2 | This demo shows how to create variable size dataset and then it creates mini-batches from this dataset so that it
3 | calculates the likelihood of each observation sequence in every batch using vmap.
4 | Author : Aleyna Kara (@karalleyna)
5 | '''
6 |
7 |
8 | from jax import vmap, jit
9 | from jax.random import split, randint, PRNGKey
10 | import jax.numpy as jnp
11 |
12 | import hmm_utils
13 | from hmm_discrete_lib import hmm_sample_jax, hmm_loglikelihood_numpy, hmm_loglikelihood_jax
14 | from hmm_discrete_lib import HMMNumpy, HMMJax
15 | import numpy as np
16 |
17 | def loglikelihood_numpy(params_numpy, batches, lens):
18 | return np.vstack([hmm_loglikelihood_numpy(params_numpy, batch, l) for batch, l in zip(batches, lens)])
19 |
20 | def loglikelihood_jax(params_jax, batches, lens):
21 | return vmap(hmm_loglikelihood_jax, in_axes=(None, 0, 0))(params_jax, batches, lens)
22 |
23 | # state transition matrix
24 | A = jnp.array([
25 | [0.95, 0.05],
26 | [0.10, 0.90]
27 | ])
28 |
29 | # observation matrix
30 | B = jnp.array([
31 | [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
32 | [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
33 | ])
34 |
35 | pi = jnp.array([1, 1]) / 2
36 |
37 | params_numpy= HMMNumpy(np.array(A), np.array(B), np.array(pi))
38 | params_jax = HMMJax(A, B, pi)
39 |
40 | seed = 0
41 | rng_key = PRNGKey(seed)
42 | rng_key, rng_sample = split(rng_key)
43 |
44 | n_obs_seq, batch_size, max_len = 15, 5, 10
45 |
46 | observations, lens = hmm_utils.hmm_sample_n(params_jax,
47 | hmm_sample_jax,
48 | n_obs_seq, max_len,
49 | rng_sample)
50 |
51 | observations, lens = hmm_utils.pad_sequences(observations, lens)
52 |
53 | rng_key, rng_batch = split(rng_key)
54 | batches, lens = hmm_utils.hmm_sample_minibatches(observations,
55 | lens,
56 | batch_size,
57 | rng_batch)
58 |
59 | ll_numpy = loglikelihood_numpy(params_numpy, np.array(batches), np.array(lens))
60 | ll_jax = loglikelihood_jax(params_jax, batches, lens)
61 |
62 | assert np.allclose(ll_numpy, ll_jax)
63 | print(f'Loglikelihood {ll_numpy}')
--------------------------------------------------------------------------------
/jsl/hmm/old/hmm_sgd_lib.py:
--------------------------------------------------------------------------------
1 | # Trains hidden markov models with discrete observations using Gradient Descent in a stateless way.
2 | # Author : Aleyna Kara(@karalleyna)
3 |
4 |
5 | import jax
6 | import itertools
7 | from jax import jit
8 | from jax.nn import softmax
9 | from jax.random import PRNGKey, split, normal
10 | from jsl.hmm.hmm_utils import hmm_sample_minibatches
11 | from jsl.hmm.hmm_discrete_lib import HMMJax, hmm_loglikelihood_jax
12 |
13 | opt_init, opt_update, get_params = None, None, None
14 |
15 | def init_random_params(sizes, rng_key):
16 | """
17 | Initializes the components of HMM from normal distibution
18 |
19 | Parameters
20 | ----------
21 | sizes: List
22 | Consists of number of hidden states and observable events, respectively
23 |
24 | rng_key : array
25 | Random key of shape (2,) and dtype uint32
26 |
27 | Returns
28 | -------
29 | * array(num_hidden, num_hidden)
30 | Transition probability matrix
31 |
32 | * array(num_hidden, num_obs)
33 | Emission probability matrix
34 |
35 | * array(1, num_hidden)
36 | Initial distribution probabilities
37 | """
38 | num_hidden, num_obs = sizes
39 | rng_key, rng_a, rng_b, rng_pi = split(rng_key, 4)
40 | return HMMJax(normal(rng_a, (num_hidden, num_hidden)),
41 | normal(rng_b, (num_hidden, num_obs)),
42 | normal(rng_pi, (num_hidden,)))
43 |
44 |
45 | @jit
46 | def loss_fn(params, batch, lens):
47 | """
48 | Objective function of hidden markov models for discrete observations. It returns the mean of the negative
49 | loglikelihood of the sequence of observations
50 |
51 | Parameters
52 | ----------
53 | params : HMMJax
54 | Hidden Markov Model
55 |
56 | batch: array(N, max_len)
57 | Minibatch consisting of observation sequences
58 |
59 | lens : array(N, seq_len)
60 | Consists of the valid length of each observation sequence in the minibatch
61 |
62 | Returns
63 | -------
64 | * float
65 | The mean negative loglikelihood of the minibatch
66 | """
67 | params_soft = HMMJax(softmax(params.trans_mat, axis=1),
68 | softmax(params.obs_mat, axis=1),
69 | softmax(params.init_dist))
70 | return -hmm_loglikelihood_jax(params_soft, batch, lens).mean()
71 |
72 |
73 | @jit
74 | def update(i, opt_state, batch, lens):
75 | """
76 | Objective function of hidden markov models for discrete observations. It returns the mean of the negative
77 | loglikelihood of the sequence of observations
78 |
79 | Parameters
80 | ----------
81 | i : int
82 | Specifies the current iteration
83 |
84 | opt_state : OptimizerState
85 |
86 | batch: array(N, max_len)
87 | Minibatch consisting of observation sequences
88 |
89 | lens : array(N, seq_len)
90 | Consists of the valid length of each observation sequence in the minibatch
91 |
92 | Returns
93 | -------
94 | * OptimizerState
95 |
96 | * float
97 | The mean negative loglikelihood of the minibatch, i.e. loss value for the current iteration.
98 | """
99 | params = get_params(opt_state)
100 | loss, grads = jax.value_and_grad(loss_fn)(params, batch, lens)
101 | return opt_update(i, grads, opt_state), loss
102 |
103 |
104 | def fit(observations, lens, num_hidden, num_obs, batch_size, optimizer,rng_key=None, num_epochs=1):
105 | """
106 | Trains the HMM model with the given number of hidden states and observations via any optimizer.
107 |
108 | Parameters
109 | ----------
110 | observations: array(N, seq_len)
111 | All observation sequences
112 |
113 | lens : array(N, seq_len)
114 | Consists of the valid length of each observation sequence
115 |
116 | num_hidden : int
117 | The number of hidden state
118 |
119 | num_obs : int
120 | The number of observable events
121 |
122 | batch_size : int
123 | The number of observation sequences that will be included in each minibatch
124 |
125 | optimizer : jax.experimental.optimizers.Optimizer
126 | Optimizer that is used during training
127 |
128 | num_epochs : int
129 | The total number of iterations
130 |
131 | Returns
132 | -------
133 | * HMMJax
134 | Hidden Markov Model
135 |
136 | * array
137 | Consists of training losses
138 | """
139 | global opt_init, opt_update, get_params
140 |
141 | if rng_key is None:
142 | rng_key = PRNGKey(0)
143 |
144 | rng_init, rng_iter = split(rng_key)
145 | params = init_random_params([num_hidden, num_obs], rng_init)
146 | opt_init, opt_update, get_params = optimizer
147 | opt_state = opt_init(params)
148 | itercount = itertools.count()
149 |
150 | def epoch_step(opt_state, key):
151 |
152 | def train_step(opt_state, params):
153 | batch, length = params
154 | opt_state, loss = update(next(itercount), opt_state, batch, length)
155 | return opt_state, loss
156 |
157 | batches, valid_lens = hmm_sample_minibatches(observations, lens, batch_size, key)
158 | params = (batches, valid_lens)
159 | opt_state, losses = jax.lax.scan(train_step, opt_state, params)
160 | return opt_state, losses.mean()
161 |
162 | epochs = split(rng_iter, num_epochs)
163 | opt_state, losses = jax.lax.scan(epoch_step, opt_state, epochs)
164 |
165 | losses = losses.flatten()
166 |
167 | params = get_params(opt_state)
168 | params = HMMJax(softmax(params.trans_mat, axis=1),
169 | softmax(params.obs_mat, axis=1),
170 | softmax(params.init_dist))
171 | return params, losses
172 |
--------------------------------------------------------------------------------
/jsl/hmm/sparse_lib.py:
--------------------------------------------------------------------------------
1 | """
2 | jax.experimental.sparse-compatible Hidden Markov Model (HMM)
3 | """
4 | import jax
5 | import chex
6 | from functools import partial
7 | from typing import Callable, Tuple, Dict
8 |
9 |
10 | def alpha_step(alpha_prev, y, local_evidence_multiple, transition_matrix):
11 | local_evidence = local_evidence_multiple(y)
12 | alpha_next = local_evidence * (transition_matrix.T @ alpha_prev)
13 | normalisation_cst = alpha_next.sum()
14 | alpha_next = alpha_next / normalisation_cst
15 |
16 | carry = {
17 | "alpha": alpha_next,
18 | "cst": normalisation_cst
19 | }
20 |
21 | return alpha_next, carry
22 |
23 |
24 | def beta_step(beta_next, y, local_evidence_multiple, transition_matrix):
25 | norm_cst = beta_next.sum()
26 | local_evidence = local_evidence_multiple(y)
27 | beta_prev = transition_matrix @ (local_evidence * beta_next)
28 | beta_prev = beta_prev / norm_cst
29 |
30 | carry = {
31 | "beta": beta_prev,
32 | "cst": norm_cst
33 | }
34 |
35 | return beta_prev, carry
36 |
37 |
38 | def alpha_forward(obs: chex.Array,
39 | local_evidence: Callable[[chex.Array], chex.Array],
40 | transition_matrix: chex.Array,
41 | alpha_init: chex.Array) -> Tuple[chex.Array, chex.Array]:
42 | """
43 | Compute the alpha-forward (forward-filter) pass for a given observation sequence.
44 | """
45 | alpha_step_part = partial(alpha_step,
46 | local_evidence_multiple=local_evidence,
47 | transition_matrix=transition_matrix)
48 | alpha_last, alpha_hist = jax.lax.scan(alpha_step_part, alpha_init, obs)
49 | return alpha_last, alpha_hist
50 |
51 |
52 | def beta_backward(obs: chex.Array,
53 | local_evidence: Callable[[chex.Array], chex.Array],
54 | transition_matrix: chex.Array,
55 | alpha_last: chex.Array) -> Tuple[chex.Array, chex.Array]:
56 | """
57 | Compute the forward-filter and backward-smoother pass for a
58 | given observation sequence.
59 | """
60 | beta_step_part = partial(beta_step,
61 | local_evidence_multiple=local_evidence,
62 | transition_matrix=transition_matrix)
63 | beta_first, beta_hist = jax.lax.scan(beta_step_part, alpha_last, obs, reverse=True)
64 | return beta_first, beta_hist
65 |
66 |
67 | def forward_backward(obs: chex.Array,
68 | local_evidence: Callable[[chex.Array], chex.Array],
69 | transition_matrix: chex.Array,
70 | alpha_init: chex.Array) -> Dict[str, chex.Array]:
71 | """
72 | Compute the forward-filter and backward-smoother pass for a
73 | given observation sequence.
74 | """
75 | alpha_last, alpha_hist = alpha_forward(obs, local_evidence, transition_matrix, alpha_init)
76 | _, beta_hist = beta_backward(obs, local_evidence, transition_matrix, alpha_last)
77 |
78 | filter_hist = alpha_hist["alpha"]
79 | smooth_hist = filter_hist * beta_hist["beta"]
80 | smooth_hist = smooth_hist / smooth_hist.sum(axis=1, keepdims=True)
81 |
82 | res = {
83 | "filter": filter_hist,
84 | "smooth": smooth_hist
85 | }
86 |
87 | return res
88 |
--------------------------------------------------------------------------------
/jsl/lds/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/jsl/lds/cont_kalman_filter.py:
--------------------------------------------------------------------------------
1 | # Implementation of the Kalman Filter for
2 | # continuous time series
3 | # Author: Gerardo Durán-Martín (@gerdm), Aleyna Kara(@karalleyna)
4 |
5 | import jax.numpy as jnp
6 | from jax import random, lax
7 | from jax.scipy.linalg import solve
8 |
9 | import chex
10 | from math import ceil
11 |
12 | from jsl.lds.kalman_filter import LDS
13 |
14 |
15 | def _rk2(x0, M, nsteps, dt):
16 | """
17 | class-independent second-order Runge-Kutta method for linear systems
18 |
19 | Parameters
20 | ----------
21 | x0: array(state_size, )
22 | Initial state of the system
23 | M: array(state_size, K)
24 | Evolution matrix
25 | nsteps: int
26 | Total number of steps to integrate
27 | dt: float
28 | integration step size
29 |
30 | Returns
31 | -------
32 | array(nsteps, state_size)
33 | Integration history
34 | """
35 | def f(x): return M @ x
36 | input_dim, *_ = x0.shape
37 |
38 | def step(xt, t):
39 | k1 = f(xt)
40 | k2 = f(xt + dt * k1)
41 | xt = xt + dt * (k1 + k2) / 2
42 | return xt, xt
43 |
44 | steps = jnp.arange(nsteps)
45 | _, simulation = lax.scan(step, x0, steps)
46 |
47 | simulation = jnp.vstack([x0, simulation])
48 | return simulation
49 |
50 | def sample(key: chex.PRNGKey,
51 | params: LDS,
52 | x0: chex.Array,
53 | T: float,
54 | nsamples: int,
55 | dt: float=0.01,
56 | noisy: bool=False):
57 | """
58 | Run the Kalman Filter algorithm. First, we integrate
59 | up to time T, then we obtain nsamples equally-spaced points. Finally,
60 | we transform the latent space to obtain the observations
61 |
62 | Parameters
63 | ----------
64 | params: LDS
65 | Linear Dynamical System object
66 | key: jax.random.PRNGKey
67 | x0: array(state_size)
68 | Initial state of simulation
69 | T: float
70 | Final time of integration
71 | nsamples: int
72 | Number of observations to take from the total integration
73 | dt: float
74 | integration step size
75 | noisy: bool
76 | Whether to (naively) add noise to the state space
77 |
78 | Returns
79 | -------
80 | * array(nsamples, state_size)
81 | State-space values
82 | * array(nsamples, obs_size)
83 | Observed-space values
84 | * int
85 | Number of observations skipped between one
86 | datapoint and the next
87 | """
88 | nsteps = ceil(T / dt)
89 | jump_size = ceil(nsteps / nsamples)
90 | correction = nsamples - ceil(nsteps / jump_size)
91 | nsteps += correction * jump_size
92 |
93 | key_state, key_obs = random.split(key)
94 | obs_size, state_size = params.C.shape
95 | state_noise = random.multivariate_normal(key_state, jnp.zeros(state_size), params.Q, (nsteps,))
96 | obs_noise = random.multivariate_normal(key_obs, jnp.zeros(obs_size), params.R, (nsteps,))
97 | simulation = _rk2(x0, params.A, nsteps, dt)
98 |
99 | if noisy:
100 | simulation = simulation + state_noise
101 |
102 | sample_state = simulation[::jump_size]
103 | sample_obs = jnp.einsum("ij,si->si", params.C, sample_state) + obs_noise[:len(sample_state)]
104 |
105 | return sample_state, sample_obs, jump_size
106 |
107 | def filter(params: LDS,
108 | x_hist: chex.Array,
109 | jump_size: chex.Array,
110 | dt: chex.Array):
111 | """
112 | Compute the online version of the Kalman-Filter, i.e,
113 | the one-step-ahead prediction for the hidden state or the
114 | time update step
115 |
116 | Parameters
117 | ----------
118 | x_hist: array(timesteps, observation_size)
119 |
120 | Returns
121 | -------
122 | * array(timesteps, state_size):
123 | Filtered means mut
124 | * array(timesteps, state_size, state_size)
125 | Filtered covariances Sigmat
126 | * array(timesteps, state_size)
127 | Filtered conditional means mut|t-1
128 | * array(timesteps, state_size, state_size)
129 | Filtered conditional covariances Sigmat|t-1
130 | """
131 | obs_size, state_size = params.C.shape
132 |
133 | I = jnp.eye(state_size)
134 | timesteps, *_ = x_hist.shape
135 | mu_hist = jnp.zeros((timesteps, state_size))
136 | Sigma_hist = jnp.zeros((timesteps, state_size, state_size))
137 | Sigma_cond_hist = jnp.zeros((timesteps, state_size, state_size))
138 | mu_cond_hist = jnp.zeros((timesteps, state_size))
139 |
140 | # Initial configuration
141 | A, Q, C, R = params.A, params.Q, params.C, params.R
142 | mu, Sigma = params.mu, params.Sigma
143 |
144 | temp = C @ Sigma @ C.T + R
145 | K1 = solve(temp, C @ Sigma, sym_pos=True).T
146 | mu1 = mu + K1 @ (x_hist[0] - C @ mu)
147 | Sigma1 = (I - K1 @ C) @ Sigma
148 |
149 |
150 | def rk_integration_step(state, carry):
151 | # Runge-kutta integration step
152 | mu, Sigma = state
153 | k1 = A @ mu
154 | k2 = A @ (mu + dt * k1)
155 | mu = mu + dt * (k1 + k2) / 2
156 |
157 | k1 = A @ Sigma @ A.T + Q
158 | k2 = A @ (Sigma + dt * k1) @ A.T + Q
159 | Sigma = Sigma + dt * (k1 + k2) / 2
160 |
161 | return (mu, Sigma), None
162 |
163 | def step(state, x):
164 | mun, Sigman = state
165 | initial_state = (mun, Sigman)
166 | (mun, Sigman), _ = lax.scan(rk_integration_step, initial_state, jnp.arange(jump_size))
167 |
168 | Sigman_cond = jnp.ones_like(Sigman) * Sigman
169 | St = C @ Sigman_cond @ C.T + R
170 | Kn = solve(St, C @ Sigman_cond, sym_pos=True).T
171 |
172 | mu_update = jnp.ones_like(mun) * mun
173 | x_update = C @ mun
174 | mun = mu_update + Kn @ (x - x_update)
175 | Sigman = (I - Kn @ C) @ Sigman_cond
176 |
177 | return (mun, Sigman), (mun, Sigman, mu_update, Sigman_cond)
178 |
179 | initial_state = (mu1, Sigma1)
180 | _, (mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist) = lax.scan(step, initial_state, x_hist[1:])
181 |
182 | mu_hist = jnp.vstack([mu1[None, ...], mu_hist])
183 | Sigma_hist = jnp.vstack([Sigma1[None, ...], Sigma_hist])
184 | mu_cond_hist = jnp.vstack([params.mu[None, ...], mu_cond_hist])
185 | Sigma_cond_hist = jnp.vstack([params.Sigma[None, ...], Sigma_cond_hist])
186 |
187 | return mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist
188 |
--------------------------------------------------------------------------------
/jsl/lds/kalman_filter_test.py:
--------------------------------------------------------------------------------
1 | from jax import random
2 | from jax import numpy as jnp
3 | import numpy as np
4 | import tensorflow_probability.substrates.jax.distributions as tfd
5 |
6 | from jsl.lds.kalman_filter import LDS, kalman_filter, kalman_smoother
7 |
8 |
9 | def lds_jsl_to_tfp(num_timesteps, lds):
10 | """Convert a JSL `LDS` object into a tfp `LinearGaussianStateSpaceModel`.
11 |
12 | Args:
13 | num_timesteps: int, number of timesteps.
14 | lds: LDS object.
15 | """
16 | dynamics_noise_dist = tfd.MultivariateNormalFullCovariance(covariance_matrix=lds.Q)
17 | emission_noise_dist = tfd.MultivariateNormalFullCovariance(covariance_matrix=lds.R)
18 | initial_dist = tfd.MultivariateNormalFullCovariance(lds.mu, lds.Sigma)
19 |
20 | tfp_lgssm = tfd.LinearGaussianStateSpaceModel(
21 | num_timesteps,
22 | lds.A, dynamics_noise_dist,
23 | lds.C, emission_noise_dist,
24 | initial_dist,
25 | )
26 |
27 | return tfp_lgssm
28 |
29 |
30 | def test_kalman_filter():
31 | key = random.PRNGKey(314)
32 | num_timesteps = 15
33 | delta = 1.0
34 |
35 | ### LDS Parameters ###
36 | state_size = 2
37 | observation_size = 2
38 | A = jnp.eye(state_size)
39 | C = jnp.eye(state_size)
40 |
41 | transition_noise_scale = 1.0
42 | observation_noise_scale = 1.0
43 | Q = jnp.eye(state_size) * transition_noise_scale
44 | R = jnp.eye(observation_size) * observation_noise_scale
45 |
46 | ### Prior distribution params ###
47 | mu0 = jnp.array([8, 10]).astype(float)
48 | Sigma0 = jnp.eye(state_size) * 1.0
49 |
50 | ### Sample data ###
51 | lds_instance = LDS(A, C, Q, R, mu0, Sigma0)
52 | z_hist, x_hist = lds_instance.sample(key, num_timesteps)
53 |
54 | filter_output = kalman_filter(lds_instance, x_hist)
55 | JSL_filtered_means, JSL_filtered_covs, *_ = filter_output
56 | JSL_smoothed_means, JSL_smoothed_covs = kalman_smoother(lds_instance, *filter_output)
57 |
58 | tfp_lgssm = lds_jsl_to_tfp(num_timesteps, lds_instance)
59 | _, tfp_filtered_means, tfp_filtered_covs, *_ = tfp_lgssm.forward_filter(x_hist)
60 | tfp_smoothed_means, tfp_smoothed_covs = tfp_lgssm.posterior_marginals(x_hist)
61 |
62 | assert np.allclose(JSL_filtered_means, tfp_filtered_means, rtol=1e-2)
63 | assert np.allclose(JSL_filtered_covs, tfp_filtered_covs, rtol=1e-2)
64 | assert np.allclose(JSL_smoothed_means, tfp_smoothed_means, rtol=1e-2)
65 | assert np.allclose(JSL_smoothed_covs, tfp_smoothed_covs, rtol=1e-2)
66 |
--------------------------------------------------------------------------------
/jsl/lds/kalman_filter_with_unknown_noise.py:
--------------------------------------------------------------------------------
1 | # Jax implementation of a Linear Dynamical System where the observation noise is not known.
2 | # Author: Gerardo Durán-Martín (@gerdm), Aleyna Kara(@karalleyna)
3 |
4 | import chex
5 |
6 | import jax.numpy as jnp
7 | from jax import lax, vmap, tree_map
8 |
9 | from dataclasses import dataclass
10 | from functools import partial
11 | from typing import Union, Callable
12 |
13 | from tensorflow_probability.substrates import jax as tfp
14 |
15 | tfd = tfp.distributions
16 |
17 |
18 | @dataclass
19 | class LDS:
20 | """
21 | Kalman filtering for a linear Gaussian state space model with scalar observations,
22 | where all the parameters are known except for the observation noise variance.
23 | The model has the following form:
24 | p(state(0)) = Gauss(mu0, Sigma0)
25 | p(state(t) | state(t-1)) = Gauss(A * state(t-1), Q)
26 | p(obs(t) | state(t)) = Gauss(C * state(t), r)
27 | The value of r is jointly inferred together with the latent states to produce the posterior
28 | p(state(t) , r | obs(1:t)) = Gauss(mu(t), Sigma(t) * r) * Ga(1/r | nu(t)/2, nu(t)*tau(t)/2)
29 | where 1/r is the observation precision. For details on this algorithm, see sec 4.5 of
30 | "Bayesian forecasting and dynamic models", West and Harrison, 1997.
31 | https://www2.stat.duke.edu/~mw/West&HarrisonBook/
32 |
33 | Parameters
34 | ----------
35 | A: array(state_size, state_size)
36 | Transition matrix
37 | C: array(observation_size, state_size)
38 | Observation matrix
39 | Q: array(state_size, state_size)
40 | Transition covariance matrix
41 | mu: array(state_size)
42 | Mean of initial configuration
43 | Sigma: array(state_size, state_size) or 0
44 | Covariance of initial configuration. If value is set
45 | to zero, the initial state will be completely determined
46 | by mu0
47 | """
48 | A: chex.Array
49 | C: Union[chex.Array, Callable]
50 | Q: chex.Array
51 | R: chex.Array
52 | mu: chex.Array
53 | Sigma: chex.Array
54 | v: chex.Array
55 | tau: chex.Array
56 |
57 |
58 | def kalman_filter(params: LDS, x_hist: chex.Array,
59 | return_history: bool = True):
60 | """
61 | Compute the online version of the Kalman-Filter, i.e,
62 | the one-step-ahead prediction for the hidden state or the
63 | time update step
64 |
65 | Parameters
66 | ----------
67 | params: LDS
68 | Linear Dynamical System object
69 | x_hist: array(timesteps, observation_size)
70 | return_history: bool
71 |
72 | Returns
73 | -------
74 | * array(timesteps, state_size):
75 | Filtered means mut
76 | * array(timesteps, state_size, state_size)
77 | Filtered covariances Sigmat
78 | * array(timesteps, state_size)
79 | Filtered conditional means mut|t-1
80 | * array(timesteps, state_size, state_size)
81 | Filtered conditional covariances Sigmat|t-1
82 | """
83 | A, Q, R = params.A, params.Q, params.R
84 | state_size, _ = A.shape
85 |
86 | def kalman_step(state, obs):
87 | mu, Sigma, v, tau = state
88 | covariates, response = obs
89 |
90 | mu_cond = jnp.matmul(A, mu, precision=lax.Precision.HIGHEST)
91 | Sigmat_cond = jnp.matmul(jnp.matmul(A, Sigma, precision=lax.Precision.HIGHEST), A,
92 | precision=lax.Precision.HIGHEST) + Q
93 |
94 | e_k = response - covariates.T @ mu_cond
95 | s_k = covariates.T @ Sigmat_cond @ covariates + 1
96 | Kt = (Sigmat_cond @ covariates) / s_k
97 |
98 | mu = mu + e_k * Kt
99 | Sigma = Sigmat_cond - jnp.outer(Kt, Kt) * s_k
100 |
101 | v_update = v + 1
102 | tau = (v * tau + (e_k * e_k) / s_k) / v_update
103 |
104 | return (mu, Sigma, v_update, tau), (mu, Sigma)
105 |
106 | mu0, Sigma0 = params.mu, params.Sigma
107 | initial_state = (mu0, Sigma0, 0)
108 | (mu, Sigma, _, _), history = lax.scan(kalman_step, initial_state, x_hist)
109 | if return_history:
110 | return history
111 | return mu, Sigma
112 |
113 |
114 | def filter(params: LDS, x_hist: chex.Array,
115 | return_history: bool = True):
116 | """
117 | Compute the online version of the Kalman-Filter, i.e,
118 | the one-step-ahead prediction for the hidden state or the
119 | time update step.
120 | Note that x_hist can optionally be of dimensionality two,
121 | This corresponds to different samples of the same underlying
122 | Linear Dynamical System
123 | Parameters
124 | ----------
125 | params: LDS
126 | Linear Dynamical System object
127 | x_hist: array(n_samples?, timesteps, observation_size)
128 | Returns
129 | -------
130 | * array(n_samples?, timesteps, state_size):
131 | Filtered means mut
132 | * array(n_samples?, timesteps, state_size, state_size)
133 | Filtered covariances Sigmat
134 | * array(n_samples?, timesteps, state_size)
135 | Filtered conditional means mut|t-1
136 | * array(n_samples?, timesteps, state_size, state_size)
137 | Filtered conditional covariances Sigmat|t-1
138 | """
139 | has_one_sim = False
140 | if x_hist.ndim == 2:
141 | x_hist = x_hist[None, ...]
142 | has_one_sim = True
143 | kalman_map = vmap(partial(kalman_filter, return_history=return_history), (None, 0))
144 | outputs = kalman_map(params, x_hist)
145 | if has_one_sim and return_history:
146 | return tree_map(lambda x: x[0, ...], outputs)
147 | return outputs
148 |
--------------------------------------------------------------------------------
/jsl/lds/kalman_sampler.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax.random import multivariate_normal, PRNGKey
3 | from jax.scipy.linalg import solve, cholesky
4 | from jax import lax
5 |
6 | from .kalman_filter import LDS
7 |
8 |
9 |
10 | def smooth_sampler(params: LDS,
11 | key: PRNGKey,
12 | mu_hist: jnp.array,
13 | Sigma_hist: jnp.array,
14 | n_samples: jnp.array = 1):
15 | """
16 | Backwards sample from the smoothing distribution
17 | Parameters
18 | ----------
19 | params: LDS
20 | Linear Dynamical System object
21 | key: jax.random.PRNGKey
22 | Seed of state noises
23 | mu_hist: array(timesteps, state_size):
24 | Filtered means mut
25 | Sigma_hist: array(timesteps, state_size, state_size)
26 | Filtered covariances Sigmat
27 | n_samples: int
28 | Number of posterior samples (optional)
29 | Returns
30 | -------
31 | * array(n_samples, timesteps, state_size):
32 | Posterior samples
33 | """
34 | state_size, _ = params.get_trans_mat_of(0).shape
35 | I = jnp.eye(state_size)
36 | timesteps = len(mu_hist)
37 | # Generate all state noise terms
38 | zeros_state = jnp.zeros(state_size)
39 | system_noise = multivariate_normal(key, zeros_state, I, (timesteps, n_samples))
40 | state_T = mu_hist[-1] + system_noise[-1] @ Sigma_hist[-1, ...].T
41 |
42 | def smooth_sample_step(state, inps):
43 | system_noise_t, mutt, Sigmatt , t = inps
44 | A = params.get_trans_mat_of(t)
45 | et = state - mutt @ A.T
46 | St = A @ Sigmatt @ A.T + params.get_system_noise_of(t)
47 | Kt = solve(St, A @ Sigmatt, sym_pos=True).T
48 | mu_t = mutt + et @ Kt.T
49 | Sigma_t = (I - Kt @ A) @ Sigmatt
50 | Sigma_root = cholesky(Sigma_t)
51 | state_new = mu_t + system_noise_t @ Sigma_root.T
52 | return state_new, state_new
53 |
54 | inps = (system_noise[-2::-1, ...], mu_hist[-2::-1, ...], Sigma_hist[-2::-1, ...], jnp.arange(1, timesteps)[::-1])
55 | _, state_sample_smooth = lax.scan(smooth_sample_step, state_T, inps)
56 |
57 | state_sample_smooth = jnp.concatenate([state_sample_smooth[::-1, ...], state_T[None, ...]], axis=0)
58 | state_sample_smooth = jnp.swapaxes(state_sample_smooth, 0, 1)
59 |
60 | if n_samples == 1:
61 | state_sample_smooth = state_sample_smooth[:, 0, :]
62 | return state_sample_smooth
--------------------------------------------------------------------------------
/jsl/lds/mixture_kalman_filter.py:
--------------------------------------------------------------------------------
1 | # Mixture Kalman Filter library. Also known as the
2 | # Rao-Blackwell Particle Filter.
3 |
4 | # Author: Gerardo Durán-Martín (@gerdm)
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | from jax import random
9 | from jax.scipy.special import logit
10 | from dataclasses import dataclass
11 |
12 |
13 | @dataclass
14 | class RBPFParamsDiscrete:
15 | """
16 | Rao-Blackwell Particle Filtering (RBPF) parameters for
17 | a system with discrete latent-space.
18 | We assume that the system evolves as
19 | z_next = A * z_old + B(u_old) + noise1_next
20 | x_next = C * z_next + noise2_next
21 | u_next ~ transition_matrix(u_old)
22 |
23 | where
24 | noise1_next ~ N(0, Q)
25 | noise2_next ~ N(0, R)
26 | """
27 | A: jnp.array
28 | B: jnp.array
29 | C: jnp.array
30 | Q: jnp.array
31 | R: jnp.array
32 | transition_matrix: jnp.array
33 |
34 |
35 | def draw_state(val, key, params):
36 | """
37 | Simulate one step of a system that evolves as
38 | A z_{t-1} + Bk + eps,
39 | where eps ~ N(0, Q).
40 |
41 | Parameters
42 | ----------
43 | val: tuple (int, jnp.array)
44 | (latent value of system, state value of system).
45 | params: PRBPFParamsDiscrete
46 | key: PRNGKey
47 | """
48 | latent_old, state_old = val
49 | probabilities = params.transition_matrix[latent_old, :]
50 | logits = logit(probabilities)
51 | latent_new = random.categorical(key, logits)
52 |
53 | key_latent, key_obs = random.split(key)
54 | state_new = params.A @ state_old + params.B[latent_new, :]
55 | state_new = random.multivariate_normal(key_latent, state_new, params.Q)
56 | obs_new = random.multivariate_normal(key_obs, params.C @ state_new, params.R)
57 |
58 | return (latent_new, state_new), (latent_new, state_new, obs_new)
59 |
60 |
61 | def kf_update(mu_t, Sigma_t, k, xt, params):
62 | I = jnp.eye(len(mu_t))
63 | mu_t_cond = params.A @ mu_t + params.B[k]
64 | Sigma_t_cond = params.A @ Sigma_t @ params.A.T + params.Q
65 | xt_cond = params.C @ mu_t_cond
66 | St = params.C @ Sigma_t_cond @ params.C.T + params.R
67 |
68 | Kt = Sigma_t_cond @ params.C.T @ jnp.linalg.inv(St)
69 |
70 | # Estimation update
71 | mu_t = mu_t_cond + Kt @ (xt - xt_cond)
72 | Sigma_t = (I - Kt @ params.C) @ Sigma_t_cond
73 |
74 | # Normalisation constant
75 | mean_norm = params.C @ mu_t_cond
76 | cov_norm = params.C @ Sigma_t_cond @ params.C.T + params.R
77 | Ltk = jax.scipy.stats.multivariate_normal.pdf(xt, mean_norm, cov_norm)
78 |
79 | return mu_t, Sigma_t, Ltk
80 |
81 |
82 | def rbpf_step(key, weight_t, st, mu_t, Sigma_t, xt, params):
83 | log_p_next = logit(params.transition_matrix[st])
84 | k = random.categorical(key, log_p_next)
85 | mu_t, Sigma_t, Ltk = kf_update(mu_t, Sigma_t, k, xt, params)
86 | weight_t = weight_t * Ltk
87 |
88 | return mu_t, Sigma_t, weight_t, Ltk
89 |
90 |
91 | kf_update_vmap = jax.vmap(kf_update, in_axes=(None, None, 0, None, None), out_axes=0)
92 |
93 |
94 | def rbpf_step_optimal(key, weight_t, st, mu_t, Sigma_t, xt, params):
95 | k = jnp.arange(len(params.transition_matrix))
96 | mu_tk, Sigma_tk, Ltk = kf_update_vmap(mu_t, Sigma_t, k, xt, params)
97 |
98 | proposal = Ltk * params.transition_matrix[st]
99 |
100 | weight_tk = weight_t * proposal.sum()
101 | proposal = proposal / proposal.sum()
102 |
103 | return mu_tk, Sigma_tk, weight_tk, proposal
104 |
105 |
106 | # vectorised RBPF step
107 | rbpf_step_vec = jax.vmap(rbpf_step, in_axes=(0, 0, 0, 0, 0, None, None))
108 | # vectorisedRBPF Step optimal
109 | rbpf_step_optimal_vec = jax.vmap(rbpf_step_optimal, in_axes=(0, 0, 0, 0, 0, None, None))
110 |
111 |
112 | def rbpf(current_config, xt, params, nparticles=100):
113 | """
114 | Rao-Blackwell Particle Filter using prior as proposal
115 | """
116 | key, mu_t, Sigma_t, weights_t, st = current_config
117 |
118 | key_sample, key_state, key_next, key_reindex = random.split(key, 4)
119 | keys = random.split(key_sample, nparticles)
120 |
121 | st = random.categorical(key_state, logit(params.transition_matrix[st, :]))
122 | mu_t, Sigma_t, weights_t, Ltk = rbpf_step_vec(keys, weights_t, st, mu_t, Sigma_t, xt, params)
123 | weights_t = weights_t / weights_t.sum()
124 |
125 | indices = jnp.arange(nparticles)
126 | pi = random.choice(key_reindex, indices, shape=(nparticles,), p=weights_t, replace=True)
127 | st = st[pi]
128 | mu_t = mu_t[pi, ...]
129 | Sigma_t = Sigma_t[pi, ...]
130 | weights_t = jnp.ones(nparticles) / nparticles
131 |
132 | return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, Ltk)
133 |
134 |
135 | def rbpf_optimal(current_config, xt, params, nparticles=100):
136 | """
137 | Rao-Blackwell Particle Filter using optimal proposal
138 | """
139 | key, mu_t, Sigma_t, weights_t, st = current_config
140 |
141 | key_sample, key_state, key_next, key_reindex = random.split(key, 4)
142 | keys = random.split(key_sample, nparticles)
143 |
144 | st = random.categorical(key_state, logit(params.transition_matrix[st, :]))
145 | mu_t, Sigma_t, weights_t, proposal = rbpf_step_optimal_vec(keys, weights_t, st, mu_t, Sigma_t, xt, params)
146 |
147 | indices = jnp.arange(nparticles)
148 | pi = random.choice(key_reindex, indices, shape=(nparticles,), p=weights_t, replace=True)
149 |
150 | # Obtain optimal proposal distribution
151 | proposal_samp = proposal[pi, :]
152 | st = random.categorical(key, logit(proposal_samp))
153 |
154 | mu_t = mu_t[pi, st, ...]
155 | Sigma_t = Sigma_t[pi, st, ...]
156 |
157 | weights_t = jnp.ones(nparticles) / nparticles
158 |
159 | return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, proposal_samp)
160 |
--------------------------------------------------------------------------------
/jsl/nlds/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probml/JSL/649c1fa9709c9195d80f06d7a367112d67dc60c9/jsl/nlds/__init__.py
--------------------------------------------------------------------------------
/jsl/nlds/base.py:
--------------------------------------------------------------------------------
1 | # Library of nonlinear dynamical systems
2 | # Usage: Every discrete xKF class inherits from NLDS.
3 | # There are two ways to use this library in the discrete case:
4 | # 1) Explicitly initialize a discrete NLDS object with the desired parameters,
5 | # then pass it onto the xKF class of your choice.
6 | # 2) Initialize the xKF object with the desired NLDS parameters using
7 | # the .from_base constructor.
8 | # Way 1 is preferable whenever you want to use the same NLDS for multiple
9 | # filtering processes. Way 2 is preferred whenever you want to use a single NLDS
10 | # for a single filtering process
11 |
12 | # Author: Gerardo Durán-Martín (@gerdm)
13 |
14 | import jax
15 | from jax.random import split, multivariate_normal
16 |
17 | import chex
18 |
19 | from dataclasses import dataclass
20 | from typing import Callable
21 |
22 |
23 | @dataclass
24 | class NLDS:
25 | """
26 | Base class for the nonlinear dynamical systems' module
27 |
28 | Parameters
29 | ----------
30 | fz: function
31 | Nonlinear state transition function
32 | fx: function
33 | Nonlinear observation function
34 | Q: array(state_size, state_size) or function
35 | Nonlinear state transition noise covariance function
36 | R: array(obs_size, obs_size) or function
37 | Nonlinear observation noise covariance function
38 | """
39 | fz: Callable
40 | fx: Callable
41 | Q: chex.Array
42 | R: chex.Array
43 | alpha: float = 0.
44 | beta: float = 0.
45 | kappa: float = 0.
46 | d: int = 0
47 |
48 | def Qz(self, z, *args):
49 | if callable(self.Q):
50 | return self.Q(z, *args)
51 | else:
52 | return self.Q
53 |
54 | def Rx(self, x, *args):
55 | if callable(self.R):
56 | return self.R(x, *args)
57 | else:
58 | return self.R
59 |
60 | def __sample_step(self, input_vals, obs):
61 | key, state_t = input_vals
62 | key_system, key_obs, key = split(key, 3)
63 |
64 | state_t = multivariate_normal(key_system, self.fz(state_t), self.Qz(state_t))
65 | obs_t = multivariate_normal(key_obs, self.fx(state_t, *obs), self.Rx(state_t, *obs))
66 |
67 | return (key, state_t), (state_t, obs_t)
68 |
69 | def sample(self, key, x0, nsteps, obs=None):
70 | """
71 | Sample discrete elements of a nonlinear system
72 | Parameters
73 | ----------
74 | key: jax.random.PRNGKey
75 | x0: array(state_size)
76 | Initial state of simulation
77 | nsteps: int
78 | Total number of steps to sample from the system
79 | obs: None, tuple of arrays
80 | Observed values to pass to fx and R
81 | Returns
82 | -------
83 | * array(nsamples, state_size)
84 | State-space values
85 | * array(nsamples, obs_size)
86 | Observed-space values
87 | """
88 | obs = () if obs is None else obs
89 | state_t = x0.copy()
90 | obs_t = self.fx(state_t)
91 |
92 | self.state_size, *_ = state_t.shape
93 | self.obs_t, *_ = obs_t.shape
94 |
95 | init_state = (key, state_t)
96 | _, hist = jax.lax.scan(self.__sample_step, init_state, obs, length=nsteps)
97 |
98 | return hist
99 |
--------------------------------------------------------------------------------
/jsl/nlds/bootstrap_filter.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of the Bootrstrap Filter for discrete time systems
3 | **This implementation considers the case of multivariate normals**
4 |
5 |
6 | """
7 | import jax.numpy as jnp
8 | from jax import random, lax
9 |
10 | import chex
11 |
12 | from jax.scipy import stats
13 | from jsl.nlds.base import NLDS
14 |
15 |
16 | # TODO: Extend to general case
17 | def filter(params: NLDS,
18 | key: chex.PRNGKey,
19 | init_state: chex.Array,
20 | sample_obs: chex.Array,
21 | nsamples: int = 2000,
22 | Vinit: chex.Array = None):
23 | """
24 | init_state: array(state_size,)
25 | Initial state estimate
26 | sample_obs: array(nsamples, obs_size)
27 | Samples of the observations
28 | """
29 | m, *_ = init_state.shape
30 |
31 | fx, fz = params.fx, params.fz
32 | Q, R = params.Qz, params.Rx
33 |
34 | key, key_init = random.split(key, 2)
35 | V = Q(init_state) if Vinit is None else Vinit
36 | zt_rvs = random.multivariate_normal(key_init, init_state, V, shape=(nsamples,))
37 |
38 | init_state = (zt_rvs, key)
39 |
40 | def __filter_step(state, obs_t):
41 | indices = jnp.arange(nsamples)
42 | zt_rvs, key_t = state
43 |
44 | key_t, key_reindex, key_next = random.split(key_t, 3)
45 | # 1. Draw new points from the dynamic model
46 | zt_rvs = random.multivariate_normal(key_t, fz(zt_rvs), Q(zt_rvs))
47 |
48 | # 2. Calculate unnormalised weights
49 | xt_rvs = fx(zt_rvs)
50 | weights_t = stats.multivariate_normal.pdf(obs_t, xt_rvs, R(zt_rvs, obs_t))
51 |
52 | # 3. Resampling
53 | pi = random.choice(key_reindex, indices,
54 | p=weights_t, shape=(nsamples,))
55 | zt_rvs = zt_rvs[pi, ...]
56 | weights_t = jnp.ones(nsamples) / nsamples
57 |
58 | # 4. Compute latent-state estimate,
59 | # Set next covariance state matrix
60 | mu_t = jnp.einsum("im,i->m", zt_rvs, weights_t)
61 |
62 | return (zt_rvs, key_next), mu_t
63 |
64 | _, mu_hist = lax.scan(__filter_step, init_state, sample_obs)
65 |
66 | return mu_hist
67 |
--------------------------------------------------------------------------------
/jsl/nlds/continuous_extended_kalman_filter.py:
--------------------------------------------------------------------------------
1 | """
2 | Extended Kalman Filter for a nonlinear continuous time
3 | dynamical system with observations in discrete time.
4 | """
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | from jax import lax, jacrev
9 |
10 | import chex
11 |
12 | from math import ceil
13 |
14 | from jsl.nlds.base import NLDS
15 |
16 |
17 | def _rk2(x0, f, nsteps, dt):
18 | """
19 | class-independent second-order Runge-Kutta method
20 |
21 | Parameters
22 | ----------
23 | x0: array(state_size, )
24 | Initial state of the system
25 | f: function
26 | Function to integrate. Must return jax.numpy
27 | array of size state_size
28 | nsteps: int
29 | Total number of steps to integrate
30 | dt: float
31 | integration step size
32 |
33 | Returns
34 | -------
35 | array(nsteps, state_size)
36 | Integration history
37 | """
38 | input_dim, *_ = x0.shape
39 |
40 | def step(xt, t):
41 | k1 = f(xt)
42 | k2 = f(xt + dt * k1)
43 | xt = xt + dt * (k1 + k2) / 2
44 | return xt, xt
45 |
46 | steps = jnp.arange(nsteps)
47 | _, simulation = lax.scan(step, x0, steps)
48 |
49 | simulation = jnp.vstack([x0, simulation])
50 | return simulation
51 |
52 |
53 | def sample(key: chex.PRNGKey,
54 | params: NLDS,
55 | x0: chex.Array,
56 | T: float,
57 | nsamples: int,
58 | dt: float = 0.01,
59 | noisy: bool = False):
60 | """
61 | Run the Extended Kalman Filter algorithm. First, we integrate
62 | up to time T, then we obtain nsamples equally-spaced points. Finally,
63 | we transform the latent space to obtain the observations
64 |
65 | Parameters
66 | ----------
67 | key: jax.random.PRNGKey
68 | Initial seed
69 | x0: array(state_size)
70 | Initial state of simulation
71 | T: float
72 | Final time of integration
73 | nsamples: int
74 | Number of observations to take from the total integration
75 | dt: float
76 | integration step size
77 | noisy: bool
78 | Whether to (naively) add noise to the state space
79 |
80 | Returns
81 | -------
82 | * array(nsamples, state_size)
83 | State-space values
84 | * array(nsamples, obs_size)
85 | Observed-space values
86 | * int
87 | Number of observations skipped between one
88 | datapoint and the next
89 | """
90 |
91 | fz, fx = params.fz, params.fx
92 | Q, R = params.Qz, params.Rx
93 |
94 | state_size, _ = Q.shape
95 | obs_size, _ = R.shape
96 |
97 | nsteps = ceil(T / dt)
98 | jump_size = ceil(nsteps / nsamples)
99 | correction = nsamples - ceil(nsteps / jump_size)
100 | nsteps += correction * jump_size
101 |
102 | key_state, key_obs = jax.random.split(key)
103 | state_noise = jax.random.multivariate_normal(key_state, jnp.zeros(state_size), Q, (nsteps,))
104 | obs_noise = jax.random.multivariate_normal(key_obs, jnp.zeros(obs_size), R, (nsteps,))
105 | simulation = _rk2(x0, fz, nsteps, dt)
106 |
107 | if noisy:
108 | simulation = simulation + jnp.sqrt(dt) * state_noise
109 |
110 | sample_state = simulation[::jump_size]
111 | sample_obs = jnp.apply_along_axis(fx, 1, sample_state) + obs_noise[:len(sample_state)]
112 |
113 | return sample_state, sample_obs, jump_size
114 |
115 |
116 | def _Vt_dot(V, G, Q):
117 | return G @ V @ G.T + Q
118 |
119 |
120 | def estimate(params: NLDS,
121 | sample_state: chex.Array,
122 | sample_obs: chex.Array,
123 | jump_size: int,
124 | dt: float,
125 | return_history: bool = True):
126 | """
127 | Run the Extended Kalman Filter algorithm over a set of observed samples.
128 |
129 | Parameters
130 | ----------
131 | sample_state: array(nsamples, state_size)
132 | sample_obs: array(nsamples, obs_size)
133 | jump_size: int
134 | dt: float
135 | return_history: bool
136 |
137 | Returns
138 | -------
139 | * array(nsamples, state_size)
140 | History of filtered mean terms
141 | * array(nsamples, state_size, state_size)
142 | History of filtered covariance terms
143 | """
144 |
145 | fz, fx = params.fz, params.fx
146 | Q, R = params.Qz, params.Rx
147 |
148 | Dfz = jacrev(fz)
149 | Dfx = jacrev(fx)
150 |
151 | state_size, _ = Q.shape
152 | obs_size, _ = R.shape
153 |
154 | I = jnp.eye(state_size)
155 | Vt = R.copy()
156 | mu_t = sample_state[0]
157 |
158 | def jump_step(state, t):
159 | mu_t, Vt = state
160 | k1 = fz(mu_t)
161 | k2 = fz(mu_t + dt * k1)
162 | mu_t = mu_t + dt * (k1 + k2) / 2
163 |
164 | Gt = Dfz(mu_t)
165 | k1 = _Vt_dot(Vt, Gt, Q)
166 | k2 = _Vt_dot(Vt + dt * k1, Gt, Q)
167 | Vt = Vt + dt * (k1 + k2) / 2
168 | return (mu_t, Vt), None
169 |
170 | def step(state, obs):
171 | jumps = jnp.arange(jump_size)
172 | (mu, V), _ = lax.scan(jump_step, state, jumps)
173 |
174 | mu_t_cond = mu
175 | Vt_cond = V
176 | Ht = Dfx(mu_t_cond)
177 |
178 | Kt = Vt_cond @ Ht.T @ jnp.linalg.inv(Ht @ Vt_cond @ Ht.T + R)
179 | mu = mu_t_cond + Kt @ (obs - fx(mu_t_cond))
180 | V = (I - Kt @ Ht) @ Vt_cond
181 | return (mu, V), (mu, V)
182 |
183 | initial_state = (mu_t.copy(), Vt.copy())
184 | (mu, V), (mu_hist, V_hist) = lax.scan(step, initial_state, sample_obs[1:])
185 |
186 | if return_history:
187 | mu_hist = jnp.vstack([mu_t, mu_hist])
188 | V_hist = jnp.vstack([Vt, V_hist])
189 | return mu_hist, V_hist
190 |
191 | return mu, V
192 |
--------------------------------------------------------------------------------
/jsl/nlds/diagonal_extended_kalman_filter.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of the Diagonal Extended Kalman Filter for a nonlinear
3 | dynamical system with discrete observations. Also known as the
4 | Node-decoupled Extended Kalman Filter (NDEKF)
5 | """
6 |
7 | import jax.numpy as jnp
8 | from jax import jacrev, lax
9 |
10 | import chex
11 | from typing import Tuple
12 |
13 | from .base import NLDS
14 |
15 |
16 | def filter(params: NLDS,
17 | init_state: chex.Array,
18 | sample_obs: chex.Array,
19 | observations: Tuple = None,
20 | Vinit: chex.Array = None,
21 | return_history: bool = True):
22 | """
23 | Run the Extended Kalman Filter algorithm over a set of observed samples.
24 | Parameters
25 | ----------
26 | init_state: array(state_size)
27 | sample_obs: array(nsamples, obs_size)
28 | Returns
29 | -------
30 | * array(nsamples, state_size)
31 | History of filtered mean terms
32 | * array(nsamples, state_size, state_size)
33 | History of filtered covariance terms
34 | """
35 | state_size, *_ = init_state.shape
36 |
37 | fz, fx = params.fz, params.fx
38 | Q, R = params.Qz, params.Rx
39 | Dfx = jacrev(fx)
40 |
41 | Vt = Q(init_state) if Vinit is None else Vinit
42 |
43 | t = 0
44 | state = (init_state, Vt, t)
45 | observations = (observations,) if type(observations) is not tuple else observations
46 | xs = (sample_obs, observations)
47 |
48 | def filter_step(state: Tuple[chex.Array, chex.Array],
49 | xs: Tuple[chex.Array, int]):
50 | """
51 | Run the Extended Kalman filter algorithm for a single step
52 | Paramters
53 | ---------
54 | state: tuple
55 | Mean, covariance at time t-1
56 | xs: tuple
57 | Target value and observations at time t
58 | """
59 | mu_t, Vt, t = state
60 | xt, obs = xs
61 |
62 | mu_t_cond = fz(mu_t)
63 | Ht = Dfx(mu_t_cond, *obs)
64 |
65 | Rt = R(mu_t_cond, *obs)
66 | xt_hat = fx(mu_t_cond, *obs)
67 | xi = xt - xt_hat
68 | A = jnp.linalg.inv(Rt + jnp.einsum("id,jd,d->ij", Ht, Ht, Vt))
69 | mu_t = mu_t_cond + jnp.einsum("s,is,ij,j->s", Vt, Ht, A, xi)
70 | Vt = Vt - jnp.einsum("s,is,ij,is,s->s", Vt, Ht, A, Ht, Vt) + Q(mu_t, t)
71 |
72 | return (mu_t, Vt, t + 1), (mu_t, None)
73 |
74 | (mu_t, Vt, _), mu_t_hist = lax.scan(filter_step, state, xs)
75 |
76 | if return_history:
77 | return (mu_t, Vt), mu_t_hist
78 |
79 | return (mu_t, Vt), None
80 |
--------------------------------------------------------------------------------
/jsl/nlds/extended_kalman_filter.py:
--------------------------------------------------------------------------------
1 | import chex
2 | from jax import jacrev, lax
3 | import jax.numpy as jnp
4 | from typing import Dict, List, Tuple, Callable
5 | from functools import partial
6 | from .base import NLDS
7 |
8 | def filter_step(state: Tuple[chex.Array, chex.Array, int],
9 | xs: Tuple[chex.Array, chex.Array],
10 | params: NLDS,
11 | Dfx: Callable,
12 | Dfz: Callable,
13 | eps: float,
14 | return_params: Dict
15 | ) -> Tuple[Tuple[chex.Array, chex.Array, int], Dict]:
16 | """
17 | Run a single step of the extended Kalman filter (EKF) algorithm.
18 |
19 | Parameters
20 | ---------
21 | state: tuple
22 | Mean, covariance at time t-1
23 | xs: tuple
24 | Target value and covariates at time t
25 | params: NLDS
26 | Nonlinear dynamical system parameters
27 | Dfx: Callable
28 | Jacobian of the observation function
29 | Dfz: Callable
30 | Jacobian of the state transition function
31 | eps: float
32 | Small number to prevent singular matrix
33 | return_params: list
34 | Fix elements to carry
35 |
36 | Returns
37 | -------
38 | * tuple
39 | 1. Mean, covariance, and time at time t
40 | 2. History of filtered mean terms (if requested)
41 | """
42 | mu_t, Vt, t = state
43 | obs, inputs = xs
44 |
45 | state_size, *_ = mu_t.shape
46 | I = jnp.eye(state_size)
47 | Gt = Dfz(mu_t)
48 | mu_t_cond = params.fz(mu_t)
49 | Vt_cond = Gt @ Vt @ Gt.T + params.Qz(mu_t, t)
50 | Ht = Dfx(mu_t_cond, *inputs)
51 |
52 | Rt = params.Rx(mu_t_cond, *inputs)
53 | num_inputs, *_ = Rt.shape
54 |
55 | obs_hat = params.fx(mu_t_cond, *inputs)
56 | Mt = Ht @ Vt_cond @ Ht.T + Rt + eps * jnp.eye(num_inputs)
57 | Kt = Vt_cond @ Ht.T @ jnp.linalg.inv(Mt)
58 | mu_t = mu_t_cond + Kt @ (obs - obs_hat)
59 | Vt = (I - Kt @ Ht) @ Vt_cond @ (I - Kt @ Ht).T + Kt @ Rt @ Kt.T
60 |
61 | carry = {"mean": mu_t, "cov": Vt, "obs_hat": obs_hat}
62 | carry = {key: val for key, val in carry.items() if key in return_params}
63 | return (mu_t, Vt, t + 1), carry
64 |
65 |
66 | def filter(params: NLDS,
67 | init_state: chex.Array,
68 | observations: chex.Array,
69 | covariates: chex.Array = None,
70 | Vinit: chex.Array = None,
71 | return_params: List = None,
72 | eps: float = 0.001,
73 | return_history: bool = True):
74 | """
75 | Run the Extended Kalman Filter algorithm over a set of observed samples.
76 |
77 | Parameters
78 | ----------
79 | init_state: array(state_size)
80 | observations: array(nsamples, obs_size)
81 | covariates: array(nsamples, feature_size) or None
82 | optional covariates to pass to the observation function
83 | Vinit: array(state_size, state_size) or None
84 | Initial state covariance matrix
85 | return_params: list
86 | Parameters to carry from the filter step. Possible values are:
87 | "mean", "cov"
88 | return_history: bool
89 | Whether to return the history of mu and sigma obtained at each step
90 |
91 | Returns
92 | -------
93 | * array(nsamples, state_size)
94 | History of filtered mean terms
95 | * array(nsamples, state_size, state_size)
96 | History of filtered covariance terms
97 | """
98 | state_size, *_ = init_state.shape
99 |
100 | fz, fx = params.fz, params.fx
101 | Q, R = params.Qz, params.Rx
102 |
103 | Dfz = jacrev(fz)
104 | Dfx = jacrev(fx)
105 |
106 | Vt = Q(init_state) if Vinit is None else Vinit
107 |
108 | t = 0
109 | state = (init_state, Vt, t)
110 | covariates = (covariates,) if type(covariates) is not tuple else covariates
111 | xs = (observations, covariates)
112 |
113 | return_params = [] if return_params is None else return_params
114 |
115 | filter_step_pass = partial(filter_step, params=params, Dfx=Dfx, Dfz=Dfz,
116 | eps=eps, return_params=return_params)
117 | (mu_t, Vt, _), hist_elements = lax.scan(filter_step_pass, state, xs)
118 |
119 | if return_history:
120 | return (mu_t, Vt), hist_elements
121 |
122 | return (mu_t, Vt), None
123 |
--------------------------------------------------------------------------------
/jsl/nlds/extended_kalman_smoother.py:
--------------------------------------------------------------------------------
1 | # Extended Rauch-Tung-Striebel smoother or Extended Kalman Smoother (EKS)
2 | import jax
3 | import chex
4 | import jax.numpy as jnp
5 | from .base import NLDS
6 | from functools import partial
7 | from typing import Dict, List, Tuple, Callable
8 | from jsl.nlds import extended_kalman_filter as ekf
9 |
10 |
11 | def smooth_step(state: Tuple[chex.Array, chex.Array, int],
12 | xs: Tuple[chex.Array, chex.Array],
13 | params: NLDS,
14 | Dfz: Callable,
15 | eps: float,
16 | return_params: Dict
17 | ) -> Tuple[Tuple[chex.Array, chex.Array, int], Dict]:
18 | mean_next, cov_next, t = state
19 | mean_kf, cov_kf = xs
20 |
21 | mean_next_hat = params.fz(mean_kf)
22 | cov_next_hat = Dfz(mean_kf) @ cov_kf @ Dfz(mean_kf).T + params.Qz(mean_kf, t)
23 | cov_next_hat_eps = cov_next_hat + eps * jnp.eye(mean_next_hat.shape[0])
24 | kalman_gain = jnp.linalg.solve(cov_next_hat_eps, Dfz(mean_kf).T) @ cov_kf
25 |
26 | mean_prev = mean_kf + kalman_gain @ (mean_next - mean_next_hat)
27 | cov_prev = cov_kf + kalman_gain @ (cov_next - cov_next_hat) @ kalman_gain.T
28 |
29 | prev_state = (mean_prev, cov_prev, t-1)
30 | carry = {"mean": mean_prev, "cov": cov_prev}
31 | carry = {key: val for key, val in carry.items() if key in return_params}
32 |
33 | return prev_state, carry
34 |
35 |
36 | def smooth(params: NLDS,
37 | init_state: chex.Array,
38 | observations: chex.Array,
39 | covariates: chex.Array = None,
40 | Vinit: chex.Array = None,
41 | return_params: List = None,
42 | eps: float = 0.001,
43 | return_filter_history: bool = False,
44 | ) -> Dict[str, Dict[str, chex.Array]]:
45 |
46 | kf_params = ["mean", "cov"]
47 | Dfz = jax.jacrev(params.fz)
48 | _, hist_filter = ekf.filter(params, init_state, observations, covariates, Vinit,
49 | return_params=kf_params, eps=eps, return_history=True)
50 | kf_hist_mean, kf_hist_cov = hist_filter["mean"], hist_filter["cov"]
51 | kf_last_mean, kf_hist_mean = kf_hist_mean[-1], kf_hist_mean[:-1]
52 | kf_last_cov, kf_hist_cov = kf_hist_cov[-1], kf_hist_cov[:-1]
53 |
54 | smooth_step_partial = partial(smooth_step, params=params, Dfz=Dfz,
55 | eps=eps, return_params=return_params)
56 |
57 | init_state = (kf_last_mean, kf_last_cov, len(kf_hist_mean) - 1)
58 | xs = (kf_hist_mean, kf_hist_cov)
59 | _, hist_smooth = jax.lax.scan(smooth_step_partial, init_state, xs, reverse=True)
60 |
61 | hist = {
62 | "smooth": hist_smooth,
63 | "filter": hist_filter if return_filter_history else None
64 | }
65 |
66 | return hist
67 |
--------------------------------------------------------------------------------
/jsl/nlds/sequential_monte_carlo.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax.scipy.stats import norm
4 |
5 | class NonMarkovianSequenceModel:
6 | """
7 | Non-Markovian Gaussian Sequence Model
8 | """
9 | def __init__(self, phi, beta, q, r):
10 | """
11 | Parameters
12 | ----------
13 | phi: float
14 | Multiplicative effect in latent-space
15 | beta: float
16 | Decay relate in observed-space
17 | q: float
18 | Variance in latent-space
19 | r: float
20 | Variance in observed-space
21 | """
22 | self.phi = phi
23 | self.beta = beta
24 | self.q = q
25 | self.r = r
26 |
27 | @staticmethod
28 | def _obtain_weights(log_weights):
29 | weights = jnp.exp(log_weights - jax.nn.logsumexp(log_weights))
30 | return weights
31 |
32 | def sample_latent_step(self, key, x_prev):
33 | x_next = jax.random.normal(key) * jnp.sqrt(self.q) + self.phi * x_prev
34 | return x_next
35 |
36 | def sample_observed_step(self, key, mu, x_curr):
37 | mu_next = self.beta * mu + x_curr
38 | y_curr = jax.random.normal(key) * jnp.sqrt(self.r) + mu_next
39 | return y_curr, mu_next
40 |
41 | def sample_step(self, key, x_prev, mu_prev):
42 | key_latent, key_obs = jax.random.split(key)
43 | x_curr = self.sample_latent_step(key_latent, x_prev)
44 | y_curr, mu = self.sample_observed_step(key_obs, mu_prev, x_curr)
45 |
46 | carry_vals = {"x": x_curr, "y": y_curr}
47 | return (x_curr, mu), carry_vals
48 |
49 | def sample_single(self, key, nsteps):
50 | """
51 | Sample a single path from the non-Markovian Gaussian state-space model.
52 |
53 | Parameters
54 | ----------
55 | key: jax.random.PRNGKey
56 | Initial seed
57 |
58 | """
59 | key_init, key_simul = jax.random.split(key)
60 | x_init = jax.random.normal(key_init) * jnp.sqrt(self.q)
61 | mu_init = 0
62 |
63 | keys = jax.random.split(key_simul, nsteps)
64 | carry_init = (x_init, mu_init)
65 | _, hist = jax.lax.scan(lambda carry, key: self.sample_step(key, *carry), carry_init, keys)
66 | return hist
67 |
68 | def sample(self, key, nsteps, nsims=1):
69 | """
70 | Sample from a non-Markovian Gaussian state-space model.
71 |
72 | Parameters
73 | ----------
74 | key: jax.random.PRNGKey
75 | Initial key to perform the simulation.
76 | nsteps: int
77 | Total number of steps to sample.
78 | nsims: int
79 | Number of paths to sample.
80 | """
81 | key_simulations = jax.random.split(key, nsims)
82 | sample_vmap = jax.vmap(self.sample_single, (0, None))
83 |
84 | simulations = sample_vmap(key_simulations, nsteps)
85 |
86 | # convert to one-dimensional array if only one simulation is
87 | # required
88 | if nsims == 1:
89 | for key, values in simulations.items():
90 | simulations[key] = values.ravel()
91 |
92 | return simulations
93 |
94 | def _sis_step(self, key, log_weights_prev, mu_prev, xparticles_prev, yobs):
95 | """
96 | Compute one step of the sequential-importance-sampling algorithm
97 | at time t.
98 |
99 | Parameters
100 | ----------
101 | key: jax.random.PRNGKey
102 | key to sample particle.
103 | mu_prev: array(n_particles)
104 | Term carrying past cumulate values.
105 | xsamp_prev: array(n_particles)
106 | Samples / particles from the latent space at t-1
107 | yobs: float
108 | Observation at time t.
109 | """
110 |
111 | key_particles = jax.random.split(key, len(xparticles_prev))
112 |
113 | # 1. Sample from proposal
114 | xparticles = jax.vmap(self.sample_latent_step)(key_particles, xparticles_prev)
115 | # 2. Evaluate unnormalised weights
116 | # 2.1 Compute new mean
117 | mu = self.beta * mu_prev + xparticles
118 | # 2.2 Compute log-unnormalised weights
119 |
120 | log_weights = log_weights_prev + norm.logpdf(yobs, loc=mu, scale=jnp.sqrt(self.q))
121 | dict_carry = {
122 | "log_weights": log_weights,
123 | "particles": xparticles,
124 | }
125 |
126 | return (log_weights, mu, xparticles), dict_carry
127 |
128 | def sequential_importance_sample(self, key, observations, n_particles=10):
129 | """
130 | Apply sequential importance sampling (SIS) to a series of observations. Sampling
131 | considers the transition distribution as the proposal.
132 |
133 | Parameters
134 | ----------
135 | key: jax.random.PRNGKey
136 | Initial key.
137 | observations: array(n_observations)
138 | one-array of observed values.
139 | n_particles: int (default: 10)
140 | Total number of particles to consider in the SIS filter.
141 | """
142 | T = len(observations)
143 | key, key_init_particles = jax.random.split(key)
144 | keys = jax.random.split(key, T)
145 |
146 | init_log_weights = jnp.zeros(n_particles)
147 | init_mu = jnp.zeros(n_particles) # equiv. ∀n.wn=1.0
148 | init_xparticles = jax.random.normal(key_init_particles, shape=(n_particles,)) * jnp.sqrt(self.q)
149 |
150 | carry_init = (init_log_weights, init_mu, init_xparticles)
151 | xs_tuple = (keys, observations)
152 | _, dict_hist = jax.lax.scan(lambda carry, xs: self._sis_step(xs[0], *carry, xs[1]), carry_init, xs_tuple)
153 | dict_hist["weights"] = jnp.exp(dict_hist["log_weights"] - jax.nn.logsumexp(dict_hist["log_weights"], axis=1, keepdims=True))
154 |
155 | return dict_hist
156 |
157 | def _smc_step(self, key, log_weights_prev, mu_prev, xparticles_prev, yobs):
158 | n_particles = len(xparticles_prev)
159 | key, key_particles = jax.random.split(key)
160 | key_particles = jax.random.split(key_particles, n_particles)
161 |
162 | # 1. Resample particles
163 | weights = self._obtain_weights(log_weights_prev)
164 | ix_sampled = jax.random.choice(key, n_particles, p=weights, shape=(n_particles,))
165 | xparticles_prev_sampled = xparticles_prev[ix_sampled]
166 | mu_prev_sampled = mu_prev[ix_sampled]
167 | # 2. Propagate particles
168 | xparticles = jax.vmap(self.sample_latent_step)(key_particles, xparticles_prev_sampled)
169 | # 3. Concatenate
170 | mu = self.beta * mu_prev_sampled + xparticles
171 |
172 | # ToDo: return dictionary of log_weights and sampled indices
173 | log_weights = norm.logpdf(yobs, loc=mu, scale=jnp.sqrt(self.q))
174 | dict_carry = {
175 | "log_weights": log_weights,
176 | "indices": ix_sampled,
177 | "particles": xparticles,
178 | }
179 | return (log_weights, mu, xparticles_prev_sampled), dict_carry
180 |
181 | def sequential_monte_carlo(self, key, observations, n_particles=10):
182 | """
183 | Apply sequential Monte Carlo (SCM), a.k.a sequential importance resampling (SIR),
184 | a.k.a sequential importance sampling and resampling(SISR).
185 | """
186 | T = len(observations)
187 | key, key_particle_init = jax.random.split(key)
188 | keys = jax.random.split(key, T)
189 |
190 | init_xparticles = jax.random.normal(key_particle_init, shape=(n_particles,)) * jnp.sqrt(self.q)
191 | init_log_weights = jnp.zeros(n_particles) # equiv. ∀n.wn=1.0
192 | init_mu = jnp.zeros(n_particles)
193 |
194 | carry_init = (init_log_weights, init_mu, init_xparticles)
195 | xs_tuple = (keys, observations)
196 | _, dict_hist = jax.lax.scan(lambda carry, xs: self._smc_step(xs[0], *carry, xs[1]), carry_init, xs_tuple)
197 | # transform log-unnormalised weights to weights
198 | dict_hist["weights"] = jnp.exp(dict_hist["log_weights"] - jax.nn.logsumexp(dict_hist["log_weights"], axis=1, keepdims=True))
199 |
200 | return dict_hist
201 |
--------------------------------------------------------------------------------
/jsl/nlds/unscented_kalman_filter.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of the Unscented Kalman Filter for discrete time systems
3 | """
4 |
5 | import jax.numpy as jnp
6 | from jax.lax import scan
7 |
8 | import chex
9 | from typing import List
10 |
11 | from .base import NLDS
12 |
13 |
14 | def sqrtm(M):
15 | """
16 | Compute the matrix square-root of a hermitian
17 | matrix M. i,e, R such that RR = M
18 |
19 | Parameters
20 | ----------
21 | M: array(m, m)
22 | Hermitian matrix
23 |
24 | Returns
25 | -------
26 | array(m, m): square-root matrix
27 | """
28 | evals, evecs = jnp.linalg.eigh(M)
29 | R = evecs @ jnp.sqrt(jnp.diag(evals)) @ jnp.linalg.inv(evecs)
30 | return R
31 |
32 |
33 | def filter(params: NLDS,
34 | init_state: chex.Array,
35 | sample_obs: chex.Array,
36 | observations: List = None,
37 | Vinit: chex.Array = None,
38 | return_history: bool = True):
39 | """
40 | Run the Unscented Kalman Filter algorithm over a set of observed samples.
41 | Parameters
42 | ----------
43 | sample_obs: array(nsamples, obs_size)
44 | return_history: bool
45 | Whether to return the history of mu and Sigma values.
46 | Returns
47 | -------
48 | * array(nsamples, state_size)
49 | History of filtered mean terms
50 | * array(nsamples, state_size, state_size)
51 | History of filtered covariance terms
52 | """
53 | alpha = params.alpha
54 | beta = params.beta
55 | kappa = params.kappa
56 | d = params.d
57 |
58 | fx, fz = params.fx, params.fz
59 | Q, R = params.Qz, params.Rx
60 |
61 | lmbda = alpha ** 2 * (d + kappa) - d
62 | gamma = jnp.sqrt(d + lmbda)
63 |
64 | wm_vec = jnp.array([1 / (2 * (d + lmbda)) if i > 0
65 | else lmbda / (d + lmbda)
66 | for i in range(2 * d + 1)])
67 | wc_vec = jnp.array([1 / (2 * (d + lmbda)) if i > 0
68 | else lmbda / (d + lmbda) + (1 - alpha ** 2 + beta)
69 | for i in range(2 * d + 1)])
70 | nsteps, *_ = sample_obs.shape
71 | initial_mu_t = init_state
72 | initial_Sigma_t = Q(init_state) if Vinit is None else Vinit
73 |
74 | if observations is None:
75 | observations = iter([()] * nsteps)
76 | else:
77 | observations = iter([(obs,) for obs in observations])
78 |
79 | def filter_step(params, sample_observation):
80 | mu_t, Sigma_t = params
81 | observation = next(observations)
82 |
83 | # TO-DO: use jax.scipy.linalg.sqrtm when it gets added to lib
84 | comp1 = mu_t[:, None] + gamma * sqrtm(Sigma_t)
85 | comp2 = mu_t[:, None] - gamma * sqrtm(Sigma_t)
86 | # sigma_points = jnp.c_[mu_t, comp1, comp2]
87 | sigma_points = jnp.concatenate((mu_t[:, None], comp1, comp2), axis=1)
88 |
89 | z_bar = fz(sigma_points)
90 | mu_bar = z_bar @ wm_vec
91 | Sigma_bar = (z_bar - mu_bar[:, None])
92 | Sigma_bar = jnp.einsum("i,ji,ki->jk", wc_vec, Sigma_bar, Sigma_bar) + Q(mu_t)
93 |
94 | Sigma_bar_half = sqrtm(Sigma_bar)
95 | comp1 = mu_bar[:, None] + gamma * Sigma_bar_half
96 | comp2 = mu_bar[:, None] - gamma * Sigma_bar_half
97 | # sigma_points = jnp.c_[mu_bar, comp1, comp2]
98 | sigma_points = jnp.concatenate((mu_bar[:, None], comp1, comp2), axis=1)
99 |
100 | x_bar = fx(sigma_points, *observation)
101 | x_hat = x_bar @ wm_vec
102 | St = x_bar - x_hat[:, None]
103 | St = jnp.einsum("i,ji,ki->jk", wc_vec, St, St) + R(mu_t, *observation)
104 |
105 | mu_hat_component = z_bar - mu_bar[:, None]
106 | x_hat_component = x_bar - x_hat[:, None]
107 | Sigma_bar_y = jnp.einsum("i,ji,ki->jk", wc_vec, mu_hat_component, x_hat_component)
108 | Kt = Sigma_bar_y @ jnp.linalg.inv(St)
109 |
110 | mu_t = mu_bar + Kt @ (sample_observation - x_hat)
111 | Sigma_t = Sigma_bar - Kt @ St @ Kt.T
112 |
113 | return (mu_t, Sigma_t), (mu_t, Sigma_t)
114 |
115 | (mu, Sigma), (mu_hist, Sigma_hist) = scan(filter_step, (initial_mu_t, initial_Sigma_t), sample_obs[1:])
116 |
117 | mu_hist = jnp.vstack([initial_mu_t[None, ...], mu_hist])
118 | Sigma_hist = jnp.vstack([initial_Sigma_t[None, ...], Sigma_hist])
119 |
120 | if return_history:
121 | return mu_hist, Sigma_hist
122 | return mu, Sigma
123 |
--------------------------------------------------------------------------------
/jsl/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name="snb", version="0.0.1", install_requires=["gym"] )
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="jsl",
5 | packages=find_packages(),
6 | install_requires=[
7 | "chex",
8 | "dataclasses",
9 | "jaxlib",
10 | "jax",
11 | "matplotlib",
12 | "tensorflow_probability"
13 | ]
14 | )
15 |
--------------------------------------------------------------------------------