├── sts_jax ├── __init__.py ├── causal_impact │ ├── __init__.py │ └── causal_impact.py ├── figures │ ├── causal_obs.png │ ├── comparison.png │ ├── electr_obs.png │ ├── poisson_obs.png │ ├── causal_forecast.png │ ├── electr_forecast.png │ └── poisson_forecast.png └── structural_time_series │ ├── __init__.py │ ├── learning.py │ ├── sts_model.py │ ├── sts_ssm.py │ └── sts_components.py ├── test-requirements.txt ├── lint-requirements.txt ├── setup.cfg ├── pyproject.toml ├── .gitignore ├── requirements.txt ├── Makefile ├── setup.py ├── .github └── workflows │ └── ci.yml ├── LICENSE ├── tests └── structural_time_series │ ├── test_autoregressive.py │ └── test_local_linear_trend.py └── README.md /sts_jax/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sts_jax/causal_impact/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | 3 | pytest>=7.0.1 4 | pytest-cov>=3.0.0 -------------------------------------------------------------------------------- /lint-requirements.txt: -------------------------------------------------------------------------------- 1 | black>=22.3.0 2 | flake8>=4.0.1 3 | isort>=5.10.1 4 | pre-commit>=2.19.0 -------------------------------------------------------------------------------- /sts_jax/figures/causal_obs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/sts-jax/main/sts_jax/figures/causal_obs.png -------------------------------------------------------------------------------- /sts_jax/figures/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/sts-jax/main/sts_jax/figures/comparison.png -------------------------------------------------------------------------------- /sts_jax/figures/electr_obs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/sts-jax/main/sts_jax/figures/electr_obs.png -------------------------------------------------------------------------------- /sts_jax/figures/poisson_obs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/sts-jax/main/sts_jax/figures/poisson_obs.png -------------------------------------------------------------------------------- /sts_jax/figures/causal_forecast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/sts-jax/main/sts_jax/figures/causal_forecast.png -------------------------------------------------------------------------------- /sts_jax/figures/electr_forecast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/sts-jax/main/sts_jax/figures/electr_forecast.png -------------------------------------------------------------------------------- /sts_jax/figures/poisson_forecast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/sts-jax/main/sts_jax/figures/poisson_forecast.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | extend-ignore = E203, E501, F722 4 | per-file-ignores = */__init__.py:F401,F403 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | profile = "black" 3 | 4 | [tool.black] 5 | line-length = 120 6 | 7 | [tool.pytest.ini_options] 8 | addopts = [ 9 | "-v", 10 | "--cov=sts_jax", 11 | "--color=yes", 12 | ] 13 | testpaths = "tests" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | docs/_build/ 3 | build/ 4 | dist/ 5 | 6 | *egg-info 7 | *.ipynb_checkpoints 8 | 9 | # ignore figures unless manually added 10 | # *.png 11 | *.jpg 12 | *.pdf 13 | *-dot 14 | *.DS_Store 15 | .vscode/ 16 | 17 | .coverage -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax>=0.3.15 2 | jaxlib 3 | optax 4 | chex 5 | tensorflow>=2.11 6 | tensorflow_probability 7 | matplotlib 8 | seaborn 9 | tqdm 10 | flax 11 | scikit-learn 12 | blackjax 13 | jaxopt 14 | jaxtyping 15 | typing-extensions 16 | dynamax @ git+https://github.com/probml/dynamax.git#egg=dynamax -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: init lint check_lint test 2 | 3 | init: 4 | python -m pip install -e . 5 | 6 | lint: 7 | pip install -r lint-requirements.txt 8 | isort . 9 | black . 10 | 11 | check_lint: 12 | pip install -r lint-requirements.txt 13 | flake8 . 14 | isort --check-only . 15 | black --diff --check --fast . 16 | 17 | test: 18 | pip install -r test-requirements.txt 19 | pytest 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from distutils.core import setup 3 | 4 | import setuptools 5 | 6 | with open("requirements.txt") as f: 7 | requirements = list(map(lambda x: x.strip(), f.read().strip().splitlines())) 8 | 9 | setup( 10 | name="sts_jax", 11 | version="0.1", 12 | description="JAX code for structural time series", 13 | url="https://github.com/probml/sts-jax", 14 | install_requires=requirements, 15 | packages=setuptools.find_packages(), 16 | ) 17 | -------------------------------------------------------------------------------- /sts_jax/structural_time_series/__init__.py: -------------------------------------------------------------------------------- 1 | from sts_jax.structural_time_series.sts_components import ( 2 | Autoregressive, 3 | Cycle, 4 | LinearRegression, 5 | LocalLinearTrend, 6 | SeasonalDummy, 7 | SeasonalTrig, 8 | ) 9 | from sts_jax.structural_time_series.sts_model import StructuralTimeSeries 10 | 11 | __all__ = [ 12 | "Autoregressive", 13 | "Cycle", 14 | "LinearRegression", 15 | "LocalLinearTrend", 16 | "SeasonalDummy", 17 | "SeasonalTrig", 18 | "StructuralTimeSeries", 19 | ] 20 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: [push] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.8", "3.9", "3.10"] 11 | 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Set up Python 15 | uses: actions/setup-python@v3 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Run lint 19 | run: | 20 | make init 21 | make check_lint 22 | test: 23 | runs-on: ubuntu-latest 24 | strategy: 25 | matrix: 26 | python-version: ["3.8", "3.9", "3.10"] 27 | 28 | steps: 29 | - uses: actions/checkout@v3 30 | - name: Set up Python 31 | uses: actions/setup-python@v3 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | - name: Run tests 35 | run: | 36 | make init 37 | make test 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Probabilistic machine learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/structural_time_series/test_autoregressive.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.random as jr 3 | import pytest 4 | import tensorflow as tf 5 | import tensorflow_probability as tfp 6 | from jax import lax 7 | from tensorflow_probability.substrates.jax.distributions import ( 8 | MultivariateNormalFullCovariance as MVN, 9 | ) 10 | 11 | from sts_jax.structural_time_series.sts_components import Autoregressive 12 | from sts_jax.structural_time_series.sts_model import StructuralTimeSeries as STS 13 | 14 | 15 | def _build_models(time_steps, key): 16 | 17 | keys = jr.split(key, 5) 18 | standard_mvn = MVN(jnp.zeros(1), jnp.eye(1)) 19 | 20 | # Generate parameters of the STS component 21 | level_scale = 5.0 22 | coef = 0.8 23 | initial_level = standard_mvn.sample(seed=keys[0]) 24 | 25 | obs_noise_scale = 4.0 26 | 27 | # Generate observed time series using the SSM representation. 28 | F = jnp.array([[coef]]) 29 | H = jnp.array([[1]]) 30 | Q = jnp.array([[level_scale]]) 31 | R = obs_noise_scale 32 | 33 | def _step(current_state, key): 34 | key1, key2 = jr.split(key) 35 | current_obs = H @ current_state + R * standard_mvn.sample(seed=key1) 36 | next_state = F @ current_state + Q @ MVN(jnp.zeros(1), jnp.eye(1)).sample(seed=key2) 37 | return next_state, current_obs 38 | 39 | initial_state = initial_level 40 | key_seq = jr.split(keys[2], time_steps) 41 | _, obs_time_series = lax.scan(_step, initial_state, key_seq) 42 | 43 | # Build the STS model using tfp module. 44 | tfp_comp = tfp.sts.Autoregressive(order=1, observed_time_series=obs_time_series, name="ar") 45 | tfp_model = tfp.sts.Sum([tfp_comp], observed_time_series=obs_time_series) 46 | 47 | # Build the dynamax STS model. 48 | dynamax_comp = Autoregressive(order=1, name="ar") 49 | dynamax_model = STS([dynamax_comp], obs_time_series=obs_time_series) 50 | 51 | # Set the parameters to the parameters learned by the tfp module and fix the parameters. 52 | tfp_vi_posterior = tfp.sts.build_factored_surrogate_posterior(tfp_model) 53 | tfp.vi.fit_surrogate_posterior( 54 | target_log_prob_fn=tfp_model.joint_distribution(obs_time_series).log_prob, 55 | surrogate_posterior=tfp_vi_posterior, 56 | optimizer=tf.optimizers.Adam(learning_rate=0.1), 57 | num_steps=200, 58 | jit_compile=True, 59 | ) 60 | vi_dists, _ = tfp_vi_posterior.distribution.sample_distributions() 61 | tfp_params = tfp_vi_posterior.sample(sample_shape=(1,)) 62 | 63 | dynamax_model.params["ar"]["cov_level"] = jnp.atleast_2d(jnp.array(tfp_params["ar/_level_scale"] ** 2)) 64 | dynamax_model.params["ar"]["coef"] = jnp.array(tfp_params["ar/_coefficients"])[0] 65 | dynamax_model.params["obs_model"]["cov"] = jnp.atleast_2d(jnp.array(tfp_params["observation_noise_scale"] ** 2)) 66 | 67 | return (tfp_model, tfp_params, dynamax_model, dynamax_model.params, obs_time_series, vi_dists) 68 | 69 | 70 | def test_autoregress(time_steps=150, key=jr.PRNGKey(3)): 71 | 72 | tfp_model, tfp_params, dynamax_model, dynamax_params, obs_time_series, vi_dists = _build_models(time_steps, key) 73 | 74 | # Fit and forecast with the tfp module. 75 | # Not use tfp.sts.decompose_by_component() since its output series is centered at 0. 76 | masked_time_series = tfp.sts.MaskedTimeSeries( 77 | time_series=obs_time_series, is_missing=tf.math.is_nan(obs_time_series) 78 | ) 79 | tfp_posterior = tfp.sts.impute_missing_values( 80 | tfp_model, masked_time_series, tfp_params, include_observation_noise=False 81 | ) 82 | 83 | tfp_posterior_mean = jnp.array(tfp_posterior.mean()).squeeze() 84 | tfp_posterior_scale = jnp.array(jnp.array(tfp_posterior.stddev())).squeeze() 85 | 86 | # Fit and forecast with dynamax 87 | dynamax_posterior = dynamax_model.decompose_by_component(dynamax_params, obs_time_series) 88 | dynamax_posterior_mean = dynamax_model._uncenter_obs(dynamax_posterior["ar"]["pos_mean"]).squeeze() 89 | dynamax_posterior_cov = dynamax_posterior["ar"]["pos_cov"].squeeze() 90 | 91 | # Compare posterior inference by tfp and dynamax. 92 | # In comparing the smoothed posterior, we omit the first 5 time steps, 93 | # since the tfp and the dynamax implementations of STS has different settings in 94 | # distributions of initial state, which will influence the posterior inference of 95 | # the first few states. 96 | len_step = jnp.abs(tfp_posterior_mean[1:] - tfp_posterior_mean[:-1]).mean() 97 | assert jnp.allclose(tfp_posterior_mean[5:], dynamax_posterior_mean[5:], atol=len_step) 98 | assert jnp.allclose(tfp_posterior_scale[5:], jnp.sqrt(dynamax_posterior_cov)[5:], rtol=1e-2) 99 | 100 | 101 | @pytest.mark.skip( 102 | reason="Skipped because the forecast mean and variances are now computed as sample mean and variance, for dynamax model" 103 | ) 104 | def test_autoregress_forecast(time_steps=150, key=jr.PRNGKey(3)): 105 | tfp_model, tfp_params, dynamax_model, dynamax_params, obs_time_series, vi_dists = _build_models(time_steps, key) 106 | 107 | masked_time_series = tfp.sts.MaskedTimeSeries( 108 | time_series=obs_time_series, is_missing=tf.math.is_nan(obs_time_series) 109 | ) 110 | 111 | tfp_posterior = tfp.sts.impute_missing_values( 112 | tfp_model, masked_time_series, tfp_params, include_observation_noise=False 113 | ) 114 | 115 | tfp_posterior_mean = jnp.array(tfp_posterior.mean()).squeeze() 116 | 117 | tfp_forecasts = tfp.sts.forecast( 118 | tfp_model, obs_time_series, parameter_samples=tfp_params, num_steps_forecast=50, include_observation_noise=True 119 | ) 120 | 121 | tfp_forecast_mean = jnp.array(tfp_forecasts.mean()).squeeze() 122 | tfp_forecast_scale = jnp.array(tfp_forecasts.stddev()).squeeze() 123 | 124 | dynamax_forecast = dynamax_model.forecast(dynamax_params, obs_time_series, num_forecast_steps=50)[1] 125 | dynamax_forecast_mean = jnp.concatenate(dynamax_forecast).mean(axis=0).squeeze() 126 | dynamax_forecast_cov = jnp.concatenate(dynamax_forecast).var(axis=0).squeeze() 127 | # Compare forecast by tfp and dynamax. 128 | len_step = jnp.abs(tfp_posterior_mean[1:] - tfp_posterior_mean[:-1]).mean() 129 | assert jnp.allclose(tfp_forecast_mean, dynamax_forecast_mean, atol=0.5 * len_step) 130 | assert jnp.allclose(tfp_forecast_scale, jnp.sqrt(dynamax_forecast_cov), rtol=5e-2) 131 | -------------------------------------------------------------------------------- /tests/structural_time_series/test_local_linear_trend.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.random as jr 3 | import pytest 4 | import tensorflow as tf 5 | import tensorflow_probability as tfp 6 | from jax import lax 7 | from tensorflow_probability.substrates.jax.distributions import ( 8 | MultivariateNormalFullCovariance as MVN, 9 | ) 10 | 11 | from sts_jax.structural_time_series.sts_components import LocalLinearTrend 12 | from sts_jax.structural_time_series.sts_model import StructuralTimeSeries as STS 13 | 14 | 15 | def _build_models(time_steps, key): 16 | 17 | keys = jr.split(key, 5) 18 | standard_mvn = MVN(jnp.zeros(1), jnp.eye(1)) 19 | 20 | # Generate parameters of the STS component 21 | level_scale = 5 22 | slope_scale = 0.5 23 | initial_level = standard_mvn.sample(seed=keys[0]) 24 | initial_slope = standard_mvn.sample(seed=keys[1]) 25 | 26 | obs_noise_scale = 10 27 | 28 | # Generate observed time series using the SSM representation. 29 | F = jnp.array([[1, 1], [0, 1]]) 30 | H = jnp.array([[1, 0]]) 31 | Q = jnp.block([[level_scale, 0], [0, slope_scale]]) 32 | R = obs_noise_scale 33 | 34 | def _step(current_state, key): 35 | key1, key2 = jr.split(key) 36 | current_obs = H @ current_state + R * standard_mvn.sample(seed=key1) 37 | next_state = F @ current_state + Q @ MVN(jnp.zeros(2), jnp.eye(2)).sample(seed=key2) 38 | return next_state, current_obs 39 | 40 | initial_state = jnp.concatenate((initial_level, initial_slope)) 41 | key_seq = jr.split(keys[2], time_steps) 42 | _, obs_time_series = lax.scan(_step, initial_state, key_seq) 43 | 44 | # Build the STS model using tfp module. 45 | tfp_comp = tfp.sts.LocalLinearTrend(observed_time_series=obs_time_series, name="local_linear_trend") 46 | tfp_model = tfp.sts.Sum([tfp_comp], observed_time_series=obs_time_series) 47 | 48 | # Build the dynamax STS model. 49 | dynamax_comp = LocalLinearTrend(name="local_linear_trend") 50 | dynamax_model = STS([dynamax_comp], obs_time_series=obs_time_series) 51 | 52 | # Set the parameters to the parameters learned by the tfp module and fix the parameters. 53 | tfp_vi_posterior = tfp.sts.build_factored_surrogate_posterior(tfp_model) 54 | tfp.vi.fit_surrogate_posterior( 55 | target_log_prob_fn=tfp_model.joint_distribution(obs_time_series).log_prob, 56 | surrogate_posterior=tfp_vi_posterior, 57 | optimizer=tf.optimizers.Adam(learning_rate=0.1), 58 | num_steps=200, 59 | jit_compile=True, 60 | ) 61 | vi_dists, _ = tfp_vi_posterior.distribution.sample_distributions() 62 | tfp_params = tfp_vi_posterior.sample(sample_shape=(1,)) 63 | 64 | dynamax_model.params["local_linear_trend"]["cov_level"] = jnp.atleast_2d( 65 | jnp.array(tfp_params["local_linear_trend/_level_scale"] ** 2) 66 | ) 67 | dynamax_model.params["local_linear_trend"]["cov_slope"] = jnp.atleast_2d( 68 | jnp.array(tfp_params["local_linear_trend/_slope_scale"] ** 2) 69 | ) 70 | dynamax_model.params["obs_model"]["cov"] = jnp.atleast_2d(jnp.array(tfp_params["observation_noise_scale"] ** 2)) 71 | 72 | return (tfp_model, tfp_params, dynamax_model, dynamax_model.params, obs_time_series, vi_dists) 73 | 74 | 75 | def test_local_linear_trend(time_steps=150, key=jr.PRNGKey(3)): 76 | 77 | tfp_model, tfp_params, dynamax_model, dynamax_params, obs_time_series, vi_dists = _build_models(time_steps, key) 78 | 79 | # Fit and forecast with the tfp module. 80 | # Not use tfp.sts.decmopose_by_component() since its output series is centered at 0. 81 | masked_time_series = tfp.sts.MaskedTimeSeries( 82 | time_series=obs_time_series, is_missing=tf.math.is_nan(obs_time_series) 83 | ) 84 | tfp_posterior = tfp.sts.impute_missing_values( 85 | tfp_model, masked_time_series, tfp_params, include_observation_noise=False 86 | ) 87 | 88 | tfp_posterior_mean = jnp.array(tfp_posterior.mean()).squeeze() 89 | tfp_posterior_scale = jnp.array(jnp.array(tfp_posterior.stddev())).squeeze() 90 | 91 | # Fit and forecast with dynamax 92 | dynamax_posterior = dynamax_model.decompose_by_component(dynamax_params, obs_time_series) 93 | dynamax_posterior_mean = dynamax_model._uncenter_obs(dynamax_posterior["local_linear_trend"]["pos_mean"]).squeeze() 94 | dynamax_posterior_cov = dynamax_posterior["local_linear_trend"]["pos_cov"].squeeze() 95 | 96 | # Compare posterior inference by tfp and dynamax. 97 | # In comparing the smoothed posterior, we omit the first N time steps, 98 | # since the tfp and the dynamax implementations of STS has different settings in 99 | # distributions of initial state, which will influence the posterior inference of 100 | # the first few states. 101 | start = 10 102 | 103 | print(tfp_posterior_mean[start : start + 5]) 104 | print(dynamax_posterior_mean[start : start + 5]) 105 | print(tfp_posterior_scale[start : start + 5]) 106 | print(jnp.sqrt(dynamax_posterior_cov[start : start + 5])) 107 | 108 | assert jnp.allclose(tfp_posterior_mean[start:], dynamax_posterior_mean[start:], atol=1e-1, rtol=1e-1) 109 | assert jnp.allclose(tfp_posterior_scale[start:], jnp.sqrt(dynamax_posterior_cov)[start:], atol=1e-1, rtol=1e-1) 110 | 111 | 112 | @pytest.mark.skip( 113 | reason="Skipped because the forecast mean and variances are now computed as sample mean and variance, for dynamax model" 114 | ) 115 | def test_local_linear_trend_forecast(time_steps=150, key=jr.PRNGKey(3)): 116 | tfp_model, tfp_params, dynamax_model, dynamax_params, obs_time_series, vi_dists = _build_models(time_steps, key) 117 | 118 | # Fit and forecast with the tfp module. 119 | # Not use tfp.sts.decmopose_by_component() since its output series is centered at 0. 120 | masked_time_series = tfp.sts.MaskedTimeSeries( 121 | time_series=obs_time_series, is_missing=tf.math.is_nan(obs_time_series) 122 | ) 123 | tfp_posterior = tfp.sts.impute_missing_values( 124 | tfp_model, masked_time_series, tfp_params, include_observation_noise=False 125 | ) 126 | 127 | tfp_posterior_mean = jnp.array(tfp_posterior.mean()).squeeze() 128 | 129 | tfp_forecasts = tfp.sts.forecast( 130 | tfp_model, obs_time_series, parameter_samples=tfp_params, num_steps_forecast=50, include_observation_noise=True 131 | ) 132 | 133 | tfp_forecast_mean = jnp.array(tfp_forecasts.mean()).squeeze() 134 | tfp_forecast_scale = jnp.array(tfp_forecasts.stddev()).squeeze() 135 | 136 | dynamax_forecast = dynamax_model.forecast(dynamax_params, obs_time_series, num_forecast_steps=50)[1] 137 | dynamax_forecast_mean = jnp.concatenate(dynamax_forecast).mean(axis=0).squeeze() 138 | dynamax_forecast_cov = jnp.concatenate(dynamax_forecast).var(axis=0).squeeze() 139 | 140 | len_step = jnp.abs(tfp_posterior_mean[1:] - tfp_posterior_mean[:-1]).mean() 141 | assert jnp.allclose(tfp_forecast_mean, dynamax_forecast_mean, atol=0.5 * len_step) 142 | assert jnp.allclose(tfp_forecast_scale, jnp.sqrt(dynamax_forecast_cov), rtol=5e-2) 143 | -------------------------------------------------------------------------------- /sts_jax/structural_time_series/learning.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | from typing import Optional, Tuple 4 | 5 | import blackjax 6 | import jax.numpy as jnp 7 | import jax.random as jr 8 | import jax.scipy.stats.norm as norm 9 | import optax 10 | from dynamax.parameters import ( 11 | from_unconstrained, 12 | log_det_jac_constrain, 13 | to_unconstrained, 14 | ) 15 | from dynamax.types import PRNGKey 16 | from dynamax.utils.utils import ensure_array_has_batch_dim, pytree_stack 17 | from fastprogress.fastprogress import progress_bar 18 | from jax import jit, lax, value_and_grad, vmap 19 | from jax.tree_util import tree_leaves, tree_map 20 | from jaxtyping import Array, Float 21 | 22 | from .sts_components import ParamPropertiesSTS, ParamsSTS 23 | from .sts_ssm import StructuralTimeSeriesSSM 24 | 25 | 26 | def fit_vi( 27 | model: StructuralTimeSeriesSSM, 28 | initial_params: ParamsSTS, 29 | param_props: ParamPropertiesSTS, 30 | num_samples: int, 31 | emissions: Float[Array, "num_timesteps dim_obs"], 32 | inputs: Optional[Float[Array, "num_timesteps dim_inputs"]] = None, 33 | optimizer: optax.GradientTransformation = optax.adam(1e-1), 34 | K: int = 1, 35 | key: PRNGKey = jr.PRNGKey(0), 36 | num_step_iters: int = 50, 37 | ) -> Tuple[ParamsSTS, Float[Array, " num_samples"]]: 38 | r""" 39 | ADVI approximate the posterior distribtuion p of unconstrained global parameters 40 | with factorized multivatriate normal distribution: 41 | $$ 42 | q = \prod_{k=1}^{K} q_k(mu_k, sigma_k), 43 | $$ 44 | where K is dimension of p. 45 | 46 | The hyper-parameters of q to be optimized over are (mu_k, log_sigma_k))_{k=1}^{K}. 47 | 48 | The trick of reparameterization is employed to reduce the variance of SGD, 49 | which is achieved by written KL(q || p) as expectation over standard normal distribution 50 | so a sample from q is obstained by 51 | s = z * exp(log_sigma_k) + mu_k, 52 | where z is a sample from the standard multivarate normal distribtion. 53 | 54 | Args: 55 | sample_size (int): number of samples to be returned from the fitted approxiamtion q. 56 | M (int): number of fixed samples from q used in evaluation of ELBO. 57 | 58 | Returns: 59 | Samples from the approximate posterior q 60 | """ 61 | key1, key2 = jr.split(key, 2) 62 | # Make sure the emissions and covariates have batch dimensions 63 | batch_emissions = ensure_array_has_batch_dim(emissions, model.emission_shape) 64 | batch_inputs = ensure_array_has_batch_dim(inputs, model.inputs_shape) 65 | 66 | initial_unc_params = to_unconstrained(initial_params, param_props) 67 | 68 | @jit 69 | def unnorm_log_pos(_unc_params): 70 | params = from_unconstrained(_unc_params, param_props) 71 | log_det_jac = log_det_jac_constrain(params, param_props) 72 | log_pri = model.log_prior(params) + log_det_jac 73 | batch_lls = vmap(partial(model.marginal_log_prob, params))(batch_emissions, batch_inputs) 74 | lp = batch_lls.sum() + log_pri 75 | return lp 76 | 77 | @jit 78 | def elbo(vi_hyper, key): 79 | """Evaluate negative ELBO at fixed sample from the approximate distribution q.""" 80 | keys = iter(jr.split(key, 10)) 81 | # Turn VI parameters and fixed noises into samples of unconstrained parameters of q. 82 | unc_params = tree_map( 83 | lambda mu, ls: mu + jnp.exp(ls) * jr.normal(next(keys), ls.shape).sum(), vi_hyper["mu"], vi_hyper["log_sig"] 84 | ) 85 | log_probs = unnorm_log_pos(unc_params) 86 | log_q = jnp.array( 87 | tree_leaves( 88 | tree_map( 89 | lambda x, *p: norm.logpdf(x, p[0], jnp.exp(p[1])).sum(), 90 | unc_params, 91 | vi_hyper["mu"], 92 | vi_hyper["log_sig"], 93 | ) 94 | ) 95 | ).sum() 96 | return log_probs - log_q 97 | 98 | def loss_fn(vi_hyp, key): 99 | return -jnp.mean(vmap(partial(elbo, vi_hyp))(jr.split(key, K))) 100 | 101 | # Fit 102 | curr_vi_mus = initial_unc_params 103 | curr_vi_log_sigs = tree_map(lambda x: jnp.zeros(x.shape), initial_unc_params) 104 | curr_vi_hyper = OrderedDict() 105 | curr_vi_hyper["mu"] = curr_vi_mus 106 | curr_vi_hyper["log_sig"] = curr_vi_log_sigs 107 | 108 | # Optimize 109 | opt_state = optimizer.init(curr_vi_hyper) 110 | loss_grad_fn = value_and_grad(loss_fn) 111 | 112 | def train_step(carry, key): 113 | vi_hyp, opt_state = carry 114 | loss, grads = loss_grad_fn(vi_hyp, key) 115 | updates, opt_state = optimizer.update(grads, opt_state) 116 | vi_hyp = optax.apply_updates(vi_hyp, updates) 117 | return (vi_hyp, opt_state), loss 118 | 119 | # Run the optimizer 120 | initial_carry = (curr_vi_hyper, opt_state) 121 | (vi_hyp_fitted, opt_state), losses = lax.scan(train_step, initial_carry, jr.split(key1, num_step_iters)) 122 | 123 | # Sample from the learned approximate posterior q 124 | def vi_sample(key): 125 | return from_unconstrained( 126 | tree_map( 127 | lambda mu, s: mu + jnp.exp(s) * jr.normal(key, s.shape), vi_hyp_fitted["mu"], vi_hyp_fitted["log_sig"] 128 | ), 129 | param_props, 130 | ) 131 | 132 | samples = vmap(vi_sample)(jr.split(key2, num_samples)) 133 | 134 | return samples, losses 135 | 136 | 137 | def fit_hmc( 138 | model: StructuralTimeSeriesSSM, 139 | initial_params: ParamsSTS, 140 | param_props: ParamPropertiesSTS, 141 | num_samples: int, 142 | emissions: Float[Array, "num_timesteps dim_obs"], 143 | inputs: Optional[Float[Array, "num_timesteps dim_inputs"]] = None, 144 | key: PRNGKey = jr.PRNGKey(0), 145 | warmup_steps: int = 100, 146 | verbose: bool = True, 147 | ) -> Tuple[ParamsSTS, Float[Array, " num_samples"]]: 148 | """Sample parameters of the model using HMC.""" 149 | # Make sure the emissions and covariates have batch dimensions 150 | batch_emissions = ensure_array_has_batch_dim(emissions, model.emission_shape) 151 | batch_inputs = ensure_array_has_batch_dim(inputs, model.inputs_shape) 152 | 153 | initial_unc_params = to_unconstrained(initial_params, param_props) 154 | 155 | # The log likelihood that the HMC samples from 156 | def unnorm_log_pos(_unc_params): 157 | params = from_unconstrained(_unc_params, param_props) 158 | log_det_jac = log_det_jac_constrain(params, param_props) 159 | log_pri = model.log_prior(params) + log_det_jac 160 | batch_lls = vmap(partial(model.marginal_log_prob, params))(batch_emissions, batch_inputs) 161 | lp = log_pri + batch_lls.sum() 162 | return lp 163 | 164 | # Initialize the HMC sampler using window_adaptations 165 | warmup = blackjax.window_adaptation(blackjax.nuts, unnorm_log_pos, num_steps=warmup_steps, progress_bar=verbose) 166 | init_key, key = jr.split(key) 167 | hmc_initial_state, hmc_kernel, _ = warmup.run(init_key, initial_unc_params) 168 | 169 | @jit 170 | def hmc_step(hmc_state, step_key): 171 | next_hmc_state, _ = hmc_kernel(step_key, hmc_state) 172 | params = from_unconstrained(hmc_state.position, param_props) 173 | return next_hmc_state, params 174 | 175 | # Start sampling. 176 | log_probs = [] 177 | samples = [] 178 | hmc_state = hmc_initial_state 179 | pbar = progress_bar(range(num_samples)) if verbose else range(num_samples) 180 | for _ in pbar: 181 | step_key, key = jr.split(key) 182 | hmc_state, params = hmc_step(hmc_state, step_key) 183 | log_probs.append(-hmc_state.potential_energy) 184 | samples.append(params) 185 | 186 | # Combine the samples into a single pytree 187 | return pytree_stack(samples), jnp.array(log_probs) 188 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sts-jax 2 | Structural Time Series (STS) in JAX 3 | 4 | This library has a similar to design to [tfp.sts](https://www.tensorflow.org/probability/api_docs/python/tfp/sts), 5 | but is built entirely in JAX, 6 | and uses the [Dynamax](https://github.com/probml/dynamax/tree/main/dynamax) library 7 | for state-space models. 8 | We also include an implementation of the 9 | [causal impact](https://google.github.io/CausalImpact/) method. 10 | This has a similar to design to [tfcausalimpact](https://github.com/WillianFuks/tfcausalimpact), 11 | but is built entirely in JAX. 12 | 13 | ## Installation 14 | 15 | To install the latest development branch: 16 | 17 | ``` {.console} 18 | pip install git+https://github.com/probml/sts-jax 19 | ``` 20 | or use 21 | ``` {.console} 22 | git clone git@github.com:probml/sts-jax.git 23 | cd sts-jax 24 | pip install -e . 25 | ``` 26 | 27 | ## What are structural time series (STS) models? 28 | 29 | The STS model is a linear state space model with a specific structure. In particular, 30 | the latent state $z_t$ is a composition of states of all latent components: 31 | 32 | $$z_t = [c_{1, t}, c_{2, t}, ...]$$ 33 | 34 | where $c_{i,t}$ is the state of latent component $c_i$ at time step $t$. 35 | 36 | The STS model (with scalar Gaussian observations) takes the form: 37 | 38 | $$y_t = H_t z_t + u_t + \epsilon_t, \qquad \epsilon_t \sim \mathcal{N}(0, \sigma^2_t)$$ 39 | 40 | $$z_{t+1} = F_t z_t + R_t \eta_t, \qquad \eta_t \sim \mathcal{N}(0, Q_t)$$ 41 | 42 | where 43 | 44 | * $y_t$: observation (emission) at time $t$. 45 | * $\sigma^2_t$: variance of the observation noise. 46 | * $H_t$: emission matrix, which sums up the contributions of all latent components. 47 | * $u_t = x_t^T \beta$: regression component from external inputs. 48 | * $F_t$: fixed transition matrix of the latent dynamics. 49 | * $R_t$: the selection matrix, which is a subset of columns of base vector $e_i$, converting 50 | the non-singular covariance matrix into the (possibly singular) covariance matrix of 51 | the latent state $z_t$. 52 | * $Q_t$: non-singular covariance matrix of the latent state, so the dimension of $Q_t$ 53 | can be smaller than the dimension of $z_t$. 54 | 55 | The covariance matrix of the latent dynamics model takes the form $R Q R^T$, where $Q$ is 56 | a non-singular matrix (block diagonal), and $R$ is the selecting matrix. 57 | 58 | More information of STS models can be found in these books: 59 | 60 | > - \"Machine Learning: Advanced Topics\", K. Murphy, MIT Press 2023. 61 | > Available at . 62 | > - \"Time Series Analysis by State Space Methods (2nd edn)\", James Durbin, Siem Jan Koopman, 63 | > Oxford University Press, 2012. 64 | 65 | ## Usage 66 | 67 | In this library, an STS model is constructed by providing the observed time series and specifying a list of 68 | components and the distribution family of the observation. This library implements 69 | common STS components including **local linear trend** component, **seasonal** component, 70 | **cycle** component, **autoregressive** component, and **regression** component. 71 | The observed time series can follow either the **Gaussian** 72 | distribution or the **Poisson** distribution. (Other likelihood functions can also be added.) 73 | 74 | Internally, the STS model is converted to the corresponding state space model (SSM) and inference 75 | and learning of parameters are performed on the SSM. 76 | If the observation $Y_t$ follows a Gaussian distribution, the inference of latent variables 77 | $Z_{1:T}$ (gven the parameters) is based on the 78 | [Kalman filter](https://github.com/probml/dynamax/tree/main/dynamax/linear_gaussian_ssm). 79 | Alternatively, if the observation $Y_t$ follows Poisson distribution, with 80 | a mean given by $E[Y_t|Z_t] = e^{H_t Z_t + u_t}$, the inference of the 81 | latent variables $Z_{1:t}$ is based on a generalization of the extended 82 | Kalman filter, which we call the 83 | [conditional moment Gaussian filter](https://github.com/probml/dynamax/tree/main/dynamax/generalized_gaussian_ssm), 84 | based on [Tronarp 2018](https://acris.aalto.fi/ws/portalfiles/portal/17669270/cm_parapub.pdf). 85 | 86 | The marginal likelihood of $Y_{1:T}$ conditioned on parameters can be evaluated as a 87 | byproduct of the forwards filtering process. 88 | This can then be used to learn the parameters of the STS model, 89 | using **MLE** (based on SGD implemented in the library [optax](https://github.com/deepmind/optax)), 90 | **ADVI** (using a Gaussian posterior approximation on the unconstrained parameter space), 91 | or **HMC** (from the library [blackjax](https://github.com/blackjax-devs/blackjax)). 92 | The parameter estimation is done offline, given one or more historical timeseries. 93 | These parameters can then be used for forecasting the future. 94 | 95 | Below we illustrate the API applied to some example datasets. 96 | 97 | ## Electricity demand 98 | 99 | This example is adapted from the [TFP blog](https://blog.tensorflow.org/2019/03/structural-time-series-modeling-in.html). 100 | See [this file](./sts_jax/structural_time_series/demos/sts_electric_demo.ipynb) for a runnable version 101 | of this demo. 102 | 103 | The problem of interest is to forecast electricity demand in Victoria, Australia. 104 | The dataset contains hourly record of electricity demand and temperature measurements 105 | from the first 8 weeks of 2014. The following plot is the training 106 | set of the data, which contains measurements in the first 6 weeks. 107 | 108 |

