├── .github
└── workflows
│ └── run_tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .pylintrc
├── INPUT.md
├── README.md
├── REQUIREMENTS.txt
├── codecov.yml
├── conda_envs
└── pymc_statespace.yml
├── data
└── nile.csv
├── examples
├── ARMA Example.ipynb
├── Custom SSM - Daily Seasonality.ipynb
├── Nile Local Level Model.ipynb
└── VARMAX Example.ipynb
├── pymc_statespace
├── __init__.py
├── core
│ ├── __init__.py
│ ├── representation.py
│ └── statespace.py
├── filters
│ ├── __init__.py
│ ├── kalman_filter.py
│ ├── kalman_smoother.py
│ └── utilities.py
├── models
│ ├── SARIMAX.py
│ ├── VARMAX.py
│ ├── __init__.py
│ ├── local_level.py
│ └── utilities.py
└── utils
│ ├── __init__.py
│ ├── numba_linalg.py
│ ├── pytensor_scipy.py
│ └── simulation.py
├── pyproject.toml
├── setup.cfg
├── setup.py
├── svgs
├── 17a925cbd243c9dd3ce20fd8558993f1.svg
├── 17b59c002f249204f24e31507dc4957d.svg
├── 2d6502498c2ef42e278a774086e73a26.svg
├── 3b5e41543d7fc8cedf98ec609b343134.svg
├── 4881244fda86ce9792cccafb3bb7eb0c.svg
├── 523b266d36c270dbbb5daf2c9092ce0f.svg
├── 54221efbfb5e69569dfe8ddea785093a.svg
├── 5dfc2ae9e19de8995e3edb059f7abd19.svg
├── 92b8c1194757fb3131cda468a34be85f.svg
├── a06c0e58d4d162b0e87d32927c9812db.svg
├── a13d89295e999545a129b2d412e99f6d.svg
├── c1133ef3a57a33193b33c998a34246cc.svg
├── cac7e81ebde5e530e639eae5389f149e.svg
├── d523a14b8179ebe46f0ed16895ee46f0.svg
├── dfbe60bd49a89dc2de3950ee0ffab3f3.svg
├── edcff444fd5240add1c47d2de50ebd7e.svg
├── f566e90ed17c5292db4600846e0ace27.svg
└── ff25a8f22c7430ca572d33206c0a9176.svg
└── tests
├── __init__.py
├── test_VARMAX.py
├── test_data
├── nile.csv
└── statsmodels_macrodata_processed.csv
├── test_kalman_filter.py
├── test_local_level.py
├── test_numba_linalg.py
├── test_pytensor_scipy.py
├── test_representation.py
├── test_simulations.py
├── test_statespace.py
└── utilities
├── __init__.py
├── statsmodel_local_level.py
└── test_helpers.py
/.github/workflows/run_tests.yml:
--------------------------------------------------------------------------------
1 | name: run_tests
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches: [main]
7 |
8 |
9 | jobs:
10 | unittest:
11 | strategy:
12 | fail-fast: false
13 | matrix:
14 | os: [ubuntu-latest, windows-latest]
15 | python-version: ["3.9"]
16 | test-subset:
17 | - |
18 | tests/
19 |
20 | runs-on: ${{ matrix.os }}
21 |
22 | env:
23 | TEST_SUBSET: ${{ matrix.test-subset }}
24 |
25 | defaults:
26 | run:
27 | shell: bash -l {0}
28 |
29 | steps:
30 | - uses: actions/checkout@v3
31 | - uses: actions/cache@v3
32 | env:
33 | # Increase this value to reset cache if pymc_statespace.yml has not changed
34 | CACHE_NUMBER: 0
35 | with:
36 | path: ~/conda_pkgs_dir
37 | key: ${{ runner.os }}-py${{matrix.python-version}}-conda-${{ env.CACHE_NUMBER }}-${{
38 | hashFiles('conda_envs/pymc_statespace.yml') }}
39 | - name: Cache multiple paths
40 | uses: actions/cache@v3
41 | env:
42 | # Increase this value to reset cache if requirements.txt has not changed
43 | CACHE_NUMBER: 0
44 | with:
45 | path: |
46 | ~/.cache/pip
47 | $RUNNER_TOOL_CACHE/Python/*
48 | ~\AppData\Local\pip\Cache
49 | key: ${{ runner.os }}-build-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{
50 | hashFiles('requirements.txt') }}
51 | - uses: conda-incubator/setup-miniconda@v2
52 | with:
53 | miniforge-variant: Mambaforge
54 | miniforge-version: latest
55 | mamba-version: "*"
56 | activate-environment: pymc-statespace
57 | channel-priority: flexible
58 | environment-file: conda_envs/pymc_statespace.yml
59 | python-version: 3.9
60 | use-mamba: true
61 | use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267
62 |
63 | - name: Install current branch
64 | run: |
65 | conda activate pymc-statespace
66 | pip install -e .
67 | python --version
68 |
69 | - name: Run tests
70 | run: |
71 | python -m pytest -vv --cov=pymc_statespace --cov-report=xml --no-cov-on-fail --cov-report term $TEST_SUBSET
72 | - name: Upload coverage to Codecov
73 | uses: codecov/codecov-action@v3
74 | with:
75 | token: ${{ secrets.CODECOV_TOKEN }} # use token for more robust uploads
76 | env_vars: TEST_SUBSET
77 | name: ${{ matrix.os }}
78 | fail_ci_if_error: false
79 | verbose: true
80 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/macos,windows,pycharm,jupyternotebooks
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=macos,windows,pycharm,jupyternotebooks
3 |
4 | **/__pycache__/
5 |
6 | ### JupyterNotebooks ###
7 | # gitignore template for Jupyter Notebooks
8 | # website: http://jupyter.org/
9 |
10 | .ipynb_checkpoints
11 | */.ipynb_checkpoints/*
12 |
13 | # IPython
14 | profile_default/
15 | ipython_config.py
16 |
17 | # Remove previous ipynb_checkpoints
18 | # git rm -r .ipynb_checkpoints/
19 |
20 | ### macOS ###
21 | # General
22 | .DS_Store
23 | .AppleDouble
24 | .LSOverride
25 |
26 | # Icon must end with two \r
27 | Icon
28 |
29 |
30 | # Thumbnails
31 | ._*
32 |
33 | # Files that might appear in the root of a volume
34 | .DocumentRevisions-V100
35 | .fseventsd
36 | .Spotlight-V100
37 | .TemporaryItems
38 | .Trashes
39 | .VolumeIcon.icns
40 | .com.apple.timemachine.donotpresent
41 |
42 | # Directories potentially created on remote AFP share
43 | .AppleDB
44 | .AppleDesktop
45 | Network Trash Folder
46 | Temporary Items
47 | .apdisk
48 |
49 | ### macOS Patch ###
50 | # iCloud generated files
51 | *.icloud
52 |
53 | ### PyCharm ###
54 | .idea/
55 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
56 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
57 |
58 | # User-specific stuff
59 | .idea/**/workspace.xml
60 | .idea/**/tasks.xml
61 | .idea/**/usage.statistics.xml
62 | .idea/**/dictionaries
63 | .idea/**/shelf
64 |
65 | # AWS User-specific
66 | .idea/**/aws.xml
67 |
68 | # Generated files
69 | .idea/**/contentModel.xml
70 |
71 | # Sensitive or high-churn files
72 | .idea/**/dataSources/
73 | .idea/**/dataSources.ids
74 | .idea/**/dataSources.local.xml
75 | .idea/**/sqlDataSources.xml
76 | .idea/**/dynamic.xml
77 | .idea/**/uiDesigner.xml
78 | .idea/**/dbnavigator.xml
79 |
80 | # Gradle
81 | .idea/**/gradle.xml
82 | .idea/**/libraries
83 |
84 | # Gradle and Maven with auto-import
85 | # When using Gradle or Maven with auto-import, you should exclude module files,
86 | # since they will be recreated, and may cause churn. Uncomment if using
87 | # auto-import.
88 | # .idea/artifacts
89 | # .idea/compiler.xml
90 | # .idea/jarRepositories.xml
91 | # .idea/modules.xml
92 | # .idea/*.iml
93 | # .idea/modules
94 | # *.iml
95 | # *.ipr
96 |
97 | # CMake
98 | cmake-build-*/
99 |
100 | # Mongo Explorer plugin
101 | .idea/**/mongoSettings.xml
102 |
103 | # File-based project format
104 | *.iws
105 |
106 | # IntelliJ
107 | out/
108 |
109 | # mpeltonen/sbt-idea plugin
110 | .idea_modules/
111 |
112 | # JIRA plugin
113 | atlassian-ide-plugin.xml
114 |
115 | # Cursive Clojure plugin
116 | .idea/replstate.xml
117 |
118 | # SonarLint plugin
119 | .idea/sonarlint/
120 |
121 | # Crashlytics plugin (for Android Studio and IntelliJ)
122 | com_crashlytics_export_strings.xml
123 | crashlytics.properties
124 | crashlytics-build.properties
125 | fabric.properties
126 |
127 | # Editor-based Rest Client
128 | .idea/httpRequests
129 |
130 | # Android studio 3.1+ serialized cache file
131 | .idea/caches/build_file_checksums.ser
132 |
133 | ### PyCharm Patch ###
134 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
135 |
136 | # *.iml
137 | # modules.xml
138 | # .idea/misc.xml
139 | # *.ipr
140 |
141 | # Sonarlint plugin
142 | # https://plugins.jetbrains.com/plugin/7973-sonarlint
143 | .idea/**/sonarlint/
144 |
145 | # SonarQube Plugin
146 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
147 | .idea/**/sonarIssues.xml
148 |
149 | # Markdown Navigator plugin
150 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
151 | .idea/**/markdown-navigator.xml
152 | .idea/**/markdown-navigator-enh.xml
153 | .idea/**/markdown-navigator/
154 |
155 | # Cache file creation bug
156 | # See https://youtrack.jetbrains.com/issue/JBR-2257
157 | .idea/$CACHE_FILE$
158 |
159 | # CodeStream plugin
160 | # https://plugins.jetbrains.com/plugin/12206-codestream
161 | .idea/codestream.xml
162 |
163 | # Azure Toolkit for IntelliJ plugin
164 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij
165 | .idea/**/azureSettings.xml
166 |
167 | ### Windows ###
168 | # Windows thumbnail cache files
169 | Thumbs.db
170 | Thumbs.db:encryptable
171 | ehthumbs.db
172 | ehthumbs_vista.db
173 |
174 | # Dump file
175 | *.stackdump
176 |
177 | # Folder config file
178 | [Dd]esktop.ini
179 |
180 | # Recycle Bin used on file shares
181 | $RECYCLE.BIN/
182 |
183 | # Windows Installer files
184 | *.cab
185 | *.msi
186 | *.msix
187 | *.msm
188 | *.msp
189 |
190 | # Windows shortcuts
191 | *.lnk
192 |
193 | # End of https://www.toptal.com/developers/gitignore/api/macos,windows,pycharm,jupyternotebooks
194 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.3.0
4 | hooks:
5 | - id: check-merge-conflict
6 | - id: check-toml
7 | - id: check-yaml
8 | - id: debug-statements
9 | - id: end-of-file-fixer
10 | exclude: .txt$
11 | - id: trailing-whitespace
12 | - id: requirements-txt-fixer
13 | - repo: https://github.com/PyCQA/isort
14 | rev: 5.12.0
15 | hooks:
16 | - id: isort
17 | name: isort
18 | - repo: https://github.com/asottile/pyupgrade
19 | rev: v3.3.1
20 | hooks:
21 | - id: pyupgrade
22 | args: [--py37-plus]
23 | - repo: https://github.com/psf/black
24 | rev: 22.12.0
25 | hooks:
26 | - id: black
27 | - id: black-jupyter
28 | - repo: https://github.com/PyCQA/pylint
29 | rev: v2.16.0b1
30 | hooks:
31 | - id: pylint
32 | args: [--rcfile=.pylintrc]
33 | files: ^gEconpy/
34 | - repo: local
35 | hooks:
36 | - id: no-relative-imports
37 | name: No relative imports
38 | entry: from \.[\.\w]* import
39 | types: [python]
40 | language: pygrep
41 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 | # Use multiple processes to speed up Pylint.
3 | jobs=1
4 |
5 | # Allow loading of arbitrary C extensions. Extensions are imported into the
6 | # active Python interpreter and may run arbitrary code.
7 | unsafe-load-any-extension=no
8 |
9 | # Allow optimization of some AST trees. This will activate a peephole AST
10 | # optimizer, which will apply various small optimizations. For instance, it can
11 | # be used to obtain the result of joining multiple strings with the addition
12 | # operator. Joining a lot of strings can lead to a maximum recursion error in
13 | # Pylint and this flag can prevent that. It has one side effect, the resulting
14 | # AST will be different than the one from reality.
15 | optimize-ast=no
16 |
17 | [MESSAGES CONTROL]
18 |
19 | # Only show warnings with the listed confidence levels. Leave empty to show
20 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
21 | confidence=
22 |
23 | # Disable the message, report, category or checker with the given id(s). You
24 | # can either give multiple identifiers separated by comma (,) or put this
25 | # option multiple times (only on the command line, not in the configuration
26 | # file where it should appear only once).You can also use "--disable=all" to
27 | # disable everything first and then reenable specific checks. For example, if
28 | # you want to run only the similarities checker, you can use "--disable=all
29 | # --enable=similarities". If you want to run only the classes checker, but have
30 | # no Warning level messages displayed, use"--disable=all --enable=classes
31 | # --disable=W"
32 | disable=all
33 |
34 | # Enable the message, report, category or checker with the given id(s). You can
35 | # either give multiple identifier separated by comma (,) or put this option
36 | # multiple time. See also the "--disable" option for examples.
37 | enable=import-self,
38 | reimported,
39 | wildcard-import,
40 | misplaced-future,
41 | relative-import,
42 | deprecated-module,
43 | unpacking-non-sequence,
44 | invalid-all-object,
45 | undefined-all-variable,
46 | used-before-assignment,
47 | cell-var-from-loop,
48 | global-variable-undefined,
49 | dangerous-default-value,
50 | # redefined-builtin,
51 | redefine-in-handler,
52 | unused-import,
53 | unused-wildcard-import,
54 | global-variable-not-assigned,
55 | undefined-loop-variable,
56 | global-statement,
57 | global-at-module-level,
58 | bad-open-mode,
59 | redundant-unittest-assert,
60 | boolean-datetime,
61 | # unused-variable
62 |
63 |
64 | [REPORTS]
65 |
66 | # Set the output format. Available formats are text, parseable, colorized, msvs
67 | # (visual studio) and html. You can also give a reporter class, eg
68 | # mypackage.mymodule.MyReporterClass.
69 | output-format=parseable
70 |
71 | # Put messages in a separate file for each module / package specified on the
72 | # command line instead of printing them on stdout. Reports (if any) will be
73 | # written in a file name "pylint_global.[txt|html]".
74 | files-output=no
75 |
76 | # Tells whether to display a full report or only the messages
77 | reports=no
78 |
79 | # Python expression which should return a note less than 10 (10 is the highest
80 | # note). You have access to the variables errors warning, statement which
81 | # respectively contain the number of errors / warnings messages and the total
82 | # number of statements analyzed. This is used by the global evaluation report
83 | # (RP0004).
84 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
85 |
86 | [BASIC]
87 |
88 | # List of builtins function names that should not be used, separated by a comma
89 | bad-functions=map,filter,input
90 |
91 | # Good variable names which should always be accepted, separated by a comma
92 | good-names=i,j,k,ex,Run,_
93 |
94 | # Bad variable names which should always be refused, separated by a comma
95 | bad-names=foo,bar,baz,toto,tutu,tata
96 |
97 | # Colon-delimited sets of names that determine each other's naming style when
98 | # the name regexes allow several styles.
99 | name-group=
100 |
101 | # Include a hint for the correct naming format with invalid-name
102 | include-naming-hint=yes
103 |
104 | # Regular expression matching correct method names
105 | method-rgx=[a-z_][a-z0-9_]{2,30}$
106 |
107 | # Naming hint for method names
108 | method-name-hint=[a-z_][a-z0-9_]{2,30}$
109 |
110 | # Regular expression matching correct function names
111 | function-rgx=[a-z_][a-z0-9_]{2,30}$
112 |
113 | # Naming hint for function names
114 | function-name-hint=[a-z_][a-z0-9_]{2,30}$
115 |
116 | # Regular expression matching correct module names
117 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
118 |
119 | # Naming hint for module names
120 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
121 |
122 | # Regular expression matching correct attribute names
123 | attr-rgx=[a-z_][a-z0-9_]{2,30}$
124 |
125 | # Naming hint for attribute names
126 | attr-name-hint=[a-z_][a-z0-9_]{2,30}$
127 |
128 | # Regular expression matching correct class attribute names
129 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
130 |
131 | # Naming hint for class attribute names
132 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
133 |
134 | # Regular expression matching correct constant names
135 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
136 |
137 | # Naming hint for constant names
138 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
139 |
140 | # Regular expression matching correct class names
141 | class-rgx=[A-Z_][a-zA-Z0-9]+$
142 |
143 | # Naming hint for class names
144 | class-name-hint=[A-Z_][a-zA-Z0-9]+$
145 |
146 | # Regular expression matching correct argument names
147 | argument-rgx=[a-z_][a-z0-9_]{2,30}$
148 |
149 | # Naming hint for argument names
150 | argument-name-hint=[a-z_][a-z0-9_]{2,30}$
151 |
152 | # Regular expression matching correct inline iteration names
153 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
154 |
155 | # Naming hint for inline iteration names
156 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
157 |
158 | # Regular expression matching correct variable names
159 | variable-rgx=[a-z_][a-z0-9_]{2,30}$
160 |
161 | # Naming hint for variable names
162 | variable-name-hint=[a-z_][a-z0-9_]{2,30}$
163 |
164 | # Regular expression which should only match function or class names that do
165 | # not require a docstring.
166 | no-docstring-rgx=^_
167 |
168 | # Minimum line length for functions/classes that require docstrings, shorter
169 | # ones are exempt.
170 | docstring-min-length=-1
171 |
172 |
173 | [ELIF]
174 |
175 | # Maximum number of nested blocks for function / method body
176 | max-nested-blocks=5
177 |
178 |
179 | [FORMAT]
180 |
181 | # Maximum number of characters on a single line.
182 | max-line-length=100
183 |
184 | # Regexp for a line that is allowed to be longer than the limit.
185 | ignore-long-lines=^\s*(# )??$
186 |
187 | # Allow the body of an if to be on the same line as the test if there is no
188 | # else.
189 | single-line-if-stmt=no
190 |
191 | # List of optional constructs for which whitespace checking is disabled. `dict-
192 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
193 | # `trailing-comma` allows a space between comma and closing bracket: (a, ).
194 | # `empty-line` allows space-only lines.
195 | no-space-check=trailing-comma,dict-separator
196 |
197 | # Maximum number of lines in a module
198 | max-module-lines=1000
199 |
200 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
201 | # tab).
202 | indent-string=' '
203 |
204 | # Number of spaces of indent required inside a hanging or continued line.
205 | indent-after-paren=4
206 |
207 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
208 | expected-line-ending-format=
209 |
210 |
211 | [LOGGING]
212 |
213 | # Logging modules to check that the string format arguments are in logging
214 | # function parameter format
215 | logging-modules=logging
216 |
217 |
218 | [MISCELLANEOUS]
219 |
220 | # List of note tags to take in consideration, separated by a comma.
221 | notes=FIXME,XXX,TODO
222 |
223 |
224 | [SIMILARITIES]
225 |
226 | # Minimum lines number of a similarity.
227 | min-similarity-lines=4
228 |
229 | # Ignore comments when computing similarities.
230 | ignore-comments=yes
231 |
232 | # Ignore docstrings when computing similarities.
233 | ignore-docstrings=yes
234 |
235 | # Ignore imports when computing similarities.
236 | ignore-imports=no
237 |
238 |
239 | [SPELLING]
240 |
241 | # Spelling dictionary name. Available dictionaries: none. To make it working
242 | # install python-enchant package.
243 | spelling-dict=
244 |
245 | # List of comma separated words that should not be checked.
246 | spelling-ignore-words=
247 |
248 | # A path to a file that contains private dictionary; one word per line.
249 | spelling-private-dict-file=
250 |
251 | # Tells whether to store unknown words to indicated private dictionary in
252 | # --spelling-private-dict-file option instead of raising a message.
253 | spelling-store-unknown-words=no
254 |
255 |
256 | [TYPECHECK]
257 |
258 | # Tells whether missing members accessed in mixin class should be ignored. A
259 | # mixin class is detected if its name ends with "mixin" (case insensitive).
260 | ignore-mixin-members=yes
261 |
262 | # List of module names for which member attributes should not be checked
263 | # (useful for modules/projects where namespaces are manipulated during runtime
264 | # and thus existing member attributes cannot be deduced by static analysis. It
265 | # supports qualified module names, as well as Unix pattern matching.
266 | ignored-modules=
267 |
268 | # List of classes names for which member attributes should not be checked
269 | # (useful for classes with attributes dynamically set). This supports can work
270 | # with qualified names.
271 | ignored-classes=
272 |
273 | # List of members which are set dynamically and missed by pylint inference
274 | # system, and so shouldn't trigger E1101 when accessed. Python regular
275 | # expressions are accepted.
276 | generated-members=
277 |
278 |
279 | [VARIABLES]
280 |
281 | # Tells whether we should check for unused import in __init__ files.
282 | init-import=no
283 |
284 | # A regular expression matching the name of dummy variables (i.e. expectedly
285 | # not used).
286 | dummy-variables-rgx=_$|dummy
287 |
288 | # List of additional names supposed to be defined in builtins. Remember that
289 | # you should avoid to define new builtins when possible.
290 | additional-builtins=
291 |
292 | # List of strings which can identify a callback function by name. A callback
293 | # name must start or end with one of those strings.
294 | callbacks=cb_,_cb
295 |
296 |
297 | [CLASSES]
298 |
299 | # List of method names used to declare (i.e. assign) instance attributes.
300 | defining-attr-methods=__init__,__new__,setUp
301 |
302 | # List of valid names for the first argument in a class method.
303 | valid-classmethod-first-arg=cls
304 |
305 | # List of valid names for the first argument in a metaclass class method.
306 | valid-metaclass-classmethod-first-arg=mcs
307 |
308 | # List of member names, which should be excluded from the protected access
309 | # warning.
310 | exclude-protected=_asdict,_fields,_replace,_source,_make
311 |
312 |
313 | [DESIGN]
314 |
315 | # Maximum number of arguments for function / method
316 | max-args=5
317 |
318 | # Argument names that match this expression will be ignored. Default to name
319 | # with leading underscore
320 | ignored-argument-names=_.*
321 |
322 | # Maximum number of locals for function / method body
323 | max-locals=15
324 |
325 | # Maximum number of return / yield for function / method body
326 | max-returns=6
327 |
328 | # Maximum number of branch for function / method body
329 | max-branches=12
330 |
331 | # Maximum number of statements in function / method body
332 | max-statements=50
333 |
334 | # Maximum number of parents for a class (see R0901).
335 | max-parents=7
336 |
337 | # Maximum number of attributes for a class (see R0902).
338 | max-attributes=7
339 |
340 | # Minimum number of public methods for a class (see R0903).
341 | min-public-methods=2
342 |
343 | # Maximum number of public methods for a class (see R0904).
344 | max-public-methods=20
345 |
346 | # Maximum number of boolean expressions in a if statement
347 | max-bool-expr=5
348 |
349 |
350 | [IMPORTS]
351 |
352 | # Deprecated modules which should not be used, separated by a comma
353 | deprecated-modules=optparse
354 |
355 | # Create a graph of every (i.e. internal and external) dependencies in the
356 | # given file (report RP0402 must not be disabled)
357 | import-graph=
358 |
359 | # Create a graph of external dependencies in the given file (report RP0402 must
360 | # not be disabled)
361 | ext-import-graph=
362 |
363 | # Create a graph of internal dependencies in the given file (report RP0402 must
364 | # not be disabled)
365 | int-import-graph=
366 |
367 |
368 | [EXCEPTIONS]
369 |
370 | # Exceptions that will emit a warning when being caught. Defaults to
371 | # "Exception"
372 | overgeneral-exceptions=Exception
373 |
--------------------------------------------------------------------------------
/INPUT.md:
--------------------------------------------------------------------------------
1 | # PyMC StateSpace
2 | A system for Bayesian estimation of state space models using PyMC 4.0. This package is designed to mirror the functionality of the Statsmodels.api `tsa.statespace` module, except within a Bayesian estimation framework. To accomplish this, PyMC Statespace has a Kalman filter written in Pytensor, allowing the gradients of the iterative Kalman filter likelihood to be computed and provided to the PyMC NUTS sampler.
3 |
4 | ## State Space Models
5 | This package follows Statsmodels in using the Durbin and Koopman (2012) nomenclature for a linear state space model. Under this nomenclature, the model is written as:
6 |
7 | $y_t = Z_t \alpha_t + d_t + \varepsilon_t, \quad \varepsilon_t \sim N(0, H_t)$
8 | $\alpha_{t+1} = T_t \alpha_t + c_t + R_t \eta_t, \quad \eta_t \sim N(0, Q_t)$
9 | $\alpha_1 \sim N(a_1, P_1)$
10 |
11 | The objects in the above equation have the following shapes and meanings:
12 |
13 | - $y_t, p \times 1$, observation vector
14 | - $\alpha_t, m \times 1$, state vector
15 | - $\varepsilon_t, p \times 1$, observation noise vector
16 | - $\eta_t, r \times 1$, state innovation vector
17 |
18 |
19 | - $Z_t, p \times m$, the design matrix
20 | - $H_t, p \times p$, observation noise covariance matrix
21 | - $T_t, m \times m$, the transition matrix
22 | - $R_t, m \times r$, the selection matrix
23 | - $Q_t, r \times r$, the state innovation covariance matrix
24 |
25 |
26 | - $c_t, m \times 1$, the state intercept vector
27 | - $d_t, p \times 1$, the observation intercept vector
28 |
29 | The linear state space model is a workhorse in many disciplines, and is flexible enough to represent a wide range of models, including Box-Jenkins SARIMAX class models, time series decompositions, and model of multiple time series (VARMAX) models. Use of a Kalman filter allows for estimation of unobserved and missing variables. Taken together, these are a powerful class of models that can be further augmented by Bayesian estimation. This allows the researcher to integrate uncertainty about the true model when estimating model parameteres.
30 |
31 |
32 | ## Example Usage
33 |
34 | Currently, local level and ARMA models are implemented. To initialize a model, simply create a model object, provide data, and any additional arguments unique to that model.
35 | ```python
36 | import pymc_statespace as pmss
37 | #make sure data is a 2d numpy array!
38 | arma_model = pmss.BayesianARMA(data = data[:, None], order=(1, 1))
39 | ```
40 | You will see a message letting you know the model was created, as well as telling you the expected parameter names you will need to create inside a PyMC model block.
41 | ```commandline
42 | Model successfully initialized! The following parameters should be assigned priors inside a PyMC model block: x0, sigma_state, rho, theta
43 | ```
44 |
45 | Next, a PyMC model is declared as usual, and the parameters can be passed into the state space model. This is done as follows:
46 | ```python
47 | with pm.Model() as arma_model:
48 | x0 = pm.Normal('x0', mu=0, sigma=1.0, shape=mod.k_states)
49 | sigma_state = pm.HalfNormal('sigma_state', sigma=1)
50 |
51 | rho = pm.TruncatedNormal('rho', mu=0.0, sigma=0.5, lower=-1.0, upper=1.0, shape=p)
52 | theta = pm.Normal('theta', mu=0.0, sigma=0.5, shape=q)
53 |
54 | arma_model.build_statespace_graph()
55 | trace_1 = pm.sample(init='jitter+adapt_diag_grad', target_accept=0.9)
56 | ```
57 |
58 | After you place priors over the requested parameters, call `arma_model.build_statespace_graph()` to automatically put everything together into a state space system. If you are missing any parameters, you will get a error letting you know what else needs to be defined.
59 |
60 | And that's it! After this, you can sample the PyMC model as normal.
61 |
62 |
63 | ## Creating your own state space model
64 |
65 | Creating a custom state space model isn't complicated. Once again, the API follows the Statsmodels implementation. All models need to subclass the `PyMCStateSpace` class, and pass three values into the class super construction: `data` (from which p is inferred), `k_states` (this is "m" in the shapes above), and `k_posdef` (this is "r" above). The user also needs to declare any required state space matrices. Here is an example of a simple local linear trend model:
66 |
67 | ```python
68 | def __init__(self, data):
69 | # Model shapes
70 | k_states = k_posdef = 2
71 |
72 | super().__init__(data, k_states, k_posdef)
73 |
74 | # Initialize the matrices
75 | self.ssm['design'] = np.array([[1.0, 0.0]])
76 | self.ssm['transition'] = np.array([[1.0, 1.0],
77 | [0.0, 1.0]])
78 | self.ssm['selection'] = np.eye(k_states)
79 |
80 | self.ssm['initial_state'] = np.array([[0.0],
81 | [0.0]])
82 | self.ssm['initial_state_cov'] = np.array([[1.0, 0.0],
83 | [0.0, 1.0]])
84 | ```
85 |
86 | You will also need to set up the `param_names` class method, which returns a list of parameter names. This lets users know what priors they need to define, and lets the model gather the required parameters from the PyMC model context.
87 |
88 | ```python
89 | @property
90 | def param_names(self):
91 | return ['x0', 'P0', 'sigma_obs', 'sigma_state']
92 | ```
93 |
94 | `self.ssm` is an `PytensorRepresentation` class object that is created by the super constructor. Every model has a `self.ssm` and a `self.kalman_filter` created after the super constructor is called. All the matrices stored in `self.ssm` are Pytensor tensor variables, but numpy arrays can be passed to them for convenience. Behind the scenes, they will be converted to Pytensor tensors.
95 |
96 | Note that the names of the matrices correspond to the names listed above. They are (in the same order):
97 |
98 | - Z = design
99 | - H = obs_cov
100 | - T = transition
101 | - R = selection
102 | - Q = state_cov
103 | - c = state_intercept
104 | - d = obs_intercept
105 | - a1 = initial_state
106 | - P1 = initial_state_cov
107 |
108 | Indexing by name only will expose the entire matrix. A name can also be followed by the usual numpy slice notation to get a specific element, row, or column.
109 |
110 | The user also needs to implement an `update` method, which takes in a single Pytensor tensor as an argument. This method routes the parameters estimated by PyMC into the right spots in the state space matrices. The local level has at least two parameters to estimate: the variance of the level state innovations, and the variance of the trend state innovations. Here is the corresponding update method:
111 |
112 | ```python
113 | def update(self, theta: at.TensorVariable) -> None:
114 | # Observation covariance
115 | self.ssm['obs_cov', 0, 0] = theta[0]
116 |
117 | # State covariance
118 | self.ssm['state_cov', np.diag_indices(2)] = theta[1:]
119 | ```
120 |
121 | And that's it! By making a model you also gain access to simulation helper functions, including prior and posterior simulation from the predicted, filtered, and smoothed states, as well as stability analysis and impulse response functions.
122 |
123 | This package is very much a work in progress and really needs help! If you're interested in time series analysis and want to contribute, please reach out!
124 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UPDATE
2 |
3 | This repo is no longer maintained! Statespace models are now a part of [pymc-experimental](https://github.com/pymc-devs/pymc-experimental/), and are maintained by the PyMC development team. Please look over there for the most up-to-date Bayesian State Space models!
4 |
5 | # PyMC StateSpace
6 | A system for Bayesian estimation of state space models using PyMC 5.0. This package is designed to mirror the functionality of the Statsmodels.api `tsa.statespace` module, except within a Bayesian estimation framework. To accomplish this, PyMC Statespace has a Kalman filter written in Pytensor, allowing the gradients of the iterative Kalman filter likelihood to be computed and provided to the PyMC NUTS sampler.
7 |
8 | ## State Space Models
9 | This package follows Statsmodels in using the Durbin and Koopman (2012) nomenclature for a linear state space model. Under this nomenclature, the model is written as:
10 |
11 |
12 |
13 |
14 |
15 | The objects in the above equation have the following shapes and meanings:
16 |
17 | -
, observation vector
18 | -
, state vector
19 | -
, observation noise vector
20 | -
, state innovation vector
21 |
22 |
23 | -
, the design matrix
24 | -
, observation noise covariance matrix
25 | -
, the transition matrix
26 | -
, the selection matrix
27 | -
, the state innovation covariance matrix
28 |
29 |
30 | -
, the state intercept vector
31 | -
, the observation intercept vector
32 |
33 | The linear state space model is a workhorse in many disciplines, and is flexible enough to represent a wide range of models, including Box-Jenkins SARIMAX class models, time series decompositions, and model of multiple time series (VARMAX) models. Use of a Kalman filter allows for estimation of unobserved and missing variables. Taken together, these are a powerful class of models that can be further augmented by Bayesian estimation. This allows the researcher to integrate uncertainty about the true model when estimating model parameteres.
34 |
35 |
36 | ## Example Usage
37 |
38 | Currently, local level and ARMA models are implemented. To initialize a model, simply create a model object, provide data, and any additional arguments unique to that model. In addition, you can choose from several Kalman Filter implementations. Currently these are `standard`, `steady_state`, `univariate`, `single`, and `cholesky`. For more information, see the docs (they're on the to-do list)
39 | ```python
40 | import pymc_statespace as pmss
41 | #make sure data is a 2d numpy array!
42 | arma_model = pmss.BayesianARMA(data = data[:, None], order=(1, 1), filter='standard', stationary_initalization=True)
43 | ```
44 | You will see a message letting you know the model was created, as well as telling you the expected parameter names you will need to create inside a PyMC model block.
45 | ```commandline
46 | Model successfully initialized! The following parameters should be assigned priors inside a PyMC model block: x0, sigma_state, rho, theta
47 | ```
48 |
49 | Next, a PyMC model is declared as usual, and the parameters can be passed into the state space model. This is done as follows:
50 | ```python
51 | with pm.Model() as arma_model:
52 | x0 = pm.Normal('x0', mu=0, sigma=1.0, shape=mod.k_states)
53 | sigma_state = pm.HalfNormal('sigma_state', sigma=1)
54 |
55 | rho = pm.TruncatedNormal('rho', mu=0.0, sigma=0.5, lower=-1.0, upper=1.0, shape=p)
56 | theta = pm.Normal('theta', mu=0.0, sigma=0.5, shape=q)
57 |
58 | arma_model.build_statespace_graph()
59 | trace_1 = pm.sample(init='jitter+adapt_diag_grad', target_accept=0.9)
60 | ```
61 |
62 | After you place priors over the requested parameters, call `arma_model.build_statespace_graph()` to automatically put everything together into a state space system. If you are missing any parameters, you will get a error letting you know what else needs to be defined.
63 |
64 | And that's it! After this, you can sample the PyMC model as normal.
65 |
66 |
67 | ## Creating your own state space model
68 |
69 | Creating a custom state space model isn't complicated. Once again, the API follows the Statsmodels implementation. All models need to subclass the `PyMCStateSpace` class, and pass three values into the class super construction: `data` (from which p is inferred), `k_states` (this is "m" in the shapes above), and `k_posdef` (this is "r" above). The user also needs to declare any required state space matrices. Here is an example of a simple local linear trend model:
70 |
71 | ```python
72 | def __init__(self, data):
73 | # Model shapes
74 | k_states = k_posdef = 2
75 |
76 | super().__init__(data, k_states, k_posdef)
77 |
78 | # Initialize the matrices
79 | self.ssm['design'] = np.array([[1.0, 0.0]])
80 | self.ssm['transition'] = np.array([[1.0, 1.0],
81 | [0.0, 1.0]])
82 | self.ssm['selection'] = np.eye(k_states)
83 |
84 | self.ssm['initial_state'] = np.array([[0.0],
85 | [0.0]])
86 | self.ssm['initial_state_cov'] = np.array([[1.0, 0.0],
87 | [0.0, 1.0]])
88 | ```
89 | You will also need to set up the `param_names` class method, which returns a list of parameter names. This lets users know what priors they need to define, and lets the model gather the required parameters from the PyMC model context.
90 |
91 | ```python
92 | @property
93 | def param_names(self):
94 | return ['x0', 'P0', 'sigma_obs', 'sigma_state']
95 | ```
96 |
97 | `self.ssm` is an `PytensorRepresentation` class object that is created by the super constructor. Every model has a `self.ssm` and a `self.kalman_filter` created after the super constructor is called. All the matrices stored in `self.ssm` are Pytensor tensor variables, but numpy arrays can be passed to them for convenience. Behind the scenes, they will be converted to Pytensor tensors.
98 |
99 | Note that the names of the matrices correspond to the names listed above. They are (in the same order):
100 |
101 | - Z = design
102 | - H = obs_cov
103 | - T = transition
104 | - R = selection
105 | - Q = state_cov
106 | - c = state_intercept
107 | - d = obs_intercept
108 | - a1 = initial_state
109 | - P1 = initial_state_cov
110 |
111 | Indexing by name only will expose the entire matrix. A name can also be followed by the usual numpy slice notation to get a specific element, row, or column.
112 |
113 | The user also needs to implement an `update` method, which takes in a single Pytensor tensor as an argument. This method routes the parameters estimated by PyMC into the right spots in the state space matrices. The local level has at least two parameters to estimate: the variance of the level state innovations, and the variance of the trend state innovations. Here is the corresponding update method:
114 |
115 | ```python
116 | def update(self, theta: at.TensorVariable) -> None:
117 | # Observation covariance
118 | self.ssm['obs_cov', 0, 0] = theta[0]
119 |
120 | # State covariance
121 | self.ssm['state_cov', np.diag_indices(2)] = theta[1:]
122 | ```
123 |
124 | This function is why the order matters when flattening and concatenating the random variables inside the PyMC model. In this case, we must first pass `sigma_obs`, followed by `sigma_level`, then `sigma_trend`.
125 |
126 | But that's it! Obviously this API isn't great, and will be subject to change as the package evolves, but it should be enough to get a motivated research going. Happy estimation, and let me know all the bugs you find by opening an issue.
127 |
--------------------------------------------------------------------------------
/REQUIREMENTS.txt:
--------------------------------------------------------------------------------
1 | # Base dependencies
2 | pymc>=5.2
3 | pytensor
4 | numba>=0.57
5 | numpy
6 | pandas
7 | xarray
8 | matplotlib
9 | arviz
10 | statsmodels
11 |
12 | # Testing dependencies
13 | pre-commit
14 | pytest-cov>=2.5
15 | pytest>=3.0
16 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | codecov:
2 | require_ci_to_pass: yes
3 |
4 | coverage:
5 | precision: 2
6 | round: down
7 | range: "70...100"
8 | status:
9 | project:
10 | default:
11 | # basic
12 | target: auto
13 | threshold: 1%
14 | base: auto
15 | patch:
16 | default:
17 | # basic
18 | target: 50%
19 | threshold: 1%
20 | base: auto
21 |
22 | ignore:
23 | - "pymc_statespace/tests/*"
24 | - "pymc_statespace/examples/*"
25 | - "pymc_statespace/data/*"
26 |
27 | comment:
28 | layout: "reach, diff, flags, files"
29 | behavior: default
30 | require_changes: false # if true: only post the comment if coverage changes
31 | require_base: no # [yes :: must have a base report to post]
32 | require_head: yes # [yes :: must have a head report to post]
33 | branches: null # branch names that can post comment
34 |
--------------------------------------------------------------------------------
/conda_envs/pymc_statespace.yml:
--------------------------------------------------------------------------------
1 | name: pymc-statespace
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | # Base dependencies
6 | - pymc
7 | - pytensor
8 | - numba::numba>=0.57
9 | - numpy
10 | - pandas
11 | - xarray
12 | - matplotlib
13 | - arviz
14 | - statsmodels
15 | # Testing dependencies
16 | - pre-commit
17 | - pytest-cov>=2.5
18 | - pytest>=3.0
19 | - pytest-env
20 |
--------------------------------------------------------------------------------
/data/nile.csv:
--------------------------------------------------------------------------------
1 | "x"
2 | 1120
3 | 1160
4 | 963
5 | 1210
6 | 1160
7 | 1160
8 | 813
9 | 1230
10 | 1370
11 | 1140
12 | 995
13 | 935
14 | 1110
15 | 994
16 | 1020
17 | 960
18 | 1180
19 | 799
20 | 958
21 | 1140
22 | 1100
23 | 1210
24 | 1150
25 | 1250
26 | 1260
27 | 1220
28 | 1030
29 | 1100
30 | 774
31 | 840
32 | 874
33 | 694
34 | 940
35 | 833
36 | 701
37 | 916
38 | 692
39 | 1020
40 | 1050
41 | 969
42 | 831
43 | 726
44 | 456
45 | 824
46 | 702
47 | 1120
48 | 1100
49 | 832
50 | 764
51 | 821
52 | 768
53 | 845
54 | 864
55 | 862
56 | 698
57 | 845
58 | 744
59 | 796
60 | 1040
61 | 759
62 | 781
63 | 865
64 | 845
65 | 944
66 | 984
67 | 897
68 | 822
69 | 1010
70 | 771
71 | 676
72 | 649
73 | 846
74 | 812
75 | 742
76 | 801
77 | 1040
78 | 860
79 | 874
80 | 848
81 | 890
82 | 744
83 | 749
84 | 838
85 | 1050
86 | 918
87 | 986
88 | 797
89 | 923
90 | 975
91 | 815
92 | 1020
93 | 906
94 | 901
95 | 1170
96 | 912
97 | 746
98 | 919
99 | 718
100 | 714
101 | 740
102 |
--------------------------------------------------------------------------------
/pymc_statespace/__init__.py:
--------------------------------------------------------------------------------
1 | from pymc_statespace.models.local_level import BayesianLocalLevel
2 | from pymc_statespace.models.SARIMAX import BayesianARMA
3 | from pymc_statespace.models.VARMAX import BayesianVARMAX
4 |
5 | __all__ = ["BayesianLocalLevel", "BayesianARMA", "BayesianVARMAX"]
6 |
7 | __version__ = "0.0.1"
8 |
--------------------------------------------------------------------------------
/pymc_statespace/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jessegrabowski/pymc_statespace/0659a3bfd9186f128f238c69d5193d3c461f3948/pymc_statespace/core/__init__.py
--------------------------------------------------------------------------------
/pymc_statespace/core/representation.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | from typing import List, Optional, Tuple, Type, Union
3 |
4 | import numpy as np
5 | import pandas.core.tools.datetimes
6 | import pytensor.tensor as pt
7 | from pandas import DataFrame
8 | from pytensor.tensor import TensorVariable
9 |
10 | KeyLike = Union[Tuple[Union[str, int]], str]
11 |
12 |
13 | def _preprocess_data(data: Union[DataFrame, np.ndarray], expected_dims=3):
14 | if isinstance(data, pandas.DataFrame):
15 | data = data.values
16 | elif not isinstance(data, np.ndarray):
17 | raise ValueError("Expected pandas Dataframe or numpy array as data")
18 |
19 | if data.ndim < expected_dims:
20 | n_dims = data.ndim
21 | n_to_add = expected_dims - n_dims + 1
22 | data = reduce(lambda a, b: np.expand_dims(a, -1), [data] * n_to_add)
23 |
24 | return data
25 |
26 |
27 | class PytensorRepresentation:
28 | def __init__(
29 | self,
30 | data: Union[DataFrame, np.ndarray],
31 | k_states: int,
32 | k_posdef: int,
33 | design: Optional[np.ndarray] = None,
34 | obs_intercept: Optional[np.ndarray] = None,
35 | obs_cov=None,
36 | transition=None,
37 | state_intercept=None,
38 | selection=None,
39 | state_cov=None,
40 | initial_state=None,
41 | initial_state_cov=None,
42 | ) -> None:
43 | """
44 | A representation of a State Space model, in Pytensor. Shamelessly copied from the Statsmodels.api implementation
45 | found here:
46 |
47 | https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py
48 |
49 | Parameters
50 | ----------
51 | data: ArrayLike
52 | Array of observed data (called exog in statsmodels)
53 | k_states: int
54 | Number of hidden states
55 | k_posdef: int
56 | Number of states that have exogenous shocks; also the rank of the selection matrix R.
57 | design: ArrayLike, optional
58 | Design matrix, denoted 'Z' in [1].
59 | obs_intercept: ArrayLike, optional
60 | Constant vector in the observation equation, denoted 'd' in [1]. Currently
61 | not used.
62 | obs_cov: ArrayLike, optional
63 | Covariance matrix for multivariate-normal errors in the observation equation. Denoted 'H' in
64 | [1].
65 | transition: ArrayLike, optional
66 | Transition equation that updates the hidden state between time-steps. Denoted 'T' in [1].
67 | state_intercept: ArrayLike, optional
68 | Constant vector for the observation equation, denoted 'c' in [1]. Currently not used.
69 | selection: ArrayLike, optional
70 | Selection matrix that matches shocks to hidden states, denoted 'R' in [1]. This is the identity
71 | matrix when k_posdef = k_states.
72 | state_cov: ArrayLike, optional
73 | Covariance matrix for state equations, denoted 'Q' in [1]. Null matrix when there is no observation
74 | noise.
75 | initial_state: ArrayLike, optional
76 | Experimental setting to allow for Bayesian estimation of the initial state, denoted `alpha_0` in [1]. Default
77 | It should potentially be removed in favor of the closed-form diffuse initialization.
78 | initial_state_cov: ArrayLike, optional
79 | Experimental setting to allow for Bayesian estimation of the initial state, denoted `P_0` in [1]. Default
80 | It should potentially be removed in favor of the closed-form diffuse initialization.
81 |
82 | References
83 | ----------
84 | .. [1] Durbin, James, and Siem Jan Koopman. 2012.
85 | Time Series Analysis by State Space Methods: Second Edition.
86 | Oxford University Press.
87 | """
88 |
89 | # self.data = pt.tensor3(name="Data")
90 | # self.transition = pt.tensor3(name="transition")
91 | # self.selection = pt.tensor3(name="selection")
92 | # self.design = pt.tensor3(name="design")
93 | # self.obs_cov = pt.tensor3(name="obs_cov")
94 | # self.state_cov = pt.tensor3(name="state_cov")
95 | # self.state_intercept = pt.tensor3(name="state_cov")
96 | # self.obs_intercept = pt.tensor3(name="state_cov")
97 | # self.initial_state = pt.tensor3(name="state_cov")
98 | # self.initial_state_cov = pt.tensor3(name="state_cov")
99 |
100 | self.data = _preprocess_data(data)
101 | self.k_states = k_states
102 | self.k_posdef = k_posdef if k_posdef is not None else k_states
103 |
104 | self.n_obs, self.k_endog, *_ = data.shape
105 |
106 | # The last dimension is for time varying matrices; it could be n_obs. Not thinking about that now.
107 | self.shapes = {
108 | "data": (self.k_endog, self.n_obs, 1),
109 | "design": (self.k_endog, self.k_states, 1),
110 | "obs_intercept": (self.k_endog, 1, 1),
111 | "obs_cov": (self.k_endog, self.k_endog, 1),
112 | "transition": (self.k_states, self.k_states, 1),
113 | "state_intercept": (self.k_states, 1, 1),
114 | "selection": (self.k_states, self.k_posdef, 1),
115 | "state_cov": (self.k_posdef, self.k_posdef, 1),
116 | "initial_state": (self.k_states, 1, 1),
117 | "initial_state_cov": (self.k_states, self.k_states, 1),
118 | }
119 |
120 | # Initialize the representation matrices
121 | scope = locals()
122 | for name, shape in self.shapes.items():
123 | if name == "data":
124 | continue
125 |
126 | elif scope[name] is not None:
127 | matrix = self._numpy_to_pytensor(name, scope[name])
128 | setattr(self, name, matrix)
129 |
130 | else:
131 | setattr(self, name, pt.zeros(shape))
132 |
133 | def _validate_key(self, key: KeyLike) -> None:
134 | if key not in self.shapes:
135 | raise IndexError(f"{key} is an invalid state space matrix name")
136 |
137 | def update_shape(self, key: KeyLike, value: Union[np.ndarray, pt.TensorType]) -> None:
138 | # TODO: Get rid of these evals
139 |
140 | if isinstance(value, (pt.TensorConstant, pt.TensorVariable)):
141 | shape = value.shape.eval()
142 | else:
143 | shape = value.shape
144 |
145 | old_shape = self.shapes[key]
146 | if not all([a == b for a, b in zip(shape[:2], old_shape[:2])]):
147 | raise ValueError(
148 | f"The first two dimensions of {key} must be {old_shape[:2]}, found {shape[:2]}"
149 | )
150 |
151 | # Add time dimension dummy if none present
152 | if len(shape) == 2:
153 | self.shapes[key] = shape + (1,)
154 |
155 | self.shapes[key] = shape
156 |
157 | def _add_time_dim_to_slice(
158 | self, name: str, slice_: Union[List[int], Tuple[int]], n_dim: int
159 | ) -> Tuple[int]:
160 | no_time_dim = self.shapes[name][-1] == 1
161 |
162 | # Case 1: All dimensions are sliced
163 | if len(slice_) == n_dim:
164 | return slice_
165 |
166 | # Case 2a: There is a time dim. Just return.
167 | if not no_time_dim:
168 | return slice_
169 |
170 | # Case 2b: There's no time dim. Slice away the dummy dim.
171 | if len(slice_) < n_dim:
172 | empty_slice = (slice(None, None, None),)
173 | n_omitted = n_dim - len(slice_) - 1
174 | return tuple(slice_) + empty_slice * n_omitted + (0,)
175 |
176 | @staticmethod
177 | def _validate_key_and_get_type(key: KeyLike) -> Type[str]:
178 | if isinstance(key, tuple) and not isinstance(key[0], str):
179 | raise IndexError("First index must the name of a valid state space matrix.")
180 |
181 | return type(key)
182 |
183 | def _validate_matrix_shape(self, name: str, X: np.ndarray) -> None:
184 | *expected_shape, time_dim = self.shapes[name]
185 | expected_shape = tuple(expected_shape)
186 |
187 | if X.ndim > 3 or X.ndim < 2:
188 | raise ValueError(
189 | f"Array provided for {name} has {X.ndim} dimensions, "
190 | f"expecting 2 (static) or 3 (time-varying)"
191 | )
192 |
193 | if X.ndim == 2:
194 | if expected_shape != X.shape:
195 | raise ValueError(
196 | f"Array provided for {name} has shape {X.shape}, expected {expected_shape}"
197 | )
198 | if X.ndim == 3:
199 | if X.shape[:2] != expected_shape:
200 | raise ValueError(
201 | f"First two dimensions of array provided for {name} has shape {X.shape[:2]}, "
202 | f"expected {expected_shape}"
203 | )
204 | if X.shape[-1] != self.data.shape[0]:
205 | raise ValueError(
206 | f"Last dimension (time dimension) of array provided for {name} has shape "
207 | f"{X.shape[-1]}, expected {self.data.shape[0]} (equal to the first dimension of the "
208 | f"provided data)"
209 | )
210 |
211 | def _numpy_to_pytensor(self, name: str, X: np.ndarray) -> pt.TensorVariable:
212 | X = X.copy()
213 | self._validate_matrix_shape(name, X)
214 | # Add a time dimension if one isn't provided
215 | if X.ndim == 2:
216 | X = X[..., None]
217 | return pt.as_tensor(X, name=name)
218 |
219 | def __getitem__(self, key: KeyLike) -> pt.TensorVariable:
220 | _type = self._validate_key_and_get_type(key)
221 |
222 | # Case 1: user asked for an entire matrix by name
223 | if _type is str:
224 | self._validate_key(key)
225 | matrix = getattr(self, key)
226 |
227 | # Slice away the time dimension if it's a dummy
228 | if self.shapes[key][-1] == 1:
229 | return matrix[(slice(None),) * (matrix.ndim - 1) + (0,)]
230 |
231 | # If it's time varying, return everything
232 | else:
233 | return matrix
234 |
235 | # Case 2: user asked for a particular matrix and some slices of it
236 | elif _type is tuple:
237 | name, *slice_ = key
238 | self._validate_key(name)
239 |
240 | matrix = getattr(self, name)
241 | slice_ = self._add_time_dim_to_slice(name, slice_, matrix.ndim)
242 |
243 | return matrix[slice_]
244 |
245 | # Case 3: There is only one slice index, but it's not a string
246 | else:
247 | raise IndexError("First index must the name of a valid state space matrix.")
248 |
249 | def __setitem__(self, key: KeyLike, value: Union[float, int, np.ndarray]) -> None:
250 | _type = type(key)
251 | # Case 1: key is a string: we are setting an entire matrix.
252 | if _type is str:
253 | self._validate_key(key)
254 | if isinstance(value, np.ndarray):
255 | value = self._numpy_to_pytensor(key, value)
256 | setattr(self, key, value)
257 | self.update_shape(key, value)
258 |
259 | # Case 2: key is a string plus a slice: we are setting a subset of a matrix
260 | elif _type is tuple:
261 | name, *slice_ = key
262 | self._validate_key(name)
263 |
264 | matrix = getattr(self, name)
265 |
266 | slice_ = self._add_time_dim_to_slice(name, slice_, matrix.ndim)
267 |
268 | matrix = pt.set_subtensor(matrix[slice_], value)
269 | setattr(self, name, matrix)
270 |
--------------------------------------------------------------------------------
/pymc_statespace/filters/__init__.py:
--------------------------------------------------------------------------------
1 | from pymc_statespace.filters.kalman_filter import (
2 | CholeskyFilter,
3 | SingleTimeseriesFilter,
4 | StandardFilter,
5 | SteadyStateFilter,
6 | UnivariateFilter,
7 | )
8 | from pymc_statespace.filters.kalman_smoother import KalmanSmoother
9 |
10 | __all__ = [
11 | "StandardFilter",
12 | "UnivariateFilter",
13 | "SteadyStateFilter",
14 | "KalmanSmoother",
15 | "SingleTimeseriesFilter",
16 | "CholeskyFilter",
17 | ]
18 |
--------------------------------------------------------------------------------
/pymc_statespace/filters/kalman_smoother.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytensor
4 | import pytensor.tensor as pt
5 | from pytensor.compile import get_mode
6 | from pytensor.tensor.nlinalg import matrix_dot
7 |
8 | from pymc_statespace.filters.utilities import split_vars_into_seq_and_nonseq
9 |
10 |
11 | class KalmanSmoother:
12 | def __init__(self, mode: Optional[str] = None):
13 | self.mode = mode
14 | self.seq_names = []
15 | self.non_seq_names = []
16 |
17 | def unpack_args(self, args):
18 | """
19 | The order of inputs to the inner scan function is not known, since some, all, or none of the input matrices
20 | can be time varying. The order arguments are fed to the inner function is sequences, outputs_info,
21 | non-sequences. This function works out which matrices are where, and returns a standardized order expected
22 | by the kalman_step function.
23 |
24 | The standard order is: a, P, a_smooth, P_smooth, T, R, Q
25 | """
26 | # If there are no sequence parameters (all params are static),
27 | # no changes are needed, params will be in order.
28 | args = list(args)
29 | n_seq = len(self.seq_names)
30 | if n_seq == 0:
31 | return args
32 |
33 | # The first two args are always a and P
34 | a = args.pop(0)
35 | P = args.pop(0)
36 |
37 | # There are always two outputs_info wedged between the seqs and non_seqs
38 | seqs, (a_smooth, P_smooth), non_seqs = (
39 | args[:n_seq],
40 | args[n_seq : n_seq + 2],
41 | args[n_seq + 2 :],
42 | )
43 | return_ordered = []
44 | for name in ["T", "R", "Q"]:
45 | if name in self.seq_names:
46 | idx = self.seq_names.index(name)
47 | return_ordered.append(seqs[idx])
48 | else:
49 | idx = self.non_seq_names.index(name)
50 | return_ordered.append(non_seqs[idx])
51 |
52 | T, R, Q = return_ordered
53 |
54 | return a, P, a_smooth, P_smooth, T, R, Q
55 |
56 | def build_graph(self, T, R, Q, filtered_states, filtered_covariances, mode=None):
57 | self.mode = mode
58 |
59 | a_last = filtered_states[-1]
60 | P_last = filtered_covariances[-1]
61 |
62 | sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
63 | [T, R, Q], ["T", "R", "Q"]
64 | )
65 |
66 | self.seq_names = seq_names
67 | self.non_seq_names = non_seq_names
68 |
69 | smoother_result, updates = pytensor.scan(
70 | self.smoother_step,
71 | sequences=[filtered_states[:-1], filtered_covariances[:-1]] + sequences,
72 | outputs_info=[a_last, P_last],
73 | non_sequences=non_sequences,
74 | go_backwards=True,
75 | name="kalman_smoother",
76 | mode=get_mode(self.mode),
77 | )
78 |
79 | smoothed_states, smoothed_covariances = smoother_result
80 | smoothed_states = pt.concatenate([smoothed_states[::-1], pt.atleast_3d(a_last)], axis=0)
81 | smoothed_covariances = pt.concatenate(
82 | [smoothed_covariances[::-1], pt.atleast_3d(P_last)], axis=0
83 | )
84 |
85 | return smoothed_states, smoothed_covariances
86 |
87 | def smoother_step(self, *args):
88 | a, P, a_smooth, P_smooth, T, R, Q = self.unpack_args(args)
89 | a_hat, P_hat = self.predict(a, P, T, R, Q)
90 |
91 | # Use pinv, otherwise P_hat is singular when there is missing data
92 | smoother_gain = matrix_dot(pt.linalg.pinv(P_hat), T, P).T
93 | a_smooth_next = a + smoother_gain @ (a_smooth - a_hat)
94 |
95 | P_smooth_next = P + matrix_dot(smoother_gain, P_smooth - P_hat, smoother_gain.T)
96 |
97 | return a_smooth_next, P_smooth_next
98 |
99 | @staticmethod
100 | def predict(a, P, T, R, Q):
101 | a_hat = T.dot(a)
102 | P_hat = matrix_dot(T, P, T.T) + matrix_dot(R, Q, R.T)
103 |
104 | return a_hat, P_hat
105 |
--------------------------------------------------------------------------------
/pymc_statespace/filters/utilities.py:
--------------------------------------------------------------------------------
1 | def split_vars_into_seq_and_nonseq(params, param_names):
2 | """
3 | Split inputs into those that are time varying and those that are not. This division is required by scan.
4 | """
5 | sequences, non_sequences = [], []
6 | seq_names, non_seq_names = [], []
7 |
8 | for param, name in zip(params, param_names):
9 | if param.ndim == 2:
10 | non_sequences.append(param)
11 | non_seq_names.append(name)
12 | elif param.ndim == 3:
13 | sequences.append(param)
14 | seq_names.append(name)
15 | else:
16 | raise ValueError(
17 | f"Matrix {name} has {param.ndim}, it should either 2 (static) or 3 (time varying)."
18 | )
19 |
20 | return sequences, non_sequences, seq_names, non_seq_names
21 |
--------------------------------------------------------------------------------
/pymc_statespace/models/SARIMAX.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 | import pytensor.tensor as at
5 |
6 | from pymc_statespace.core.statespace import PyMCStateSpace
7 | from pymc_statespace.utils.pytensor_scipy import solve_discrete_lyapunov
8 |
9 |
10 | class BayesianARMA(PyMCStateSpace):
11 | def __init__(
12 | self,
13 | data,
14 | order: Tuple[int, int],
15 | stationary_initialization: bool = True,
16 | filter_type: str = "standard",
17 | ):
18 |
19 | # Model order
20 | self.p, self.q = order
21 |
22 | self.stationary_initialization = stationary_initialization
23 |
24 | k_states = max(self.p, self.q + 1)
25 | k_posdef = 1
26 |
27 | super().__init__(data, k_states, k_posdef, filter_type)
28 |
29 | # Initialize the matrices
30 | self.ssm["design"] = np.r_[[1.0], np.zeros(k_states - 1)][None]
31 |
32 | self.ssm["transition"] = np.eye(k_states, k=1)
33 |
34 | self.ssm["selection"] = np.r_[[[1.0]], np.zeros(k_states - 1)[:, None]]
35 |
36 | self.ssm["initial_state"] = np.zeros(k_states)[:, None]
37 |
38 | self.ssm["initial_state_cov"] = np.eye(k_states)
39 |
40 | # Cache some indices
41 | self._state_cov_idx = ("state_cov",) + np.diag_indices(k_posdef)
42 | self._ar_param_idx = ("transition",) + (
43 | np.arange(self.p, dtype=int),
44 | np.zeros(self.p, dtype=int),
45 | )
46 | self._ma_param_idx = ("selection",) + (
47 | np.arange(1, self.q + 1, dtype=int),
48 | np.zeros(self.q, dtype=int),
49 | )
50 |
51 | @property
52 | def param_names(self):
53 | names = ["x0", "P0", "sigma_state", "rho", "theta"]
54 | if self.stationary_initialization:
55 | names.remove("P0")
56 |
57 | return names
58 |
59 | def update(self, theta: at.TensorVariable) -> None:
60 | """
61 | Put parameter values from vector theta into the correct positions in the state space matrices.
62 | TODO: Can this be done using variable names to avoid the need to ravel and concatenate all RVs in the
63 | PyMC model?
64 |
65 | Parameters
66 | ----------
67 | theta: TensorVariable
68 | Vector of all variables in the state space model
69 | """
70 | cursor = 0
71 |
72 | # initial states
73 | param_slice = slice(cursor, cursor + self.k_states)
74 | cursor += self.k_states
75 | self.ssm["initial_state", :, 0] = theta[param_slice]
76 |
77 | if not self.stationary_initialization:
78 | # initial covariance
79 | param_slice = slice(cursor, cursor + self.k_states**2)
80 | cursor += self.k_states**2
81 | self.ssm["initial_state_cov", :, :] = theta[param_slice].reshape(
82 | (self.k_states, self.k_states)
83 | )
84 |
85 | # State covariance
86 | param_slice = slice(cursor, cursor + 1)
87 | cursor += 1
88 | self.ssm[self._state_cov_idx] = theta[param_slice]
89 |
90 | # AR parameteres
91 | param_slice = slice(cursor, cursor + self.p)
92 | cursor += self.p
93 | self.ssm[self._ar_param_idx] = theta[param_slice]
94 |
95 | # MA parameters
96 | param_slice = slice(cursor, cursor + self.q)
97 | cursor += self.q
98 | self.ssm[self._ma_param_idx] = theta[param_slice]
99 |
100 | if self.stationary_initialization:
101 | # Solve for matrix quadratic for P0
102 | T = self.ssm["transition"]
103 | R = self.ssm["selection"]
104 | Q = self.ssm["state_cov"]
105 |
106 | P0 = solve_discrete_lyapunov(T, at.linalg.matrix_dot(R, Q, R.T), method="bilinear")
107 | self.ssm["initial_state_cov", :, :] = P0
108 |
--------------------------------------------------------------------------------
/pymc_statespace/models/VARMAX.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 | import pytensor.tensor as at
5 |
6 | from pymc_statespace.core.statespace import PyMCStateSpace
7 | from pymc_statespace.models.utilities import get_slice_and_move_cursor
8 | from pymc_statespace.utils.pytensor_scipy import solve_discrete_lyapunov
9 |
10 |
11 | class BayesianVARMAX(PyMCStateSpace):
12 | def __init__(
13 | self,
14 | data,
15 | order: Tuple[int, int],
16 | stationary_initialization: bool = True,
17 | filter_type: str = "standard",
18 | measurement_error: bool = True,
19 | verbose=True,
20 | ):
21 |
22 | self.p, self.q = order
23 | self.stationary_initialization = stationary_initialization
24 | self.measurement_error = measurement_error
25 |
26 | k_order = max(self.p, 1) + self.q
27 | k_obs = data.shape[1]
28 | k_states = k_obs * k_order
29 | k_posdef = data.shape[1]
30 |
31 | super().__init__(data, k_states, k_posdef, filter_type, verbose=verbose)
32 |
33 | # Save counts of the number of parameters in each category
34 | self.param_counts = {
35 | "x0": k_states,
36 | "P0": k_states**2 * (1 - self.stationary_initialization),
37 | "AR": k_obs**2 * self.p,
38 | "MA": k_obs**2 * self.q,
39 | "state_cov": k_posdef**2,
40 | "obs_cov": k_obs * self.measurement_error,
41 | }
42 |
43 | # Initialize the matrices
44 | # Design matrix is a truncated identity (first k_obs states observed)
45 | self.ssm[("design",) + np.diag_indices(k_obs)] = 1
46 |
47 | # Transition matrix has 4 blocks:
48 | self.ssm["transition"] = np.zeros((k_states, k_states))
49 |
50 | # UL: AR coefs (k_obs, k_obs * min(p, 1))
51 | # UR: MA coefs (k_obs, k_obs * q)
52 | # LL: Truncated identity (k_obs * min(p, 1), k_obs * min(p, 1))
53 | # LR: Shifted identity (k_obs * p, k_obs * q)
54 | if self.p > 1:
55 | idx = (slice(k_obs, k_obs * self.p), slice(0, k_obs * (self.p - 1)))
56 | self.ssm[("transition",) + idx] = np.eye(k_obs * (self.p - 1))
57 |
58 | if self.q > 1:
59 | idx = (slice(-k_obs * (self.q - 1), None), slice(-k_obs * self.q, -k_obs))
60 | self.ssm[("transition",) + idx] = np.eye(k_obs * (self.q - 1))
61 |
62 | # The selection matrix is (k_states, k_obs), with two (k_obs, k_obs) identity
63 | # matrix blocks inside. One is always on top, the other starts after (k_obs * p) rows
64 | self.ssm["selection"] = np.zeros((k_states, k_obs))
65 | self.ssm["selection", slice(0, k_obs), :] = np.eye(k_obs)
66 | if self.q > 0:
67 | end = -k_obs * (self.q - 1) if self.q > 1 else None
68 | self.ssm["selection", slice(k_obs * -self.q, end), :] = np.eye(k_obs)
69 |
70 | # self.ssm["initial_state"] = np.zeros(k_states)[:, None]
71 | # self.ssm["initial_state_cov"] = np.eye(k_states)
72 | # self.ssm["state_cov"] = np.eye(k_posdef)
73 | #
74 | # if self.measurement_error:
75 | # self.ssm['obs_cov'] = np.eye(k_obs)
76 |
77 | # Cache some indices
78 | self._ar_param_idx = ("transition", slice(0, k_obs), slice(0, k_obs * self.p))
79 | self._ma_param_idx = ("transition", slice(0, k_obs), slice(k_obs * max(1, self.p), None))
80 | self._obs_cov_idx = ("obs_cov",) + np.diag_indices(k_obs)
81 |
82 | @property
83 | def param_names(self):
84 | names = ["x0", "P0", "ar_params", "ma_params", "state_cov", "obs_cov"]
85 | if self.stationary_initialization:
86 | names.remove("P0")
87 | if not self.measurement_error:
88 | names.remove("obs_cov")
89 | if self.p == 0:
90 | names.remove("ar_params")
91 | if self.q == 0:
92 | names.remove("ma_params")
93 | return names
94 |
95 | def update(self, theta: at.TensorVariable) -> None:
96 | """
97 | Put parameter values from vector theta into the correct positions in the state space matrices.
98 |
99 | Parameters
100 | ----------
101 | theta: TensorVariable
102 | Vector of all variables in the state space model
103 | """
104 |
105 | cursor = 0
106 | # initial states
107 | param_slice, cursor = get_slice_and_move_cursor(cursor, self.param_counts["x0"])
108 | self.ssm["initial_state", :, 0] = theta[param_slice]
109 |
110 | if not self.stationary_initialization:
111 | # initial covariance
112 | param_slice, cursor = get_slice_and_move_cursor(cursor, self.param_counts["P0"])
113 | self.ssm["initial_state_cov", :, :] = theta[param_slice].reshape(
114 | (self.k_states, self.k_states)
115 | )
116 |
117 | # AR parameters
118 | if self.p > 0:
119 | ar_shape = (self.k_endog, self.k_endog * self.p)
120 | param_slice, cursor = get_slice_and_move_cursor(cursor, self.param_counts["AR"])
121 | self.ssm[self._ar_param_idx] = theta[param_slice].reshape(ar_shape)
122 |
123 | # MA parameters
124 | if self.q > 0:
125 | ma_shape = (self.k_endog, self.k_endog * self.q)
126 | param_slice, cursor = get_slice_and_move_cursor(cursor, self.param_counts["MA"])
127 | self.ssm[self._ma_param_idx] = theta[param_slice].reshape(ma_shape)
128 |
129 | # State covariance
130 | param_slice, cursor = get_slice_and_move_cursor(
131 | cursor, self.param_counts["state_cov"], last_slice=not self.measurement_error
132 | )
133 |
134 | self.ssm["state_cov", :, :] = theta[param_slice].reshape((self.k_posdef, self.k_posdef))
135 |
136 | # Measurement error
137 | if self.measurement_error:
138 | param_slice, cursor = get_slice_and_move_cursor(
139 | cursor, self.param_counts["obs_cov"], last_slice=True
140 | )
141 | self.ssm[self._obs_cov_idx] = theta[param_slice]
142 |
143 | if self.stationary_initialization:
144 | # Solve for matrix quadratic for P0
145 | T = self.ssm["transition"]
146 | R = self.ssm["selection"]
147 | Q = self.ssm["state_cov"]
148 |
149 | P0 = solve_discrete_lyapunov(T, at.linalg.matrix_dot(R, Q, R.T), method="bilinear")
150 | self.ssm["initial_state_cov", :, :] = P0
151 |
--------------------------------------------------------------------------------
/pymc_statespace/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jessegrabowski/pymc_statespace/0659a3bfd9186f128f238c69d5193d3c461f3948/pymc_statespace/models/__init__.py
--------------------------------------------------------------------------------
/pymc_statespace/models/local_level.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytensor.tensor as at
3 |
4 | from pymc_statespace.core.statespace import PyMCStateSpace
5 |
6 |
7 | class BayesianLocalLevel(PyMCStateSpace):
8 | def __init__(self, data):
9 | k_states = k_posdef = 2
10 |
11 | super().__init__(data, k_states, k_posdef)
12 |
13 | # Initialize the matrices
14 | self.ssm["design"] = np.array([[1.0, 0.0]])
15 | self.ssm["transition"] = np.array([[1.0, 1.0], [0.0, 1.0]])
16 | self.ssm["selection"] = np.eye(k_states)
17 |
18 | self.ssm["initial_state"] = np.array([[0.0], [0.0]])
19 | self.ssm["initial_state_cov"] = np.array([[1.0, 0.0], [0.0, 1.0]])
20 |
21 | # Cache some indices
22 | self._state_cov_idx = ("state_cov",) + np.diag_indices(k_posdef)
23 |
24 | @property
25 | def param_names(self):
26 | return ["x0", "P0", "sigma_obs", "sigma_state"]
27 |
28 | def update(self, theta: at.TensorVariable) -> None:
29 | """
30 | Put parameter values from vector theta into the correct positions in the state space matrices.
31 | TODO: Can this be done using variable names to avoid the need to ravel and concatenate all RVs in the
32 | PyMC model?
33 |
34 | Parameters
35 | ----------
36 | theta: TensorVariable
37 | Vector of all variables in the state space model
38 | """
39 | # initial states
40 | self.ssm["initial_state", :, 0] = theta[:2]
41 |
42 | # initial covariance
43 | self.ssm["initial_state_cov", :, :] = theta[2:6].reshape((2, 2))
44 |
45 | # Observation covariance
46 | self.ssm["obs_cov", 0, 0] = theta[6]
47 |
48 | # State covariance
49 | self.ssm[self._state_cov_idx] = theta[7:]
50 |
--------------------------------------------------------------------------------
/pymc_statespace/models/utilities.py:
--------------------------------------------------------------------------------
1 | def get_slice_and_move_cursor(cursor, param_count, last_slice=False):
2 | param_slice = slice(cursor, None if last_slice else cursor + param_count)
3 | cursor += param_count
4 |
5 | return param_slice, cursor
6 |
--------------------------------------------------------------------------------
/pymc_statespace/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jessegrabowski/pymc_statespace/0659a3bfd9186f128f238c69d5193d3c461f3948/pymc_statespace/utils/__init__.py
--------------------------------------------------------------------------------
/pymc_statespace/utils/numba_linalg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numba import njit
3 |
4 |
5 | @njit
6 | def numba_block_diagonal(nd_array):
7 | n, rows, cols = nd_array.shape
8 |
9 | out = np.zeros((n * rows, n * cols))
10 |
11 | r, c = 0, 0
12 | for i, (rr, cc) in enumerate([(rows, cols)] * n):
13 | out[r : r + rr, c : c + cc] = nd_array[i]
14 | r += rr
15 | c += cc
16 | return out
17 |
--------------------------------------------------------------------------------
/pymc_statespace/utils/pytensor_scipy.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytensor
4 | import pytensor.tensor as at
5 | import scipy
6 | from pytensor.tensor import TensorVariable, as_tensor_variable
7 | from pytensor.tensor.nlinalg import matrix_dot
8 | from pytensor.tensor.slinalg import solve_discrete_lyapunov
9 |
10 |
11 | class SolveDiscreteARE(at.Op):
12 | __props__ = ("enforce_Q_symmetric",)
13 |
14 | def __init__(self, enforce_Q_symmetric=False):
15 | self.enforce_Q_symmetric = enforce_Q_symmetric
16 |
17 | def make_node(self, A, B, Q, R):
18 | A = as_tensor_variable(A)
19 | B = as_tensor_variable(B)
20 | Q = as_tensor_variable(Q)
21 | R = as_tensor_variable(R)
22 |
23 | out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype, Q.dtype, R.dtype)
24 | X = pytensor.tensor.matrix(dtype=out_dtype)
25 |
26 | return pytensor.graph.basic.Apply(self, [A, B, Q, R], [X])
27 |
28 | def perform(self, node, inputs, output_storage):
29 | A, B, Q, R = inputs
30 | X = output_storage[0]
31 |
32 | if self.enforce_Q_symmetric:
33 | Q = 0.5 * (Q + Q.T)
34 | X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R)
35 |
36 | def infer_shape(self, fgraph, node, shapes):
37 | return [shapes[0]]
38 |
39 | def grad(self, inputs, output_grads):
40 | # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
41 | A, B, Q, R = inputs
42 |
43 | (dX,) = output_grads
44 | X = self(A, B, Q, R)
45 |
46 | K_inner = R + at.linalg.matrix_dot(B.T, X, B)
47 | K_inner_inv = at.linalg.solve(K_inner, at.eye(R.shape[0]))
48 | K = matrix_dot(K_inner_inv, B.T, X, A)
49 |
50 | A_tilde = A - B.dot(K)
51 |
52 | dX_symm = 0.5 * (dX + dX.T)
53 | S = solve_discrete_lyapunov(A_tilde, dX_symm)
54 |
55 | A_bar = 2 * matrix_dot(X, A_tilde, S)
56 | B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
57 | Q_bar = S
58 | R_bar = matrix_dot(K, S, K.T)
59 |
60 | return [A_bar, B_bar, Q_bar, R_bar]
61 |
62 |
63 | def solve_discrete_are(A, B, Q, R) -> TensorVariable:
64 | """
65 | Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
66 | Parameters
67 | ----------
68 | A: ArrayLike
69 | Square matrix of shape M x M
70 | B: ArrayLike
71 | Square matrix of shape M x M
72 | Q: ArrayLike
73 | Square matrix of shape M x M
74 | R: ArrayLike
75 | Square matrix of shape N x N
76 |
77 | Returns
78 | -------
79 | X: at.matrix
80 | Square matrix of shape M x M, representing the solution to the DARE
81 | """
82 |
83 | return SolveDiscreteARE()(A, B, Q, R)
84 |
--------------------------------------------------------------------------------
/pymc_statespace/utils/simulation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numba import njit
3 |
4 | from pymc_statespace.utils.numba_linalg import numba_block_diagonal
5 |
6 |
7 | @njit
8 | def numba_mvn_draws(mu, cov):
9 | samples = np.random.randn(*mu.shape)
10 | k = cov.shape[0]
11 | jitter = np.eye(k) * np.random.uniform(1e-12, 1e-8)
12 |
13 | L = np.linalg.cholesky(cov + jitter)
14 | return mu + L @ samples
15 |
16 |
17 | @njit
18 | def conditional_simulation(mus, covs, n, k, n_simulations=100):
19 | n_samples = mus.shape[0]
20 | simulations = np.empty((n_samples * n_simulations, n, k))
21 |
22 | for i in range(n_samples):
23 | for j in range(n_simulations):
24 | sim = numba_mvn_draws(mus[i], numba_block_diagonal(covs[i]))
25 | simulations[(i * n_simulations + j), :, :] = sim.reshape(n, k)
26 | return simulations
27 |
28 |
29 | @njit
30 | def simulate_statespace(T, Z, R, H, Q, n_steps, x0=None):
31 | n_obs, n_states = Z.shape
32 | k_posdef = R.shape[1]
33 | k_obs_noise = H.shape[0] * (1 - int(np.all(H == 0)))
34 |
35 | state_noise = np.random.randn(n_steps, k_posdef)
36 | state_chol = np.linalg.cholesky(Q)
37 | state_innovations = state_noise @ state_chol
38 |
39 | if k_obs_noise != 0:
40 | obs_noise = np.random.randn(n_steps, k_obs_noise)
41 | obs_chol = np.linalg.cholesky(H)
42 | obs_innovations = obs_noise @ obs_chol
43 |
44 | simulated_states = np.zeros((n_steps, n_states))
45 | simulated_obs = np.zeros((n_steps, n_obs))
46 |
47 | if x0 is not None:
48 | simulated_states[0] = x0
49 | simulated_obs[0] = Z @ x0
50 |
51 | if k_obs_noise != 0:
52 | for t in range(1, n_steps):
53 | simulated_states[t] = T @ simulated_states[t - 1] + R @ state_innovations[t]
54 | simulated_obs[t] = Z @ simulated_states[t - 1] + obs_innovations[t]
55 | else:
56 | for t in range(1, n_steps):
57 | simulated_states[t] = T @ simulated_states[t - 1] + R @ state_innovations[t]
58 | simulated_obs[t] = Z @ simulated_states[t - 1]
59 |
60 | return simulated_states, simulated_obs
61 |
62 |
63 | def unconditional_simulations(thetas, update_funcs, n_steps=100, n_simulations=100):
64 | samples, *_ = thetas[0].shape
65 | _, _, T, Z, R, H, Q = (f(*[theta[0] for theta in thetas])[0] for f in update_funcs)
66 | n_obs, n_states = Z.shape
67 |
68 | states = np.empty((samples * n_simulations, n_steps, n_states))
69 | observed = np.empty((samples * n_simulations, n_steps, n_obs))
70 |
71 | for i in range(samples):
72 | theta = [x[i] for x in thetas]
73 | _, _, T, Z, R, H, Q = (f(*theta)[0] for f in update_funcs)
74 | for j in range(n_simulations):
75 | sim_state, sim_obs = simulate_statespace(T, Z, R, H, Q, n_steps=n_steps)
76 | states[i * n_simulations + j, :, :] = sim_state
77 | observed[i * n_simulations + j, :, :] = sim_obs
78 |
79 | return states, observed
80 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.pytest.ini_options]
2 | minversion = "6.0"
3 | xfail_strict=true
4 | filterwarnings = [
5 | "error",
6 | "ignore::DeprecationWarning"]
7 |
8 | env = ["NUMBA_DISABLE_JIT = 1"]
9 |
10 | [tool.isort]
11 | profile = 'black'
12 |
13 | [tool.black]
14 | line-length = 100
15 |
16 | [tool.nbqa.mutate]
17 | isort = 1
18 | black = 1
19 | pyupgrade = 1
20 |
21 | [tool.bumpver]
22 | current_version = "0.0.1"
23 | version_pattern = "MAJOR.MINOR.PATCH"
24 | commit_message = "Bump version {old_version} -> {new_version}"
25 | commit = true
26 | tag = true
27 | push = false
28 |
29 | [tool.bumpver.file_patterns]
30 | "pyproject.toml" = ['current_version = "{version}"']
31 | "setup.cfg" = ['version = {version}']
32 | "pymc_statespace/__init__.py" = ["{version}"]
33 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = pymc-statespace
3 | version = 0.0.1
4 | description = A system for Bayesian estimation of state space models
5 | long_description = file: README.md
6 | long_description_content_type = text/markdown
7 | url = https://github.com/jessegrabowski/pymc_statespace
8 | author = Jesse Grabowski
9 | author_email = jessegrabowski@gmail.com
10 |
11 | [options]
12 | packages = find:
13 | install_requires =
14 | pymc>=5.2
15 | pytensor
16 | numba>=0.57
17 | numpy
18 | pandas
19 | xarray
20 | matplotlib
21 | arviz
22 | statsmodels
23 |
24 | [options.packages.find]
25 | exclude = tests*
26 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup()
4 |
--------------------------------------------------------------------------------
/svgs/17b59c002f249204f24e31507dc4957d.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/svgs/3b5e41543d7fc8cedf98ec609b343134.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/svgs/523b266d36c270dbbb5daf2c9092ce0f.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/svgs/54221efbfb5e69569dfe8ddea785093a.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/svgs/92b8c1194757fb3131cda468a34be85f.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/svgs/a06c0e58d4d162b0e87d32927c9812db.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/svgs/a13d89295e999545a129b2d412e99f6d.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/svgs/cac7e81ebde5e530e639eae5389f149e.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/svgs/d523a14b8179ebe46f0ed16895ee46f0.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/svgs/edcff444fd5240add1c47d2de50ebd7e.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/svgs/f566e90ed17c5292db4600846e0ace27.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/svgs/ff25a8f22c7430ca572d33206c0a9176.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jessegrabowski/pymc_statespace/0659a3bfd9186f128f238c69d5193d3c461f3948/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_VARMAX.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import warnings
4 | from itertools import product
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import pymc as pm
10 | import pytensor.tensor as pt
11 | import pytest
12 | import statsmodels.api as sm
13 | from numpy.testing import assert_allclose
14 |
15 | from pymc_statespace import BayesianVARMAX
16 |
17 | ROOT = Path(__file__).parent.absolute()
18 | sys.path.append(ROOT)
19 |
20 |
21 | @pytest.fixture
22 | def data():
23 | return pd.read_csv(
24 | os.path.join(ROOT, "test_data/statsmodels_macrodata_processed.csv"), index_col=0
25 | )
26 |
27 |
28 | ps = [0, 1, 2, 3]
29 | qs = [0, 1, 2, 3]
30 | orders = list(product(ps, qs))[1:]
31 | ids = [f"p={x[0]}, q={x[1]}" for x in orders]
32 |
33 |
34 | @pytest.mark.parametrize("order", orders, ids=ids)
35 | @pytest.mark.parametrize("matrix", ["transition", "selection", "state_cov", "obs_cov", "design"])
36 | def test_VARMAX_init_matches_statsmodels(data, order, matrix):
37 | p, q = order
38 |
39 | mod = BayesianVARMAX(data, order=(p, q), verbose=False)
40 | with warnings.catch_warnings():
41 | warnings.simplefilter("ignore")
42 | sm_var = sm.tsa.VARMAX(data, order=(p, q))
43 |
44 | assert_allclose(mod.ssm[matrix].eval(), sm_var.ssm[matrix])
45 |
46 |
47 | @pytest.mark.parametrize("order", orders, ids=ids)
48 | @pytest.mark.parametrize("var", ["AR", "MA", "state_cov"])
49 | def test_VARMAX_param_counts_match_statsmodels(data, order, var):
50 | p, q = order
51 |
52 | mod = BayesianVARMAX(data, order=(p, q), verbose=False)
53 | with warnings.catch_warnings():
54 | warnings.simplefilter("ignore")
55 | sm_var = sm.tsa.VARMAX(data, order=(p, q))
56 |
57 | count = mod.param_counts[var]
58 | if var == "state_cov":
59 | # Statsmodels only counts the lower triangle
60 | count = mod.k_posdef * (mod.k_posdef - 1)
61 | assert count == sm_var.parameters[var.lower()]
62 |
63 |
64 | @pytest.mark.parametrize("order", orders, ids=ids)
65 | @pytest.mark.parametrize("matrix", ["transition", "selection", "state_cov", "obs_cov", "design"])
66 | def test_VARMAX_update_matches_statsmodels(data, order, matrix):
67 | p, q = order
68 |
69 | with warnings.catch_warnings():
70 | warnings.simplefilter("ignore")
71 | sm_var = sm.tsa.VARMAX(data, order=(p, q))
72 |
73 | param_counts = [None] + np.cumsum(list(sm_var.parameters.values())).tolist()
74 | param_slices = [slice(a, b) for a, b in zip(param_counts[:-1], param_counts[1:])]
75 | param_lists = [trend, ar, ma, reg, state_cov, obs_cov] = [
76 | sm_var.param_names[idx] for idx in param_slices
77 | ]
78 | param_d = {
79 | k: np.random.normal(scale=0.1) ** 2 for param_list in param_lists for k in param_list
80 | }
81 |
82 | res = sm_var.fit_constrained(param_d)
83 |
84 | mod = BayesianVARMAX(data, order=(p, q), verbose=False, measurement_error=False)
85 |
86 | with pm.Model() as pm_mod:
87 | x0 = pm.Deterministic("x0", pt.zeros(mod.k_states))
88 | ma_params = pm.Deterministic(
89 | "ma_params", pt.as_tensor_variable(np.array([param_d[var] for var in ma]))
90 | )
91 | ar_params = pm.Deterministic(
92 | "ar_params", pt.as_tensor_variable(np.array([param_d[var] for var in ar]))
93 | )
94 | state_chol = np.zeros((mod.k_posdef, mod.k_posdef))
95 | state_chol[np.tril_indices(mod.k_posdef)] = np.array([param_d[var] for var in state_cov])
96 | state_cov = pm.Deterministic("state_cov", pt.as_tensor_variable(state_chol @ state_chol.T))
97 | mod.build_statespace_graph()
98 |
99 | assert_allclose(mod.ssm[matrix].eval(), sm_var.ssm[matrix])
100 |
--------------------------------------------------------------------------------
/tests/test_data/nile.csv:
--------------------------------------------------------------------------------
1 | "x"
2 | 1120
3 | 1160
4 | 963
5 | 1210
6 | 1160
7 | 1160
8 | 813
9 | 1230
10 | 1370
11 | 1140
12 | 995
13 | 935
14 | 1110
15 | 994
16 | 1020
17 | 960
18 | 1180
19 | 799
20 | 958
21 | 1140
22 | 1100
23 | 1210
24 | 1150
25 | 1250
26 | 1260
27 | 1220
28 | 1030
29 | 1100
30 | 774
31 | 840
32 | 874
33 | 694
34 | 940
35 | 833
36 | 701
37 | 916
38 | 692
39 | 1020
40 | 1050
41 | 969
42 | 831
43 | 726
44 | 456
45 | 824
46 | 702
47 | 1120
48 | 1100
49 | 832
50 | 764
51 | 821
52 | 768
53 | 845
54 | 864
55 | 862
56 | 698
57 | 845
58 | 744
59 | 796
60 | 1040
61 | 759
62 | 781
63 | 865
64 | 845
65 | 944
66 | 984
67 | 897
68 | 822
69 | 1010
70 | 771
71 | 676
72 | 649
73 | 846
74 | 812
75 | 742
76 | 801
77 | 1040
78 | 860
79 | 874
80 | 848
81 | 890
82 | 744
83 | 749
84 | 838
85 | 1050
86 | 918
87 | 986
88 | 797
89 | 923
90 | 975
91 | 815
92 | 1020
93 | 906
94 | 901
95 | 1170
96 | 912
97 | 746
98 | 919
99 | 718
100 | 714
101 | 740
102 |
--------------------------------------------------------------------------------
/tests/test_data/statsmodels_macrodata_processed.csv:
--------------------------------------------------------------------------------
1 | ,realgdp,realcons,realinv
2 | 1959-06-30,0.024942130816387298,0.015286107415635186,0.08021268127441772
3 | 1959-09-30,-0.0011929521106681662,0.010385977737146668,-0.0721310437426288
4 | 1959-12-31,0.003494532654372051,0.0010840109462586511,0.03442511117316993
5 | 1960-03-31,0.022190179514293362,0.0095341508767115,0.10266376814814393
6 | 1960-06-30,-0.004684553282733539,0.012572428049423046,-0.10669384516657932
7 | 1960-09-30,0.0016328801889198274,-0.003967926518265941,-0.0059778791939155695
8 | 1960-12-31,-0.01290635994607836,0.001343033218102363,-0.1318520178798508
9 | 1961-03-31,0.005922591255446363,-0.00027964988017536996,0.025244180753152712
10 | 1961-06-30,0.01853453415389339,0.01476984095507472,0.07183387371281658
11 | 1961-09-30,0.016031639161811384,0.004838630433301461,0.08045270660189097
12 | 1961-12-31,0.020152816046898003,0.019823062007010783,0.016737113362548683
13 | 1962-03-31,0.017777259286836156,0.01059116613253508,0.05791064029750803
14 | 1962-06-30,0.010980515349976017,0.01221623378687653,-0.009715848023952312
15 | 1962-09-30,0.009204067213396172,0.008062026704950043,0.017733971140915017
16 | 1962-12-31,0.002427018714243445,0.014082552171855056,-0.03414697935818811
17 | 1963-03-31,0.01298521045351464,0.006712294307324562,0.05400709679319515
18 | 1963-06-30,0.012452834593128514,0.009504277281920714,0.014467702035278585
19 | 1963-09-30,0.01865404073909538,0.013515416623270937,0.032089340831295665
20 | 1963-12-31,0.0075738617893286175,0.008349119168469699,0.012232500910016597
21 | 1964-03-31,0.022195863868844867,0.01955417478523991,0.04029537549762097
22 | 1964-06-30,0.011419916678280018,0.01741600837232049,-0.004608479556662992
23 | 1964-09-30,0.01349678440620039,0.018195638939121572,0.02348211049511484
24 | 1964-12-31,0.0027684319766052568,0.0028061004289119396,0.008127111263337206
25 | 1965-03-31,0.02426471280607778,0.02198702890900517,0.09587891389684966
26 | 1965-06-30,0.013476920985249663,0.010995612622818562,-6.05874183419175e-05
27 | 1965-09-30,0.02009026988071838,0.01702550085568255,0.035089793025194105
28 | 1965-12-31,0.0238395627245076,0.027732704078884396,0.004599659936751266
29 | 1966-03-31,0.024249417896358594,0.014669560481433308,0.0811651892140155
30 | 1966-06-30,0.003323329257922225,0.002551564216559221,-0.018415529197198133
31 | 1966-09-30,0.00655531757692529,0.011402150557305646,-0.009958807818388316
32 | 1966-12-31,0.008069240525410137,0.0041484272445648784,0.004789900368053601
33 | 1967-03-31,0.008770749498840047,0.005795667399719484,-0.027762768509880686
34 | 1967-06-30,0.00020820851954361785,0.013544411542069312,-0.043574196677693244
35 | 1967-09-30,0.007946288894562059,0.00511384508807744,0.028297460430921184
36 | 1967-12-31,0.0076008372189466655,0.006142850307218062,0.021403487969130275
37 | 1968-03-31,0.020399308411914063,0.02360689149619688,0.02153029334303813
38 | 1968-06-30,0.016836250623496696,0.015212856527097252,0.03963280880042941
39 | 1968-09-30,0.006818187675689202,0.0186292780304953,-0.033002368053595355
40 | 1968-12-31,0.004323535018018632,0.004584535250413246,0.010333919522066637
41 | 1969-03-31,0.015626993226984354,0.011144077123032226,0.06380045990686156
42 | 1969-06-30,0.002908045754733024,0.006350181971833457,-0.007999752832058782
43 | 1969-09-30,0.00630412172871786,0.0048201868198987086,0.02285706056559711
44 | 1969-12-31,-0.004707590235586423,0.00794538968601799,-0.05536353177486397
45 | 1970-03-31,-0.0015699839630070045,0.006120060394082749,-0.031798102136913364
46 | 1970-06-30,0.0018110848665440216,0.004583883784991194,0.0031276717452559666
47 | 1970-09-30,0.008864772287074274,0.008706318768779475,0.016943182250337863
48 | 1970-12-31,-0.010660821696957257,-0.002723956139778494,-0.05967484549706992
49 | 1971-03-31,0.027202168334472532,0.018949376401506512,0.12209449744029222
50 | 1971-06-30,0.0056567889131589055,0.00912956377520402,0.030519869627655183
51 | 1971-09-30,0.007950886761737053,0.007924948579291602,0.013068141575951486
52 | 1971-12-31,0.002774937094654817,0.016492493094995453,-0.03178387805772953
53 | 1972-03-31,0.01772331446789366,0.013266567575952237,0.06832901837757266
54 | 1972-06-30,0.023438898862099933,0.018924178786744683,0.059410089170633285
55 | 1972-09-30,0.009538014207757683,0.015320125879748403,0.01413153537262346
56 | 1972-12-31,0.016336792461098426,0.023192179368109578,0.0051755829879578386
57 | 1973-03-31,0.02525804234288387,0.018129887541462608,0.061563645765462915
58 | 1973-06-30,0.011501097796264403,-0.0005053376394563713,0.04568816612695592
59 | 1973-09-30,-0.005350042942591671,0.0035634922053358054,-0.03988458461383537
60 | 1973-12-31,0.009493238255322112,-0.0029318600468926093,0.037538073304505204
61 | 1974-03-31,-0.008807613756498966,-0.008783808120972125,-0.06593511447881983
62 | 1974-06-30,0.002557212404131093,0.0034656569352140565,-0.004967723257728096
63 | 1974-09-30,-0.009936678481238914,0.004117708773394568,-0.058061447421211554
64 | 1974-12-31,-0.003943318438729193,-0.014743377118552559,0.009461202519150724
65 | 1975-03-31,-0.012237921312634015,0.00833777411551928,-0.19316322686510023
66 | 1975-06-30,0.007613228387052473,0.016532431063058795,-0.035342692176429935
67 | 1975-09-30,0.01670305536159411,0.014167887518730993,0.08128912221366846
68 | 1975-12-31,0.01297845428771538,0.010526251464110459,0.027115371108036967
69 | 1976-03-31,0.02247811058416005,0.01979843719362151,0.09853530316550518
70 | 1976-06-30,0.007492297472733611,0.009116702499950335,0.04176338114011724
71 | 1976-09-30,0.004886706909426053,0.010532158188986784,0.0018540964384721192
72 | 1976-12-31,0.007235398273953919,0.012916949526614374,0.006927681623511539
73 | 1977-03-31,0.011541159204879747,0.011378325577613424,0.048024592961263046
74 | 1977-06-30,0.019678246737301563,0.005512803347329509,0.07444626765491602
75 | 1977-09-30,0.017726137552946497,0.009497286008931738,0.053063063975045566
76 | 1977-12-31,-0.0002069209563817509,0.014954910730345716,-0.02863469545466213
77 | 1978-03-31,0.003408732702791184,0.0058204155446599515,0.01963890276843383
78 | 1978-06-30,0.038585475403756675,0.02116452181230777,0.0664049810036742
79 | 1978-09-30,0.009756162989988937,0.004185066791748682,0.030400604229731343
80 | 1978-12-31,0.013139436292622264,0.008023980684571441,0.022499917479481546
81 | 1979-03-31,0.0016709945702260143,0.005039002863874487,-0.00016226982579325977
82 | 1979-06-30,0.0009382908509696364,-0.0005852464852740269,-0.002316127923967848
83 | 1979-09-30,0.007162420120430113,0.009771159048007405,-0.019659741924709984
84 | 1979-12-31,0.0027476398395105406,0.002657931274367087,-0.0187911675618766
85 | 1980-03-31,0.0032161514595951957,-0.0017360658804363993,-0.007274586891929502
86 | 1980-06-30,-0.02070793158034334,-0.02295523265035193,-0.09455513343548372
87 | 1980-09-30,-0.001860258111918256,0.010664338361117132,-0.07927831360496729
88 | 1980-12-31,0.01832680700930389,0.013238212922260573,0.0968258047063566
89 | 1981-03-31,0.020566824679381313,0.005455894343649348,0.0947431099496372
90 | 1981-06-30,-0.00801140272402101,0.0,-0.048776329370762816
91 | 1981-09-30,0.012077078189896895,0.004046421832148539,0.06021997477221852
92 | 1981-12-31,-0.012535910197692957,-0.007584115017426285,-0.039552608640689435
93 | 1982-03-31,-0.016547233635337832,0.006437116638242202,-0.11001915580959931
94 | 1982-06-30,0.00540438914742758,0.0035930053234061177,-0.0008870179588731375
95 | 1982-09-30,-0.0038627257685650562,0.00763767645005764,-0.011739399570102726
96 | 1982-12-31,0.0007891034952027809,0.018070548333479763,-0.0932678840105785
97 | 1983-03-31,0.012360524736578782,0.00975494533673249,0.034986691966604866
98 | 1983-06-30,0.02222733528868126,0.019647192590248608,0.0921358495310205
99 | 1983-09-30,0.019527814105522623,0.017529979698624132,0.06507760763708603
100 | 1983-12-31,0.020459959942552786,0.01573023035373744,0.10011095968525563
101 | 1984-03-31,0.019210165810124025,0.008528432247278062,0.09954287887801616
102 | 1984-06-30,0.01711776350258809,0.01421719441872682,0.033159907559543456
103 | 1984-09-30,0.009671516584397466,0.00766837614262883,0.02297814070131121
104 | 1984-12-31,0.008108095427875384,0.013091863113150026,-0.016542413296196656
105 | 1985-03-31,0.00939240246972517,0.016827618778211928,-0.03352576107703609
106 | 1985-06-30,0.008431222368500357,0.009052660542270274,0.017114336250190654
107 | 1985-09-30,0.015499913364575235,0.018822015632972366,-0.0111110951439235
108 | 1985-12-31,0.007560947026535203,0.0021976357071959995,0.03835114042144738
109 | 1986-03-31,0.009563067941808612,0.00831155625679969,-0.0020569212508529944
110 | 1986-06-30,0.004009178207507347,0.010591865745748663,-0.02244250448137919
111 | 1986-09-30,0.009595188450980174,0.017337686123433116,-0.03185277828355737
112 | 1986-12-31,0.004821793960115173,0.005997865443140071,0.001549575539982584
113 | 1987-03-31,0.005528935739866014,-0.0015221813917598581,0.030095986982594525
114 | 1987-06-30,0.010577916660121645,0.013308102382179499,0.0013989296878635926
115 | 1987-09-30,0.008635542829132703,0.011078573792220325,0.0010078332176899352
116 | 1987-12-31,0.01696579977682866,0.0024001801486601693,0.0750846836507959
117 | 1988-03-31,0.005159352817525331,0.01656194042911352,-0.057993128623962775
118 | 1988-06-30,0.01276151335619602,0.007247274184601693,0.02403489024554961
119 | 1988-09-30,0.005149580544561161,0.007864565206437746,0.006403454265695885
120 | 1988-12-31,0.013264351168576383,0.011701532598701547,0.013156828932998188
121 | 1989-03-31,0.009344884553845745,0.0036700139253920128,0.037605497844809044
122 | 1989-06-30,0.007454656791988867,0.004467358058686699,-0.011753595376318593
123 | 1989-09-30,0.007899659182339036,0.010308676978862508,-0.011821928251866787
124 | 1989-12-31,0.0021804320484175577,0.004877222036274276,-0.010316302801358646
125 | 1990-03-31,0.010392526978691308,0.007875099856645917,0.009793458713496683
126 | 1990-06-30,0.003966490297244718,0.003294399681363913,0.0002839754625112434
127 | 1990-09-30,-1.5137345979354677e-05,0.003789233369559497,-0.023819672846864037
128 | 1990-12-31,-0.008799970050372252,-0.007800424872460354,-0.06532875105654501
129 | 1991-03-31,-0.004856014563534572,-0.002853392463110893,-0.04156716709528485
130 | 1991-06-30,0.00672662020935455,0.007597287958210686,-0.005040521569013912
131 | 1991-09-30,0.004203639797180969,0.0038051724967349543,0.02459129417067274
132 | 1991-12-31,0.003912442289637497,-0.00044911861228591476,0.03756840115321758
133 | 1992-03-31,0.010916709746682685,0.017055094156383177,-0.02248205323390895
134 | 1992-06-30,0.010569355265042546,0.0059076236833242035,0.06379346786306783
135 | 1992-09-30,0.010265417859880444,0.01098812198103083,0.010284555756322256
136 | 1992-12-31,0.010468628590384554,0.012138581118792402,0.031159828241962728
137 | 1993-03-31,0.0018361379777385167,0.004031472307636008,0.02322640760306438
138 | 1993-06-30,0.006377497246457864,0.009549783338858475,0.00782162022374333
139 | 1993-09-30,0.005250241191419036,0.010803521482550593,-0.0007046835401780527
140 | 1993-12-31,0.013119466990481499,0.00885699235809767,0.051424907521263385
141 | 1994-03-31,0.009688253462533325,0.011073205873081804,0.042240819567175514
142 | 1994-06-30,0.01358569880162186,0.007393812006782241,0.05665002899143978
143 | 1994-09-30,0.006420403956367338,0.00797962003214181,-0.018141383569249214
144 | 1994-12-31,0.01104477749296251,0.009819113744269359,0.04556673100613651
145 | 1995-03-31,0.0024502400068371344,0.0011665112641257025,0.010112657125509017
146 | 1995-06-30,0.0021473245426548715,0.00816073795386174,-0.02726183773355295
147 | 1995-09-30,0.00836938846235924,0.008897280161562549,-0.009672229723568293
148 | 1995-12-31,0.006947981234826983,0.007015740708053997,0.02776802287631952
149 | 1996-03-31,0.0068267227982588,0.009111938423863819,0.013087865506328455
150 | 1996-06-30,0.01714011546601668,0.011246513926669977,0.050498380009807775
151 | 1996-09-30,0.008660860530211423,0.005956322159514471,0.049194061600450034
152 | 1996-12-31,0.010856683405549461,0.008122624642348697,-0.0027222953368513103
153 | 1997-03-31,0.0076617204401348005,0.010018604777197737,0.023078936598619038
154 | 1997-06-30,0.0147176177580981,0.004035267457265235,0.06189844480028306
155 | 1997-09-30,0.01247302253870508,0.016863909657265808,0.01762257993775762
156 | 1997-12-31,0.00764257404176405,0.011372793339003096,0.01584465220509479
157 | 1998-03-31,0.009402375786226713,0.009903963846072728,0.04663007831635557
158 | 1998-06-30,0.008952009121774296,0.017059054612118985,-0.012038222653899311
159 | 1998-09-30,0.013108366865449028,0.013177531481691318,0.028250019650299052
160 | 1998-12-31,0.01716157436950816,0.015216842332248959,0.03165239255225849
161 | 1999-03-31,0.008868795717511091,0.009810532772998926,0.03100057041875015
162 | 1999-06-30,0.007786660186548389,0.01562159957313547,-0.003497282090812348
163 | 1999-09-30,0.01263644360829197,0.011942459850613929,0.024791673771441758
164 | 1999-12-31,0.01780192698963745,0.014009877815061245,0.034874117857821574
165 | 2000-03-31,0.002610475349184682,0.015056864152766636,-0.014060287861793697
166 | 2000-06-30,0.0193185856258129,0.009354481042235463,0.06693782876834042
167 | 2000-09-30,0.0008357335001925037,0.009738736557368455,-0.015765208923020246
168 | 2000-12-31,0.005900007164653331,0.008802484015141943,0.00044731799897590463
169 | 2001-03-31,-0.003302713380032074,0.003985048423478688,-0.05434900101941054
170 | 2001-06-30,0.0065359877030051194,0.003763424514321656,-0.0032138639191021667
171 | 2001-09-30,-0.0027454160608613165,0.004389912833836718,-0.021314183957510835
172 | 2001-12-31,0.0035257644037436364,0.015542609405262198,-0.05936563281615381
173 | 2002-03-31,0.008551982920579349,0.003436710689667777,0.03303111582599705
174 | 2002-06-30,0.005292010252050616,0.0050767575239785145,0.011917567926544415
175 | 2002-09-30,0.004984622513504178,0.006754003917142981,0.0020698926930968753
176 | 2002-12-31,0.0002064215385182422,0.003545619424254909,-0.000723313380029289
177 | 2003-03-31,0.004043517814475095,0.005147238095256412,-4.301834346964739e-05
178 | 2003-06-30,0.00794435538282201,0.009252460293302178,0.005805607370590771
179 | 2003-09-30,0.016622298075029462,0.013846458933041816,0.035648614625020336
180 | 2003-12-31,0.008954497678535844,0.005506879502322093,0.036318830433965665
181 | 2004-03-31,0.007017360710909415,0.009478743954543845,0.005207404106674751
182 | 2004-06-30,0.00708219043205105,0.005389829132301571,0.042516889635136224
183 | 2004-09-30,0.007318523149864475,0.008521678425312373,0.01288186847757089
184 | 2004-12-31,0.008638865661898976,0.01143533669801755,0.020403414183223667
185 | 2005-03-31,0.009928644671413522,0.007459800209780099,0.0210216168488655
186 | 2005-06-30,0.004253071337316783,0.009576660204851706,-0.018054001870183356
187 | 2005-09-30,0.007567538827411013,0.007097406480417234,0.010956113173677728
188 | 2005-12-31,0.00515464977792135,0.0025796872496552936,0.03521745293244205
189 | 2006-03-31,0.013032825454358132,0.010976272570252021,0.01446706222507732
190 | 2006-06-30,0.0035955893813284234,0.005371345093337254,-0.001535141513263838
191 | 2006-09-30,0.0002664262315530408,0.006145988881025133,-0.014078087577432008
192 | 2006-12-31,0.007282045058815356,0.009949568459104441,-0.02897189213120477
193 | 2007-03-31,0.0029985596181916208,0.00905317160290231,-0.015520338524286359
194 | 2007-06-30,0.007913399166394441,0.0028453507443444437,0.01378658394853094
195 | 2007-09-30,0.008831847811689997,0.004735045433360341,0.001976111281232207
196 | 2007-12-31,0.0052515140142759265,0.0029947827636505053,-0.020077985987605906
197 | 2008-03-31,-0.0018225504794315839,-0.0014962702917689086,-0.01927639001132686
198 | 2008-06-30,0.0036144287915647055,0.00014972781649902345,-0.02743538266838197
199 | 2008-09-30,-0.0067813613915230775,-0.00894805285032696,-0.01783623003564294
200 | 2008-12-31,-0.01380482973598518,-0.007842752651521678,-0.06916464982757464
201 | 2009-03-31,-0.016611979742227945,0.0015105004366180452,-0.17559819913074826
202 | 2009-06-30,-0.001851247642470355,-0.00219586786933057,-0.06756146963272425
203 | 2009-09-30,0.006862187581308632,0.007264873372626823,0.02019724281421187
204 |
--------------------------------------------------------------------------------
/tests/test_kalman_filter.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | import pytensor
5 | import pytensor.tensor as pt
6 | import pytest
7 | from numpy.testing import assert_allclose
8 |
9 | from pymc_statespace.filters import (
10 | CholeskyFilter,
11 | KalmanSmoother,
12 | SingleTimeseriesFilter,
13 | StandardFilter,
14 | SteadyStateFilter,
15 | UnivariateFilter,
16 | )
17 | from pymc_statespace.filters.kalman_filter import BaseFilter
18 | from tests.utilities.test_helpers import (
19 | get_expected_shape,
20 | get_sm_state_from_output_name,
21 | initialize_filter,
22 | make_test_inputs,
23 | nile_test_test_helper,
24 | )
25 |
26 | standard_inout = initialize_filter(StandardFilter())
27 | cholesky_inout = initialize_filter(CholeskyFilter())
28 | univariate_inout = initialize_filter(UnivariateFilter())
29 | single_inout = initialize_filter(SingleTimeseriesFilter())
30 | steadystate_inout = initialize_filter(SteadyStateFilter())
31 |
32 | f_standard = pytensor.function(*standard_inout)
33 | f_cholesky = pytensor.function(*cholesky_inout)
34 | f_univariate = pytensor.function(*univariate_inout)
35 | f_single_ts = pytensor.function(*single_inout)
36 | f_steady = pytensor.function(*steadystate_inout)
37 |
38 | filter_funcs = [f_standard, f_cholesky, f_univariate, f_single_ts, f_steady]
39 |
40 | filter_names = [
41 | "StandardFilter",
42 | "CholeskyFilter",
43 | "UnivariateFilter",
44 | "SingleTimeSeriesFilter",
45 | "SteadyStateFilter",
46 | ]
47 | output_names = [
48 | "filtered_states",
49 | "predicted_states",
50 | "smoothed_states",
51 | "filtered_covs",
52 | "predicted_covs",
53 | "smoothed_covs",
54 | "log_likelihood",
55 | "ll_obs",
56 | ]
57 |
58 |
59 | def test_base_class_update_raises():
60 | filter = BaseFilter()
61 | inputs = [None] * 8
62 | with pytest.raises(NotImplementedError):
63 | filter.update(*inputs)
64 |
65 |
66 | @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
67 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names)
68 | def test_output_shapes_one_state_one_observed(filter_func, output_idx, name):
69 | p, m, r, n = 1, 1, 1, 10
70 | inputs = make_test_inputs(p, m, r, n)
71 |
72 | outputs = filter_func(*inputs)
73 | expected_output = get_expected_shape(name, p, m, r, n)
74 |
75 | assert outputs[output_idx].shape == expected_output
76 |
77 |
78 | @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
79 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names)
80 | def test_output_shapes_when_all_states_are_stochastic(filter_func, output_idx, name):
81 | p, m, r, n = 1, 2, 2, 10
82 | inputs = make_test_inputs(p, m, r, n)
83 |
84 | outputs = filter_func(*inputs)
85 | expected_output = get_expected_shape(name, p, m, r, n)
86 |
87 | assert outputs[output_idx].shape == expected_output
88 |
89 |
90 | @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
91 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names)
92 | def test_output_shapes_when_some_states_are_deterministic(filter_func, output_idx, name):
93 | p, m, r, n = 1, 5, 2, 10
94 | inputs = make_test_inputs(p, m, r, n)
95 |
96 | outputs = filter_func(*inputs)
97 | expected_output = get_expected_shape(name, p, m, r, n)
98 |
99 | assert outputs[output_idx].shape == expected_output
100 |
101 |
102 | @pytest.fixture
103 | def f_standard_nd():
104 | ksmoother = KalmanSmoother()
105 | data = pt.dtensor3(name="data")
106 | a0 = pt.matrix(name="a0")
107 | P0 = pt.matrix(name="P0")
108 | Q = pt.dtensor3(name="Q")
109 | H = pt.dtensor3(name="H")
110 | T = pt.dtensor3(name="T")
111 | R = pt.dtensor3(name="R")
112 | Z = pt.dtensor3(name="Z")
113 |
114 | inputs = [data, a0, P0, T, Z, R, H, Q]
115 |
116 | (
117 | filtered_states,
118 | predicted_states,
119 | filtered_covs,
120 | predicted_covs,
121 | log_likelihood,
122 | ll_obs,
123 | ) = StandardFilter().build_graph(*inputs)
124 |
125 | smoothed_states, smoothed_covs = ksmoother.build_graph(T, R, Q, filtered_states, filtered_covs)
126 |
127 | outputs = [
128 | filtered_states,
129 | predicted_states,
130 | smoothed_states,
131 | filtered_covs,
132 | predicted_covs,
133 | smoothed_covs,
134 | log_likelihood,
135 | ll_obs,
136 | ]
137 |
138 | f_standard = pytensor.function(inputs, outputs)
139 |
140 | return f_standard
141 |
142 |
143 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names)
144 | def test_output_shapes_with_time_varying_matrices(f_standard_nd, output_idx, name):
145 | p, m, r, n = 1, 5, 2, 10
146 | data, a0, P0, T, Z, R, H, Q = make_test_inputs(p, m, r, n)
147 | T = np.concatenate([np.expand_dims(T, 0)] * n, axis=0)
148 | Z = np.concatenate([np.expand_dims(Z, 0)] * n, axis=0)
149 | R = np.concatenate([np.expand_dims(R, 0)] * n, axis=0)
150 | H = np.concatenate([np.expand_dims(H, 0)] * n, axis=0)
151 | Q = np.concatenate([np.expand_dims(Q, 0)] * n, axis=0)
152 |
153 | outputs = f_standard_nd(data, a0, P0, T, Z, R, H, Q)
154 | expected_output = get_expected_shape(name, p, m, r, n)
155 |
156 | assert outputs[output_idx].shape == expected_output
157 |
158 |
159 | @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
160 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names)
161 | def test_output_with_deterministic_observation_equation(filter_func, output_idx, name):
162 | p, m, r, n = 1, 5, 1, 10
163 | inputs = make_test_inputs(p, m, r, n)
164 |
165 | outputs = filter_func(*inputs)
166 | expected_output = get_expected_shape(name, p, m, r, n)
167 |
168 | assert outputs[output_idx].shape == expected_output
169 |
170 |
171 | @pytest.mark.parametrize(
172 | ("filter_func", "filter_name"), zip(filter_funcs, filter_names), ids=filter_names
173 | )
174 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names)
175 | def test_output_with_multiple_observed(filter_func, filter_name, output_idx, name):
176 | p, m, r, n = 5, 5, 1, 10
177 | inputs = make_test_inputs(p, m, r, n)
178 | expected_output = get_expected_shape(name, p, m, r, n)
179 |
180 | if filter_name == "SingleTimeSeriesFilter":
181 | with pytest.raises(
182 | AssertionError,
183 | match="UnivariateTimeSeries filter requires data be at most 1-dimensional",
184 | ):
185 | filter_func(*inputs)
186 |
187 | else:
188 | outputs = filter_func(*inputs)
189 | assert outputs[output_idx].shape == expected_output
190 |
191 |
192 | @pytest.mark.parametrize(
193 | ("filter_func", "filter_name"), zip(filter_funcs, filter_names), ids=filter_names
194 | )
195 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names)
196 | @pytest.mark.parametrize("p", [1, 5], ids=["univariate (p=1)", "multivariate (p=5)"])
197 | def test_missing_data(filter_func, filter_name, output_idx, name, p):
198 | m, r, n = 5, 1, 10
199 | inputs = make_test_inputs(p, m, r, n, missing_data=1)
200 | if p > 1 and filter_name == "SingleTimeSeriesFilter":
201 | with pytest.raises(
202 | AssertionError,
203 | match="UnivariateTimeSeries filter requires data be at most 1-dimensional",
204 | ):
205 | filter_func(*inputs)
206 |
207 | else:
208 | outputs = filter_func(*inputs)
209 | assert not np.any(np.isnan(outputs[output_idx]))
210 |
211 |
212 | @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
213 | @pytest.mark.parametrize("output_idx", [(0, 2), (3, 5)], ids=["smoothed_states", "smoothed_covs"])
214 | def test_last_smoother_is_last_filtered(filter_func, output_idx):
215 | p, m, r, n = 1, 5, 1, 10
216 | inputs = make_test_inputs(p, m, r, n)
217 | outputs = filter_func(*inputs)
218 |
219 | filtered = outputs[output_idx[0]]
220 | smoothed = outputs[output_idx[1]]
221 |
222 | assert_allclose(filtered[-1], smoothed[-1])
223 |
224 |
225 | # TODO: These tests omit the SteadyStateFilter, because it gives different results to StatsModels (reason to dump it?)
226 | @pytest.mark.parametrize("filter_func", filter_funcs[:-1], ids=filter_names[:-1])
227 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names)
228 | @pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"])
229 | def test_filters_match_statsmodel_output(filter_func, output_idx, name, n_missing):
230 | fit_sm_mod, inputs = nile_test_test_helper(n_missing)
231 | outputs = filter_func(*inputs)
232 |
233 | val_to_test = outputs[output_idx].squeeze()
234 | ref_val = get_sm_state_from_output_name(fit_sm_mod, name)
235 |
236 | if name == "smoothed_covs":
237 | # TODO: The smoothed covariance matrices have large errors (1e-2) ONLY in the first few states -- no idea why.
238 | assert_allclose(val_to_test[5:], ref_val[5:])
239 | else:
240 | # Need atol = 1e-7 for smoother tests to pass
241 | assert_allclose(val_to_test, ref_val, atol=1e-7)
242 |
243 |
244 | if __name__ == "__main__":
245 | unittest.main()
246 |
--------------------------------------------------------------------------------
/tests/test_local_level.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 | from pathlib import Path
4 |
5 | import pandas as pd
6 |
7 | from pymc_statespace import BayesianLocalLevel
8 |
9 | ROOT = Path(__file__).parent.absolute()
10 | nile = pd.read_csv(os.path.join(ROOT, "test_data/nile.csv"))
11 | nile.index = pd.date_range(start="1871-01-01", end="1970-01-01", freq="AS-Jan")
12 | nile.rename(columns={"x": "height"}, inplace=True)
13 | nile = (nile - nile.mean()) / nile.std()
14 |
15 |
16 | def test_local_level_model():
17 | mod = BayesianLocalLevel(data=nile.values)
18 |
19 |
20 | if __name__ == "__main__":
21 | unittest.main()
22 |
--------------------------------------------------------------------------------
/tests/test_numba_linalg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from numpy.testing import assert_allclose
4 |
5 | from pymc_statespace.utils.numba_linalg import numba_block_diagonal
6 |
7 |
8 | def test_numba_block_diagonal():
9 | stack = np.concatenate([np.eye(3)[None]] * 5, axis=0)
10 | block_stack = numba_block_diagonal(stack)
11 | assert_allclose(block_stack, np.eye(15))
12 |
--------------------------------------------------------------------------------
/tests/test_pytensor_scipy.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import unittest
3 |
4 | import numpy as np
5 | from numpy.testing import assert_allclose
6 | from pytensor.configdefaults import config
7 | from pytensor.gradient import verify_grad as orig_verify_grad
8 |
9 | from pymc_statespace.utils.pytensor_scipy import SolveDiscreteARE, solve_discrete_are
10 |
11 | solve_discrete_are_enforce = SolveDiscreteARE(enforce_Q_symmetric=True)
12 |
13 |
14 | def fetch_seed(pseed=None):
15 | """
16 | Copied from pytensor.test.unittest_tools
17 | """
18 |
19 | seed = pseed or config.unittests__rseed
20 | if seed == "random":
21 | seed = None
22 |
23 | try:
24 | if seed:
25 | seed = int(seed)
26 | else:
27 | seed = None
28 | except ValueError:
29 | print(
30 | ("Error: config.unittests__rseed contains " "invalid seed, using None instead"),
31 | file=sys.stderr,
32 | )
33 | seed = None
34 |
35 | return seed
36 |
37 |
38 | def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs):
39 | """
40 | Copied from pytensor.test.unittest_tools
41 | """
42 | if rng is None:
43 | rng = np.random.default_rng(fetch_seed())
44 |
45 | # TODO: Needed to increase tolerance for certain tests when migrating to
46 | # Generators from RandomStates. Caused flaky test failures. Needs further investigation
47 | if "rel_tol" not in kwargs:
48 | kwargs["rel_tol"] = 0.05
49 | if "abs_tol" not in kwargs:
50 | kwargs["abs_tol"] = 0.05
51 | orig_verify_grad(op, pt, n_tests, rng, *args, **kwargs)
52 |
53 |
54 | class TestSolveDiscreteARE(unittest.TestCase):
55 | def test_forward(self):
56 | # TEST CASE 4 : darex #1 -- taken from Scipy tests
57 | a, b, q, r = (
58 | np.array([[4, 3], [-4.5, -3.5]]),
59 | np.array([[1], [-1]]),
60 | np.array([[9, 6], [6, 4]]),
61 | np.array([[1]]),
62 | )
63 | a, b, q, r = (x.astype("float64") for x in [a, b, q, r])
64 |
65 | x = solve_discrete_are(a, b, q, r).eval()
66 | res = a.T.dot(x.dot(a)) - x + q
67 | res -= (
68 | a.conj()
69 | .T.dot(x.dot(b))
70 | .dot(np.linalg.solve(r + b.conj().T.dot(x.dot(b)), b.T).dot(x.dot(a)))
71 | )
72 |
73 | assert_allclose(res, np.zeros_like(res), atol=1e-12)
74 |
75 | def test_backward(self):
76 |
77 | a, b, q, r = (
78 | np.array([[4, 3], [-4.5, -3.5]]),
79 | np.array([[1], [-1]]),
80 | np.array([[9, 6], [6, 4]]),
81 | np.array([[1]]),
82 | )
83 | a, b, q, r = (x.astype("float64") for x in [a, b, q, r])
84 |
85 | rng = np.random.default_rng(fetch_seed())
86 | verify_grad(solve_discrete_are_enforce, pt=[a, b, q, r], rng=rng)
87 |
88 |
89 | if __name__ == "__main__":
90 | unittest.main()
91 |
--------------------------------------------------------------------------------
/tests/test_representation.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | import pytensor
5 | import pytensor.tensor as pt
6 | from numpy.testing import assert_allclose
7 |
8 | from pymc_statespace.core.representation import PytensorRepresentation
9 | from tests.utilities.test_helpers import make_test_inputs
10 |
11 |
12 | class BasicFunctionality(unittest.TestCase):
13 | def setUp(self):
14 | self.data = np.arange(10)[:, None]
15 |
16 | def test_numpy_to_pytensor(self):
17 | ssm = PytensorRepresentation(data=np.zeros((4, 1)), k_states=5, k_posdef=1)
18 | X = np.eye(5)
19 | X_pt = ssm._numpy_to_pytensor("transition", X)
20 | self.assertTrue(isinstance(X_pt, pt.TensorVariable))
21 |
22 | def test_default_shapes_full_rank(self):
23 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=5)
24 | p = ssm.data.shape[1]
25 | m = ssm.k_states
26 | r = ssm.k_posdef
27 |
28 | self.assertTrue(ssm.data.shape == (10, 1, 1))
29 | self.assertTrue(ssm["design"].eval().shape == (p, m))
30 | self.assertTrue(ssm["transition"].eval().shape == (m, m))
31 | self.assertTrue(ssm["selection"].eval().shape == (m, r))
32 | self.assertTrue(ssm["state_cov"].eval().shape == (r, r))
33 | self.assertTrue(ssm["obs_cov"].eval().shape == (p, p))
34 |
35 | def test_default_shapes_low_rank(self):
36 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=2)
37 | p = ssm.data.shape[1]
38 | m = ssm.k_states
39 | r = ssm.k_posdef
40 |
41 | self.assertTrue(ssm.data.shape == (10, 1, 1))
42 | self.assertTrue(ssm["design"].eval().shape == (p, m))
43 | self.assertTrue(ssm["transition"].eval().shape == (m, m))
44 | self.assertTrue(ssm["selection"].eval().shape == (m, r))
45 | self.assertTrue(ssm["state_cov"].eval().shape == (r, r))
46 | self.assertTrue(ssm["obs_cov"].eval().shape == (p, p))
47 |
48 | def test_matrix_assignment(self):
49 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=2)
50 |
51 | ssm["design", 0, 0] = 3.0
52 | ssm["transition", 0, :] = 2.7
53 | ssm["selection", -1, -1] = 9.9
54 |
55 | self.assertTrue(ssm["design"].eval()[0, 0] == 3.0)
56 | self.assertTrue(np.all(ssm["transition"].eval()[0, :] == 2.7))
57 | self.assertTrue(ssm["selection"].eval()[-1, -1] == 9.9)
58 |
59 | def test_build_representation_from_data(self):
60 | p, m, r, n = 3, 6, 1, 10
61 | inputs = [data, a0, P0, T, Z, R, H, Q] = make_test_inputs(p, m, r, n, missing_data=0)
62 | ssm = PytensorRepresentation(
63 | data=data,
64 | k_states=m,
65 | k_posdef=r,
66 | design=Z,
67 | transition=T,
68 | selection=R,
69 | state_cov=Q,
70 | obs_cov=H,
71 | initial_state=a0,
72 | initial_state_cov=P0,
73 | )
74 | names = [
75 | "initial_state",
76 | "initial_state_cov",
77 | "transition",
78 | "design",
79 | "selection",
80 | "obs_cov",
81 | "state_cov",
82 | ]
83 | for name, X in zip(names, inputs[1:]):
84 | self.assertTrue(np.allclose(X, ssm[name].eval()))
85 |
86 | def test_assign_time_varying_matrices(self):
87 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=2)
88 | n = self.data.shape[0]
89 |
90 | ssm["design", 0, 0] = 3.0
91 | ssm["transition", 0, :] = 2.7
92 | ssm["selection", -1, -1] = 9.9
93 |
94 | ssm["state_intercept"] = np.zeros((5, 1, self.data.shape[0]))
95 | ssm["state_intercept", 0, 0, :] = np.arange(n)
96 |
97 | self.assertTrue(ssm["design"].eval()[0, 0] == 3.0)
98 | self.assertTrue(np.all(ssm["transition"].eval()[0, :] == 2.7))
99 | self.assertTrue(ssm["selection"].eval()[-1, -1] == 9.9)
100 | self.assertTrue(np.allclose(ssm["state_intercept"][0, 0, :].eval(), np.arange(n)))
101 |
102 | def test_invalid_key_name_raises(self):
103 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=1)
104 | with self.assertRaises(IndexError) as e:
105 | X = ssm["invalid_key"]
106 | msg = str(e.exception)
107 | self.assertEqual(msg, "invalid_key is an invalid state space matrix name")
108 |
109 | def test_non_string_key_raises(self):
110 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=1)
111 | with self.assertRaises(IndexError) as e:
112 | X = ssm[0]
113 | msg = str(e.exception)
114 | self.assertEqual(msg, "First index must the name of a valid state space matrix.")
115 |
116 | def test_invalid_key_tuple_raises(self):
117 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=1)
118 | with self.assertRaises(IndexError) as e:
119 | X = ssm[0, 1, 1]
120 | msg = str(e.exception)
121 | self.assertEqual(msg, "First index must the name of a valid state space matrix.")
122 |
123 | def test_slice_statespace_matrix(self):
124 | T = np.eye(5)
125 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=1, transition=T)
126 | T_out = ssm["transition", :3, :]
127 | assert_allclose(T[:3], T_out.eval())
128 |
129 | def test_update_matrix_via_key(self):
130 | T = np.eye(5)
131 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=1)
132 | ssm["transition"] = T
133 |
134 | assert_allclose(T, ssm["transition"].eval())
135 |
136 | def test_update_matrix_with_invalid_shape_raises(self):
137 | T = np.eye(10)
138 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=1)
139 | with self.assertRaises(ValueError) as e:
140 | ssm["transition"] = T
141 | msg = str(e.exception)
142 | self.assertEqual(msg, "Array provided for transition has shape (10, 10), expected (5, 5)")
143 |
144 |
145 | if __name__ == "__main__":
146 | unittest.main()
147 |
--------------------------------------------------------------------------------
/tests/test_simulations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from numpy.testing import assert_allclose
4 |
5 | from pymc_statespace.utils.simulation import (
6 | conditional_simulation,
7 | numba_mvn_draws,
8 | simulate_statespace,
9 | unconditional_simulations,
10 | )
11 | from tests.utilities.test_helpers import make_test_inputs
12 |
13 |
14 | def test_numba_mvn_draws():
15 | cov = np.random.normal(size=(2, 2))
16 | cov = cov @ cov.T
17 | draws = numba_mvn_draws(mu=np.zeros((2, 1_000_000)), cov=cov)
18 |
19 | assert_allclose(np.cov(draws), cov, atol=0.01, rtol=0.01)
20 |
21 |
22 | def test_simulate_statespace():
23 | data, a0, P0, T, Z, R, H, Q = make_test_inputs(3, 5, 1, 100)
24 | simulated_states, simulated_obs = simulate_statespace(T, Z, R, H, Q, n_steps=100)
25 |
26 | assert simulated_states.shape == (100, 5)
27 | assert simulated_obs.shape == (100, 3)
28 |
29 |
30 | def test_simulate_statespace_with_x0():
31 | data, a0, P0, T, Z, R, H, Q = make_test_inputs(3, 5, 1, 100)
32 | simulated_states, simulated_obs = simulate_statespace(
33 | T, Z, R, H, Q, n_steps=100, x0=a0.squeeze()
34 | )
35 |
36 | assert simulated_states.shape == (100, 5)
37 | assert simulated_obs.shape == (100, 3)
38 | assert np.all(simulated_states[0] == a0)
39 | assert np.all(simulated_obs[0] == Z @ a0)
40 |
41 |
42 | def test_simulate_statespace_no_obs_noise():
43 | data, a0, P0, T, Z, R, H, Q = make_test_inputs(3, 5, 1, 100)
44 | H = np.zeros_like(H)
45 | simulated_states, simulated_obs = simulate_statespace(
46 | T, Z, R, H, Q, n_steps=100, x0=a0.squeeze()
47 | )
48 |
49 | assert simulated_states.shape == (100, 5)
50 | assert simulated_obs.shape == (100, 3)
51 |
52 |
53 | def test_conditional_simulation():
54 | pass
55 |
56 |
57 | def test_unconditional_simulation():
58 | pass
59 |
--------------------------------------------------------------------------------
/tests/test_statespace.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import pymc as pm
8 | import pytest
9 | from numpy.testing import assert_allclose
10 |
11 | from pymc_statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace
12 | from tests.utilities.test_helpers import make_test_inputs
13 |
14 | ROOT = Path(__file__).parent.absolute()
15 | nile = pd.read_csv(os.path.join(ROOT, "test_data/nile.csv"))
16 | nile.index = pd.date_range(start="1871-01-01", end="1970-01-01", freq="AS-Jan")
17 | nile.rename(columns={"x": "height"}, inplace=True)
18 | nile = (nile - nile.mean()) / nile.std()
19 |
20 |
21 | @pytest.fixture()
22 | def ss_mod():
23 | class StateSpace(PyMCStateSpace):
24 | @property
25 | def param_names(self):
26 | return ["rho", "zeta"]
27 |
28 | def update(self, theta):
29 | self.ssm["transition", 0, :] = theta
30 |
31 | T = np.zeros((2, 2)).astype("float64")
32 | T[1, 0] = 1.0
33 | Z = np.array([[1.0, 0.0]])
34 | R = np.array([[1.0], [0.0]])
35 | H = np.array([[0.1]])
36 | Q = np.array([[0.8]])
37 |
38 | ss_mod = StateSpace(data=nile, k_states=2, k_posdef=1, filter_type="standard")
39 | for X, name in zip(
40 | [T, Z, R, H, Q], ["transition", "design", "selection", "obs_cov", "state_cov"]
41 | ):
42 | ss_mod.ssm[name] = X
43 |
44 | return ss_mod
45 |
46 |
47 | @pytest.fixture
48 | def pymc_mod(ss_mod):
49 | with pm.Model() as pymc_mod:
50 | rho = pm.Normal("rho")
51 | zeta = pm.Deterministic("zeta", 1 - rho)
52 | ss_mod.build_statespace_graph()
53 | ss_mod.build_smoother_graph()
54 |
55 | return pymc_mod
56 |
57 |
58 | @pytest.fixture
59 | def idata(pymc_mod):
60 | with pymc_mod:
61 | idata = pm.sample(draws=100, tune=0, chains=1)
62 |
63 | return idata
64 |
65 |
66 | def test_invalid_filter_name_raises():
67 | msg = "The following are valid filter types: " + ", ".join(list(FILTER_FACTORY.keys()))
68 | with pytest.raises(NotImplementedError, match=msg):
69 | mod = PyMCStateSpace(data=nile.values, k_states=5, k_posdef=1, filter_type="invalid_filter")
70 |
71 |
72 | def test_singleseriesfilter_raises_if_data_is_nd():
73 | data = np.random.normal(size=(4, 10))
74 | msg = 'Cannot use filter_type = "single" with multiple observed time series'
75 | with pytest.raises(ValueError, match=msg):
76 | mod = PyMCStateSpace(data=data, k_states=5, k_posdef=1, filter_type="single")
77 |
78 |
79 | def test_unpack_matrices():
80 | p, m, r, n = 2, 5, 1, 10
81 | data, *inputs = make_test_inputs(p, m, r, n, missing_data=0)
82 | mod = PyMCStateSpace(
83 | data=data[..., 0], k_states=m, k_posdef=r, filter_type="standard", verbose=False
84 | )
85 |
86 | outputs = mod.unpack_statespace()
87 | for x, y in zip(inputs, outputs):
88 | assert_allclose(np.zeros_like(x), y.eval())
89 |
90 |
91 | def test_param_names_raises_on_base_class():
92 | mod = PyMCStateSpace(data=nile, k_states=5, k_posdef=1, filter_type="standard", verbose=False)
93 | with pytest.raises(NotImplementedError):
94 | x = mod.param_names
95 |
96 |
97 | def test_update_raises_on_base_class():
98 | mod = PyMCStateSpace(data=nile, k_states=5, k_posdef=1, filter_type="standard", verbose=False)
99 | theta = np.zeros(4)
100 | with pytest.raises(NotImplementedError):
101 | mod.update(theta)
102 |
103 |
104 | def test_gather_pymc_variables(ss_mod):
105 | with pm.Model() as mod:
106 | rho = pm.Normal("rho")
107 | zeta = pm.Deterministic("zeta", 1 - rho)
108 | theta = ss_mod.gather_required_random_variables()
109 |
110 | assert_allclose(pm.math.stack([rho, zeta]).eval(), theta.eval())
111 |
112 |
113 | def test_gather_raises_if_variable_missing(ss_mod):
114 | with pm.Model() as mod:
115 | rho = pm.Normal("rho")
116 | msg = "The following required model parameters were not found in the PyMC model: zeta"
117 | with pytest.raises(ValueError, match=msg):
118 | theta = ss_mod.gather_required_random_variables()
119 |
120 |
121 | def test_build_smoother_fails_if_statespace_not_built_first(ss_mod):
122 | msg = (
123 | "Couldn't find Kalman filtered time series among model deterministics. Have you run"
124 | ".build_statespace_graph() ?"
125 | )
126 |
127 | with pm.Model() as mod:
128 | rho = pm.Normal("rho")
129 | zeta = pm.Normal("zeta")
130 | with pytest.raises(ValueError, match=msg):
131 | ss_mod.build_smoother_graph()
132 |
133 |
134 | def test_build_statespace_graph(pymc_mod):
135 | for name in [
136 | "filtered_states",
137 | "predicted_states",
138 | "predicted_covariances",
139 | "filtered_covariances",
140 | ]:
141 | assert name in [x.name for x in pymc_mod.deterministics]
142 |
143 |
144 | def test_build_smoother_graph(ss_mod, pymc_mod):
145 | names = ["smoothed_states", "smoothed_covariances"]
146 | for name in names:
147 | assert name in [x.name for x in pymc_mod.deterministics]
148 |
149 |
150 | @pytest.mark.parametrize(
151 | "filter_output",
152 | ["filtered", "predicted", "smoothed", "invalid"],
153 | ids=["filtered", "predicted", "smoothed", "invalid"],
154 | )
155 | def test_sample_conditional_prior(ss_mod, pymc_mod, filter_output):
156 | if filter_output == "invalid":
157 | msg = "filter_output should be one of filtered, predicted, or smoothed, recieved invalid"
158 | with pytest.raises(ValueError, match=msg), pymc_mod:
159 | ss_mod.sample_conditional_prior(filter_output=filter_output)
160 | else:
161 | with pymc_mod:
162 | conditional_prior = ss_mod.sample_conditional_prior(
163 | filter_output=filter_output, n_simulations=1, prior_samples=100
164 | )
165 |
166 |
167 | @pytest.mark.parametrize(
168 | "filter_output",
169 | ["filtered", "predicted", "smoothed", "invalid"],
170 | ids=["filtered", "predicted", "smoothed", "invalid"],
171 | )
172 | def test_sample_conditional_posterior(ss_mod, pymc_mod, idata, filter_output):
173 | if filter_output == "invalid":
174 | msg = "filter_output should be one of filtered, predicted, or smoothed, recieved invalid"
175 | with pytest.raises(ValueError, match=msg), pymc_mod:
176 | ss_mod.sample_conditional_prior(filter_output=filter_output)
177 | else:
178 | with pymc_mod:
179 | conditional_prior = ss_mod.sample_conditional_posterior(
180 | idata, filter_output=filter_output, n_simulations=1, posterior_samples=0.5
181 | )
182 |
183 |
184 | def test_sample_conditional_posterior_raises_on_invalid_samples(ss_mod, pymc_mod, idata):
185 | msg = (
186 | "If posterior_samples is a float, it should be between 0 and 1, representing the "
187 | "fraction of total posterior samples to re-sample."
188 | )
189 |
190 | with pymc_mod:
191 | with pytest.raises(ValueError, match=msg):
192 | conditional_prior = ss_mod.sample_conditional_posterior(
193 | idata, filter_output="predicted", n_simulations=1, posterior_samples=-0.3
194 | )
195 |
196 |
197 | def test_sample_conditional_posterior_default_samples(ss_mod, pymc_mod, idata):
198 | with pymc_mod:
199 | conditional_prior = ss_mod.sample_conditional_posterior(
200 | idata, filter_output="predicted", n_simulations=1
201 | )
202 |
203 |
204 | def test_sample_unconditional_prior(ss_mod, pymc_mod):
205 | with pymc_mod:
206 | unconditional_prior = ss_mod.sample_unconditional_prior(n_simulations=1, prior_samples=100)
207 |
208 |
209 | def test_sample_unconditional_posterior(ss_mod, pymc_mod, idata):
210 | with pymc_mod:
211 | unconditional_posterior = ss_mod.sample_unconditional_posterior(
212 | idata, n_steps=100, n_simulations=1, posterior_samples=10
213 | )
214 |
215 |
216 | if __name__ == "__main__":
217 | unittest.main()
218 |
--------------------------------------------------------------------------------
/tests/utilities/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jessegrabowski/pymc_statespace/0659a3bfd9186f128f238c69d5193d3c461f3948/tests/utilities/__init__.py
--------------------------------------------------------------------------------
/tests/utilities/statsmodel_local_level.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import statsmodels.api as sm
3 |
4 |
5 | class LocalLinearTrend(sm.tsa.statespace.MLEModel):
6 | def __init__(self, endog, **kwargs):
7 | # Model order
8 | k_states = k_posdef = 2
9 |
10 | # Initialize the statespace
11 | super().__init__(endog, k_states=k_states, k_posdef=k_posdef, **kwargs)
12 |
13 | # Initialize the matrices
14 | self.ssm["design"] = np.array([1, 0])
15 | self.ssm["transition"] = np.array([[1, 1], [0, 1]])
16 | self.ssm["selection"] = np.eye(k_states)
17 |
18 | # Cache some indices
19 | self._state_cov_idx = ("state_cov",) + np.diag_indices(k_posdef)
20 |
21 | @property
22 | def param_names(self):
23 | return ["sigma2.measurement", "sigma2.level", "sigma2.trend"]
24 |
25 | @property
26 | def start_params(self):
27 | return [np.std(self.endog)] * 3
28 |
29 | def transform_params(self, unconstrained):
30 | return unconstrained**2
31 |
32 | def untransform_params(self, constrained):
33 | return constrained**0.5
34 |
35 | def update(self, params, *args, **kwargs):
36 | params = super().update(params, *args, **kwargs)
37 |
38 | # Observation covariance
39 | self.ssm["obs_cov", 0, 0] = params[0]
40 |
41 | # State covariance
42 | self.ssm[self._state_cov_idx] = params[1:]
43 |
--------------------------------------------------------------------------------
/tests/utilities/test_helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import pytensor.tensor as pt
7 |
8 | from pymc_statespace.filters.kalman_smoother import KalmanSmoother
9 | from tests.utilities.statsmodel_local_level import LocalLinearTrend
10 |
11 | ROOT = Path(__file__).parent.parent.absolute()
12 | nile_data = pd.read_csv(os.path.join(ROOT, "test_data/nile.csv"))
13 | nile_data["x"] = nile_data["x"].astype(float)
14 |
15 |
16 | def initialize_filter(kfilter):
17 | ksmoother = KalmanSmoother()
18 | data = pt.dtensor3(name="data")
19 | a0 = pt.matrix(name="a0")
20 | P0 = pt.matrix(name="P0")
21 | Q = pt.matrix(name="Q")
22 | H = pt.matrix(name="H")
23 | T = pt.matrix(name="T")
24 | R = pt.matrix(name="R")
25 | Z = pt.matrix(name="Z")
26 |
27 | inputs = [data, a0, P0, T, Z, R, H, Q]
28 |
29 | (
30 | filtered_states,
31 | predicted_states,
32 | filtered_covs,
33 | predicted_covs,
34 | log_likelihood,
35 | ll_obs,
36 | ) = kfilter.build_graph(*inputs)
37 |
38 | smoothed_states, smoothed_covs = ksmoother.build_graph(T, R, Q, filtered_states, filtered_covs)
39 |
40 | outputs = [
41 | filtered_states,
42 | predicted_states,
43 | smoothed_states,
44 | filtered_covs,
45 | predicted_covs,
46 | smoothed_covs,
47 | log_likelihood,
48 | ll_obs,
49 | ]
50 |
51 | return inputs, outputs
52 |
53 |
54 | def add_missing_data(data, n_missing):
55 | n = data.shape[0]
56 | missing_idx = np.random.choice(n, n_missing, replace=False)
57 | data[missing_idx] = np.nan
58 |
59 | return data
60 |
61 |
62 | def make_test_inputs(p, m, r, n, missing_data=None, H_is_zero=False):
63 | data = np.arange(n * p, dtype="float").reshape(-1, p, 1)
64 | if missing_data is not None:
65 | data = add_missing_data(data, missing_data)
66 |
67 | a0 = np.zeros((m, 1))
68 | P0 = np.eye(m)
69 | Q = np.eye(r)
70 | H = np.zeros((p, p)) if H_is_zero else np.eye(p)
71 | T = np.eye(m, k=-1)
72 | T[0, :] = 1 / m
73 | R = np.eye(m)[:, :r]
74 | Z = np.eye(m)[:p, :]
75 |
76 | return data, a0, P0, T, Z, R, H, Q
77 |
78 |
79 | def get_expected_shape(name, p, m, r, n):
80 | if name == "log_likelihood":
81 | return ()
82 | elif name == "ll_obs":
83 | return (n,)
84 | filter_type, variable = name.split("_")
85 | if filter_type == "predicted":
86 | n += 1
87 | if variable == "states":
88 | return n, m, 1
89 | if variable == "covs":
90 | return n, m, m
91 |
92 |
93 | def get_sm_state_from_output_name(res, name):
94 | if name == "log_likelihood":
95 | return res.llf
96 | elif name == "ll_obs":
97 | return res.llf_obs
98 |
99 | filter_type, variable = name.split("_")
100 | sm_states = getattr(res, "states")
101 |
102 | if variable == "states":
103 | return getattr(sm_states, filter_type)
104 | if variable == "covs":
105 | m = res.filter_results.k_states
106 | # remove the "s" from "covs"
107 | return getattr(sm_states, name[:-1]).reshape(-1, m, m)
108 |
109 |
110 | def nile_test_test_helper(n_missing=0):
111 | a0 = np.zeros((2, 1))
112 | P0 = np.eye(2) * 1e6
113 | Q = np.eye(2) * np.array([0.5, 0.01])
114 | H = np.eye(1) * 0.8
115 | T = np.array([[1.0, 1.0], [0.0, 1.0]])
116 | R = np.eye(2)
117 | Z = np.array([[1.0, 0.0]])
118 |
119 | data = nile_data.values.copy()
120 | if n_missing > 0:
121 | data = add_missing_data(data, n_missing)
122 |
123 | sm_model = LocalLinearTrend(
124 | endog=data,
125 | initialization="known",
126 | initial_state_cov=P0,
127 | initial_state=a0.ravel(),
128 | )
129 |
130 | res = sm_model.fit_constrained(
131 | constraints={
132 | "sigma2.measurement": 0.8,
133 | "sigma2.level": 0.5,
134 | "sigma2.trend": 0.01,
135 | }
136 | )
137 |
138 | inputs = [data[..., None], a0, P0, T, Z, R, H, Q]
139 |
140 | return res, inputs
141 |
--------------------------------------------------------------------------------