├── .gitattributes ├── .github └── workflows │ ├── build_documentation.yml │ ├── interrogate.yml │ ├── publish-to-pypi.yml │ ├── publish_documentation.yml │ └── run_tests.yml ├── .gitignore ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── README.md ├── api.rst ├── conf.py ├── figures │ ├── .gitignore │ ├── LDS-UZY.png │ ├── SSM-simplified.png │ ├── casino.png │ ├── hmmDgmPlatesY.png │ ├── lgssm_parallel_smoothing_timing.png │ ├── pendulum.png │ ├── rlsDgm.png │ └── rlsDgmW.png ├── index.rst ├── make.bat ├── notebooks │ ├── generalized_gaussian_ssm │ │ ├── cmgf_logistic_regression_demo.ipynb │ │ ├── cmgf_mlp_classification_demo.ipynb │ │ └── cmgf_poisson_demo.ipynb │ ├── hmm │ │ ├── autoregressive_hmm.ipynb │ │ ├── casino_hmm_inference.ipynb │ │ ├── casino_hmm_learning.ipynb │ │ ├── custom_hmm.ipynb │ │ └── gaussian_hmm.ipynb │ ├── linear_gaussian_ssm │ │ ├── kf_linreg.ipynb │ │ ├── kf_tracking.ipynb │ │ ├── lgssm_hmc.ipynb │ │ ├── lgssm_learning.ipynb │ │ └── lgssm_parallel_inference.ipynb │ ├── nonlinear_gaussian_ssm │ │ ├── ekf_mlp.ipynb │ │ ├── ekf_ukf_pendulum.ipynb │ │ └── ekf_ukf_spiral.ipynb │ └── slds │ │ └── rbpf_maneuver.ipynb └── types.md ├── dynamax ├── __init__.py ├── _version.py ├── generalized_gaussian_ssm │ ├── README.md │ ├── __init__.py │ ├── dekf │ │ └── README.md │ ├── demos │ │ ├── __init__.py │ │ ├── cmgf_logreg_estimator.py │ │ ├── cmgf_multiclass_logreg_demo.ipynb │ │ └── dirichlet_kalman-filter_demo.ipynb │ ├── inference.py │ ├── inference_test.py │ ├── models.py │ └── models_test.py ├── hidden_markov_model │ ├── README.md │ ├── __init__.py │ ├── demos │ │ ├── bach_chorales_hmm.ipynb │ │ ├── bernoulli_hmm_example.ipynb │ │ ├── categorical_glm_hmm_demo.py │ │ ├── fixed_lag_smoother_hmm.ipynb │ │ ├── low_rank_gaussian_hmm.ipynb │ │ ├── multinomial_hmm.ipynb │ │ ├── parallel_message_passing.ipynb │ │ ├── poisson_hmm_changepoint.ipynb │ │ ├── poisson_hmm_earthquakes.py │ │ ├── poisson_hmm_neurons.ipynb │ │ └── switching_linear_regression.ipynb │ ├── inference.py │ ├── inference_test.py │ ├── models │ │ ├── __init__.py │ │ ├── abstractions.py │ │ ├── arhmm.py │ │ ├── bernoulli_hmm.py │ │ ├── categorical_glm_hmm.py │ │ ├── categorical_hmm.py │ │ ├── gamma_hmm.py │ │ ├── gaussian_hmm.py │ │ ├── gmm_hmm.py │ │ ├── initial.py │ │ ├── linreg_hmm.py │ │ ├── logreg_hmm.py │ │ ├── multinomial_hmm.py │ │ ├── poisson_hmm.py │ │ ├── test_models.py │ │ └── transitions.py │ └── parallel_inference.py ├── linear_gaussian_ssm │ ├── README.md │ ├── __init__.py │ ├── demos │ │ ├── __init__.py │ │ └── lgssm_blocked_gibbs.ipynb │ ├── inference.py │ ├── inference_test.py │ ├── info_inference.py │ ├── info_inference_test.py │ ├── models.py │ ├── models_test.py │ ├── parallel_inference.py │ └── parallel_inference_test.py ├── nonlinear_gaussian_ssm │ ├── README.md │ ├── __init__.py │ ├── inference_ekf.py │ ├── inference_ekf_test.py │ ├── inference_test_utils.py │ ├── inference_ukf.py │ ├── inference_ukf_test.py │ ├── models.py │ └── sarkka_lib.py ├── parameters.py ├── parameters_test.py ├── slds │ ├── __init__.py │ ├── inference.py │ ├── inference_test.py │ ├── mixture_kalman_filter_demo.py │ └── models.py ├── ssm.py ├── types.py ├── utils │ ├── __init__.py │ ├── bijectors.py │ ├── distributions.py │ ├── distributions_test.py │ ├── optimize.py │ ├── plotting.py │ ├── test_optimize.py │ ├── utils.py │ └── utils_test.py └── warnings.py ├── logo ├── dynamax.ai ├── dynamax.png ├── logo.gif ├── make_logo.ipynb └── mask.png ├── paper ├── paper.bib ├── paper.md └── paper.pdf ├── pyproject.toml ├── setup.py └── versioneer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | dynamax/_version.py export-subst 2 | *.ipynb linguist-documentation 3 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build the documentation 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | build: 11 | name: Build 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Checkout the branch 15 | uses: actions/checkout@v2.3.1 16 | with: 17 | persist-credentials: false 18 | 19 | - name: Set up Python 3.11 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: 3.11 23 | 24 | - name: Build the documentation with Sphinx 25 | run: | 26 | pip install -e '.[dev]' 27 | sphinx-build -b html docs docs/build/html -------------------------------------------------------------------------------- /.github/workflows/interrogate.yml: -------------------------------------------------------------------------------- 1 | name: Docstrings 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | Workflow: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Clone the reference repository 16 | uses: actions/checkout@v3.5.2 17 | 18 | - name: Set up Python 3.10.6 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: '3.10.6' 22 | 23 | - name: Install dependencies 24 | run: | 25 | pip install -e '.[dev]' 26 | 27 | - name: Run docstring checks 28 | run: interrogate -vv dynamax --fail-under 66 29 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python distribution to PyPI and TestPyPI 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | release: 7 | types: 8 | - created 9 | 10 | jobs: 11 | build: 12 | name: Build distribution 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | with: 18 | persist-credentials: false 19 | - name: Set up Python 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: "3.x" 23 | - name: Install pypa/build 24 | run: >- 25 | python3 -m 26 | pip install 27 | build 28 | --user 29 | - name: Build a binary wheel and a source tarball 30 | run: python3 -m build 31 | - name: Store the distribution packages 32 | uses: actions/upload-artifact@v4 33 | with: 34 | name: python-package-distributions 35 | path: dist/ 36 | 37 | publish-to-pypi: 38 | name: >- 39 | Publish Python 🐍 distribution 📦 to PyPI 40 | if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes 41 | needs: 42 | - build 43 | runs-on: ubuntu-latest 44 | environment: 45 | name: pypi 46 | url: https://pypi.org/p/dynamax # Replace with your PyPI project name 47 | permissions: 48 | id-token: write # IMPORTANT: mandatory for trusted publishing 49 | 50 | steps: 51 | - name: Download all the dists 52 | uses: actions/download-artifact@v4 53 | with: 54 | name: python-package-distributions 55 | path: dist/ 56 | - name: Publish distribution 📦 to PyPI 57 | uses: pypa/gh-action-pypi-publish@release/v1 58 | 59 | publish-to-testpypi: 60 | name: Publish Python 🐍 distribution 📦 to TestPyPI 61 | needs: 62 | - build 63 | runs-on: ubuntu-latest 64 | 65 | environment: 66 | name: testpypi 67 | url: https://test.pypi.org/p/dynamax 68 | 69 | permissions: 70 | id-token: write # IMPORTANT: mandatory for trusted publishing 71 | 72 | steps: 73 | - name: Download all the dists 74 | uses: actions/download-artifact@v4 75 | with: 76 | name: python-package-distributions 77 | path: dist/ 78 | - name: Publish distribution 📦 to TestPyPI 79 | uses: pypa/gh-action-pypi-publish@release/v1 80 | with: 81 | repository-url: https://test.pypi.org/legacy/ 82 | -------------------------------------------------------------------------------- /.github/workflows/publish_documentation.yml: -------------------------------------------------------------------------------- 1 | 2 | name: Publish the documentation 3 | 4 | on: 5 | workflow_dispatch: 6 | 7 | push: 8 | branches: 9 | - main 10 | 11 | release: 12 | types: [published] 13 | 14 | jobs: 15 | publish: 16 | name: Publish 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout the branch 20 | uses: actions/checkout@v2.3.1 21 | with: 22 | persist-credentials: false 23 | 24 | - name: Set up Python 3.11 25 | uses: actions/setup-python@v1 26 | with: 27 | python-version: 3.11 28 | 29 | - name: Build the documentation with Sphinx 30 | run: | 31 | pip install -e '.[dev]' 32 | sphinx-build -b html docs docs/build/html 33 | 34 | - name: Publish the documentation 35 | uses: JamesIves/github-pages-deploy-action@3.6.2 36 | with: 37 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 38 | BRANCH: gh-pages 39 | FOLDER: docs/build/html 40 | CLEAN: true -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | pull_request: 9 | branches: 10 | - main 11 | 12 | jobs: 13 | Workflow: 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | matrix: 17 | # Select the Python versions to test against 18 | os: ["ubuntu-latest", "macos-latest"] 19 | python-version: ["3.10", "3.11", "3.12"] 20 | name: ${{ matrix.os }} with Python ${{ matrix.python-version }} 21 | steps: 22 | - name: Clone the reference repository 23 | uses: actions/checkout@v3.5.2 24 | 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | 30 | - name: Install dependencies 31 | run: | 32 | pip install -e '.[test]' 33 | 34 | - name: Run tests 35 | run: pytest --cov=./ 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | docs/_build/ 3 | build/ 4 | dist/ 5 | 6 | .coverage 7 | 8 | *egg-info 9 | *.ipynb_checkpoints 10 | 11 | # ignore figures unless manually added 12 | *.png 13 | *.jpg 14 | *-dot 15 | *.DS_Store 16 | .vscode/ 17 | .venv -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: "1.2.0" 2 | authors: 3 | - family-names: Linderman 4 | given-names: Scott W. 5 | orcid: "https://orcid.org/0000-0002-3878-9073" 6 | - family-names: Chang 7 | given-names: Peter 8 | - family-names: Harper-Donnelly 9 | given-names: Giles 10 | - family-names: Kara 11 | given-names: Aleyna 12 | - family-names: Li 13 | given-names: Xinglong 14 | - family-names: Duran-Martin 15 | given-names: Gerardo 16 | - family-names: Murphy 17 | given-names: Kevin 18 | contact: 19 | - family-names: Linderman 20 | given-names: Scott W. 21 | orcid: "https://orcid.org/0000-0002-3878-9073" 22 | - family-names: Murphy 23 | given-names: Kevin 24 | doi: 10.6084/m9.figshare.28665131 25 | message: If you use this software, please cite our article in the 26 | Journal of Open Source Software. 27 | preferred-citation: 28 | authors: 29 | - family-names: Linderman 30 | given-names: Scott W. 31 | orcid: "https://orcid.org/0000-0002-3878-9073" 32 | - family-names: Chang 33 | given-names: Peter 34 | - family-names: Harper-Donnelly 35 | given-names: Giles 36 | - family-names: Kara 37 | given-names: Aleyna 38 | - family-names: Li 39 | given-names: Xinglong 40 | - family-names: Duran-Martin 41 | given-names: Gerardo 42 | - family-names: Murphy 43 | given-names: Kevin 44 | date-published: 2025-04-03 45 | doi: 10.21105/joss.07069 46 | issn: 2475-9066 47 | issue: 108 48 | journal: Journal of Open Source Software 49 | publisher: 50 | name: Open Journals 51 | start: 7069 52 | title: "Dynamax: A Python package for probabilistic state space 53 | modeling with JAX" 54 | type: article 55 | url: "https://joss.theoj.org/papers/10.21105/joss.07069" 56 | volume: 10 57 | title: "Dynamax: A Python package for probabilistic state space modeling 58 | with JAX" -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions (pull requests) are very welcome! 4 | 5 | ## How to contribute 6 | 7 | First fork the library on GitHub. 8 | 9 | Then clone and install the library in development mode: 10 | 11 | ```bash 12 | git clone https://github.com/your-username-here/dynamax.git 13 | cd dynamax 14 | pip install -e '.[dev]' 15 | ``` 16 | 17 | Now make your changes. Make sure to include additional tests if necessary. 18 | 19 | Next verify the tests all pass: 20 | 21 | ```bash 22 | pip install pytest 23 | pytest 24 | ``` 25 | 26 | Then commit your changes and push back to your fork of the repository: 27 | 28 | ```bash 29 | git add 30 | git commit -m "" 31 | git push 32 | ``` 33 | 34 | Finally, open a pull request on GitHub! 35 | 36 | ## What to contribute 37 | 38 | Please see this [list of open issues](https://github.com/probml/dynamax/issues), 39 | especially ones tagges as "help wanted". 40 | 41 | 42 | 43 | ## Contributor License Agreement 44 | 45 | Contributions to this project means that the contributors agree to releasing the contributions under the MIT license. 46 | 47 | 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Probabilistic machine learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include versioneer.py 2 | include dynamax/_version.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Welcome to DYNAMAX! 2 | 3 | ![Logo](https://raw.githubusercontent.com/probml/dynamax/main/logo/logo.gif) 4 | 5 | ![Test Status](https://github.com/probml/dynamax/actions/workflows/run_tests.yml/badge.svg?branch=main) 6 | ![Docstrings](https://github.com/probml/dynamax/actions/workflows/interrogate.yml/badge.svg) 7 | [![DOI](https://joss.theoj.org/papers/10.21105/joss.07069/status.svg)](https://doi.org/10.21105/joss.07069) 8 | 9 | Dynamax is a library for probabilistic state space models (SSMs) written 10 | in [JAX](https://github.com/google/jax). It has code for inference 11 | (state estimation) and learning (parameter estimation) in a variety of 12 | SSMs, including: 13 | 14 | - Hidden Markov Models (HMMs) 15 | - Linear Gaussian State Space Models (aka Linear Dynamical Systems) 16 | - Nonlinear Gaussian State Space Models 17 | - Generalized Gaussian State Space Models (with non-Gaussian emission 18 | models) 19 | 20 | The library consists of a set of core, functionally pure, low-level 21 | inference algorithms, as well as a set of model classes which provide a 22 | more user-friendly, object-oriented interface. It is compatible with 23 | other libraries in the JAX ecosystem, such as 24 | [optax](https://github.com/deepmind/optax) (used for estimating 25 | parameters using stochastic gradient descent), and 26 | [Blackjax](https://github.com/blackjax-devs/blackjax) (used for 27 | computing the parameter posterior using Hamiltonian Monte Carlo (HMC) or 28 | sequential Monte Carlo (SMC)). 29 | 30 | ## Documentation 31 | 32 | For tutorials and API documentation, see: https://probml.github.io/dynamax/. 33 | 34 | For an extension of dynamax that supports structural time series models, 35 | see https://github.com/probml/sts-jax. 36 | 37 | For an illustration of how to use dynamax inside of [bayeux](https://jax-ml.github.io/bayeux/) to perform Bayesian inference 38 | for the parameters of an SSM, see https://jax-ml.github.io/bayeux/examples/dynamax_and_bayeux/. 39 | 40 | ## Installation and Testing 41 | 42 | To install the latest releast of dynamax from PyPi: 43 | 44 | ``` {.console} 45 | pip install dynamax # Install dynamax and core dependencies, or 46 | pip install dynamax[notebooks] # Install with demo notebook dependencies 47 | ``` 48 | 49 | To install the latest development branch: 50 | 51 | ``` {.console} 52 | pip install git+https://github.com/probml/dynamax.git 53 | ``` 54 | 55 | Finally, if you\'re a developer, you can install dynamax along with the 56 | test and documentation dependencies with: 57 | 58 | ``` {.console} 59 | git clone git@github.com:probml/dynamax.git 60 | cd dynamax 61 | pip install -e '.[dev]' 62 | ``` 63 | 64 | To run the tests: 65 | 66 | ``` {.console} 67 | pytest dynamax # Run all tests 68 | pytest dynamax/hmm/inference_test.py # Run a specific test 69 | pytest -k lgssm # Run tests with lgssm in the name 70 | ``` 71 | 72 | ## What are state space models? 73 | 74 | A state space model or SSM is a partially observed Markov model, in 75 | which the hidden state, $z_t$, evolves over time according to a Markov 76 | process, possibly conditional on external inputs / controls / 77 | covariates, $u_t$, and generates an observation, $y_t$. This is 78 | illustrated in the graphical model below. 79 | 80 |

81 | 82 |