109 | drawing 110 |

111 | 112 | We now build a model where the demand linearly depends on the temperature, 113 | but also has two seasonal components, and an auto-regressive component. 114 | 115 | ```python 116 | import sts_jax.structural_time_series.sts_model as sts 117 | 118 | hour_of_day_effect = sts.SeasonalDummy(num_seasons=24, 119 | name='hour_of_day_effect') 120 | day_of_week_effect = sts.SeasonalTrig(num_seasons=7, num_steps_per_season=24, 121 | name='day_of_week_effect') 122 | temperature_effect = sts.LinearRegression(dim_covariates=1, add_bias=True, 123 | name='temperature_effect') 124 | autoregress_effect = sts.Autoregressive(order=1, 125 | name='autoregress_effect') 126 | 127 | # The STS model is constructed by providing the observed time series, 128 | # specifying a list of components and the distribution family of the observations. 129 | model = sts.StructuralTimeSeries( 130 | [hour_of_day_effect, day_of_week_effect, temperature_effect, autoregress_effect], 131 | obs_time_series, 132 | obs_distribution='Gaussian', 133 | covariates=temperature_training_data) 134 | 135 | ``` 136 | In this case, we choose to fit the model using MLE. 137 | 138 | ```python 139 | # Perform the MLE estimation of parameters via SGD implemented in dynamax library. 140 | opt_param, _losses = model.fit_mle(obs_time_series, 141 | covariates=temperature_training_data, 142 | num_steps=2000) 143 | ``` 144 | 145 | We can now plug in the parameters and the future inputs, 146 | and use ancestral sampling from the 147 | filtered posterior to forecast future observations. 148 | 149 | ```python 150 | # The 'forecast' method samples the future means and future observations from the 151 | # predictive distribution, conditioned on the parameters of the model. 152 | forecast_means, forecasts = model.forecast(opt_param, 153 | obs_time_series, 154 | num_forecast_steps, 155 | past_covariates=temperature_training_data, 156 | forecast_covariates=temperature_predict_data) 157 | ``` 158 | 159 | The following plot shows the mean and 95\% probability interval of the forecast. 160 |

