├── LICENSE ├── README.md ├── notebooks ├── README.md ├── cubature_example.ipynb ├── ekf_example.ipynb ├── linear_test.ipynb └── runtime.ipynb ├── parsmooth ├── __init__.py ├── cubature_common │ └── __init__.py ├── models │ ├── __init__.py │ ├── bearings.py │ └── linear.py ├── parallel │ ├── __init__.py │ ├── ckf.py │ ├── cks.py │ ├── ekf.py │ ├── eks.py │ └── operators.py ├── sequential │ ├── __init__.py │ ├── cubature.py │ └── extended.py └── utils.py ├── requirements.txt └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 EEA-sensors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parallel Iterated Extended and Sigma-Point Kalman Smoothers 2 | 3 | Companion code in JAX for the paper Parallel Iterated Extended and Sigma-Point Kalman Smoothers [2]. 4 | 5 | What is it? 6 | ----------- 7 | 8 | This is an implementation of parallelized Extended and Sigma-Points Bayesian Filters and Smoothers with CPU/GPU/TPU support coded using [JAX](https://github.com/google/jax) primitives, in particular [associative scan](https://en.wikipedia.org/wiki/Prefix_sum?wprov=sfla1). 9 | 10 | Supported features 11 | ------------------ 12 | 13 | * Extended Kalman Filtering and Smoothing 14 | * Cubature Kalman Filtering and Smoothing 15 | * Iterated versions of the above 16 | 17 | Installation 18 | ------------ 19 | - With GPU CUDA 11.0 support 20 | - Using pip 21 | Run `pip install https://github.com/EEA-sensors/parallel-non-linear-gaussian-smoothers.git -f https://storage.googleapis.com/jax-releases/jax_releases.html` 22 | - By cloning 23 | Clone https://github.com/EEA-sensors/parallel-non-linear-gaussian-smoothers.git 24 | Run `pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_releases.html` 25 | Run `python setup.py [install|develop]` depending on the level of installation you want 26 | - Without GPU support 27 | - By cloning 28 | Clone https://github.com/EEA-sensors/parallel-non-linear-gaussian-smoothers.git 29 | Run `python setup.py [install|develop] --no-deps` depending on the level of installation you want 30 | 31 | Manually install the dependencies `jax` and `jaxlib`, and for examples only `matplotlib`, `numba`, `tqdm` 32 | 33 | Example 34 | ------- 35 | 36 | ```python 37 | from parsmooth.parallel import ieks 38 | from parsmooth.utils import MVNormalParameters 39 | 40 | initial_guess = MVNormalParameters(...) 41 | data = ... 42 | Q = ... # transition noise covariance matrix 43 | R = ... # observation error covariance matrix 44 | 45 | def transition_function(x): 46 | ... 47 | return next_x 48 | 49 | def observation_function(x): 50 | ... 51 | return obs 52 | 53 | iterated_smoothed_trajectories = ieks(initial_guess, 54 | data, 55 | transition_function, 56 | Q, 57 | observation_function, 58 | R, 59 | n_iter=100) # runs the parallel IEKS 100 times. 60 | 61 | ``` 62 | 63 | For more examples, see the [notebooks](https://github.com/EEA-sensors/parallel-non-linear-gaussian-smoothers/tree/master/notebooks). 64 | 65 | Acknowlegments 66 | -------------- 67 | This JAX-based code was created by [Adrien Corenflos](https://adriencorenflos.github.io/) to implement the original idea by [Fatemeh Yaghoobi](https://fatameh-yaghoobi.github.io/) [2] who provided the initial code for the parallelized extended Kalman filter in pure Python. The sequential cubature filtering code was adapted from some original code by [Zheng Zhao](https://users.aalto.fi/~zhaoz1/). 68 | 69 | References 70 | ---------- 71 | 72 | [1] S. Särkkä and A. F. García-Fernández. *Temporal Parallelization of Bayesian Smoothers.* In: IEEE Transactions on Automatic Control 2020. 73 | 74 | [2] F. Yaghoobi and A. Corenflos and S. Hassan and S. Särkkä. *Parallel Iterated Extended and Sigma-Points Kalman Smoothers.* To appear in Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). ([arXiv](http://arxiv.org/abs/2102.00514)) 75 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | ### Author: [Adrien Corenflos](https://github.com/AdrienCorenflos/) 2 | 3 | These notebooks illustrate the parallel implementation of iterated Kalman filters and smoothers on GPU. They can be downloaded to be run locally or on Google Colab. 4 | -------------------------------------------------------------------------------- /notebooks/ekf_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 16, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import jax.numpy as jnp\n", 17 | "from jax.config import config\n", 18 | "\n", 19 | "from parsmooth.sequential import ieks as seq_ieks\n", 20 | "from parsmooth.models.bearings import get_data, make_parameters, plot_bearings\n", 21 | "from parsmooth.parallel import ekf, eks, ieks\n", 22 | "from parsmooth.utils import MVNormalParameters\n", 23 | "\n", 24 | "# from jax import jit\n", 25 | "\n", 26 | "jax.config.update(\"jax_debug_nans\", False)\n", 27 | "jax.config.update(\"jax_enable_x64\", False)\n", 28 | "jax.config.update(\"jax_platform_name\", \"gpu\")" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "### Input parameters" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "s1 = jnp.array([-1.5, 0.5]) # First sensor location\n", 45 | "s2 = jnp.array([1., 1.]) # Second sensor location\n", 46 | "r = 0.5 # Observation noise (stddev)\n", 47 | "dt = 0.01 # discretization time step\n", 48 | "x0 = jnp.array([0.1, 0.2, 1, 0]) # initial true location\n", 49 | "qc = 0.01 # discretization noise\n", 50 | "qw = 0.1 # discretization noise\n", 51 | "\n", 52 | "T = 1000 # number of observations" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "### Get parameters" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "Q, R, observation_function, transition_function = make_parameters(qc, qw, r, dt, s1, s2)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "### Get data" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 4, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "ts, true_states, observations = get_data(x0, dt, r, T, s1, s2, random_state=42)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "### We can now run the filter and smoother" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "Initial state guess" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 5, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "m = jnp.array([-1., -1., 0., 0., 0.])\n", 108 | "P = jnp.eye(5)\n", 109 | "\n", 110 | "initial_guess = MVNormalParameters(m, P)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "Run the filter" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 6, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "initial_states = MVNormalParameters(jnp.zeros((T + 1, 5)), jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T + 1, axis=0))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 7, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "filtered = ekf(initial_guess,\n", 136 | " observations,\n", 137 | " transition_function,\n", 138 | " Q,\n", 139 | " observation_function,\n", 140 | " R,\n", 141 | " initial_states.mean)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 8, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "smoothed = eks(transition_function, Q, filtered)\n" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "source": [ 156 | "### And the iterated one" 157 | ], 158 | "metadata": { 159 | "collapsed": false 160 | } 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 9, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "iterated_smoothed_trajectories = ieks(initial_guess,\n", 169 | " observations,\n", 170 | " transition_function,\n", 171 | " Q,\n", 172 | " observation_function,\n", 173 | " R,\n", 174 | " n_iter=15)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 17, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "seq_iterated_smoothed_trajectories = seq_ieks(initial_guess,\n", 184 | " observations,\n", 185 | " transition_function,\n", 186 | " Q,\n", 187 | " observation_function,\n", 188 | " R,\n", 189 | " None,\n", 190 | " n_iter=15,\n", 191 | " propagate_first=False)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "### For comparison we can run the sequential iterated smoother too" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "### Plot the result" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 11, 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "data": { 215 | "text/plain": "
", 216 | "image/png": "\n" 217 | }, 218 | "metadata": { 219 | "needs_background": "light" 220 | }, 221 | "output_type": "display_data" 222 | } 223 | ], 224 | "source": [ 225 | "plot_bearings([true_states, filtered.mean, smoothed.mean, iterated_smoothed_trajectories.mean,\n", 226 | " seq_iterated_smoothed_trajectories.mean],\n", 227 | " [\"True\", \"Filter\", \"Smoother\", \"Iterated Smoother\", \"Seq - Iterated Smoother\"],\n", 228 | " s1, s2, figsize=(15, 10), quiver=False)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 12, 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "data": { 238 | "text/plain": "(1001, 5)" 239 | }, 240 | "execution_count": 12, 241 | "metadata": {}, 242 | "output_type": "execute_result" 243 | } 244 | ], 245 | "source": [ 246 | "iterated_smoothed_trajectories.mean.shape" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 13, 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "data": { 256 | "text/plain": "(1000, 5)" 257 | }, 258 | "execution_count": 13, 259 | "metadata": {}, 260 | "output_type": "execute_result" 261 | } 262 | ], 263 | "source": [ 264 | "seq_iterated_smoothed_trajectories.mean.shape\n" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 13, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 13, 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 13, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 13, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 13, 298 | "outputs": [], 299 | "source": [], 300 | "metadata": { 301 | "collapsed": false, 302 | "pycharm": { 303 | "name": "#%%\n" 304 | } 305 | } 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 13, 310 | "outputs": [], 311 | "source": [], 312 | "metadata": { 313 | "collapsed": false, 314 | "pycharm": { 315 | "name": "#%%\n" 316 | } 317 | } 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 13, 322 | "outputs": [], 323 | "source": [], 324 | "metadata": { 325 | "collapsed": false, 326 | "pycharm": { 327 | "name": "#%%\n" 328 | } 329 | } 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 13, 334 | "outputs": [], 335 | "source": [], 336 | "metadata": { 337 | "collapsed": false, 338 | "pycharm": { 339 | "name": "#%%\n" 340 | } 341 | } 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 13, 346 | "outputs": [], 347 | "source": [], 348 | "metadata": { 349 | "collapsed": false, 350 | "pycharm": { 351 | "name": "#%%\n" 352 | } 353 | } 354 | } 355 | ], 356 | "metadata": { 357 | "kernelspec": { 358 | "name": "python3", 359 | "language": "python", 360 | "display_name": "Python 3 (ipykernel)" 361 | }, 362 | "language_info": { 363 | "codemirror_mode": { 364 | "name": "ipython", 365 | "version": 3 366 | }, 367 | "file_extension": ".py", 368 | "mimetype": "text/x-python", 369 | "name": "python", 370 | "nbconvert_exporter": "python", 371 | "pygments_lexer": "ipython3", 372 | "version": "3.8.5" 373 | } 374 | }, 375 | "nbformat": 4, 376 | "nbformat_minor": 4 377 | } -------------------------------------------------------------------------------- /notebooks/linear_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "metadata": { 13 | "ExecuteTime": { 14 | "end_time": "2024-08-09T19:57:28.160509Z", 15 | "start_time": "2024-08-09T19:57:27.779511Z" 16 | } 17 | }, 18 | "source": [ 19 | "import jax.numpy as jnp\n", 20 | "from parsmooth.parallel import ekf, eks\n", 21 | "from parsmooth.sequential import ekf as seq_ekf, eks as seq_eks, ckf as seq_ckf, cks as seq_cks\n", 22 | "from parsmooth.models.linear import get_data, make_parameters\n", 23 | "from parsmooth.utils import MVNormalParameters" 24 | ], 25 | "outputs": [], 26 | "execution_count": 1 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Input parameters" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "metadata": { 38 | "ExecuteTime": { 39 | "end_time": "2024-08-09T19:57:28.478071Z", 40 | "start_time": "2024-08-09T19:57:28.162240Z" 41 | } 42 | }, 43 | "source": [ 44 | "r = 0.5\n", 45 | "q = 0.1\n", 46 | "x0 = jnp.array([0., 0.]) # initial true location\n", 47 | "\n", 48 | "T = 1000 # number of observations" 49 | ], 50 | "outputs": [ 51 | { 52 | "name": "stderr", 53 | "output_type": "stream", 54 | "text": [ 55 | "2024-08-09 20:57:28.400845: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.5 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" 56 | ] 57 | } 58 | ], 59 | "execution_count": 2 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "### Get parameters" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "metadata": { 71 | "ExecuteTime": { 72 | "end_time": "2024-08-09T19:57:28.871458Z", 73 | "start_time": "2024-08-09T19:57:28.479254Z" 74 | } 75 | }, 76 | "source": [ 77 | "A, H, Q, R, observation_function, transition_function = make_parameters(r, q)" 78 | ], 79 | "outputs": [], 80 | "execution_count": 3 81 | }, 82 | { 83 | "cell_type": "code", 84 | "metadata": { 85 | "ExecuteTime": { 86 | "end_time": "2024-08-09T19:57:28.875619Z", 87 | "start_time": "2024-08-09T19:57:28.872932Z" 88 | } 89 | }, 90 | "source": [ 91 | "observation_function = jnp.vectorize(observation_function, signature=\"(m)->(d)\")\n", 92 | "transition_function = jnp.vectorize(transition_function, signature=\"(m)->(m)\")" 93 | ], 94 | "outputs": [], 95 | "execution_count": 4 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "### Get data" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "metadata": { 107 | "ExecuteTime": { 108 | "end_time": "2024-08-09T19:57:29.444821Z", 109 | "start_time": "2024-08-09T19:57:28.876503Z" 110 | } 111 | }, 112 | "source": [ 113 | "ts, true_states, observations = get_data(x0, A, H, R, Q, T, 42)" 114 | ], 115 | "outputs": [], 116 | "execution_count": 5 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "### We can now run the filter" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "Initial state guess" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "metadata": { 135 | "ExecuteTime": { 136 | "end_time": "2024-08-09T19:57:29.449773Z", 137 | "start_time": "2024-08-09T19:57:29.445689Z" 138 | } 139 | }, 140 | "source": [ 141 | "m = jnp.array([0., 0.])\n", 142 | "P = jnp.eye(2)\n", 143 | "\n", 144 | "initial_guess = MVNormalParameters(m, P)" 145 | ], 146 | "outputs": [], 147 | "execution_count": 6 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "### We can now run the smoother" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "Run the filters" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "ExecuteTime": { 167 | "end_time": "2024-08-09T19:57:42.814995Z", 168 | "start_time": "2024-08-09T19:57:29.451638Z" 169 | } 170 | }, 171 | "source": [ 172 | "par_ekf_filtered = ekf(initial_guess, observations, transition_function, Q, observation_function, R)\n", 173 | "seq_ekf_ll, seq_ekf_filtered = seq_ekf(initial_guess, observations, transition_function, Q, observation_function, R)\n", 174 | "par_ckf_ll, seq_ckf_filtered = seq_ckf(initial_guess, observations, transition_function, Q, observation_function, R)" 175 | ], 176 | "outputs": [], 177 | "execution_count": 7 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "Compare:" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "metadata": { 189 | "ExecuteTime": { 190 | "end_time": "2024-08-09T19:57:43.061201Z", 191 | "start_time": "2024-08-09T19:57:42.816350Z" 192 | } 193 | }, 194 | "source": [ 195 | "print(seq_ekf_ll, par_ckf_ll)\n", 196 | "\n", 197 | "print(jnp.max(jnp.abs(par_ekf_filtered.mean - seq_ekf_filtered.mean)))\n", 198 | "print(jnp.max(jnp.abs(par_ekf_filtered.mean - seq_ckf_filtered.mean)))\n", 199 | "\n", 200 | "print(jnp.max(jnp.abs(par_ekf_filtered.cov - seq_ekf_filtered.cov)))\n", 201 | "print(jnp.max(jnp.abs(par_ekf_filtered.cov - seq_ckf_filtered.cov)))" 202 | ], 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "-1178.5336 -1178.5336\n", 209 | "1.7851591e-05\n", 210 | "1.7851591e-05\n", 211 | "2.346933e-06\n", 212 | "2.346933e-06\n" 213 | ] 214 | } 215 | ], 216 | "execution_count": 8 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "Run the smoothers" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "metadata": { 228 | "ExecuteTime": { 229 | "end_time": "2024-08-09T19:57:46.092218Z", 230 | "start_time": "2024-08-09T19:57:43.063031Z" 231 | } 232 | }, 233 | "source": [ 234 | "par_eks_smoothed = eks(transition_function, Q, par_ekf_filtered, par_ekf_filtered.mean)\n", 235 | "seq_eks_smoothed = seq_eks(transition_function, Q, par_ekf_filtered)\n", 236 | "seq_cks_smoothed = seq_cks(transition_function, Q, par_ekf_filtered)" 237 | ], 238 | "outputs": [], 239 | "execution_count": 9 240 | }, 241 | { 242 | "cell_type": "code", 243 | "metadata": { 244 | "ExecuteTime": { 245 | "end_time": "2024-08-09T19:57:46.100782Z", 246 | "start_time": "2024-08-09T19:57:46.093902Z" 247 | } 248 | }, 249 | "source": [ 250 | "print(jnp.max(jnp.abs(par_eks_smoothed.mean - seq_eks_smoothed.mean)))\n", 251 | "print(jnp.max(jnp.abs(par_eks_smoothed.mean - seq_cks_smoothed.mean)))\n", 252 | "\n", 253 | "print(jnp.max(jnp.abs(par_eks_smoothed.cov - seq_eks_smoothed.cov)))\n", 254 | "print(jnp.max(jnp.abs(par_eks_smoothed.cov - seq_cks_smoothed.cov)))" 255 | ], 256 | "outputs": [ 257 | { 258 | "name": "stdout", 259 | "output_type": "stream", 260 | "text": [ 261 | "6.023445e-05\n", 262 | "6.01972e-05\n", 263 | "2.7239323e-05\n", 264 | "2.7239323e-05\n" 265 | ] 266 | } 267 | ], 268 | "execution_count": 10 269 | }, 270 | { 271 | "cell_type": "code", 272 | "metadata": { 273 | "ExecuteTime": { 274 | "end_time": "2024-08-09T19:57:46.111310Z", 275 | "start_time": "2024-08-09T19:57:46.105079Z" 276 | } 277 | }, 278 | "source": [], 279 | "outputs": [], 280 | "execution_count": 10 281 | } 282 | ], 283 | "metadata": { 284 | "kernelspec": { 285 | "name": "python3", 286 | "language": "python", 287 | "display_name": "Python 3 (ipykernel)" 288 | }, 289 | "language_info": { 290 | "codemirror_mode": { 291 | "name": "ipython", 292 | "version": 3 293 | }, 294 | "file_extension": ".py", 295 | "mimetype": "text/x-python", 296 | "name": "python", 297 | "nbconvert_exporter": "python", 298 | "pygments_lexer": "ipython3", 299 | "version": "3.8.5" 300 | } 301 | }, 302 | "nbformat": 4, 303 | "nbformat_minor": 4 304 | } 305 | -------------------------------------------------------------------------------- /notebooks/runtime.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Runtime experiments for CPU and GPU benchmarking of our algorithms" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "id": "6-Car3VDUC1b" 15 | }, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Model name: Intel(R) Xeon(R) W-2133 CPU @ 3.60GHz\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "!lscpu |grep 'Model name'" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": { 33 | "colab": { 34 | "base_uri": "https://localhost:8080/", 35 | "height": 51 36 | }, 37 | "id": "AQzvVUCESkGP", 38 | "outputId": "e998c8f0-3abe-48de-bbfe-6f4f09662af2" 39 | }, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "Mon Oct 19 21:43:43 2020 \n", 46 | "+-----------------------------------------------------------------------------+\n", 47 | "| NVIDIA-SMI 455.23.05 Driver Version: 455.23.05 CUDA Version: 11.1 |\n", 48 | "|-------------------------------+----------------------+----------------------+\n", 49 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 50 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 51 | "| | | MIG M. |\n", 52 | "|===============================+======================+======================|\n", 53 | "| 0 Quadro P2000 On | 00000000:91:00.0 On | N/A |\n", 54 | "| 47% 36C P5 8W / 75W | 4722MiB / 5050MiB | 3% Default |\n", 55 | "| | | N/A |\n", 56 | "+-------------------------------+----------------------+----------------------+\n", 57 | " \n", 58 | "+-----------------------------------------------------------------------------+\n", 59 | "| Processes: |\n", 60 | "| GPU GI CI PID Type Process name GPU Memory |\n", 61 | "| ID ID Usage |\n", 62 | "|=============================================================================|\n", 63 | "| 0 N/A N/A 1177 G /usr/lib/xorg/Xorg 65MiB |\n", 64 | "| 0 N/A N/A 1217 G /usr/bin/gnome-shell 80MiB |\n", 65 | "| 0 N/A N/A 1797 G /usr/lib/xorg/Xorg 385MiB |\n", 66 | "| 0 N/A N/A 2014 G /usr/bin/gnome-shell 34MiB |\n", 67 | "| 0 N/A N/A 2333 G ...oken=12133755331783826913 10MiB |\n", 68 | "| 0 N/A N/A 3158 G .../debug.log --shared-files 43MiB |\n", 69 | "| 0 N/A N/A 4660 G ...AAAAAAAAA= --shared-files 142MiB |\n", 70 | "| 0 N/A N/A 30637 C ...rallelEKF/venv/bin/python 3951MiB |\n", 71 | "+-----------------------------------------------------------------------------+\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "!nvidia-smi" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": { 82 | "id": "D66bXe-p0wdC" 83 | }, 84 | "source": [ 85 | "### Imports" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 36, 91 | "metadata": { 92 | "id": "cfLgDa4d0wdD" 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "from jax import jit, devices, make_jaxpr\n", 97 | "from jax.config import config\n", 98 | "import jax.numpy as jnp\n", 99 | "from matplotlib import rcParams\n", 100 | "import matplotlib.pyplot as plt\n", 101 | "import numpy as np\n", 102 | "import pandas as pd\n", 103 | "import time\n", 104 | "import tqdm\n", 105 | "\n", 106 | "from parsmooth.parallel import ieks, icks\n", 107 | "from parsmooth.sequential import ieks as seq_ieks, icks as seq_icks\n", 108 | "from parsmooth.models.bearings import get_data, make_parameters, plot_bearings\n", 109 | "from parsmooth.utils import MVNormalParameters" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 30, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "# rcParams['font.family'] = 'sans-serif'\n", 119 | "rcParams['font.sans-serif'] = ['Computer Modern Sans serif']" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": { 125 | "id": "7n5s5J1S0wdH" 126 | }, 127 | "source": [ 128 | "### Input parameters" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 5, 134 | "metadata": { 135 | "colab": { 136 | "base_uri": "https://localhost:8080/", 137 | "height": 51 138 | }, 139 | "id": "I7udHQWm0wdH", 140 | "outputId": "112b06e5-55f9-4d22-f4db-27d7df41cb1b" 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "s1 = jnp.array([-1.5, 0.5]) # First sensor location\n", 145 | "s2 = jnp.array([1., 1.]) # Second sensor location\n", 146 | "r = 5. # Observation noise (stddev) - Large because IEKS is not very stable\n", 147 | "dt = 0.01 # discretization time step\n", 148 | "x0 = jnp.array([0.1, 0.2, 1, 0]) # initial true location\n", 149 | "qc = 0.1 # noise - Large because IEKS is not very stable\n", 150 | "qw = 0.1 # noise - Small because IEKS is not very stable\n", 151 | "\n", 152 | "T = 100 # number of observations" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": { 158 | "id": "rjRijP8D0wdL" 159 | }, 160 | "source": [ 161 | "### Get parameters" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "metadata": { 168 | "id": "MnBGuowI0wdL" 169 | }, 170 | "outputs": [], 171 | "source": [ 172 | "Q, R, observation_function, transition_function = make_parameters(qc, qw, r, dt, s1, s2)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 7, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "transition_function = jnp.vectorize(transition_function, signature=\"(m)->(m)\")\n", 182 | "observation_function = jnp.vectorize(observation_function, signature=\"(m)->(d)\")" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": { 188 | "id": "EEWK8aah0wdO" 189 | }, 190 | "source": [ 191 | "### Get data" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 8, 197 | "metadata": { 198 | "id": "tMnoVBYc0wdP" 199 | }, 200 | "outputs": [], 201 | "source": [ 202 | "ts, true_states, observations = get_data(x0, dt, r, T, s1, s2, qw, random_state=42)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": { 208 | "id": "wKIP9WPu0wdS" 209 | }, 210 | "source": [ 211 | "### We can now run the filter" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": { 217 | "id": "MNhOFYON0wdS" 218 | }, 219 | "source": [ 220 | "Initial state guess" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 9, 226 | "metadata": { 227 | "id": "0Z9b5Hpe0wdT" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "m = jnp.array([-1., -1., 0., 0., 0.])\n", 232 | "P = jnp.eye(5)\n", 233 | "\n", 234 | "initial_guess = MVNormalParameters(m, P)\n", 235 | "initial_linearization_points = jnp.zeros((T, 5), dtype=m.dtype)\n", 236 | "initial_linearization_covariances = jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T, axis=0)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": { 242 | "id": "SzlrSEnn0wdi" 243 | }, 244 | "source": [ 245 | "### Sequential vs Parallel computation time comparison" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 10, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "gpu_par_ieks = jit(ieks, static_argnums=(2, 4, 7), backend=\"gpu\")\n", 255 | "cpu_par_ieks = jit(ieks, static_argnums=(2, 4, 7), backend=\"cpu\")\n", 256 | "\n", 257 | "gpu_seq_ieks = jit(seq_ieks, static_argnums=(2, 4, 7), backend=\"gpu\")\n", 258 | "cpu_seq_ieks = jit(seq_ieks, static_argnums=(2, 4, 7), backend=\"cpu\")\n", 259 | "\n", 260 | "gpu_par_icks = jit(icks, static_argnums=(2, 4, 7), backend=\"gpu\")\n", 261 | "cpu_par_icks = jit(icks, static_argnums=(2, 4, 7), backend=\"cpu\")\n", 262 | "\n", 263 | "gpu_seq_icks = jit(seq_icks, static_argnums=(2, 4, 7), backend=\"gpu\")\n", 264 | "cpu_seq_icks = jit(seq_icks, static_argnums=(2, 4, 7), backend=\"cpu\")" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 11, 270 | "metadata": { 271 | "id": "8ihUWWOniFgy" 272 | }, 273 | "outputs": [], 274 | "source": [ 275 | "def profile_smoother(s_method, lengths, n_runs=1, n_iter=10):\n", 276 | " res_mean = []\n", 277 | " for j in tqdm.tqdm(lengths):\n", 278 | " observations_slice = observations[:j]\n", 279 | " init_linearizations_points_slice = initial_linearization_points[:j]\n", 280 | " init_linearizations_covs_slice = initial_linearization_covariances[:j]\n", 281 | " init_linearizations_states = MVNormalParameters(init_linearizations_points_slice, init_linearizations_covs_slice)\n", 282 | " args = initial_guess, observations_slice, transition_function, Q, observation_function, R, init_linearizations_states, n_iter\n", 283 | " s = s_method(*args) # this is a call used for compiling the function, this is a bit slow at the moment in JAX and shouldn't be taken into account for benchmarking.\n", 284 | " # they are currently working on AOT compilation, which would then reduce the overhead substantially. \n", 285 | " s.mean.block_until_ready()\n", 286 | " run_times = []\n", 287 | " for _ in range(n_runs):\n", 288 | " tic = time.time()\n", 289 | " s_states = s_method(*args)\n", 290 | " s_states.mean.block_until_ready()\n", 291 | " toc = time.time()\n", 292 | " run_times.append(toc - tic)\n", 293 | " res_mean.append(np.mean(run_times))\n", 294 | " return np.array(res_mean)" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": { 300 | "id": "VrId5I8-0wdj" 301 | }, 302 | "source": [ 303 | "Let's now run the sequential vs the parallel implementation to see the performance gain coming from such parallelisation" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 12, 309 | "metadata": { 310 | "id": "L4aio98JjBvs" 311 | }, 312 | "outputs": [], 313 | "source": [ 314 | "lengths_space = np.logspace(1, np.log10(T), num=20).astype(np.int32)" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 13, 320 | "metadata": {}, 321 | "outputs": [ 322 | { 323 | "name": "stderr", 324 | "output_type": "stream", 325 | "text": [ 326 | "100%|██████████| 3/3 [01:58<00:00, 39.46s/it]\n", 327 | "100%|██████████| 3/3 [00:07<00:00, 2.49s/it]\n", 328 | "100%|██████████| 3/3 [02:21<00:00, 47.26s/it]\n", 329 | "100%|██████████| 3/3 [00:08<00:00, 2.96s/it]\n" 330 | ] 331 | } 332 | ], 333 | "source": [ 334 | "gpu_par_ieks_time = profile_smoother(gpu_par_ieks, lengths_space)\n", 335 | "cpu_par_ieks_time = profile_smoother(cpu_par_ieks, lengths_space)\n", 336 | "\n", 337 | "gpu_seq_ieks_time = profile_smoother(gpu_seq_ieks, lengths_space)\n", 338 | "cpu_seq_ieks_time = profile_smoother(cpu_seq_ieks, lengths_space)\n", 339 | "\n", 340 | "gpu_par_icks_time = profile_smoother(gpu_par_icks, lengths_space)\n", 341 | "cpu_par_icks_time = profile_smoother(cpu_par_icks, lengths_space)\n", 342 | "\n", 343 | "gpu_seq_icks_time = profile_smoother(gpu_seq_icks, lengths_space)\n", 344 | "cpu_seq_icks_time = profile_smoother(cpu_seq_icks, lengths_space)\n" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 41, 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "data = np.stack([\n", 354 | " gpu_par_ieks_time,\n", 355 | " cpu_par_ieks_time,\n", 356 | " gpu_seq_ieks_time,\n", 357 | " cpu_seq_ieks_time,\n", 358 | " gpu_par_icks_time,\n", 359 | " cpu_par_icks_time,\n", 360 | " gpu_seq_icks_time,\n", 361 | " cpu_seq_icks_time],\n", 362 | " axis=1)\n", 363 | "\n", 364 | "columns = [\"GPU_par_IEKS\",\n", 365 | " \"CPU_par_IEKS\",\n", 366 | " \"GPU_seq_IEKS\",\n", 367 | " \"CPU_seq_IEKS\",\n", 368 | " \"GPU_par_ICKS\",\n", 369 | " \"CPU_par_ICKS\",\n", 370 | " \"GPU_seq_ICKS\",\n", 371 | " \"CPU_seq_ICKS\"]\n", 372 | "\n", 373 | "df = pd.DataFrame(index=lengths_space, data=data, colums=columns)\n", 374 | "df.to_csv(\"...\")" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [] 383 | } 384 | ], 385 | "metadata": { 386 | "colab": { 387 | "name": "ekf_example.ipynb", 388 | "provenance": [], 389 | "toc_visible": true 390 | }, 391 | "kernelspec": { 392 | "display_name": "parsmooth", 393 | "language": "python", 394 | "name": "parsmooth" 395 | }, 396 | "language_info": { 397 | "codemirror_mode": { 398 | "name": "ipython", 399 | "version": 3 400 | }, 401 | "file_extension": ".py", 402 | "mimetype": "text/x-python", 403 | "name": "python", 404 | "nbconvert_exporter": "python", 405 | "pygments_lexer": "ipython3", 406 | "version": "3.8.5" 407 | } 408 | }, 409 | "nbformat": 4, 410 | "nbformat_minor": 4 411 | } -------------------------------------------------------------------------------- /parsmooth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EEA-sensors/parallel-non-linear-gaussian-smoothers/4fb19a189f1338e8cec9dafd2d8088c233875b5d/parsmooth/__init__.py -------------------------------------------------------------------------------- /parsmooth/cubature_common/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import Tuple 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | from ..utils import MVNormalParameters 8 | 9 | SigmaPoints = namedtuple( 10 | 'SigmaPoints', ['points', 'wm', 'wc'] 11 | ) 12 | 13 | 14 | def mean_sigma_points(points): 15 | """ 16 | Computes the mean of sigma points 17 | 18 | Parameters 19 | ---------- 20 | points: SigmaPoints 21 | The sigma points 22 | 23 | Returns 24 | ------- 25 | mean: array_like 26 | the mean of the sigma points 27 | """ 28 | return jnp.dot(points.wm, points.points) 29 | 30 | 31 | def covariance_sigma_points(points_1, mean_1, points_2, mean_2): 32 | """ 33 | Computes the covariance between two sets of sigma points 34 | 35 | Parameters 36 | ---------- 37 | points_1: SigmaPoints 38 | first set of sigma points 39 | mean_1: array_like 40 | assumed mean of the first set of points 41 | points_2: SigmaPoints 42 | second set of sigma points 43 | points_1: SigmaPoints 44 | assumed mean of the second set of points 45 | 46 | Returns 47 | ------- 48 | cov: array_like 49 | the covariance of the two sets 50 | """ 51 | one = (points_1.points - mean_1.reshape(1, -1)).T * points_1.wc.reshape(1, -1) 52 | two = points_2.points - mean_2.reshape(1, -1) 53 | return jnp.dot(one, two) 54 | 55 | 56 | def cubature_weights(n_dim: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 57 | """ Computes the weights associated with the spherical cubature method. 58 | The number of sigma-points is 2 * n_dim 59 | 60 | Parameters 61 | ---------- 62 | n_dim: int 63 | Dimensionality of the problem 64 | 65 | Returns 66 | ------- 67 | wm: np.ndarray 68 | Weights means 69 | wc: np.ndarray 70 | Weights covariances 71 | xi: np.ndarray 72 | Orthogonal vectors 73 | """ 74 | wm = np.ones(shape=(2 * n_dim,)) / (2 * n_dim) 75 | wc = wm 76 | xi = np.concatenate([np.eye(n_dim), -np.eye(n_dim)], axis=0) * np.sqrt(n_dim) 77 | 78 | return wm, wc, xi 79 | 80 | 81 | def get_sigma_points(mv_normal_parameters: MVNormalParameters) -> SigmaPoints: 82 | """ Computes the sigma-points for a given mv normal distribution 83 | The number of sigma-points is 2*n_dim 84 | 85 | Parameters 86 | ---------- 87 | mv_normal_parameters: MVNormalParameters 88 | Mean and Covariance of the distribution 89 | 90 | Returns 91 | ------- 92 | out: SigmaPoints 93 | sigma points for the spherical cubature transform 94 | 95 | """ 96 | mean = mv_normal_parameters.mean 97 | n_dim = mean.shape[0] 98 | 99 | wm, wc, xi = cubature_weights(n_dim) 100 | 101 | sigma_points = jnp.repeat(mean.reshape(1, -1), wm.shape[0], axis=0) \ 102 | + jnp.dot(jnp.linalg.cholesky(mv_normal_parameters.cov), xi.T).T 103 | 104 | return SigmaPoints(sigma_points, wm, wc) 105 | 106 | 107 | def get_mv_normal_parameters(sigma_points: SigmaPoints) -> MVNormalParameters: 108 | """ Computes the MV Normal distribution parameters associated with the sigma points 109 | 110 | Parameters 111 | ---------- 112 | sigma_points: SigmaPoints 113 | shape of sigma_points.points is (n_dim, 2*n_dim) 114 | Returns 115 | ------- 116 | out: MVNormalParameters 117 | Mean and covariance of RV of dimension K computed from sigma-points 118 | """ 119 | m = mean_sigma_points(sigma_points) 120 | cov = covariance_sigma_points(sigma_points, m, sigma_points, m) 121 | return MVNormalParameters(m, cov=cov) 122 | -------------------------------------------------------------------------------- /parsmooth/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EEA-sensors/parallel-non-linear-gaussian-smoothers/4fb19a189f1338e8cec9dafd2d8088c233875b5d/parsmooth/models/__init__.py -------------------------------------------------------------------------------- /parsmooth/models/bearings.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax.numpy as jnp 4 | import matplotlib.pyplot as plt 5 | import numba as nb 6 | import numpy as np 7 | import scipy.linalg as linalg 8 | from jax import lax, jit 9 | 10 | __all__ = ["make_parameters", "get_data"] 11 | 12 | 13 | def _transition_function(x, dt): 14 | """ Deterministic transition function used in the state space model 15 | 16 | Parameters 17 | ---------- 18 | x: array_like 19 | The current state 20 | dt: float 21 | Time step between observations 22 | 23 | Returns 24 | ------- 25 | out: array_like 26 | The transitioned state 27 | """ 28 | w = x[-1] 29 | predicate = jnp.abs(w) < 1e-6 30 | 31 | coswt = jnp.cos(w * dt) 32 | sinwt = jnp.sin(w * dt) 33 | 34 | def true_fun(_): 35 | return coswt, 0., sinwt, dt 36 | 37 | def false_fun(_): 38 | coswto = coswt - 1 39 | return coswt, coswto / w, sinwt, sinwt / w 40 | 41 | coswt, coswtopw, sinwt, sinwtpw = lax.cond(predicate, true_fun, false_fun, None) 42 | 43 | F = jnp.array([[1, 0, sinwtpw, -coswtopw, 0], 44 | [0, 1, coswtopw, sinwtpw, 0], 45 | [0, 0, coswt, sinwt, 0], 46 | [0, 0, -sinwt, coswt, 0], 47 | [0, 0, 0, 0, 1]]) 48 | return F @ x 49 | 50 | 51 | def _observation_function(x, s1, s2): 52 | """ 53 | Returns the observed angles as function of the state and the sensors locations 54 | 55 | Parameters 56 | ---------- 57 | x: array_like 58 | The current state 59 | s1: array_like 60 | The first sensor location 61 | s2: array_like 62 | The second sensor location 63 | 64 | Returns 65 | ------- 66 | y: array_like 67 | The observed angles, the first component is the angle w.r.t. the first sensor, the second w.r.t the second. 68 | """ 69 | return jnp.array([jnp.arctan2(x[1] - s1[1], x[0] - s1[0]), 70 | jnp.arctan2(x[1] - s2[1], x[0] - s2[0])]) 71 | 72 | 73 | @partial(jnp.vectorize, excluded=(1, 2), signature="(m)->(d)") 74 | def inverse_bearings(observation, s1, s2): 75 | """ 76 | Inverse the bearings observation to the location as if there was no noise, 77 | This is only used to provide an initial point for the linearization of the IEKS and ICKS. 78 | 79 | Parameters 80 | ---------- 81 | observation: (2) array 82 | The bearings observation 83 | s1: (2) array 84 | The first sensor position 85 | s2: (2) array 86 | The second sensor position 87 | 88 | Returns 89 | ------- 90 | out: (2) array 91 | The inversed position of the state 92 | """ 93 | tan_theta = jnp.tan(observation) 94 | A = jnp.array([[tan_theta[0], -1], 95 | [tan_theta[1], -1]]) 96 | b = jnp.array([s1[0] * tan_theta[0] - s1[1], 97 | s2[0] * tan_theta[1] - s2[1]]) 98 | return jnp.linalg.solve(A, b) 99 | 100 | 101 | def make_parameters(qc, qw, r, dt, s1, s2): 102 | """ Discretizes the model with continuous transition noise qc, for step-size dt. 103 | The model is described in "Multitarget-multisensor tracking: principles and techniques" by 104 | Bar-Shalom, Yaakov and Li, Xiao-Rong 105 | 106 | Parameters 107 | ---------- 108 | qc: float 109 | Transition covariance of the continuous SSM 110 | qw: float 111 | Transition covariance of the continuous SSM 112 | r: float 113 | Observation error standard deviation 114 | dt: float 115 | Discretization time step 116 | s1: array_like 117 | The location of the first sensor 118 | s2: array_like 119 | The location of the second sensor 120 | 121 | Returns 122 | ------- 123 | Q: array_like 124 | The transition covariance matrix for the discrete SSM 125 | R: array_like 126 | The observation covariance matrix 127 | observation_function: callable 128 | The observation function 129 | transition_function: callable 130 | The transition function 131 | """ 132 | 133 | Q = jnp.array([[qc * dt ** 3 / 3, 0, qc * dt ** 2 / 2, 0, 0], 134 | [0, qc * dt ** 3 / 3, 0, qc * dt ** 2 / 2, 0], 135 | [qc * dt ** 2 / 2, 0, qc * dt, 0, 0], 136 | [0, qc * dt ** 2 / 2, 0, qc * dt, 0], 137 | [0, 0, 0, 0, dt * qw]]) 138 | 139 | R = r ** 2 * jnp.eye(2) 140 | 141 | observation_function = jit(partial(_observation_function, s1=s1, s2=s2)) 142 | transition_function = jit(partial(_transition_function, dt=dt)) 143 | 144 | return Q, R, observation_function, transition_function 145 | 146 | 147 | @nb.njit 148 | def _get_data(x, dt, a_s, s1, s2, r, normals, observations, true_states): 149 | for i, a in enumerate(a_s): 150 | with nb.objmode(x='float32[::1]'): 151 | F = np.array([[0, 0, 1, 0], 152 | [0, 0, 0, 1], 153 | [0, 0, 0, a], 154 | [0, 0, -a, 0]], dtype=np.float32) 155 | x = linalg.expm(F * dt) @ x 156 | y1 = np.arctan2(x[1] - s1[1], x[0] - s1[0]) + r * normals[i, 0] 157 | y2 = np.arctan2(x[1] - s2[1], x[0] - s2[0]) + r * normals[i, 1] 158 | 159 | observations[i] = [y1, y2] 160 | observations[i] = [y1, y2] 161 | true_states[i] = np.concatenate((x, np.array([a]))) 162 | # return true_states, observations 163 | 164 | 165 | def get_data(x0, dt, r, T, s1, s2, q=10., random_state=None): 166 | """ 167 | 168 | Parameters 169 | ---------- 170 | x0: array_like 171 | true initial state 172 | dt: float 173 | time step for observations 174 | r: float 175 | observation model standard deviation 176 | T: int 177 | number of time steps 178 | s1: array_like 179 | The location of the first sensor 180 | s2: array_like 181 | The location of the second sensor 182 | q: float 183 | noise of the angular momentum 184 | random_state: np.random.RandomState or int, optional 185 | numpy random state 186 | 187 | Returns 188 | ------- 189 | ts: array_like 190 | array of time steps 191 | true_states: array_like 192 | array of true states 193 | observations: array_like 194 | array of observations 195 | """ 196 | if random_state is None or isinstance(random_state, int): 197 | random_state = np.random.RandomState(random_state) 198 | a_s = 1 + q * dt * np.cumsum(random_state.randn(T)) 199 | a_s = a_s.astype(np.float32) 200 | s1 = np.asarray(s1, dtype=np.float32) 201 | s2 = np.asarray(s2, dtype=np.float32) 202 | 203 | x = np.copy(x0).astype(np.float32) 204 | observations = np.empty((T, 2), dtype=np.float32) 205 | true_states = np.zeros((T+1, 5), dtype=np.float32) 206 | ts = np.linspace(dt, (T + 1) * dt, T).astype(np.float32) 207 | true_states[0, :4] = x 208 | normals = random_state.randn(T, 2).astype(np.float32) 209 | 210 | _get_data(x, dt, a_s, s1, s2, r, normals, observations, true_states[1:]) 211 | return ts, true_states, observations 212 | 213 | 214 | def plot_bearings(states, labels, s1, s2, figsize=(10, 10), quiver=False): 215 | """ 216 | 217 | Parameters 218 | ---------- 219 | states: list of array_like 220 | list of states to plot 221 | labels: list of str 222 | list of lables for the states 223 | s1: array_like 224 | first sensor 225 | s2: array_like 226 | second sensor 227 | figsize: tuple of int 228 | figure size in inches 229 | quiver: bool 230 | show the velocity field 231 | 232 | Returns 233 | ------- 234 | 235 | """ 236 | fig, ax = plt.subplots(figsize=figsize) 237 | 238 | if not isinstance(states, list): 239 | states = [states] 240 | 241 | if not isinstance(labels, list): 242 | labels = [labels] 243 | 244 | for label, state in zip(labels, states): 245 | ax.plot(*state[:, :2].T, linestyle='--', label=label, alpha=0.75) 246 | if quiver: 247 | ax.quiver(*state[::10].T, units='xy', scale=4, width=0.01) 248 | ax.scatter(*s1, marker="o", s=200, label="Sensor 1", color='k') 249 | ax.scatter(*s2, marker="x", s=200, label="Sensor 2", color='k') 250 | 251 | ax.legend(loc="lower left") 252 | -------------------------------------------------------------------------------- /parsmooth/models/linear.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from jax import jit 6 | 7 | __all__ = ["make_parameters", "get_data"] 8 | 9 | 10 | def _transition_function(x, A): 11 | """ Deterministic transition function used in the state space model 12 | 13 | Parameters 14 | ---------- 15 | x: array_like 16 | The current state 17 | A: array_like 18 | transition matrix 19 | 20 | Returns 21 | ------- 22 | out: array_like 23 | The transitioned state 24 | """ 25 | return jnp.dot(A, x) 26 | 27 | 28 | def _observation_function(x, H): 29 | """ 30 | Returns the observed angles as function of the state and the sensors locations 31 | 32 | Parameters 33 | ---------- 34 | x: array_like 35 | The current state 36 | H: array_like 37 | observation matrix 38 | 39 | Returns 40 | ------- 41 | y: array_like 42 | The observed data 43 | """ 44 | return jnp.dot(H, x) 45 | 46 | 47 | def make_parameters(r, q): 48 | A = 0.5 * jnp.eye(2) 49 | Q = q * jnp.eye(2) 50 | R = r * jnp.eye(1) 51 | H = jnp.array([[1., 0.5]]) 52 | 53 | observation_function = jit(partial(_observation_function, H=H)) 54 | transition_function = jit(partial(_transition_function, A=A)) 55 | 56 | return A, H, Q, R, observation_function, transition_function 57 | 58 | 59 | def get_data(x0, A, H, R, Q, T, random_state=None): 60 | """ 61 | 62 | Parameters 63 | ---------- 64 | x0: array_like 65 | true initial state 66 | A: array_like 67 | transition matrix 68 | H: array_like 69 | transition matrix 70 | R: array_like 71 | observation model covariance 72 | Q: array_like 73 | noise covariance 74 | s1: array_like 75 | The location of the first sensor 76 | s2: array_like 77 | The location of the second sensor 78 | random_state: np.random.RandomState or int, optional 79 | numpy random state 80 | 81 | Returns 82 | ------- 83 | ts: array_like 84 | array of time steps 85 | true_states: array_like 86 | array of true states 87 | observations: array_like 88 | array of observations 89 | """ 90 | if random_state is None or isinstance(random_state, int): 91 | random_state = np.random.RandomState(random_state) 92 | 93 | R_shape = R.shape[0] 94 | Q_shape = Q.shape[0] 95 | normals = random_state.randn(T, Q_shape + R_shape).astype(np.float32) 96 | chol_R = np.linalg.cholesky(R) 97 | chol_Q = np.linalg.cholesky(Q) 98 | 99 | x = np.copy(x0).astype(np.float32) 100 | observations = np.empty((T, R_shape), dtype=np.float32) 101 | true_states = np.empty((T, Q_shape), dtype=np.float32) 102 | 103 | for i in range(T): 104 | x = A @ x + chol_Q @ normals[i, :Q_shape] 105 | y = H @ x + chol_R @ normals[i, Q_shape:] 106 | true_states[i] = x 107 | observations[i] = y 108 | 109 | ts = np.linspace(1, T, T).astype(np.float32) 110 | 111 | return ts, true_states, observations 112 | -------------------------------------------------------------------------------- /parsmooth/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .ckf import filter_routine as ckf 2 | from .cks import smoother_routine as cks, iterated_smoother_routine as icks 3 | from .ekf import filter_routine as ekf 4 | from .eks import smoother_routine as eks, iterated_smoother_routine as ieks 5 | -------------------------------------------------------------------------------- /parsmooth/parallel/ckf.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import jax.scipy.linalg as jlinalg 6 | from jax import lax, vmap 7 | 8 | from parsmooth.utils import MVNormalParameters 9 | from .operators import filtering_operator 10 | from ..cubature_common import get_sigma_points, get_mv_normal_parameters, covariance_sigma_points, SigmaPoints 11 | 12 | 13 | def make_associative_filtering_params(observation_function, Rk, transition_function, Qk_1, yk, i, initial_state, 14 | prev_linearization_state, linearization_state, propagate_first): 15 | predicate = i == 0 16 | 17 | def _first(_): 18 | return _make_associative_filtering_params_first(observation_function, Rk, transition_function, Qk_1, 19 | initial_state, prev_linearization_state, linearization_state, 20 | yk, propagate_first) 21 | 22 | def _generic(_): 23 | return _make_associative_filtering_params_generic(observation_function, Rk, transition_function, Qk_1, 24 | prev_linearization_state, linearization_state, yk) 25 | 26 | return lax.cond(predicate, 27 | _first, # take initial 28 | _generic, # take generic 29 | None) 30 | 31 | 32 | def _make_associative_filtering_params_first(observation_function, R, transition_function, Q, initial_state, 33 | prev_linearization_state, linearization_state, y, propagate_first): 34 | # Prediction part 35 | 36 | if propagate_first: 37 | initial_sigma_points = get_sigma_points(prev_linearization_state) 38 | propagated_points = transition_function(initial_sigma_points.points) 39 | propagated_sigma_points = SigmaPoints(propagated_points, initial_sigma_points.wm, initial_sigma_points.wc) 40 | propagated_state = get_mv_normal_parameters(propagated_sigma_points) 41 | 42 | pred_cross_covariance = covariance_sigma_points(initial_sigma_points, prev_linearization_state.mean, 43 | propagated_sigma_points, 44 | propagated_state.mean) 45 | 46 | F = jlinalg.solve(prev_linearization_state.cov, pred_cross_covariance, 47 | assume_a="pos").T # Linearized transition function 48 | 49 | m = propagated_state.mean + F @ (initial_state.mean - prev_linearization_state.mean) 50 | P = propagated_state.cov + Q + F @ (initial_state.cov - prev_linearization_state.cov) @ F.T 51 | linearization_points = get_sigma_points(linearization_state) 52 | obs_points = observation_function(linearization_points.points) 53 | obs_sigma_points = SigmaPoints(obs_points, linearization_points.wm, linearization_points.wc) 54 | obs_mvn = get_mv_normal_parameters(obs_sigma_points) 55 | update_cross_covariance = covariance_sigma_points(linearization_points, linearization_state.mean, 56 | obs_sigma_points, obs_mvn.mean) 57 | 58 | H = jlinalg.solve(linearization_state.cov, update_cross_covariance, assume_a="pos").T 59 | d = obs_mvn.mean - jnp.dot(H, linearization_state.mean) 60 | predicted_observation = H @ m + d 61 | 62 | S = H @ (P - linearization_state.cov) @ H.T + R + obs_mvn.cov 63 | else: 64 | m = initial_state.mean 65 | P = initial_state.cov 66 | linearization_points = get_sigma_points(prev_linearization_state) 67 | obs_points = observation_function(linearization_points.points) 68 | obs_sigma_points = SigmaPoints(obs_points, linearization_points.wm, linearization_points.wc) 69 | obs_mvn = get_mv_normal_parameters(obs_sigma_points) 70 | update_cross_covariance = covariance_sigma_points(linearization_points, linearization_state.mean, 71 | obs_sigma_points, obs_mvn.mean) 72 | 73 | H = jlinalg.solve(prev_linearization_state.cov, update_cross_covariance, assume_a="pos").T 74 | d = obs_mvn.mean - jnp.dot(H, prev_linearization_state.mean) 75 | predicted_observation = H @ m + d 76 | 77 | S = H @ (P - prev_linearization_state.cov) @ H.T + R + obs_mvn.cov 78 | 79 | K = jlinalg.solve(S, H @ P, assume_a="pos").T 80 | A = jnp.zeros_like(initial_state.cov) 81 | b = m + K @ (y - predicted_observation) 82 | C = P - K @ S @ K.T 83 | 84 | eta = jnp.zeros_like(initial_state.mean) 85 | J = jnp.zeros_like(initial_state.cov) 86 | 87 | return A, b, 0.5 * (C + C.T), eta, 0.5 * (J + J.T) 88 | 89 | 90 | def _make_associative_filtering_params_generic(observation_function, Rk, transition_function, Qk_1, 91 | prev_linearization_state, linearization_state, yk): 92 | # Prediction part 93 | sigma_points = get_sigma_points(prev_linearization_state) 94 | 95 | propagated_points = transition_function(sigma_points.points) 96 | propagated_sigma_points = SigmaPoints(propagated_points, sigma_points.wm, sigma_points.wc) 97 | propagated_state = get_mv_normal_parameters(propagated_sigma_points) 98 | 99 | pred_cross_covariance = covariance_sigma_points(sigma_points, prev_linearization_state.mean, 100 | propagated_sigma_points, 101 | propagated_state.mean) 102 | 103 | F = jlinalg.solve(prev_linearization_state.cov, pred_cross_covariance, 104 | assume_a="pos").T # Linearized transition function 105 | pred_mean_residual = propagated_state.mean - F @ prev_linearization_state.mean 106 | pred_cov_residual = propagated_state.cov - F @ prev_linearization_state.cov @ F.T + Qk_1 107 | 108 | # Update part 109 | linearization_points = get_sigma_points(linearization_state) 110 | obs_points = observation_function(linearization_points.points) 111 | obs_sigma_points = SigmaPoints(obs_points, linearization_points.wm, linearization_points.wc) 112 | obs_mvn = get_mv_normal_parameters(obs_sigma_points) 113 | update_cross_covariance = covariance_sigma_points(linearization_points, 114 | linearization_state.mean, 115 | obs_sigma_points, 116 | obs_mvn.mean) 117 | 118 | H = jlinalg.solve(linearization_state.cov, update_cross_covariance, assume_a="pos").T 119 | obs_mean_residual = obs_mvn.mean - jnp.dot(H, linearization_state.mean) 120 | obs_cov_residual = obs_mvn.cov - H @ linearization_state.cov @ H.T 121 | 122 | S = H @ pred_cov_residual @ H.T + Rk + obs_cov_residual # total residual covariance 123 | total_obs_residual = (yk - H @ pred_mean_residual - obs_mean_residual) 124 | S_invH = jlinalg.solve(S, H, assume_a="pos") 125 | 126 | K = (S_invH @ pred_cov_residual).T 127 | A = F - K @ H @ F 128 | b = pred_mean_residual + K @ total_obs_residual 129 | C = pred_cov_residual - K @ S @ K.T 130 | 131 | temp = (S_invH @ F).T 132 | HF = H @ F 133 | 134 | eta = temp @ total_obs_residual 135 | J = temp @ HF 136 | return A, b, 0.5 * (C + C.T), eta, 0.5 * (J + J.T) 137 | 138 | 139 | def filter_routine(initial_state: MVNormalParameters, 140 | observations: jnp.ndarray, 141 | transition_function: Callable, 142 | transition_covariance: jnp.ndarray, 143 | observation_function: Callable, 144 | observation_covariance: jnp.ndarray, 145 | linearization_states: MVNormalParameters = None, 146 | propagate_first: bool = True): 147 | """ Computes the predict-update routine of the Cubature Kalman Filter equations 148 | using temporal parallelization and returns a series of filtered_states TODO:reference 149 | 150 | Parameters 151 | ---------- 152 | initial_state: MVNormalParameters 153 | prior belief on the initial state distribution 154 | observations: (n, K) array 155 | array of n observations of dimension K 156 | transition_function: callable 157 | transition function of the state space model 158 | transition_covariance: (D, D) array 159 | transition covariance for each time step 160 | observation_function: callable 161 | observation function of the state space model 162 | observation_covariance: (K, K) array 163 | observation error covariances for each time step 164 | linearization_states: MVNormalParameters, optional 165 | in the case of Sigma-Point . 166 | propagate_first: bool, optional 167 | Is the first step a transition or an update? i.e. False if the initial time step has 168 | an associated observation. Default is True. 169 | 170 | Returns 171 | ------- 172 | filtered_states: MVNormalParameters 173 | list of filtered states 174 | 175 | """ 176 | n_observations = observations.shape[0] 177 | x_dim = initial_state.mean.shape[0] 178 | dtype = initial_state.mean.dtype 179 | 180 | if linearization_states is not None: 181 | if propagate_first: 182 | x_k_1_s = jax.tree_map(lambda z: z[:-1], linearization_states) 183 | x_k_s = jax.tree_map(lambda z: z[1:], linearization_states) 184 | else: 185 | x_k_1_s = jax.tree_map(lambda z: jnp.concatenate([z[None, 0], z[:-1]], 0), linearization_states) 186 | x_k_s = linearization_states 187 | else: 188 | 189 | m_k_s = jnp.zeros((n_observations, x_dim), dtype=dtype) 190 | P_k_s = jnp.repeat(jnp.eye(x_dim)[None, ...], n_observations, axis=0) 191 | x_k_1_s = x_k_s = MVNormalParameters(m_k_s, P_k_s) 192 | 193 | @vmap 194 | def make_params(obs, i, prev_linearization_state, linearisation_state): 195 | return make_associative_filtering_params(observation_function, observation_covariance, transition_function, 196 | transition_covariance, obs, i, initial_state, 197 | prev_linearization_state, linearisation_state, propagate_first) 198 | 199 | As, bs, Cs, etas, Js = make_params(observations, jnp.arange(n_observations), x_k_1_s, 200 | x_k_s) 201 | _, filtered_means, filtered_covariances, _, _ = lax.associative_scan(filtering_operator, (As, bs, Cs, etas, Js)) 202 | 203 | filtered_states = MVNormalParameters(filtered_means, filtered_covariances) 204 | if propagate_first: 205 | filtered_states = jax.tree_map(lambda x, y: jnp.concatenate([x[None, ...], y], 0), 206 | initial_state, filtered_states) 207 | return filtered_states 208 | -------------------------------------------------------------------------------- /parsmooth/parallel/cks.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax.numpy as jnp 4 | import jax.scipy.linalg as jlinalg 5 | from jax import lax, vmap 6 | 7 | from parsmooth.parallel.ckf import filter_routine 8 | from parsmooth.utils import MVNormalParameters 9 | from .operators import smoothing_operator 10 | from ..cubature_common import get_sigma_points, SigmaPoints, get_mv_normal_parameters, covariance_sigma_points 11 | 12 | 13 | def make_associative_smoothing_params(transition_function, Qk, i, n, filtered_state, linearization_state): 14 | predicate = i == n - 1 15 | 16 | def _last(_): 17 | return filtered_state.mean, jnp.zeros_like(filtered_state.cov), filtered_state.cov 18 | 19 | def _generic(_): 20 | return _make_associative_smoothing_params_generic(transition_function, Qk, filtered_state, linearization_state) 21 | 22 | return lax.cond(predicate, 23 | _last, # take initial 24 | _generic, # take generic 25 | None) 26 | 27 | 28 | def _make_associative_smoothing_params_generic(transition_function, Qk, filtered_state, linearization_state): 29 | # Prediction part 30 | sigma_points = get_sigma_points(linearization_state) 31 | 32 | propagated_points = transition_function(sigma_points.points) 33 | propagated_sigma_points = SigmaPoints(propagated_points, sigma_points.wm, sigma_points.wc) 34 | propagated_state = get_mv_normal_parameters(propagated_sigma_points) 35 | 36 | pred_cross_covariance = covariance_sigma_points(sigma_points, 37 | linearization_state.mean, 38 | propagated_sigma_points, 39 | propagated_state.mean) 40 | 41 | F = jlinalg.solve(linearization_state.cov, pred_cross_covariance, 42 | assume_a="pos").T # Linearized transition function 43 | 44 | Pp = Qk + propagated_state.cov + F @ (filtered_state.cov - linearization_state.cov) @ F.T 45 | 46 | E = jlinalg.solve(Pp, F @ filtered_state.cov, assume_a="pos").T 47 | 48 | g = filtered_state.mean - E @ (propagated_state.mean + F @ (filtered_state.mean - linearization_state.mean)) 49 | L = filtered_state.cov - E @ F @ filtered_state.cov 50 | 51 | return g, E, 0.5 * (L + L.T) 52 | 53 | 54 | def smoother_routine(transition_function: Callable, 55 | transition_covariance: jnp.ndarray, 56 | filtered_states: MVNormalParameters, 57 | linearization_states: MVNormalParameters = None): 58 | """ Computes the predict-update routine of the Extended Kalman Filter equations 59 | using temporal parallelization and returns a series of filtered_states TODO:reference 60 | 61 | Parameters 62 | ---------- 63 | transition_function: callable 64 | transition function of the state space model 65 | transition_covariance: (D, D) array 66 | transition covariance for each time step 67 | observation error covariances for each time step 68 | filtered_states: MVNormalParameters 69 | states resulting from (iterated) EKF 70 | linearization_states: MVNormalParameters, optional 71 | states at which to compute the cubature linearized functions 72 | 73 | Returns 74 | ------- 75 | filtered_states: MVNormalParameters 76 | list of filtered states 77 | 78 | """ 79 | n_observations = filtered_states.mean.shape[0] 80 | 81 | @vmap 82 | def make_params(i, filtered_state, linearization_state): 83 | if linearization_state is None: 84 | linearization_state = filtered_state 85 | return make_associative_smoothing_params(transition_function, transition_covariance, 86 | i, n_observations, filtered_state, linearization_state) 87 | 88 | gs, Es, Ls = make_params(jnp.arange(n_observations), filtered_states, linearization_states) 89 | 90 | smoothed_means, _, smoothed_covariances = lax.associative_scan(smoothing_operator, (gs, Es, Ls), reverse=True) 91 | 92 | return vmap(MVNormalParameters)(smoothed_means, smoothed_covariances) 93 | 94 | 95 | def iterated_smoother_routine(initial_state: MVNormalParameters, 96 | observations: jnp.ndarray, 97 | transition_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 98 | transition_covariance: jnp.ndarray, 99 | observation_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 100 | observation_covariance: jnp.ndarray, 101 | initial_linearization_states: MVNormalParameters = None, 102 | n_iter: int = 100, 103 | propagate_first: bool = True): 104 | """ 105 | Computes the Gauss-Newton iterated cubature Kalman smoother 106 | 107 | Parameters 108 | ---------- 109 | initial_state: MVNormalParameters 110 | prior belief on the initial state distribution 111 | observations: (n, K) array 112 | array of n observations of dimension K 113 | transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` 114 | transition function of the state space model 115 | transition_covariance: (D, D) array 116 | transition covariances for each time step, if passed only one, it is repeated n times 117 | observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` 118 | observation function of the state space model 119 | observation_covariance: (K, K) array 120 | observation error covariances for each time step, if passed only one, it is repeated n times 121 | initial_linearization_states: MVNormalParameters, optional 122 | points at which to compute the jacobians durning the first pass. 123 | n_iter: int 124 | number of times the filter-smoother routine is computed 125 | propagate_first: bool, optional 126 | Is the first step a transition or an update? i.e. False if the initial time step has 127 | an associated observation. Default is True. 128 | Returns 129 | ------- 130 | iterated_smoothed_trajectories: MVNormalParameters 131 | The result of the smoothing routine 132 | 133 | """ 134 | 135 | def body(linearization_points, _): 136 | filtered_states = filter_routine(initial_state, observations, transition_function, transition_covariance, 137 | observation_function, observation_covariance, linearization_points, 138 | propagate_first) 139 | return smoother_routine(transition_function, transition_covariance, filtered_states, 140 | linearization_points), None 141 | 142 | if initial_linearization_states is None: 143 | initial_linearization_states = body(None, None)[0] 144 | 145 | iterated_smoothed_trajectories, _ = lax.scan(body, initial_linearization_states, jnp.arange(n_iter)) 146 | return iterated_smoothed_trajectories 147 | -------------------------------------------------------------------------------- /parsmooth/parallel/ekf.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import jax.scipy.linalg as jlinalg 6 | from jax import lax, vmap, jacfwd 7 | 8 | from parsmooth.utils import MVNormalParameters 9 | from .operators import filtering_operator 10 | 11 | 12 | def make_associative_filtering_params(observation_function, Rk, transition_function, Qk_1, yk, i, m0, P0, x_k_1, x_k, 13 | propagate_first): 14 | predicate = i == 0 15 | 16 | jac_obs = jacfwd(observation_function, 0) 17 | jac_trans = jacfwd(transition_function, 0) 18 | 19 | def _first(_): 20 | return _make_associative_filtering_params_first(observation_function, jac_obs, Rk, transition_function, 21 | jac_trans, Qk_1, m0, P0, x_k_1, x_k, yk, propagate_first) 22 | 23 | def _generic(_): 24 | return _make_associative_filtering_params_generic(observation_function, jac_obs, Rk, transition_function, 25 | jac_trans, x_k_1, x_k, Qk_1, yk) 26 | 27 | return lax.cond(predicate, 28 | _first, # take initial 29 | _generic, # take generic 30 | None) 31 | 32 | 33 | def _make_associative_filtering_params_first(observation_function, jac_observation_function, R, transition_function, 34 | jac_transition_function, Q, m0, P0, x_k_1, x_k, y, propagate_first): 35 | if propagate_first: 36 | F = jac_transition_function(x_k_1) 37 | m = F @ (m0 - x_k_1) + transition_function(x_k_1) 38 | P = F @ P0 @ F.T + Q 39 | H = jac_observation_function(x_k) 40 | alpha = observation_function(x_k) + H @ (m - x_k) 41 | else: 42 | P = P0 43 | m = m0 44 | H = jac_observation_function(x_k_1) 45 | alpha = observation_function(x_k_1) + H @ (m0 - x_k_1) 46 | 47 | S = H @ P @ H.T + R 48 | K = jlinalg.solve(S, H @ P, assume_a="pos").T 49 | A = jnp.zeros_like(P0) 50 | 51 | b = m + K @ (y - alpha) 52 | C = P - (K @ S @ K.T) 53 | 54 | eta = jnp.zeros_like(m0) 55 | J = jnp.zeros_like(P0) 56 | 57 | return A, b, C, eta, J 58 | 59 | 60 | def _make_associative_filtering_params_generic(observation_function, jac_observation_function, Rk, transition_function, 61 | jac_transition_function, x_k_1, x_k, Qk_1, yk): 62 | F = jac_transition_function(x_k_1) 63 | H = jac_observation_function(x_k) 64 | 65 | F_x_k_1 = F @ x_k_1 66 | x_k_hat = transition_function(x_k_1) 67 | 68 | alpha = observation_function(x_k) + H @ (x_k_hat - F_x_k_1 - x_k) 69 | residual = yk - alpha 70 | HQ = H @ Qk_1 71 | 72 | S = HQ @ H.T + Rk 73 | S_invH = jlinalg.solve(S, H, assume_a="pos") 74 | K = (S_invH @ Qk_1).T 75 | A = F - K @ H @ F 76 | b = K @ residual + x_k_hat - F_x_k_1 77 | C = Qk_1 - K @ H @ Qk_1 78 | 79 | HF = H @ F 80 | 81 | temp = (S_invH @ F).T 82 | eta = temp @ residual 83 | J = temp @ HF 84 | 85 | return A, b, C, eta, J 86 | 87 | 88 | def filter_routine(initial_state: MVNormalParameters, 89 | observations: jnp.ndarray, 90 | transition_function: Callable, 91 | transition_covariance: jnp.ndarray, 92 | observation_function: Callable, 93 | observation_covariance: jnp.ndarray, 94 | linearization_points: jnp.ndarray = None, 95 | propagate_first: bool = True 96 | ): 97 | """ Computes the predict-update routine of the Extended Kalman Filter equations 98 | using temporal parallelization and returns a series of filtered_states TODO:reference 99 | 100 | Parameters 101 | ---------- 102 | initial_state: MVNormalParameters 103 | prior belief on the initial state distribution 104 | observations: (n, K) array 105 | array of n observations of dimension K 106 | transition_function: callable 107 | transition function of the state space model 108 | transition_covariance: (D, D) array 109 | transition covariance for each time step 110 | observation_function: callable 111 | observation function of the state space model 112 | observation_covariance: (K, K) array 113 | observation error covariances for each time step 114 | linearization_points: (n, D) array, optional 115 | points at which to compute the jacobians. 116 | propagate_first: bool, optional 117 | Is the first step a transition or an update? i.e. False if the initial time step has 118 | an associated observation. Default is True. 119 | Returns 120 | ------- 121 | filtered_states: MVNormalParameters 122 | list of filtered states 123 | 124 | """ 125 | n_observations = observations.shape[0] 126 | x_dim = initial_state.mean.shape[0] 127 | dtype = initial_state.mean.dtype 128 | 129 | @vmap 130 | def make_params(obs, i, x_k_1, x_k): 131 | return make_associative_filtering_params(observation_function, observation_covariance, 132 | transition_function, transition_covariance, obs, 133 | i, initial_state.mean, 134 | initial_state.cov, x_k_1, x_k, propagate_first) 135 | 136 | if linearization_points is not None: 137 | if propagate_first: 138 | x_k_1_s = linearization_points[:-1] 139 | x_k_s = linearization_points[1:] 140 | else: 141 | x_k_1_s = jnp.concatenate([linearization_points[None, 0], linearization_points[:-1]]) 142 | x_k_s = linearization_points 143 | else: 144 | x_k_1_s = x_k_s = jnp.zeros((n_observations, x_dim), dtype=dtype) 145 | 146 | As, bs, Cs, etas, Js = make_params(observations, jnp.arange(n_observations), x_k_1_s, x_k_s) 147 | _, filtered_means, filtered_covariances, _, _ = lax.associative_scan(filtering_operator, (As, bs, Cs, etas, Js)) 148 | filtered_states = MVNormalParameters(filtered_means, filtered_covariances) 149 | if propagate_first: 150 | filtered_states = jax.tree_map(lambda x, y: jnp.concatenate([x[None, ...], y], 0), 151 | initial_state, filtered_states) 152 | return filtered_states 153 | -------------------------------------------------------------------------------- /parsmooth/parallel/eks.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax.numpy as jnp 4 | import jax.scipy.linalg as jlinalg 5 | from jax import lax, vmap, jacfwd 6 | 7 | from parsmooth.utils import MVNormalParameters 8 | from .ekf import filter_routine 9 | from .operators import smoothing_operator 10 | 11 | 12 | def make_associative_smoothing_params(transition_function, Qk, i, n, mk, Pk, xk): 13 | predicate = i == n - 1 14 | 15 | jac_trans = jacfwd(transition_function, 0) 16 | 17 | def _last(_): 18 | return mk, jnp.zeros_like(Pk), Pk 19 | 20 | def _generic(_): 21 | return _make_associative_smoothing_params_generic(transition_function, jac_trans, Qk, mk, Pk, xk) 22 | 23 | return lax.cond(predicate, 24 | _last, # take initial 25 | _generic, # take generic 26 | None) 27 | 28 | 29 | def _make_associative_smoothing_params_generic(transition_function, jac_transition_function, Qk, mk, Pk, xk): 30 | F = jac_transition_function(xk) 31 | Pp = F @ Pk @ F.T + Qk 32 | 33 | E = jlinalg.solve(Pp, F @ Pk, assume_a="pos").T 34 | 35 | g = mk - E @ (transition_function(xk) + F @ (mk - xk)) 36 | L = Pk - E @ Pp @ E.T 37 | 38 | return g, E, L 39 | 40 | 41 | def smoother_routine(transition_function: Callable, 42 | transition_covariance: jnp.ndarray, 43 | filtered_states: MVNormalParameters, 44 | linearisation_points: jnp.ndarray = None): 45 | """ Computes the predict-update routine of the Extended Kalman Filter equations 46 | using temporal parallelization and returns a series of filtered_states TODO:reference 47 | 48 | Parameters 49 | ---------- 50 | transition_function: callable 51 | transition function of the state space model 52 | transition_covariance: (D, D) array 53 | transition covariance for each time step 54 | observation error covariances for each time step 55 | filtered_states: MVNormalParameters 56 | states resulting from (iterated) EKF 57 | linearisation_points: (n, D) array, optional 58 | points at which to compute the jacobians, typically previous run. 59 | 60 | Returns 61 | ------- 62 | filtered_states: MVNormalParameters 63 | list of filtered states 64 | 65 | """ 66 | n_observations = filtered_states.mean.shape[0] 67 | 68 | if linearisation_points is None: 69 | linearisation_points = filtered_states.mean 70 | 71 | @vmap 72 | def make_params(i, mk, Pk, xk): 73 | return make_associative_smoothing_params(transition_function, transition_covariance, 74 | i, n_observations, mk, Pk, xk) 75 | 76 | gs, Es, Ls = make_params(jnp.arange(n_observations), filtered_states.mean, 77 | filtered_states.cov, linearisation_points) 78 | 79 | smoothed_means, _, smoothed_covariances = lax.associative_scan(smoothing_operator, (gs, Es, Ls), reverse=True) 80 | return MVNormalParameters(smoothed_means, smoothed_covariances) 81 | 82 | 83 | def iterated_smoother_routine(initial_state: MVNormalParameters, 84 | observations: jnp.ndarray, 85 | transition_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 86 | transition_covariance: jnp.ndarray, 87 | observation_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 88 | observation_covariance: jnp.ndarray, 89 | initial_linearization_points: jnp.ndarray = None, 90 | n_iter: int = 100, 91 | propagate_first: bool = True): 92 | """ 93 | Computes the Gauss-Newton iterated extended Kalman smoother 94 | 95 | Parameters 96 | ---------- 97 | initial_state: MVNormalParameters 98 | prior belief on the initial state distribution 99 | observations: (n, K) array 100 | array of n observations of dimension K 101 | transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` 102 | transition function of the state space model 103 | transition_covariance: (D, D) array 104 | transition covariances for each time step, if passed only one, it is repeated n times 105 | observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` 106 | observation function of the state space model 107 | observation_covariance: (K, K) array 108 | observation error covariances for each time step, if passed only one, it is repeated n times 109 | initial_linearization_points: (N, D) array, optional 110 | points at which to compute the jacobians durning the first pass. 111 | n_iter: int 112 | number of times the filter-smoother routine is computed 113 | propagate_first: bool, optional 114 | Is the first step a transition or an update? i.e. False if the initial time step has 115 | an associated observation. Default is True. 116 | Returns 117 | ------- 118 | iterated_smoothed_trajectories: MVNormalParameters 119 | The result of the smoothing routine 120 | 121 | """ 122 | 123 | def body(linearization_points, _): 124 | if linearization_points is not None: 125 | linearization_points = linearization_points.mean if isinstance(linearization_points, 126 | MVNormalParameters) else linearization_points 127 | filtered_states = filter_routine(initial_state, observations, transition_function, transition_covariance, 128 | observation_function, observation_covariance, linearization_points, 129 | propagate_first) 130 | return smoother_routine(transition_function, transition_covariance, filtered_states, 131 | linearization_points), None 132 | 133 | if initial_linearization_points is None: 134 | initial_linearization_points = body(None, None)[0] 135 | 136 | iterated_smoothed_trajectories, _ = lax.scan(body, initial_linearization_points, jnp.arange(n_iter)) 137 | return iterated_smoothed_trajectories 138 | -------------------------------------------------------------------------------- /parsmooth/parallel/operators.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.scipy.linalg as jlinalg 3 | from jax import vmap 4 | 5 | 6 | @vmap 7 | def filtering_operator(elem1, elem2): 8 | """ 9 | Associative operator described in TODO: put the reference 10 | 11 | Parameters 12 | ---------- 13 | elem1: tuple of array 14 | a_i, b_i, C_i, eta_i, J_i 15 | elem2: tuple of array 16 | a_j, b_j, C_j, eta_j, J_j 17 | 18 | Returns 19 | ------- 20 | 21 | """ 22 | A1, b1, C1, eta1, J1 = elem1 23 | A2, b2, C2, eta2, J2 = elem2 24 | dim = b1.shape[0] 25 | 26 | I_dim = jnp.eye(dim) 27 | 28 | IpCJ = I_dim + jnp.dot(C1, J2) 29 | IpJC = I_dim + jnp.dot(J2, C1) 30 | 31 | AIpCJ_inv = jlinalg.solve(IpCJ.T, A2.T, assume_a="gen").T 32 | AIpJC_inv = jlinalg.solve(IpJC.T, A1, assume_a="gen").T 33 | 34 | A = jnp.dot(AIpCJ_inv, A1) 35 | b = jnp.dot(AIpCJ_inv, b1 + jnp.dot(C1, eta2)) + b2 36 | C = jnp.dot(AIpCJ_inv, jnp.dot(C1, A2.T)) + C2 37 | eta = jnp.dot(AIpJC_inv, eta2 - jnp.dot(J2, b1)) + eta1 38 | J = jnp.dot(AIpJC_inv, jnp.dot(J2, A1)) + J1 39 | return A, b, 0.5 * (C + C.T), eta, 0.5 * (J + J.T) 40 | 41 | 42 | @vmap 43 | def smoothing_operator(elem1, elem2): 44 | """ 45 | Associative operator described in TODO: put the reference 46 | 47 | Parameters 48 | ---------- 49 | elem1: tuple of array 50 | g_i, E_i, L_i 51 | elem2: tuple of array 52 | g_j, E_j, L_j 53 | 54 | Returns 55 | ------- 56 | 57 | """ 58 | g1, E1, L1 = elem1 59 | g2, E2, L2 = elem2 60 | 61 | g = E2 @ g1 + g2 62 | E = E2 @ E1 63 | L = E2 @ L1 @ E2.T + L2 64 | return g, E, 0.5 * (L + L.T) 65 | -------------------------------------------------------------------------------- /parsmooth/sequential/__init__.py: -------------------------------------------------------------------------------- 1 | from .cubature import filter_routine as ckf, smoother_routine as cks, iterated_smoother_routine as icks 2 | from .extended import filter_routine as ekf, smoother_routine as eks, iterated_smoother_routine as ieks 3 | -------------------------------------------------------------------------------- /parsmooth/sequential/cubature.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Callable 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import jax.scipy.linalg as jlinalg 6 | from jax import lax 7 | from jax.scipy.stats import multivariate_normal 8 | 9 | from ..cubature_common import SigmaPoints, get_sigma_points, get_mv_normal_parameters, covariance_sigma_points 10 | from ..utils import MVNormalParameters, make_matrices_parameters 11 | 12 | 13 | def predict(transition_function: Callable, 14 | transition_covariance: jnp.ndarray, 15 | previous_state: MVNormalParameters, 16 | linearization_state: MVNormalParameters, 17 | return_linearized_transition: bool = False 18 | ) -> MVNormalParameters: 19 | """ Computes the cubature Kalman filter linearization of :math:`x_{t+1} = f(x_t, \mathcal{N}(0, \Sigma))` 20 | 21 | Parameters 22 | ---------- 23 | transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` 24 | transition function of the state space model 25 | transition_covariance: (D,D) array 26 | covariance :math:`\Sigma` of the noise fed to transition_function 27 | previous_state: MVNormalParameters 28 | previous state for the filter x 29 | linearization_state: MVNormalParameters 30 | state for the linearization of the prediction 31 | return_linearized_transition: bool, optional 32 | Returns the linearized transition matrix A 33 | 34 | Returns 35 | ------- 36 | mvn_parameters: MVNormalParameters 37 | Propagated approximate Normal distribution 38 | 39 | F: array_like 40 | returned if return_linearized_transition is True 41 | """ 42 | if linearization_state is None: 43 | linearization_state = previous_state 44 | 45 | sigma_points = get_sigma_points(linearization_state) 46 | propagated_points = transition_function(sigma_points.points) 47 | propagated_sigma_points = SigmaPoints(propagated_points, 48 | sigma_points.wm, 49 | sigma_points.wc) 50 | 51 | propagated_state = get_mv_normal_parameters(propagated_sigma_points) 52 | cross_covariance = covariance_sigma_points(sigma_points, linearization_state.mean, propagated_sigma_points, 53 | propagated_state.mean) 54 | 55 | F = jlinalg.solve(linearization_state.cov, cross_covariance, assume_a="pos").T # Linearized transition function 56 | b = propagated_state.mean - jnp.dot(F, linearization_state.mean) # Linearized offset 57 | 58 | mean = F @ previous_state.mean + b 59 | cov = transition_covariance + propagated_state.cov + F @ (previous_state.cov - linearization_state.cov) @ F.T 60 | if return_linearized_transition: 61 | return MVNormalParameters(mean, cov), F 62 | return MVNormalParameters(mean, 0.5 * (cov + cov.T)) 63 | 64 | 65 | def update(observation_function: Callable, 66 | observation_covariance: jnp.ndarray, 67 | predicted_state: MVNormalParameters, 68 | observation: jnp.ndarray, 69 | linearization_state: MVNormalParameters) -> MVNormalParameters: 70 | """ Computes the extended kalman filter linearization of :math:`x_t \mid y_t` 71 | 72 | Parameters 73 | ---------- 74 | observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` 75 | observation function of the state space model 76 | observation_covariance: (K,K) array 77 | observation_error :math:`\Sigma` fed to observation_function 78 | predicted_state: MVNormalParameters 79 | predicted approximate mv normal parameters of the filter :math:`x` 80 | observation: (K) array 81 | Observation :math:`y` 82 | linearization_state: MVNormalParameters 83 | state for the linearization of the update 84 | 85 | Returns 86 | ------- 87 | updated_mvn_parameters: MVNormalParameters 88 | filtered state 89 | """ 90 | if linearization_state is None: 91 | linearization_state = predicted_state 92 | sigma_points = get_sigma_points(linearization_state) 93 | obs_points = observation_function(sigma_points.points) 94 | obs_sigma_points = SigmaPoints(obs_points, sigma_points.wm, sigma_points.wc) 95 | 96 | obs_state = get_mv_normal_parameters(obs_sigma_points) 97 | cross_covariance = covariance_sigma_points(sigma_points, linearization_state.mean, obs_sigma_points, 98 | obs_state.mean) 99 | 100 | H = jlinalg.solve(linearization_state.cov, cross_covariance, assume_a="pos").T # linearized observation function 101 | 102 | d = obs_state.mean - jnp.dot(H, linearization_state.mean) # linearized observation offset 103 | 104 | residual_cov = H @ (predicted_state.cov - linearization_state.cov) @ H.T + \ 105 | observation_covariance + obs_state.cov 106 | 107 | gain = jlinalg.solve(residual_cov, H @ predicted_state.cov).T 108 | 109 | predicted_observation = H @ predicted_state.mean + d 110 | 111 | residual = observation - predicted_observation 112 | mean = predicted_state.mean + gain @ residual 113 | cov = predicted_state.cov - gain @ residual_cov @ gain.T 114 | loglikelihood = multivariate_normal.logpdf(residual, jnp.zeros_like(residual), residual_cov) 115 | 116 | return loglikelihood, MVNormalParameters(mean, 0.5 * (cov + cov.T)) 117 | 118 | 119 | def filter_routine(initial_state: MVNormalParameters, 120 | observations: jnp.ndarray, 121 | transition_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 122 | transition_covariances: jnp.ndarray, 123 | observation_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 124 | observation_covariances: jnp.ndarray, 125 | linearization_states: MVNormalParameters = None, 126 | propagate_first: bool = True) -> Tuple[float, MVNormalParameters]: 127 | """ Computes the predict-update routine of the cubature Kalman Filter equations and returns a series of filtered_states 128 | 129 | Parameters 130 | ---------- 131 | initial_state: MVNormalParameters 132 | prior belief on the initial state distribution 133 | observations: (n, K) array 134 | array of n observations of dimension K 135 | transition_function: callable :math:`f(x_t, \epsilon_t) \mapsto x_{t-1}` 136 | transition function of the state space model 137 | transition_covariances: (D, D) or (1, D, D) or (n, D, D) array 138 | transition covariances for each time step, if passed only one, it is repeated n times 139 | observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` 140 | observation function of the state space model 141 | observation_covariances: (K, K) or (1, K, K) or (n, K, K) array 142 | observation error covariances for each time step, if passed only one, it is repeated n times 143 | linearization_states: MVNormalParameters, optional 144 | states for the cubature linearization 145 | propagate_first: bool, optional 146 | Is the first step a transition or an update? i.e. False if the initial time step has 147 | an associated observation. Default is True. 148 | 149 | Returns 150 | ------- 151 | loglikelihood: float 152 | Marginal loglikelihood of the observations given the parameters 153 | filtered_states: MVNormalParameters 154 | list of filtered states 155 | """ 156 | 157 | n_observations = observations.shape[0] 158 | 159 | transition_covariances, observation_covariances = list(map( 160 | lambda z: make_matrices_parameters(z, n_observations), 161 | [transition_covariances, 162 | observation_covariances])) 163 | 164 | def prop_first_body(carry, inputs): 165 | running_ell, state, prev_linearization_state = carry 166 | observation, transition_covariance, observation_covariance, linearization_state = inputs 167 | predicted_state = predict(transition_function, transition_covariance, state, prev_linearization_state) 168 | loglikelihood, updated_state = update(observation_function, observation_covariance, predicted_state, 169 | observation, linearization_state) 170 | 171 | return (running_ell + loglikelihood, updated_state, linearization_state), updated_state 172 | 173 | def update_first_body(carry, inputs): 174 | running_ell, state, _ = carry 175 | observation, transition_covariance, observation_covariance, linearization_point = inputs 176 | loglikelihood, updated_state = update(observation_function, observation_covariance, state, 177 | observation, linearization_point) 178 | predicted_state = predict(transition_function, transition_covariance, updated_state, linearization_point) 179 | return (running_ell + loglikelihood, predicted_state, linearization_point), updated_state 180 | 181 | body = prop_first_body if propagate_first else update_first_body 182 | 183 | initial_linearization_state = jax.tree_map(lambda z: z[0], linearization_states) 184 | if propagate_first: 185 | linearization_states = jax.tree_map(lambda z: z[1:], linearization_states) 186 | 187 | (ell, *_), filtered_states = lax.scan(body, 188 | (0., initial_state, initial_linearization_state), 189 | [observations, 190 | transition_covariances, 191 | observation_covariances, 192 | linearization_states], 193 | length=n_observations) 194 | 195 | if propagate_first: 196 | filtered_states = jax.tree_map(lambda y, z: jnp.concatenate([y[None, ...], z], 0), initial_state, 197 | filtered_states) 198 | return ell, filtered_states 199 | 200 | 201 | def smooth(transition_function: Callable[[jnp.ndarray], jnp.ndarray], 202 | transition_covariance: jnp.array, 203 | filtered_state: MVNormalParameters, 204 | previous_smoothed: MVNormalParameters, 205 | linearization_state: MVNormalParameters) -> MVNormalParameters: 206 | """ 207 | One step cubature kalman smoother 208 | 209 | Parameters 210 | ---------- 211 | transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` 212 | transition function of the state space model 213 | transition_covariance: (D,D) array 214 | covariance :math:`\Sigma` of the noise fed to transition_function 215 | filtered_state: MVNormalParameters 216 | mean and cov computed by Kalman Filtering 217 | previous_smoothed: MVNormalParameters 218 | smoothed state of the previous step 219 | linearization_state: MVNormalParameters 220 | state for the cubature linearization 221 | 222 | Returns 223 | ------- 224 | smoothed_state: MVNormalParameters 225 | smoothed state 226 | """ 227 | predicted_state, F = predict(transition_function, transition_covariance, filtered_state, linearization_state, True) 228 | smoothing_gain = jnp.linalg.solve(predicted_state.cov, F @ filtered_state.cov).T 229 | mean = filtered_state.mean + smoothing_gain @ (previous_smoothed.mean - predicted_state.mean) 230 | cov = filtered_state.cov + smoothing_gain @ (previous_smoothed.cov - predicted_state.cov) @ smoothing_gain.T 231 | return MVNormalParameters(mean, 0.5 * (cov + cov.T)) 232 | 233 | 234 | def smoother_routine(transition_function: Callable[[jnp.ndarray], jnp.ndarray], 235 | transition_covariances: jnp.ndarray, 236 | filtered_states: MVNormalParameters, 237 | linearization_states: MVNormalParameters = None, 238 | propagate_first: bool = True 239 | ) -> MVNormalParameters: 240 | """ Computes the cubature Rauch-Tung-Striebel smoother routine and returns a series of smoothed_states 241 | 242 | Parameters 243 | ---------- 244 | filtered_states: MVNormalParameters 245 | Filtered states obtained from Kalman Filter 246 | transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` 247 | transition function of the state space model 248 | transition_covariances: (D, D) or (1, D, D) or (n, D, D) array 249 | transition covariances for each time step, if passed only one, it is repeated n times 250 | linearization_states: MVNormalParameters, optional 251 | states for the cubature linearization 252 | propagate_first: bool, optional 253 | Is the first step a transition or an update? i.e. False if the initial time step has 254 | an associated observation. Default is True. 255 | 256 | Returns 257 | ------- 258 | smoothed_states: MVNormalParameters 259 | list of smoothed states 260 | """ 261 | n_observations = filtered_states.mean.shape[0] 262 | if propagate_first: 263 | transition_covariances = make_matrices_parameters(transition_covariances, n_observations - 1) 264 | else: 265 | transition_covariances = make_matrices_parameters(transition_covariances, n_observations) 266 | transition_covariances = transition_covariances[:-1] 267 | 268 | def body(state, inputs): 269 | filtered, transition_covariance, linearization_state = inputs 270 | if linearization_state is None: 271 | linearization_state = filtered 272 | smoothed_state = smooth(transition_function, transition_covariance, filtered, state, linearization_state) 273 | return smoothed_state, smoothed_state 274 | 275 | last_state = MVNormalParameters(filtered_states.mean[-1], filtered_states.cov[-1]) 276 | filtered_states, linearization_states = jax.tree_map(lambda x: x[:-1], 277 | [filtered_states, linearization_states]) 278 | _, smoothed_states = lax.scan(body, 279 | last_state, 280 | [filtered_states, transition_covariances, linearization_states], 281 | reverse=True) 282 | 283 | smoothed_states = jax.tree_map(lambda y, z: jnp.concatenate([y, z[None, ...]], 0), smoothed_states, 284 | last_state) 285 | 286 | return smoothed_states 287 | 288 | 289 | def iterated_smoother_routine(initial_state: MVNormalParameters, 290 | observations: jnp.ndarray, 291 | transition_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 292 | transition_covariances: jnp.ndarray, 293 | observation_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 294 | observation_covariances: jnp.ndarray, 295 | initial_linearization_states: jnp.ndarray = None, 296 | n_iter: int = 100, 297 | propagate_first: bool = False) -> MVNormalParameters: 298 | """ 299 | Computes the Gauss-Newton iterated extended Kalman smoother 300 | 301 | Parameters 302 | ---------- 303 | initial_state: MVNormalParameters 304 | prior belief on the initial state distribution 305 | observations: (n, K) array 306 | array of n observations of dimension K 307 | transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` 308 | transition function of the state space model 309 | transition_covariances: (D, D) or (1, D, D) or (n, D, D) array 310 | transition covariances for each time step, if passed only one, it is repeated n times 311 | observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` 312 | observation function of the state space model 313 | observation_covariances: (K, K) or (1, K, K) or (n, K, K) array 314 | observation error covariances for each time step, if passed only one, it is repeated n times 315 | initial_linearization_states: MVNormalParameters, optional 316 | states for linearization of the first pass. 317 | n_iter: int 318 | number of times the filter-smoother routine is computed 319 | propagate_first: bool, optional 320 | Is the first step a transition or an update? i.e. False if the initial time step has 321 | an associated observation. Default is True. 322 | 323 | Returns 324 | ------- 325 | iterated_smoothed_trajectories: MVNormalParameters 326 | The result of the smoothing routine 327 | 328 | """ 329 | n_observations = observations.shape[0] 330 | transition_covariances, observation_covariances = list(map( 331 | lambda z: make_matrices_parameters(z, n_observations), 332 | [transition_covariances, 333 | observation_covariances])) 334 | 335 | def body(linearization_points, _): 336 | _, filtered_states = filter_routine(initial_state, observations, transition_function, transition_covariances, 337 | observation_function, observation_covariances, linearization_points, 338 | propagate_first) 339 | return smoother_routine(transition_function, transition_covariances, filtered_states, 340 | linearization_points, propagate_first), None 341 | 342 | if initial_linearization_states is None: 343 | initial_linearization_states = body(None, None)[0] 344 | 345 | iterated_smoothed_trajectories, _ = lax.scan(body, initial_linearization_states, jnp.arange(n_iter)) 346 | return iterated_smoothed_trajectories 347 | -------------------------------------------------------------------------------- /parsmooth/sequential/extended.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import jax.scipy.linalg as jlag 6 | from jax import lax, jacfwd 7 | from jax.scipy.stats import multivariate_normal 8 | 9 | from ..utils import MVNormalParameters, make_matrices_parameters 10 | 11 | __all__ = ["filter_routine", "smoother_routine"] 12 | 13 | 14 | def predict(transition_function: Callable[[jnp.ndarray], jnp.ndarray], 15 | transition_covariance: jnp.ndarray, 16 | prior: MVNormalParameters, 17 | linearization_point: jnp.ndarray) -> MVNormalParameters: 18 | """ Computes the extended kalman filter linearization of :math:`x_{t+1} = f(x_t, \mathcal{N}(0, \Sigma))` 19 | 20 | Parameters 21 | ---------- 22 | transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` 23 | transition function of the state space model 24 | transition_covariance: (D,D) array 25 | covariance :math:`\Sigma` of the noise fed to transition_function 26 | prior: MVNormalParameters 27 | prior state of the filter x 28 | linearization_point: jnp.ndarray 29 | Where to compute the Jacobian 30 | 31 | Returns 32 | ------- 33 | out: MVNormalParameters 34 | Predicted state 35 | """ 36 | if linearization_point is None: 37 | linearization_point = prior.mean 38 | jac_x = jacfwd(transition_function, 0)(linearization_point) 39 | cov = jnp.dot(jac_x, jnp.dot(prior.cov, jac_x.T)) + transition_covariance 40 | mean = transition_function(linearization_point) 41 | mean = mean + jnp.dot(jac_x, prior.mean - linearization_point) 42 | return MVNormalParameters(mean, 0.5 * (cov + cov.T)) 43 | 44 | 45 | def update(observation_function: Callable[[jnp.ndarray], jnp.ndarray], 46 | observation_covariance: jnp.ndarray, 47 | predicted: MVNormalParameters, 48 | observation: jnp.ndarray, 49 | linearization_point: jnp.ndarray) -> Tuple[float, MVNormalParameters]: 50 | """ Computes the extended kalman filter linearization of :math:`x_t \mid y_t` 51 | 52 | Parameters 53 | ---------- 54 | observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` 55 | observation function of the state space model 56 | observation_covariance: (K,K) array 57 | observation_error :math:`\Sigma` fed to observation_function 58 | predicted: MVNormalParameters 59 | predicted state of the filter :math:`x` 60 | observation: (K) array 61 | Observation :math:`y` 62 | linearization_point: jnp.ndarray 63 | Where to compute the Jacobian 64 | 65 | Returns 66 | ------- 67 | loglikelihood: float 68 | Log-likelihood increment for observation 69 | updated_state: MVNormalParameters 70 | filtered state 71 | """ 72 | if linearization_point is None: 73 | linearization_point = predicted.mean 74 | jac_x = jacfwd(observation_function, 0)(linearization_point) 75 | 76 | obs_mean = observation_function(linearization_point) + jnp.dot(jac_x, predicted.mean - linearization_point) 77 | 78 | residual = observation - obs_mean 79 | residual_covariance = jnp.dot(jac_x, jnp.dot(predicted.cov, jac_x.T)) 80 | residual_covariance = residual_covariance + observation_covariance 81 | 82 | gain = jnp.dot(predicted.cov, jlag.solve(residual_covariance, jac_x, assume_a="pos").T) 83 | 84 | mean = predicted.mean + jnp.dot(gain, residual) 85 | cov = predicted.cov - jnp.dot(gain, jnp.dot(residual_covariance, gain.T)) 86 | updated_state = MVNormalParameters(mean, 0.5 * (cov + cov.T)) 87 | 88 | loglikelihood = multivariate_normal.logpdf(residual, jnp.zeros_like(residual), residual_covariance) 89 | return loglikelihood, updated_state 90 | 91 | 92 | def filter_routine(initial_state: MVNormalParameters, 93 | observations: jnp.ndarray, 94 | transition_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 95 | transition_covariances: jnp.ndarray, 96 | observation_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 97 | observation_covariances: jnp.ndarray, 98 | linearization_points: jnp.ndarray = None, 99 | propagate_first: bool = True) -> Tuple[float, MVNormalParameters]: 100 | """ Computes the predict-update routine of the Kalman Filter equations and returns a series of filtered_states 101 | 102 | Parameters 103 | ---------- 104 | initial_state: MVNormalParameters 105 | prior belief on the initial state distribution 106 | observations: (n, K) array 107 | array of n observations of dimension K 108 | transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` 109 | transition function of the state space model 110 | transition_covariances: (D, D) or (1, D, D) or (n, D, D) array 111 | transition covariances for each time step, if passed only one, it is repeated n times 112 | observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` 113 | observation function of the state space model 114 | observation_covariances: (K, K) or (1, K, K) or (n, K, K) array 115 | observation error covariances for each time step, if passed only one, it is repeated n times 116 | linearization_points: (n, D) array, optional 117 | points at which to compute the jacobians. 118 | propagate_first: bool, optional 119 | Is the first step a transition or an update? i.e. False if the initial time step has 120 | an associated observation. Default is True. 121 | 122 | Returns 123 | ------- 124 | loglikelihood: float 125 | Marginal loglikelihood of the observations given the parameters 126 | filtered_states: MVNormalParameters 127 | list of filtered states 128 | """ 129 | n_observations = observations.shape[0] 130 | 131 | transition_covariances, observation_covariances = list(map( 132 | lambda z: make_matrices_parameters(z, n_observations), 133 | [transition_covariances, 134 | observation_covariances])) 135 | 136 | def prop_first_body(carry, inputs): 137 | running_ell, state, prev_linearization_point = carry 138 | observation, transition_covariance, observation_covariance, linearization_point = inputs 139 | predicted_state = predict(transition_function, transition_covariance, state, prev_linearization_point) 140 | loglikelihood, updated_state = update(observation_function, observation_covariance, predicted_state, 141 | observation, linearization_point) 142 | 143 | return (running_ell + loglikelihood, updated_state, linearization_point), updated_state 144 | 145 | def update_first_body(carry, inputs): 146 | running_ell, state, _ = carry 147 | observation, transition_covariance, observation_covariance, linearization_point = inputs 148 | loglikelihood, updated_state = update(observation_function, observation_covariance, state, 149 | observation, linearization_point) 150 | predicted_state = predict(transition_function, transition_covariance, updated_state, linearization_point) 151 | return (running_ell + loglikelihood, predicted_state, linearization_point), updated_state 152 | 153 | body = prop_first_body if propagate_first else update_first_body 154 | 155 | if linearization_points is not None: 156 | initial_linearization_point = linearization_points[0] if linearization_points is not None else None 157 | linearization_points = linearization_points[1:] if propagate_first else linearization_points 158 | else: 159 | initial_linearization_point = linearization_points = None 160 | 161 | (ell, *_), filtered_states = lax.scan(body, 162 | (0., initial_state, initial_linearization_point), 163 | [observations, 164 | transition_covariances, 165 | observation_covariances, 166 | linearization_points], 167 | length=n_observations) 168 | 169 | if propagate_first: 170 | filtered_states = jax.tree_map(lambda y, z: jnp.concatenate([y[None, ...], z], 0), initial_state, 171 | filtered_states) 172 | 173 | return ell, filtered_states 174 | 175 | 176 | def smooth(transition_function: Callable[[jnp.ndarray], jnp.ndarray], 177 | transition_covariance: jnp.array, 178 | filtered_state: MVNormalParameters, 179 | previous_smoothed: MVNormalParameters, 180 | linearization_point: jnp.ndarray) -> MVNormalParameters: 181 | """ 182 | One step extended kalman smoother 183 | 184 | Parameters 185 | ---------- 186 | transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` 187 | transition function of the state space model 188 | transition_covariance: (D,D) array 189 | covariance :math:`\Sigma` of the noise fed to transition_function 190 | filtered_state: MVNormalParameters 191 | mean and cov computed by Kalman Filtering 192 | previous_smoothed: MVNormalParameters, 193 | smoothed state of the previous step 194 | linearization_point: jnp.ndarray 195 | Where to compute the Jacobian 196 | 197 | Returns 198 | ------- 199 | smoothed_state: MVNormalParameters 200 | smoothed state 201 | """ 202 | 203 | jac_x = jacfwd(transition_function, 0)(linearization_point) 204 | 205 | mean = transition_function(linearization_point) + jnp.dot(jac_x, filtered_state.mean - linearization_point) 206 | mean_diff = previous_smoothed.mean - mean 207 | 208 | cov = jnp.dot(jac_x, jnp.dot(filtered_state.cov, jac_x.T)) + transition_covariance 209 | cov_diff = previous_smoothed.cov - cov 210 | 211 | gain = jnp.dot(filtered_state.cov, jlag.solve(cov, jac_x, assume_a="pos").T) 212 | 213 | mean = filtered_state.mean + jnp.dot(gain, mean_diff) 214 | cov = filtered_state.cov + jnp.dot(gain, jnp.dot(cov_diff, gain.T)) 215 | return MVNormalParameters(mean, cov) 216 | 217 | 218 | def smoother_routine(transition_function: Callable[[jnp.ndarray], jnp.ndarray], 219 | transition_covariances: jnp.ndarray, 220 | filtered_states: MVNormalParameters, 221 | linearization_points: jnp.ndarray = None, 222 | propagate_first: bool = True, 223 | ) -> MVNormalParameters: 224 | """ Computes the extended Rauch-Tung-Striebel (a.k.a extended Kalman) smoother routine and returns a series of smoothed_states 225 | 226 | Parameters 227 | ---------- 228 | filtered_states: MVNormalParameters 229 | Filtered states obtained from Kalman Filter 230 | transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` 231 | transition function of the state space model 232 | transition_covariances: (D, D) or (1, D, D) or (n, D, D) array 233 | transition covariances for each time step, if passed only one, it is repeated n times 234 | linearization_points: (n, D) array, optional 235 | points at which to compute the jacobians. 236 | propagate_first: bool, optional 237 | Is the first step a transition or an update? i.e. False if the initial time step has 238 | an associated observation. Default is True. 239 | 240 | Returns 241 | ------- 242 | smoothed_states: MVNormalParameters 243 | list of smoothed states 244 | """ 245 | n_observations = filtered_states.mean.shape[0] 246 | if propagate_first: 247 | transition_covariances = make_matrices_parameters(transition_covariances, n_observations - 1) 248 | else: 249 | transition_covariances = make_matrices_parameters(transition_covariances, n_observations) 250 | transition_covariances = transition_covariances[:-1] 251 | 252 | def body(state, inputs): 253 | filtered, transition_covariance, linearization_point = inputs 254 | if linearization_point is None: 255 | linearization_point = filtered.mean 256 | smoothed_state = smooth(transition_function, transition_covariance, filtered, state, linearization_point) 257 | return smoothed_state, smoothed_state 258 | 259 | last_state = MVNormalParameters(filtered_states.mean[-1], filtered_states.cov[-1]) 260 | filtered_states, linearization_points = jax.tree_map(lambda x: x[:-1], 261 | [filtered_states, linearization_points]) 262 | _, smoothed_states = lax.scan(body, 263 | last_state, 264 | [filtered_states, transition_covariances, linearization_points], 265 | reverse=True) 266 | 267 | smoothed_states = jax.tree_map(lambda y, z: jnp.concatenate([y, z[None, ...]], 0), smoothed_states, 268 | last_state) 269 | 270 | return smoothed_states 271 | 272 | 273 | def iterated_smoother_routine(initial_state: MVNormalParameters, 274 | observations: jnp.ndarray, 275 | transition_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 276 | transition_covariances: jnp.ndarray, 277 | observation_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 278 | observation_covariances: jnp.ndarray, 279 | initial_linearization_points: jnp.ndarray = None, 280 | n_iter: int = 100, 281 | propagate_first: bool = True): 282 | """ 283 | Computes the Gauss-Newton iterated extended Kalman smoother 284 | 285 | Parameters 286 | ---------- 287 | initial_state: MVNormalParameters 288 | prior belief on the initial state distribution 289 | observations: (n, K) array 290 | array of n observations of dimension K 291 | transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` 292 | transition function of the state space model 293 | transition_covariances: (D, D) or (1, D, D) or (n, D, D) array 294 | transition covariances for each time step, if passed only one, it is repeated n times 295 | observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` 296 | observation function of the state space model 297 | observation_covariances: (K, K) or (1, K, K) or (n, K, K) array 298 | observation error covariances for each time step, if passed only one, it is repeated n times 299 | initial_linearization_points: jnp.ndarray , optional 300 | points at which to compute the jacobians durning the first pass. 301 | n_iter: int 302 | number of times the filter-smoother routine is computed 303 | propagate_first: bool, optional 304 | Is the first step a transition or an update? i.e. False if the initial time step has 305 | an associated observation. Default is True. 306 | 307 | Returns 308 | ------- 309 | iterated_smoothed_trajectories: MVNormalParameters 310 | The result of the smoothing routine 311 | 312 | """ 313 | n_observations = observations.shape[0] 314 | 315 | transition_covariances, observation_covariances = list(map( 316 | lambda z: make_matrices_parameters(z, n_observations), 317 | [transition_covariances, 318 | observation_covariances])) 319 | 320 | def body(curr_smoother, _): 321 | if curr_smoother is not None: 322 | linearization_points = curr_smoother.mean if isinstance(curr_smoother, 323 | MVNormalParameters) else curr_smoother 324 | else: 325 | linearization_points = None 326 | 327 | _, filtered_states = filter_routine(initial_state, observations, transition_function, transition_covariances, 328 | observation_function, observation_covariances, linearization_points, 329 | propagate_first) 330 | return smoother_routine(transition_function, transition_covariances, filtered_states, 331 | linearization_points, propagate_first), None 332 | 333 | if initial_linearization_points is None: 334 | initial_linearization_points = body(None, None)[0] 335 | 336 | iterated_smoothed_trajectories, _ = lax.scan(body, initial_linearization_points, jnp.arange(n_iter)) 337 | return iterated_smoothed_trajectories 338 | -------------------------------------------------------------------------------- /parsmooth/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import jax.numpy as jnp 4 | import jax.scipy.linalg as jlnialg 5 | import numpy as np 6 | 7 | __all__ = ["MVNormalParameters", "make_matrices_parameters"] 8 | 9 | MVNormalParameters = namedtuple("MVNormalParameters", ["mean", "cov"]) 10 | 11 | 12 | def make_matrices_parameters(matrix: jnp.ndarray or np.array, n_observations: int) -> jnp.array: 13 | """ Processes a matrix (or "list" thereof) to be able to be iterated over n_observations times 14 | 15 | Parameters 16 | ---------- 17 | matrix: array 18 | Matrix to be processed 19 | n_observations: int 20 | First dimension of the returned array 21 | 22 | Returns 23 | ------- 24 | 25 | """ 26 | if jnp.ndim(matrix) <= 2: 27 | return jnp.tile(matrix, (n_observations, 1, 1)) 28 | elif jnp.ndim(matrix) == 3: 29 | if matrix.shape[0] == 1: 30 | return jnp.repeat(matrix, n_observations, 0) 31 | if matrix.shape[0] == n_observations: 32 | return matrix 33 | raise ValueError("if matrix has 3 dimensions, its first dimension must be of size 1 or n_observations") 34 | raise ValueError("matrix has more than 3 dimensions") 35 | 36 | 37 | # The real logic 38 | def _make_associative_filtering_params(args): 39 | Hk, Rk, Fk_1, Qk_1, uk_1, yk, dk, I_dim = args 40 | 41 | # FIRST TERM 42 | ############ 43 | 44 | # temp variable 45 | HQ = jnp.dot(Hk, Qk_1) # Hk @ Qk_1 46 | 47 | Sk = jnp.dot(HQ, Hk.T) + Rk 48 | Kk = jlnialg.solve(Sk, HQ, assume_a="pos").T # using the fact that S and Q are symmetric 49 | 50 | # temp variable: 51 | I_KH = I_dim - jnp.dot(Kk, Hk) # I - Kk @ Hk 52 | 53 | Ck = jnp.dot(I_KH, Qk_1) 54 | 55 | residual = (yk - jnp.dot(Hk, uk_1) - dk) 56 | 57 | bk = uk_1 + jnp.dot(Kk, residual) 58 | Ak = jnp.dot(I_KH, Fk_1) 59 | 60 | # SECOND TERM 61 | ############# 62 | HF = jnp.dot(Hk, Fk_1) 63 | FHS_inv = jsolve(Sk, HF).T 64 | 65 | etak = jnp.dot(FHS_inv, residual) 66 | Jk = jnp.dot(FHS_inv, HF) 67 | 68 | return Ak, bk, Ck, etak, Jk 69 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter 2 | matplotlib 3 | numba -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Author: Adrien Corenflos 2 | 3 | """Install parsmooth.""" 4 | 5 | from setuptools import setup, find_packages 6 | 7 | with open('requirements.txt') as f: 8 | requirements = f.read().splitlines() 9 | 10 | setup( 11 | name='parsmooth', 12 | version='0.1', 13 | description='Parallel Extended Kalman Filter.', 14 | author='Adrien Corenflos', 15 | author_email='adrien.corenflos@gmail.com', 16 | url='https://github.com/AdrienCorenflos/parallelEKF', 17 | packages=find_packages(), 18 | install_requires=requirements, 19 | ) 20 | --------------------------------------------------------------------------------