83 | 84 | The corresponding joint distribution has the following form (in dynamax, 85 | we restrict attention to discrete time systems): 86 | 87 | $$p(y_{1:T}, z_{1:T} \mid u_{1:T}) = p(z_1 \mid u_1) \prod_{t=2}^T p(z_t \mid z_{t-1}, u_t) \prod_{t=1}^T p(y_t \mid z_t, u_t)$$ 88 | 89 | Here $p(z_t | z_{t-1}, u_t)$ is called the transition or dynamics model, 90 | and $p(y_t | z_{t}, u_t)$ is called the observation or emission model. 91 | In both cases, the inputs $u_t$ are optional; furthermore, the 92 | observation model may have auto-regressive dependencies, in which case 93 | we write $p(y_t | z_{t}, u_t, y_{1:t-1})$. 94 | 95 | We assume that we see the observations $y_{1:T}$, and want to infer the 96 | hidden states, either using online filtering (i.e., computing 97 | $p(z_t|y_{1:t})$ ) or offline smoothing (i.e., computing 98 | $p(z_t|y_{1:T})$ ). We may also be interested in predicting future 99 | states, $p(z_{t+h}|y_{1:t})$, or future observations, 100 | $p(y_{t+h}|y_{1:t})$, where h is the forecast horizon. (Note that by 101 | using a hidden state to represent the past observations, the model can 102 | have \"infinite\" memory, unlike a standard auto-regressive model.) All 103 | of these computations can be done efficiently using our library, as we 104 | discuss below. In addition, we can estimate the parameters of the 105 | transition and emission models, as we discuss below. 106 | 107 | More information can be found in these books: 108 | 109 | > - \"Machine Learning: Advanced Topics\", K. Murphy, MIT Press 2023. 110 | > Available at . 111 | > - \"Bayesian Filtering and Smoothing, Second Edition\", S. Särkkä and L. Svensson, Cambridge 112 | > University Press, 2023. Available at 113 | > 114 | 115 | ## Example usage 116 | 117 | Dynamax includes classes for many kinds of SSM. You can use these models 118 | to simulate data, and you can fit the models using standard learning 119 | algorithms like expectation-maximization (EM) and stochastic gradient 120 | descent (SGD). Below we illustrate the high level (object-oriented) API 121 | for the case of an HMM with Gaussian emissions. (See [this 122 | notebook](https://github.com/probml/dynamax/blob/main/docs/notebooks/hmm/gaussian_hmm.ipynb) 123 | for a runnable version of this code.) 124 | 125 | ```python 126 | import jax.numpy as jnp 127 | import jax.random as jr 128 | import matplotlib.pyplot as plt 129 | from dynamax.hidden_markov_model import GaussianHMM 130 | 131 | key1, key2, key3 = jr.split(jr.PRNGKey(0), 3) 132 | num_states = 3 133 | emission_dim = 2 134 | num_timesteps = 1000 135 | 136 | # Make a Gaussian HMM and sample data from it 137 | hmm = GaussianHMM(num_states, emission_dim) 138 | true_params, _ = hmm.initialize(key1) 139 | true_states, emissions = hmm.sample(true_params, key2, num_timesteps) 140 | 141 | # Make a new Gaussian HMM and fit it with EM 142 | params, props = hmm.initialize(key3, method="kmeans", emissions=emissions) 143 | params, lls = hmm.fit_em(params, props, emissions, num_iters=20) 144 | 145 | # Plot the marginal log probs across EM iterations 146 | plt.plot(lls) 147 | plt.xlabel("EM iterations") 148 | plt.ylabel("marginal log prob.") 149 | 150 | # Use fitted model for posterior inference 151 | post = hmm.smoother(params, emissions) 152 | print(post.smoothed_probs.shape) # (1000, 3) 153 | ``` 154 | 155 | JAX allows you to easily vectorize these operations with `vmap`. 156 | For example, you can sample and fit to a batch of emissions as shown below. 157 | 158 | ```python 159 | from functools import partial 160 | from jax import vmap 161 | 162 | num_seq = 200 163 | batch_true_states, batch_emissions = \ 164 | vmap(partial(hmm.sample, true_params, num_timesteps=num_timesteps))( 165 | jr.split(key2, num_seq)) 166 | print(batch_true_states.shape, batch_emissions.shape) # (200,1000) and (200,1000,2) 167 | 168 | # Make a new Gaussian HMM and fit it with EM 169 | params, props = hmm.initialize(key3, method="kmeans", emissions=batch_emissions) 170 | params, lls = hmm.fit_em(params, props, batch_emissions, num_iters=20) 171 | ``` 172 | 173 | These examples demonstrate the dynamax models, but we can also call the low-level 174 | inference code directly. 175 | 176 | ## Contributing 177 | 178 | Please see [this page](https://github.com/probml/dynamax/blob/main/CONTRIBUTING.md) for details 179 | on how to contribute. 180 | 181 | ## About 182 | Core team: Peter Chang, Giles Harper-Donnelly, Aleyna Kara, Xinglong Li, Scott Linderman, Kevin Murphy. 183 | 184 | Other contributors: Adrien Corenflos, Elizabeth DuPre, Gerardo Duran-Martin, Colin Schlager, Libby Zhang and other people [listed here](https://github.com/probml/dynamax/graphs/contributors) 185 | 186 | MIT License. 2022 187 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | To build the sphinx documentation, run 4 | ``` 5 | cd docs/ 6 | make html 7 | ``` 8 | This will find all the jupyter notebooks, run them, collect the output, and incorporate it into the documentation. 9 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | State Space Model (Base class) 2 | =============================== 3 | 4 | .. autoclass:: dynamax.ssm.SSM 5 | :members: 6 | 7 | Parameters 8 | ---------- 9 | 10 | Parameters and their associated properties are stored as :class:`jax.DeviceArray` 11 | and :class:`dynamax.parameters.ParameterProperties`, respectively. They are bundled together into a 12 | :class:`dynamax.parameters.ParameterSet` and a :class:`dynamax.parameters.PropertySet`, which are simply 13 | aliases for immutable datastructures (in our case, :class:`NamedTuple`). 14 | 15 | .. autoclass:: dynamax.parameters.ParameterSet 16 | .. autoclass:: dynamax.parameters.PropertySet 17 | .. autoclass:: dynamax.parameters.ParameterProperties 18 | 19 | Hidden Markov Model 20 | =================== 21 | 22 | Abstract classes 23 | ------------------ 24 | 25 | .. autoclass:: dynamax.hidden_markov_model.HMM 26 | :show-inheritance: 27 | :members: 28 | 29 | .. autoclass:: dynamax.hidden_markov_model.HMMInitialState 30 | :members: 31 | 32 | .. autoclass:: dynamax.hidden_markov_model.HMMTransitions 33 | :members: 34 | 35 | .. autoclass:: dynamax.hidden_markov_model.HMMEmissions 36 | :members: 37 | 38 | High-level models 39 | ----------------- 40 | 41 | The HMM implementations below cover common emission distributions and, 42 | if the emissions are exponential family distributions, the models implement 43 | closed form EM updates. For HMMs with emissions outside the non-exponential family, 44 | these models default to a generic M-step implemented in :class:`HMMEmissions`. 45 | 46 | Unless otherwise specified, these models have standard initial distributions and 47 | transition distributions with conjugate, Bayesian priors on their parameters. 48 | 49 | **Initial distribution:** 50 | 51 | $$p(z_1 \mid \pi_1) = \mathrm{Cat}(z_1 \mid \pi_1)$$ 52 | $$p(\pi_1) = \mathrm{Dir}(\pi_1 \mid \alpha 1_K)$$ 53 | 54 | where $\alpha$ is the prior concentration on the initial distribution $\pi_1$. 55 | 56 | **Transition distribution:** 57 | 58 | $$p(z_t \mid z_{t-1}, \theta) = \mathrm{Cat}(z_t \mid A_{z_{t-1}})$$ 59 | $$p(A) = \prod_{k=1}^K \mathrm{Dir}(A_k \mid \beta 1_K + \kappa e_k)$$ 60 | 61 | where $\beta$ is the prior concentration on the rows of the transition matrix $A$ 62 | and $\kappa$ is the `stickiness`, which biases the prior toward transition matrices 63 | with larger values along the diagonal. 64 | 65 | These hyperparameters can be specified in the HMM constructors, and they 66 | default to weak priors without any stickiness. 67 | 68 | 69 | .. autoclass:: dynamax.hidden_markov_model.BernoulliHMM 70 | :show-inheritance: 71 | :members: initialize 72 | 73 | .. autoclass:: dynamax.hidden_markov_model.CategoricalHMM 74 | :show-inheritance: 75 | :members: initialize 76 | 77 | .. autoclass:: dynamax.hidden_markov_model.GammaHMM 78 | :show-inheritance: 79 | :members: initialize 80 | 81 | .. autoclass:: dynamax.hidden_markov_model.GaussianHMM 82 | :show-inheritance: 83 | :members: initialize 84 | 85 | .. autoclass:: dynamax.hidden_markov_model.DiagonalGaussianHMM 86 | :show-inheritance: 87 | :members: initialize 88 | 89 | .. autoclass:: dynamax.hidden_markov_model.SphericalGaussianHMM 90 | :show-inheritance: 91 | :members: initialize 92 | 93 | .. autoclass:: dynamax.hidden_markov_model.SharedCovarianceGaussianHMM 94 | :show-inheritance: 95 | :members: initialize 96 | 97 | .. autoclass:: dynamax.hidden_markov_model.LowRankGaussianHMM 98 | :show-inheritance: 99 | :members: initialize 100 | 101 | .. autoclass:: dynamax.hidden_markov_model.MultinomialHMM 102 | :show-inheritance: 103 | :members: initialize 104 | 105 | .. autoclass:: dynamax.hidden_markov_model.PoissonHMM 106 | :show-inheritance: 107 | :members: initialize 108 | 109 | .. autoclass:: dynamax.hidden_markov_model.GaussianMixtureHMM 110 | :show-inheritance: 111 | :members: initialize 112 | 113 | .. autoclass:: dynamax.hidden_markov_model.DiagonalGaussianMixtureHMM 114 | :show-inheritance: 115 | :members: initialize 116 | 117 | .. autoclass:: dynamax.hidden_markov_model.LinearRegressionHMM 118 | :show-inheritance: 119 | :members: initialize 120 | 121 | .. autoclass:: dynamax.hidden_markov_model.LogisticRegressionHMM 122 | :show-inheritance: 123 | :members: initialize 124 | 125 | .. autoclass:: dynamax.hidden_markov_model.CategoricalRegressionHMM 126 | :show-inheritance: 127 | :members: initialize 128 | 129 | .. autoclass:: dynamax.hidden_markov_model.LinearAutoregressiveHMM 130 | :show-inheritance: 131 | :members: initialize, sample, compute_inputs 132 | 133 | Low-level inference 134 | ------------------- 135 | 136 | .. autoclass:: dynamax.hidden_markov_model.HMMPosterior 137 | .. autoclass:: dynamax.hidden_markov_model.HMMPosteriorFiltered 138 | 139 | .. autofunction:: dynamax.hidden_markov_model.hmm_filter 140 | .. autofunction:: dynamax.hidden_markov_model.hmm_smoother 141 | .. autofunction:: dynamax.hidden_markov_model.hmm_two_filter_smoother 142 | .. autofunction:: dynamax.hidden_markov_model.hmm_fixed_lag_smoother 143 | .. autofunction:: dynamax.hidden_markov_model.hmm_posterior_mode 144 | .. autofunction:: dynamax.hidden_markov_model.hmm_posterior_sample 145 | .. autofunction:: dynamax.hidden_markov_model.parallel_hmm_filter 146 | .. autofunction:: dynamax.hidden_markov_model.parallel_hmm_smoother 147 | 148 | Types 149 | ----- 150 | 151 | .. autoclass:: dynamax.hidden_markov_model.HMMParameterSet 152 | .. autoclass:: dynamax.hidden_markov_model.HMMPropertySet 153 | 154 | 155 | Linear Gaussian SSM 156 | ==================== 157 | 158 | High-level class 159 | ---------------- 160 | 161 | .. autoclass:: dynamax.linear_gaussian_ssm.LinearGaussianSSM 162 | :members: 163 | 164 | Low-level inference 165 | ------------------- 166 | 167 | .. autofunction:: dynamax.linear_gaussian_ssm.lgssm_filter 168 | .. autofunction:: dynamax.linear_gaussian_ssm.lgssm_smoother 169 | .. autofunction:: dynamax.linear_gaussian_ssm.lgssm_posterior_sample 170 | 171 | Types 172 | ----- 173 | 174 | .. autoclass:: dynamax.linear_gaussian_ssm.ParamsLGSSM 175 | .. autoclass:: dynamax.linear_gaussian_ssm.ParamsLGSSMInitial 176 | .. autoclass:: dynamax.linear_gaussian_ssm.ParamsLGSSMDynamics 177 | .. autoclass:: dynamax.linear_gaussian_ssm.ParamsLGSSMEmissions 178 | 179 | .. autoclass:: dynamax.linear_gaussian_ssm.PosteriorGSSMFiltered 180 | .. autoclass:: dynamax.linear_gaussian_ssm.PosteriorGSSMSmoothed 181 | 182 | Nonlinear Gaussian GSSM 183 | ======================== 184 | 185 | 186 | High-level class 187 | ---------------- 188 | 189 | .. autoclass:: dynamax.nonlinear_gaussian_ssm.NonlinearGaussianSSM 190 | :members: 191 | 192 | Low-level inference 193 | ------------------- 194 | 195 | .. autofunction:: dynamax.nonlinear_gaussian_ssm.extended_kalman_filter 196 | .. autofunction:: dynamax.nonlinear_gaussian_ssm.extended_kalman_smoother 197 | 198 | .. autofunction:: dynamax.nonlinear_gaussian_ssm.unscented_kalman_filter 199 | .. autofunction:: dynamax.nonlinear_gaussian_ssm.unscented_kalman_smoother 200 | 201 | Types 202 | ----- 203 | 204 | .. autoclass:: dynamax.nonlinear_gaussian_ssm.ParamsNLGSSM 205 | 206 | 207 | Generalized Gaussian GSSM 208 | ========================== 209 | 210 | High-level class 211 | ---------------- 212 | 213 | .. autoclass:: dynamax.generalized_gaussian_ssm.GeneralizedGaussianSSM 214 | :members: 215 | 216 | Low-level inference 217 | ------------------- 218 | 219 | .. autofunction:: dynamax.generalized_gaussian_ssm.conditional_moments_gaussian_filter 220 | .. autofunction:: dynamax.generalized_gaussian_ssm.conditional_moments_gaussian_smoother 221 | 222 | Types 223 | ----- 224 | 225 | .. autoclass:: dynamax.generalized_gaussian_ssm.ParamsGGSSM 226 | 227 | Utilities 228 | ========= 229 | 230 | .. autofunction:: dynamax.utils.utils.find_permutation 231 | 232 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'dynamax' 21 | copyright = '2022, Peter Chang, Giles Harper-Donnelly, Aleyna Kara, Xinglong Li, Scott Linderman, and Kevin Murphy' 22 | author = 'Peter Chang, Giles Harper-Donnelly, Aleyna Kara, Xinglong Li, Scott Linderman, and Kevin Murphy' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | extensions = [ 28 | "sphinx.ext.autodoc", 29 | "sphinx.ext.autosummary", 30 | "sphinx.ext.napoleon", 31 | "sphinx.ext.intersphinx", 32 | "sphinx.ext.viewcode", 33 | "sphinx_math_dollar", 34 | "sphinx.ext.mathjax", 35 | "myst_nb", 36 | ] 37 | 38 | intersphinx_mapping = { 39 | "python": ("https://docs.python.org/3", None), 40 | "numpy": ("https://numpy.org/doc/stable", None), 41 | "jax": ("https://jax.readthedocs.io/en/latest", None), 42 | } 43 | 44 | source_suffix = { 45 | '.rst': 'restructuredtext', 46 | '.myst': 'myst-nb', 47 | '.ipynb': 'myst-nb' 48 | } 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ['_templates'] 52 | 53 | # List of patterns, relative to source directory, that match files and 54 | # directories to ignore when looking for source files. 55 | # This pattern also affects html_static_path and html_extra_path. 56 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'notebooks/slds/rbpf_maneuver.ipynb'] 57 | 58 | nb_execution_allow_errors = False 59 | 60 | # Myst-NB 61 | myst_enable_extensions = [ 62 | "dollarmath", 63 | "amsmath", 64 | "deflist", 65 | "colon_fence", 66 | # "html_admonition", 67 | # "html_image", 68 | # "smartquotes", 69 | # "replacements", 70 | # "linkify", 71 | # "substitution", 72 | ] 73 | nb_execution_timeout = 600 74 | nb_execution_mode = "cache" 75 | 76 | # -- Options for HTML output ------------------------------------------------- 77 | 78 | # The theme to use for HTML and HTML Help pages. See the documentation for 79 | # a list of builtin themes. 80 | # 81 | html_title = "" 82 | html_logo = "../logo/logo.gif" 83 | html_theme = 'sphinx_book_theme' 84 | html_theme_options = { 85 | 'repository_url': 'https://github.com/probml/dynamax', 86 | "use_repository_button": True, 87 | "use_download_button": False, 88 | 'repository_branch': 'main', 89 | "path_to_docs": 'docs', 90 | 'launch_buttons': { 91 | 'colab_url': 'https://colab.research.google.com', 92 | 'binderhub_url': 'https://mybinder.org' 93 | }, 94 | } 95 | 96 | # Add any paths that contain custom static files (such as style sheets) here, 97 | # relative to this directory. They are copied after the builtin static files, 98 | # so a file named "default.css" will overwrite the builtin "default.css". 99 | html_static_path = ['_static'] 100 | 101 | 102 | autosummary_generate = True 103 | autodoc_typehints = "description" 104 | add_module_names = False 105 | autodoc_member_order = "bysource" -------------------------------------------------------------------------------- /docs/figures/.gitignore: -------------------------------------------------------------------------------- 1 | # allow png files by undoing parent restriction 2 | !*.png 3 | -------------------------------------------------------------------------------- /docs/figures/LDS-UZY.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/docs/figures/LDS-UZY.png -------------------------------------------------------------------------------- /docs/figures/SSM-simplified.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/docs/figures/SSM-simplified.png -------------------------------------------------------------------------------- /docs/figures/casino.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/docs/figures/casino.png -------------------------------------------------------------------------------- /docs/figures/hmmDgmPlatesY.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/docs/figures/hmmDgmPlatesY.png -------------------------------------------------------------------------------- /docs/figures/lgssm_parallel_smoothing_timing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/docs/figures/lgssm_parallel_smoothing_timing.png -------------------------------------------------------------------------------- /docs/figures/pendulum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/docs/figures/pendulum.png -------------------------------------------------------------------------------- /docs/figures/rlsDgm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/docs/figures/rlsDgm.png -------------------------------------------------------------------------------- /docs/figures/rlsDgmW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/docs/figures/rlsDgmW.png -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/types.md: -------------------------------------------------------------------------------- 1 | # Terminology for types 2 | 3 | Dynamax uses [jaxtyping](https://github.com/google/jaxtyping), that provides type declarations 4 | for JAX arrays. These declarations can be checked at run time by using other libraries, 5 | such as [beartype](https://github.com/beartype/beartype) 6 | or [typeguard](https://github.com/agronholm/typeguard). However, currently the run-time checking is 7 | disabled, so the declarations are just for documentation purposes. 8 | 9 | Following [diffrax](https://docs.kidger.site/diffrax/api/type_terminology/), 10 | our API documentation uses a few convenient shorthands for some types. 11 | 12 | - `Scalar` refers to either an `int`, `float`, or a JAX array with shape `()`. 13 | - `PyTree` refers to any PyTree. 14 | - `Array` refers to a JAX array. 15 | 16 | --- 17 | 18 | In addition shapes and dtypes of `Array`s are annotated: 19 | 20 | - `Array["dim1", "dim2"]` refers to a JAX array with shape `(dim1, dim2)`, and so on for other shapes. 21 | - If a dimension is named in this way, then it should match up and be of equal size to the equally-named dimensions of all other arrays passed at the same time. 22 | - `Array[()]` refers to an array with shape `()`. 23 | - `...` refers to an arbitrary number of dimensions, e.g. `Array["times", ...]`. 24 | - `Array[bool]` refers to a JAX array with Boolean dtype. (And so on for other dtypes.) 25 | - These are combined via e.g. `Array["dim1", "dim2", bool]`. 26 | - Some arguments may have different shapes, for example, a transition matrix may be constant or time-varying. 27 | For this, we use `Union` types. 28 | 29 | 30 | -------------------------------------------------------------------------------- /dynamax/__init__.py: -------------------------------------------------------------------------------- 1 | from . import _version 2 | __version__ = _version.get_versions()['version'] 3 | 4 | # Catch expected warnings from TFP 5 | import dynamax.warnings 6 | 7 | # Default to float32 matrix multiplication on TPUs and GPUs 8 | import jax 9 | jax.config.update('jax_default_matmul_precision', 'float32') 10 | 11 | from . import _version 12 | __version__ = _version.get_versions()['version'] 13 | -------------------------------------------------------------------------------- /dynamax/generalized_gaussian_ssm/README.md: -------------------------------------------------------------------------------- 1 | # Generalized Gaussian State Space Models 2 | 3 | A GG-SSM is an SSM with nonlinear Gaussian dynamics, and nonlinear observations, 4 | where the emission distribution can be non-Gaussian (e.g., Poisson, categorical). 5 | To support approximateinference in this model, we can use the condititional moments 6 | Gaussian filtering, which is a form of the generalized Gaussian filter where 7 | the observation model is represented in terms of conditional first and second-order 8 | moments, i.e., E[Y|z] and Cov[Y|z]. 9 | 10 | -------------------------------------------------------------------------------- /dynamax/generalized_gaussian_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | from dynamax.generalized_gaussian_ssm.models import ParamsGGSSM, GeneralizedGaussianSSM 2 | from dynamax.generalized_gaussian_ssm.inference import EKFIntegrals, UKFIntegrals, GHKFIntegrals 3 | from dynamax.generalized_gaussian_ssm.inference import conditional_moments_gaussian_filter 4 | from dynamax.generalized_gaussian_ssm.inference import conditional_moments_gaussian_smoother 5 | -------------------------------------------------------------------------------- /dynamax/generalized_gaussian_ssm/dekf/README.md: -------------------------------------------------------------------------------- 1 | The diagonal EKF code (used in our paper https://openreview.net/pdf?id=asgeEt25kk) 2 | has moved to https://github.com/probml/dynamax 3 | 4 | -------------------------------------------------------------------------------- /dynamax/generalized_gaussian_ssm/demos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/dynamax/generalized_gaussian_ssm/demos/__init__.py -------------------------------------------------------------------------------- /dynamax/generalized_gaussian_ssm/demos/cmgf_logreg_estimator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo of a Conditional Moment Generating Function (CMGF) estimator for 3 | online estimation of a logistic regression. 4 | """ 5 | import jax 6 | from jax import numpy as jnp 7 | from sklearn.preprocessing import OneHotEncoder 8 | from sklearn.base import BaseEstimator, ClassifierMixin 9 | 10 | from dynamax.generalized_gaussian_ssm import conditional_moments_gaussian_filter, EKFIntegrals, ParamsGGSSM 11 | 12 | def fill_diagonal(A, elts): 13 | """ 14 | Fill the diagonal of a matrix with elements from a vector. 15 | """ 16 | # Taken from https://github.com/google/jax/issues/2680 17 | elts = jnp.ravel(elts) 18 | i, j = jnp.diag_indices(min(A.shape[-2:])) 19 | return A.at[..., i, j].set(elts) 20 | 21 | class CMGFEstimator(BaseEstimator, ClassifierMixin): 22 | """ 23 | Conditional Moment Generating Function (CMGF) estimator for online estimation of a logistic regression. 24 | """ 25 | def __init__(self, mean=None, cov=None): 26 | self.mean = mean 27 | self.cov = cov 28 | 29 | def fit(self, X, y): 30 | """ 31 | Fit the model to the data in online fasion using CMGF. 32 | """ 33 | X_bias = jnp.concatenate([jnp.ones((len(X), 1)), X], axis=1) 34 | # Encode output as one-hot-encoded vectors with first column dropped, 35 | # i.e., [0, ..., 0] correspondes to 1st class 36 | # This is done to prevent the "Dummy Variable Trap". 37 | enc = OneHotEncoder(drop='first') 38 | y_oh = jnp.array(enc.fit_transform(y.reshape(-1, 1)).toarray()) 39 | input_dim = X_bias.shape[-1] 40 | num_classes = y_oh.shape[-1] + 1 41 | self.classes_ = jnp.arange(num_classes) 42 | weight_dim = input_dim * num_classes 43 | 44 | initial_mean, initial_covariance = jnp.zeros(weight_dim), jnp.eye(weight_dim) 45 | dynamics_function = lambda w, x: w 46 | dynamics_covariance = jnp.zeros((weight_dim, weight_dim)) 47 | emission_mean_function = lambda w, x: jax.nn.softmax(x @ w.reshape(input_dim, -1))[1:] 48 | 49 | def emission_var_function(w, x): 50 | """Compute the variance of the emission distribution.""" 51 | ps = jnp.atleast_2d(emission_mean_function(w, x)) 52 | return fill_diagonal(ps.T @ -ps, ps * (1-ps)) 53 | 54 | cmgf_params = ParamsGGSSM( 55 | initial_mean = initial_mean, 56 | initial_covariance = initial_covariance, 57 | dynamics_function = dynamics_function, 58 | dynamics_covariance = dynamics_covariance, 59 | emission_mean_function = emission_mean_function, 60 | emission_cov_function = emission_var_function 61 | ) 62 | post = conditional_moments_gaussian_filter(cmgf_params, EKFIntegrals(), y_oh, inputs = X_bias) 63 | post_means, post_covs = post.filtered_means, post.filtered_covariances 64 | self.mean, self.cov = post_means[-1], post_covs[-1] 65 | return self 66 | 67 | def predict(self, X): 68 | """Predict the outputs for a new input""" 69 | return jnp.argmax(self.predict_proba(X), axis=1) 70 | 71 | def predict_proba(self, X): 72 | """Predict the class probabilities for a new input""" 73 | X = jnp.array(X) 74 | X_bias = jnp.concatenate([jnp.ones((len(X), 1)), X], axis=1) 75 | return jax.nn.softmax(X_bias @ self.mean.reshape(X_bias.shape[-1], -1)) 76 | -------------------------------------------------------------------------------- /dynamax/generalized_gaussian_ssm/inference_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for inference in the generalized Gaussian SSM. 3 | """ 4 | import jax.numpy as jnp 5 | 6 | from dynamax.generalized_gaussian_ssm.models import ParamsGGSSM 7 | from dynamax.generalized_gaussian_ssm.inference import conditional_moments_gaussian_smoother, EKFIntegrals, UKFIntegrals 8 | from dynamax.nonlinear_gaussian_ssm.inference_ekf import extended_kalman_smoother 9 | from dynamax.nonlinear_gaussian_ssm.inference_ukf import unscented_kalman_smoother, UKFHyperParams 10 | from dynamax.nonlinear_gaussian_ssm.inference_test_utils import random_nlgssm_args 11 | from dynamax.utils.utils import has_tpu 12 | from functools import partial 13 | 14 | if has_tpu(): 15 | allclose = partial(jnp.allclose, atol=1e-1) 16 | else: 17 | allclose = partial(jnp.allclose, atol=1e-3) 18 | 19 | 20 | def ekf(key=0, num_timesteps=15): 21 | """ 22 | Test EKF as a GGF 23 | """ 24 | nlgssm_args, _, emissions = random_nlgssm_args(key=key, num_timesteps=num_timesteps) 25 | 26 | # Run EKF from dynamax.ekf 27 | ekf_post = extended_kalman_smoother(nlgssm_args, emissions) 28 | # Run EKF as a GGF 29 | ekf_params = ParamsGGSSM( 30 | initial_mean=nlgssm_args.initial_mean, 31 | initial_covariance=nlgssm_args.initial_covariance, 32 | dynamics_function=nlgssm_args.dynamics_function, 33 | dynamics_covariance=nlgssm_args.dynamics_covariance, 34 | emission_mean_function=nlgssm_args.emission_function, 35 | emission_cov_function=lambda x: nlgssm_args.emission_covariance, 36 | ) 37 | ggf_post = conditional_moments_gaussian_smoother(ekf_params, EKFIntegrals(), emissions) 38 | 39 | # Compare filter and smoother results 40 | assert allclose(ekf_post.marginal_loglik, ggf_post.marginal_loglik) 41 | assert allclose(ekf_post.filtered_means, ggf_post.filtered_means) 42 | assert allclose(ekf_post.filtered_covariances, ggf_post.filtered_covariances) 43 | assert allclose(ekf_post.smoothed_means, ggf_post.smoothed_means) 44 | assert allclose(ekf_post.smoothed_covariances, ggf_post.smoothed_covariances) 45 | 46 | 47 | def test_ukf(key=1, num_timesteps=15): 48 | """ 49 | Test UKF as a GGF 50 | """ 51 | nlgssm_args, _, emissions = random_nlgssm_args(key=key, num_timesteps=num_timesteps) 52 | hyperparams = UKFHyperParams() 53 | 54 | # Run UKF from dynamax.ukf 55 | ukf_post = unscented_kalman_smoother(nlgssm_args, emissions, hyperparams) 56 | # Run UKF as GGF 57 | ukf_params = ParamsGGSSM( 58 | initial_mean=nlgssm_args.initial_mean, 59 | initial_covariance=nlgssm_args.initial_covariance, 60 | dynamics_function=nlgssm_args.dynamics_function, 61 | dynamics_covariance=nlgssm_args.dynamics_covariance, 62 | emission_mean_function=nlgssm_args.emission_function, 63 | emission_cov_function=lambda x: nlgssm_args.emission_covariance, 64 | ) 65 | ggf_post = conditional_moments_gaussian_smoother(ukf_params, UKFIntegrals(), emissions) 66 | 67 | # Compare filter and smoother results 68 | # c1, c2 = ukf_post.filtered_covariances, ggf_post.filtered_covariances 69 | # print(c1[0], '\n\n', c2[0]) 70 | assert allclose(ukf_post.marginal_loglik, ggf_post.marginal_loglik) 71 | assert allclose(ukf_post.filtered_means, ggf_post.filtered_means) 72 | assert allclose(ukf_post.filtered_covariances, ggf_post.filtered_covariances) 73 | assert allclose(ukf_post.smoothed_means, ggf_post.smoothed_means) 74 | assert allclose(ukf_post.smoothed_covariances, ggf_post.smoothed_covariances) -------------------------------------------------------------------------------- /dynamax/generalized_gaussian_ssm/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generalized Gaussian State Space Models. 3 | """ 4 | from jaxtyping import Array, Float 5 | import tensorflow_probability.substrates.jax as tfp 6 | from typing import NamedTuple, Optional, Union, Callable 7 | 8 | tfd = tfp.distributions 9 | tfb = tfp.bijectors 10 | MVN = tfd.MultivariateNormalFullCovariance 11 | 12 | from dynamax.ssm import SSM 13 | from dynamax.nonlinear_gaussian_ssm.models import FnStateToState, FnStateAndInputToState 14 | from dynamax.nonlinear_gaussian_ssm.models import FnStateToEmission, FnStateAndInputToEmission 15 | 16 | FnStateToEmission2 = Callable[[Float[Array, " state_dim"]], Float[Array, "emission_dim emission_dim"]] 17 | FnStateAndInputToEmission2 = Callable[[Float[Array, " state_dim"], Float[Array, " input_dim"]], Float[Array, "emission_dim emission_dim"]] 18 | 19 | # emission distribution takes a mean vector and covariance matrix and returns a distribution 20 | EmissionDistFn = Callable[ [Float[Array, " state_dim"], Float[Array, "state_dim state_dim"]], tfd.Distribution] 21 | 22 | 23 | class ParamsGGSSM(NamedTuple): 24 | """ 25 | Container for Generalized Gaussian SSM parameters. 26 | Specifically, it defines the following model: 27 | 28 | $$p(z_t | z_{t-1}, u_t) = N(z_t | f(z_{t-1}, u_t), Q_t)$$ 29 | $$p(y_t | z_t) = q(y_t | h(z_t, u_t), R(z_t, u_t))$$ 30 | $$p(z_1) = N(z_1 | m, S)$$ 31 | 32 | This differs from NLGSSM in by allowing a general emission model. 33 | If you have no inputs, the dynamics and emission functions do not to take $u_t$ as an argument. 34 | 35 | :param initial_mean: $m$ 36 | :param initial_covariance: $S$ 37 | :param dynamics_function: $f$. This has the signature $f: Z * U -> Y$ or $h: Z -> Y$. 38 | :param dynamics_covariance: $Q$ 39 | :param emission_mean_function: $h$. This has the signature $h: Z * U -> Z$ or $h: Z -> Z$. 40 | :param emission_cov_function: $R$. This has the signature $R: Z * U -> Z*Z$ or $R: Z -> Z*Z$. 41 | :param emission_dist: the observation distribution $q$. This is a callable that takes the predicted 42 | mean and covariance of Y, and returns a tfp distribution object: $q: Z * (Z*Z) -> Dist(Y)$. 43 | 44 | 45 | """ 46 | 47 | initial_mean: Float[Array, " state_dim"] 48 | initial_covariance: Float[Array, "state_dim state_dim"] 49 | dynamics_function: Union[FnStateToState, FnStateAndInputToState] 50 | dynamics_covariance: Float[Array, "state_dim state_dim"] 51 | 52 | emission_mean_function: Union[FnStateToEmission, FnStateAndInputToEmission] 53 | emission_cov_function: Union[FnStateToEmission2, FnStateAndInputToEmission2] 54 | emission_dist: EmissionDistFn = lambda mean, cov: MVN(loc=mean, covariance_matrix=cov) 55 | 56 | 57 | class GeneralizedGaussianSSM(SSM): 58 | """ 59 | Generalized Gaussian State Space Model. 60 | 61 | The model is defined as follows 62 | 63 | $$p(z_t | z_{t-1}, u_t) = N(z_t | f(z_{t-1}, u_t), Q_t)$$ 64 | $$p(y_t | z_t) = q(y_t | h(z_t, u_t), R(z_t, u_t))$$ 65 | $$p(z_1) = N(z_1 | m, S)$$ 66 | 67 | where the model parameters are 68 | 69 | * $z_t$ = hidden variables of size `state_dim`, 70 | * $y_t$ = observed variables of size `emission_dim` 71 | * $u_t$ = input covariates of size `input_dim` (defaults to 0). 72 | * $f$ = dynamics (transition) function 73 | * $h$ = emission (observation) function 74 | * $Q$ = covariance matrix of dynamics (system) noise 75 | * $R$ = covariance function for emission (observation) noise 76 | * $m$ = mean of initial state 77 | * $S$ = covariance matrix of initial state 78 | 79 | The parameters of the model are stored in a separate object of type :class:`ParamsGGSSM`. 80 | 81 | For example usage, see https://github.com/probml/dynamax/blob/main/dynamax/generalized_gaussian_ssm/models_test.py. 82 | 83 | """ 84 | 85 | def __init__(self, state_dim, emission_dim, input_dim=0): 86 | self.state_dim = state_dim 87 | self.emission_dim = emission_dim 88 | self.input_dim = 0 89 | 90 | @property 91 | def emission_shape(self): 92 | """Shape of the emissions""" 93 | return (self.emission_dim,) 94 | 95 | @property 96 | def covariates_shape(self): 97 | """Shape of the covariates""" 98 | return (self.input_dim,) if self.input_dim > 0 else None 99 | 100 | def initial_distribution(self, 101 | params: ParamsGGSSM, 102 | inputs: Optional[Float[Array, " input_dim"]]=None) \ 103 | -> tfd.Distribution: 104 | """Returns the initial distribution of the state.""" 105 | return MVN(params.initial_mean, params.initial_covariance) 106 | 107 | def transition_distribution(self, 108 | params: ParamsGGSSM, 109 | state: Float[Array, " state_dim"], 110 | inputs: Optional[Float[Array, " input_dim"]]=None 111 | ) -> tfd.Distribution: 112 | """Returns the transition distribution of the state.""" 113 | f = params.dynamics_function 114 | if inputs is None: 115 | mean = f(state) 116 | else: 117 | mean = f(state, inputs) 118 | return MVN(mean, params.dynamics_covariance) 119 | 120 | def emission_distribution(self, 121 | params: ParamsGGSSM, 122 | state: Float[Array, " state_dim"], 123 | inputs: Optional[Float[Array, " input_dim"]]=None) \ 124 | -> tfd.Distribution: 125 | """Returns the emission distribution of the state.""" 126 | h = params.emission_mean_function 127 | R = params.emission_cov_function 128 | if inputs is None: 129 | mean = h(state) 130 | cov = R(state) 131 | else: 132 | mean = h(state, inputs) 133 | cov = R(state, inputs) 134 | return params.emission_dist(mean, cov) 135 | -------------------------------------------------------------------------------- /dynamax/generalized_gaussian_ssm/models_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the GeneralizedGaussianSSM class. 3 | """ 4 | import pytest 5 | import jax.random as jr 6 | import jax.numpy as jnp 7 | import tensorflow_probability.substrates.jax as tfp 8 | 9 | tfd = tfp.distributions 10 | 11 | from dynamax.generalized_gaussian_ssm import GeneralizedGaussianSSM, ParamsGGSSM 12 | from dynamax.generalized_gaussian_ssm.inference import conditional_moments_gaussian_filter, EKFIntegrals 13 | 14 | NUM_TIMESTEPS = 100 15 | 16 | CONFIGS = [ 17 | (jr.PRNGKey(0), dict(state_dim=3, emission_dim=5)), 18 | (jr.PRNGKey(1), dict(state_dim=5, emission_dim=5)), 19 | (jr.PRNGKey(2), dict(state_dim=10, emission_dim=7)), 20 | ] 21 | 22 | @pytest.mark.parametrize(["key", "kwargs"], CONFIGS) 23 | def test_poisson_emission(key, kwargs): 24 | """ 25 | Test that the marginal log-likelihood under Poisson emission is higher than 26 | when we treat the count emissions as Gaussian random variables. 27 | """ 28 | keys = jr.split(key, 3) 29 | state_dim = kwargs['state_dim'] 30 | emission_dim = 1 # Univariate Poisson 31 | poisson_weights = jr.normal(keys[0], shape=(emission_dim, state_dim)) / jnp.sqrt(state_dim) 32 | model = GeneralizedGaussianSSM(state_dim, emission_dim) 33 | 34 | # Define model parameters with Poisson emission 35 | pois_params = ParamsGGSSM( 36 | initial_mean=jr.normal(keys[1], (state_dim,)), 37 | initial_covariance=jnp.eye(state_dim), 38 | dynamics_function=lambda z: 0.99 * z, 39 | dynamics_covariance=0.001*jnp.eye(state_dim), 40 | emission_mean_function=lambda z: jnp.exp(poisson_weights @ z), 41 | emission_cov_function = lambda z: jnp.exp(poisson_weights @ z), 42 | emission_dist=lambda mu, Sigma: tfd.Poisson(log_rate = jnp.log(mu)) 43 | ) 44 | _, emissions = model.sample(pois_params, keys[2], num_timesteps=NUM_TIMESTEPS) 45 | 46 | # Define model parameters with default Gaussian emission 47 | gaussian_params = ParamsGGSSM( 48 | initial_mean=jr.normal(keys[1], (state_dim,)), 49 | initial_covariance=jnp.eye(state_dim), 50 | dynamics_function=lambda z: 0.99 * z, 51 | dynamics_covariance=0.001*jnp.eye(state_dim), 52 | emission_mean_function=lambda z: jnp.exp(poisson_weights @ z), 53 | emission_cov_function=lambda z: jnp.exp(poisson_weights @ z) 54 | ) 55 | 56 | # Fit model with Poisson emission 57 | pois_marginal_lls = conditional_moments_gaussian_filter(pois_params, EKFIntegrals(), emissions).marginal_loglik 58 | 59 | # Fit model with Gaussian emission 60 | gaussian_marginal_lls = conditional_moments_gaussian_filter(gaussian_params, EKFIntegrals(), emissions).marginal_loglik 61 | 62 | # Check that the marginal log-likelihoods under Poisson emission are higher 63 | assert pois_marginal_lls > gaussian_marginal_lls 64 | -------------------------------------------------------------------------------- /dynamax/hidden_markov_model/README.md: -------------------------------------------------------------------------------- 1 | # Hidden Markov Models 2 | 3 | We support HMMs with a variety of observation models, including categorical, Gaussian, poisson, etc. 4 | We support exact posterior inference using the forwards-backwards algorithm, and parameter estimation using EM and SGD. 5 | -------------------------------------------------------------------------------- /dynamax/hidden_markov_model/__init__.py: -------------------------------------------------------------------------------- 1 | from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMInitialState, HMMTransitions, HMMParameterSet, HMMPropertySet 2 | from dynamax.hidden_markov_model.models.arhmm import LinearAutoregressiveHMM 3 | from dynamax.hidden_markov_model.models.bernoulli_hmm import BernoulliHMM 4 | from dynamax.hidden_markov_model.models.categorical_glm_hmm import CategoricalRegressionHMM 5 | from dynamax.hidden_markov_model.models.categorical_hmm import CategoricalHMM 6 | from dynamax.hidden_markov_model.models.gamma_hmm import GammaHMM 7 | from dynamax.hidden_markov_model.models.gaussian_hmm import GaussianHMM, DiagonalGaussianHMM, SphericalGaussianHMM, SharedCovarianceGaussianHMM, LowRankGaussianHMM 8 | from dynamax.hidden_markov_model.models.gmm_hmm import GaussianMixtureHMM, DiagonalGaussianMixtureHMM 9 | from dynamax.hidden_markov_model.models.linreg_hmm import LinearRegressionHMM 10 | from dynamax.hidden_markov_model.models.logreg_hmm import LogisticRegressionHMM 11 | from dynamax.hidden_markov_model.models.multinomial_hmm import MultinomialHMM 12 | from dynamax.hidden_markov_model.models.poisson_hmm import PoissonHMM 13 | 14 | from dynamax.hidden_markov_model.inference import HMMPosterior 15 | from dynamax.hidden_markov_model.inference import HMMPosteriorFiltered 16 | from dynamax.hidden_markov_model.inference import hmm_filter 17 | from dynamax.hidden_markov_model.inference import hmm_backward_filter 18 | from dynamax.hidden_markov_model.inference import hmm_two_filter_smoother 19 | from dynamax.hidden_markov_model.inference import hmm_smoother 20 | from dynamax.hidden_markov_model.inference import hmm_posterior_mode 21 | from dynamax.hidden_markov_model.inference import hmm_posterior_sample 22 | from dynamax.hidden_markov_model.inference import hmm_fixed_lag_smoother 23 | from dynamax.hidden_markov_model.inference import compute_transition_probs 24 | 25 | from dynamax.hidden_markov_model.parallel_inference import hmm_filter as parallel_hmm_filter 26 | from dynamax.hidden_markov_model.parallel_inference import hmm_smoother as parallel_hmm_smoother 27 | from dynamax.hidden_markov_model.parallel_inference import hmm_posterior_sample as parallel_hmm_posterior_sample -------------------------------------------------------------------------------- /dynamax/hidden_markov_model/demos/categorical_glm_hmm_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script demonstrates how to use the CategoricalRegressionHMM class. 3 | """ 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | from jax import vmap 7 | import matplotlib.pyplot as plt 8 | 9 | from dynamax.hidden_markov_model import CategoricalRegressionHMM 10 | 11 | if __name__ == "__main__": 12 | key1, key2, key3, key4 = jr.split(jr.PRNGKey(0), 4) 13 | 14 | num_states = 2 15 | num_classes = 3 16 | feature_dim = 10 17 | num_timesteps = 20000 18 | 19 | hmm = CategoricalRegressionHMM(num_states, num_classes, feature_dim) 20 | transition_matrix = jnp.array([[0.95, 0.05], 21 | [0.05, 0.95]]) 22 | true_params, _ = hmm.initialize(key=key1, transition_matrix=transition_matrix) 23 | 24 | inputs = jr.normal(key2, (num_timesteps, feature_dim)) 25 | states, emissions = hmm.sample(true_params, key3, num_timesteps, inputs=inputs) 26 | 27 | # Try fitting it! 28 | test_hmm = CategoricalRegressionHMM(num_states, num_classes, feature_dim) 29 | params, props = test_hmm.initialize(key=key4) 30 | params, lps = test_hmm.fit_em(params, props, emissions, inputs=inputs, num_iters=100) 31 | 32 | # Plot the data and predictions 33 | # Compute the most likely states 34 | most_likely_states = test_hmm.most_likely_states(params, emissions, inputs=inputs) 35 | 36 | # Predict the emissions given the true states 37 | As = params["emissions"]["weights"][most_likely_states] 38 | bs = params["emissions"]["biases"][most_likely_states] 39 | predictions = vmap(lambda x, A, b: A @ x + b)(inputs, As, bs) 40 | predictions = jnp.argmax(predictions, axis=1) 41 | 42 | offsets = 3 * jnp.arange(num_classes) 43 | plt.imshow(most_likely_states[None, :], 44 | extent=(0, num_timesteps, -3, 3 * num_classes), 45 | aspect="auto", 46 | cmap="Greys", 47 | alpha=0.5) 48 | plt.plot(emissions) 49 | plt.plot(predictions, ':k') 50 | plt.xlim(0, num_timesteps) 51 | plt.ylim(-0.25, 2.25) 52 | plt.xlabel("time") 53 | plt.xlim(0, 100) 54 | 55 | plt.figure() 56 | plt.plot(lps) 57 | plt.axhline(hmm.marginal_log_prob(true_params, emissions, inputs), color='k', ls=':') 58 | plt.xlabel("EM iteration") 59 | plt.ylabel("log joint probability") 60 | 61 | plt.figure() 62 | plt.imshow(jnp.vstack((states[None, :], most_likely_states[None, :])), 63 | aspect="auto", interpolation='none', cmap="Greys") 64 | plt.yticks([0.0, 1.0], ["$z$", r"$\hat{z}$"]) 65 | plt.xlabel("time") 66 | plt.xlim(0, 500) 67 | 68 | 69 | print("true log prob: ", hmm.marginal_log_prob(true_params, emissions, inputs=inputs)) 70 | print("test log prob: ", test_hmm.marginal_log_prob(params, emissions, inputs=inputs)) 71 | 72 | plt.show() 73 | -------------------------------------------------------------------------------- /dynamax/hidden_markov_model/demos/poisson_hmm_earthquakes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Using a Hidden Markov Model with Poisson Emissions to Understand Earthquakes 3 | 4 | Based on 5 | https://github.com/hmmlearn/hmmlearn/blob/main/examples/plot_poisson_hmm.py 6 | https://hmmlearn.readthedocs.io/en/latest/auto_examples/plot_poisson_hmm.html 7 | 8 | """ 9 | import jax.numpy as jnp 10 | import jax.random as jr 11 | import matplotlib.pyplot as plt 12 | from dynamax.hidden_markov_model.models.poisson_hmm import PoissonHMM 13 | 14 | # earthquake data from http://earthquake.usgs.gov/ 15 | EARTHQUAKES = jnp.array( 16 | [13, 14, 8, 10, 16, 26, 32, 27, 18, 32, 36, 24, 22, 23, 22, 18, 25, 21, 21, 14, 8, 11, 14, 17 | 23, 18, 17, 19, 20, 22, 19, 13, 26, 13, 14, 22, 24, 21, 22, 26, 21, 23, 24, 27, 41, 31, 27, 18 | 35, 26, 28, 36, 39, 21, 17, 22, 17, 19, 15, 34, 10, 15, 22, 18, 15, 20, 15, 22, 19, 16, 30, 19 | 27, 29, 23, 20, 16, 21, 21, 25, 16, 18, 15, 18, 14, 10, 20 | 15, 8, 15, 6, 11, 8, 7, 18, 16, 13, 12, 13, 20, 15, 16, 12, 18, 15, 16, 13, 15, 16, 11, 11]) 21 | 22 | 23 | 24 | def main(test_mode=False, num_iters=20, num_repeats=10, min_states=2, max_states=4): 25 | """ 26 | Fit a Poisson Hidden Markov Model to earthquake data. 27 | """ 28 | emission_dim = 1 29 | emissions = EARTHQUAKES.reshape(-1, emission_dim) 30 | 31 | # Now, fit a Poisson Hidden Markov Model to the data. 32 | scores = list() 33 | models = list() 34 | model_params = list() 35 | 36 | for num_states in range(min_states, max_states+1): 37 | for idx in range(num_repeats): # ten different random starting states 38 | key = jr.PRNGKey(idx) 39 | key1, key2 = jr.split(key, 2) 40 | 41 | model = PoissonHMM(num_states, emission_dim) 42 | params, param_props = model.initialize(key1) 43 | params["emissions"]["rates"] = jr.uniform(key2, (num_states, emission_dim), minval=10.0, maxval=35.0) 44 | 45 | params, losses = model.fit_em(params, param_props, emissions[None, ...], num_iters=num_iters) 46 | models.append(model) 47 | model_params.append(params) 48 | scores.append(model.marginal_log_prob(params, emissions)) 49 | print(f"Score: {scores[-1]}") 50 | 51 | # get the best model 52 | model = models[jnp.argmax(jnp.array(scores))] 53 | params = model_params[jnp.argmax(jnp.array(scores))] 54 | print(f"The best model had a score of {max(scores)} and " 55 | f"{model.num_states} components") 56 | 57 | # use the Viterbi algorithm to predict the most likely sequence of states 58 | # given the model 59 | states = model.most_likely_states(params, emissions) 60 | 61 | if not test_mode: 62 | # Let's plot the rates from our most likely series of states of 63 | # earthquake activity with the earthquake data. As we can see, the 64 | # model with the maximum likelihood had different states which may reflect 65 | # times of varying earthquake danger. 66 | 67 | # plot model states over time 68 | fig, ax = plt.subplots() 69 | ax.plot(params["emissions"]["rates"][states], ".-", ms=6, mfc="orange") 70 | ax.plot(emissions.ravel()) 71 | ax.set_title("States compared to generated") 72 | ax.set_xlabel("State") 73 | 74 | # Fortunately, 2006 ended with a period of relative tectonic stability, and, 75 | # if we look at our transition matrix, we can see that the off-diagonal terms 76 | # are small, meaning that the state transitions are rare and it's unlikely that 77 | # there will be high earthquake danger in the near future. 78 | fig, ax = plt.subplots() 79 | ax.imshow(params["transitions"]["transition_matrix"], aspect="auto", cmap="spring") 80 | ax.set_title("Transition Matrix") 81 | ax.set_xlabel("State To") 82 | ax.set_ylabel("State From") 83 | 84 | plt.show() 85 | 86 | # Run the demo 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /dynamax/hidden_markov_model/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/dynamax/hidden_markov_model/models/__init__.py -------------------------------------------------------------------------------- /dynamax/hidden_markov_model/models/initial.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the implementation of the initial distribution of a hidden Markov model. 3 | """ 4 | from typing import Any, cast, NamedTuple, Optional, Tuple, Union 5 | import jax.numpy as jnp 6 | import jax.random as jr 7 | from jaxtyping import Float, Array 8 | import tensorflow_probability.substrates.jax.distributions as tfd 9 | import tensorflow_probability.substrates.jax.bijectors as tfb 10 | from dynamax.hidden_markov_model.inference import HMMPosterior 11 | from dynamax.hidden_markov_model.models.abstractions import HMMInitialState 12 | from dynamax.parameters import ParameterProperties 13 | from dynamax.types import Scalar 14 | 15 | 16 | class ParamsStandardHMMInitialState(NamedTuple): 17 | """Named tuple for the parameters of the standard HMM initial distribution.""" 18 | probs: Union[Float[Array, " state_dim"], ParameterProperties] 19 | 20 | 21 | class StandardHMMInitialState(HMMInitialState): 22 | """Abstract class for HMM initial distributions. 23 | """ 24 | def __init__(self, 25 | num_states: int, 26 | initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1): 27 | """ 28 | Args: 29 | initial_probabilities[k]: prob(hidden(1)=k) 30 | """ 31 | self.num_states = num_states 32 | self.initial_probs_concentration = initial_probs_concentration * jnp.ones(num_states) 33 | 34 | def distribution(self, params: ParamsStandardHMMInitialState, inputs=None) -> tfd.Distribution: 35 | """Return the distribution object of the initial distribution.""" 36 | return tfd.Categorical(probs=params.probs) 37 | 38 | def initialize( 39 | self, 40 | key: Optional[Array]=None, 41 | method="prior", 42 | initial_probs: Optional[Float[Array, " num_states"]]=None 43 | ) -> Tuple[ParamsStandardHMMInitialState, ParamsStandardHMMInitialState]: 44 | """Initialize the model parameters and their corresponding properties. 45 | 46 | Args: 47 | key (_type_, optional): _description_. Defaults to None. 48 | method (str, optional): _description_. Defaults to "prior". 49 | initial_probs (_type_, optional): _description_. Defaults to None. 50 | 51 | Returns: 52 | _type_: _description_ 53 | """ 54 | # Initialize the initial probabilities 55 | if initial_probs is None: 56 | if key is None: 57 | raise ValueError("key must be provided if initial_probs is not provided.") 58 | else: 59 | this_key, key = jr.split(key) 60 | initial_probs = tfd.Dirichlet(self.initial_probs_concentration).sample(seed=this_key) 61 | 62 | # Package the results into dictionaries 63 | params = ParamsStandardHMMInitialState(probs=initial_probs) 64 | props = ParamsStandardHMMInitialState(probs=ParameterProperties(constrainer=tfb.SoftmaxCentered())) 65 | return params, props 66 | 67 | def log_prior(self, params: ParamsStandardHMMInitialState) -> Scalar: 68 | """Compute the log prior of the parameters.""" 69 | return tfd.Dirichlet(self.initial_probs_concentration).log_prob(params.probs) 70 | 71 | def _compute_initial_probs( 72 | self, params: ParamsStandardHMMInitialState, inputs=None 73 | ) -> Float[Array, " num_states"]: 74 | """Compute the initial probabilities.""" 75 | return params.probs 76 | 77 | def collect_suff_stats(self, params, posterior: HMMPosterior, inputs=None) -> Float[Array, " num_states"]: 78 | """Collect the sufficient statistics for the initial distribution.""" 79 | return posterior.smoothed_probs[0] 80 | 81 | def initialize_m_step_state(self, params, props) -> None: 82 | """Initialize the state for the M-step.""" 83 | return None 84 | 85 | def m_step( 86 | self, 87 | params: ParamsStandardHMMInitialState, 88 | props: ParamsStandardHMMInitialState, 89 | batch_stats: Float[Array, "batch num_states"], 90 | m_step_state: Any 91 | ) -> Tuple[ParamsStandardHMMInitialState, Any]: 92 | """Perform the M-step of the EM algorithm.""" 93 | if props.probs.trainable: 94 | if self.num_states == 1: 95 | probs = jnp.array([1.0]) 96 | else: 97 | expected_initial_counts = batch_stats.sum(axis=0) 98 | probs = tfd.Dirichlet(self.initial_probs_concentration + expected_initial_counts).mode() 99 | params = params._replace(probs=probs) 100 | return params, m_step_state 101 | 102 | -------------------------------------------------------------------------------- /dynamax/hidden_markov_model/models/test_models.py: -------------------------------------------------------------------------------- 1 | """Tests for the HMM models.""" 2 | 3 | import dynamax.hidden_markov_model as models 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | import pytest 7 | 8 | from jax import vmap 9 | from dynamax.utils.utils import monotonically_increasing 10 | 11 | NUM_TIMESTEPS = 50 12 | 13 | CONFIGS = [ 14 | (models.BernoulliHMM, dict(num_states=4, emission_dim=3), None), 15 | (models.CategoricalHMM, dict(num_states=4, emission_dim=3, num_classes=5), None), 16 | (models.CategoricalRegressionHMM, dict(num_states=4, num_classes=3, input_dim=5), jnp.ones((NUM_TIMESTEPS, 5))), 17 | (models.GammaHMM, dict(num_states=4), None), 18 | (models.GaussianHMM, dict(num_states=4, emission_dim=3, emission_prior_concentration=1.0, emission_prior_scale=1.0), None), 19 | (models.DiagonalGaussianHMM, dict(num_states=4, emission_dim=3), None), 20 | (models.SphericalGaussianHMM, dict(num_states=4, emission_dim=3), None), 21 | (models.SharedCovarianceGaussianHMM, dict(num_states=4, emission_dim=3), None), 22 | (models.LowRankGaussianHMM, dict(num_states=4, emission_dim=3, emission_rank=1), None), 23 | (models.GaussianMixtureHMM, dict(num_states=4, num_components=2, emission_dim=3, emission_prior_mean_concentration=1.0), None), 24 | (models.DiagonalGaussianMixtureHMM, dict(num_states=4, num_components=2, emission_dim=3, emission_prior_mean_concentration=1.0), None), 25 | (models.LinearRegressionHMM, dict(num_states=3, emission_dim=3, input_dim=5), jr.normal(jr.PRNGKey(0),(NUM_TIMESTEPS, 5))), 26 | (models.LogisticRegressionHMM, dict(num_states=4, input_dim=5), jnp.ones((NUM_TIMESTEPS, 5))), 27 | (models.MultinomialHMM, dict(num_states=4, emission_dim=3, num_classes=5, num_trials=10), None), 28 | (models.PoissonHMM, dict(num_states=4, emission_dim=3), None), 29 | ] 30 | 31 | 32 | @pytest.mark.parametrize(["cls", "kwargs", "inputs"], CONFIGS) 33 | def test_sample_and_fit(cls, kwargs, inputs): 34 | """Test that we can sample from and fit a model.""" 35 | hmm = cls(**kwargs) 36 | key1, key2 = jr.split(jr.PRNGKey(42)) 37 | params, param_props = hmm.initialize(key1) 38 | states, emissions = hmm.sample(params, key2, num_timesteps=NUM_TIMESTEPS, inputs=inputs) 39 | fitted_params, lps = hmm.fit_em(params, param_props, emissions, inputs=inputs, num_iters=10) 40 | assert monotonically_increasing(lps, atol=1e-2, rtol=1e-2) 41 | fitted_params, lps = hmm.fit_sgd(params, param_props, emissions, inputs=inputs, num_epochs=10) 42 | 43 | 44 | ## A few model-specific tests 45 | def test_categorical_hmm_viterbi(): 46 | """Test that the Viterbi algorithm works for a simple CategoricalHMM.""" 47 | # From http://en.wikipedia.org/wiki/Viterbi_algorithm: 48 | hmm = models.CategoricalHMM(num_states=2, emission_dim=1, num_classes=3) 49 | params, props = hmm.initialize( 50 | jr.PRNGKey(0), 51 | initial_probs=jnp.array([0.6, 0.4]), 52 | transition_matrix=jnp.array([[0.7, 0.3], [0.4, 0.6]]), 53 | emission_probs=jnp.array([[0.1, 0.4, 0.5], [0.6, 0.3, 0.1]]).reshape(2, 1, 3)) 54 | 55 | emissions = jnp.arange(3).reshape(3, 1) 56 | state_sequence = hmm.most_likely_states(params, emissions) 57 | assert jnp.allclose(jnp.squeeze(state_sequence), jnp.array([1, 0, 0])) 58 | 59 | 60 | def test_gmm_hmm_vs_gmm_diag_hmm(key=jr.PRNGKey(0), num_states=4, num_components=3, emission_dim=2): 61 | """Test that a GaussianMixtureHMM and DiagonalGaussianMixtureHMM are equivalent.""" 62 | key1, key2, key3 = jr.split(key, 3) 63 | diag_hmm = models.DiagonalGaussianMixtureHMM(num_states, num_components, emission_dim) 64 | diag_params, _ = diag_hmm.initialize(key1) 65 | 66 | full_hmm = models.GaussianMixtureHMM(num_states, num_components, emission_dim) 67 | emission_covariances = vmap(lambda ss: vmap(lambda s: jnp.diag(s**2))(ss))(diag_params.emissions.scale_diags) 68 | full_params, _ = full_hmm.initialize(key2, 69 | initial_probs=diag_params.initial.probs, 70 | transition_matrix=diag_params.transitions.transition_matrix, 71 | emission_weights=diag_params.emissions.weights, 72 | emission_means=diag_params.emissions.means, 73 | emission_covariances=emission_covariances) 74 | 75 | states_diag, emissions_diag = diag_hmm.sample(diag_params, key3, NUM_TIMESTEPS) 76 | states_full, emissions_full = full_hmm.sample(full_params, key3, NUM_TIMESTEPS) 77 | assert jnp.allclose(emissions_full, emissions_diag) 78 | assert jnp.allclose(states_full, states_diag) 79 | 80 | posterior_diag = diag_hmm.smoother(diag_params, emissions_diag) 81 | posterior_full = full_hmm.smoother(full_params, emissions_full) 82 | 83 | assert jnp.allclose(posterior_diag.marginal_loglik, posterior_full.marginal_loglik) 84 | assert jnp.allclose(posterior_diag.filtered_probs, posterior_full.filtered_probs) 85 | assert jnp.allclose(posterior_diag.predicted_probs, posterior_full.predicted_probs) 86 | assert jnp.allclose(posterior_diag.smoothed_probs, posterior_full.smoothed_probs) 87 | assert jnp.allclose(posterior_diag.initial_probs, posterior_full.initial_probs) 88 | 89 | states_diag = diag_hmm.most_likely_states(diag_params, emissions_diag) 90 | states_full = full_hmm.most_likely_states(full_params, emissions_full) 91 | assert jnp.allclose(states_full, states_diag) 92 | 93 | 94 | def test_sample_and_fit_arhmm(): 95 | """Test that we can sample from and fit a LinearAutoregressiveHMM.""" 96 | arhmm = models.LinearAutoregressiveHMM(num_states=4, emission_dim=2, num_lags=1) 97 | #key1, key2 = jr.split(jr.PRNGKey(int(datetime.now().timestamp()))) 98 | key1, key2 = jr.split(jr.PRNGKey(42)) 99 | params, param_props = arhmm.initialize(key1) 100 | states, emissions = arhmm.sample(params, key2, num_timesteps=NUM_TIMESTEPS) 101 | inputs = arhmm.compute_inputs(emissions) 102 | fitted_params, lps = arhmm.fit_em(params, param_props, emissions, inputs=inputs, num_iters=10) 103 | assert monotonically_increasing(lps, atol=1e-2, rtol=1e-2) 104 | fitted_params, lps = arhmm.fit_sgd(params, param_props, emissions, inputs=inputs, num_epochs=10) 105 | 106 | 107 | # @pytest.mark.skip(reason="this would introduce a torch dependency") 108 | # def test_hmm_fit_stochastic_em(num_iters=100): 109 | # """Evaluate stochastic em fit with respect to exact em fit.""" 110 | 111 | # # ------------------------------------------------------------- 112 | # def _collate(batch): 113 | # """Merges a list of samples to form a batch of tensors.""" 114 | # if isinstance(batch[0], jnp.ndarray): 115 | # return jnp.stack(batch) 116 | # elif isinstance(batch[0], (tuple,list)): 117 | # transposed = zip(*batch) 118 | # return [_collate(samples) for samples in transposed] 119 | # else: 120 | # return jnp.array(batch) 121 | 122 | 123 | # from torch.utils.data import DataLoader 124 | # class ArrayLoader(DataLoader): 125 | # """Generates an iterable over the given array, with option to reshuffle. 126 | 127 | # Args: 128 | # dataset: Any object implementing __len__ and __getitem__ 129 | # batch_size (int): Number of samples to load per batch 130 | # shuffle (bool): If True, reshuffle data at every epoch 131 | # drop_last (bool): If true, drop last incomplete batch if dataset size is 132 | # not divisible by batch size, drop last incomplete batch. Else, keep 133 | # (smaller) last batch. 134 | # """ 135 | # def __init__(self, dataset, batch_size=1, shuffle=True, drop_last=True): 136 | 137 | # super(self.__class__, self).__init__(dataset, 138 | # batch_size=batch_size, 139 | # shuffle=shuffle, 140 | # collate_fn=_collate, 141 | # drop_last=drop_last 142 | # ) 143 | # # Generate data and construct dataloader 144 | # true_hmm, _, batch_emissions = make_rnd_model_and_data(num_batches=8) 145 | # emissions_generator = ArrayLoader(batch_emissions, batch_size=2, shuffle=True) 146 | 147 | # refr_hmm = GaussianHMM.random_initialization(jr.PRNGKey(1), 2 * true_hmm.num_states, true_hmm.num_obs) 148 | # test_hmm = GaussianHMM.random_initialization(jr.PRNGKey(1), 2 * true_hmm.num_states, true_hmm.num_obs) 149 | 150 | # refr_lps = refr_hmm.fit_em(batch_emissions, num_iters) 151 | 152 | # total_emissions = len(batch_emissions.reshape(-1, true_hmm.num_obs)) 153 | # test_lps = test_hmm.fit_stochastic_em( 154 | # emissions_generator, total_emissions, num_epochs=num_iters, 155 | # ) 156 | 157 | # # ------------------------------------------------------------------------- 158 | # # we expect lps to likely differ by quite a bit, but should be in the same order 159 | # print(f'test log prob {test_lps[-1]:.2f} refrence lp {refr_lps[-1]:.2f}') 160 | # assert jnp.allclose(test_lps[-1], refr_lps[-1], atol=100) 161 | 162 | # refr_mu = refr_hmm.emission_means.value 163 | # test_mu = test_hmm.emission_means.value 164 | 165 | # assert jnp.alltrue(test_mu.shape == (10, 2)) 166 | # assert jnp.allclose(jnp.linalg.norm(test_mu-refr_mu, axis=-1), 0., atol=2) 167 | 168 | # refr_cov = refr_hmm.emission_covariance_matrices.value 169 | # test_cov = test_hmm.emission_covariance_matrices.value 170 | # assert jnp.alltrue(test_cov.shape == (10, 2, 2)) 171 | # assert jnp.allclose(jnp.linalg.norm(test_cov-refr_cov, axis=-1), 0., atol=1) 172 | -------------------------------------------------------------------------------- /dynamax/hidden_markov_model/models/transitions.py: -------------------------------------------------------------------------------- 1 | """Module for HMM transition models.""" 2 | import jax.numpy as jnp 3 | import tensorflow_probability.substrates.jax.distributions as tfd 4 | import tensorflow_probability.substrates.jax.bijectors as tfb 5 | 6 | from dynamax.hidden_markov_model.models.abstractions import HMMTransitions 7 | from dynamax.hidden_markov_model.inference import HMMPosterior 8 | from dynamax.parameters import ParameterProperties 9 | from dynamax.types import IntScalar, Scalar 10 | 11 | from jaxtyping import Float, Array 12 | from typing import Any, cast, NamedTuple, Optional, Tuple, Union 13 | 14 | 15 | class ParamsStandardHMMTransitions(NamedTuple): 16 | """Named tuple for the parameters of the StandardHMMTransitions model.""" 17 | transition_matrix: Union[Float[Array, "state_dim state_dim"], ParameterProperties] 18 | 19 | 20 | class StandardHMMTransitions(HMMTransitions): 21 | r"""Standard model for HMM transitions. 22 | 23 | We place a Dirichlet prior over the rows of the transition matrix $A$, 24 | 25 | $$A_k \sim \mathrm{Dir}(\beta 1_K + \kappa e_k)$$ 26 | 27 | where 28 | 29 | * $1_K$ denotes a length-$K$ vector of ones, 30 | * $e_k$ denotes the one-hot vector with a 1 in the $k$-th position, 31 | * $\beta \in \mathbb{R}_+$ is the concentration, and 32 | * $\kappa \in \mathbb{R}_+$ is the `stickiness`. 33 | 34 | """ 35 | def __init__( 36 | self, 37 | num_states: int, 38 | concentration: Union[Scalar, Float[Array, "num_states num_states"]]=1.1, 39 | stickiness: Union[Scalar, Float[Array, " num_states"]]=0.0 40 | ): 41 | """ 42 | Args: 43 | transition_matrix[j,k]: prob(hidden(t) = k | hidden(t-1)j) 44 | """ 45 | self.num_states = num_states 46 | self.concentration = \ 47 | concentration * jnp.ones((num_states, num_states)) + \ 48 | stickiness * jnp.eye(num_states) 49 | 50 | def distribution(self, params: ParamsStandardHMMTransitions, state: IntScalar, inputs=None): 51 | """Return the distribution over the next state given the current state.""" 52 | return tfd.Categorical(probs=params.transition_matrix[state]) 53 | 54 | def initialize( 55 | self, 56 | key: Optional[Array]=None, 57 | method="prior", 58 | transition_matrix: Optional[Float[Array, "num_states num_states"]]=None 59 | ) -> Tuple[ParamsStandardHMMTransitions, ParamsStandardHMMTransitions]: 60 | """Initialize the model parameters and their corresponding properties. 61 | 62 | Args: 63 | key (_type_, optional): _description_. Defaults to None. 64 | method (str, optional): _description_. Defaults to "prior". 65 | transition_matrix (_type_, optional): _description_. Defaults to None. 66 | 67 | Returns: 68 | _type_: _description_ 69 | """ 70 | if transition_matrix is None: 71 | if key is None: 72 | raise ValueError("key must be provided if transition_matrix is not provided.") 73 | else: 74 | transition_matrix_sample = tfd.Dirichlet(self.concentration).sample(seed=key) 75 | transition_matrix = cast(Float[Array, "num_states num_states"], transition_matrix_sample) 76 | 77 | # Package the results into dictionaries 78 | params = ParamsStandardHMMTransitions(transition_matrix=transition_matrix) 79 | props = ParamsStandardHMMTransitions(transition_matrix=ParameterProperties(constrainer=tfb.SoftmaxCentered())) 80 | return params, props 81 | 82 | def log_prior(self, params: ParamsStandardHMMTransitions) -> Scalar: 83 | """Compute the log prior probability of the parameters.""" 84 | return tfd.Dirichlet(self.concentration).log_prob(params.transition_matrix).sum() 85 | 86 | def _compute_transition_matrices( 87 | self, params: ParamsStandardHMMTransitions, inputs=None 88 | ) -> Float[Array, "num_states num_states"]: 89 | """Compute the transition matrices.""" 90 | return params.transition_matrix 91 | 92 | def collect_suff_stats( 93 | self, 94 | params, 95 | posterior: HMMPosterior, 96 | inputs=None 97 | ) -> Union[Float[Array, "num_states num_states"], 98 | Float[Array, "num_timesteps_minus_1 num_states num_states"]]: 99 | """Collect the sufficient statistics for the model.""" 100 | return posterior.trans_probs 101 | 102 | def initialize_m_step_state(self, params, props): 103 | """Initialize the state for the M-step.""" 104 | return None 105 | 106 | def m_step( 107 | self, 108 | params: ParamsStandardHMMTransitions, 109 | props: ParamsStandardHMMTransitions, 110 | batch_stats: Float[Array, "batch num_states num_states"], 111 | m_step_state: Any 112 | ) -> Tuple[ParamsStandardHMMTransitions, Any]: 113 | """Perform the M-step of the EM algorithm.""" 114 | if props.transition_matrix.trainable: 115 | if self.num_states == 1: 116 | transition_matrix = jnp.array([[1.0]]) 117 | else: 118 | expected_trans_counts = batch_stats.sum(axis=0) 119 | transition_matrix = tfd.Dirichlet(self.concentration + expected_trans_counts).mode() 120 | params = params._replace(transition_matrix=transition_matrix) 121 | return params, m_step_state 122 | -------------------------------------------------------------------------------- /dynamax/hidden_markov_model/parallel_inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Parallel implementations of the forward filtering and backward smoothing algorithms 3 | """ 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | from jax import lax, vmap, value_and_grad 7 | from jaxtyping import Array, Float, Int 8 | from typing import NamedTuple, Tuple 9 | 10 | from dynamax.hidden_markov_model.inference import HMMPosterior, HMMPosteriorFiltered 11 | from dynamax.types import Scalar 12 | 13 | #---------------------------------------------------------------------------# 14 | # Filtering # 15 | #---------------------------------------------------------------------------# 16 | 17 | class FilterMessage(NamedTuple): 18 | r"""Filtering associative scan elements. 19 | 20 | Attributes: 21 | A: $p(z_j \mid z_i)$ 22 | log_b: $\log P(y_{i+1}, ..., y_j \mid z_i)$ 23 | """ 24 | A: Float[Array, "num_timesteps num_states num_states"] 25 | log_b: Float[Array, "num_timesteps num_states"] 26 | 27 | 28 | def _condition_on(A : Float[Array, "num_states num_states"], 29 | ll : Float[Array, " num_states"], 30 | axis : int=-1) -> \ 31 | Tuple[Float[Array, "num_states num_states"], Float[Array, "num_states"]]: 32 | """ Update the message by conditioning on new observations. 33 | """ 34 | ll_max = ll.max(axis=axis) 35 | A_cond = A * jnp.exp(ll - ll_max) 36 | norm = A_cond.sum(axis=axis) 37 | A_cond /= jnp.expand_dims(norm, axis=axis) 38 | return A_cond, jnp.log(norm) + ll_max 39 | 40 | 41 | def hmm_filter(initial_probs: Float[Array, " num_states"], 42 | transition_matrix: Float[Array, "num_states num_states"], 43 | log_likelihoods: Float[Array, "num_timesteps num_states"] 44 | ) -> HMMPosteriorFiltered: 45 | r"""Parallel implementation of the forward filtering algorithm with `jax.lax.associative_scan`. 46 | 47 | *Note: for this function, the transition matrix must be fixed. We may add support 48 | for nonstationary transition matrices in a future release.* 49 | 50 | Args: 51 | initial_distribution: $p(z_1 \mid u_1, \theta)$ 52 | transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$ 53 | log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$. 54 | 55 | Returns: 56 | filtered posterior distribution 57 | 58 | """ 59 | T, K = log_likelihoods.shape 60 | 61 | @vmap 62 | def marginalize(m_ij, m_jk): 63 | """ 64 | Compute the message from time i to time k by marginalizing out 65 | the hidden state at time j. 66 | """ 67 | A_ij_cond, lognorm = _condition_on(m_ij.A, m_jk.log_b) 68 | A_ik = A_ij_cond @ m_jk.A 69 | log_b_ik = m_ij.log_b + lognorm 70 | return FilterMessage(A=A_ik, log_b=log_b_ik) 71 | 72 | 73 | # Initialize the messages 74 | A0, log_b0 = _condition_on(initial_probs, log_likelihoods[0]) 75 | A0 *= jnp.ones((K, K)) 76 | log_b0 *= jnp.ones(K) 77 | A1T, log_b1T = vmap(_condition_on, in_axes=(None, 0))(transition_matrix, log_likelihoods[1:]) 78 | initial_messages = FilterMessage( 79 | A=jnp.concatenate([A0[None, :, :], A1T]), 80 | log_b=jnp.vstack([log_b0, log_b1T]) 81 | ) 82 | 83 | # Run the associative scan 84 | partial_messages = lax.associative_scan(marginalize, initial_messages) 85 | 86 | # Extract the marginal log likelihood and filtered probabilities 87 | marginal_loglik = partial_messages.log_b[-1,0] 88 | filtered_probs = partial_messages.A[:, 0, :] 89 | 90 | # Compute the predicted probabilities 91 | predicted_probs = jnp.vstack([initial_probs, filtered_probs[:-1] @ transition_matrix]) 92 | 93 | # Package into a posterior object 94 | return HMMPosteriorFiltered(marginal_loglik=marginal_loglik, 95 | filtered_probs=filtered_probs, 96 | predicted_probs=predicted_probs) 97 | 98 | 99 | #---------------------------------------------------------------------------# 100 | # Smoothing # 101 | #---------------------------------------------------------------------------# 102 | 103 | 104 | def hmm_smoother(initial_probs: Float[Array, " num_states"], 105 | transition_matrix: Float[Array, "num_states num_states"], 106 | log_likelihoods: Float[Array, "num_timesteps num_states"] 107 | ) -> HMMPosterior: 108 | r"""Parallel implementation of HMM smoothing with `jax.lax.associative_scan`. 109 | 110 | **Notes:** 111 | 112 | * This implementation uses the automatic differentiation of the HMM log normalizer rather than an explicit implementation of the backward message passing. 113 | * The transition matrix must be fixed. We may add support for nonstationary transition matrices in a future release. 114 | 115 | Args: 116 | initial_distribution: $p(z_1 \mid u_1, \theta)$ 117 | transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$ 118 | log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$. 119 | 120 | Returns: 121 | smoothed posterior distribution 122 | 123 | """ 124 | def log_normalizer(log_initial_probs, log_transition_matrix, log_likelihoods): 125 | """Compute the log normalizer of the HMM model.""" 126 | post = hmm_filter(jnp.exp(log_initial_probs), 127 | jnp.exp(log_transition_matrix), 128 | log_likelihoods) 129 | return post.marginal_loglik, post 130 | 131 | f = value_and_grad(log_normalizer, has_aux=True, argnums=(1, 2)) 132 | (marginal_loglik, fwd_post), (trans_probs, smoothed_probs) = \ 133 | f(jnp.log(initial_probs), jnp.log(transition_matrix), log_likelihoods) 134 | 135 | return HMMPosterior( 136 | marginal_loglik=marginal_loglik, 137 | filtered_probs=fwd_post.filtered_probs, 138 | predicted_probs=fwd_post.predicted_probs, 139 | initial_probs=smoothed_probs[0], 140 | smoothed_probs=smoothed_probs, 141 | trans_probs=trans_probs 142 | ) 143 | 144 | 145 | #---------------------------------------------------------------------------# 146 | # Sampling # 147 | #---------------------------------------------------------------------------# 148 | r"""Associative scan elements $E_ij$ are vectors specifying a sample:: 149 | 150 | $z_j ~ p(z_j \mid z_i)$ 151 | 152 | for each possible value of $z_i$. 153 | """ 154 | 155 | def _initialize_sampling_messages(key, transition_matrix, filtered_probs): 156 | """Preprocess filtering output to construct input for sampling assocative scan.""" 157 | 158 | T, K = filtered_probs.shape 159 | keys = jr.split(key, T) 160 | 161 | def _last_message(key, probs): 162 | """Sample the last hidden state.""" 163 | state = jr.choice(key, K, p=probs) 164 | return jnp.repeat(state, K) 165 | 166 | @vmap 167 | def _generic_message(key, probs): 168 | """Sample a hidden state given the previous state.""" 169 | smoothed_probs = probs * transition_matrix.T 170 | smoothed_probs = smoothed_probs / smoothed_probs.sum(1).reshape(K,1) 171 | return vmap(lambda p: jr.choice(key, K, p=p))(smoothed_probs) 172 | 173 | En = _last_message(keys[-1], filtered_probs[-1]) 174 | Et = _generic_message(keys[:-1], filtered_probs[:-1]) 175 | return jnp.concatenate([Et, En[None]]) 176 | 177 | 178 | def hmm_posterior_sample(key: Array, 179 | initial_distribution: Float[Array, " num_states"], 180 | transition_matrix: Float[Array, "num_states num_states"], 181 | log_likelihoods: Float[Array, "num_timesteps num_states"] 182 | ) -> Tuple[Scalar, Int[Array, " num_timesteps"]]: 183 | r"""Sample a sequence of hidden states from the posterior. 184 | 185 | Args: 186 | key: random number generator 187 | initial_distribution: $p(z_1 \mid u_1, \theta)$ 188 | transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$ 189 | log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$. 190 | 191 | Returns: 192 | log_normalizer: $\log P(y_{1:T} \mid u_{1:T}, \theta)$ 193 | states: sequence of hidden states $z_{1:T}$ 194 | """ 195 | # Run the HMM filter 196 | post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods) 197 | log_normalizer = post.marginal_loglik 198 | filtered_probs = post.filtered_probs 199 | 200 | @vmap 201 | def _operator(E_jk, E_ij): 202 | """Sample a hidden state given the previous state.""" 203 | return jnp.take(E_ij, E_jk) 204 | 205 | initial_messages = _initialize_sampling_messages(key, transition_matrix, filtered_probs) 206 | final_messages = lax.associative_scan(_operator, initial_messages, reverse=True) 207 | states = final_messages[:,0] 208 | return log_normalizer, states 209 | -------------------------------------------------------------------------------- /dynamax/linear_gaussian_ssm/README.md: -------------------------------------------------------------------------------- 1 | # Linear Gaussian State Space Models 2 | 3 | We provide a model class definition, plus code for state estimation using Kalman filtering and smoothing, and code for parameter estimation. 4 | -------------------------------------------------------------------------------- /dynamax/linear_gaussian_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM 2 | from dynamax.linear_gaussian_ssm.inference import ParamsLGSSMInitial 3 | from dynamax.linear_gaussian_ssm.inference import ParamsLGSSMDynamics 4 | from dynamax.linear_gaussian_ssm.inference import ParamsLGSSMEmissions 5 | from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered 6 | from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMSmoothed 7 | from dynamax.linear_gaussian_ssm.inference import lgssm_filter 8 | from dynamax.linear_gaussian_ssm.inference import lgssm_smoother 9 | from dynamax.linear_gaussian_ssm.inference import lgssm_posterior_sample 10 | from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample 11 | 12 | from dynamax.linear_gaussian_ssm.info_inference import ParamsLGSSMInfo 13 | from dynamax.linear_gaussian_ssm.info_inference import PosteriorGSSMInfoFiltered 14 | from dynamax.linear_gaussian_ssm.info_inference import PosteriorGSSMInfoSmoothed 15 | from dynamax.linear_gaussian_ssm.info_inference import lgssm_info_filter 16 | from dynamax.linear_gaussian_ssm.info_inference import lgssm_info_smoother 17 | 18 | from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_filter as parallel_lgssm_filter 19 | from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_smoother as parallel_lgssm_smoother 20 | from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_posterior_sample as parallel_lgssm_posterior_sample 21 | 22 | from dynamax.linear_gaussian_ssm.models import LinearGaussianConjugateSSM, LinearGaussianSSM 23 | -------------------------------------------------------------------------------- /dynamax/linear_gaussian_ssm/demos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/dynamax/linear_gaussian_ssm/demos/__init__.py -------------------------------------------------------------------------------- /dynamax/linear_gaussian_ssm/inference_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the inference methods of the Linear Gaussian State Space Model. 3 | """ 4 | import pytest 5 | 6 | import jax.scipy.linalg as jla 7 | import jax.numpy as jnp 8 | import tensorflow_probability.substrates.jax.distributions as tfd 9 | 10 | from functools import partial 11 | from dynamax.linear_gaussian_ssm import LinearGaussianSSM 12 | from dynamax.utils.utils import has_tpu 13 | from jax import vmap 14 | from jax import random as jr 15 | 16 | # Use different tolerance threshold for TPU 17 | if has_tpu(): 18 | allclose = partial(jnp.allclose, atol=1e-1) 19 | else: 20 | allclose = partial(jnp.allclose, atol=1e-4) 21 | 22 | 23 | def flatten_diagonal_emission_cov(params): 24 | """Flatten a diagonal emission covariance matrix into a vector. 25 | 26 | Args: 27 | params: LGSSMParams object. 28 | 29 | Returns: 30 | params: LGSSMParams object with flattened diagonal emission covariance. 31 | """ 32 | R = params.emissions.cov 33 | 34 | if R.ndim == 2: 35 | R_diag = jnp.diag(R) 36 | R_full = jnp.diag(R_diag) 37 | else: 38 | R_diag = vmap(jnp.diag)(R) 39 | R_full = vmap(jnp.diag)(R_diag) 40 | 41 | assert allclose(R, R_full), "R is not diagonal" 42 | 43 | emission_params_diag = params.emissions._replace(cov=R_diag) 44 | params = params._replace(emissions=emission_params_diag) 45 | return params 46 | 47 | 48 | def joint_posterior_mvn(params, emissions): 49 | """Construct the joint posterior MVN of a LGSSM, by inverting the joint precision matrix which 50 | has a known block tridiagonal form. 51 | 52 | Args: 53 | params: LGSSMParams object. 54 | emissions: Emission data. 55 | 56 | Returns: 57 | means: jnp.ndarray, shape (num_timesteps, state_dim), the joint posterior means. 58 | Sigma_diag_blocks: jnp.ndarray, shape (num_timesteps, state_dim, state_dim), the joint posterior covariance diagonal blocks. 59 | """ 60 | Q = params.dynamics.cov 61 | R = params.emissions.cov 62 | F = params.dynamics.weights 63 | H = params.emissions.weights 64 | mu0 = params.initial.mean 65 | Sigma0 = params.initial.cov 66 | num_timesteps = emissions.shape[0] 67 | state_dim = params.dynamics.weights.shape[0] 68 | emission_dim = params.emissions.weights.shape[0] 69 | Qinv = jnp.linalg.inv(Q) 70 | Rinv = jnp.linalg.inv(R) 71 | Sigma0inv = jnp.linalg.inv(Sigma0) 72 | 73 | # Construct the big precision matrix (block tridiagonal) 74 | # set up small blocks 75 | Omega1 = F.T @ Qinv @ F + H.T @ Rinv @ H + Sigma0inv 76 | Omegat = F.T @ Qinv @ F + H.T @ Rinv @ H + Qinv 77 | OmegaT = Qinv + H.T @ Rinv @ H 78 | OmegaC = - F.T @ Qinv 79 | 80 | # construct big block diagonal matrix 81 | blocks = [Omega1] + [Omegat] * (num_timesteps-2) + [OmegaT] 82 | Omega_diag = jla.block_diag(*blocks) 83 | 84 | # construct big block super/sub-diagonal matrices and sum 85 | aux = jnp.empty((0, state_dim), int) 86 | blocks = [OmegaC] * (num_timesteps-1) 87 | Omega_superdiag = jla.block_diag(aux, *blocks, aux.T) 88 | Omega_subdiag = Omega_superdiag.T 89 | Omega = Omega_diag + Omega_superdiag + Omega_subdiag 90 | 91 | # Compute the joint covariance matrix 92 | # diagonal blocks are the smoothed covariances (marginals of the full joint) 93 | Sigma = jnp.linalg.inv(Omega) 94 | covs = jnp.array([Sigma[i:i+state_dim, i:i+state_dim] for i in range(0, num_timesteps*state_dim, state_dim)]) 95 | 96 | # Compute the means (these are the smoothing means) 97 | # they are the solution to the big linear system Omega @ means = rhs 98 | padded = jnp.pad(Sigma0inv @ mu0, (0, (num_timesteps-1)*state_dim ), constant_values=0).reshape(num_timesteps * state_dim, 1) 99 | rhs = jla.block_diag(*[H.T @ Rinv] * num_timesteps) @ emissions.reshape((num_timesteps*emission_dim, 1)) + padded 100 | means = Sigma @ rhs 101 | means = means.reshape((num_timesteps, state_dim)) 102 | 103 | return means, covs 104 | 105 | 106 | def lgssm_dynamax_to_tfp(num_timesteps, params): 107 | """Create a Tensorflow Probability `LinearGaussianStateSpaceModel` object 108 | from an dynamax `LinearGaussianSSM`. 109 | 110 | Args: 111 | num_timesteps: int, the number of timesteps. 112 | lgssm: LinearGaussianSSM or LGSSMParams object. 113 | """ 114 | dynamics_noise_dist = tfd.MultivariateNormalFullCovariance(covariance_matrix=params.dynamics.cov) 115 | emission_noise_dist = tfd.MultivariateNormalFullCovariance(covariance_matrix=params.emissions.cov) 116 | initial_dist = tfd.MultivariateNormalFullCovariance(params.initial.mean, params.initial.cov) 117 | 118 | tfp_lgssm = tfd.LinearGaussianStateSpaceModel( 119 | num_timesteps, 120 | params.dynamics.weights, 121 | dynamics_noise_dist, 122 | params.emissions.weights, 123 | emission_noise_dist, 124 | initial_dist, 125 | ) 126 | 127 | return tfp_lgssm 128 | 129 | 130 | class TestFilteringAndSmoothing: 131 | """ 132 | Tests for the filtering and smoothing methods of the Linear Gaussian State Space 133 | """ 134 | key = jr.PRNGKey(0) 135 | num_timesteps = 15 136 | num_samples = 1000 137 | state_dim = 4 138 | emission_dim = 2 139 | 140 | k1, k2, k3 = jr.split(key, 3) 141 | 142 | # Construct an LGSSM with simple dynamics and emissions 143 | mu0 = jnp.array([0.0, 0.0, 0.0, 0.0]) 144 | Sigma0 = jnp.eye(state_dim) * 0.1 145 | F = jnp.array([[1, 0, 1, 0], 146 | [0, 1, 0, 1], 147 | [0, 0, 1, 0], 148 | [0, 0, 0, 1]], dtype=jnp.float32) 149 | Q = jnp.eye(state_dim) * 0.001 150 | H = jnp.array([[1.0, 0, 0, 0], 151 | [0, 1.0, 0, 0]]) 152 | R = jnp.eye(emission_dim) * 1.0 153 | 154 | lgssm = LinearGaussianSSM(state_dim, emission_dim) 155 | params, _ = lgssm.initialize(k1, 156 | initial_mean=mu0, 157 | initial_covariance=Sigma0, 158 | dynamics_weights=F, 159 | dynamics_covariance=Q, 160 | emission_weights=H, 161 | emission_covariance=R) 162 | 163 | # Sample random emissions 164 | _, emissions = lgssm.sample(params, k2, num_timesteps) 165 | 166 | # Run the smoother with the full covariance parameterization 167 | posterior = lgssm.smoother(params, emissions) 168 | 169 | # Run the smoother with the diagonal covariance parameterization 170 | params_diag = flatten_diagonal_emission_cov(params) 171 | ssm_posterior_diag = lgssm.smoother(params_diag, emissions) 172 | 173 | # Sample from the posterior distribution 174 | posterior_sample = partial(lgssm.posterior_sample, 175 | params=params, 176 | emissions=emissions) 177 | samples = vmap(posterior_sample)(jr.split(k3, num_samples)) 178 | 179 | def test_smoother_vs_tfp(self): 180 | """Test that the dynamax and TFP implementations of the Kalman filter are consistent.""" 181 | tfp_lgssm = lgssm_dynamax_to_tfp(self.num_timesteps, self.params) 182 | tfp_lls, tfp_filtered_means, tfp_filtered_covs, *_ = tfp_lgssm.forward_filter(self.emissions) 183 | tfp_smoothed_means, tfp_smoothed_covs = tfp_lgssm.posterior_marginals(self.emissions) 184 | 185 | assert allclose(self.posterior.filtered_means, tfp_filtered_means) 186 | assert allclose(self.posterior.filtered_covariances, tfp_filtered_covs) 187 | assert allclose(self.posterior.smoothed_means, tfp_smoothed_means) 188 | assert allclose(self.posterior.smoothed_covariances, tfp_smoothed_covs) 189 | assert allclose(self.posterior.marginal_loglik, tfp_lls.sum()) 190 | 191 | # Compare posterior with diagonal emission covariance 192 | assert allclose(self.ssm_posterior_diag.filtered_means, tfp_filtered_means) 193 | assert allclose(self.ssm_posterior_diag.filtered_covariances, tfp_filtered_covs) 194 | assert allclose(self.ssm_posterior_diag.smoothed_means, tfp_smoothed_means) 195 | assert allclose(self.ssm_posterior_diag.smoothed_covariances, tfp_smoothed_covs) 196 | assert allclose(self.ssm_posterior_diag.marginal_loglik, tfp_lls.sum()) 197 | 198 | 199 | def test_kalman_vs_joint(self): 200 | """Test that the dynamax and joint posterior methods are consistent.""" 201 | joint_means, joint_covs = joint_posterior_mvn(self.params, self.emissions) 202 | 203 | assert allclose(self.posterior.smoothed_means, joint_means) 204 | assert allclose(self.posterior.smoothed_covariances, joint_covs) 205 | assert allclose(self.ssm_posterior_diag.smoothed_means, joint_means) 206 | assert allclose(self.ssm_posterior_diag.smoothed_covariances, joint_covs) 207 | 208 | 209 | def test_posterior_samples(self): 210 | """Test that posterior samples match the mean of the smoother""" 211 | monte_carlo_var = vmap(jnp.diag)(self.posterior.smoothed_covariances) / self.num_samples 212 | assert jnp.all(abs(jnp.mean(self.samples, axis=0) - self.posterior.smoothed_means) < 6 * jnp.sqrt(monte_carlo_var)) 213 | -------------------------------------------------------------------------------- /dynamax/linear_gaussian_ssm/info_inference_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for information form inference in linear Gaussian SSMs. 3 | """ 4 | import jax.numpy as jnp 5 | 6 | from functools import partial 7 | from jax import random as jr 8 | 9 | from dynamax.linear_gaussian_ssm.inference import lgssm_smoother, lgssm_filter 10 | from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions 11 | from dynamax.linear_gaussian_ssm.info_inference import lgssm_info_filter, lgssm_info_smoother, info_to_moment_form 12 | from dynamax.linear_gaussian_ssm.info_inference import ParamsLGSSMInfo 13 | from dynamax.utils.utils import has_tpu 14 | 15 | # Use lower tolerance for TPU tests. 16 | if has_tpu(): 17 | allclose = partial(jnp.allclose, atol=1e-1) 18 | else: 19 | allclose = partial(jnp.allclose, atol=1e-4) 20 | 21 | 22 | def build_lgssm_moment_and_info_form(): 23 | """Construct example LinearGaussianSSM and equivalent LGSSMInfoParams 24 | object for testing. 25 | """ 26 | 27 | delta = 1.0 28 | F = jnp.array([[1.0, 0, delta, 0], [0, 1.0, 0, delta], [0, 0, 1.0, 0], [0, 0, 0, 1.0]]) 29 | 30 | H = jnp.array([[1.0, 0, 0, 0], [0, 1.0, 0, 0]]) 31 | 32 | state_size, _ = F.shape 33 | observation_size, _ = H.shape 34 | 35 | Q = jnp.eye(state_size) * 0.001 36 | Q_prec = jnp.linalg.inv(Q) 37 | R = jnp.eye(observation_size) * 1.0 38 | R_prec = jnp.linalg.inv(R) 39 | 40 | input_size = 1 41 | B = jnp.array([1.0, 0.5, -0.05, -0.01]).reshape((state_size, input_size)) 42 | b = jnp.ones((state_size,)) * 0.01 43 | D = jnp.ones((observation_size, input_size)) 44 | d = jnp.ones((observation_size,)) * 0.02 45 | 46 | # Prior parameter distribution 47 | mu0 = jnp.array([8.0, 10.0, 1.0, 0.0]) 48 | Sigma0 = jnp.eye(state_size) * 0.1 49 | Lambda0 = jnp.linalg.inv(Sigma0) 50 | 51 | # Construct LGSSM 52 | lgssm = ParamsLGSSM( 53 | initial=ParamsLGSSMInitial(mean=mu0,cov=Sigma0), 54 | dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q), 55 | emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R) 56 | ) 57 | 58 | lgssm_info = ParamsLGSSMInfo( 59 | initial_mean=mu0, 60 | initial_precision=Lambda0, 61 | dynamics_weights=F, 62 | dynamics_precision=Q_prec, 63 | dynamics_input_weights=B, 64 | dynamics_bias=b, 65 | emission_weights=H, 66 | emission_precision=R_prec, 67 | emission_input_weights=D, 68 | emission_bias=d, 69 | ) 70 | 71 | return lgssm, lgssm_info 72 | 73 | 74 | class TestInfoFilteringAndSmoothing: 75 | """Test information form filtering and smoothing by comparing it to moment 76 | form. 77 | """ 78 | 79 | lgssm, lgssm_info = build_lgssm_moment_and_info_form() 80 | 81 | # Sample data from model. 82 | key = jr.PRNGKey(0) 83 | num_timesteps = 15 84 | input_size = lgssm.dynamics.input_weights.shape[1] 85 | inputs = jnp.zeros((num_timesteps, input_size)) 86 | 87 | y = jr.normal(key, (num_timesteps, 2)) 88 | 89 | lgssm_moment_posterior = lgssm_smoother(lgssm, y, inputs) 90 | lgssm_info_posterior = lgssm_info_smoother(lgssm_info, y, inputs) 91 | 92 | info_filtered_means, info_filtered_covs = info_to_moment_form( 93 | lgssm_info_posterior.filtered_etas, lgssm_info_posterior.filtered_precisions 94 | ) 95 | info_smoothed_means, info_smoothed_covs = info_to_moment_form( 96 | lgssm_info_posterior.smoothed_etas, lgssm_info_posterior.smoothed_precisions 97 | ) 98 | 99 | def test_filtered_means(self): 100 | """Test filtered means.""" 101 | assert allclose(self.info_filtered_means, self.lgssm_moment_posterior.filtered_means) 102 | 103 | def test_filtered_covs(self): 104 | """Test filtered covariances.""" 105 | assert allclose(self.info_filtered_covs, self.lgssm_moment_posterior.filtered_covariances) 106 | 107 | def test_smoothed_means(self): 108 | """Test smoothed means.""" 109 | assert allclose(self.info_smoothed_means, self.lgssm_moment_posterior.smoothed_means) 110 | 111 | def test_smoothed_covs(self): 112 | """Test smoothed covariances.""" 113 | assert allclose(self.info_smoothed_covs, self.lgssm_moment_posterior.smoothed_covariances) 114 | 115 | def test_marginal_loglik(self): 116 | """Test marginal log likelihood.""" 117 | assert allclose(self.lgssm_info_posterior.marginal_loglik, self.lgssm_moment_posterior.marginal_loglik) 118 | 119 | 120 | class TestInfoKFLinReg: 121 | """Test non-stationary emission matrix in information filter. 122 | 123 | Compare to moment form filter using the example in 124 | `lgssm/demos/kf_linreg.py` 125 | """ 126 | 127 | n_obs = 21 128 | x = jnp.linspace(0, 20, n_obs) 129 | X = jnp.column_stack((jnp.ones_like(x), x)) # Design matrix. (N,2) 130 | state_dim = X.shape[1] # 2 131 | emission_dim = 1 132 | F = jnp.eye(2) 133 | Q = jnp.zeros((2, 2)) # No parameter drift. 134 | Q_prec = jnp.diag(jnp.repeat(1e32, 2)) # Can't use infinite precision. 135 | obs_var = 1.0 136 | R = jnp.ones((1, 1)) * obs_var 137 | R_prec = jnp.linalg.inv(R) 138 | mu0 = jnp.zeros(2) 139 | Sigma0 = jnp.eye(2) * 10.0 140 | Lambda0 = jnp.linalg.inv(Sigma0) 141 | 142 | # Data from original matlab example 143 | y = jnp.array([ 2.4865, -0.3033, -4.0531, -4.3359, -6.1742, -5.604 , 144 | -3.5069, -2.3257, -4.6377, -0.2327, -1.9858, 1.0284, 145 | -2.264 , -0.4508, 1.1672, 6.6524, 4.1452, 5.2677, 146 | 6.3403, 9.6264, 14.7842]) 147 | inputs = jnp.zeros((len(y), 1)) 148 | input_dim = inputs.shape[1] 149 | 150 | lgssm_moment = ParamsLGSSM( 151 | initial=ParamsLGSSMInitial(mean=mu0,cov=Sigma0), 152 | dynamics=ParamsLGSSMDynamics(weights=F, bias=jnp.zeros(state_dim), input_weights=jnp.zeros((state_dim, input_dim)), cov=Q), 153 | emissions=ParamsLGSSMEmissions(weights=X[:, None, :], bias=jnp.zeros(emission_dim), input_weights=jnp.zeros((emission_dim, input_dim)), cov=R) 154 | ) 155 | 156 | lgssm_info = ParamsLGSSMInfo( 157 | initial_mean=mu0, 158 | initial_precision=Lambda0, 159 | dynamics_weights=F, 160 | dynamics_input_weights=jnp.zeros((mu0.shape[0], 1)), # no inputs 161 | dynamics_bias=jnp.zeros(1), 162 | dynamics_precision=Q_prec, 163 | emission_weights=X[:, None, :], 164 | emission_input_weights=jnp.zeros(1), 165 | emission_bias=jnp.zeros(1), 166 | emission_precision=R_prec, 167 | ) 168 | 169 | lgssm_moment_posterior = lgssm_filter(lgssm_moment, y[:, None], inputs) 170 | lgssm_info_posterior = lgssm_info_filter(lgssm_info, y[:, None], inputs) 171 | 172 | info_filtered_means, info_filtered_covs = info_to_moment_form( 173 | lgssm_info_posterior.filtered_etas, lgssm_info_posterior.filtered_precisions 174 | ) 175 | 176 | def test_filtered_means(self): 177 | """Test filtered means.""" 178 | assert allclose(self.info_filtered_means, self.lgssm_moment_posterior.filtered_means) 179 | 180 | def test_filtered_covs(self): 181 | """Test filtered covariances.""" 182 | assert allclose(self.info_filtered_covs, self.lgssm_moment_posterior.filtered_covariances) 183 | -------------------------------------------------------------------------------- /dynamax/linear_gaussian_ssm/models_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the linear Gaussian SSM models. 3 | """ 4 | from functools import partial 5 | from itertools import count 6 | 7 | import pytest 8 | from jax import vmap 9 | import jax.numpy as jnp 10 | import jax.random as jr 11 | 12 | from dynamax.linear_gaussian_ssm import LinearGaussianSSM 13 | from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM 14 | from dynamax.utils.utils import monotonically_increasing 15 | 16 | NUM_TIMESTEPS = 100 17 | 18 | CONFIGS = [ 19 | (LinearGaussianSSM, dict(state_dim=2, emission_dim=10), None), 20 | (LinearGaussianConjugateSSM, dict(state_dim=2, emission_dim=10), None), 21 | ] 22 | 23 | @pytest.mark.parametrize(["cls", "kwargs", "inputs"], CONFIGS) 24 | def test_sample_and_fit(cls, kwargs, inputs): 25 | """ 26 | Test that the model can sample and fit the data. 27 | """ 28 | model = cls(**kwargs) 29 | #key1, key2 = jr.split(jr.PRNGKey(int(datetime.now().timestamp()))) 30 | key1, key2 = jr.split(jr.PRNGKey(0)) 31 | params, param_props = model.initialize(key1) 32 | states, emissions = model.sample(params, key2, num_timesteps=NUM_TIMESTEPS, inputs=inputs) 33 | fitted_params, lps = model.fit_em(params, param_props, emissions, inputs=inputs, num_iters=3) 34 | assert monotonically_increasing(lps) 35 | fitted_params, lps = model.fit_sgd(params, param_props, emissions, inputs=inputs, num_epochs=3) 36 | 37 | def test_fit_blocked_gibbs_batched(): 38 | """ 39 | Test that the blocked Gibbs sampler works for multiple observations. 40 | """ 41 | state_dim = 2 42 | emission_dim = 3 43 | num_timesteps = 4 44 | m_samples = 5 45 | keys = map(jr.PRNGKey, count()) 46 | m_keys = jr.split(next(keys), num=m_samples) 47 | 48 | model = LinearGaussianConjugateSSM(state_dim, emission_dim) 49 | params, _ = model.initialize(next(keys)) 50 | _, y_obs = vmap(partial(model.sample, params, num_timesteps=num_timesteps))(m_keys) 51 | 52 | model.fit_blocked_gibbs(next(keys), params, sample_size=6, emissions=y_obs) -------------------------------------------------------------------------------- /dynamax/nonlinear_gaussian_ssm/README.md: -------------------------------------------------------------------------------- 1 | # Nonlinear Gaussian State Space Models 2 | 3 | We provide a model definition for SSMs with nonlinear dynamics and/or nonlinear observations; 4 | we assume additive Gaussian noise. 5 | A variety of deterministic inference algorithsms (e.g., EKF, UKF, GGF) are supported. 6 | Parameter estimation is not yet supported, since the nonlinear functions are treated as blackboxes; 7 | however, the user can roll their own learning code. 8 | 9 | -------------------------------------------------------------------------------- /dynamax/nonlinear_gaussian_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | from dynamax.nonlinear_gaussian_ssm.models import ParamsNLGSSM, NonlinearGaussianSSM 2 | from dynamax.nonlinear_gaussian_ssm.inference_ekf import extended_kalman_filter 3 | from dynamax.nonlinear_gaussian_ssm.inference_ekf import extended_kalman_smoother 4 | from dynamax.nonlinear_gaussian_ssm.inference_ekf import extended_kalman_posterior_sample 5 | from dynamax.nonlinear_gaussian_ssm.inference_ukf import unscented_kalman_filter 6 | from dynamax.nonlinear_gaussian_ssm.inference_ukf import unscented_kalman_smoother 7 | from dynamax.nonlinear_gaussian_ssm.inference_ukf import UKFHyperParams -------------------------------------------------------------------------------- /dynamax/nonlinear_gaussian_ssm/inference_ekf_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the extended Kalman filter and smoother. 3 | """ 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | 7 | from functools import partial 8 | from dynamax.linear_gaussian_ssm import lgssm_filter, lgssm_smoother, lgssm_posterior_sample 9 | from dynamax.nonlinear_gaussian_ssm.inference_ekf import extended_kalman_filter, extended_kalman_smoother, extended_kalman_posterior_sample 10 | from dynamax.nonlinear_gaussian_ssm.inference_test_utils import lgssm_to_nlgssm, random_lgssm_args, random_nlgssm_args 11 | from dynamax.nonlinear_gaussian_ssm.sarkka_lib import ekf, eks 12 | from dynamax.utils.utils import has_tpu 13 | from jax import vmap 14 | 15 | if has_tpu(): 16 | # TPU has very poor numerical stability 17 | allclose = partial(jnp.allclose, atol=1e-1) 18 | else: 19 | allclose = partial(jnp.allclose, atol=1e-4) 20 | 21 | 22 | def test_extended_kalman_filter_linear(key=0, num_timesteps=15): 23 | """ 24 | Test that the extended Kalman filter produces the correct filtered moments 25 | in the linear Gaussian case. 26 | """ 27 | args, _, emissions = random_lgssm_args(key=key, num_timesteps=num_timesteps) 28 | 29 | # Run standard Kalman filter 30 | kf_post = lgssm_filter(args, emissions) 31 | # Run extended Kalman filter 32 | ekf_post = extended_kalman_filter(lgssm_to_nlgssm(args), emissions) 33 | 34 | # Compare filter results 35 | assert allclose(kf_post.marginal_loglik, ekf_post.marginal_loglik) 36 | assert allclose(kf_post.filtered_means, ekf_post.filtered_means) 37 | assert allclose(kf_post.filtered_covariances, ekf_post.filtered_covariances) 38 | 39 | 40 | def test_extended_kalman_filter_nonlinear(key=42, num_timesteps=15): 41 | """ 42 | Test that the extended Kalman filter produces the correct filtered moments 43 | by comparing it to the sarkka-jax library. 44 | """ 45 | args, _, emissions = random_nlgssm_args(key=key, num_timesteps=num_timesteps) 46 | 47 | # Run EKF from sarkka-jax library 48 | means_ext, covs_ext = ekf(*args, emissions) 49 | # Run EKF from dynamax 50 | ekf_post = extended_kalman_filter(args, emissions) 51 | 52 | # Compare filter results 53 | assert allclose(means_ext, ekf_post.filtered_means) 54 | assert allclose(covs_ext, ekf_post.filtered_covariances) 55 | 56 | 57 | def test_extended_kalman_smoother_linear(key=0, num_timesteps=15): 58 | """ 59 | Test that the extended Kalman smoother produces the correct smoothed moments 60 | in the linear Gaussian case. 61 | """ 62 | args, _, emissions = random_lgssm_args(key=key, num_timesteps=num_timesteps) 63 | 64 | # Run standard Kalman smoother 65 | kf_post = lgssm_smoother(args, emissions) 66 | # Run extended Kalman filter 67 | ekf_post = extended_kalman_smoother(lgssm_to_nlgssm(args), emissions) 68 | 69 | # Compare smoother results 70 | assert allclose(kf_post.smoothed_means, ekf_post.smoothed_means) 71 | assert allclose(kf_post.smoothed_covariances, ekf_post.smoothed_covariances) 72 | 73 | 74 | def extended_kalman_smoother_nonlinear(key=0, num_timesteps=15): 75 | """ 76 | Test that the extended Kalman smoother produces the correct smoothed moments 77 | by comparing it to the sarkka-jax library. 78 | """ 79 | args, _, emissions = random_nlgssm_args(key=key, num_timesteps=num_timesteps) 80 | 81 | # Run EK smoother from sarkka-jax library 82 | means_ext, covs_ext = eks(*args, emissions) 83 | # Run EK smoother from dynamax 84 | ekf_post = extended_kalman_smoother(args, emissions) 85 | 86 | # Compare filter results 87 | assert allclose(means_ext, ekf_post.smoothed_means) 88 | assert allclose(covs_ext, ekf_post.smoothed_covariances) 89 | 90 | 91 | def test_extended_kalman_sampler_linear(key=0, num_timesteps=15): 92 | """ 93 | Test that the extended Kalman sampler produces samples with the correct mean 94 | in the linear Gaussian case. 95 | """ 96 | args, _, emissions = random_lgssm_args(key=key, num_timesteps=num_timesteps) 97 | new_key = jr.split(jr.PRNGKey(key))[1] 98 | 99 | # Run standard Kalman sampler 100 | kf_sample = lgssm_posterior_sample(new_key, args, emissions) 101 | # Run extended Kalman sampler 102 | ekf_sample = extended_kalman_posterior_sample(new_key, lgssm_to_nlgssm(args), emissions) 103 | 104 | # Compare samples 105 | assert allclose(kf_sample, ekf_sample) 106 | 107 | 108 | def test_extended_kalman_sampler_nonlinear(key=0, num_timesteps=15, sample_size=50000): 109 | """ 110 | Test that the extended Kalman sampler produces samples with the correct mean. 111 | """ 112 | # note: empirical covariance needs a large sample_size to converge 113 | 114 | args, _, emissions = random_nlgssm_args(key=key, num_timesteps=num_timesteps) 115 | 116 | # Run EK smoother from dynamax 117 | ekf_post = extended_kalman_smoother(args, emissions) 118 | 119 | # Run extended Kalman sampler 120 | sampler = vmap(extended_kalman_posterior_sample, in_axes=(0,None,None)) 121 | keys = jr.split(jr.PRNGKey(key), sample_size) 122 | ekf_samples = sampler(keys, args, emissions) 123 | 124 | # Compare sample moments to smoother output 125 | # Use the posterior variance to compute the variance of the Monte Carlo estimate, 126 | # and check that the differences are within 6 standard deviations. 127 | post_variance = vmap(jnp.diag)(ekf_post.smoothed_covariances) 128 | threshold = 6 * jnp.sqrt(post_variance / sample_size) 129 | empirical_means = ekf_samples.mean(0) 130 | assert jnp.all(abs(empirical_means - ekf_post.smoothed_means) < threshold) 131 | -------------------------------------------------------------------------------- /dynamax/nonlinear_gaussian_ssm/inference_test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains utility functions that are used to test EKF, UKF and GGF inference, 3 | by comparing the results to the sarkka-lib codebase on some toy problems. 4 | """ 5 | import jax.random as jr 6 | import jax.numpy as jnp 7 | 8 | from jaxtyping import Array, Float 9 | from typing import Tuple, Union 10 | import tensorflow_probability.substrates.jax as tfp 11 | 12 | from dynamax.linear_gaussian_ssm import LinearGaussianSSM 13 | from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions 14 | from dynamax.nonlinear_gaussian_ssm.models import ParamsNLGSSM, NonlinearGaussianSSM 15 | from dynamax.parameters import ParameterProperties 16 | from dynamax.ssm import SSM 17 | from dynamax.utils.bijectors import RealToPSDBijector 18 | from dynamax.types import PRNGKeyT 19 | 20 | 21 | tfd = tfp.distributions 22 | tfb = tfp.bijectors 23 | MVN = tfd.MultivariateNormalFullCovariance 24 | 25 | 26 | def lgssm_to_nlgssm(params: ParamsLGSSM) -> ParamsNLGSSM: 27 | """Generates NonLinearGaussianSSM params from LinearGaussianSSM params""" 28 | nlgssm_params = ParamsNLGSSM( 29 | initial_mean=params.initial.mean, 30 | initial_covariance=params.initial.cov, 31 | dynamics_function=lambda x: params.dynamics.weights @ x + params.dynamics.bias, 32 | dynamics_covariance=params.dynamics.cov, 33 | emission_function=lambda x: params.emissions.weights @ x + params.emissions.bias, 34 | emission_covariance=params.emissions.cov, 35 | ) 36 | return nlgssm_params 37 | 38 | 39 | def random_lgssm_args(key: Union[int, PRNGKeyT] = 0, 40 | num_timesteps: int = 15, 41 | state_dim: int = 4, 42 | emission_dim: int = 2) -> \ 43 | Tuple[ParamsLGSSM, Float[Array, "ntime state_dim"], 44 | Float[Array, "ntime emission_dim"]]: 45 | """ 46 | Generates random LGSSM parameters, states and emissions. 47 | """ 48 | if isinstance(key, int): 49 | key = jr.PRNGKey(key) 50 | *keys, subkey = jr.split(key, 9) 51 | 52 | # Generate random parameters 53 | initial_mean = jr.normal(keys[0], (state_dim,)) 54 | initial_covariance = jnp.eye(state_dim) * jr.uniform(keys[1]) 55 | dynamics_covariance = jnp.eye(state_dim) * jr.uniform(keys[2]) 56 | emission_covariance = jnp.eye(emission_dim) * jr.uniform(keys[3]) 57 | 58 | params = ParamsLGSSM( 59 | initial=ParamsLGSSMInitial( 60 | mean=initial_mean, 61 | cov=initial_covariance 62 | ), 63 | dynamics=ParamsLGSSMDynamics( 64 | weights=jr.normal(keys[4], (state_dim, state_dim)), 65 | bias=jr.normal(keys[5], (state_dim,)), 66 | input_weights=jnp.zeros((state_dim, 0)), 67 | cov=dynamics_covariance 68 | ), 69 | emissions=ParamsLGSSMEmissions( 70 | weights=jr.normal(keys[6], (emission_dim, state_dim)), 71 | bias=jr.normal(keys[7], (emission_dim,)), 72 | input_weights=jnp.zeros((emission_dim, 0)), 73 | cov=emission_covariance, 74 | ) 75 | ) 76 | 77 | # Generate random samples 78 | model = LinearGaussianSSM(state_dim, emission_dim) 79 | key, subkey = jr.split(subkey, 2) 80 | states, emissions = model.sample(params, key, num_timesteps) 81 | return params, states, emissions 82 | 83 | 84 | def to_poly(state, degree): 85 | """ 86 | Returns the polynomial features of the state up to the given degree. 87 | """ 88 | return jnp.concatenate([state**d for d in jnp.arange(degree+1)]) 89 | 90 | def make_nlgssm_params(state_dim, 91 | emission_dim, 92 | dynamics_degree=1, 93 | emission_degree=1, 94 | key=jr.PRNGKey(0)): 95 | """ 96 | Generates random NLGSSM parameters. 97 | """ 98 | dynamics_weights = jr.normal(key, (state_dim, state_dim * (dynamics_degree+1))) 99 | f = lambda z: jnp.sin(dynamics_weights @ to_poly(z, dynamics_degree)) 100 | emission_weights = jr.normal(key, (emission_dim, state_dim * (emission_degree+1))) 101 | h = lambda z: jnp.cos(emission_weights @ to_poly(z, emission_degree)) 102 | params = ParamsNLGSSM( 103 | initial_mean = 0.2 * jnp.ones(state_dim), 104 | initial_covariance = jnp.eye(state_dim), 105 | dynamics_function = f, 106 | dynamics_covariance = jnp.eye(state_dim), 107 | emission_function = h, 108 | emission_covariance = jnp.eye(emission_dim) 109 | ) 110 | return params 111 | 112 | class SimpleNonlinearSSM(SSM): 113 | """ 114 | Simple nonlinear SSM with sinusoidal dynamics and cosine emissions. 115 | """ 116 | def __init__(self, state_dim, emission_dim, dynamics_degree=1, emission_degree=1): 117 | self.state_dim = state_dim 118 | self.emission_dim = emission_dim 119 | self.dynamics_degree = dynamics_degree 120 | self.emission_degree = emission_degree 121 | 122 | @property 123 | def emission_shape(self): 124 | """Returns the shape of the emission distribution.""" 125 | return (self.emission_dim,) 126 | 127 | def initial_distribution(self, params, covariates=None): 128 | """Returns the initial distribution.""" 129 | return MVN(params["initial"]["mean"], params["initial"]["cov"]) 130 | 131 | def transition_distribution(self, params, state, covariates=None): 132 | """Returns the nonlinear dynamics function.""" 133 | x = to_poly(state, self.dynamics_degree) 134 | mean = jnp.sin(params["dynamics"]["weights"] @ x) 135 | return MVN(mean, params["dynamics"]["cov"]) 136 | 137 | def emission_distribution(self, params, state, covariates=None): 138 | """Returns the nonlinear emission function.""" 139 | x = to_poly(state, self.emission_degree) 140 | mean = jnp.cos(params["emissions"]["weights"] @ x) 141 | return MVN(mean, params["emissions"]["cov"]) 142 | 143 | def initialize(self, key): 144 | """Initializes the parameters.""" 145 | key1, key2 = jr.split(key) 146 | params = dict( 147 | initial=dict(mean=0.2 * jnp.ones(self.state_dim), cov=jnp.eye(self.state_dim)), 148 | dynamics=dict(weights=jr.normal(key1, (self.state_dim, self.state_dim * (self.dynamics_degree+1))), 149 | cov=jnp.eye(self.state_dim)), 150 | emissions=dict(weights=jr.normal(key2, (self.emission_dim, self.state_dim * (self.emission_degree+1))), 151 | cov=jnp.eye(self.emission_dim)), 152 | ) 153 | 154 | param_props = dict( 155 | initial=dict(mean=ParameterProperties(), 156 | cov=ParameterProperties(constrainer=RealToPSDBijector())), 157 | dynamics=dict(weights=ParameterProperties(), 158 | cov=ParameterProperties(constrainer=RealToPSDBijector())), 159 | emissions=dict(weights=ParameterProperties(), 160 | cov=ParameterProperties(constrainer=RealToPSDBijector())), 161 | ) 162 | return params, param_props 163 | 164 | def _make_inference_args(self, params): 165 | """Returns the inference arguments.""" 166 | f = lambda state: jnp.sin(params["dynamics"]["weights"] @ to_poly(state, self.dynamics_degree)) 167 | h = lambda state: jnp.cos(params["emissions"]["weights"] @ to_poly(state, self.emission_degree)) 168 | return ParamsNLGSSM( 169 | initial_mean=params["initial"]["mean"], 170 | initial_covariance=params["initial"]["cov"], 171 | dynamics_function=f, 172 | dynamics_covariance=params["dynamics"]["cov"], 173 | emission_function=h, 174 | emission_covariance=params["emissions"]["cov"]) 175 | 176 | 177 | def random_nlgssm_args(key=0, num_timesteps=15, state_dim=4, emission_dim=2): 178 | """ 179 | Generates random NLGSSM parameters, states and emissions. 180 | """ 181 | if isinstance(key, int): 182 | key = jr.PRNGKey(key) 183 | init_key, sample_key = jr.split(key, 2) 184 | params = make_nlgssm_params(state_dim, emission_dim, key=init_key) 185 | model = NonlinearGaussianSSM(state_dim, emission_dim) 186 | states, emissions = model.sample(params, sample_key, num_timesteps) 187 | return params, states, emissions 188 | 189 | -------------------------------------------------------------------------------- /dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the unscented Kalman filter and smoother. 3 | """ 4 | import jax.numpy as jnp 5 | 6 | from dynamax.nonlinear_gaussian_ssm.inference_ukf import unscented_kalman_smoother, UKFHyperParams 7 | from dynamax.nonlinear_gaussian_ssm.sarkka_lib import ukf, uks 8 | from dynamax.nonlinear_gaussian_ssm.inference_test_utils import random_nlgssm_args 9 | from dynamax.utils.utils import has_tpu 10 | from functools import partial 11 | 12 | if has_tpu(): 13 | allclose = partial(jnp.allclose, atol=1e-1) 14 | else: 15 | allclose = partial(jnp.allclose, atol=1e-4) 16 | 17 | def test_ukf_nonlinear(key=0, num_timesteps=15): 18 | """ 19 | Test that the unscented Kalman filter produces the correct filtered and smoothed moments 20 | by comparing it to the sarkka-jax library. 21 | """ 22 | nlgssm_args, _, emissions = random_nlgssm_args(key=key, num_timesteps=num_timesteps) 23 | hyperparams = UKFHyperParams() 24 | 25 | # Run UKF from sarkka-jax library 26 | means_ukf, covs_ukf = ukf(*nlgssm_args, *hyperparams, emissions) 27 | # Run UKS from sarkka-jax library 28 | means_uks, covs_uks = uks(*nlgssm_args, *hyperparams, emissions) 29 | # Run UKS from dynamax 30 | uks_post = unscented_kalman_smoother(nlgssm_args, emissions, hyperparams) 31 | 32 | # Compare filter results 33 | assert allclose(means_ukf, uks_post.filtered_means) 34 | assert allclose(covs_ukf, uks_post.filtered_covariances) 35 | assert allclose(means_uks, uks_post.smoothed_means) 36 | assert allclose(covs_uks, uks_post.smoothed_covariances) -------------------------------------------------------------------------------- /dynamax/nonlinear_gaussian_ssm/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Nonlinear Gaussian State Space Model objects. 3 | """ 4 | import tensorflow_probability.substrates.jax as tfp 5 | import tensorflow_probability.substrates.jax.distributions as tfd 6 | 7 | from dynamax.ssm import SSM 8 | from jaxtyping import Array, Float 9 | from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN 10 | from typing import NamedTuple, Optional, Union, Callable 11 | 12 | 13 | FnStateToState = Callable[ [Float[Array, " state_dim"]], Float[Array, " state_dim"]] 14 | FnStateAndInputToState = Callable[ [Float[Array, " state_dim"], Float[Array, " input_dim"]], Float[Array, " state_dim"]] 15 | FnStateToEmission = Callable[ [Float[Array, " state_dim"]], Float[Array, " emission_dim"]] 16 | FnStateAndInputToEmission = Callable[ [Float[Array, " state_dim"], Float[Array, " input_dim"] ], Float[Array, " emission_dim"]] 17 | 18 | 19 | class ParamsNLGSSM(NamedTuple): 20 | """Parameters for a NLGSSM model. 21 | 22 | $$p(z_t | z_{t-1}, u_t) = N(z_t | f(z_{t-1}, u_t), Q_t)$$ 23 | $$p(y_t | z_t) = N(y_t | h(z_t, u_t), R_t)$$ 24 | $$p(z_1) = N(z_1 | m, S)$$ 25 | 26 | If you have no inputs, the dynamics and emission functions do not to take $u_t$ as an argument. 27 | 28 | :param dynamics_function: $f$ 29 | :param dynamics_covariance: $Q$ 30 | :param emissions_function: $h$ 31 | :param emissions_covariance: $R$ 32 | :param initial_mean: $m$ 33 | :param initial_covariance: $S$ 34 | 35 | """ 36 | 37 | initial_mean: Float[Array, " state_dim"] 38 | initial_covariance: Float[Array, "state_dim state_dim"] 39 | dynamics_function: Union[FnStateToState, FnStateAndInputToState] 40 | dynamics_covariance: Float[Array, "state_dim state_dim"] 41 | emission_function: Union[FnStateToEmission, FnStateAndInputToEmission] 42 | emission_covariance: Float[Array, "emission_dim emission_dim"] 43 | 44 | 45 | class NonlinearGaussianSSM(SSM): 46 | """ 47 | Nonlinear Gaussian State Space Model. 48 | 49 | The model is defined as follows 50 | 51 | $$p(z_t | z_{t-1}, u_t) = N(z_t | f(z_{t-1}, u_t), Q_t)$$ 52 | $$p(y_t | z_t) = N(y_t | h(z_t, u_t), R_t)$$ 53 | $$p(z_1) = N(z_1 | m, S)$$ 54 | 55 | where the model parameters are 56 | 57 | * $z_t$ = hidden variables of size `state_dim`, 58 | * $y_t$ = observed variables of size `emission_dim` 59 | * $u_t$ = input covariates of size `input_dim` (defaults to 0). 60 | * $f$ = dynamics (transition) function 61 | * $h$ = emission (observation) function 62 | * $Q$ = covariance matrix of dynamics (system) noise 63 | * $R$ = covariance matrix for emission (observation) noise 64 | * $m$ = mean of initial state 65 | * $S$ = covariance matrix of initial state 66 | 67 | 68 | These parameters of the model are stored in a separate object of type :class:`ParamsNLGSSM`. 69 | """ 70 | def __init__(self, state_dim: int, emission_dim: int, input_dim: int = 0): 71 | self.state_dim = state_dim 72 | self.emission_dim = emission_dim 73 | self.input_dim = 0 74 | 75 | @property 76 | def emission_shape(self): 77 | """Returns the shape of the emission distribution.""" 78 | return (self.emission_dim,) 79 | 80 | @property 81 | def inputs_shape(self): 82 | """Returns the shape of the input distribution.""" 83 | return (self.input_dim,) if self.input_dim > 0 else None 84 | 85 | def initial_distribution(self, 86 | params: ParamsNLGSSM, 87 | inputs: Optional[Float[Array, " input_dim"]] = None) \ 88 | -> tfd.Distribution: 89 | """Returns the initial distribution.""" 90 | return MVN(params.initial_mean, params.initial_covariance) 91 | 92 | def transition_distribution(self, 93 | params: ParamsNLGSSM, 94 | state: Float[Array, " state_dim"], 95 | inputs: Optional[Float[Array, " input_dim"]] = None) \ 96 | -> tfd.Distribution: 97 | """Returns the nonlinear dynamics distribution.""" 98 | f = params.dynamics_function 99 | if inputs is None: 100 | mean = f(state) 101 | else: 102 | mean = f(state, inputs) 103 | return MVN(mean, params.dynamics_covariance) 104 | 105 | def emission_distribution(self, 106 | params: ParamsNLGSSM, 107 | state: Float[Array, " state_dim"], 108 | inputs: Float[Array, " input_dim"] = None) \ 109 | -> tfd.Distribution: 110 | """Returns the nonlinear emission distribution.""" 111 | h = params.emission_function 112 | if inputs is None: 113 | mean = h(state) 114 | else: 115 | mean = h(state, inputs) 116 | return MVN(mean, params.emission_covariance) 117 | -------------------------------------------------------------------------------- /dynamax/nonlinear_gaussian_ssm/sarkka_lib.py: -------------------------------------------------------------------------------- 1 | """ 2 | External implementations of nlgssm algorithms to use for unit test. 3 | Taken from https://github.com/petergchang/sarkka-jax 4 | Based on Simo Särkkä (2013), “Bayesian Filtering and Smoothing,” 5 | Available: https://users.aalto.fi/~ssarkka/pub/cup_book_online_20131111.pdf 6 | """ 7 | 8 | import jax.numpy as jnp 9 | from jax import vmap 10 | from jax import lax 11 | from jax import jacfwd 12 | 13 | 14 | def ekf(m_0, P_0, f, Q, h, R, Y): 15 | """ 16 | First-order additive EKF (Sarkka Algorithm 5.4) 17 | """ 18 | num_timesteps = len(Y) 19 | # Compute Jacobians 20 | F, H = jacfwd(f), jacfwd(h) 21 | 22 | def _step(carry, t): 23 | """One step of EKF""" 24 | m_k, P_k = carry 25 | 26 | # Update 27 | v = Y[t] - h(m_k) 28 | S = jnp.atleast_2d(H(m_k) @ P_k @ H(m_k).T + R) 29 | K = P_k @ H(m_k).T @ jnp.linalg.inv(S) 30 | m_post = m_k + K @ v 31 | P_post = P_k - K @ S @ K.T 32 | 33 | # Prediction step 34 | m_pred = f(m_post) 35 | P_pred = F(m_post) @ P_post @ F(m_post).T + Q 36 | 37 | return (m_pred, P_pred), (m_post, P_post) 38 | 39 | carry = (m_0, P_0) 40 | _, (ms, Ps) = lax.scan(_step, carry, jnp.arange(num_timesteps)) 41 | return ms, Ps 42 | 43 | 44 | def eks(m_0, P_0, f, Q, h, R, Y): 45 | """ 46 | First-order additive EK smoother 47 | """ 48 | num_timesteps = len(Y) 49 | 50 | # Run ekf 51 | m_post, P_post = ekf(m_0, P_0, f, Q, h, R, Y) 52 | 53 | # Compute Jacobians 54 | F, H = jacfwd(f), jacfwd(h) 55 | 56 | def _step(carry, t): 57 | """One step of EKS""" 58 | m_k, P_k = carry 59 | 60 | # Prediction step 61 | m_pred = f(m_post[t]) 62 | P_pred = F(m_post[t]) @ P_post[t] @ F(m_post[t]).T + Q 63 | G = P_post[t] @ F(m_post[t]).T @ jnp.linalg.inv(P_pred) 64 | 65 | # Update step 66 | m_sm = m_post[t] + G @ (m_k - m_pred) 67 | P_sm = P_post[t] + G @ (P_k - P_pred) @ G.T 68 | 69 | return (m_sm, P_sm), (m_sm, P_sm) 70 | 71 | carry = (m_post[-1], P_post[-1]) 72 | _, (m_sm, P_sm) = lax.scan(_step, carry, jnp.arange(num_timesteps - 1), reverse=True) 73 | m_sm = jnp.concatenate((m_sm, jnp.array([m_post[-1]]))) 74 | P_sm = jnp.concatenate((P_sm, jnp.array([P_post[-1]]))) 75 | 76 | return m_sm, P_sm 77 | 78 | 79 | def slf_additive(m_0, P_0, f, Q, h, R, Ef, Efdx, Eh, Ehdx, Y): 80 | """ 81 | Additive SLF with closed-form expectations (Sarkka Algorithm 5.10) 82 | """ 83 | num_timesteps = len(Y) 84 | 85 | def _step(carry, t): 86 | """One step of SLF""" 87 | m_k, P_k = carry 88 | 89 | # Update step 90 | v = Y[t] - Eh(m_k, P_k) 91 | S = jnp.atleast_2d(Ehdx(m_k, P_k) @ jnp.linalg.inv(P_k) @ Ehdx(m_k, P_k).T + R) 92 | K = Ehdx(m_k, P_k).T @ jnp.linalg.inv(S) 93 | m_post = m_k + K @ v 94 | P_post = P_k - K @ S @ K.T 95 | 96 | # Prediction step 97 | m_pred = Ef(m_post, P_post) 98 | P_pred = Efdx(m_post, P_post) @ jnp.linalg.inv(P_post) @ Efdx(m_post, P_post).T + Q 99 | 100 | return (m_pred, P_pred), (m_post, P_post) 101 | 102 | carry = (m_0, P_0) 103 | _, (ms, Ps) = lax.scan(_step, carry, jnp.arange(num_timesteps)) 104 | return ms, Ps 105 | 106 | 107 | def ukf(m_0, P_0, f, Q, h, R, alpha, beta, kappa, Y): 108 | """ 109 | Additive UKF (Sarkka Algorithm 5.14) 110 | """ 111 | num_timesteps, n = len(Y), P_0.shape[0] 112 | lamb = alpha**2 * (n + kappa) - n 113 | 114 | # Compute weights for mean and covariance estimates 115 | def compute_weights(n, alpha, beta, lamb): 116 | """Compute weights for UKF""" 117 | factor = 1 / (2 * (n + lamb)) 118 | w_mean = jnp.concatenate((jnp.array([lamb / (n + lamb)]), jnp.ones(2 * n) * factor)) 119 | w_cov = jnp.concatenate((jnp.array([lamb / (n + lamb) + (1 - alpha**2 + beta)]), jnp.ones(2 * n) * factor)) 120 | return w_mean, w_cov 121 | 122 | w_mean, w_cov = compute_weights(n, alpha, beta, lamb) 123 | 124 | def _step(carry, t): 125 | """One step of UKF""" 126 | m_k, P_k = carry 127 | 128 | # Update step: 129 | # 1. Form sigma points 130 | sigmas_update = compute_sigmas(m_k, P_k, n, lamb) 131 | # 2. Propagate the sigma points 132 | sigmas_update_prop = vmap(h, 0, 0)(sigmas_update) 133 | # 3. Compute params 134 | mu = jnp.tensordot(w_mean, sigmas_update_prop, axes=1) 135 | outer = lambda x, y: jnp.atleast_2d(x).T @ jnp.atleast_2d(y) 136 | outer = vmap(outer, 0, 0) 137 | S = jnp.tensordot(w_cov, outer(sigmas_update_prop - mu, sigmas_update_prop - mu), axes=1) + R 138 | C = jnp.tensordot(w_cov, outer(sigmas_update - m_k, sigmas_update_prop - mu), axes=1) 139 | # 4. Compute posterior 140 | K = C @ jnp.linalg.inv(S) 141 | m_post = m_k + K @ (Y[t] - mu) 142 | P_post = P_k - K @ S @ K.T 143 | 144 | # Prediction step: 145 | # 1. Form sigma points 146 | sigmas_pred = compute_sigmas(m_post, P_post, n, lamb) 147 | # 2. Propagate the sigma points 148 | sigmas_pred = vmap(f, 0, 0)(sigmas_pred) 149 | # 3. Compute predicted mean and covariance 150 | m_pred = jnp.tensordot(w_mean, sigmas_pred, axes=1) 151 | P_pred = jnp.tensordot(w_cov, outer(sigmas_pred - m_pred, sigmas_pred - m_pred), axes=1) + Q 152 | 153 | return (m_pred, P_pred), (m_post, P_post) 154 | 155 | # Find 2n+1 sigma points 156 | def compute_sigmas(m, P, n, lamb): 157 | """Compute sigma points""" 158 | disc = jnp.sqrt(n + lamb) * jnp.linalg.cholesky(P) 159 | sigma_plus = jnp.array([m + disc[:, i] for i in range(n)]) 160 | sigma_minus = jnp.array([m - disc[:, i] for i in range(n)]) 161 | return jnp.concatenate((jnp.array([m]), sigma_plus, sigma_minus)) 162 | 163 | carry = (m_0, P_0) 164 | _, (ms, Ps) = lax.scan(_step, carry, jnp.arange(num_timesteps)) 165 | return ms, Ps 166 | 167 | 168 | def uks(m_0, P_0, f, Q, h, R, alpha, beta, kappa, Y): 169 | """ 170 | First-order additive UKS 171 | """ 172 | num_timesteps, n = len(Y), P_0.shape[0] 173 | lamb = alpha**2 * (n + kappa) - n 174 | 175 | # Compute weights for mean and covariance estimates 176 | def compute_weights(n, alpha, beta, lamb): 177 | """Compute weights for UKS""" 178 | factor = 1 / (2 * (n + lamb)) 179 | w_mean = jnp.concatenate((jnp.array([lamb / (n + lamb)]), jnp.ones(2 * n) * factor)) 180 | w_cov = jnp.concatenate((jnp.array([lamb / (n + lamb) + (1 - alpha**2 + beta)]), jnp.ones(2 * n) * factor)) 181 | return w_mean, w_cov 182 | 183 | w_mean, w_cov = compute_weights(n, alpha, beta, lamb) 184 | 185 | # Run ukf 186 | m_post, P_post = ukf(m_0, P_0, f, Q, h, R, alpha, beta, kappa, Y) 187 | 188 | def _step(carry, t): 189 | """One step of UKS""" 190 | m_k, P_k = carry 191 | m_p, P_p = m_post[t], P_post[t] 192 | 193 | # Prediction step 194 | sigmas_pred = compute_sigmas(m_p, P_p, n, lamb) 195 | sigmas_pred_prop = vmap(f, 0, 0)(sigmas_pred) 196 | m_pred = jnp.tensordot(w_mean, sigmas_pred_prop, axes=1) 197 | outer = lambda x, y: jnp.atleast_2d(x).T @ jnp.atleast_2d(y) 198 | outer = vmap(outer, 0, 0) 199 | P_pred = jnp.tensordot(w_cov, outer(sigmas_pred_prop - m_pred, sigmas_pred_prop - m_pred), axes=1) + Q 200 | P_cross = jnp.tensordot(w_cov, outer(sigmas_pred - m_p, sigmas_pred_prop - m_pred), axes=1) 201 | G = P_cross @ jnp.linalg.inv(P_pred) 202 | 203 | # Update step 204 | m_sm = m_p + G @ (m_k - m_pred) 205 | P_sm = P_p + G @ (P_k - P_pred) @ G.T 206 | 207 | return (m_sm, P_sm), (m_sm, P_sm) 208 | 209 | # Find 2n+1 sigma points 210 | def compute_sigmas(m, P, n, lamb): 211 | """Compute sigma points""" 212 | disc = jnp.sqrt(n + lamb) * jnp.linalg.cholesky(P) 213 | sigma_plus = jnp.array([m + disc[:, i] for i in range(n)]) 214 | sigma_minus = jnp.array([m - disc[:, i] for i in range(n)]) 215 | return jnp.concatenate((jnp.array([m]), sigma_plus, sigma_minus)) 216 | 217 | carry = (m_post[-1], P_post[-1]) 218 | _, (m_sm, P_sm) = lax.scan(_step, carry, jnp.arange(num_timesteps - 1), reverse=True) 219 | m_sm = jnp.concatenate((m_sm, jnp.array([m_post[-1]]))) 220 | P_sm = jnp.concatenate((P_sm, jnp.array([P_post[-1]]))) 221 | 222 | return m_sm, P_sm 223 | -------------------------------------------------------------------------------- /dynamax/parameters.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for managing parameters and their properties as PyTrees. 3 | """ 4 | import jax.numpy as jnp 5 | from jax import lax 6 | from jax.tree_util import tree_reduce, tree_map, register_pytree_node_class 7 | import tensorflow_probability.substrates.jax.bijectors as tfb 8 | from typing import Optional, runtime_checkable 9 | from typing_extensions import Protocol 10 | 11 | from dynamax.types import Scalar 12 | 13 | @runtime_checkable 14 | class ParameterSet(Protocol): 15 | """A :class:`NamedTuple` with parameters stored as :class:`jax.DeviceArray` in the leaf nodes. 16 | 17 | """ 18 | pass 19 | 20 | @runtime_checkable 21 | class PropertySet(Protocol): 22 | """A matching :class:`NamedTuple` with :class:`ParameterProperties` stored in the leaf nodes. 23 | 24 | """ 25 | pass 26 | 27 | 28 | @register_pytree_node_class 29 | class ParameterProperties: 30 | """A PyTree containing parameter metadata (properties). 31 | 32 | Note: the properties are stored in the aux_data of this PyTree so that 33 | changes will trigger recompilation of functions that rely on them. 34 | 35 | Args: 36 | trainable (bool): flag specifying whether or not to fit this parameter is adjustable. 37 | constrainer (Optional tfb.Bijector): bijector mapping to constrained form. 38 | 39 | """ 40 | def __init__(self, 41 | trainable: bool = True, 42 | constrainer: Optional[tfb.Bijector] = None) -> None: 43 | self.trainable = trainable 44 | self.constrainer = constrainer 45 | 46 | def tree_flatten(self): 47 | """Flatten the PyTree into a tuple of aux_data and children.""" 48 | return (), (self.trainable, self.constrainer) 49 | 50 | @classmethod 51 | def tree_unflatten(cls, aux_data, children): 52 | """Reconstruct the PyTree from the tuple of aux_data and children.""" 53 | return cls(*aux_data) 54 | 55 | def __repr__(self): 56 | """Return a string representation of the PyTree.""" 57 | return f"ParameterProperties(trainable={self.trainable}, constrainer={self.constrainer})" 58 | 59 | 60 | def to_unconstrained(params: ParameterSet, props: PropertySet) -> ParameterSet: 61 | """Convert the constrained parameters to unconstrained form. 62 | 63 | Args: 64 | params: (nested) named tuple whose leaf values are DeviceArrays containing 65 | parameter values. 66 | props: matching named tuple whose leaf values are ParameterProperties, 67 | containing an optional bijector that converts to unconstrained form, 68 | and a boolean flag specifying if the parameter is trainable or not. 69 | 70 | Returns: 71 | unc_params: named tuple containing parameters in unconstrained form. 72 | 73 | """ 74 | to_unc = lambda value, prop: prop.constrainer.inverse(value) \ 75 | if prop.constrainer is not None else value 76 | is_leaf = lambda node: isinstance(node, (ParameterProperties,)) 77 | return tree_map(to_unc, params, props, is_leaf=is_leaf) 78 | 79 | 80 | def from_unconstrained(unc_params: ParameterSet, props: PropertySet) -> ParameterSet: 81 | """Convert the unconstrained parameters to constrained form. 82 | 83 | Args: 84 | unc_params: (nested) named tuple whose leaf values are DeviceArrays containing 85 | unconstrained parameter values. 86 | props: matching named tuple whose leaf values are ParameterProperties, 87 | containing an optional bijector that converts to unconstrained form, 88 | and a boolean flag specifying if the parameter is trainable or not. 89 | 90 | Returns: 91 | params: named tuple containing parameters in constrained form. 92 | If a parameter is marked with trainable=False (frozen) in the properties structure, 93 | it will be tagged with a "stop gradient". Thus the gradient of any loss function computed 94 | using these frozen constrained parameters will be zero. 95 | 96 | """ 97 | def from_unc(unc_value, prop): 98 | """Convert the unconstrained value to constrained form.""" 99 | value = prop.constrainer(unc_value) if prop.constrainer is not None else unc_value 100 | value = lax.stop_gradient(value) if not prop.trainable else value 101 | return value 102 | 103 | is_leaf = lambda node: isinstance(node, (ParameterProperties,)) 104 | return tree_map(from_unc, unc_params, props, is_leaf=is_leaf) 105 | 106 | 107 | def log_det_jac_constrain(params: ParameterSet, props: PropertySet) -> Scalar: 108 | """Log determinant of the Jacobian matrix evaluated at the unconstrained parameters. 109 | 110 | Let x be the unconstrained parameter and f(x) be the constrained parameter, so 111 | that in code, `props.constrainer` is the Bijector f. To perform Hamiltonian 112 | Monte Carlo (HMC) on the unconstrained parameters we need the log determinant of 113 | the forward Jacobian, |df(x) / dx|. In math, this falls out as follows: 114 | 115 | ..math: 116 | log p(x) = log p(f(x)) + log |df(x) / dx| 117 | 118 | Args: 119 | params: PyTree whose leaf values are DeviceArrays 120 | props: matching PyTree whose leaf values are ParameterProperties 121 | 122 | Returns: 123 | logdet: the log determinant of the forward Jacobian. 124 | """ 125 | unc_params = to_unconstrained(params, props) 126 | def _compute_logdet(unc_value, prop): 127 | """Compute the log determinant of the Jacobian matrix.""" 128 | logdet = prop.constrainer.forward_log_det_jacobian(unc_value).sum() \ 129 | if prop.constrainer is not None else 0.0 130 | return logdet if prop.trainable else 0.0 131 | 132 | is_leaf = lambda node: isinstance(node, (ParameterProperties,)) 133 | logdets = tree_map(_compute_logdet, unc_params, props, is_leaf=is_leaf) 134 | return tree_reduce(jnp.add, logdets, 0.0) 135 | -------------------------------------------------------------------------------- /dynamax/parameters_test.py: -------------------------------------------------------------------------------- 1 | """Tests for dynamax.parameters module""" 2 | import copy 3 | import jax.numpy as jnp 4 | import optax 5 | import tensorflow_probability.substrates.jax.bijectors as tfb 6 | 7 | from dynamax.parameters import ParameterProperties, to_unconstrained, from_unconstrained, log_det_jac_constrain 8 | from jax import jit, value_and_grad, lax 9 | from jax.tree_util import tree_map, tree_leaves 10 | from jaxtyping import Float, Array 11 | from typing import NamedTuple, Union 12 | 13 | 14 | class InitialParams(NamedTuple): 15 | """Dummy Initial state distribution parameters""" 16 | probs: Union[Float[Array, " state_dim"], ParameterProperties] 17 | 18 | class TransitionsParams(NamedTuple): 19 | """Dummy Transition matrix parameters""" 20 | transition_matrix: Union[Float[Array, "state_dim state_dim"], ParameterProperties] 21 | 22 | class EmissionsParams(NamedTuple): 23 | """Dummy Emission distribution parameters""" 24 | means: Union[Float[Array, "state_dim emission_dim"], ParameterProperties] 25 | scales: Union[Float[Array, "state_dim emission_dim"], ParameterProperties] 26 | 27 | class Params(NamedTuple): 28 | """Dummy SSM parameters""" 29 | initial: InitialParams 30 | transitions: TransitionsParams 31 | emissions: EmissionsParams 32 | 33 | 34 | def make_params(): 35 | """Create a dummy set of parameters and properties""" 36 | params = Params( 37 | initial=InitialParams(probs=jnp.ones(3) / 3.0), 38 | transitions=TransitionsParams(transition_matrix=0.9 * jnp.eye(3) + 0.1 * jnp.ones((3, 3)) / 3), 39 | emissions=EmissionsParams(means=jnp.zeros((3, 2)), scales=jnp.ones((3, 2))) 40 | ) 41 | 42 | props = Params( 43 | initial=InitialParams(probs=ParameterProperties(trainable=False, constrainer=tfb.SoftmaxCentered())), 44 | transitions=TransitionsParams(transition_matrix=ParameterProperties(constrainer=tfb.SoftmaxCentered())), 45 | emissions=EmissionsParams(means=ParameterProperties(), scales=ParameterProperties(constrainer=tfb.Softplus(), trainable=False)) 46 | ) 47 | return params, props 48 | 49 | 50 | def test_parameter_tofrom_unconstrained(): 51 | """Test that to_unconstrained and from_unconstrained are inverses""" 52 | params, props = make_params() 53 | unc_params = to_unconstrained(params, props) 54 | recon_params = from_unconstrained(unc_params, props) 55 | assert all(tree_leaves(tree_map(jnp.allclose, params, recon_params))) 56 | 57 | 58 | def test_parameter_pytree_jittable(): 59 | """Test that the parameter PyTree is jittable""" 60 | # If there's a problem with our PyTree registration, this should catch it. 61 | params, props = make_params() 62 | 63 | @jit 64 | def get_trainable(params, props): 65 | """Return a PyTree of trainable parameters""" 66 | return tree_map(lambda node, prop: node if prop.trainable else None, 67 | params, props, 68 | is_leaf=lambda node: isinstance(node, ParameterProperties)) 69 | 70 | # first call, jit 71 | get_trainable(params, props) 72 | assert get_trainable._cache_size() == 1 73 | 74 | # change param values, don't jit 75 | params = params._replace(initial=params.initial._replace(probs=jnp.zeros(3))) 76 | get_trainable(params, props) 77 | assert get_trainable._cache_size() == 1 78 | 79 | # change param dtype, jit 80 | params = params._replace(initial=params.initial._replace(probs=jnp.zeros(3, dtype=int))) 81 | get_trainable(params, props) 82 | assert get_trainable._cache_size() == 2 83 | 84 | # change props, jit 85 | props.transitions.transition_matrix.trainable = False 86 | get_trainable(params, props) 87 | assert get_trainable._cache_size() == 3 88 | 89 | 90 | def test_parameter_constrained(): 91 | """Test that only trainable params are updated in gradient descent. 92 | """ 93 | params, props = make_params() 94 | original_params = copy.deepcopy(params) 95 | 96 | unc_params = to_unconstrained(params, props) 97 | def loss(unc_params): 98 | """Dummy loss function""" 99 | params = from_unconstrained(unc_params, props) 100 | log_initial_probs = jnp.log(params.initial.probs) 101 | log_transition_matrix = jnp.log(params.transitions.transition_matrix) 102 | means = params.emissions.means 103 | scales = params.emissions.scales 104 | 105 | lp = log_initial_probs[1] 106 | lp += log_transition_matrix[0,0] 107 | lp += log_transition_matrix[1,1] 108 | lp += log_transition_matrix[2,2] 109 | lp += jnp.sum(-0.5 * (1.0 - means[0])**2 / scales[0]**2) 110 | lp += jnp.sum(-0.5 * (2.0 - means[1])**2 / scales[1]**2) 111 | lp += jnp.sum(-0.5 * (3.0 - means[2])**2 / scales[2]**2) 112 | return -lp 113 | 114 | # Run a dummy optimization 115 | f = jit(value_and_grad(loss)) 116 | optimizer = optax.adam(1e-2) 117 | opt_state = optimizer.init(unc_params) 118 | 119 | def step(carry, args): 120 | """Optimization step""" 121 | unc_params, opt_state = carry 122 | loss, grads = f(unc_params) 123 | updates, opt_state = optimizer.update(grads, opt_state) 124 | unc_params = optax.apply_updates(unc_params, updates) 125 | return (unc_params, opt_state), loss 126 | 127 | initial_carry = (unc_params, opt_state) 128 | (unc_params, opt_state), losses = \ 129 | lax.scan(step, initial_carry, None, length=10) 130 | params = from_unconstrained(unc_params, props) 131 | 132 | assert jnp.allclose(params.initial.probs, original_params.initial.probs) 133 | assert not jnp.allclose(params.transitions.transition_matrix, original_params.transitions.transition_matrix) 134 | assert not jnp.allclose(params.emissions.means, original_params.emissions.means) 135 | assert jnp.allclose(params.emissions.scales, original_params.emissions.scales) 136 | 137 | 138 | def test_logdet_jacobian(): 139 | """Test that log_det_jac_constrain is correct""" 140 | params, props = make_params() 141 | unc_params = to_unconstrained(params, props) 142 | logdet = log_det_jac_constrain(params, props) 143 | 144 | # only the transition matrix is constrained and trainable 145 | f = props.transitions.transition_matrix.constrainer.forward_log_det_jacobian 146 | logdet_manual = f(unc_params.transitions.transition_matrix).sum() 147 | assert jnp.isclose(logdet, logdet_manual) 148 | -------------------------------------------------------------------------------- /dynamax/slds/__init__.py: -------------------------------------------------------------------------------- 1 | from dynamax.slds.inference import DiscreteParamsSLDS, LGParamsSLDS, ParamsSLDS, rbpfilter, rbpfilter_optimal 2 | from dynamax.slds.models import SLDS 3 | -------------------------------------------------------------------------------- /dynamax/slds/inference_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the inference functions in dynamax/slds/inference.py 3 | """ 4 | import pytest 5 | 6 | import jax.numpy as jnp 7 | import jax.random as jr 8 | import dynamax.slds.mixture_kalman_filter_demo as kflib 9 | import jax 10 | 11 | from dynamax.slds import SLDS, DiscreteParamsSLDS, LGParamsSLDS, ParamsSLDS, rbpfilter, rbpfilter_optimal 12 | from functools import partial 13 | from functools import partial 14 | from jax.scipy.special import logit 15 | 16 | 17 | class TestRBPF(): 18 | """ 19 | Tests for the inference functions in dynamax/slds/inference.py 20 | """ 21 | ## Model definitions 22 | num_states = 3 23 | num_particles = 10 24 | state_dim = 4 25 | emission_dim = 4 26 | 27 | TT = 1 28 | A = jnp.array([[1, TT, 0, 0], 29 | [0, 1, 0, 0], 30 | [0, 0, 1, TT], 31 | [0, 0, 0, 1]]) 32 | 33 | 34 | B1 = jnp.array([0, 0, 0, 0]) 35 | B2 = jnp.array([-1.225, -0.35, 1.225, 0.35]) 36 | B3 = jnp.array([1.225, 0.35, -1.225, -0.35]) 37 | B = jnp.stack([B1, B2, B3], axis=0) 38 | 39 | Q = 0.2 * jnp.eye(4) 40 | R = 10.0 * jnp.diag(jnp.array([2, 1, 2, 1])) 41 | C = jnp.eye(4) 42 | 43 | transition_matrix = jnp.array([ 44 | [0.8, 0.1, 0.1], 45 | [0.1, 0.8, 0.1], 46 | [0.1, 0.1, 0.8] 47 | ]) 48 | 49 | discr_params = DiscreteParamsSLDS( 50 | initial_distribution=jnp.ones(num_states)/num_states, 51 | transition_matrix=transition_matrix, 52 | proposal_transition_matrix=transition_matrix 53 | ) 54 | 55 | lg_params = LGParamsSLDS( 56 | initial_mean=jnp.ones(state_dim), 57 | initial_cov=jnp.eye(state_dim), 58 | dynamics_weights=A, 59 | dynamics_cov=Q, 60 | dynamics_bias=jnp.array([B1, B2, B3]), 61 | dynamics_input_weights=None, 62 | emission_weights=C, 63 | emission_cov=R, 64 | emission_bias=None, 65 | emission_input_weights=None 66 | ) 67 | 68 | pre_params = ParamsSLDS( 69 | discrete=discr_params, 70 | linear_gaussian=lg_params 71 | ) 72 | 73 | params = pre_params.initialize(num_states, state_dim, emission_dim) 74 | 75 | ## Sample states and emissions 76 | key = jr.PRNGKey(1) 77 | slds = SLDS(num_states, state_dim, emission_dim) 78 | dstates, cstates, emissions = slds.sample(params, key, 100) 79 | 80 | ## Baseline Implementation parameters 81 | key_base = jr.PRNGKey(31) 82 | key_mean_init, key_sample, key_state, key_next = jr.split(key_base, 4) 83 | p_init = jnp.array([0.0, 1.0, 0.0]) 84 | 85 | mu_0 = 0.01 * jr.normal(key_mean_init, (num_particles, 4)) 86 | Sigma_0 = jnp.zeros((num_particles, 4,4)) 87 | s0 = jr.categorical(key_state, logit(p_init), shape=(num_particles,)) 88 | weights_0 = jnp.ones(num_particles) / num_particles 89 | init_config = (key_next, mu_0, Sigma_0, weights_0, s0) 90 | params1 = kflib.RBPFParamsDiscrete(A, B, C, Q, R, transition_matrix) 91 | 92 | def test_rbpf(self): 93 | """ 94 | Test the RBPF implementation 95 | """ 96 | # Baseline 97 | rbpf_optimal_part = partial(kflib.rbpf, params=self.params1, nparticles=self.num_particles) 98 | _, (mu_hist, Sigma_hist, weights_hist, s_hist, Ptk) = jax.lax.scan(rbpf_optimal_part, self.init_config, self.emissions) 99 | bl_post_mean = jnp.einsum("ts,tsm->tm", weights_hist, mu_hist) 100 | 101 | bl_rbpf_mse = ((bl_post_mean - self.cstates)[:, [0, 2]] ** 2).mean(axis=0).sum() 102 | # Dynamax 103 | out = rbpfilter(self.num_particles, self.params, self.emissions, self.key) 104 | means = out['means'] 105 | weights = out['weights'] 106 | dyn_post_mean = jnp.einsum("ts,tsm->tm", weights, means) 107 | dyn_rbpf_mse = ((dyn_post_mean - self.cstates)[:, [0, 2]] ** 2).mean(axis=0).sum() 108 | print(bl_rbpf_mse, dyn_rbpf_mse) 109 | assert jnp.allclose(bl_post_mean, dyn_post_mean, atol=10.0) 110 | 111 | def test_rbpf_optimal(self): 112 | """ 113 | Test the RBPF optimal implementation 114 | """ 115 | # Baseline 116 | rbpf_optimal_part = partial(kflib.rbpf_optimal, params=self.params1, nparticles=self.num_particles) 117 | _, (mu_hist, Sigma_hist, weights_hist, s_hist, Ptk) = jax.lax.scan(rbpf_optimal_part, self.init_config, self.emissions) 118 | bl_post_mean = jnp.einsum("ts,tsm->tm", weights_hist, mu_hist) 119 | bl_rbpf_mse = ((bl_post_mean - self.cstates)[:, [0, 2]] ** 2).mean(axis=0).sum() 120 | latent_hist_est = Ptk.mean(axis=1).argmax(axis=1) 121 | # Dynamax 122 | out = rbpfilter_optimal(self.num_particles, self.params, self.emissions, self.key) 123 | means = out['means'] 124 | weights = out['weights'] 125 | dyn_post_mean = jnp.einsum("ts,tsm->tm", weights, means) 126 | dyn_rbpf_mse = ((dyn_post_mean - self.cstates)[:, [0, 2]] ** 2).mean(axis=0).sum() 127 | print(bl_rbpf_mse, dyn_rbpf_mse) 128 | assert jnp.allclose(bl_post_mean, dyn_post_mean, atol=10.0) 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /dynamax/slds/mixture_kalman_filter_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo of a fitting an SLDS with the Rao-Blackwell Particle Filter. 3 | 4 | Author: Gerardo Durán-Martín (@gerdm) 5 | """ 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | from dataclasses import dataclass 10 | from jax import random 11 | from jax.scipy.special import logit 12 | from jaxtyping import Array, Float 13 | 14 | @dataclass 15 | class RBPFParamsDiscrete: 16 | """ 17 | Rao-Blackwell Particle Filtering (RBPF) parameters for 18 | a system with discrete latent-space. 19 | We assume that the system evolves as 20 | z_next = A * z_old + B(u_old) + noise1_next 21 | x_next = C * z_next + noise2_next 22 | u_next ~ transition_matrix(u_old) 23 | 24 | where 25 | noise1_next ~ N(0, Q) 26 | noise2_next ~ N(0, R) 27 | """ 28 | A: Float[Array, "dim_hidden dim_hidden"] 29 | B: Float[Array, "dim_hidden dim_control"] 30 | C: Float[Array, "dim_emission dim_hidden"] 31 | Q: Float[Array, "dim_hidden dim_hidden"] 32 | R: Float[Array, "dim_emission dim_emission"] 33 | transition_matrix: Float[Array, "dim_control dim_control"] 34 | 35 | 36 | def draw_state(val, key, params): 37 | """ 38 | Simulate one step of a system that evolves as 39 | A z_{t-1} + Bk + eps, 40 | where eps ~ N(0, Q). 41 | 42 | Parameters 43 | ---------- 44 | val: tuple (int, jnp.array) 45 | (latent value of system, state value of system). 46 | params: RBPFParamsDiscrete 47 | key: PRNGKey 48 | """ 49 | latent_old, state_old = val 50 | probabilities = params.transition_matrix[latent_old, :] 51 | logits = logit(probabilities) 52 | latent_new = random.categorical(key, logits) 53 | 54 | key_latent, key_obs = random.split(key) 55 | state_new = params.A @ state_old + params.B[latent_new, :] 56 | state_new = random.multivariate_normal(key_latent, state_new, params.Q) 57 | obs_new = random.multivariate_normal(key_obs, params.C @ state_new, params.R) 58 | 59 | return (latent_new, state_new), (latent_new, state_new, obs_new) 60 | 61 | 62 | def kf_update(mu_t, Sigma_t, k, xt, params): 63 | """ 64 | Kalman filter update step. 65 | """ 66 | I = jnp.eye(len(mu_t)) 67 | mu_t_cond = params.A @ mu_t + params.B[k] 68 | Sigma_t_cond = params.A @ Sigma_t @ params.A.T + params.Q 69 | xt_cond = params.C @ mu_t_cond 70 | St = params.C @ Sigma_t_cond @ params.C.T + params.R 71 | 72 | Kt = Sigma_t_cond @ params.C.T @ jnp.linalg.inv(St) 73 | 74 | # Estimation update 75 | mu_t = mu_t_cond + Kt @ (xt - xt_cond) 76 | Sigma_t = (I - Kt @ params.C) @ Sigma_t_cond 77 | 78 | # Normalisation constant 79 | mean_norm = params.C @ mu_t_cond 80 | cov_norm = params.C @ Sigma_t_cond @ params.C.T + params.R 81 | Ltk = jax.scipy.stats.multivariate_normal.pdf(xt, mean_norm, cov_norm) 82 | 83 | return mu_t, Sigma_t, Ltk 84 | 85 | 86 | def rbpf_step(key, weight_t, st, mu_t, Sigma_t, xt, params): 87 | """ 88 | Rao-Blackwell Particle Filter step. 89 | """ 90 | log_p_next = logit(params.transition_matrix[st]) 91 | k = random.categorical(key, log_p_next) 92 | mu_t, Sigma_t, Ltk = kf_update(mu_t, Sigma_t, k, xt, params) 93 | weight_t = weight_t * Ltk 94 | 95 | return mu_t, Sigma_t, weight_t, Ltk 96 | 97 | 98 | kf_update_vmap = jax.vmap(kf_update, in_axes=(None, None, 0, None, None), out_axes=0) 99 | 100 | 101 | def rbpf_step_optimal(key, weight_t, st, mu_t, Sigma_t, xt, params): 102 | """ 103 | Rao-Blackwell Particle Filter step with optimal proposal. 104 | """ 105 | # do Kalman step for all possible discrete latent states 106 | k = jnp.arange(len(params.transition_matrix)) 107 | mu_tk, Sigma_tk, Ltk = kf_update_vmap(mu_t, Sigma_t, k, xt, params) 108 | proposal = Ltk * params.transition_matrix[st] 109 | 110 | weight_tk = weight_t * proposal.sum() 111 | proposal = proposal / proposal.sum() 112 | 113 | return mu_tk, Sigma_tk, weight_tk, proposal 114 | 115 | 116 | # vectorised RBPF step 117 | rbpf_step_vec = jax.vmap(rbpf_step, in_axes=(0, 0, 0, 0, 0, None, None)) 118 | # vectorisedRBPF Step optimal 119 | rbpf_step_optimal_vec = jax.vmap(rbpf_step_optimal, in_axes=(0, 0, 0, 0, 0, None, None)) 120 | 121 | 122 | def rbpf(current_config, xt, params, nparticles=100): 123 | """ 124 | Rao-Blackwell Particle Filter using prior as proposal 125 | """ 126 | key, mu_t, Sigma_t, weights_t, st = current_config 127 | 128 | key_sample, key_state, key_next, key_reindex = random.split(key, 4) 129 | keys = random.split(key_sample, nparticles) 130 | 131 | st = random.categorical(key_state, logit(params.transition_matrix[st, :])) 132 | mu_t, Sigma_t, weights_t, Ltk = rbpf_step_vec(keys, weights_t, st, mu_t, Sigma_t, xt, params) 133 | weights_t = weights_t / weights_t.sum() 134 | 135 | indices = jnp.arange(nparticles) 136 | pi = random.choice(key_reindex, indices, shape=(nparticles,), p=weights_t, replace=True) 137 | st = st[pi] 138 | mu_t = mu_t[pi, ...] 139 | Sigma_t = Sigma_t[pi, ...] 140 | weights_t = jnp.ones(nparticles) / nparticles 141 | 142 | return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, Ltk) 143 | 144 | 145 | def rbpf_optimal(current_config, xt, params, nparticles=100): 146 | """ 147 | Rao-Blackwell Particle Filter using optimal proposal 148 | """ 149 | key, mu_t, Sigma_t, weights_t, st = current_config 150 | 151 | key_sample, key_state, key_next, key_reindex = random.split(key, 4) 152 | keys = random.split(key_sample, nparticles) 153 | 154 | st = random.categorical(key_state, logit(params.transition_matrix[st, :])) 155 | mu_t, Sigma_t, weights_t, proposal = rbpf_step_optimal_vec(keys, weights_t, st, mu_t, Sigma_t, xt, params) 156 | 157 | 158 | # Resampling 159 | indices = jnp.arange(nparticles) 160 | pi = random.choice(key_reindex, indices, shape=(nparticles,), p=weights_t, replace=True) 161 | 162 | # Obtain optimal proposal distribution 163 | proposal_samp = proposal[pi, :] 164 | st = random.categorical(key, logit(proposal_samp)) 165 | 166 | mu_t = mu_t[pi, st, ...] 167 | Sigma_t = Sigma_t[pi, st, ...] 168 | 169 | weights_t = jnp.ones(nparticles) / nparticles 170 | 171 | return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, proposal_samp) 172 | -------------------------------------------------------------------------------- /dynamax/slds/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Switching linear dynamical systems models. 3 | """ 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | import tensorflow_probability.substrates.jax.distributions as tfd 7 | 8 | from jax import lax 9 | from jax.tree_util import tree_map 10 | from jaxtyping import Array, Float 11 | from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN 12 | from typing import Optional, Tuple 13 | 14 | from dynamax.ssm import SSM 15 | from dynamax.slds.inference import ParamsSLDS 16 | from dynamax.types import PRNGKeyT 17 | 18 | 19 | class SLDS(SSM): 20 | """ 21 | Switching Linear Dynamical Systems (SLDS) model. 22 | 23 | Args: 24 | num_states: number of states $K$ 25 | state_dim: dimension of the state space $D$ 26 | emission_dim: dimension of the observation space $E$ 27 | input_dim: dimension of the input space $U$ 28 | """ 29 | def __init__(self, 30 | num_states: int, 31 | state_dim: int, 32 | emission_dim: int, 33 | input_dim: int=1): 34 | self.num_states = num_states 35 | self.state_dim = state_dim 36 | self.emission_dim = emission_dim 37 | self.input_dim = input_dim 38 | 39 | @property 40 | def emission_shape(self): 41 | """Shape of the emissions.""" 42 | return (self.emission_dim,) 43 | 44 | @property 45 | def inputs_shape(self): 46 | """Shape of the input distribution.""" 47 | return (self.input_dim,) if self.input_dim > 0 else None 48 | 49 | def initial_distribution(self, 50 | params: ParamsSLDS, 51 | dstate = int) \ 52 | -> tfd.Distribution: 53 | """ 54 | Return the initial distribution of the continuous latent states. 55 | """ 56 | params = params.linear_gaussian 57 | return MVN(params.initial_mean[dstate], params.initial_cov[dstate]) 58 | 59 | def transition_distribution(self, 60 | params: ParamsSLDS, 61 | dstate: int, 62 | cstate: Float[Array, " state_dim"], 63 | inputs: Optional[Float[Array, "ntime input_dim"]]=None 64 | ) -> tfd.Distribution: 65 | """ 66 | Return the transition distribution of the continuous latent states. 67 | """ 68 | params = params.linear_gaussian 69 | inputs = inputs if inputs is not None else jnp.zeros(self.input_dim) 70 | dynamics_input_weights = params.dynamics_input_weights if params.dynamics_input_weights is not None else jnp.zeros((self.num_states, self.state_dim, self.input_dim)) 71 | mean = params.dynamics_weights[dstate] @ cstate + dynamics_input_weights[dstate] @ inputs + params.dynamics_bias[dstate] 72 | return MVN(mean, params.dynamics_cov[dstate]) 73 | 74 | def emission_distribution(self, 75 | params: ParamsSLDS, 76 | dstate: int, 77 | cstate: Float[Array, " state_dim"], 78 | inputs: Optional[Float[Array, "ntime input_dim"]]=None) \ 79 | -> tfd.Distribution: 80 | """ 81 | Return the emission distribution of the observations. 82 | """ 83 | params = params.linear_gaussian 84 | inputs = inputs if inputs is not None else jnp.zeros(self.input_dim) 85 | emission_input_weights = params.emission_input_weights if params.emission_input_weights is not None else jnp.zeros((self.num_states, self.emission_dim, self.input_dim)) 86 | mean = params.emission_weights[dstate] @ cstate + emission_input_weights[dstate] @ inputs + params.emission_bias[dstate] 87 | return MVN(mean, params.emission_cov[dstate]) 88 | 89 | def sample(self, 90 | params: ParamsSLDS, 91 | key: PRNGKeyT, 92 | num_timesteps: int, 93 | inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None 94 | ) -> Tuple[Float[Array, "num_timesteps state_dim"], 95 | Float[Array, "num_timesteps emission_dim"]]: 96 | r"""Sample states $z_{1:T}$ and emissions $y_{1:T}$ given parameters $\theta$ and (optionally) inputs $u_{1:T}$. 97 | 98 | Args: 99 | params: model parameters $\theta$ 100 | key: random number generator 101 | num_timesteps: number of timesteps $T$ 102 | inputs: inputs $u_{1:T}$ 103 | 104 | Returns: 105 | latent states and emissions 106 | 107 | """ 108 | if not params.linear_gaussian.initialized: raise ValueError("ParamsSLDS must be initialized") 109 | 110 | def _step(prev_states, args): 111 | """Sample the next state and emission.""" 112 | key, inpt = args 113 | key0, key1, key2 = jr.split(key, 3) 114 | dstate, cstate = prev_states 115 | dstate = jr.choice(key0, jnp.arange(self.num_states), p = params.discrete.transition_matrix[dstate,:]) 116 | cstate = self.transition_distribution(params, dstate, cstate, inpt).sample(seed=key2) 117 | emission = self.emission_distribution(params, dstate, cstate, inpt).sample(seed=key1) 118 | return (dstate, cstate), (dstate, cstate, emission) 119 | 120 | # Sample the initial state 121 | key0, key1, key2, key = jr.split(key, 4) 122 | initial_input = tree_map(lambda x: x[0], inputs) 123 | initial_dstate = jr.choice(key0, jnp.arange(self.num_states), p = params.discrete.initial_distribution) 124 | initial_cstate = self.initial_distribution(params, initial_dstate).sample(seed=key1) 125 | initial_emission = self.emission_distribution(params, initial_dstate, initial_cstate, initial_input).sample(seed=key2) 126 | 127 | # Sample the remaining emissions and states 128 | next_keys = jr.split(key, num_timesteps - 1) 129 | next_inputs = tree_map(lambda x: x[1:], inputs) 130 | _, (next_dstates, next_cstates, next_emissions) = lax.scan(_step, (initial_dstate, initial_cstate), (next_keys, next_inputs)) 131 | 132 | # Concatenate the initial state and emission with the following ones 133 | expand_and_cat = lambda x0, x1T: jnp.concatenate((jnp.expand_dims(x0, 0), x1T)) 134 | dstates = tree_map(expand_and_cat, initial_dstate, next_dstates) 135 | cstates = tree_map(expand_and_cat, initial_cstate, next_cstates) 136 | emissions = tree_map(expand_and_cat, initial_emission, next_emissions) 137 | return dstates, cstates, emissions 138 | -------------------------------------------------------------------------------- /dynamax/types.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines type aliases for the dynamax package. 3 | """ 4 | from typing import Union 5 | from jaxtyping import Array, Float, Int 6 | 7 | PRNGKeyT = Array 8 | 9 | Scalar = Union[float, Float[Array, ""]] # python float or scalar jax device array with dtype float 10 | 11 | IntScalar = Union[int, Int[Array, ""]] 12 | -------------------------------------------------------------------------------- /dynamax/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/dynamax/utils/__init__.py -------------------------------------------------------------------------------- /dynamax/utils/bijectors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bijectors for converting between positive semi-definite matrices and real vectors. 3 | """ 4 | import tensorflow_probability.substrates.jax.bijectors as tfb 5 | 6 | # From https://www.tensorflow.org/probability/examples/ 7 | # TensorFlow_Probability_Case_Study_Covariance_Estimation 8 | class PSDToRealBijector(tfb.Chain): 9 | """ 10 | Bijector that maps a positive definite matrix to a real vector. 11 | """ 12 | def __init__(self, 13 | validate_args=False, 14 | validate_event_size=False, 15 | parameters=None, 16 | name=None): 17 | 18 | bijectors = [ 19 | tfb.Invert(tfb.FillTriangular()), 20 | tfb.TransformDiagonal(tfb.Invert(tfb.Exp())), 21 | tfb.Invert(tfb.CholeskyOuterProduct()), 22 | ] 23 | super().__init__(bijectors, validate_args, validate_event_size, parameters, name) 24 | 25 | 26 | class RealToPSDBijector(tfb.Chain): 27 | """ 28 | Bijector that maps a real vector to a positive definite matrix. 29 | """ 30 | def __init__(self, 31 | validate_args=False, 32 | validate_event_size=False, 33 | parameters=None, 34 | name=None): 35 | 36 | bijectors = [ 37 | tfb.CholeskyOuterProduct(), 38 | tfb.TransformDiagonal(tfb.Exp()), 39 | tfb.FillTriangular(), 40 | ] 41 | super().__init__(bijectors, validate_args, validate_event_size, parameters, name) 42 | -------------------------------------------------------------------------------- /dynamax/utils/optimize.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains functions for optimization. 3 | """ 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | import optax 7 | from jax import jit, lax, value_and_grad 8 | from jax.tree_util import tree_map 9 | from dynamax.utils.utils import pytree_len 10 | 11 | 12 | def sample_minibatches(key, dataset, batch_size, shuffle): 13 | """Sequence generator. 14 | 15 | NB: The generator does not preform as expected when used to yield data 16 | within jit'd code. This is likely because the generator internally 17 | updates a state with each yield (which doesn't play well with jit). 18 | """ 19 | n_data = pytree_len(dataset) 20 | perm = jnp.where(shuffle, jr.permutation(key, n_data), jnp.arange(n_data)) 21 | for idx in range(0, n_data, batch_size): 22 | yield tree_map(lambda x: x[perm[idx:min(idx + batch_size, n_data)]], dataset) 23 | 24 | 25 | def run_sgd(loss_fn, 26 | params, 27 | dataset, 28 | optimizer=optax.adam(1e-3), 29 | batch_size=1, 30 | num_epochs=50, 31 | shuffle=False, 32 | key=jr.PRNGKey(0)): 33 | """ 34 | Note that batch_emissions is initially of shape (N,T) 35 | where N is the number of independent sequences and 36 | T is the length of a sequence. Then, a random susbet with shape (B, T) 37 | of entire sequence, not time steps, is sampled at each step where B is 38 | batch size. 39 | 40 | Args: 41 | loss_fn: Objective function. 42 | params: initial value of parameters to be estimated. 43 | dataset: PyTree of data arrays with leading batch dimension 44 | optmizer: Optimizer. 45 | batch_size: Number of sequences used at each update step. 46 | num_iters: Iterations made on only one mini-batch. 47 | shuffle: Indicates whether to shuffle emissions. 48 | key: RNG key. 49 | 50 | Returns: 51 | params: The optimized parameters giving a low loss. 52 | losses: Output of loss_fn stored at each step. 53 | """ 54 | opt_state = optimizer.init(params) 55 | num_batches = pytree_len(dataset) 56 | num_complete_batches, leftover = jnp.divmod(num_batches, batch_size) 57 | num_batches = num_complete_batches + jnp.where(leftover == 0, 0, 1) 58 | loss_grad_fn = jit(value_and_grad(loss_fn)) 59 | 60 | if batch_size >= num_batches: 61 | shuffle = False 62 | 63 | keys = jr.split(key, num_epochs) 64 | losses = [] 65 | for key in keys: 66 | sample_generator = sample_minibatches(key, dataset, batch_size, shuffle) 67 | avg_loss = 0.0 68 | for itr in range(num_batches): 69 | minibatch = next(sample_generator) 70 | this_loss, grads = loss_grad_fn(params, minibatch) 71 | updates, opt_state = optimizer.update(grads, opt_state) 72 | params = optax.apply_updates(params, updates) 73 | avg_loss = (avg_loss * itr + this_loss) / (itr + 1) 74 | losses.append(avg_loss) 75 | return params, jnp.stack(losses) 76 | 77 | 78 | def run_gradient_descent(objective, 79 | params, 80 | optimizer=optax.adam(1e-2), 81 | optimizer_state=None, 82 | num_mstep_iters=50): 83 | """ 84 | Run gradient descent on the objective function. 85 | """ 86 | if optimizer_state is None: 87 | optimizer_state = optimizer.init(params) 88 | 89 | # Minimize the negative expected log joint with gradient descent 90 | loss_grad_fn = value_and_grad(objective) 91 | 92 | # One step of the algorithm 93 | def train_step(carry, args): 94 | """One step of the algorithm.""" 95 | params, optimizer_state = carry 96 | loss, grads = loss_grad_fn(params) 97 | updates, optimizer_state = optimizer.update(grads, optimizer_state) 98 | params = optax.apply_updates(params, updates) 99 | return (params, optimizer_state), loss 100 | 101 | # Run the optimizer 102 | initial_carry = (params, optimizer_state) 103 | (params, optimizer_state), losses = \ 104 | lax.scan(train_step, initial_carry, None, length=num_mstep_iters) 105 | 106 | # Return the updated parameters 107 | return params, optimizer_state, losses 108 | -------------------------------------------------------------------------------- /dynamax/utils/plotting.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for plotting. 3 | """ 4 | import jax.numpy as jnp 5 | from matplotlib.patches import Ellipse, transforms 6 | from matplotlib.colors import LinearSegmentedColormap 7 | from matplotlib import pyplot as plt 8 | import seaborn as sns 9 | 10 | 11 | _COLOR_NAMES = [ 12 | "windows blue", 13 | "red", 14 | "amber", 15 | "faded green", 16 | "dusty purple", 17 | "orange", 18 | "clay", 19 | "pink", 20 | "greyish", 21 | "mint", 22 | "light cyan", 23 | "steel blue", 24 | "forest green", 25 | "pastel purple", 26 | "salmon", 27 | "dark brown", 28 | ] 29 | COLORS = sns.xkcd_palette(_COLOR_NAMES) 30 | 31 | 32 | def white_to_color_cmap(color, nsteps=256): 33 | """Return a cmap which ranges from white to the specified color. 34 | Ported from HIPS-LIB plotting functions [https://github.com/HIPS/hips-lib] 35 | """ 36 | # Get a red-white-black cmap 37 | cdict = { 38 | "red": ((0.0, 1.0, 1.0), (1.0, color[0], color[0])), 39 | "green": ((0.0, 1.0, 1.0), (1.0, color[1], color[0])), 40 | "blue": ((0.0, 1.0, 1.0), (1.0, color[2], color[0])), 41 | } 42 | cmap = LinearSegmentedColormap("white_color_colormap", cdict, nsteps) 43 | return cmap 44 | 45 | 46 | def gradient_cmap(colors, nsteps=256, bounds=None): 47 | """Return a colormap that interpolates between a set of colors. 48 | Ported from HIPS-LIB plotting functions [https://github.com/HIPS/hips-lib] 49 | """ 50 | ncolors = len(colors) 51 | # assert colors.shape[1] == 3 52 | if bounds is None: 53 | bounds = jnp.linspace(0, 1, ncolors) 54 | 55 | reds = [] 56 | greens = [] 57 | blues = [] 58 | alphas = [] 59 | for b, c in zip(bounds, colors): 60 | reds.append((b, c[0], c[0])) 61 | greens.append((b, c[1], c[1])) 62 | blues.append((b, c[2], c[2])) 63 | alphas.append((b, c[3], c[3]) if len(c) == 4 else (b, 1.0, 1.0)) 64 | 65 | cdict = {"red": tuple(reds), "green": tuple(greens), "blue": tuple(blues), "alpha": tuple(alphas)} 66 | 67 | cmap = LinearSegmentedColormap("grad_colormap", cdict, nsteps) 68 | return cmap 69 | 70 | 71 | CMAP = gradient_cmap(COLORS) 72 | 73 | # https://matplotlib.org/devdocs/gallery/statistics/confidence_ellipse.html 74 | def plot_ellipse(Sigma, mu, ax, n_std=3.0, facecolor="none", edgecolor="k", **kwargs): 75 | """Plot an ellipse to with centre `mu` and axes defined by `Sigma`.""" 76 | cov = Sigma 77 | pearson = cov[0, 1] / jnp.sqrt(cov[0, 0] * cov[1, 1]) 78 | 79 | ell_radius_x = jnp.sqrt(1 + pearson) 80 | ell_radius_y = jnp.sqrt(1 - pearson) 81 | 82 | # if facecolor not in kwargs: 83 | # kwargs['facecolor'] = 'none' 84 | # if edgecolor not in kwargs: 85 | # kwargs['edgecolor'] = 'k' 86 | 87 | ellipse = Ellipse( 88 | (0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2, facecolor=facecolor, edgecolor=edgecolor, **kwargs 89 | ) 90 | 91 | scale_x = jnp.sqrt(cov[0, 0]) * n_std 92 | mean_x = mu[0] 93 | 94 | scale_y = jnp.sqrt(cov[1, 1]) * n_std 95 | mean_y = mu[1] 96 | 97 | transf = transforms.Affine2D().rotate_deg(45).scale(scale_x, scale_y).translate(mean_x, mean_y) 98 | 99 | ellipse.set_transform(transf + ax.transData) 100 | 101 | return ax.add_patch(ellipse) 102 | 103 | 104 | def plot_uncertainty_ellipses(means, Sigmas, ax, n_std=3.0, label=None, **kwargs): 105 | """Loop over means and Sigmas to add ellipses representing uncertainty.""" 106 | for i, (Sigma, mu) in enumerate(zip(Sigmas, means)): 107 | plot_ellipse(Sigma, mu, ax, n_std, 108 | label=label if i == 0 else None, 109 | **kwargs) 110 | 111 | # Some custom params to make prettier plots. 112 | custom_rcparams_base = { 113 | "font.size" : 13.0, 114 | "font.sans-serif" : ['Helvetica Neue', 'Lucida Grande', 'Verdana', 'Geneva', 'Lucid', 'Arial', 'Avant Garde', 'sans-serif'], 115 | "text.color" : "555555", 116 | "axes.facecolor" : "white", ## axes background color 117 | "axes.edgecolor" : "555555", ## axes edge color 118 | "axes.linewidth" : 1, ## edge linewidth 119 | "axes.titlesize" : 14, ## fontsize of the axes title 120 | "axes.titlepad" : 10.0, ## pad between axes and title in points 121 | "axes.labelcolor" : "555555", 122 | "axes.spines.top" : False, 123 | "axes.spines.right" : False, 124 | "axes.prop_cycle" : plt.cycler('color', ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', 125 | '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']), 126 | "xtick.color" : "555555", 127 | "ytick.color" : "555555", 128 | "grid.color" : "eeeeee", ## grid color 129 | "legend.frameon" : False, ## if True, draw the legend on a background patch 130 | "figure.titlesize" : 16, ## size of the figure title (Figure.suptitle()) 131 | "figure.facecolor" : "white", ## figure facecolor 132 | "figure.frameon" : False, ## enable figure frame 133 | "figure.subplot.top" : 0.91, ## the top of the subplots of the figure 134 | } 135 | 136 | # Some custom params specifically designed for plots in a notebook. 137 | custom_rcparams_notebook = { 138 | **custom_rcparams_base, 139 | "figure.figsize": (7.0, 5.0), 140 | "axes.labelsize": 14, 141 | "xtick.labelsize": 12, 142 | "ytick.labelsize": 12, 143 | "legend.fontsize": 12, 144 | "grid.linewidth": 1, 145 | "lines.linewidth": 1.75, 146 | "patch.linewidth": .3, 147 | "lines.markersize": 7, 148 | "lines.markeredgewidth": 0, 149 | "xtick.major.width": 1, 150 | "ytick.major.width": 1, 151 | "xtick.minor.width": .5, 152 | "ytick.minor.width": .5, 153 | "xtick.major.pad": 7, 154 | "ytick.major.pad": 7 155 | } 156 | 157 | -------------------------------------------------------------------------------- /dynamax/utils/test_optimize.py: -------------------------------------------------------------------------------- 1 | from dynamax.utils.optimize import run_sgd 2 | import jax.numpy as jnp 3 | from optax import adam 4 | from numpy.testing import assert_allclose 5 | 6 | 7 | def test_run_sgd(): 8 | """Test that run_sgd solves an exactly solvable problem.""" 9 | 10 | def _loss(a, x): 11 | return jnp.sum((x - a)**2 / 2) 12 | 13 | # Average `mini_batch_1`: 0.0 14 | mini_batch_1 = jnp.array([ 0.5333575 , 1.5523977 , -0.34479547, -0.80614984, -0.93481004]) 15 | # Average `mini_batch_2`: 1.0 16 | mini_batch_2 = jnp.array([0.52032334, 1.6625587 , 1.1381058 , 1.2635592 , 0.41545272]) 17 | # Average `X_train`: 0.5 18 | X_train = jnp.concatenate([mini_batch_1, mini_batch_2]).reshape(-1, 1) 19 | 20 | param_init = jnp.array(0.5) 21 | settings = { 22 | 'params': param_init, 'num_epochs': 10_000, 'optimizer': adam(1e-3), 'batch_size': 5, 23 | } 24 | # Train on mini_batch_1 with batch size five (=full dataset). 25 | solution_mini_batch_1, _ = run_sgd(_loss, dataset=mini_batch_1.reshape(-1, 1), **settings) 26 | assert_allclose(solution_mini_batch_1, 0.0, atol=1e-3, rtol=1e-3) 27 | 28 | # Train on X_train with mini batch size five (=half the dataset). 29 | solution, losses = run_sgd(_loss, dataset=X_train, **settings) 30 | assert_allclose(solution, 0.5, atol=1e-3) 31 | num_batches = len(X_train) / len(mini_batch_1) 32 | assert_allclose(losses[-1], _loss(0.5, X_train) / num_batches, atol=1e-3, rtol=1e-3) 33 | -------------------------------------------------------------------------------- /dynamax/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for the library. 3 | """ 4 | import jax 5 | import jaxlib 6 | import jax.numpy as jnp 7 | import jax.random as jr 8 | 9 | from functools import partial 10 | from jax import jit 11 | from jax import vmap 12 | from jax.tree_util import tree_map, tree_leaves, tree_flatten, tree_unflatten 13 | from jaxtyping import Array, Int 14 | from scipy.optimize import linear_sum_assignment 15 | from jax.scipy.linalg import cho_factor, cho_solve 16 | 17 | def has_tpu(): 18 | """Check if the current device is a TPU.""" 19 | try: 20 | return isinstance(jax.devices()[0], jaxlib.xla_extension.TpuDevice) 21 | except: 22 | return False 23 | 24 | 25 | @jit 26 | def pad_sequences(observations, valid_lens, pad_val=0): 27 | """ 28 | Pad ragged sequences to a fixed length. 29 | Parameters 30 | ---------- 31 | observations : array(N, seq_len) 32 | All observation sequences 33 | valid_lens : array(N, seq_len) 34 | Consists of the valid length of each observation sequence 35 | pad_val : int 36 | Value that the invalid observable events of the observation sequence will be replaced 37 | Returns 38 | ------- 39 | * array(n, max_len) 40 | Ragged dataset 41 | """ 42 | 43 | def pad(seq, len): 44 | """Pad a single sequence.""" 45 | idx = jnp.arange(1, seq.shape[0] + 1) 46 | return jnp.where(idx <= len, seq, pad_val) 47 | 48 | dataset = vmap(pad, in_axes=(0, 0))(observations, valid_lens), valid_lens 49 | return dataset 50 | 51 | 52 | def monotonically_increasing(x, atol=0., rtol=0.): 53 | """Check if an array is monotonically increasing.""" 54 | thresh = atol + rtol*jnp.abs(x[:-1]) 55 | return jnp.all(jnp.diff(x) >= -thresh) 56 | 57 | 58 | def pytree_len(pytree): 59 | """Return the number of leaves in a PyTree.""" 60 | if pytree is None: 61 | return 0 62 | else: 63 | return len(tree_leaves(pytree)[0]) 64 | 65 | 66 | def pytree_sum(pytree, axis=None, keepdims=False, where=None): 67 | """Sum all the leaves in a PyTree.""" 68 | return tree_map(partial(jnp.sum, axis=axis, keepdims=keepdims, where=where), pytree) 69 | 70 | 71 | def pytree_slice(pytree, slc): 72 | """Slice all the leaves in a Pytree.""" 73 | return tree_map(lambda x: x[slc], pytree) 74 | 75 | 76 | def pytree_stack(pytrees): 77 | """Stack all the leaves in a list of PyTrees.""" 78 | _, treedef = tree_flatten(pytrees[0]) 79 | leaves = [tree_leaves(tree) for tree in pytrees] 80 | return tree_unflatten(treedef, [jnp.stack(vals) for vals in zip(*leaves)]) 81 | 82 | def random_rotation(seed, n, theta=None): 83 | r"""Helper function to create a rotating linear system. 84 | 85 | Args: 86 | seed (jax.random.PRNGKey): JAX random seed. 87 | n (int): Dimension of the rotation matrix. 88 | theta (float, optional): If specified, this is the angle of the rotation, otherwise 89 | a random angle sampled from a standard Gaussian scaled by ::math::`\pi / 2`. Defaults to None. 90 | Returns: 91 | [type]: [description] 92 | """ 93 | 94 | key1, key2 = jr.split(seed) 95 | 96 | if theta is None: 97 | # Sample a random, slow rotation 98 | theta = 0.5 * jnp.pi * jr.uniform(key1) 99 | 100 | if n == 1: 101 | return jr.uniform(key1) * jnp.eye(1) 102 | 103 | rot = jnp.array([[jnp.cos(theta), -jnp.sin(theta)], [jnp.sin(theta), jnp.cos(theta)]]) 104 | out = jnp.eye(n) 105 | out = out.at[:2, :2].set(rot) 106 | q = jnp.linalg.qr(jr.uniform(key2, shape=(n, n)))[0] 107 | return q.dot(out).dot(q.T) 108 | 109 | 110 | def ensure_array_has_batch_dim(tree, instance_shapes): 111 | """Add a batch dimension to a PyTree, if necessary. 112 | 113 | Example: If `tree` is an array of shape (T, D) where `T` is 114 | the number of time steps and `D` is the emission dimension, 115 | and if `instance_shapes` is a tuple (D,), then the return 116 | value is the array with an added batch dimension, with 117 | shape (1, T, D). 118 | 119 | Example: If `tree` is an array of shape (N,TD) and 120 | `instance_shapes` is a tuple (D,), then the return 121 | value is simply `tree`, since it already has a batch 122 | dimension (of length N). 123 | 124 | Example: If `tree = (A, B)` is a tuple of arrays with 125 | `A.shape = (100,2)` `B.shape = (100,4)`, and 126 | `instances_shapes = ((2,), (4,))`, then the return value 127 | is equivalent to `(jnp.expand_dims(A, 0), jnp.expand_dims(B, 0))`. 128 | 129 | Args: 130 | tree (_type_): PyTree whose leaves' shapes are either 131 | (batch, length) + instance_shape or (length,) + instance_shape. 132 | If the latter, this function adds a batch dimension of 1 to 133 | each leaf node. 134 | 135 | instance_shape (_type_): matching PyTree where the "leaves" are 136 | tuples of integers specifying the shape of one "instance" or 137 | entry in the array. 138 | """ 139 | def _expand_dim(x, shp): 140 | """Add a batch dimension to an array, if necessary.""" 141 | ndim = len(shp) 142 | assert x.ndim > ndim, "array does not match expected shape!" 143 | assert all([(d1 == d2) for d1, d2 in zip(x.shape[-ndim:], shp)]), \ 144 | "array does not match expected shape!" 145 | 146 | if x.ndim == ndim + 2: 147 | # x already has a batch dim 148 | return x 149 | elif x.ndim == ndim + 1: 150 | # x has a leading time dimension but no batch dim 151 | return jnp.expand_dims(x, 0) 152 | else: 153 | raise Exception("array has too many dimensions!") 154 | 155 | if tree is None: 156 | return None 157 | else: 158 | return tree_map(_expand_dim, tree, instance_shapes) 159 | 160 | 161 | def compute_state_overlap( 162 | z1: Int[Array, " num_timesteps"], 163 | z2: Int[Array, " num_timesteps"] 164 | ): 165 | """ 166 | Compute a matrix describing the state-wise overlap between two state vectors 167 | ``z1`` and ``z2``. 168 | 169 | The state vectors should both of shape ``(T,)`` and be integer typed. 170 | 171 | Args: 172 | z1: The first state vector. 173 | z2: The second state vector. 174 | 175 | Returns: 176 | overlap matrix: Matrix of cumulative overlap events. 177 | """ 178 | assert z1.shape == z2.shape 179 | assert z1.min() >= 0 and z2.min() >= 0 180 | 181 | K = max(max(z1), max(z2)) + 1 182 | 183 | overlap = jnp.sum( 184 | (z1[:, None] == jnp.arange(K))[:, :, None] 185 | & (z2[:, None] == jnp.arange(K))[:, None, :], 186 | axis=0, 187 | ) 188 | return overlap 189 | 190 | 191 | def find_permutation( 192 | z1: Int[Array, " num_timesteps"], 193 | z2: Int[Array, " num_timesteps"] 194 | ): 195 | """ 196 | Find the permutation of the state labels in sequence ``z1`` so that they 197 | best align with the labels in ``z2``. 198 | 199 | Args: 200 | z1: The first state vector. 201 | z2: The second state vector. 202 | 203 | Returns: 204 | permutation such that ``jnp.take(perm, z1)`` best aligns with ``z2``. 205 | Thus, ``len(perm) = min(z1.max(), z2.max()) + 1``. 206 | 207 | """ 208 | overlap = compute_state_overlap(z1, z2) 209 | _, perm = linear_sum_assignment(-overlap) 210 | return perm 211 | 212 | 213 | def psd_solve(A, b, diagonal_boost=1e-9): 214 | """A wrapper for coordinating the linalg solvers used in the library for psd matrices.""" 215 | A = symmetrize(A) + diagonal_boost * jnp.eye(A.shape[-1]) 216 | L, lower = cho_factor(A, lower=True) 217 | x = cho_solve((L, lower), b) 218 | return x 219 | 220 | def symmetrize(A): 221 | """Symmetrize one or more matrices.""" 222 | return 0.5 * (A + jnp.swapaxes(A, -1, -2)) 223 | -------------------------------------------------------------------------------- /dynamax/utils/utils_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests of the utility functions. 3 | """ 4 | import jax.numpy as jnp 5 | 6 | from dynamax.utils.utils import find_permutation 7 | 8 | def test_find_permutation(): 9 | """Test the find_permutation function 10 | """ 11 | z1 = jnp.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) 12 | z2 = jnp.array([1, 1, 0, 0, 3, 3, 2, 2, 4, 4]) 13 | true_perm = jnp.array([1, 0, 3, 2, 4]) 14 | perm = find_permutation(z1, z2) 15 | assert jnp.allclose(jnp.take(perm, z1), z2) 16 | assert jnp.allclose(true_perm, perm) 17 | 18 | 19 | def test_find_permutation_unmatched(): 20 | """Test the find_permutation function with K_2 > K_1. 21 | """ 22 | z1 = jnp.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) 23 | z2 = jnp.array([5, 5, 6, 6, 7, 7, 8, 8, 9, 9]) 24 | true_perm = jnp.array([5, 6, 7, 8, 9, 0, 1, 2, 3, 4]) 25 | perm = find_permutation(z1, z2) 26 | assert jnp.allclose(true_perm, perm) 27 | 28 | 29 | def test_find_permutation_unmatched_v2(): 30 | """Test the find_permutation function with K_2 < K_1. 31 | """ 32 | z1 = jnp.array([5, 5, 6, 6, 7, 7, 8, 8, 9, 9]) 33 | z2 = jnp.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) 34 | true_perm = jnp.array([5, 6, 7, 8, 9, 0, 1, 2, 3, 4]) 35 | perm = find_permutation(z1, z2) 36 | assert jnp.allclose(true_perm, perm) 37 | 38 | 39 | def test_find_permutation_unmatched_v3(): 40 | """Test the find_permutation function with a more complex assignment. 41 | """ 42 | z1 = jnp.array([0, 5, 5, 1, 6, 6, 2, 7, 7, 3, 8, 8, 4, 9, 9]) 43 | z2 = jnp.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4]) 44 | true_perm = jnp.array([5, 6, 7, 8, 9, 0, 1, 2, 3, 4]) 45 | perm = find_permutation(z1, z2) 46 | assert jnp.allclose(true_perm, perm) -------------------------------------------------------------------------------- /dynamax/warnings.py: -------------------------------------------------------------------------------- 1 | """ 2 | TensorFlow Probability logs a few annoying messages. We suppress these by default. 3 | """ 4 | import logging 5 | import warnings 6 | 7 | 8 | class CheckTypesFilter(logging.Filter): 9 | """ 10 | Catch "check_types" warnings that are sent to the logger 11 | """ 12 | def filter(self, record): 13 | """Filter out check_types warnings""" 14 | return "check_types" not in record.getMessage() 15 | 16 | 17 | logger = logging.getLogger() 18 | logger.addFilter(CheckTypesFilter()) 19 | 20 | 21 | # Catch UserWarning: Explicitly requested dtype... 22 | warnings.filterwarnings("ignore", category=UserWarning, message="Explicitly requested dtype") 23 | warnings.filterwarnings("ignore", category=DeprecationWarning, message="Using or importing the ABCs") 24 | -------------------------------------------------------------------------------- /logo/dynamax.ai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/logo/dynamax.ai -------------------------------------------------------------------------------- /logo/dynamax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/logo/dynamax.png -------------------------------------------------------------------------------- /logo/logo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/logo/logo.gif -------------------------------------------------------------------------------- /logo/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/logo/mask.png -------------------------------------------------------------------------------- /paper/paper.bib: -------------------------------------------------------------------------------- 1 | @article{vyas2020computation, 2 | title={Computation through neural population dynamics}, 3 | author={Vyas, Saurabh and Golub, Matthew D and Sussillo, David and Shenoy, Krishna V}, 4 | journal={Annual review of neuroscience}, 5 | volume={43}, 6 | number={1}, 7 | pages={249--275}, 8 | year={2020}, 9 | publisher={Annual Reviews}, 10 | doi={10.1146/annurev-neuro-092619-094115} 11 | } 12 | 13 | @book{murphy2023probabilistic, 14 | author = "Kevin P. Murphy", 15 | title = "Probabilistic Machine Learning: Advanced Topics", 16 | publisher = "MIT Press", 17 | year = 2023, 18 | url = "http://probml.github.io/book2" 19 | } 20 | 21 | @book{sarkka2023bayesian, 22 | title={Bayesian filtering and smoothing}, 23 | author={S{\"a}rkk{\"a}, Simo and Svensson, Lennart}, 24 | volume={17}, 25 | year={2023}, 26 | publisher={Cambridge University Press}, 27 | doi={10.1017/CBO9781139344203} 28 | } 29 | 30 | 31 | @misc{jax, 32 | author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, 33 | title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, 34 | url = {http://github.com/google/jax}, 35 | version = {0.3.13}, 36 | year = {2018}, 37 | } 38 | 39 | @inproceedings{zhao2023revisiting, 40 | title={Revisiting structured variational autoencoders}, 41 | author={Zhao, Yixiu and Linderman, Scott}, 42 | booktitle={International Conference on Machine Learning}, 43 | pages={42046--42057}, 44 | year={2023}, 45 | organization={PMLR}, 46 | doi={10.48550/arXiv.2305.16543} 47 | } 48 | 49 | @article{lee2023switching, 50 | title={Switching autoregressive low-rank tensor models}, 51 | author={Lee, Hyun Dong and Warrington, Andrew and Glaser, Joshua and Linderman, Scott}, 52 | journal={Advances in Neural Information Processing Systems}, 53 | volume={36}, 54 | pages={57976--58010}, 55 | year={2023}, 56 | doi={10.48550/arXiv.2306.03291} 57 | } 58 | 59 | @inproceedings{chang2023low, 60 | title = {Low-rank extended {K}alman filtering for online learning of neural networks from streaming data}, 61 | author = {Chang, Peter G. and Dur\'an-Mart\'in, Gerardo and Shestopaloff, Alex and Jones, Matt and Murphy, Kevin P}, 62 | booktitle = {Proceedings of The 2nd Conference on Lifelong Learning Agents}, 63 | pages = {1025--1071}, 64 | year = {2023}, 65 | editor = {Chandar, Sarath and Pascanu, Razvan and Sedghi, Hanie and Precup, Doina}, 66 | volume = {232}, 67 | series = {Proceedings of Machine Learning Research}, 68 | month = {22--25 Aug}, 69 | publisher = {PMLR}, 70 | doi={10.48550/arXiv.2305.19535}, 71 | } 72 | 73 | 74 | @article{weinreb2024keypoint, 75 | author = {Weinreb, Caleb and Pearl, Jonah E. and Lin, Sherry and Osman, Mohammed Abdal Monium and Zhang, Libby and Annapragada, Sidharth and Conlin, Eli and Hoffmann, Red and Makowska, Sofia and Gillis, Winthrop F. and Jay, Maya and Ye, Shaokai and Mathis, Alexander and Mathis, Mackenzie W. and Pereira, Talmo and Linderman, Scott W. and Datta, Sandeep Robert}, 76 | date = {2024/07/01}, 77 | id = {Weinreb2024}, 78 | journal = {Nature Methods}, 79 | number = {7}, 80 | pages = {1329--1339}, 81 | title = {Keypoint-{M}o{S}eq: parsing behavior by linking point tracking to pose dynamics}, 82 | volume = {21}, 83 | year = {2024}, 84 | doi={10.1038/s41592-024-02318-2}, 85 | } 86 | 87 | @misc{pyhsmm, 88 | author = {Matthew James Johnson}, 89 | title = {{PyHSMM}: Bayesian inference in HSMMs and HMMs}, 90 | url = {https://github.com/mattjj/pyhsmm}, 91 | version = {0.0.0}, 92 | year = {2020}, 93 | } 94 | 95 | @misc{eeasensors, 96 | author = {Adrien Corenflos and Simo Särkkä}, 97 | title = {Code Companion for {B}ayesian {F}iltering and {S}moothing}, 98 | url = {https://github.com/EEA-sensors/Bayesian-Filtering-and-Smoothing}, 99 | version = {1.0}, 100 | year = {2021}, 101 | } 102 | 103 | 104 | @misc{ssm, 105 | author = {Linderman, Scott and Antin, Benjamin and Zoltowski, David and Glaser, Joshua}, 106 | title = {{SSM: Bayesian Learning and Inference for State Space Models}}, 107 | url = {https://github.com/lindermanlab/ssm}, 108 | version = {0.0.1}, 109 | year = {2020} 110 | } 111 | 112 | @misc{jsl, 113 | author = {Duran-Martin, Gerardo and Murphy, Kevin and Kara, Aleyna}, 114 | title = {{JSL: JAX State-Space models (SSM) Library}}, 115 | url={https://github.com/probml/JSL}, 116 | year={2022} 117 | } 118 | 119 | @inproceedings{seabold2010statsmodels, 120 | title={statsmodels: {E}conometric and statistical modeling with python}, 121 | author={Seabold, Skipper and Perktold, Josef}, 122 | booktitle={9th Python in Science Conference}, 123 | year={2010}, 124 | doi={10.25080/majora-92bf1922-011} 125 | } 126 | 127 | @misc{hmmlearn, 128 | author={Ron Weiss and Shiqiao Du and Jaques Grobler and David Cournapeau and Fabian Pedregosa and Gael Varoquaux and Andreas Mueller and Bertrand Thirion and Daniel Nouri and Gilles Louppe and Jake Vanderplas and John Benediktsson and Lars Buitinck and Mikhail Korobov and Robert McGibbon and Stefano Lattarini and Vlad Niculae and Alexandre Gramfort and Sergei Lebedev and Daniela Huppenkothen and Christopher Farrow and Alexandr Yanenko and Antony Lee and Matthew Danielson and Alex Rockhill}, 129 | title={hmmlearn}, 130 | url={https://github.com/hmmlearn/hmmlearn}, 131 | version={0.3.2}, 132 | year={2024}, 133 | } 134 | 135 | @ARTICLE{durbin1998biological, 136 | title = "Biological sequence analysis: {P}robabilistic models of proteins 137 | and nucleic acids", 138 | author = "Durbin, Richard and Eddy, Sean R and Krogh, Anders and Mitchison, 139 | Graeme", 140 | publisher = "Cambridge University Press", 141 | month = apr, 142 | year = 1998, 143 | doi={10.1017/cbo9780511790492}, 144 | } 145 | 146 | @article{patterson2008state, 147 | title={State-space models of individual animal movement}, 148 | author={Patterson, Toby A and Thomas, Len and Wilcox, Chris and Ovaskainen, Otso and Matthiopoulos, Jason}, 149 | journal={Trends in ecology \& evolution}, 150 | volume={23}, 151 | number={2}, 152 | pages={87--94}, 153 | year={2008}, 154 | publisher={Elsevier}, 155 | doi={10.1016/j.tree.2007.10.009} 156 | } 157 | 158 | @article{jacquier2002bayesian, 159 | title={Bayesian analysis of stochastic volatility models}, 160 | author={Jacquier, Eric and Polson, Nicholas G and Rossi, Peter E}, 161 | journal={Journal of Business \& Economic Statistics}, 162 | volume={20}, 163 | number={1}, 164 | pages={69--87}, 165 | year={2002}, 166 | publisher={Taylor \& Francis}, 167 | doi={10.1198/073500102753410408} 168 | } 169 | 170 | @article{ott2004local, 171 | title={A local ensemble {K}alman filter for atmospheric data assimilation}, 172 | author={Ott, Edward and Hunt, Brian R and Szunyogh, Istvan and Zimin, Aleksey V and Kostelich, Eric J and Corazza, Matteo and Kalnay, Eugenia and Patil, DJ and Yorke, James A}, 173 | journal={Tellus A: Dynamic Meteorology and Oceanography}, 174 | volume={56}, 175 | number={5}, 176 | pages={415--428}, 177 | year={2004}, 178 | publisher={Taylor \& Francis}, 179 | doi={10.3402/tellusa.v56i5.14462} 180 | } 181 | 182 | @article{stone1975parallel, 183 | title={Parallel tridiagonal equation solvers}, 184 | author={Stone, Harold S}, 185 | journal={ACM Transactions on Mathematical Software (TOMS)}, 186 | volume={1}, 187 | number={4}, 188 | pages={289--307}, 189 | year={1975}, 190 | publisher={ACM New York, NY, USA}, 191 | doi={10.1145/355656.355657} 192 | } 193 | 194 | @article{sarkka2020temporal, 195 | title={Temporal parallelization of {B}ayesian smoothers}, 196 | author={S{\"a}rkk{\"a}, Simo and Garc{\'\i}a-Fern{\'a}ndez, {\'A}ngel F}, 197 | journal={IEEE Transactions on Automatic Control}, 198 | volume={66}, 199 | number={1}, 200 | pages={299--306}, 201 | year={2020}, 202 | publisher={IEEE}, 203 | doi={10.1109/TAC.2020.2976316} 204 | } 205 | 206 | @article{hassan2021temporal, 207 | title={Temporal parallelization of inference in hidden {M}arkov models}, 208 | author={Hassan, Syeda Sakira and S{\"a}rkk{\"a}, Simo and Garc{\'\i}a-Fern{\'a}ndez, {\'A}ngel F}, 209 | journal={IEEE Transactions on Signal Processing}, 210 | volume={69}, 211 | pages={4875--4887}, 212 | year={2021}, 213 | publisher={IEEE}, 214 | doi={10.1109/TSP.2021.3103338} 215 | } 216 | 217 | 218 | @misc{sts-jax, 219 | author={Xinglong Li and Kevin Murphy}, 220 | title={Structural Time Series (STS) in JAX}, 221 | url={https://github.com/probml/sts-jax}, 222 | year={2022}, 223 | } 224 | 225 | @article{dalle2024hiddenmarkovmodels, 226 | title={{HiddenMarkovModels.jl: Generic, fast and reliable state space modeling}}, 227 | author={Dalle, Guillaume}, 228 | journal={Journal of Open Source Software}, 229 | volume={9}, 230 | number={96}, 231 | pages={6436}, 232 | year={2024}, 233 | doi={10.21105/joss.06436} 234 | } -------------------------------------------------------------------------------- /paper/paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 'Dynamax: A Python package for probabilistic state space modeling with JAX' 3 | tags: 4 | - Python 5 | - State space models 6 | - dynamics 7 | - JAX 8 | 9 | authors: 10 | - name: Scott W. Linderman 11 | orcid: 0000-0002-3878-9073 12 | affiliation: "1" # (Multiple affiliations must be quoted) 13 | corresponding: true 14 | - name: Peter Chang 15 | affiliation: "2" 16 | - name: Giles Harper-Donnelly 17 | affiliation: "3" 18 | - name: Aleyna Kara 19 | affiliation: "4" 20 | - name: Xinglong Li 21 | affiliation: "5" 22 | - name: Gerardo Duran-Martin 23 | affiliation: "6" 24 | - name: Kevin Murphy 25 | affiliation: "7" 26 | corresponding: true 27 | affiliations: 28 | - name: Department of Statistics and Wu Tsai Neurosciences Institute, Stanford University, USA 29 | index: 1 30 | - name: CSAIL, Massachusetts Institute of Technology, USA 31 | index: 2 32 | - name: Cambridge University, England, UK 33 | index: 3 34 | - name: Computer Science Department, Technical University of Munich Garching, Germany 35 | index: 4 36 | - name: Statistics Department, University of British Columbia, Canada 37 | index: 5 38 | - name: Queen Mary University of London, England, UK 39 | index: 6 40 | - name: Google DeepMind, USA 41 | index: 7 42 | 43 | date: 19 July 2024 44 | bibliography: paper.bib 45 | 46 | --- 47 | 48 | # Summary 49 | 50 | State space models (SSMs) are fundamental tools for modeling sequential data. They are broadly used across engineering disciplines like signal processing and control theory, as well as scientific domains like neuroscience [@vyas2020computation], genetics [@durbin1998biological], ecology [@patterson2008state], computational ethology [@weinreb2024keypoint], economics [@jacquier2002bayesian], and climate science [@ott2004local]. Fast and robust tools for state space modeling are crucial to researchers in all of these application areas. 51 | 52 | State space models specify a probability distribution over a sequence of observations, $y_1, \ldots y_T$, where $y_t$ denotes the observation at time $t$. The key assumption of an SSM is that the observations arise from a sequence of _latent states_, $z_1, \ldots, z_T$, which evolve according to a _dynamics model_ (a.k.a., transition model). An SSM may also use inputs (a.k.a., controls or covariates), $u_1,\ldots,u_T$, to steer the latent state dynamics and influence the observations. 53 | For example, in a neuroscience application from @vyas2020computation, $y_t$ represents a vector of spike counts from $\sim 1000$ measured neurons, and $z_t$ is a lower dimensional latent state that changes slowly over time and captures correlations among the measured neurons. If sensory inputs to the neural circuit are known, they can be encoded in $u_t$. 54 | In the computational ethology application of @weinreb2024keypoint, $y_t$ represents a vector of 3D locations for several key points on an animal's body, and $z_t$ is a discrete behavioral state that specifies how the animal's posture changes over time. 55 | In both examples, there are two main objectives: First, we aim to infer the latent states $z_t$ that best explain the observed data; formally, this is called _state inference_. 56 | Second, we need to estimate the dynamics that govern how latent states evolve; formally, this is part of the _parameter estimation_ process. 57 | `Dynamax` provides algorithms for state inference and parameter estimation in a variety of SSMs. 58 | 59 | There are a few key design choices to make when constructing an SSM: 60 | 61 | - What is the type of latent state? E.g., is $z_t$ a continuous or discrete random variable? 62 | - How do the latent states evolve over time? E.g., are the dynamics linear or nonlinear? 63 | - How are the observations distributed? E.g., are they Gaussian, Poisson, etc.? 64 | 65 | Some design choices are so common they have their own names. Hidden Markov models (HMM) are SSMs with discrete latent states, and linear dynamical systems (LDS) are SSMs with continuous latent states, linear dynamics, and additive Gaussian noise. `Dynamax` supports canonical SSMs and allows the user to construct bespoke models as needed, simply by inheriting from a base class and specifying a few model-specific functions. For example, see the _Creating Custom HMMs_ tutorial in the Dynamax documentation. 66 | 67 | Finally, even for canonical models, there are several algorithms for state inference and parameter estimation. `Dynamax` provides robust implementations of several low-level inference algorithms to suit a variety of applications, allowing users to choose among a host of models and algorithms for their application. More information about state space models and algorithms for state inference and parameter estimation can be found in the textbooks by @murphy2023probabilistic and @sarkka2023bayesian. 68 | 69 | 70 | # Statement of need 71 | 72 | `Dynamax` is an open-source Python package for state space modeling. Since it is built with `JAX` [@jax], it supports just-in-time (JIT) compilation for hardware acceleration on CPU, GPU, and TPU machines. It also supports automatic differentiation for gradient-based model learning. While other libraries exist for state space modeling in Python [@pyhsmm; @ssm; @eeasensors; @seabold2010statsmodels; @hmmlearn] and Julia [@dalle2024hiddenmarkovmodels], `Dynamax` provides a diverse combination of low-level inference algorithms and high-level modeling objects that can support a wide range of research applications in JAX. Additionally, `Dynamax` implements parallel message passing algorithms that leverage the associative scan (a.k.a., parallel scan) primitive in JAX to take full advantage of modern hardware accelerators. Currently, these primitives are not natively supported in other frameworks like PyTorch. While various subsets of these models and algorithms may be found in other libraries, Dynamax is a "one stop shop" for state space modeling in JAX. 73 | 74 | The API for `Dynamax` is divided into two parts: a set of core, functionally pure, low-level inference algorithms, and a high-level, object oriented module for constructing and fitting probabilistic SSMs. The low-level inference API provides message passing algorithms for several common types of SSMs. For example, `Dynamax` provides `JAX` implementations for: 75 | 76 | - Forward-Backward algorithms for discrete-state hidden Markov models (HMMs), 77 | - Kalman filtering and smoothing algorithms for linear Gaussian SSMs, 78 | - Extended and unscented generalized Kalman filtering and smoothing for nonlinear and/or non-Gaussian SSMs, and 79 | - Parallel message passing routines that leverage GPU or TPU acceleration to perform message passing in $O(\log T)$ time on a parallel machine [@stone1975parallel; @sarkka2020temporal; @hassan2021temporal]. Note that these routines are not simply parallelizing over batches of time series, but rather using a parallel algorithm with sublinear depth or span. 80 | 81 | The high-level model API makes it easy to construct, fit, and inspect HMMs and linear Gaussian SSMs. Finally, the online `Dynamax` documentation and tutorials provide a wealth of resources for state space modeling experts and newcomers alike. 82 | 83 | `Dynamax` has supported several publications. The low-level API has been used in machine learning research [@zhao2023revisiting; @lee2023switching; @chang2023low]. Special purpose libraries have been built on top of `Dynamax`, like the Keypoint-MoSeq library for modeling animal behavior [@weinreb2024keypoint] and the Structural Time Series in JAX library, `sts-jax` [@sts-jax]. Finally, the `Dynamax` tutorials are used as reference examples in a major machine learning textbook [@murphy2023probabilistic]. 84 | 85 | # Acknowledgements 86 | 87 | A significant portion of this library was developed while S.W.L. was a Visiting Faculty Researcher at Google and P.C., G.H.D., A.K., and X.L. were Google Summer of Code participants. 88 | 89 | # References 90 | -------------------------------------------------------------------------------- /paper/paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/dynamax/800fae691edc7a372605a230d91344bd4420fd93/paper/paper.pdf -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 30.3.0", "wheel", "versioneer[toml]==0.29"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "dynamax" 7 | dynamic = ["version"] 8 | requires-python = ">= 3.10" 9 | dependencies = [ 10 | "jax>=0.3.15", 11 | "jaxlib", 12 | "fastprogress", 13 | "optax", 14 | "tensorflow_probability", 15 | "scikit-learn", 16 | "jaxtyping", 17 | "typing-extensions", 18 | "numpy" 19 | ] 20 | 21 | authors = [ 22 | {name="Scott Linderman"}, 23 | {name="Peter Chang"}, 24 | {name="Giles Harper-Donnelly"}, 25 | {name="Aleyna Kara"}, 26 | {name="Xinglong Li"}, 27 | {name="Kevin Murphy"} 28 | ] 29 | maintainers = [ 30 | {name="Scott Linderman", email="scott.linderman@stanford.edu"} 31 | ] 32 | description = "Dynamic State Space Models in JAX." 33 | readme = "README.md" 34 | license = {file="LICENSE"} 35 | classifiers = ["Programming Language :: Python"] 36 | 37 | [project.urls] 38 | homepage = "https://github.com/probml/dynamax" 39 | documentation = "https://probml.github.io/dynamax/" 40 | repository = "https://github.com/probml/dynamax" 41 | 42 | [project.optional-dependencies] 43 | notebooks = [ 44 | "matplotlib", 45 | "seaborn", 46 | "flax", 47 | "blackjax", 48 | "graphviz", 49 | "scipy" 50 | ] 51 | 52 | doc = [ 53 | "matplotlib", 54 | "seaborn", 55 | "flax", 56 | "blackjax", 57 | "graphviz", 58 | "scipy", 59 | "sphinx", 60 | "sphinx-autobuild", 61 | "sphinx_autodoc_typehints", 62 | "sphinx-math-dollar", 63 | "myst-nb", 64 | "jupytext", 65 | "sphinx-book-theme" 66 | ] 67 | 68 | test = [ 69 | "codecov", 70 | "coverage", 71 | "pytest>=3.9", 72 | "pytest-cov", 73 | "interrogate>=1.5.0" 74 | ] 75 | 76 | dev = [ 77 | "matplotlib", 78 | "seaborn", 79 | "flax", 80 | "blackjax", 81 | "graphviz", 82 | "scipy", 83 | "sphinx", 84 | "sphinx-autobuild", 85 | "sphinx_autodoc_typehints", 86 | "sphinx-math-dollar", 87 | "myst-nb", 88 | "jupytext", 89 | "sphinx-book-theme", 90 | "codecov", 91 | "coverage", 92 | "pytest>=3.9", 93 | "pytest-cov", 94 | "interrogate>=1.5.0" 95 | ] 96 | 97 | [tool.setuptools.packages.find] 98 | exclude = ["logo", "docs"] 99 | 100 | 101 | [tool.versioneer] 102 | VCS = "git" 103 | style = "pep440-pre" 104 | versionfile_source = "dynamax/_version.py" 105 | versionfile_build = "dynamax/_version.py" 106 | tag_prefix = "" 107 | parentdir_prefix = "" 108 | 109 | [tool.black] 110 | line-length = 120 111 | 112 | [tool.interrogate] 113 | ignore-init-method = true 114 | ignore-init-module = true 115 | fail-under = 66 116 | verbose = 2 117 | quiet = false 118 | color = true 119 | 120 | [tool.ruff.lint] 121 | ignore = ["F722"] 122 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import setuptools 3 | import versioneer 4 | 5 | if __name__ == '__main__': 6 | setuptools.setup(name='dynamax', 7 | version=versioneer.get_version(), 8 | cmdclass=versioneer.get_cmdclass() 9 | ) 10 | --------------------------------------------------------------------------------