161 | drawing 162 |

163 | 164 | ## CO2 levels 165 | 166 | This example is adapted from the [TFP blog](https://blog.tensorflow.org/2019/03/structural-time-series-modeling-in.html). 167 | See [this file](./sts_jax/structural_time_series/demos/sts_co2_demo.ipynb) for a runnable version 168 | of the demo, which is similar to the electricity example. 169 | 170 | ## Time series with Poisson observations 171 | 172 | We can also fit STS models with discrete observations following the Poisson 173 | distribution. Internally, the inference of the latent states $Z_{1:T}$ in the corresponding SSM 174 | is based on the (generalized) extended Kalman filter implemented 175 | in the library dynamax. An STS model for a Poisson-distributed time series can be constructed 176 | simply by specifying observation distribution to be 'Poisson'. Everything else is the same 177 | as the Gaussian case. 178 | 179 | Below we create a synthetic dataset, following [this TFP example](https://www.tensorflow.org/probability/examples/STS_approximate_inference_for_models_with_non_Gaussian_observations). 180 | See [this file](./sts_jax/structural_time_series/demos/sts_poisson_demo.ipynb) for a runnable version 181 | of this demo. 182 | 183 | 184 | ```python 185 | import sts_jax.structural_time_series.sts_model as sts 186 | 187 | # This example uses a synthetic dataset and the STS model contains only a 188 | # local linear trend component. 189 | trend = sts.LocalLinearTrend() 190 | model = sts.StructuralTimeSeries([trend], 191 | obs_distribution='Poisson', 192 | obs_time_series=counts_training) 193 | 194 | # Fit the model using HMC algorithm 195 | param_samples, _log_probs = model.fit_hmc(num_samples=200, 196 | obs_time_series=counts_training) 197 | 198 | # Forecast into the future given samples of parameters returned by the HMC algorithm. 199 | forecasts = model.forecast(param_samples, obs_time_series, num_forecast_steps)[1] 200 | ``` 201 |

202 | drawing 203 |

204 | 205 | ### Comparison to TFP 206 | 207 | The TFP approach to STS with non-conjugate likelihoods is to perform 208 | HMC on the joint distribution of the latent states $Z_{1:T}$ and the parameters, conditioned 209 | on the observations $Y_{1:T}$. Since the dimension of the state space grows linearly 210 | with the length of the time series to be fitted, the implementation will be inefficient 211 | when $T$ is relatively large. By contrast, we (approximately) marginalize out $Z_{1:T}$, 212 | using a generalized extended Kalman filter, 213 | and just perform HMC in the collapsed parameter space. This is much faster, but yields 214 | comparable error, as we show below. (The burnin steps of HMC in the TFP-STS 215 | implementation is adjusted such that the forecast error of the two implementations 216 | are comparable.) 217 | 218 |

219 | drawing 220 |

221 | 222 | ## Causal Impact 223 | 224 | The [causal impact](https://google.github.io/CausalImpact/CausalImpact.html) 225 | method is implemented on top of the STS-JAX package. 226 | 227 | Below we show an example, where Y is the output time series and X is a parallel 228 | set of input covariates. We notice a sudden change in the response variable at time $t=70$, 229 | caused by some kind of intervention (e.g., launching an ad campaign). 230 | We define the causal impact of this intervention 231 | to be the change in the observed output compared to what we would have 232 | expected had the intervention not happened. 233 | See [this file](./sts_jax/causal_impact/causal_impact_demo.ipynb) 234 | for a runnable version of this demo. 235 | (See also the [CausalPy](https://www.pymc-labs.io/blog-posts/causalpy-a-new-package-for-bayesian-causal-inference-for-quasi-experiments/) 236 | package for some related methods.) 237 | 238 |

239 | drawing 240 |

241 | 242 | This is how we run inference: 243 | 244 | ```python 245 | from sts_jax.causal_impact.causal_impact import causal_impact 246 | 247 | # The causal impact is inferred by providing the target time series and covariates, 248 | # specifying the intervention time and the distribution family of the observation. 249 | # If the STS model is not given, an STS model with only a local linear trend component 250 | # in addition to the regression component is constructed by default internally. 251 | impact = causal_impact(obs_time_series, 252 | intervention_timepoint, 253 | 'Gaussian', 254 | covariates, 255 | sts_model=None) 256 | 257 | ``` 258 | 259 | 260 | The format of the output from our 261 | causal impact code follows that of the R package 262 | [CausalImpact](https://google.github.io/CausalImpact/CausalImpact.html), 263 | and is shown below. 264 | 265 | ```python 266 | impact.plot() 267 | ``` 268 | 269 |

270 | drawing 271 |

272 | 273 | ```python 274 | impact.print_summary() 275 | 276 | Posterior inference of the causal impact: 277 | 278 | Average Cumulative 279 | Actual 129.93 3897.88 280 | 281 | Prediction (s.d.) 120.01 (2.04) 3600.42 (61.31) 282 | 95% CI [114.82, 123.07] [3444.72, 3692.09] 283 | 284 | Absolute effect (s.d.) 9.92 (2.04) 297.45 (61.31) 285 | 95% CI [6.86, 15.11] [205.78, 453.16] 286 | 287 | Relative effect (s.d.) 8.29% (1.89%) 8.29% (1.89%) 288 | 95% CI [5.57%, 13.16%] [5.57%, 13.16%] 289 | 290 | Posterior tail-area probability p: 0.0050 291 | Posterior prob of a causal effect: 99.50% 292 | ``` 293 | 294 | 295 | 296 | ## About 297 | 298 | Authors: [Xinlong Xi](https://www.stat.ubc.ca/users/xinglong-li), 299 | [Kevin Murphy](https://www.cs.ubc.ca/~murphyk/). 300 | 301 | MIT License. 2022 302 | -------------------------------------------------------------------------------- /sts_jax/causal_impact/causal_impact.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import Optional 3 | 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | import matplotlib.pyplot as plt 7 | from dynamax.types import PRNGKey 8 | from jaxtyping import Array, Float 9 | 10 | import sts_jax.structural_time_series as sts 11 | 12 | 13 | class CausalImpact: 14 | """A wrapper class of helper functions of the causal impact""" 15 | 16 | def __init__( 17 | self, 18 | sts_model: sts.StructuralTimeSeries, 19 | intervention_time: int, 20 | predict: dict, 21 | effect: dict, 22 | summary: dict, 23 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 24 | ) -> None: 25 | """ 26 | Args: 27 | sts_model: an instance of the StructualTimeSeries. 28 | intervention_time: time point of the intervention. 29 | effect: a dictionary containing pointwise effect and cumulative effect returned by 30 | the function 'causal_impact' 31 | """ 32 | self.intervention_time = intervention_time 33 | self.sts_model = sts_model 34 | self.predict_point = predict["pointwise"] 35 | self.predict_interval = predict["interval"] 36 | self.impact_point = effect["pointwise"] 37 | self.impact_cumulat = effect["cumulative"] 38 | self.summary = summary 39 | self.time_series = obs_time_series 40 | 41 | def plot(self) -> None: 42 | """Plot the effect.""" 43 | x = jnp.arange(self.time_series.shape[0]) 44 | fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(9, 6), sharex=True, layout="constrained") 45 | 46 | # Plot the original obvervation and the counterfactual predict 47 | ax1.plot(x, self.time_series, color="black", lw=2, label="Observation") 48 | ax1.plot(x, self.predict_point, linestyle="dashed", color="blue", lw=2, label="Prediction") 49 | ax1.fill_between(x, self.predict_interval[0], self.predict_interval[1], color="blue", alpha=0.2) 50 | ax1.axvline(x=self.intervention_time, linestyle="dashed", color="gray", lw=2) 51 | ax1.set_title("Original time series") 52 | 53 | # Plot the pointwise causal impact 54 | ax2.plot(x, self.impact_point[0], linestyle="dashed", color="blue") 55 | ax2.fill_between(x, self.impact_point[1][0], self.impact_point[1][1], color="blue", alpha=0.2) 56 | ax2.axvline(x=self.intervention_time, linestyle="dashed", color="gray", lw=2) 57 | ax2.set_title("Poinwise causal impact") 58 | 59 | # Plot the cumulative causal impact 60 | ax3.plot(x, self.impact_cumulat[0], linestyle="dashed", color="blue") 61 | ax3.fill_between(x, self.impact_cumulat[1][0], self.impact_cumulat[1][1], color="blue", alpha=0.2) 62 | ax3.axvline(x=self.intervention_time, linestyle="dashed", color="gray", lw=2) 63 | ax3.set_title("Cumulative causal impact") 64 | 65 | return fig, ax1, ax2, ax3 66 | 67 | def print_summary(self) -> None: 68 | """Print the summary of the inferred effect as a table.""" 69 | # Number of columns for each column of the table to be printed 70 | ncol1, ncol2, ncol3 = 25, 20, 20 71 | ci_level = self.summary["confidence_level"] 72 | 73 | # Summary statistics of the post-intervention observation 74 | actual = self.summary["actual"] 75 | f_actual = f"{'Actual': <{ncol1}}" f"{actual.average: ^{ncol2}.2f}" f"{actual.cumulative: ^{ncol3}.2f}\n" 76 | 77 | # Summary statistics of the post-intervention prediction 78 | pred = self.summary["pred"] 79 | pred_sd = self.summary["pred_sd"] 80 | pred_l = self.summary["pred_lower"] 81 | pred_r = self.summary["pred_upper"] 82 | f_pred = ( 83 | f"{'Prediction (s.d.)': <{ncol1}}" 84 | f'{f"{pred.average:.2f} ({pred_sd.average:.2f})": ^{ncol2}}' 85 | f'{f"{pred.cumulative:.2f} ({pred_sd.cumulative:.2f})": ^{ncol3}}' 86 | ) 87 | f_pred_ci = ( 88 | f'{f"{ci_level:.0%} CI": <{ncol1}}' 89 | f'{f"[{pred_l.average:.2f}, {pred_r.average:.2f}]": ^{ncol2}}' 90 | f'{f"[{pred_l.cumulative:.2f}, {pred_r.cumulative:.2f}]": ^{ncol3}}' 91 | ) 92 | 93 | # Summary statistics of the absolute post-invervention effect 94 | abs_e = self.summary["abs_effect"] 95 | abs_sd = self.summary["abs_effect_sd"] 96 | abs_l = self.summary["abs_effect_lower"] 97 | abs_r = self.summary["abs_effect_upper"] 98 | f_abs = ( 99 | f"{'Absolute effect (s.d.)': <{ncol1}}" 100 | f'{f"{abs_e.average:.2f} ({abs_sd.average:.2f})": ^{ncol2}}' 101 | f'{f"{abs_e.cumulative:.2f} ({abs_sd.cumulative:.2f})": ^{ncol3}}' 102 | ) 103 | f_abs_ci = ( 104 | f'{f"{ci_level:.0%} CI": <{ncol1}}' 105 | f'{f"[{abs_l.average:.2f}, {abs_r.average:.2f}]": ^{ncol2}}' 106 | f'{f"[{abs_l.cumulative:.2f}, {abs_r.cumulative:.2f}]": ^{ncol3}}' 107 | ) 108 | 109 | # Summary statistics of the relative post-intervention effect 110 | rel_e = self.summary["rel_effect"] 111 | rel_sd = self.summary["rel_effect_sd"] 112 | rel_l = self.summary["rel_effect_lower"] 113 | rel_r = self.summary["rel_effect_upper"] 114 | f_rel = ( 115 | f"{'Relative effect (s.d.)': <{ncol1}}" 116 | f'{f"{rel_e.average:.2%} ({rel_sd.average:.2%})": ^{ncol2}}' 117 | f'{f"{rel_e.cumulative:.2%} ({rel_sd.cumulative:.2%})": ^{ncol3}}' 118 | ) 119 | f_rel_ci = ( 120 | f'{f"{ci_level:.0%} CI": <{ncol1}}' 121 | f'{f"[{rel_l.average:.2%}, {rel_r.average:.2%}]": ^{ncol2}}' 122 | f'{f"[{rel_l.cumulative:.2%}, {rel_r.cumulative:.2%}]": ^{ncol3}}' 123 | ) 124 | 125 | # Format all statistics 126 | summary_stats = ( 127 | f"Posterior inference of the causal impact:\n" 128 | f"\n" 129 | f"{'': <{ncol1}}{'Average': ^{ncol2}}{'Cumulative': ^{ncol3}}\n" 130 | f"{f_actual}\n" 131 | f"{f_pred}\n{f_pred_ci}\n" 132 | f"\n" 133 | f"{f_abs}\n{f_abs_ci}\n" 134 | f"\n" 135 | f"{f_rel}\n{f_rel_ci}\n" 136 | f"\n" 137 | f"Posterior tail-area probability p: {self.summary['tail_prob']:.4f}\n" 138 | f"Posterior prob of a causal effect: {1-self.summary['tail_prob']:.2%}\n" 139 | ) 140 | print(summary_stats) 141 | 142 | 143 | def causal_impact( 144 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 145 | intervention_timepoint: int, 146 | obs_distribution: str = "Gaussian", 147 | covariates: Optional[Float[Array, "num_timesteps dim_obs"]] = None, 148 | sts_model: sts.StructuralTimeSeries = None, 149 | confidence_level: Float = 0.95, 150 | key: PRNGKey = jr.PRNGKey(0), 151 | num_samples: int = 200, 152 | ) -> CausalImpact: 153 | r"""Inferring the causal impact of an intervention on a time series via the structural 154 | time series (STS) model. 155 | 156 | Args: 157 | obs_time_series: observed time series. 158 | intervention_time_point: the time point when the intervention took place. 159 | obs_distribution: distribution family of the observation, can be either 'Gaussian' or 160 | 'Poisson' 161 | covariates: covariates of the regression component of the STS model. 162 | sts_model: an instance of StructuralTimeSeries, if not given, an STS model with a local 163 | linear latent component is used by default. If covariates is not None, a linear 164 | regression term will also be added to the default STS model. 165 | confidence_level: confidence level of the prediction interval. 166 | num_samples: number of samples used to estimate the prediction mean and prediction 167 | interval. 168 | 169 | Returns: 170 | An instance of CausalImpact. 171 | """ 172 | 173 | assert obs_distribution in ["Gaussian", "Poisson"] 174 | if sts_model is not None: 175 | assert obs_distribution == sts_model.obs_distribution 176 | 177 | prob_lower, prob_upper = 0.5 - confidence_level / 2.0, 0.5 + confidence_level / 2.0 178 | key1, key2, key3 = jr.split(key, 3) 179 | num_timesteps, dim_obs = obs_time_series.shape 180 | 181 | # Split the data into pre-intervention period and post-intervention period 182 | time_series_pre = obs_time_series[:intervention_timepoint] 183 | time_series_pos = obs_time_series[intervention_timepoint:] 184 | 185 | if covariates is not None: 186 | dim_covariates = covariates.shape[-1] 187 | covariates_pre = covariates[:intervention_timepoint] 188 | covariates_pos = covariates[intervention_timepoint:] 189 | else: 190 | covariates_pre = covariates_pos = None 191 | 192 | # Construct the default STS model with only one local linear trend latent component. 193 | if sts_model is None: 194 | local_linear_trend = sts.LocalLinearTrend() 195 | components = [local_linear_trend] 196 | # Add one linear regression component if covariates is not None. 197 | if covariates is not None: 198 | linear_regression = sts.LinearRegression(dim_covariates=dim_covariates) 199 | components.append(linear_regression) 200 | sts_model = sts.StructuralTimeSeries(components, obs_time_series, covariates, obs_distribution) 201 | 202 | # Fit the STS model, sample from the past and forecast. 203 | # Model fitting 204 | params_posterior_samples, _ = sts_model.fit_hmc(num_samples, time_series_pre, covariates=covariates_pre, key=key1) 205 | # Sample observations from the posterior predictive sample given paramters of the STS model. 206 | posterior_sample_means, posterior_samples = sts_model.posterior_sample( 207 | params_posterior_samples, time_series_pre, covariates_pre, key=key2 208 | ) 209 | # Forecast by sampling observations from the predictive distribution in the future. 210 | forecast_means, forecast_samples = sts_model.forecast( 211 | params_posterior_samples, time_series_pre, time_series_pos.shape[0], 100, covariates_pre, covariates_pos, key3 212 | ) 213 | forecast_means = forecast_means.mean(axis=1) 214 | forecast_samples = forecast_samples.mean(axis=1) 215 | 216 | predict_means = jnp.concatenate((posterior_sample_means, forecast_means), axis=1).squeeze() 217 | predict_observations = jnp.concatenate((posterior_samples, forecast_samples), axis=1).squeeze() 218 | 219 | confidence_bounds = jnp.quantile(predict_observations, jnp.array([prob_lower, prob_upper]), axis=0) 220 | predict_point = predict_means.mean(axis=0) 221 | predict_interval_upper = confidence_bounds[0] 222 | predict_interval_lower = confidence_bounds[1] 223 | 224 | cum_predict_point = jnp.cumsum(predict_point) 225 | cum_confidence_bounds = jnp.quantile( 226 | predict_observations.cumsum(axis=1), jnp.array([prob_lower, prob_upper]), axis=0 227 | ) 228 | cum_predict_interval_upper = cum_confidence_bounds[0] 229 | cum_predict_interval_lower = cum_confidence_bounds[1] 230 | 231 | # Evaluate the causal impact 232 | impact_point = obs_time_series.squeeze() - predict_point 233 | impact_interval_lower = obs_time_series.squeeze() - predict_interval_upper 234 | impact_interval_upper = obs_time_series.squeeze() - predict_interval_lower 235 | 236 | cum_obs = jnp.cumsum(obs_time_series.squeeze()) 237 | cum_impact_point = cum_obs - cum_predict_point 238 | cum_impact_interval_lower = cum_obs - cum_predict_interval_upper 239 | cum_impact_interval_upper = cum_obs - cum_predict_interval_lower 240 | 241 | impact = { 242 | "pointwise": (impact_point, (impact_interval_lower, impact_interval_upper)), 243 | "cumulative": (cum_impact_point, (cum_impact_interval_lower, cum_impact_interval_upper)), 244 | } 245 | 246 | predict = {"pointwise": predict_point, "interval": confidence_bounds} 247 | 248 | summary = dict() 249 | summary["confidence_level"] = confidence_level 250 | Stats = namedtuple("Stats", ["average", "cumulative"]) 251 | 252 | # Summary statistics of the post-intervention observation 253 | summary["actual"] = Stats(average=time_series_pos.mean(), cumulative=time_series_pos.sum()) 254 | 255 | # Summary statistics of the post-intervention prediction 256 | summary["pred"] = Stats(average=forecast_means.mean(axis=0).mean(), cumulative=forecast_means.mean(axis=0).sum()) 257 | summary["pred_lower"] = Stats( 258 | average=jnp.quantile(forecast_samples.mean(axis=1), prob_lower), 259 | cumulative=jnp.quantile(forecast_samples.sum(axis=1), prob_lower), 260 | ) 261 | summary["pred_upper"] = Stats( 262 | average=jnp.quantile(forecast_samples.mean(axis=1), prob_upper), 263 | cumulative=jnp.quantile(forecast_samples.sum(axis=1), prob_upper), 264 | ) 265 | summary["pred_sd"] = Stats( 266 | average=jnp.std(forecast_samples.mean(axis=1)), cumulative=jnp.std(forecast_samples.sum(axis=1)) 267 | ) 268 | 269 | # Summary statistics of the absolute post-invervention effect 270 | effect_means = time_series_pos - forecast_means 271 | effects = time_series_pos - forecast_samples 272 | summary["abs_effect"] = Stats(average=effect_means.mean(axis=0).mean(), cumulative=effect_means.mean(axis=0).sum()) 273 | summary["abs_effect_lower"] = Stats( 274 | average=jnp.quantile(effects.mean(axis=1), prob_lower), cumulative=jnp.quantile(effects.sum(axis=1), prob_lower) 275 | ) 276 | summary["abs_effect_upper"] = Stats( 277 | average=jnp.quantile(effects.mean(axis=1), prob_upper), cumulative=jnp.quantile(effects.sum(axis=1), prob_upper) 278 | ) 279 | summary["abs_effect_sd"] = Stats(average=jnp.std(effects.mean(axis=1)), cumulative=jnp.std(effects.sum(axis=1))) 280 | 281 | # Summary statistics of the relative post-intervention effect 282 | rel_effect_means_sum = effect_means.sum(axis=1) / forecast_means.sum(axis=1) 283 | rel_effects_sum = effects.sum(axis=1) / forecast_samples.sum(axis=1) 284 | summary["rel_effect"] = Stats(average=rel_effect_means_sum.mean(), cumulative=rel_effect_means_sum.mean()) 285 | summary["rel_effect_lower"] = Stats( 286 | average=jnp.quantile(rel_effects_sum, prob_lower), cumulative=jnp.quantile(rel_effects_sum, prob_lower) 287 | ) 288 | summary["rel_effect_upper"] = Stats( 289 | average=jnp.quantile(rel_effects_sum, prob_upper), cumulative=jnp.quantile(rel_effects_sum, prob_upper) 290 | ) 291 | summary["rel_effect_sd"] = Stats(average=jnp.std(rel_effects_sum), cumulative=jnp.std(rel_effects_sum)) 292 | 293 | # Add one-sided tail-area probability of overall impact 294 | effects_sum = effects.sum(axis=1) 295 | p_tail = float(1 + min((effects_sum >= 0).sum(), (effects_sum <= 0).sum())) / (1 + effects_sum.shape[0]) 296 | summary["tail_prob"] = p_tail 297 | 298 | return CausalImpact(sts_model, intervention_timepoint, predict, impact, summary, obs_time_series) 299 | -------------------------------------------------------------------------------- /sts_jax/structural_time_series/sts_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | import optax 7 | import tensorflow_probability.substrates.jax.bijectors as tfb 8 | from dynamax.parameters import ParameterProperties 9 | from dynamax.types import PRNGKey, Scalar 10 | from dynamax.utils.bijectors import RealToPSDBijector 11 | from dynamax.utils.distributions import InverseWishart as IW 12 | from jax import jit, vmap 13 | from jax.tree_util import tree_leaves, tree_map 14 | from jaxtyping import Array, Float 15 | from tensorflow_probability.substrates.jax import distributions as tfd 16 | 17 | from .learning import fit_hmc, fit_vi 18 | from .sts_components import ParamPropertiesSTS, ParamsSTS, STSComponent, STSRegression 19 | from .sts_ssm import StructuralTimeSeriesSSM 20 | 21 | 22 | class StructuralTimeSeries: 23 | r"""The class of the Bayesian structural time series (STS) model. 24 | 25 | The STS model is a linear state space model with a specific structure. In particular, 26 | the latent state $z_t$ is a composition of states of all latent components: 27 | 28 | $$z_t = [c_{1, t}, c_{2, t}, ...]$$ 29 | 30 | where $c_{i,t}$ is the state of latent component $c_i$ at time step $t$. 31 | 32 | The STS model takes the form: 33 | 34 | $$y_t = H_t z_t + u_t + \epsilon_t, \qquad \epsilon_t \sim \mathcal{N}(0, \Sigma_t)$$ 35 | $$z_{t+1} = F_t z_t + R_t \eta_t, \qquad eta_t \sim \mathcal{N}(0, Q_t)$$ 36 | 37 | where 38 | 39 | * $H_t$: emission matrix, which sums up the contributions of all latent components. 40 | * $u_t$: is the contribution of the regression component. 41 | * $F_t$: transition matrix of the latent dynamics 42 | * $R_t$: the selection matrix, which is a subset of clumnes of base vector $I$, converting 43 | the non-singular covariance matrix into a (possibly singular) covariance matrix for 44 | the latent state $z_t$. 45 | * $Q_t$: nonsingular covariance matrix of the latent state, so the dimension of $Q_t$ 46 | can be smaller than the dimension of $z_t$. 47 | 48 | The covariance matrix of the latent dynamics model takes the form $R Q R^T$, where $Q$ is 49 | a nonsingular matrix (block diagonal), and $R$ is the selecting matrix. For example, 50 | for an STS model for a 1-d time series with a local linear component and a (dummy) seasonal 51 | component with 4 seasons, $R$ and $Q$ takes the form 52 | $$ 53 | Q = \begin{bmatrix} 54 | v_1 & 0 & 0 \\ 55 | 0 & v_2 & 0 \\ 56 | 0 & 0 & v_3 57 | \end{bmatrix}, 58 | \qquad 59 | R = \begin{bmatrix} 60 | 1 & 0 & 0 \\ 61 | 0 & 1 & 0 \\ 62 | 0 & 0 & 1 \\ 63 | 0 & 0 & 0 \\ 64 | 0 & 0 & 0 65 | \end{bmatrix} 66 | $$ 67 | 68 | where $v_1$, $v_2$ are variances of the 'level' part and the 'trend' part of the 69 | local linear component, and $v_3$ is the variance of the disturbance of the seasonal 70 | component. 71 | """ 72 | 73 | def __init__( 74 | self, 75 | components: List[Union[STSComponent, STSRegression]], 76 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 77 | covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 78 | obs_distribution: str = "Gaussian", 79 | obs_cov_prior: tfd.Distribution = None, 80 | obs_cov_constrainer: tfb.Bijector = None, 81 | constant_offset: bool = True, 82 | name: str = "sts_model", 83 | ) -> None: 84 | r""" 85 | Args: 86 | components: list of components of the STS model, could be instances of 87 | STSComponent and STSRegression, each component must have a unique name. 88 | obs_time_series: observed time series to be modeled. 89 | covariates: series of the covariates (if there is a STSRegression component). 90 | obs_distribution: the distribution family of the observed time series. 91 | Currently it can be either 'Gaussian' or 'Poisson'. 92 | obs_cov_prior: the prior distribution of the covariance matrix of the observation 93 | $y_t$ at each time step. This is only required when obs_distribution='Gaussian'. 94 | obs_cov_constrainer: a bijector whose inverse operator transforms the observation 95 | covariance matrix to unconstrained space, for the purpose of learning. 96 | constant_offset: If true, the observed time series will be centered before it is 97 | fitted. If obs_distribution='Poisson', log(obs_time_series) is centered. 98 | name: name of the STS model. 99 | """ 100 | names = [c.name for c in components] 101 | assert len(set(names)) == len(names), "Components should not share the same name." 102 | assert obs_distribution in [ 103 | "Gaussian", 104 | "Poisson", 105 | ], "The distribution of observations must be Gaussian or Poisson." 106 | 107 | self.name = name 108 | self.dim_obs = obs_time_series.shape[-1] 109 | self.obs_distribution = obs_distribution 110 | self.dim_covariate = covariates.shape[-1] if covariates is not None else 0 111 | 112 | # Convert the time series into the unconstrained space if obs_distribution is not Gaussian 113 | obs_unconstrained = self._unconstrain_obs(obs_time_series) 114 | self.offset = obs_unconstrained.mean(axis=0) if constant_offset else 0.0 115 | obs_centered = self._center_obs(obs_time_series) 116 | obs_centered_unconstrained = self._unconstrain_obs(obs_centered) 117 | 118 | # Initialize model paramters with the observed time series 119 | initial = obs_centered_unconstrained[0] 120 | obs_scale = jnp.std(jnp.abs(jnp.diff(obs_centered_unconstrained, axis=0)), axis=0) 121 | # If a regression component is included, remove the effect of the regression model 122 | # before initializing parameters of the time series. 123 | regression = None 124 | for c in components: 125 | if isinstance(c, STSRegression): 126 | assert len(components) > 1, "The STS model cannot only contain one regresion component!" 127 | regression = c 128 | regression.initialize_params(covariates, obs_centered_unconstrained) 129 | residuals = obs_centered_unconstrained - regression.get_reg_value(regression.params, covariates) 130 | initial = residuals[0] 131 | obs_scale = jnp.std(jnp.abs(jnp.diff(residuals, axis=0)), axis=0) 132 | for c in components: 133 | if not isinstance(c, STSRegression): 134 | c.initialize_params(initial, obs_scale) 135 | 136 | # Aggeragate components 137 | self.initial_distributions = OrderedDict() 138 | self.param_props = OrderedDict() 139 | self.param_priors = OrderedDict() 140 | self.params = OrderedDict() 141 | self.trans_mat_getters = OrderedDict() 142 | self.trans_cov_getters = OrderedDict() 143 | self.obs_mats = OrderedDict() 144 | self.cov_select_mats = OrderedDict() 145 | 146 | for c in components: 147 | if not isinstance(c, STSRegression): 148 | self.initial_distributions[c.name] = c.initial_distribution 149 | self.param_props[c.name] = c.param_props 150 | self.param_priors[c.name] = c.param_priors 151 | self.params[c.name] = c.params 152 | self.trans_mat_getters[c.name] = c.get_trans_mat 153 | self.trans_cov_getters[c.name] = c.get_trans_cov 154 | self.obs_mats[c.name] = c.obs_mat 155 | self.cov_select_mats[c.name] = c.cov_select_mat 156 | 157 | # Add parameters of the observation model if the observed time series is 158 | # normally distributed. 159 | if self.obs_distribution == "Gaussian": 160 | if obs_cov_prior is None: 161 | obs_cov_prior = IW(df=self.dim_obs, scale=1e-4 * obs_scale**2 * jnp.eye(self.dim_obs)) 162 | if obs_cov_constrainer is None: 163 | obs_cov_constrainer = RealToPSDBijector() 164 | obs_cov_props = ParameterProperties(trainable=True, constrainer=obs_cov_constrainer) 165 | obs_cov = obs_cov_prior.mode() 166 | self.param_props["obs_model"] = OrderedDict({"cov": obs_cov_props}) 167 | self.param_priors["obs_model"] = OrderedDict({"cov": obs_cov_prior}) 168 | self.params["obs_model"] = OrderedDict({"cov": obs_cov}) 169 | 170 | # Always put the regression term at the last position of the OrderedDict. 171 | if regression is not None: 172 | self.param_props[regression.name] = regression.param_props 173 | self.param_priors[regression.name] = regression.param_priors 174 | self.params[regression.name] = regression.params 175 | self.reg_func = regression.get_reg_value 176 | else: 177 | self.reg_func = None 178 | 179 | def as_ssm(self) -> StructuralTimeSeriesSSM: 180 | """Convert the STS model as a state space model.""" 181 | sts_ssm = StructuralTimeSeriesSSM( 182 | self.params, 183 | self.param_props, 184 | self.param_priors, 185 | self.trans_mat_getters, 186 | self.trans_cov_getters, 187 | self.obs_mats, 188 | self.cov_select_mats, 189 | self.initial_distributions, 190 | self.reg_func, 191 | self.obs_distribution, 192 | self.dim_covariate, 193 | ) 194 | return sts_ssm 195 | 196 | def sample( 197 | self, 198 | sts_params: ParamsSTS, 199 | num_timesteps: int, 200 | covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 201 | key: PRNGKey = jr.PRNGKey(0), 202 | ) -> Tuple[Float[Array, "num_timesteps dim_obs"], Float[Array, "num_timesteps dim_obs"]]: 203 | """Sample observed time series given model parameters.""" 204 | sts_params = self._ensure_param_has_batch_dim(sts_params) 205 | sts_ssm = self.as_ssm() 206 | 207 | @jit 208 | def single_sample(sts_param): 209 | sample_mean, sample_obs = sts_ssm.sample(sts_param, num_timesteps, covariates, key) 210 | return self._uncenter_obs(sample_mean), self._uncenter_obs(sample_obs) 211 | 212 | sample_means, sample_obs = vmap(single_sample)(sts_params) 213 | 214 | return sample_means, sample_obs 215 | 216 | def marginal_log_prob( 217 | self, 218 | sts_params: ParamsSTS, 219 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 220 | covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 221 | ) -> Scalar: 222 | """Compute marginal log likelihood of the observed time series given model parameters.""" 223 | obs_centered = self._center_obs(obs_time_series) 224 | sts_ssm = self.as_ssm() 225 | 226 | return sts_ssm.marginal_log_prob(sts_params, obs_centered, covariates) 227 | 228 | def posterior_sample( 229 | self, 230 | sts_params: ParamsSTS, 231 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 232 | covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 233 | key: PRNGKey = jr.PRNGKey(0), 234 | ) -> Tuple[Float[Array, "num_params num_timesteps dim_obs"], Float[Array, "num_params num_timesteps dim_obs"]]: 235 | """Sample latent states from their posterior given model parameters.""" 236 | sts_params = self._ensure_param_has_batch_dim(sts_params) 237 | obs_centered = self._center_obs(obs_time_series) 238 | sts_ssm = self.as_ssm() 239 | 240 | @jit 241 | def single_sample(sts_param): 242 | predictive_mean, predictive_obs = sts_ssm.posterior_sample(sts_param, obs_centered, covariates, key) 243 | return self._uncenter_obs(predictive_mean), self._uncenter_obs(predictive_obs) 244 | 245 | predictive_means, predictive_samples = vmap(single_sample)(sts_params) 246 | 247 | return predictive_means, predictive_samples 248 | 249 | def fit_mle( 250 | self, 251 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 252 | covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 253 | num_steps: int = 1000, 254 | initial_params: ParamsSTS = None, 255 | param_props: ParamPropertiesSTS = None, 256 | optimizer: optax.GradientTransformation = optax.adam(1e-1), 257 | key: PRNGKey = jr.PRNGKey(0), 258 | ) -> Tuple[ParamsSTS, Float[Array, " num_steps"]]: 259 | """Perform maximum likelihood estimate of parameters of the STS model.""" 260 | obs_centered = self._center_obs(obs_time_series) 261 | sts_ssm = self.as_ssm() 262 | curr_params = sts_ssm.params if initial_params is None else initial_params 263 | if param_props is None: 264 | param_props = sts_ssm.param_props 265 | 266 | optimal_params, losses = sts_ssm.fit_sgd( 267 | curr_params, 268 | param_props, 269 | obs_centered, 270 | num_epochs=num_steps, 271 | key=key, 272 | inputs=covariates, 273 | optimizer=optimizer, 274 | ) 275 | 276 | return optimal_params, losses 277 | 278 | def fit_vi( 279 | self, 280 | num_samples: int, 281 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 282 | covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 283 | initial_params: ParamsSTS = None, 284 | param_props: ParamPropertiesSTS = None, 285 | num_step_iters: int = 50, 286 | key: PRNGKey = jr.PRNGKey(0), 287 | ): 288 | """Sample parameters of the STS model from ADVI posterior.""" 289 | sts_ssm = self.as_ssm() 290 | if initial_params is None: 291 | initial_params = sts_ssm.params 292 | if param_props is None: 293 | param_props = sts_ssm.param_props 294 | 295 | obs_centered = self._center_obs(obs_time_series) 296 | param_samps, losses = fit_vi( 297 | sts_ssm, initial_params, param_props, num_samples, obs_centered, covariates, key, num_step_iters 298 | ) 299 | elbo = -losses 300 | 301 | return param_samps, elbo 302 | 303 | def fit_hmc( 304 | self, 305 | num_samples: int, 306 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 307 | covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 308 | initial_params: ParamsSTS = None, 309 | param_props: ParamPropertiesSTS = None, 310 | warmup_steps: int = 100, 311 | verbose: bool = True, 312 | key: PRNGKey = jr.PRNGKey(0), 313 | ): 314 | """Sample parameters of the STS model from their posterior distributions with HMC (NUTS).""" 315 | sts_ssm = self.as_ssm() 316 | # Initialize via fit MLE if initial params is not given. 317 | if initial_params is None: 318 | initial_params, _losses = self.fit_mle(obs_time_series, covariates, num_steps=500) 319 | if param_props is None: 320 | param_props = self.param_props 321 | 322 | obs_centered = self._center_obs(obs_time_series) 323 | param_samps, param_log_probs = fit_hmc( 324 | sts_ssm, initial_params, param_props, num_samples, obs_centered, covariates, key, warmup_steps, verbose 325 | ) 326 | return param_samps, param_log_probs 327 | 328 | def decompose_by_component( 329 | self, 330 | sts_params: ParamsSTS, 331 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 332 | covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 333 | num_pos_samples: int = 100, 334 | key: PRNGKey = jr.PRNGKey(0), 335 | ) -> OrderedDict: 336 | r"""Decompose the STS model into components and return the means and variances 337 | of the marginal posterior of each component. 338 | 339 | The marginal posterior of each component is obtained by averaging over 340 | conditional posteriors of that component using Kalman smoother conditioned 341 | on the sts_params. Each sts_params is a posterior sample of the STS model 342 | conditioned on observed_time_series. 343 | 344 | The marginal posterior mean and variance is computed using the formula 345 | $$E[X] = E[E[X|Y]]$$, 346 | $$Var(Y) = E[Var(X|Y)] + Var[E[X|Y]]$$ 347 | which holds for any random variables X and Y. 348 | 349 | Returns: 350 | component_dists: (OrderedDict) each item is a tuple of means and variances 351 | of one component. 352 | """ 353 | 354 | # Sample parameters from the posterior if parameters is not given 355 | if sts_params is None: 356 | sts_params = self.fit_hmc(num_pos_samples, obs_time_series, covariates, key=key) 357 | 358 | sts_params = self._ensure_param_has_batch_dim(sts_params) 359 | obs_centered = self._center_obs(obs_time_series) 360 | 361 | @jit 362 | def single_decompose(sts_param): 363 | sts_ssm = self.as_ssm() 364 | return sts_ssm.component_posterior(sts_param, obs_centered, covariates) 365 | 366 | component_conditional_pos = vmap(single_decompose)(sts_params) 367 | 368 | component_dists = OrderedDict() 369 | # Obtain the marginal posterior 370 | for c, pos in component_conditional_pos.items(): 371 | means = pos["pos_mean"] 372 | covs = pos["pos_cov"] 373 | # Use the formula: E[X] = E[E[X|Y]] 374 | mean_series = means.mean(axis=0) 375 | # Use the formula: Var(X) = E[Var(X|Y)] + Var(E[X|Y]) 376 | cov_series = jnp.mean(covs, axis=0)[..., 0] + jnp.var(means, axis=0) 377 | component_dists[c] = {"pos_mean": mean_series, "pos_cov": cov_series} 378 | 379 | return component_dists 380 | 381 | def forecast( 382 | self, 383 | sts_params: ParamsSTS, 384 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 385 | num_forecast_steps: int, 386 | num_forecast_samples: int = 100, 387 | past_covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 388 | forecast_covariates: Optional[Float[Array, "num_forecast_steps dim_covariates"]] = None, 389 | key: PRNGKey = jr.PRNGKey(0), 390 | ) -> Tuple[ 391 | Float[Array, "num_params num_forecast_samples num_forecast_steps dim_obs"], 392 | Float[Array, "num_params num_forecast_samples num_forecast_steps dim_obs"], 393 | ]: 394 | """Forecast. 395 | 396 | Args: 397 | sts_params: parameters of the STS model, or a batch of STS parameters. 398 | obs_time_series: observed time series. 399 | num_forecast_steps: number of steps of forecast. 400 | num_forecast_samples: number of samples for each STS parameter, used to compute 401 | summary statistics of the forecast, conditioned on the STS parameter. 402 | past_covariates: inputs of the regression component of the STS model, corresponding 403 | to the observed time series. 404 | forecast_covariates: inputs of the regression component of the STS model, 405 | used in forecasting. 406 | """ 407 | sts_params = self._ensure_param_has_batch_dim(sts_params) 408 | obs_centered = self._center_obs(obs_time_series) 409 | sts_ssm = self.as_ssm() 410 | 411 | @jit 412 | def single_forecast(sts_param): 413 | _forecast_mean, _forecast_obs = sts_ssm.forecast( 414 | sts_param, 415 | obs_centered, 416 | num_forecast_steps, 417 | num_forecast_samples, 418 | past_covariates, 419 | forecast_covariates, 420 | key, 421 | ) 422 | forecast_mean = vmap(self._uncenter_obs)(_forecast_mean) 423 | forecast_obs = vmap(self._uncenter_obs)(_forecast_obs) 424 | return forecast_mean, forecast_obs 425 | 426 | forecast_means, forecast_obs = vmap(single_forecast)(sts_params) 427 | 428 | return forecast_means, forecast_obs 429 | 430 | def _ensure_param_has_batch_dim(self, sts_params): 431 | """Turn parameters into batch if only one parameter is given""" 432 | # All latent components except for 'LinearRegression' have transition covariances 433 | # and the linear regression has coefficient matrix. 434 | # When the observation is Gaussian, the observation model also has a covariance matrix. 435 | # So here we assume that the largest dimension of parameters is 2. 436 | param_list = tree_leaves(sts_params) 437 | max_params_dim = max([len(x.shape) for x in param_list]) 438 | if max_params_dim > 2: 439 | return sts_params 440 | else: 441 | return tree_map(lambda x: jnp.expand_dims(x, 0), sts_params) 442 | 443 | def _constrain_obs(self, obs_time_series): 444 | if self.obs_distribution == "Gaussian": 445 | return obs_time_series 446 | elif self.obs_distribution == "Poisson": 447 | return jnp.exp(obs_time_series) 448 | 449 | def _unconstrain_obs(self, obs_time_series_constrained): 450 | if self.obs_distribution == "Gaussian": 451 | return obs_time_series_constrained 452 | elif self.obs_distribution == "Poisson": 453 | return jnp.log(obs_time_series_constrained) 454 | 455 | def _center_obs(self, obs_time_series): 456 | if self.obs_distribution == "Gaussian": 457 | obs_unconstrained = self._unconstrain_obs(obs_time_series) 458 | return self._constrain_obs(obs_unconstrained - self.offset) 459 | elif self.obs_distribution == "Poisson": 460 | return obs_time_series 461 | 462 | def _uncenter_obs(self, obs_time_series_centered): 463 | if self.obs_distribution == "Gaussian": 464 | obs_centered_unconstrained = self._unconstrain_obs(obs_time_series_centered) 465 | return self._constrain_obs(obs_centered_unconstrained + self.offset) 466 | elif self.obs_distribution == "Poisson": 467 | return obs_time_series_centered 468 | -------------------------------------------------------------------------------- /sts_jax/structural_time_series/sts_ssm.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | from typing import Callable, List, Optional, Tuple 4 | 5 | import jax.numpy as jnp 6 | import jax.random as jr 7 | import jax.scipy as jsp 8 | from dynamax.generalized_gaussian_ssm.inference import EKFIntegrals 9 | from dynamax.generalized_gaussian_ssm.inference import ( 10 | iterated_conditional_moments_gaussian_filter as cmgf_filt, 11 | ) 12 | from dynamax.generalized_gaussian_ssm.inference import ( 13 | iterated_conditional_moments_gaussian_smoother as cmgf_smooth, 14 | ) 15 | from dynamax.generalized_gaussian_ssm.models import ParamsGGSSM 16 | from dynamax.linear_gaussian_ssm import ( 17 | ParamsLGSSM, 18 | ParamsLGSSMDynamics, 19 | ParamsLGSSMEmissions, 20 | ParamsLGSSMInitial, 21 | lgssm_filter, 22 | lgssm_posterior_sample, 23 | lgssm_smoother, 24 | ) 25 | from dynamax.ssm import SSM 26 | from dynamax.types import PRNGKey, Scalar 27 | from jax import lax, vmap 28 | from jaxtyping import Array, Float 29 | from tensorflow_probability.substrates.jax import distributions as tfd 30 | from tensorflow_probability.substrates.jax.distributions import ( 31 | MultivariateNormalFullCovariance as MVN, 32 | ) 33 | from tensorflow_probability.substrates.jax.distributions import Poisson as Pois 34 | 35 | from .sts_components import ParamPriorsSTS, ParamPropertiesSTS, ParamsSTS 36 | 37 | # from dynamax.generalized_gaussian_ssm import ( 38 | # ParamsGGSSM, 39 | # EKFIntegrals, 40 | # iterated_conditional_moments_gaussian_filter as cmgf_filt, 41 | # iterated_conditional_moments_gaussian_smoother as cmgf_smooth 42 | # ) 43 | 44 | 45 | class StructuralTimeSeriesSSM(SSM): 46 | """Formulate the structual time series(STS) model into a LinearSSM model, 47 | which always have block-diagonal dynamics covariance matrix and fixed transition matrices. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | params: ParamsSTS, 53 | param_props: ParamPropertiesSTS, 54 | param_priors: ParamPriorsSTS, 55 | trans_mat_getters: List, 56 | trans_cov_getters: List, 57 | obs_mats: List, 58 | cov_select_mats: List, 59 | initial_distributions: List, 60 | reg_func: Callable = None, 61 | obs_distribution: str = "Gaussian", 62 | dim_covariates: int = 0, 63 | ) -> None: 64 | """ 65 | Args: 66 | params: parameters of the STS model, is an instance of OrderedDict, each item is 67 | parameters of one component. 68 | param_props: properties of parameters of the STS model, having same tree structure 69 | with 'params', each leaf node is a instance of 'ParameterProperties' for the 70 | parameter in the conrresponding leaf node of 'params'. 71 | param_priors: priors of parameters of the STS model, having same tree structure 72 | with 'params', each leaf node is a prior distribution for the parameter in the 73 | corresponding leaf node of 'params'. 74 | trans_mat_getters: list of transition_matrix_getters, one for each latent component. 75 | trans_cov_getters: list of nonsingular_transition_covariance_getters, one for each 76 | latent component. 77 | obs_mats: list of observation matrices, one for each latent component. 78 | cov_select_mats: list of transition_covariance_selecting matrices, one for each 79 | latent component. 80 | initial_distributions: list of initial distributions for latent state, one for each 81 | latent component. 82 | reg_func: regression function of the regression component. 83 | obs_distribution: distribution family of the observation, can be 'Gaussian' or 84 | 'Poisson'. 85 | dim_covariates: dimension of the covariates. 86 | """ 87 | 88 | self.params = params 89 | self.param_props = param_props 90 | self.param_priors = param_priors 91 | 92 | self.trans_mat_getters = trans_mat_getters 93 | self.trans_cov_getters = trans_cov_getters 94 | self.component_obs_mats = obs_mats 95 | self.cov_select_mats = cov_select_mats 96 | 97 | self.latent_comp_names = cov_select_mats.keys() 98 | self.obs_distribution = obs_distribution 99 | 100 | # Combine means and covariances of the initial state. 101 | self.initial_mean = jnp.concatenate([init_pri.mode() for init_pri in initial_distributions.values()]) 102 | self.initial_cov = jsp.linalg.block_diag( 103 | *[init_pri.covariance() for init_pri in initial_distributions.values()] 104 | ) 105 | 106 | # Combine fixed observation matrices and the covariance selecting matrices. 107 | self.obs_mat = jnp.concatenate([*obs_mats.values()], axis=1) 108 | self.cov_select_mat = jsp.linalg.block_diag(*cov_select_mats.values()) 109 | 110 | # Dimensions of the observation and the latent state 111 | self.dim_obs, self.dim_state = self.obs_mat.shape 112 | # Rank of the latent state 113 | self.dim_comp = self.get_trans_cov(self.params, 0).shape[0] 114 | # Dimension of the covariates 115 | self.dim_covariates = dim_covariates 116 | 117 | # Pick out the regression component if there is one. 118 | if reg_func is not None: 119 | # Regression component is always the last component if there is one. 120 | self.reg_name = list(params.keys())[-1] 121 | self.regression = reg_func 122 | 123 | @property 124 | def emission_shape(self): 125 | return (self.dim_obs,) 126 | 127 | @property 128 | def inputs_shape(self): 129 | return (self.dim_covariates,) 130 | 131 | def log_prior(self, params: ParamsSTS) -> Scalar: 132 | """Log prior probability of parameters.""" 133 | lp = 0.0 134 | for c_name, c_priors in self.param_priors.items(): 135 | for p_name, p_pri in c_priors.items(): 136 | lp += p_pri.log_prob(params[c_name][p_name]) 137 | return lp 138 | 139 | def initial_distribution(self): 140 | """Distribution of the initial state of the SSM form of the STS model.""" 141 | return MVN(self.initial_mean, self.initial_cov) 142 | 143 | def transition_distribution(self, state): 144 | """This is a must-have method of SSM. 145 | Not implemented here because tfp.distribution does not support multivariate normal 146 | distribution with singular convariance matrix. 147 | """ 148 | raise NotImplementedError 149 | 150 | def emission_distribution( 151 | self, state: Float[Array, " dim_state"], obs_input: Float[Array, " dim_obs"] 152 | ) -> tfd.Distribution: 153 | """Emission distribution of the SSM at one time step. 154 | The argument 'obs_input' is not the covariate of the STS model, it is either an array 155 | of 0's or the output of the regression component at the current time step, which will 156 | be added directly to the observation model. 157 | """ 158 | if self.obs_distribution == "Gaussian": 159 | return MVN(self.obs_mat @ state + obs_input, self.params["obs_model"]["cov"]) 160 | elif self.obs_distribution == "Poisson": 161 | unc_rates = self.obs_mat @ state + obs_input 162 | return Pois(rate=self._emission_constrainer(unc_rates)) 163 | 164 | def sample( 165 | self, 166 | params: ParamsSTS, 167 | num_timesteps: int, 168 | initial_state: Optional[Float[Array, " dim_states"]] = None, 169 | initial_timestep: int = 0, 170 | covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 171 | key: PRNGKey = jr.PRNGKey(0), 172 | ) -> Tuple[Float[Array, "num_timesteps dim_obs"], Float[Array, "num_timesteps dim_obs"]]: 173 | """Sample a sequence of latent states and emissions with given parameters of the STS. 174 | 175 | Args: 176 | initial_state: the latent state of the 1st sample. 177 | initial_timestep: starting time step of sampling, is only used when the transition 178 | matrix is time-dependent. 179 | """ 180 | 181 | if covariates is not None: 182 | # If there is a regression component, set the inputs of the emission model 183 | # of the SSM as the fitted value of the regression component. 184 | inputs = self.regression(params[self.reg_name], covariates) 185 | else: 186 | inputs = jnp.zeros((num_timesteps, self.dim_obs)) 187 | 188 | key1, key2 = jr.split(key, 2) 189 | if initial_state is None: 190 | initial_state = self.initial_distribution.sample(seed=key1) 191 | 192 | get_trans_mat = partial(self.get_trans_mat, params) 193 | get_trans_cov = partial(self.get_trans_cov, params) 194 | 195 | def _step(curr_state, args): 196 | key, input, t = args 197 | key1, key2 = jr.split(key, 2) 198 | # The latent state of the next time point. 199 | next_state = get_trans_mat(t) @ curr_state + self.cov_select_mat @ MVN( 200 | jnp.zeros(self.dim_comp), get_trans_cov(t) 201 | ).sample(seed=key1) 202 | curr_obs = self.emission_distribution(curr_state, input).sample(seed=key2) 203 | return next_state, (curr_state, curr_obs) 204 | 205 | # Sample the following emissions and states. 206 | keys = jr.split(key2, num_timesteps) 207 | _, (states, sample_obs) = lax.scan( 208 | _step, initial_state, (keys, inputs, initial_timestep + jnp.arange(num_timesteps)) 209 | ) 210 | sample_mean = self._emission_constrainer(states @ self.obs_mat.T + inputs) 211 | return sample_mean, sample_obs 212 | 213 | def marginal_log_prob( 214 | self, 215 | params: ParamsSTS, 216 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 217 | covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 218 | ) -> Scalar: 219 | """Compute marginal log likelihood of the observed time series given model parameters.""" 220 | if covariates is not None: 221 | # If there is a regression component, set the inputs of the emission model 222 | # of the SSM as the fitted value of the regression component. 223 | inputs = self.regression(params[self.reg_name], covariates) 224 | else: 225 | inputs = jnp.zeros(obs_time_series.shape) 226 | 227 | # Convert the model to SSM and perform filtering. 228 | ssm_params = self._to_ssm_params(params) 229 | filtered_posterior = self._ssm_filter(params=ssm_params, emissions=obs_time_series, inputs=inputs) 230 | return filtered_posterior.marginal_loglik 231 | 232 | def posterior_sample( 233 | self, 234 | params: ParamsSTS, 235 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 236 | covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 237 | key: PRNGKey = jr.PRNGKey(0), 238 | ) -> Tuple[Float[Array, "num_timesteps dim_obs"], Float[Array, "num_timesteps dim_obs"]]: 239 | """Sample latent states from the posterior distribution, as well as the predictive 240 | observations, given model parameters. 241 | """ 242 | if covariates is not None: 243 | # If there is a regression component, set the inputs of the emission model 244 | # of the SSM as the fitted value of the regression component. 245 | inputs = self.regression(params[self.reg_name], covariates) 246 | else: 247 | inputs = jnp.zeros(obs_time_series.shape) 248 | 249 | # Convert the STS model to SSM. 250 | ssm_params = self._to_ssm_params(params) 251 | 252 | # Sample latent state. 253 | key1, key2 = jr.split(key, 2) 254 | states = self._ssm_posterior_sample(ssm_params, obs_time_series, inputs, key1) 255 | 256 | # Sample predictive observations conditioned on posterior samples of latent states. 257 | def obs_sampler(state, input, key): 258 | return self.emission_distribution(state, input).sample(seed=key) 259 | 260 | keys = jr.split(key2, obs_time_series.shape[0]) 261 | 262 | predictive_obs = vmap(obs_sampler)(states, inputs, keys) 263 | predictive_mean = self._emission_constrainer(states @ self.obs_mat.T + inputs) 264 | return predictive_mean, predictive_obs 265 | 266 | def component_posterior( 267 | self, 268 | params: ParamsSTS, 269 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 270 | covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 271 | ) -> OrderedDict: 272 | """Decompose the STS model into components and return the means and variances 273 | of the marginal posterior of each component. 274 | """ 275 | component_pos = OrderedDict() 276 | if covariates is not None: 277 | # If there is a regression component, set the inputs of the emission model 278 | # of the SSM as the fitted value of the regression component. 279 | inputs = self.regression(params[self.reg_name], covariates) 280 | # Add the component corresponding to the regression component, which has no variance. 281 | component_pos[self.reg_name] = { 282 | "pos_mean": inputs, 283 | "pos_cov": jnp.zeros((*obs_time_series.shape, self.dim_obs)), 284 | } 285 | else: 286 | inputs = jnp.zeros(obs_time_series.shape) 287 | 288 | # Convert the STS model to SSM. 289 | ssm_params = self._to_ssm_params(params) 290 | 291 | # Infer the posterior of the joint SSM model. 292 | pos = self._ssm_smoother(ssm_params, obs_time_series, inputs) 293 | mu_pos = pos.smoothed_means 294 | var_pos = pos.smoothed_covariances 295 | 296 | # Decompose by latent component. 297 | _loc = 0 298 | for c, obs_mat in self.component_obs_mats.items(): 299 | # Extract posterior mean and covariances of each component from the latent state. 300 | c_dim = obs_mat.shape[1] 301 | c_mean = mu_pos[:, _loc : _loc + c_dim] 302 | c_cov = var_pos[:, _loc : _loc + c_dim, _loc : _loc + c_dim] 303 | # Posterior emission of the single component. 304 | c_obs_mean_unc = vmap(jnp.matmul, (None, 0))(obs_mat, c_mean) 305 | c_obs_mean = self._emission_constrainer(c_obs_mean_unc) 306 | if self.obs_distribution == "Gaussian": 307 | c_obs_cov = vmap(lambda s, m: m @ s @ m.T, (0, None))(c_cov, obs_mat) 308 | elif self.obs_distribution == "Poisson": 309 | # Set the covariance to be 0 if the distribution of the observation is Poisson. 310 | c_obs_cov = jnp.zeros((*obs_time_series.shape, self.dim_obs)) 311 | component_pos[c] = {"pos_mean": c_obs_mean, "pos_cov": c_obs_cov} 312 | _loc += c_dim 313 | return component_pos 314 | 315 | def forecast( 316 | self, 317 | params: ParamsSTS, 318 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 319 | num_forecast_steps: int, 320 | num_forecast_samples: int = 100, 321 | past_covariates: Optional[Float[Array, "num_timesteps dim_covariates"]] = None, 322 | forecast_covariates: Optional[Float[Array, "num_forecast_steps dim_covariates"]] = None, 323 | key: PRNGKey = jr.PRNGKey(0), 324 | ) -> Tuple[ 325 | Float[Array, "num_forecast_samples num_forecast_steps dim_obs"], 326 | Float[Array, "num_forecast_samples num_forecast_steps dim_obs"], 327 | ]: 328 | """Forecast the time series.""" 329 | if forecast_covariates is not None: 330 | # If there is a regression component, set the inputs of the emission model 331 | # of the SSM as the fitted value of the regression component. 332 | past_inputs = self.regression(params[self.reg_name], past_covariates) 333 | else: 334 | past_inputs = jnp.zeros(obs_time_series.shape) 335 | 336 | # Convert the STS model to SSM. 337 | ssm_params = self._to_ssm_params(params) 338 | get_trans_mat = partial(self.get_trans_mat, params) 339 | get_trans_cov = partial(self.get_trans_cov, params) 340 | 341 | # Filtering the observed time series to initialize the forecast 342 | filtered_posterior = self._ssm_filter(params=ssm_params, emissions=obs_time_series, inputs=past_inputs) 343 | filtered_mean = filtered_posterior.filtered_means 344 | filtered_cov = filtered_posterior.filtered_covariances 345 | 346 | # The first time step of forecast. 347 | t0 = obs_time_series.shape[0] - 1 348 | initial_state_mean = get_trans_mat(t0) @ filtered_mean[-1] 349 | initial_state_cov = ( 350 | get_trans_mat(t0) @ filtered_cov[-1] @ get_trans_mat(t0).T 351 | + self.cov_select_mat @ get_trans_cov(t0) @ self.cov_select_mat.T 352 | ) 353 | initial_states = MVN(initial_state_mean, initial_state_cov).sample(num_forecast_samples, seed=key) 354 | 355 | # Forecast by sample from an STS model conditioned on the parameter and initialized 356 | # using the filtered posterior. 357 | def single_forecast(initial_state, key): 358 | return self.sample(params, num_forecast_steps, initial_state, t0, forecast_covariates, key) 359 | 360 | forecast_mean, forecast_obs = vmap(single_forecast)(initial_states, jr.split(key, num_forecast_samples)) 361 | return forecast_mean, forecast_obs 362 | 363 | def one_step_predict(self, params, obs_time_series, covariates=None): 364 | """One step forward prediction. 365 | This is a by product of the Kalman filter. 366 | A general method of one-step-forward prediction is to be added to the class 367 | dynamax.LinearGaussianSSM 368 | """ 369 | raise NotImplementedError 370 | 371 | def get_trans_mat(self, params: ParamsSTS, t: int) -> Float[Array, "dim_state dim_state"]: 372 | """Evaluate the transition matrix of the latent state at time step t, 373 | conditioned on parameters of the model. 374 | """ 375 | trans_mat = [] 376 | for c_name in self.latent_comp_names: 377 | # Obtain the transition matrix of each single latent component. 378 | trans_getter = self.trans_mat_getters[c_name] 379 | c_trans_mat = trans_getter(params[c_name], t) 380 | trans_mat.append(c_trans_mat) 381 | return jsp.linalg.block_diag(*trans_mat) 382 | 383 | def get_trans_cov(self, params: ParamsSTS, t: int) -> Float[Array, "order_state order_state"]: 384 | """Evaluate the covariance of the latent dynamics at time step t, 385 | conditioned on parameters of the model. 386 | """ 387 | trans_cov = [] 388 | for c_name in self.latent_comp_names: 389 | # Obtain the covariance of each single latent component. 390 | cov_getter = self.trans_cov_getters[c_name] 391 | c_trans_cov = cov_getter(params[c_name], t) 392 | trans_cov.append(c_trans_cov) 393 | return jsp.linalg.block_diag(*trans_cov) 394 | 395 | def _to_ssm_params(self, params): 396 | """Convert the STS model into the form of the corresponding SSM model.""" 397 | get_trans_mat = partial(self.get_trans_mat, params) 398 | 399 | def get_sparse_cov(t): 400 | return self.cov_select_mat @ self.get_trans_cov(params, t) @ self.cov_select_mat.T 401 | 402 | if self.obs_distribution == "Gaussian": 403 | return ParamsLGSSM( 404 | initial=ParamsLGSSMInitial(mean=self.initial_mean, cov=self.initial_cov), 405 | dynamics=ParamsLGSSMDynamics( 406 | weights=get_trans_mat, 407 | bias=jnp.zeros(self.dim_state), 408 | input_weights=jnp.zeros((self.dim_state, 1)), 409 | cov=get_sparse_cov, 410 | ), 411 | emissions=ParamsLGSSMEmissions( 412 | weights=self.obs_mat, 413 | bias=jnp.zeros(self.dim_obs), 414 | input_weights=jnp.eye(self.dim_obs), 415 | cov=params["obs_model"]["cov"], 416 | ), 417 | ) 418 | elif self.obs_distribution == "Poisson": 419 | # Current formulation of the dynamics function cannot depends on t 420 | trans_mat = get_trans_mat(0) 421 | sparse_cov = get_sparse_cov(0) 422 | return ParamsGGSSM( 423 | initial_mean=self.initial_mean, 424 | initial_covariance=self.initial_cov, 425 | dynamics_function=lambda z, u: trans_mat @ z, 426 | dynamics_covariance=sparse_cov, 427 | emission_mean_function=lambda z, u: self._emission_constrainer(self.obs_mat @ z + u), 428 | emission_cov_function=lambda z, u: jnp.diag(self._emission_constrainer(self.obs_mat @ z)), 429 | emission_dist=lambda mu, Sigma: Pois(rate=mu), 430 | ) 431 | 432 | def _ssm_filter(self, params, emissions, inputs): 433 | """The filter of the corresponding SSM model.""" 434 | if self.obs_distribution == "Gaussian": 435 | return lgssm_filter(params=params, emissions=emissions, inputs=inputs) 436 | elif self.obs_distribution == "Poisson": 437 | return cmgf_filt( 438 | model_params=params, inf_params=EKFIntegrals(), emissions=emissions, inputs=inputs, num_iter=2 439 | ) 440 | 441 | def _ssm_smoother(self, params, emissions, inputs): 442 | """The smoother of the corresponding SSM model""" 443 | if self.obs_distribution == "Gaussian": 444 | return lgssm_smoother(params=params, emissions=emissions, inputs=inputs) 445 | elif self.obs_distribution == "Poisson": 446 | return cmgf_smooth(params=params, inf_params=EKFIntegrals(), emissions=emissions, inputs=inputs, num_iter=2) 447 | 448 | def _ssm_posterior_sample(self, ssm_params, obs_time_series, inputs, key): 449 | """The posterior sampler of the corresponding SSM model""" 450 | if self.obs_distribution == "Gaussian": 451 | return lgssm_posterior_sample(key=key, params=ssm_params, emissions=obs_time_series, inputs=inputs) 452 | elif self.obs_distribution == "Poisson": 453 | # Currently the posterior_sample for STS model with Poisson likelihood 454 | # simply returns the filtered means. 455 | return self._ssm_filter(ssm_params, obs_time_series, inputs).filtered_means 456 | 457 | def _emission_constrainer(self, emission): 458 | """Transform the state into the possibly constrained space.""" 459 | if self.obs_distribution == "Gaussian": 460 | return emission 461 | elif self.obs_distribution == "Poisson": 462 | return jnp.exp(emission) 463 | -------------------------------------------------------------------------------- /sts_jax/structural_time_series/sts_components.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections import OrderedDict 3 | 4 | import jax.numpy as jnp 5 | import jax.scipy as jsp 6 | import tensorflow_probability.substrates.jax.bijectors as tfb 7 | from dynamax.parameters import ParameterProperties 8 | from dynamax.utils.bijectors import RealToPSDBijector 9 | from dynamax.utils.distributions import InverseWishart as IW 10 | from dynamax.utils.distributions import MatrixNormalPrecision as MNP 11 | from jax import lax 12 | from jaxtyping import Array, Float 13 | from tensorflow_probability.substrates.jax import distributions as tfd 14 | from tensorflow_probability.substrates.jax.distributions import ( 15 | MultivariateNormalDiag as MVNDiag, 16 | ) 17 | from tensorflow_probability.substrates.jax.distributions import ( 18 | MultivariateNormalFullCovariance as MVN, 19 | ) 20 | from tensorflow_probability.substrates.jax.distributions import Uniform 21 | 22 | 23 | class ParamsSTSComponent(OrderedDict): 24 | """A :class: 'OrderdedDict' with each item being an instance of :class: 'jax.DeviceArray'.""" 25 | 26 | pass 27 | 28 | 29 | class ParamsSTS(OrderedDict): 30 | """A :class: 'OrderdedDict' with each item being an instance of :class: 'OrderedDict'.""" 31 | 32 | pass 33 | 34 | 35 | class ParamPropertiesSTS(OrderedDict): 36 | """A :class: 'OrderdedDict' with each item being an instance of :class: 'OrderedDict', 37 | having the same pytree structure with 'ParamsSTS'. 38 | """ 39 | 40 | pass 41 | 42 | 43 | class ParamPriorsSTS(OrderedDict): 44 | """A :class: 'OrderdedDict' with each item being an instance of :class: 'OrderedDict', 45 | having the same pytree structure with 'ParamsSTS'. 46 | """ 47 | 48 | pass 49 | 50 | 51 | # helper function 52 | def _initial_statistics(initial_prior, initial_mean, initial_cov): 53 | if initial_prior is not None: 54 | return initial_prior.mean(), initial_prior.covariance() 55 | else: 56 | return initial_mean, initial_cov 57 | 58 | 59 | def _set_prior(input_prior, default_prior): 60 | return default_prior if input_prior is None else input_prior 61 | 62 | 63 | ######################### 64 | # Abstract Components # 65 | ######################### 66 | 67 | 68 | class STSComponent(ABC): 69 | r"""A base class for latent component of structural time series (STS) models. 70 | 71 | **Abstract Methods** 72 | 73 | A latent component of the STS model that inherits from 'STSComponent' must implement 74 | a few key functions and properties: 75 | 76 | * :meth: 'initialize_params' initializes parameters of the latent component, 77 | given the initial value and scale of steps of the observed time series. 78 | * :meth: 'get_trans_mat' returns the transition matrix, $F[t]$, of the latent component 79 | at time step $t$. 80 | * :meth: 'get_trans_cov' returns the nonsingular covariance matrix $Q[t]$ of the latent 81 | component at time step $t$. 82 | * :attr: 'obs_mat' returns the observation (emission) matrix $H$ for the latent component. 83 | * :attr: 'cov_select_mat' returns the selecting matrix $R$ that expands the nonsingular 84 | covariance matrix $Q[t]$ in each time step into a (possibly singular) convarince 85 | matrix of shape (dim_state, dim_state). 86 | * :attr: 'name' returns the unique name of the latent component. 87 | * :attr: 'dim_obs' returns the dimension of the observation in each step of the observed 88 | time series. 89 | * :attr: 'initial_distribution' returns the initial_distribution of the initial latent 90 | state of the component, which is an instance of the class of 91 | MultivariateNormalFullCovariance 92 | from tensorflow_probability.substrates.jax.distributions. 93 | * :attr: 'params' returns parameters of the component, which is an instance of OrderedDict, 94 | forming a pytree structure of Jax. 95 | * :attr: 'param_props' returns parameter properties of each item in 'params'. 96 | param_props has the same pytree structure with 'params', and each leaf is an instance 97 | of dynamax.parameters.ParameterProperties, which specifies constrainer of 98 | each parameter and whether that parameter is trainable. 99 | * :attr: 'param_priors' returns prior distribution of each item in 'params'. 100 | param_priors has the same pytree structure with 'params', and each leaf is an instance 101 | of tfd.Distribution. 102 | """ 103 | 104 | def __init__(self, name: str, dim_obs: int = 1) -> None: 105 | self.name = name 106 | self.dim_obs = dim_obs 107 | self.initial_distribution = None 108 | 109 | self.params = OrderedDict() 110 | self.param_props = OrderedDict() 111 | self.param_priors = OrderedDict() 112 | 113 | @abstractmethod 114 | def initialize_params(self, obs_initial: Float[Array, " dim_obs"], obs_scale: Float[Array, " dim_obs"]) -> None: 115 | r"""Initialize parameters of the component given statistics of the observed time series. 116 | 117 | Args: 118 | obs_initial: the first observation in the observed time series $z_0$. 119 | obs_scale: vector of standard deviations of each dimension of the observed time series. 120 | 121 | Returns: 122 | No returns. Update self.params and self.initial_distributions directly. 123 | """ 124 | raise NotImplementedError 125 | 126 | @abstractmethod 127 | def get_trans_mat(self, params: ParamsSTSComponent, t: int) -> Float[Array, "dim_state dim_state"]: 128 | r"""Compute the transition matrix, $F[t]$, of the latent component at time step $t$. 129 | 130 | Args: 131 | params: parameters of the latent component, having the same tree structure with 132 | self.params. 133 | t: time point at which the transition matrix is to be evaluted. 134 | 135 | Returns: 136 | transition matrix, $F[t]$, of the latent component at time step $t$ 137 | """ 138 | raise NotImplementedError 139 | 140 | @abstractmethod 141 | def get_trans_cov(self, params: ParamsSTSComponent, t: int) -> Float[Array, "rank_state rank_state"]: 142 | r"""Compute the nonsingular covariance matrix, $Q[t]$, of the latent component at 143 | time step $t$. 144 | 145 | Args: 146 | params: parameters of the latent component, having the same tree structure with 147 | self.params. 148 | t: time point at which the transition matrix is to be evaluted. 149 | 150 | Returns: 151 | nonsingular covariance matrix, $Q[t]$, of the latent component at time step $t$ 152 | """ 153 | raise NotImplementedError 154 | 155 | @property 156 | @abstractmethod 157 | def obs_mat(self) -> Float[Array, "dim_obs dim_state"]: 158 | r"""Returns the observation (emission) matrix $H$ for the latent component.""" 159 | raise NotImplementedError 160 | 161 | @property 162 | @abstractmethod 163 | def cov_select_mat(self) -> Float[Array, "dim_state rank_state"]: 164 | r"""Returns the selecting matrix $R$ that expands the nonsingular covariance matrix 165 | $Q[t]$ in each time step into a (possibly singular) convarince matrix of shape 166 | (dim_state, dim_state). 167 | 168 | Returns: 169 | selecting matrix $R$ of shape (dim_state, rank_state) 170 | """ 171 | raise NotImplementedError 172 | 173 | 174 | class STSRegression(ABC): 175 | r"""A base class for regression component of structural time series (STS) models. 176 | 177 | The regression component is not treated as a latent component of the STS model. 178 | Instead, the value of the regression model in each time step is added to the observation 179 | model without adding random noise. So the regression model is not necessarily a linear 180 | regression model, any model is valid, as long as the output of the model has same dimension 181 | with the observed time series. 182 | 183 | **Abstract Methods** 184 | 185 | Models that inherit from `STSRegression` must implement a few key functions and properties: 186 | 187 | * :meth: 'initialize_params' initializes parameters of the regression model, 188 | given covariates and the observed time series. 189 | * :attr: 'name' returns the unique name of the regression component. 190 | * :attr: 'dim_obs' returns the dimension of the observation in each step of the observed 191 | time series. This equals the dimension of the output of the regression model. 192 | * :attr: 'params' returns parameters of the regression function, which is an instance of 193 | OrderedDict, forming a pytree structure of Jax. 194 | * :attr: 'param_props' returns parameter properties of each item in 'params'. 195 | param_props has the same pytree structure with 'params', and each leaf is an instance 196 | of dynamax.parameters.ParameterProperties, which specifies constrainer of 197 | each parameter and whether that parameter is trainable. 198 | * :attr: 'param_priors' returns prior distribution of each item in 'params'. 199 | param_priors has the same pytree structure with 'params', and each leaf is an instance 200 | of tfd.Distribution. 201 | """ 202 | 203 | def __init__(self, name: str, dim_obs: int = 1) -> None: 204 | self.name = name 205 | self.dim_obs = dim_obs 206 | 207 | self.params = OrderedDict() 208 | self.param_props = OrderedDict() 209 | self.param_priors = OrderedDict() 210 | 211 | @abstractmethod 212 | def initialize_params( 213 | self, 214 | covariates: Float[Array, "num_timesteps dim_covariates"], 215 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 216 | ) -> None: 217 | r"""Initialize parameters of the regression model by minimizing certain loss function, 218 | given the series of covariates and the observed time series. 219 | 220 | Args: 221 | covariates: series of covariates of the regression function. 222 | obs_time_series: observed time series. 223 | 224 | Returns: 225 | No returns. Update self.params directly. 226 | """ 227 | raise NotImplementedError 228 | 229 | @abstractmethod 230 | def get_reg_value( 231 | self, params: ParamsSTSComponent, covariates: Float[Array, "num_timesteps dim_covariates"] 232 | ) -> Float[Array, "num_timesteps dim_obs"]: 233 | r"""Returns the sequence of values of the regression model evaluted at the 234 | given parameters and the sequence of covariates. 235 | 236 | Args: 237 | params: parameters on which the regression model is to be evalueated. 238 | covariates: sequence of covariates at which the regression model is to be evaluated. 239 | 240 | Raises: 241 | values: the sequence of values of the regression model. 242 | """ 243 | raise NotImplementedError 244 | 245 | 246 | ######################### 247 | # Concrete Components # 248 | ######################### 249 | 250 | 251 | class LocalLinearTrend(STSComponent): 252 | r"""The local linear trend component of the structual time series (STS) model 253 | 254 | The latent state has two parts $[level, slope]$, having dimension 2 * dim_obs. 255 | The dynamics is: 256 | $$level[t+1] = level[t] + slope[t] + \matcal{N}(0, cov_level)$$ 257 | $$slope[t+1] = slope[t] + \mathcal{N}(0, cov_slope)$$ 258 | 259 | In the case $dim_obs = 1$, the transition matrix $F$ and the observation matrix $H$ are 260 | $$ 261 | F = \begin{bmatrix} 262 | 1 & 1 \\ 263 | 0 & 1 264 | \end{bmatrix}, 265 | \qquad 266 | H = [ 1, 0 ]. 267 | $$ 268 | """ 269 | 270 | def __init__( 271 | self, 272 | dim_obs: int = 1, 273 | name: str = "local_linear_trend", 274 | level_cov_prior: tfd.Distribution = None, 275 | slope_cov_prior: tfd.Distribution = None, 276 | initial_level_prior: tfd.MultivariateNormalFullCovariance = None, 277 | initial_slope_prior: tfd.MultivariateNormalFullCovariance = None, 278 | cov_constrainer: tfb.Bijector = None, 279 | ) -> None: 280 | super().__init__(name=name, dim_obs=dim_obs) 281 | self.level_cov_pri = level_cov_prior 282 | self.slope_cov_pri = slope_cov_prior 283 | self.initial_level_pri = initial_level_prior 284 | self.initial_slope_pri = initial_slope_prior 285 | self.cov_constrain = RealToPSDBijector() if cov_constrainer is None else cov_constrainer 286 | 287 | self.initial_distribution = MVN(jnp.zeros(2 * dim_obs), jnp.eye(2 * dim_obs)) 288 | 289 | self.param_props["cov_level"] = ParameterProperties(trainable=True, constrainer=self.cov_constrain) 290 | self.param_priors["cov_level"] = _set_prior(level_cov_prior, IW(df=dim_obs, scale=jnp.eye(dim_obs))) 291 | self.params["cov_level"] = self.param_priors["cov_level"].mode() 292 | 293 | self.param_props["cov_slope"] = ParameterProperties(trainable=True, constrainer=self.cov_constrain) 294 | self.param_priors["cov_slope"] = _set_prior(slope_cov_prior, IW(df=dim_obs, scale=jnp.eye(dim_obs))) 295 | self.params["cov_slope"] = self.param_priors["cov_slope"].mode() 296 | 297 | # The local linear trend component has a fixed transition matrix. 298 | self._trans_mat = jnp.kron(jnp.array([[1, 1], [0, 1]]), jnp.eye(dim_obs)) 299 | 300 | # Fixed observation matrix. 301 | self._obs_mat = jnp.kron(jnp.array([1, 0]), jnp.eye(dim_obs)) 302 | 303 | # Covariance selection matrix. 304 | self._cov_select_mat = jnp.eye(2 * dim_obs) 305 | 306 | def initialize_params(self, obs_initial: Float[Array, " dim_obs"], obs_scale: Float[Array, " dim_obs"]) -> None: 307 | dim_obs = len(obs_initial) 308 | 309 | # Initialize the distribution of the initial state. 310 | initial_level_mean, initial_level_cov = _initial_statistics( 311 | self.initial_level_pri, obs_initial, jnp.diag(obs_scale**2) 312 | ) 313 | initial_slope_mean, initial_slope_cov = _initial_statistics( 314 | self.initial_slope_pri, jnp.zeros(dim_obs), jnp.diag(obs_scale) 315 | ) 316 | initial_mean = jnp.concatenate((initial_level_mean, initial_slope_mean)) 317 | initial_cov = jsp.linalg.block_diag(initial_level_cov, initial_slope_cov) 318 | # initial_cov = jnp.kron(jnp.eye(2), jnp.diag(obs_scale**2)) 319 | self.initial_distribution = MVN(initial_mean, initial_cov) 320 | 321 | # Initialize parameters. 322 | self.param_priors["cov_level"] = _set_prior( 323 | self.level_cov_pri, IW(df=dim_obs, scale=1e-3 * obs_scale**2 * jnp.eye(dim_obs)) 324 | ) 325 | self.params["cov_level"] = self.param_priors["cov_level"].mode() 326 | self.param_priors["cov_slope"] = _set_prior( 327 | self.slope_cov_pri, IW(df=dim_obs, scale=1e-3 * obs_scale**2 * jnp.eye(dim_obs)) 328 | ) 329 | self.params["cov_slope"] = self.param_priors["cov_slope"].mode() 330 | 331 | def get_trans_mat(self, params: ParamsSTSComponent, t: int) -> Float[Array, "2*dim_obs 2*dim_obs"]: 332 | return self._trans_mat 333 | 334 | def get_trans_cov(self, params: ParamsSTSComponent, t: int) -> Float[Array, "2*dim_obs 2*dim_obs"]: 335 | _shape = params["cov_level"].shape 336 | return jnp.block([[params["cov_level"], jnp.zeros(_shape)], [jnp.zeros(_shape), params["cov_slope"]]]) 337 | 338 | @property 339 | def obs_mat(self) -> Float[Array, "dim_obs 2*dim_obs"]: 340 | return self._obs_mat 341 | 342 | @property 343 | def cov_select_mat(self) -> Float[Array, "2*dim_obs 2*dim_obs"]: 344 | return self._cov_select_mat 345 | 346 | 347 | class Autoregressive(STSComponent): 348 | r"""The autoregressive (AR) latent component of the structural time series (STS) model. 349 | 350 | The autoregressive model of order $p$, i.e., AR(p) is defined as 351 | 352 | $$z_t = \sum_{j=1}^{p} w_j z_{t-j} + \epsilon_t$$ 353 | 354 | where 355 | 356 | * $w_j$'s are coefficients of the autoregression. It is required that $|w_j| < 1 $ 357 | to make sure that the autoregressive model is stationary. 358 | * $\epsilon_t$ is the disturbance that follows a Gaussian distribution with mean 0 359 | and covariance $cov_level$. 360 | 361 | There are multiple ways to formulate the AR(p) component in the STS model. We set the 362 | state vector of the component to be the history of the latent state, with the dimension 363 | $p*dim_obs$. 364 | 365 | If $dim_obs=1$ and assume $p=3$, the transition matrix and the observation matrix is 366 | $$ 367 | F = \begin{bmatrix} 368 | w_1 & w_2 & w_3 \\ 369 | 1 & 0 & 0 \\ 370 | 0 & 1 & 0 371 | \end{bmatrix}, 372 | \qquad 373 | H = [1, 0, 0] 374 | $$ 375 | """ 376 | 377 | def __init__( 378 | self, 379 | order: int, 380 | dim_obs: int = 1, 381 | name: str = "ar", 382 | coefficients_prior: tfd.Distribution = None, 383 | cov_level_prior: tfd.Distribution = None, 384 | initial_state_prior: tfd.MultivariateNormalFullCovariance = None, 385 | coefficient_constrainer: tfb.Bijector = None, 386 | cov_constrainer: tfb.Bijector = None, 387 | ) -> None: 388 | super().__init__(name=name, dim_obs=dim_obs) 389 | self.coef_pri = coefficients_prior 390 | self.cov_level_pri = cov_level_prior 391 | self.initial_state_pri = initial_state_prior 392 | self.coef_constrain = tfb.Tanh() if coefficient_constrainer is None else coefficient_constrainer 393 | self.cov_constrain = RealToPSDBijector() if cov_constrainer is None else cov_constrainer 394 | 395 | self.order = order 396 | self.initial_distribution = MVN(jnp.zeros(order * dim_obs), jnp.eye(order * dim_obs)) 397 | 398 | self.param_props["cov_level"] = ParameterProperties(trainable=True, constrainer=self.cov_constrain) 399 | self.param_priors["cov_level"] = _set_prior(cov_level_prior, IW(df=dim_obs, scale=jnp.eye(dim_obs))) 400 | self.params["cov_level"] = self.param_priors["cov_level"].mode() 401 | 402 | self.param_props["coef"] = ParameterProperties(trainable=True, constrainer=self.coef_constrain) 403 | self.param_priors["coef"] = _set_prior(coefficients_prior, MVNDiag(jnp.zeros(order), jnp.ones(order))) 404 | self.params["coef"] = self.param_priors["coef"].mode() 405 | 406 | # Fixed observation matrix. 407 | self._obs_mat = jnp.kron(jnp.eye(order)[0], jnp.eye(dim_obs)) 408 | 409 | # Covariance selection matrix. 410 | self._cov_select_mat = jnp.kron(jnp.eye(order)[:, 0], jnp.eye(dim_obs)) 411 | 412 | def initialize_params(self, obs_initial: Float[Array, " dim_obs"], obs_scale: Float[Array, " dim_obs"]) -> None: 413 | dim_obs = len(obs_initial) 414 | 415 | # Initialize the distribution of the initial state. 416 | # if self.initial_state_pri is not None: 417 | # initial_state_mean = self.initial_state_pri.mean() 418 | # initial_state_cov = self.initial_state_pri.covariance() 419 | # else: 420 | # initial_state_mean = obs_initial 421 | # initial_state_cov = jnp.diag(obs_scale**2) 422 | initial_state_mean, initial_state_cov = _initial_statistics( 423 | self.initial_state_pri, obs_initial, jnp.diag(obs_scale**2) 424 | ) 425 | initial_mean = jnp.kron(jnp.eye(self.order)[0], initial_state_mean) 426 | initial_cov = jnp.kron(jnp.eye(self.order), initial_state_cov) 427 | self.initial_distribution = MVN(initial_mean, initial_cov) 428 | 429 | # Initialize parameters. 430 | self.param_priors["cov_level"] = _set_prior( 431 | self.cov_level_pri, IW(df=dim_obs, scale=1e-3 * obs_scale**2 * jnp.eye(dim_obs)) 432 | ) 433 | self.params["cov_level"] = self.param_priors["cov_level"].mode() 434 | 435 | def get_trans_mat(self, params: ParamsSTSComponent, t: int) -> Float[Array, "order*dim_obs order*dim_obs"]: 436 | trans_mat = jnp.vstack((params["coef"], jnp.eye(self.order)[:-1])) 437 | return jnp.kron(trans_mat, jnp.eye(self.dim_obs)) 438 | 439 | def get_trans_cov(self, params: ParamsSTSComponent, t: int) -> Float[Array, "dim_obs dim_obs"]: 440 | return params["cov_level"] 441 | 442 | @property 443 | def obs_mat(self) -> Float[Array, "dim_obs order*dim_obs"]: 444 | return self._obs_mat 445 | 446 | @property 447 | def cov_select_mat(self) -> Float[Array, "order*dim_obs dim_obs"]: 448 | return self._cov_select_mat 449 | 450 | 451 | class SeasonalDummy(STSComponent): 452 | r"""The (dummy) seasonal component of the structual time series (STS) model 453 | 454 | Since at any step $t$ the seasonal effect has following constraint 455 | 456 | $$sum_{j=1}^{num_seasons} s_{t-j} = 0 $$, 457 | 458 | the seasonal effect (random) of next time step takes the form: 459 | 460 | $$ s_{t+1} = - sum_{j=1}^{num_seasons-1} s_{t+1-j} + w_{t+1}$$ 461 | 462 | where 463 | 464 | * $w_{t+1}$ is the stochastic noise of the seasonal effect following a normal distribution 465 | with mean 0 and covariance $drift_cov$. 466 | So the latent state corresponding to the seasonal component has dimension 467 | $(num_seasons - 1) * dim_obs$ 468 | 469 | If $dim_obs = 1$, and suppose that $num_seasons = 4$, the transition matrix and 470 | the observation matrix is 471 | $$ 472 | F = \begin{bmatrix} 473 | -1 & -1 & -1 \\ 474 | 1 & 0 & 0 \\ 475 | 0 & 1 & 0 476 | \end{bmatrix}, 477 | \qquad 478 | H = [ 1, 0, 0 ] 479 | $$ 480 | 481 | Args (in addition to 'name' and 'dim_obs'): 482 | 'num_seasons' is the number of seasons. 483 | 'num_steps_per_season' is consecutive steps that each seasonal effect does not change. 484 | For example, if a STS model has a weekly seasonal effect but the data is measured 485 | daily, then num_steps_per_season = 7; 486 | and if a STS model has a daily seasonal effect but the data is measured hourly, 487 | then num_steps_per_season = 24. 488 | """ 489 | 490 | def __init__( 491 | self, 492 | num_seasons: int, 493 | num_steps_per_season: int = 1, 494 | dim_obs: int = 1, 495 | name: str = "seasonal_dummy", 496 | drift_cov_prior: tfd.Distribution = None, 497 | initial_effect_prior: tfd.MultivariateNormalFullCovariance = None, 498 | cov_constrainer: tfb.Bijector = None, 499 | ) -> None: 500 | super().__init__(name=name, dim_obs=dim_obs) 501 | self.drift_cov_pri = drift_cov_prior 502 | self.initial_effect_pri = initial_effect_prior 503 | self.cov_constrain = RealToPSDBijector() if cov_constrainer is None else cov_constrainer 504 | 505 | self.num_seasons = num_seasons 506 | self.steps_per_season = num_steps_per_season 507 | 508 | _c = self.num_seasons - 1 509 | self.initial_distribution = MVN(jnp.zeros(_c * dim_obs), jnp.eye(_c * dim_obs)) 510 | 511 | self.param_props["drift_cov"] = ParameterProperties(trainable=True, constrainer=self.cov_constrain) 512 | self.param_priors["drift_cov"] = _set_prior(drift_cov_prior, IW(df=dim_obs, scale=jnp.eye(dim_obs))) 513 | self.params["drift_cov"] = self.param_priors["drift_cov"].mode() 514 | 515 | # The seasonal component has a fixed transition matrix. 516 | self._trans_mat = jnp.kron(jnp.concatenate((-jnp.ones((1, _c)), jnp.eye(_c)[:-1]), axis=0), jnp.eye(dim_obs)) 517 | 518 | # Fixed observation matrix. 519 | self._obs_mat = jnp.kron(jnp.eye(_c)[0], jnp.eye(dim_obs)) 520 | 521 | # Covariance selection matrix. 522 | self._cov_select_mat = jnp.kron(jnp.eye(_c)[:, [0]], jnp.eye(dim_obs)) 523 | 524 | def initialize_params(self, obs_initial: Float[Array, " dim_obs"], obs_scale: Float[Array, " dim_obs"]) -> None: 525 | # Initialize the distribution of the initial state. 526 | dim_obs = len(obs_initial) 527 | # if self.initial_effect_pri is not None: 528 | # initial_effect_mean = self.initial_effect_pri.mean() 529 | # initial_effect_cov = self.initial_effect_pri.covariance() 530 | # else: 531 | # initial_effect_mean = jnp.zeros(dim_obs) 532 | # initial_effect_cov = jnp.diag(obs_scale**2) 533 | initial_effect_mean, initial_effect_cov = _initial_statistics( 534 | self.initial_effect_pri, jnp.zeros(dim_obs), jnp.diag(obs_scale**2) 535 | ) 536 | initial_mean = jnp.kron(jnp.eye(self.num_seasons - 1)[0], initial_effect_mean) 537 | initial_cov = jnp.kron(jnp.eye(self.num_seasons - 1), initial_effect_cov) 538 | # initial_mean = jnp.zeros((self.num_seasons-1) * dim_obs) 539 | # initial_cov = jnp.kron(jnp.eye(self.num_seasons-1), jnp.diag(obs_scale**2)) 540 | self.initial_distribution = MVN(initial_mean, initial_cov) 541 | 542 | # Initialize parameters. 543 | self.param_priors["drift_cov"] = _set_prior( 544 | self.drift_cov_pri, IW(df=dim_obs, scale=1e-3 * obs_scale**2 * jnp.eye(dim_obs)) 545 | ) 546 | self.params["drift_cov"] = self.param_priors["drift_cov"].mode() 547 | 548 | def get_trans_mat( 549 | self, params: ParamsSTSComponent, t: int 550 | ) -> Float[Array, " (num_seasons-1)*dim_obs (num_seasons-1)*dim_obs"]: 551 | return lax.cond( 552 | t % self.steps_per_season == 0, 553 | lambda: self._trans_mat, 554 | lambda: jnp.eye((self.num_seasons - 1) * self.dim_obs), 555 | ) 556 | 557 | def get_trans_cov(self, params: ParamsSTSComponent, t: int) -> Float[Array, "dim_obs dim_obs"]: 558 | return lax.cond( 559 | t % self.steps_per_season == 0, 560 | lambda: jnp.atleast_2d(params["drift_cov"]), 561 | lambda: jnp.eye(self.dim_obs) * 1e-32, 562 | ) 563 | 564 | @property 565 | def obs_mat(self) -> Float[Array, " dim_obs (num_seasons-1)*dim_obs"]: 566 | return self._obs_mat 567 | 568 | @property 569 | def cov_select_mat(self) -> Float[Array, "(num_seasons-1)*dim_obs dim_obs"]: 570 | return self._cov_select_mat 571 | 572 | 573 | class SeasonalTrig(STSComponent): 574 | r"""The trigonometric seasonal component of the structual time series (STS) model. 575 | 576 | The seasonal effect (random) of next time step takes the form: 577 | 578 | $$\gamma_t = \sum_{j=1}^{\lfloor s/2 \rfloor} \gamma_{j,t}$$ 579 | 580 | and the state is update by 581 | 582 | $$\gamma_{j, t+1} = \cos(\lambda_j) \gamma_{j,t} + \sin(\lambda_j) \gamma^*_{jt} + w_{j,t}$$ 583 | $$\gamma^*_{j, t+1} = -\sin(\lambda_j) \gamma_{j,t} + \cos(\lambda_j) \gamma^*_{jt} + w^*_{j,t}$$ 584 | 585 | where 586 | * $s$ is number of seasons. 587 | * $j = 1, ..., \lfloor s/2 \rfloor$. 588 | * $w_{jt}$ and $w^*_{jt}$ are stochastic noises of the seasonal effect following a normal 589 | distribution with mean zeros and a common covariance $drift_cov$ for all $j$ and $t$. 590 | 591 | The latent state corresponding to the seasonal component has dimension $(s-1) * dim_obs$. 592 | If $s$ is odd, then $s-1 = 2 * (s-1)/2$, which means thare are $j = 1,...,(s-1)/2$ blocks. 593 | If $s$ is even, then $s-1 = 2 * (s/2) - 1$, which means there are $j = 1,...(s/2)$ blocks, 594 | but we remove the last dimension in this case since this part does not play role in the 595 | observation. 596 | 597 | If $dim_obs = 1$, for $j = \lfloor (s-1)/2 \rfloor$: 598 | $$ 599 | F_j = \begin{bmatrix} 600 | \cos(\lambda_j) & \sin(\lambda_j) \\ 601 | -\sin(\lambda_j) & \cos(\lambda_j) 602 | \end{bmatrix}, 603 | \qquad 604 | H_j = [ 1, 0 ] 605 | $$ 606 | 607 | Args (in addition to 'name' and 'dim_obs'): 608 | 'num_seasons' is the number of seasons. 609 | 'num_steps_per_season' is consecutive steps that each seasonal effect does not change. 610 | For example, if a STS model has a weekly seasonal effect but the data is measured 611 | daily, then num_steps_per_season = 7; 612 | and if a STS model has a daily seasonal effect but the data is measured hourly, 613 | then num_steps_per_season = 24. 614 | """ 615 | 616 | def __init__( 617 | self, 618 | num_seasons: int, 619 | num_steps_per_season: int = 1, 620 | dim_obs: int = 1, 621 | name: str = "seasonal_trig", 622 | drift_cov_prior: tfd.Distribution = None, 623 | initial_effect_prior: tfd.MultivariateNormalFullCovariance = None, 624 | cov_constrainer: tfb.Bijector = None, 625 | ) -> None: 626 | super().__init__(name=name, dim_obs=dim_obs) 627 | self.drift_cov_pri = drift_cov_prior 628 | self.initial_effect_pri = initial_effect_prior 629 | self.cov_constrain = RealToPSDBijector() if cov_constrainer is None else cov_constrainer 630 | 631 | self.num_seasons = num_seasons 632 | self.steps_per_season = num_steps_per_season 633 | 634 | _c = num_seasons - 1 635 | self.initial_distribution = MVN(jnp.zeros(_c * dim_obs), jnp.eye(_c * dim_obs)) 636 | 637 | self.param_props["drift_cov"] = ParameterProperties(trainable=True, constrainer=self.cov_constrain) 638 | self.param_priors["drift_cov"] = _set_prior(drift_cov_prior, IW(df=dim_obs, scale=jnp.eye(dim_obs))) 639 | self.params["drift_cov"] = self.param_priors["drift_cov"].mode() 640 | 641 | # The seasonal component has a fixed transition matrix. 642 | num_pairs = int(jnp.floor(num_seasons / 2)) 643 | _trans_mat = jnp.zeros((2 * num_pairs, 2 * num_pairs)) 644 | for j in 1 + jnp.arange(num_pairs): 645 | lamb_j = (2 * j * jnp.pi) / num_seasons 646 | block_j = jnp.array([[jnp.cos(lamb_j), jnp.sin(lamb_j)], [-jnp.sin(lamb_j), jnp.cos(lamb_j)]]) 647 | _trans_mat = _trans_mat.at[2 * (j - 1) : 2 * j, 2 * (j - 1) : 2 * j].set(block_j) 648 | if num_seasons % 2 == 0: 649 | _trans_mat = _trans_mat[:-1, :-1] 650 | self._trans_mat = jnp.kron(_trans_mat, jnp.eye(dim_obs)) 651 | 652 | # Fixed observation matrix. 653 | _obs_mat = jnp.tile(jnp.array([1, 0]), num_pairs) 654 | if num_seasons % 2 == 0: 655 | _obs_mat = _obs_mat[:-1] 656 | self._obs_mat = jnp.kron(_obs_mat, jnp.eye(dim_obs)) 657 | 658 | # Covariance selection matrix. 659 | self._cov_select_mat = jnp.eye(_c * dim_obs) 660 | 661 | def initialize_params(self, obs_initial: Float[Array, " dim_obs"], obs_scale: Float[Array, " dim_obs"]) -> None: 662 | # Initialize the distribution of the initial state. 663 | dim_obs = len(obs_initial) 664 | # if self.initial_effect_pri is not None: 665 | # initial_effect_mean = self.initial_effect_pri.mean() 666 | # initial_effect_cov = self.initial_effect_pri.covariance() 667 | # else: 668 | # initial_effect_mean = jnp.zeros(dim_obs) 669 | # initial_effect_cov = jnp.diag(obs_scale**2) 670 | initial_effect_mean, initial_effect_cov = _initial_statistics( 671 | self.initial_effect_pri, jnp.zeros(dim_obs), jnp.diag(obs_scale**2) 672 | ) 673 | initial_mean = jnp.kron(jnp.eye(self.num_seasons - 1)[0], initial_effect_mean) 674 | initial_cov = jnp.kron(jnp.eye(self.num_seasons - 1), initial_effect_cov) 675 | # initial_mean = jnp.zeros((self.num_seasons-1) * dim_obs) 676 | # initial_cov = jnp.kron(jnp.eye(self.num_seasons-1), jnp.diag(obs_scale**2)) 677 | self.initial_distribution = MVN(initial_mean, initial_cov) 678 | 679 | # Initialize parameters. 680 | self.param_priors["drift_cov"] = _set_prior( 681 | self.drift_cov_pri, IW(df=dim_obs, scale=1e-3 * obs_scale**2 * jnp.eye(dim_obs)) 682 | ) 683 | self.params["drift_cov"] = self.param_priors["drift_cov"].mode() 684 | 685 | def get_trans_mat( 686 | self, params: ParamsSTSComponent, t: int 687 | ) -> Float[Array, " (num_seasons-1)*dim_obs (num_seasons-1)*dim_obs"]: 688 | return lax.cond( 689 | t % self.steps_per_season == 0, 690 | lambda: self._trans_mat, 691 | lambda: jnp.eye((self.num_seasons - 1) * self.dim_obs), 692 | ) 693 | 694 | def get_trans_cov( 695 | self, params: ParamsSTSComponent, t: int 696 | ) -> Float[Array, " (num_seasons-1)*dim_obs (num_seasons-1)*dim_obs"]: 697 | return lax.cond( 698 | t % self.steps_per_season == 0, 699 | lambda: jnp.kron(jnp.eye(self.num_seasons - 1), params["drift_cov"]), 700 | lambda: jnp.eye((self.num_seasons - 1) * self.dim_obs) * 1e-32, 701 | ) 702 | 703 | @property 704 | def obs_mat(self) -> Float[Array, " dim_obs (num_seassons-1)*dim_obs"]: 705 | return self._obs_mat 706 | 707 | @property 708 | def cov_select_mat(self) -> Float[Array, " (num_seasons-1)*dim_obs (num_seasons-1)*dim_obs"]: 709 | return self._cov_select_mat 710 | 711 | 712 | class Cycle(STSComponent): 713 | r"""The cycle component of the structural time series model. 714 | 715 | The cycle effect (random) of next time step takes the form: 716 | 717 | $$\gamma_t = \cos(freq) + \sin(freq)$$ 718 | 719 | and the state is updated by 720 | 721 | $$\gamma_{t+1} = \cos(freq) \gamma_t + \sin(freq) \gamma^*_t + w_t$$ 722 | $$\gamma^*_{t+1} = -\sin(freq) \gamma_t + \cos(freq) \gamma^*_t + w^*_t$$ 723 | 724 | where 725 | 726 | * $w_t$, and $w^*_t$ are stochastic noises of the cycle effect following 727 | a normal distribution with mean zeros and a common covariance $drift_cov$ for all $t$. 728 | 729 | The latent state corresponding to the cycle component has dimension $2 * dim_obs$. 730 | 731 | If $dim_obs = 1$, the transition matrix and the observation matrix is: 732 | $$ 733 | F = damp * \begin{bmatrix} 734 | \cos(freq) & \sin(freq) \\ 735 | -\sin(freq) & \cos(freq) 736 | \end{bmatrix}, 737 | \qquad 738 | H = [ 1, 0 ]. 739 | $$ 740 | 741 | where 742 | 743 | * $damp$ is the damping factor, and $0 < damp <1$. 744 | * $freq$ is the frequency factor, and $0 < freq < 2\pi$, 745 | therefore the period of cycle is $2\pi/freq$. 746 | """ 747 | 748 | def __init__( 749 | self, 750 | dim_obs: int = 1, 751 | name: str = "cycle", 752 | damping_factor_prior: tfd.Distribution = None, 753 | frequency_prior: tfd.Distribution = None, 754 | drift_cov_prior: tfd.Distribution = None, 755 | initial_effect_prior: tfd.MultivariateNormalFullCovariance = None, 756 | cov_constrainer: tfb.Bijector = None, 757 | ) -> None: 758 | super().__init__(name=name, dim_obs=dim_obs) 759 | self.damp_pri = damping_factor_prior 760 | self.frequency_pri = frequency_prior 761 | self.drift_cov_pri = drift_cov_prior 762 | self.initial_effect_pri = initial_effect_prior 763 | self.cov_constrain = RealToPSDBijector() if cov_constrainer is None else cov_constrainer 764 | 765 | self.initial_distribution = MVN(jnp.zeros(2 * dim_obs), jnp.eye(2 * dim_obs)) 766 | 767 | # Parameters of the component 768 | self.param_props["damp"] = ParameterProperties(trainable=True, constrainer=tfb.Sigmoid()) 769 | self.param_priors["damp"] = _set_prior(damping_factor_prior, Uniform(low=0.0, high=1.0)) 770 | self.params["damp"] = self.param_priors["damp"].mode() 771 | 772 | self.param_props["frequency"] = ParameterProperties( 773 | trainable=True, constrainer=tfb.Sigmoid(low=0.0, high=2 * jnp.pi) 774 | ) 775 | self.param_priors["frequency"] = _set_prior(frequency_prior, Uniform(low=0.0, high=2 * jnp.pi)) 776 | self.params["frequency"] = self.param_priors["frequency"].mode() 777 | 778 | self.param_props["drift_cov"] = ParameterProperties(trainable=True, constrainer=self.cov_constrain) 779 | self.param_priors["drift_cov"] = _set_prior(drift_cov_prior, IW(df=dim_obs, scale=jnp.eye(dim_obs))) 780 | self.params["drift_cov"] = self.param_priors["drift_cov"].mode() 781 | 782 | # Fixed observation matrix. 783 | self._obs_mat = jnp.kron(jnp.array([1, 0]), jnp.eye(dim_obs)) 784 | 785 | # Covariance selection matrix. 786 | self._cov_select_mat = jnp.kron(jnp.array([[1], [0]]), jnp.eye(dim_obs)) 787 | 788 | def initialize_params(self, obs_initial: Float[Array, " dim_obs"], obs_scale: Float[Array, " dim_obs"]) -> None: 789 | # Initialize the distribution of the initial state. 790 | dim_obs = len(obs_initial) 791 | # if self.initial_effect_pri is not None: 792 | # initial_effect_mean = self.initial_effect_pri.mean() 793 | # initial_effect_cov = self.initial_effect_pri.covariance() 794 | # else: 795 | # initial_effect_mean = jnp.zeros(dim_obs) 796 | # initial_effect_cov = jnp.diag(obs_scale**2) 797 | initial_effect_mean, initial_effect_cov = _initial_statistics( 798 | self.initial_effect_pri, jnp.zeros(dim_obs), jnp.diag(obs_scale**2) 799 | ) 800 | initial_mean = jnp.kron(jnp.eye(self.num_seasons - 1)[0], initial_effect_mean) 801 | initial_cov = jnp.kron(jnp.eye(self.num_seasons - 1), initial_effect_cov) 802 | # initial_mean = jnp.zeros((self.num_seasons-1) * dim_obs) 803 | # initial_cov = jnp.kron(jnp.eye(self.num_seasons-1), jnp.diag(obs_scale**2)) 804 | self.initial_distribution = MVN(initial_mean, initial_cov) 805 | 806 | # Initialize parameters. 807 | self.param_priors["drift_cov"] = _set_prior( 808 | self.drift_cov_pri, IW(df=dim_obs, scale=1e-3 * obs_scale**2 * jnp.eye(dim_obs)) 809 | ) 810 | self.params["drift_cov"] = self.param_priors["drift_cov"].mode() 811 | 812 | def get_trans_mat(self, params: ParamsSTSComponent, t: int) -> Float[Array, "2*dim_obs 2*dim_obs"]: 813 | freq = params["frequency"] 814 | damp = params["damp"] 815 | _trans_mat = jnp.array([[jnp.cos(freq), jnp.sin(freq)], [-jnp.sin(freq), jnp.cos(freq)]]) 816 | trans_mat = damp * _trans_mat 817 | return jnp.kron(trans_mat, jnp.eye(self.dim_obs)) 818 | 819 | def get_trans_cov(self, params: ParamsSTSComponent, t: int) -> Float[Array, "dim_obs dim_obs"]: 820 | return params["drift_cov"] 821 | 822 | @property 823 | def obs_mat(self) -> Float[Array, "dim_obs 2*dim_obs"]: 824 | return self._obs_mat 825 | 826 | @property 827 | def cov_select_mat(self) -> Float[Array, "2*dim_obs dim_obs"]: 828 | return self._cov_select_mat 829 | 830 | 831 | class LinearRegression(STSRegression): 832 | r"""The linear regression component of the structural time series (STS) model. 833 | 834 | The formula of the linear regression component is 835 | 836 | $$y_t = W x_t,$$ 837 | 838 | where 839 | 840 | * $u_t$ is the model inputs at time step $t$. 841 | $u_t = covariates[t]$ is no bias term is to be added, and 842 | $u_t = [covariates[t], 1]$ is the bias term is a bias term is to be added to the model. 843 | * $W$ is the coefficient matrix of shape $(dim_obs, dim_covariates)$ if no bias term, and 844 | $(dim_obs, dim_covariates + 1)$ is a bias term is to be added. 845 | """ 846 | 847 | def __init__( 848 | self, 849 | dim_covariates: int, 850 | add_bias: bool = True, 851 | dim_obs: int = 1, 852 | name: str = "linear_regression", 853 | weights_prior: tfd.Distribution = None, 854 | ) -> None: 855 | super().__init__(name=name, dim_obs=dim_obs) 856 | self.weights_pri = weights_prior 857 | self.add_bias = add_bias 858 | 859 | dim_inputs = dim_covariates + 1 if add_bias else dim_covariates 860 | 861 | self.param_props["weights"] = ParameterProperties(trainable=True, constrainer=tfb.Identity()) 862 | self.param_priors["weights"] = _set_prior( 863 | weights_prior, 864 | MNP( 865 | loc=jnp.zeros((dim_obs, dim_inputs)), row_covariance=jnp.eye(dim_obs), col_precision=jnp.eye(dim_inputs) 866 | ), 867 | ) 868 | self.params["weights"] = jnp.zeros((dim_obs, dim_inputs)) 869 | 870 | def initialize_params( 871 | self, 872 | covariates: Float[Array, "num_timesteps dim_covariates"], 873 | obs_time_series: Float[Array, "num_timesteps dim_obs"], 874 | ) -> None: 875 | if self.add_bias: 876 | inputs = jnp.concatenate((covariates, jnp.ones((covariates.shape[0], 1))), axis=1) 877 | W = jnp.linalg.solve(inputs.T @ inputs, inputs.T @ obs_time_series).T 878 | self.params["weights"] = W 879 | 880 | def get_reg_value( 881 | self, params: ParamsSTSComponent, covariates: Float[Array, "num_timesteps dim_covariates"] 882 | ) -> Float[Array, "num_timesteps dim_obs"]: 883 | if self.add_bias: 884 | inputs = jnp.concatenate((covariates, jnp.ones((covariates.shape[0], 1))), axis=1) 885 | return inputs @ params["weights"].T 886 | else: 887 | return covariates @ params["weights"].T 888 | --------------------------------------------------------------------------------